Skip to content

Commit

Permalink
skip decorators with names matching the 'data_key' of defined schema …
Browse files Browse the repository at this point in the history
…fields

Signed-off-by: Andrew Whitehead <[email protected]>
  • Loading branch information
andrewwhitehead committed Jul 17, 2019
1 parent b043d48 commit aa356c0
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aries_cloudagent/messaging/agent_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def extract_decorators(self, data):
ValidationError: If there is a missing field signature
"""
processed = self._decorators.extract_decorators(data)
processed = self._decorators.extract_decorators(data, self.__class__)

expect_fields = resolve_meta_property(self, "signed_fields") or ()
found_signatures = {}
Expand Down
27 changes: 22 additions & 5 deletions aries_cloudagent/messaging/decorators/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Classes for managing a collection of decorators."""

from collections import OrderedDict
from typing import Mapping, Type
from typing import Mapping, Sequence, Type

from marshmallow import Schema
from marshmallow.fields import Field

from ...error import BaseError

Expand Down Expand Up @@ -96,20 +99,34 @@ def load_decorator(self, key: str, value, serialized=False):
elif key in self:
del self[key]

def extract_decorators(self, message: Mapping, serialized=True) -> OrderedDict:
def extract_decorators(
self,
message: Mapping,
schema: Type[Schema] = None,
serialized: bool = True,
skip_attrs: Sequence[str] = None,
) -> OrderedDict:
"""Extract decorators and return the remaining properties."""
remain = OrderedDict()
skip_attrs = set(skip_attrs) if skip_attrs else set()
if schema:
for field_name, field_def in schema._declared_fields.items():
if isinstance(field_def, Field) and field_def.data_key:
skip_attrs.add(field_def.data_key)
if message:
pfx_len = len(self._prefix)
for key, value in message.items():
if key.startswith(self._prefix):
if key in skip_attrs:
pass
elif key.startswith(self._prefix):
key = key[pfx_len:]
self.load_decorator(key, value, serialized)
continue
elif self._prefix in key:
field, key = key.split(self._prefix, 1)
self.field(field).load_decorator(key, value, serialized)
else:
remain[key] = value
continue
remain[key] = value
return remain

def to_dict(self, prefix: str = None) -> OrderedDict:
Expand Down
25 changes: 22 additions & 3 deletions aries_cloudagent/messaging/decorators/tests/test_decorator_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ class SimpleModel(BaseModel):
class Meta:
schema_class = "SimpleModelSchema"

def __init__(self, *, value: str = None, **kwargs):
def __init__(self, *, value: str = None, handled_decorator: str = None, **kwargs):
super().__init__(**kwargs)
self.handled_decorator = handled_decorator
self.value = value


Expand All @@ -20,6 +21,7 @@ class Meta:
model_class = SimpleModel

value = fields.Str(required=True)
handled_decorator = fields.Str(required=False, data_key="handled~decorator")


class TestDecoratorSet(TestCase):
Expand Down Expand Up @@ -52,7 +54,7 @@ def test_decorator_model(self):

decors = BaseDecoratorSet()
decors.add_model("test", SimpleModel)
remain = decors.extract_decorators(message)
remain = decors.extract_decorators(message, SimpleModelSchema)

tested = decors["test"]
assert isinstance(tested, SimpleModel) and tested.value == "TEST"
Expand All @@ -66,11 +68,28 @@ def test_field_decorator(self):
message = {"test~decorator": decor_value, "one": "TWO"}

decors = BaseDecoratorSet()
remain = decors.extract_decorators(message)
remain = decors.extract_decorators(message, SimpleModelSchema)

# check original is unmodified
assert "test~decorator" in message

assert decors.field("test")
assert decors.field("test")["decorator"] is decor_value
assert remain == {"one": "TWO"}

def test_skip_decorator(self):

decor_value = {}
message = {"handled~decorator": decor_value, "one": "TWO"}

decors = BaseDecoratorSet()
remain = decors.extract_decorators(message, SimpleModelSchema)
print(SimpleModelSchema.__dict__)

# check original is unmodified
assert "handled~decorator" in message

print(remain)
assert not decors.field("handled")
assert remain == message

0 comments on commit aa356c0

Please sign in to comment.