From aa356c0ae2e18ba7785949dca6c2787ad43edd73 Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Wed, 17 Jul 2019 15:55:13 -0700 Subject: [PATCH] skip decorators with names matching the 'data_key' of defined schema fields Signed-off-by: Andrew Whitehead --- aries_cloudagent/messaging/agent_message.py | 2 +- aries_cloudagent/messaging/decorators/base.py | 27 +++++++++++++++---- .../decorators/tests/test_decorator_set.py | 25 ++++++++++++++--- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/aries_cloudagent/messaging/agent_message.py b/aries_cloudagent/messaging/agent_message.py index 3714f2f1d8..923efbd2b0 100644 --- a/aries_cloudagent/messaging/agent_message.py +++ b/aries_cloudagent/messaging/agent_message.py @@ -340,7 +340,7 @@ def extract_decorators(self, data): ValidationError: If there is a missing field signature """ - processed = self._decorators.extract_decorators(data) + processed = self._decorators.extract_decorators(data, self.__class__) expect_fields = resolve_meta_property(self, "signed_fields") or () found_signatures = {} diff --git a/aries_cloudagent/messaging/decorators/base.py b/aries_cloudagent/messaging/decorators/base.py index ca3e55a496..717f8ed3ec 100644 --- a/aries_cloudagent/messaging/decorators/base.py +++ b/aries_cloudagent/messaging/decorators/base.py @@ -1,7 +1,10 @@ """Classes for managing a collection of decorators.""" from collections import OrderedDict -from typing import Mapping, Type +from typing import Mapping, Sequence, Type + +from marshmallow import Schema +from marshmallow.fields import Field from ...error import BaseError @@ -96,20 +99,34 @@ def load_decorator(self, key: str, value, serialized=False): elif key in self: del self[key] - def extract_decorators(self, message: Mapping, serialized=True) -> OrderedDict: + def extract_decorators( + self, + message: Mapping, + schema: Type[Schema] = None, + serialized: bool = True, + skip_attrs: Sequence[str] = None, + ) -> OrderedDict: """Extract decorators and return the remaining properties.""" remain = OrderedDict() + skip_attrs = set(skip_attrs) if skip_attrs else set() + if schema: + for field_name, field_def in schema._declared_fields.items(): + if isinstance(field_def, Field) and field_def.data_key: + skip_attrs.add(field_def.data_key) if message: pfx_len = len(self._prefix) for key, value in message.items(): - if key.startswith(self._prefix): + if key in skip_attrs: + pass + elif key.startswith(self._prefix): key = key[pfx_len:] self.load_decorator(key, value, serialized) + continue elif self._prefix in key: field, key = key.split(self._prefix, 1) self.field(field).load_decorator(key, value, serialized) - else: - remain[key] = value + continue + remain[key] = value return remain def to_dict(self, prefix: str = None) -> OrderedDict: diff --git a/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py b/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py index 4c7e7e86f1..134d635b5d 100644 --- a/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py +++ b/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py @@ -10,8 +10,9 @@ class SimpleModel(BaseModel): class Meta: schema_class = "SimpleModelSchema" - def __init__(self, *, value: str = None, **kwargs): + def __init__(self, *, value: str = None, handled_decorator: str = None, **kwargs): super().__init__(**kwargs) + self.handled_decorator = handled_decorator self.value = value @@ -20,6 +21,7 @@ class Meta: model_class = SimpleModel value = fields.Str(required=True) + handled_decorator = fields.Str(required=False, data_key="handled~decorator") class TestDecoratorSet(TestCase): @@ -52,7 +54,7 @@ def test_decorator_model(self): decors = BaseDecoratorSet() decors.add_model("test", SimpleModel) - remain = decors.extract_decorators(message) + remain = decors.extract_decorators(message, SimpleModelSchema) tested = decors["test"] assert isinstance(tested, SimpleModel) and tested.value == "TEST" @@ -66,7 +68,7 @@ def test_field_decorator(self): message = {"test~decorator": decor_value, "one": "TWO"} decors = BaseDecoratorSet() - remain = decors.extract_decorators(message) + remain = decors.extract_decorators(message, SimpleModelSchema) # check original is unmodified assert "test~decorator" in message @@ -74,3 +76,20 @@ def test_field_decorator(self): assert decors.field("test") assert decors.field("test")["decorator"] is decor_value assert remain == {"one": "TWO"} + + def test_skip_decorator(self): + + decor_value = {} + message = {"handled~decorator": decor_value, "one": "TWO"} + + decors = BaseDecoratorSet() + remain = decors.extract_decorators(message, SimpleModelSchema) + print(SimpleModelSchema.__dict__) + + # check original is unmodified + assert "handled~decorator" in message + + print(remain) + assert not decors.field("handled") + assert remain == message +