Skip to content

Commit

Permalink
feat: tests for did rotate manager
Browse files Browse the repository at this point in the history
Signed-off-by: Akiff Manji <[email protected]>
  • Loading branch information
amanji committed Mar 14, 2024
1 parent fa1290b commit bdb2600
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
assert isinstance(context.message, Hangup)

connection_record = context.connection_record
hangup = context.message

profile = context.profile
did_rotate_mgr = DIDRotateManager(profile)

await did_rotate_mgr.receive_hangup(connection_record, hangup)
await did_rotate_mgr.receive_hangup(connection_record)
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
self._logger.debug("ProblemReportHandler called with context %s", context)
assert isinstance(context.message, RotateProblemReport)

connection_record = context.connection_record
problem_report = context.message

profile = context.profile
did_rotate_mgr = DIDRotateManager(profile)

await did_rotate_mgr.receive_problem_report(connection_record, problem_report)
await did_rotate_mgr.receive_problem_report(problem_report)
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ async def test_handle(self, MockDIDRotateManager, request_context):
await handler.handle(request_context, responder)

MockDIDRotateManager.return_value.receive_hangup.assert_called_once_with(
request_context.connection_record, request_context.message
request_context.connection_record
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ async def test_handle(self, MockDIDRotateManager, request_context):
await handler.handle(request_context, responder)

MockDIDRotateManager.return_value.receive_problem_report.assert_called_once_with(
request_context.connection_record, request_context.message
request_context.message
)
12 changes: 6 additions & 6 deletions aries_cloudagent/protocols/did_rotate/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ async def hangup(self, conn: ConnRecord) -> Hangup:

hangup = Hangup()

# TODO: Should the connection be terminated here?
async with self.profile.session() as session:
await conn.delete_record(session)

responder = self.profile.inject(BaseResponder)
await responder.send(hangup, connection_id=conn.connection_id)
Expand Down Expand Up @@ -204,9 +205,7 @@ async def receive_ack(self, conn: ConnRecord, ack: RotateAck):
conn_mgr = BaseConnectionManager(self.profile)
await conn_mgr.clear_connection_targets_cache(conn.connection_id)

async def receive_problem_report(
self, conn: ConnRecord, problem_report: RotateProblemReport
):
async def receive_problem_report(self, problem_report: RotateProblemReport):
"""Receive problem report message.
Args:
Expand All @@ -225,14 +224,15 @@ async def receive_problem_report(
async with self.profile.session() as session:
await record.save(session, reason="Received problem report")

async def receive_hangup(self, conn: ConnRecord, hangup: Hangup):
async def receive_hangup(self, conn: ConnRecord):
"""Receive hangup message.
Args:
conn (ConnRecord): The connection to rotate the DID for.
hangup (Hangup): The received hangup message.
"""
# TODO: Should the connection be terminated here?
async with self.profile.session() as session:
await conn.delete_record(session)

async def _ensure_supported_did(self, did: str):
"""Check if the DID is supported."""
Expand Down
9 changes: 9 additions & 0 deletions aries_cloudagent/protocols/did_rotate/v1_0/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .....messaging.valid import UUID4_EXAMPLE

test_conn_id = UUID4_EXAMPLE


class MockConnRecord:
def __init__(self, connection_id, is_ready) -> None:
self.connection_id = connection_id
self.is_ready = is_ready
224 changes: 224 additions & 0 deletions aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from unittest import IsolatedAsyncioTestCase

from .....connections.base_manager import (
BaseConnectionManager,
)
from .....core.in_memory.profile import InMemoryProfile
from .....messaging.responder import BaseResponder, MockResponder
from .....protocols.coordinate_mediation.v1_0.route_manager import RouteManager
from .....protocols.did_rotate.v1_0.manager import (
DIDRotateManager,
ReportableDIDRotateError,
UnrecordableKeysError,
)
from .....protocols.did_rotate.v1_0.messages.ack import RotateAck
from .....protocols.did_rotate.v1_0.messages.problem_report import (
RotateProblemReport,
)
from .....protocols.did_rotate.v1_0.messages.rotate import Rotate
from .....protocols.did_rotate.v1_0.models.rotate_record import RotateRecord
from .....protocols.didcomm_prefix import DIDCommPrefix
from .....resolver.did_resolver import DIDResolver
from .....tests import mock
from .. import message_types as test_message_types
from ..tests import MockConnRecord, test_conn_id


class TestDIDRotateManager(IsolatedAsyncioTestCase):
test_endpoint = "http://localhost"

async def asyncSetUp(self) -> None:
self.responder = MockResponder()

self.route_manager = mock.MagicMock(RouteManager)
self.route_manager.routing_info = mock.CoroutineMock(
return_value=([], self.test_endpoint)
)
self.route_manager.mediation_record_if_id = mock.CoroutineMock(
return_value=None
)
self.route_manager.mediation_record_for_connection = mock.CoroutineMock(
return_value=None
)

self.profile = InMemoryProfile.test_profile(
bind={
BaseResponder: self.responder,
RouteManager: self.route_manager,
DIDResolver: DIDResolver(),
}
)

self.manager = DIDRotateManager(self.profile)
assert self.manager.profile

async def test_hangup(self):
mock_conn_record = MockConnRecord(test_conn_id, True)
mock_conn_record.delete_record = mock.CoroutineMock()

with mock.patch.object(
self.responder, "send", mock.CoroutineMock()
) as mock_send:
msg = await self.manager.hangup(mock_conn_record)
mock_conn_record.delete_record.assert_called_once()
mock_send.assert_called_once()
assert (
msg._type == DIDCommPrefix.OLD.value + "/" + test_message_types.HANGUP
)

async def test_receive_hangup(self):
mock_conn_record = MockConnRecord(test_conn_id, True)
mock_conn_record.delete_record = mock.CoroutineMock()

await self.manager.receive_hangup(mock_conn_record)
mock_conn_record.delete_record.assert_called_once()

async def test_rotate_my_did(self):
mock_conn_record = MockConnRecord(test_conn_id, True)
test_to_did = "did:peer:2:testdid"

with mock.patch.object(
self.responder, "send", mock.CoroutineMock()
) as mock_send:
msg = await self.manager.rotate_my_did(mock_conn_record, test_to_did)
mock_send.assert_called_once()
assert (
msg._type == DIDCommPrefix.OLD.value + "/" + test_message_types.ROTATE
)

async def test_receive_rotate(self):
mock_conn_record = MockConnRecord(test_conn_id, True)

test_to_did = "did:peer:2:testdid"

record = await self.manager.receive_rotate(
mock_conn_record, Rotate(to_did=test_to_did)
)

assert record.RECORD_TYPE == RotateRecord.RECORD_TYPE
assert record.role == record.ROLE_OBSERVING
assert record.state == record.STATE_ROTATE_RECEIVED
assert record.connection_id == mock_conn_record.connection_id

async def test_receive_rotate_x(self):
mock_conn_record = MockConnRecord(test_conn_id, True)

test_to_did = "did:badmethod:1:testdid"
test_problem_report = ReportableDIDRotateError(
RotateProblemReport(problem_items=[{"did": test_to_did}])
)

with mock.patch.object(
self.manager, "_ensure_supported_did", side_effect=test_problem_report
), mock.patch.object(self.responder, "send", mock.CoroutineMock()) as mock_send:
await self.manager.receive_rotate(
mock_conn_record, Rotate(to_did=test_to_did)
)
mock_send.assert_called_once_with(
test_problem_report.message,
connection_id=mock_conn_record.connection_id,
)

@mock.patch.object(
BaseConnectionManager,
"record_keys_for_resolvable_did",
mock.CoroutineMock(),
)
async def test_commit_rotate(self, *_):
mock_conn_record = MockConnRecord(test_conn_id, True)
mock_conn_record.save = mock.CoroutineMock()

test_to_did = "did:peer:2:testdid"

record = await self.manager.receive_rotate(
mock_conn_record, Rotate(to_did=test_to_did)
)
await self.manager.commit_rotate(mock_conn_record, record)

assert record.state == RotateRecord.STATE_ACK_SENT

@mock.patch.object(
BaseConnectionManager,
"record_keys_for_resolvable_did",
mock.CoroutineMock(),
)
async def test_commit_rotate_x_no_new_did(self, *_):
mock_conn_record = MockConnRecord(test_conn_id, True)
mock_conn_record.save = mock.CoroutineMock()

test_to_did = "did:peer:2:testdid"

record = await self.manager.receive_rotate(
mock_conn_record, Rotate(to_did=test_to_did)
)
record.new_did = None

with self.assertRaises(ValueError):
await self.manager.commit_rotate(mock_conn_record, record)

async def test_commit_rotate_x_unrecordable_keys(self):
mock_conn_record = MockConnRecord(test_conn_id, True)
mock_conn_record.save = mock.CoroutineMock()

test_to_did = "did:peer:2:testdid"

record = await self.manager.receive_rotate(
mock_conn_record, Rotate(to_did=test_to_did)
)

with self.assertRaises(UnrecordableKeysError):
await self.manager.commit_rotate(mock_conn_record, record)

@mock.patch.object(
BaseConnectionManager,
"clear_connection_targets_cache",
mock.CoroutineMock(),
)
async def test_receive_ack(self, *_):
mock_conn_record = MockConnRecord(test_conn_id, True)
mock_conn_record.save = mock.CoroutineMock()

with mock.patch.object(
RotateRecord,
"retrieve_by_thread_id",
return_value=mock.CoroutineMock(
return_value=mock.MagicMock(
new_did="did:peer:2:testdid", delete_record=mock.CoroutineMock()
)
),
) as mock_rotate_record:
await self.manager.receive_ack(mock_conn_record, RotateAck())

mock_conn_record.save.assert_called_once()
mock_rotate_record.return_value.delete_record.assert_called_once()

async def test_receive_ack_x(self):
mock_conn_record = MockConnRecord(test_conn_id, True)
mock_conn_record.save = mock.CoroutineMock()

with mock.patch.object(
RotateRecord,
"retrieve_by_thread_id",
return_value=mock.CoroutineMock(),
) as mock_rotate_record:
mock_rotate_record.return_value.new_did = None
with self.assertRaises(ValueError):
await self.manager.receive_ack(mock_conn_record, RotateAck())

async def test_receive_problem_report(self):
test_to_did = "did:badmethod:1:testdid"
mock_problem_report = RotateProblemReport(
description={"code": 123},
problem_items=[{"did": test_to_did}],
)

with mock.patch.object(
RotateRecord,
"retrieve_by_thread_id",
return_value=mock.CoroutineMock(
return_value=mock.MagicMock(save=mock.CoroutineMock())
),
) as mock_rotate_record:
await self.manager.receive_problem_report(mock_problem_report)

mock_rotate_record.return_value.save.assert_called_once()
23 changes: 8 additions & 15 deletions aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,18 @@
from unittest import IsolatedAsyncioTestCase

from .....admin.request_context import AdminRequestContext
from .....messaging.valid import UUID4_EXAMPLE
from .....protocols.didcomm_prefix import DIDCommPrefix
from .....storage.error import StorageNotFoundError
from .....tests import mock
from .. import message_types as test_message_types
from .. import routes as test_module
from ..tests import MockConnRecord, test_conn_id

test_conn_id = UUID4_EXAMPLE
test_valid_rotate_request = {
"to_did": "did:example:newdid",
}


def generate_mock_rotate_message():
schema = test_module.RotateMesageSchema()
msg = schema.load(test_valid_rotate_request)

msg._id = "test-message-id"
msg._type = test_message_types.ROTATE
return msg


def generate_mock_hangup_message():
schema = test_module.HangupMessageSchema()
msg = schema.load({})
Expand All @@ -33,10 +23,13 @@ def generate_mock_hangup_message():
return msg


class MockConnRecord:
def __init__(self, connection_id, is_ready) -> None:
self.connection_id = connection_id
self.is_ready = is_ready
def generate_mock_rotate_message():
schema = test_module.RotateMesageSchema()
msg = schema.load(test_valid_rotate_request)

msg._id = "test-message-id"
msg._type = test_message_types.ROTATE
return msg


class TestDIDRotateRoutes(IsolatedAsyncioTestCase):
Expand Down

0 comments on commit bdb2600

Please sign in to comment.