Skip to content

Commit

Permalink
rework based upon feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Shaanjot Gill <[email protected]>
  • Loading branch information
shaangill025 committed Oct 3, 2022
1 parent e0caf06 commit 348e665
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 50 deletions.
6 changes: 3 additions & 3 deletions aries_cloudagent/core/protocol_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging

from string import Template
from re import sub
from typing import Mapping, Sequence

from ..config.injection_context import InjectionContext
Expand Down Expand Up @@ -102,8 +102,8 @@ def create_msg_types_for_minor_version(self, typesets, version_definition):
def _get_updated_tyoeset_dict(self, typesets, to_check, updated_typeset) -> dict:
for typeset in typesets:
for msg_type_string, module_path in typeset.items():
updated_msg_type_string = Template(msg_type_string).substitute(
version=to_check
updated_msg_type_string = sub(
r"(\d+\.)?(\*|\d+)", to_check, msg_type_string
)
updated_typeset[updated_msg_type_string] = module_path
return updated_typeset
Expand Down
26 changes: 12 additions & 14 deletions aries_cloudagent/messaging/agent_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import uuid

from collections import OrderedDict
from re import sub
from typing import Mapping, Union
from string import Template

from marshmallow import (
EXCLUDE,
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(self, _id: str = None, _decorators: BaseDecoratorSet = None):
self.__class__.__name__
)
)
self._message_type = self.Meta.message_type
# Not required for now
# if not self.Meta.handler_class:
# raise TypeError(
Expand All @@ -101,18 +102,6 @@ def _get_handler_class(cls):
"""
return resolve_class(cls.Meta.handler_class, cls)

@classmethod
def assign_version_to_message_type(cls, version: str):
"""Assign version to Meta.message_type."""
if "$version" in cls.Meta.message_type:
cls.Meta.message_type = Template(cls.Meta.message_type).substitute(
version=version
)
else:
cls.Meta.message_type = re.sub(
r"(\d+\.)?(\*|\d+)", version, cls.Meta.message_type
)

@property
def Handler(self) -> type:
"""
Expand All @@ -133,7 +122,12 @@ def _type(self) -> str:
Current DIDComm prefix, slash, message type defined on `Meta.message_type`
"""
return DIDCommPrefix.qualify_current(self.Meta.message_type)
return DIDCommPrefix.qualify_current(self._message_type)

@_type.setter
def _type(self, msg_type: str):
"""Set the message type identifier."""
self._message_type = msg_type

@property
def _id(self) -> str:
Expand Down Expand Up @@ -161,6 +155,10 @@ def _decorators(self, value: BaseDecoratorSet):
"""Fetch the message's decorator set."""
self._message_decorators = value

def get_updated_msg_type(self, version: str) -> str:
"""Update version to Meta.message_type."""
return sub(r"(\d+\.)?(\*|\d+)", version, self.Meta.message_type)

def get_signature(self, field_name: str) -> SignatureDecorator:
"""
Get the signature for a named field.
Expand Down
7 changes: 1 addition & 6 deletions aries_cloudagent/messaging/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,10 @@ def make_model(self, data: dict, **kwargs):
try:
cls_inst = self.Model(**data)
except TypeError as err:
msg_type_version = None
if "_type" in str(err) and "_type" in data:
match = re.search(r"(\d+\.)?(\*|\d+)", data["_type"])
if match:
msg_type_version = match.group()
data["msg_type"] = data["_type"]
del data["_type"]
cls_inst = self.Model(**data)
if msg_type_version:
cls_inst.assign_version_to_message_type(msg_type_version)
return cls_inst

@post_dump
Expand Down
13 changes: 7 additions & 6 deletions aries_cloudagent/protocols/out_of_band/v1_0/message_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
"2da7fc4ee043effa3a9960150e7ba8c9a4628b68/features/0434-outofband"
)

# Message types
INVITATION = "out-of-band/$version/invitation"
MESSAGE_REUSE = "out-of-band/$version/handshake-reuse"
MESSAGE_REUSE_ACCEPT = "out-of-band/$version/handshake-reuse-accepted"
PROBLEM_REPORT = "out-of-band/$version/problem_report"

# Default Version
DEFAULT_VERSION = get_proto_default_version(
"aries_cloudagent.protocols.out_of_band.definition", 1
)

# Message types
INVITATION = f"out-of-band/{DEFAULT_VERSION}/invitation"
MESSAGE_REUSE = f"out-of-band/{DEFAULT_VERSION}/handshake-reuse"
MESSAGE_REUSE_ACCEPT = f"out-of-band/{DEFAULT_VERSION}/handshake-reuse-accepted"
PROBLEM_REPORT = f"out-of-band/{DEFAULT_VERSION}/problem_report"


PROTOCOL_PACKAGE = "aries_cloudagent.protocols.out_of_band.v1_0"

MESSAGE_TYPES = DIDCommPrefix.qualify_all(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
services: Sequence[Union[Service, Text]] = None,
accept: Optional[Sequence[Text]] = None,
version: str = DEFAULT_VERSION,
msg_type: Optional[Text] = None,
**kwargs,
):
"""
Expand All @@ -142,7 +143,10 @@ def __init__(
)
self.requests_attach = list(requests_attach) if requests_attach else []
self.services = services
self.assign_version_to_message_type(version)
if msg_type:
self._type = msg_type
else:
self._type = self.get_updated_msg_type(version)
self.accept = accept

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging

