Skip to content

Commit

Permalink
Allow checking constraint type by struct name (#24802)
Browse files Browse the repository at this point in the history
  • Loading branch information
tehampson authored and pull[bot] committed Nov 30, 2023
1 parent b5475e2 commit 6076238
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
44 changes: 23 additions & 21 deletions scripts/py_matter_yamltests/matter_yamltests/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, types: list, is_null_allowed: bool = False):
self._types = types
self._is_null_allowed = is_null_allowed

def is_met(self, value):
def is_met(self, value, value_type_name):
if value is None:
return self._is_null_allowed

Expand All @@ -43,10 +43,10 @@ def is_met(self, value):
if not found_type_match:
return False

return self.check_response(value)
return self.check_response(value, value_type_name)

@abstractmethod
def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
pass


Expand All @@ -55,13 +55,13 @@ def __init__(self, has_value):
super().__init__(types=[])
self._has_value = has_value

def is_met(self, value):
def is_met(self, value, value_type_name):
# We are overriding the BaseConstraint of is_met since has value is a special case where
# we might not be expecting a value at all, but the basic null check in BaseConstraint
# is not what we want.
return self.check_response(value)
return self.check_response(value, value_type_name)

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
has_value = value is not None
return self._has_value == has_value

Expand All @@ -71,7 +71,7 @@ def __init__(self, type):
super().__init__(types=[], is_null_allowed=True)
self._type = type

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
success = False
if self._type == 'boolean' and type(value) is bool:
success = True
Expand Down Expand Up @@ -195,6 +195,8 @@ def check_response(self, value) -> bool:
success = value >= -36028797018963967 and value <= 36028797018963967
elif self._type == 'nullable_int64s' and type(value) is int:
success = value >= -9223372036854775807 and value <= 9223372036854775807
else:
success = self._type == value_type_name
return success


Expand All @@ -203,7 +205,7 @@ def __init__(self, min_length):
super().__init__(types=[str, bytes, list])
self._min_length = min_length

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return len(value) >= self._min_length


Expand All @@ -212,7 +214,7 @@ def __init__(self, max_length):
super().__init__(types=[str, bytes, list])
self._max_length = max_length

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return len(value) <= self._max_length


Expand All @@ -221,7 +223,7 @@ def __init__(self, is_hex_string: bool):
super().__init__(types=[str])
self._is_hex_string = is_hex_string

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return all(c in string.hexdigits for c in value) == self._is_hex_string


Expand All @@ -230,7 +232,7 @@ def __init__(self, starts_with):
super().__init__(types=[str])
self._starts_with = starts_with

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return value.startswith(self._starts_with)


Expand All @@ -239,7 +241,7 @@ def __init__(self, ends_with):
super().__init__(types=[str])
self._ends_with = ends_with

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return value.endswith(self._ends_with)


Expand All @@ -248,7 +250,7 @@ def __init__(self, is_upper_case):
super().__init__(types=[str])
self._is_upper_case = is_upper_case

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return value.isupper() == self._is_upper_case


Expand All @@ -257,7 +259,7 @@ def __init__(self, is_lower_case):
super().__init__(types=[str])
self._is_lower_case = is_lower_case

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return value.islower() == self._is_lower_case


Expand All @@ -266,7 +268,7 @@ def __init__(self, min_value):
super().__init__(types=[int, float], is_null_allowed=True)
self._min_value = min_value

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return value >= self._min_value


Expand All @@ -275,7 +277,7 @@ def __init__(self, max_value):
super().__init__(types=[int, float], is_null_allowed=True)
self._max_value = max_value

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return value <= self._max_value


Expand All @@ -284,7 +286,7 @@ def __init__(self, contains):
super().__init__(types=[list])
self._contains = contains

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return set(self._contains).issubset(value)


Expand All @@ -293,7 +295,7 @@ def __init__(self, excludes):
super().__init__(types=[list])
self._excludes = excludes

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return set(self._excludes).isdisjoint(value)


Expand All @@ -302,7 +304,7 @@ def __init__(self, has_masks_set):
super().__init__(types=[int])
self._has_masks_set = has_masks_set

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return all([(value & mask) == mask for mask in self._has_masks_set])


Expand All @@ -311,7 +313,7 @@ def __init__(self, has_masks_clear):
super().__init__(types=[int])
self._has_masks_clear = has_masks_clear

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return all([(value & mask) == 0 for mask in self._has_masks_clear])


Expand All @@ -320,7 +322,7 @@ def __init__(self, not_value):
super().__init__(types=[], is_null_allowed=True)
self._not_value = not_value

def check_response(self, value) -> bool:
def check_response(self, value, value_type_name) -> bool:
return value != self._not_value


Expand Down
9 changes: 8 additions & 1 deletion scripts/py_matter_yamltests/matter_yamltests/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(self, test: dict, config: dict, definitions: SpecDefinitions, pics_

argument_mapping = None
response_mapping = None
response_mapping_name = None

if self.is_attribute:
attribute = definitions.get_attribute_by_name(
Expand All @@ -242,6 +243,7 @@ def __init__(self, test: dict, config: dict, definitions: SpecDefinitions, pics_
attribute.definition.data_type.name)
argument_mapping = attribute_mapping
response_mapping = attribute_mapping
response_mapping_name = attribute.definition.data_type.name
else:
command = definitions.get_command_by_name(
self.cluster, self.command)
Expand All @@ -250,9 +252,11 @@ def __init__(self, test: dict, config: dict, definitions: SpecDefinitions, pics_
definitions, self.cluster, command.input_param)
response_mapping = self._as_mapping(
definitions, self.cluster, command.output_param)
response_mapping_name = command.output_param

self.argument_mapping = argument_mapping
self.response_mapping = response_mapping
self.response_mapping_name = response_mapping_name
self.update_arguments(self.arguments_with_placeholders)
self.update_response(self.response_with_placeholders)

Expand Down Expand Up @@ -678,13 +682,16 @@ def _response_constraints_validation(self, response, result):
error_success = 'Constraints check passed'
error_failure = 'Constraints check failed'

response_type_name = self._test.response_mapping_name
for value in self.response['values']:
if 'constraints' not in value:
continue

received_value = response.get('value')
if not self.is_attribute:
expected_name = value.get('name')
response_type_name = self._test.response_mapping.get(
expected_name)
if received_value is None or expected_name not in received_value:
received_value = None
else:
Expand All @@ -693,7 +700,7 @@ def _response_constraints_validation(self, response, result):

constraints = get_constraints(value['constraints'])

if all([constraint.is_met(received_value) for constraint in constraints]):
if all([constraint.is_met(received_value, response_type_name) for constraint in constraints]):
result.success(check_type, error_success)
else:
# TODO would be helpful to be more verbose here
Expand Down

0 comments on commit 6076238

Please sign in to comment.