Skip to content

Commit

Permalink
Merge pull request openwallet-foundation#2060 from andrewwhitehead/fi…
Browse files Browse the repository at this point in the history
…x/accept-unknown-hsproto

Do not reject OOB invitation with unknown handshake protocol(s)
  • Loading branch information
ianco authored Jan 5, 2023
2 parents 8a80f71 + 5aac8c0 commit 6f2ef55
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 49 deletions.
49 changes: 15 additions & 34 deletions aries_cloudagent/messaging/valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,60 +23,41 @@
class StrOrDictField(Field):
"""URI or Dict field for Marshmallow."""

def _serialize(self, value, attr, obj, **kwargs):
return value

def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, (str, dict)):
return value
else:
if not isinstance(value, (str, dict)):
raise ValidationError("Field should be str or dict")
return super()._deserialize(value, attr, data, **kwargs)


class StrOrNumberField(Field):
"""String or Number field for Marshmallow."""

def _serialize(self, value, attr, obj, **kwargs):
return value

def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, (str, float, int)):
return value
else:
if not isinstance(value, (str, float, int)):
raise ValidationError("Field should be str or int or float")
return super()._deserialize(value, attr, data, **kwargs)


class DictOrDictListField(Field):
"""Dict or Dict List field for Marshmallow."""

def _serialize(self, value, attr, obj, **kwargs):
return value

def _deserialize(self, value, attr, data, **kwargs):
# dict
if isinstance(value, dict):
return value
# list of dicts
elif isinstance(value, list) and all(isinstance(item, dict) for item in value):
return value
else:
raise ValidationError("Field should be dict or list of dicts")
if not isinstance(value, dict):
if not isinstance(value, list) or not all(
isinstance(item, dict) for item in value
):
raise ValidationError("Field should be dict or list of dicts")
return super()._deserialize(value, attr, data, **kwargs)


class UriOrDictField(StrOrDictField):
"""URI or Dict field for Marshmallow."""

def __init__(self, *args, **kwargs):
"""Initialize new UriOrDictField instance."""
super().__init__(*args, **kwargs)

# Insert validation into self.validators so that multiple errors can be stored.
self.validators.insert(0, self._uri_validator)

def _uri_validator(self, value):
# Check if URI when
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, str):
return Uri()(value)
# Check regex
Uri()(value)
return super()._deserialize(value, attr, data, **kwargs)


class IntEpoch(Range):
Expand Down Expand Up @@ -775,7 +756,7 @@ def __call__(self, value):
except ValidationError:
raise ValidationError(
f"credential subject id {value[0]} must be URI"
)
) from None

return value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _serialize(self, value, attr, obj, **kwargs):
"""
return value.serialize()

def _deserialize(self, value, attr, data, **kwargs):
def _deserialize(self, value, attr=None, data=None, **kwargs):
"""
Deserialize a value into a DIDDoc.
Expand Down
20 changes: 9 additions & 11 deletions aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,17 @@ def _serialize(self, value, attr, obj, **kwargs):
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, dict):
return Service.deserialize(value)
elif isinstance(value, Service):
return value
elif isinstance(value, str):
if bool(DIDValidation.PATTERN.match(value)):
return value
else:
if not DIDValidation.PATTERN.match(value):
raise ValidationError(
"Service item must be a valid decentralized identifier (DID)"
)
return value
raise ValidationError(
"Service item must be a valid decentralized identifier (DID) or object"
)


class InvitationMessage(AgentMessage):
Expand Down Expand Up @@ -221,9 +225,6 @@ class Meta:
fields.Str(
description="Handshake protocol",
example=DIDCommPrefix.qualify_current(HSProto.RFC23.name),
validate=lambda hsp: (
DIDCommPrefix.unqualify(hsp) in [p.name for p in HSProto]
),
),
required=False,
)
Expand Down Expand Up @@ -276,13 +277,10 @@ def validate_fields(self, data, **kwargs):
"""
handshake_protocols = data.get("handshake_protocols")
requests_attach = data.get("requests_attach")
if not (
(handshake_protocols and len(handshake_protocols) > 0)
or (requests_attach and len(requests_attach) > 0)
):
if not handshake_protocols and not requests_attach:
raise ValidationError(
"Model must include non-empty "
"handshake_protocols or requests_attach or both"
"handshake_protocols or requests~attach or both"
)

# services = data.get("services")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def test_invalid_invi_wrong_type_services(self):
"services": [123],
}

invi_schema = InvitationMessageSchema()
with pytest.raises(test_module.ValidationError):
invi_schema.validate_fields(obj_x)
errs = InvitationMessageSchema().validate(obj_x)
assert errs and "services" in errs

def test_assign_msg_type_version_to_model_inst(self):
test_msg = InvitationMessage()
Expand Down

0 comments on commit 6f2ef55

Please sign in to comment.