from enum import Enum
from typing import Optional, Text

from marshmallow import (
EXCLUDE,
Expand Down Expand Up @@ -41,10 +42,19 @@ class Meta:
message_type = PROBLEM_REPORT
schema_class = "OOBProblemReportSchema"

def __init__(self, version: str = DEFAULT_VERSION, *args, **kwargs):
def __init__(
self,
version: str = DEFAULT_VERSION,
msg_type: Optional[Text] = None,
*args,
**kwargs,
):
"""Initialize a ProblemReport message instance."""
super().__init__(*args, **kwargs)
self.assign_version_to_message_type(version=version)
if msg_type:
self._type = msg_type
else:
self._type = self.get_updated_msg_type(version)


class OOBProblemReportSchema(ProblemReportSchema):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Represents a Handshake Reuse message under RFC 0434."""

from marshmallow import EXCLUDE, fields, pre_dump, ValidationError
from typing import Optional, Text

from .....messaging.agent_message import AgentMessage, AgentMessageSchema

Expand All @@ -24,11 +25,15 @@ class Meta:
def __init__(
self,
version: str = DEFAULT_VERSION,
msg_type: Optional[Text] = None,
**kwargs,
):
"""Initialize Handshake Reuse message object."""
super().__init__(**kwargs)
self.assign_version_to_message_type(version=version)
if msg_type:
self._type = msg_type
else:
self._type = self.get_updated_msg_type(version)


class HandshakeReuseSchema(AgentMessageSchema):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Represents a Handshake Reuse Accept message under RFC 0434."""

from marshmallow import EXCLUDE, fields, pre_dump, ValidationError
from typing import Optional, Text

from .....messaging.agent_message import AgentMessage, AgentMessageSchema

Expand All @@ -25,11 +26,15 @@ class Meta:
def __init__(
self,
version: str = DEFAULT_VERSION,
msg_type: Optional[Text] = None,
**kwargs,
):
"""Initialize Handshake Reuse Accept object."""
super().__init__(**kwargs)
self.assign_version_to_message_type(version=version)
if msg_type:
self._type = msg_type
else:
self._type = self.get_updated_msg_type(version)


class HandshakeReuseAcceptSchema(AgentMessageSchema):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

from string import Template
from unittest import TestCase

from ......messaging.models.base import BaseModelError
Expand All @@ -10,6 +9,7 @@
from .....connections.v1_0.message_types import ARIES_PROTOCOL as CONN_PROTO
from .....didcomm_prefix import DIDCommPrefix
from .....didexchange.v1_0.message_types import ARIES_PROTOCOL as DIDX_PROTO
from .....didexchange.v1_0.messages.request import DIDXRequest

from ...message_types import INVITATION

Expand Down Expand Up @@ -45,16 +45,14 @@ def test_properties(self):
class TestInvitationMessage(TestCase):
def test_init(self):
"""Test initialization message."""
invi = InvitationMessage(
invi_msg = InvitationMessage(
comment="Hello",
label="A label",
handshake_protocols=[DIDCommPrefix.qualify_current(DIDX_PROTO)],
services=[TEST_DID],
)
assert invi.services == [TEST_DID]
assert invi._type == DIDCommPrefix.qualify_current(
Template(INVITATION).substitute(version="1.1")
)
assert invi_msg.services == [TEST_DID]
assert "out-of-band/1.1/invitation" in invi_msg._type

service = Service(_id="#inline", _type=DID_COMM, did=TEST_DID)
invi_msg = InvitationMessage(
Expand All @@ -65,9 +63,7 @@ def test_init(self):
version="1.0",
)
assert invi_msg.services == [service]
assert invi_msg._type == DIDCommPrefix.qualify_current(
Template(INVITATION).substitute(version="1.0")
)
assert "out-of-band/1.0/invitation" in invi_msg._type

def test_wrap_serde(self):
"""Test conversion of aries message to attachment decorator."""
Expand Down Expand Up @@ -150,3 +146,15 @@ def test_invalid_invi_wrong_type_services(self):
invi_schema = InvitationMessageSchema()
with pytest.raises(test_module.ValidationError):
invi_schema.validate_fields(obj_x)

def test_assign_msg_type_version_to_model_inst(self):
test_msg = InvitationMessage()
assert "1.1" in test_msg._type
assert "1.1" in InvitationMessage.Meta.message_type
test_msg = InvitationMessage(version="1.2")
assert "1.2" in test_msg._type
assert "1.1" in InvitationMessage.Meta.message_type
test_req = DIDXRequest()
assert "1.0" in test_req._type
assert "1.2" in test_msg._type
assert "1.1" in InvitationMessage.Meta.message_type
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,11 @@ def test_validate_and_logger(self):
self._caplog.set_level(logging.WARNING)
OOBProblemReportSchema().validate_fields(data)
assert "Unexpected error code received" in self._caplog.text

def test_assign_msg_type_version_to_model_inst(self):
test_msg = OOBProblemReport()
assert "1.1" in test_msg._type
assert "1.1" in OOBProblemReport.Meta.message_type
test_msg = OOBProblemReport(version="1.2")
assert "1.2" in test_msg._type
assert "1.1" in OOBProblemReport.Meta.message_type
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ def test_pre_dump_x(self):
"""Exercise pre-dump serialization requirements."""
with pytest.raises(BaseModelError):
data = self.reuse_msg.serialize()

def test_assign_msg_type_version_to_model_inst(self):
test_msg = HandshakeReuse()
assert "1.1" in test_msg._type
assert "1.1" in HandshakeReuse.Meta.message_type
test_msg = HandshakeReuse(version="1.2")
assert "1.2" in test_msg._type
assert "1.1" in HandshakeReuse.Meta.message_type
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from ......messaging.models.base import BaseModelError

from .....didcomm_prefix import DIDCommPrefix

from ..reuse_accept import HandshakeReuseAccept, HandshakeReuseAcceptSchema


Expand All @@ -31,7 +33,25 @@ def test_make_model(self):
model_instance = HandshakeReuseAccept.deserialize(data)
assert isinstance(model_instance, HandshakeReuseAccept)

def test_make_model_backward_comp(self):
"""Make reuse-accept model."""
self.reuse_accept_msg.assign_thread_id(thid="test_thid", pthid="test_pthid")
data = self.reuse_accept_msg.serialize()
data["@type"] = DIDCommPrefix.qualify_current(
"out-of-band/1.0/handshake-reuse-accepted"
)
model_instance = HandshakeReuseAccept.deserialize(data)
assert isinstance(model_instance, HandshakeReuseAccept)

def test_pre_dump_x(self):
"""Exercise pre-dump serialization requirements."""
with pytest.raises(BaseModelError):
data = self.reuse_accept_msg.serialize()

def test_assign_msg_type_version_to_model_inst(self):
test_msg = HandshakeReuseAccept()
assert "1.1" in test_msg._type
assert "1.1" in HandshakeReuseAccept.Meta.message_type
test_msg = HandshakeReuseAccept(version="1.2")
assert "1.2" in test_msg._type
assert "1.1" in HandshakeReuseAccept.Meta.message_type
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from string import Template
from typing import List
from unittest.mock import ANY

Expand Down Expand Up @@ -390,7 +389,7 @@ async def test_create_invitation_handshake_succeeds(self):
)

assert invi_rec.invitation._type == DIDCommPrefix.qualify_current(
Template(INVITATION).substitute(version="1.1")
"out-of-band/1.1/invitation"
)
assert not invi_rec.invitation.requests_attach
assert (
Expand Down Expand Up @@ -477,7 +476,7 @@ async def test_create_invitation_mediation_overwrites_routing_and_endpoint(self)
)
assert isinstance(invite, InvitationRecord)
assert invite.invitation._type == DIDCommPrefix.qualify_current(
Template(INVITATION).substitute(version="1.1")
"out-of-band/1.1/invitation"
)
assert invite.invitation.label == "test123"
assert (
Expand Down Expand Up @@ -796,9 +795,7 @@ async def test_create_invitation_peer_did(self):

assert invi_rec._invitation.ser[
"@type"
] == DIDCommPrefix.qualify_current(
Template(INVITATION).substitute(version="1.1")
)
] == DIDCommPrefix.qualify_current("out-of-band/1.1/invitation")
assert not invi_rec._invitation.ser.get("requests~attach")
assert invi_rec.invitation.label == "That guy"
assert (
Expand Down Expand Up @@ -906,7 +903,7 @@ async def test_create_handshake_reuse_msg(self):

# Assert responder has been called with the reuse message
assert reuse_message._type == DIDCommPrefix.qualify_current(
Template(MESSAGE_REUSE).substitute(version="1.1")
"out-of-band/1.1/handshake-reuse"
)
assert oob_record.reuse_msg_id == reuse_message._id

Expand Down

0 comments on commit 348e665

Please sign in to comment.