diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py index 5663dc92b4..1fc7dff655 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/hangup_handler.py @@ -21,9 +21,8 @@ async def handle(self, context: RequestContext, responder: BaseResponder): assert isinstance(context.message, Hangup) connection_record = context.connection_record - hangup = context.message profile = context.profile did_rotate_mgr = DIDRotateManager(profile) - await did_rotate_mgr.receive_hangup(connection_record, hangup) + await did_rotate_mgr.receive_hangup(connection_record) diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py index e4dfb4eb93..4d0bbc8a75 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/problem_report_handler.py @@ -20,10 +20,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("ProblemReportHandler called with context %s", context) assert isinstance(context.message, RotateProblemReport) - connection_record = context.connection_record problem_report = context.message profile = context.profile did_rotate_mgr = DIDRotateManager(profile) - await did_rotate_mgr.receive_problem_report(connection_record, problem_report) + await did_rotate_mgr.receive_problem_report(problem_report) diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py index f3d51fedf1..764a5ee5b8 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_hangup_handler.py @@ -29,5 +29,5 @@ async def test_handle(self, MockDIDRotateManager, request_context): await handler.handle(request_context, responder) MockDIDRotateManager.return_value.receive_hangup.assert_called_once_with( - request_context.connection_record, request_context.message + request_context.connection_record ) diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py index ba00576aba..e461d6c8dc 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/handlers/tests/test_problem_report_handler.py @@ -33,5 +33,5 @@ async def test_handle(self, MockDIDRotateManager, request_context): await handler.handle(request_context, responder) MockDIDRotateManager.return_value.receive_problem_report.assert_called_once_with( - request_context.connection_record, request_context.message + request_context.message ) diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/manager.py b/aries_cloudagent/protocols/did_rotate/v1_0/manager.py index cee81a82c9..c03509f389 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/manager.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/manager.py @@ -79,7 +79,8 @@ async def hangup(self, conn: ConnRecord) -> Hangup: hangup = Hangup() - # TODO: Should the connection be terminated here? + async with self.profile.session() as session: + await conn.delete_record(session) responder = self.profile.inject(BaseResponder) await responder.send(hangup, connection_id=conn.connection_id) @@ -204,9 +205,7 @@ async def receive_ack(self, conn: ConnRecord, ack: RotateAck): conn_mgr = BaseConnectionManager(self.profile) await conn_mgr.clear_connection_targets_cache(conn.connection_id) - async def receive_problem_report( - self, conn: ConnRecord, problem_report: RotateProblemReport - ): + async def receive_problem_report(self, problem_report: RotateProblemReport): """Receive problem report message. Args: @@ -225,14 +224,15 @@ async def receive_problem_report( async with self.profile.session() as session: await record.save(session, reason="Received problem report") - async def receive_hangup(self, conn: ConnRecord, hangup: Hangup): + async def receive_hangup(self, conn: ConnRecord): """Receive hangup message. Args: conn (ConnRecord): The connection to rotate the DID for. hangup (Hangup): The received hangup message. """ - # TODO: Should the connection be terminated here? + async with self.profile.session() as session: + await conn.delete_record(session) async def _ensure_supported_did(self, did: str): """Check if the DID is supported.""" diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/tests/__init__.py b/aries_cloudagent/protocols/did_rotate/v1_0/tests/__init__.py index e69de29bb2..9cb2a9fb90 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/tests/__init__.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/tests/__init__.py @@ -0,0 +1,9 @@ +from .....messaging.valid import UUID4_EXAMPLE + +test_conn_id = UUID4_EXAMPLE + + +class MockConnRecord: + def __init__(self, connection_id, is_ready) -> None: + self.connection_id = connection_id + self.is_ready = is_ready diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py new file mode 100644 index 0000000000..f217637673 --- /dev/null +++ b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_manager.py @@ -0,0 +1,224 @@ +from unittest import IsolatedAsyncioTestCase + +from .....connections.base_manager import ( + BaseConnectionManager, +) +from .....core.in_memory.profile import InMemoryProfile +from .....messaging.responder import BaseResponder, MockResponder +from .....protocols.coordinate_mediation.v1_0.route_manager import RouteManager +from .....protocols.did_rotate.v1_0.manager import ( + DIDRotateManager, + ReportableDIDRotateError, + UnrecordableKeysError, +) +from .....protocols.did_rotate.v1_0.messages.ack import RotateAck +from .....protocols.did_rotate.v1_0.messages.problem_report import ( + RotateProblemReport, +) +from .....protocols.did_rotate.v1_0.messages.rotate import Rotate +from .....protocols.did_rotate.v1_0.models.rotate_record import RotateRecord +from .....protocols.didcomm_prefix import DIDCommPrefix +from .....resolver.did_resolver import DIDResolver +from .....tests import mock +from .. import message_types as test_message_types +from ..tests import MockConnRecord, test_conn_id + + +class TestDIDRotateManager(IsolatedAsyncioTestCase): + test_endpoint = "http://localhost" + + async def asyncSetUp(self) -> None: + self.responder = MockResponder() + + self.route_manager = mock.MagicMock(RouteManager) + self.route_manager.routing_info = mock.CoroutineMock( + return_value=([], self.test_endpoint) + ) + self.route_manager.mediation_record_if_id = mock.CoroutineMock( + return_value=None + ) + self.route_manager.mediation_record_for_connection = mock.CoroutineMock( + return_value=None + ) + + self.profile = InMemoryProfile.test_profile( + bind={ + BaseResponder: self.responder, + RouteManager: self.route_manager, + DIDResolver: DIDResolver(), + } + ) + + self.manager = DIDRotateManager(self.profile) + assert self.manager.profile + + async def test_hangup(self): + mock_conn_record = MockConnRecord(test_conn_id, True) + mock_conn_record.delete_record = mock.CoroutineMock() + + with mock.patch.object( + self.responder, "send", mock.CoroutineMock() + ) as mock_send: + msg = await self.manager.hangup(mock_conn_record) + mock_conn_record.delete_record.assert_called_once() + mock_send.assert_called_once() + assert ( + msg._type == DIDCommPrefix.OLD.value + "/" + test_message_types.HANGUP + ) + + async def test_receive_hangup(self): + mock_conn_record = MockConnRecord(test_conn_id, True) + mock_conn_record.delete_record = mock.CoroutineMock() + + await self.manager.receive_hangup(mock_conn_record) + mock_conn_record.delete_record.assert_called_once() + + async def test_rotate_my_did(self): + mock_conn_record = MockConnRecord(test_conn_id, True) + test_to_did = "did:peer:2:testdid" + + with mock.patch.object( + self.responder, "send", mock.CoroutineMock() + ) as mock_send: + msg = await self.manager.rotate_my_did(mock_conn_record, test_to_did) + mock_send.assert_called_once() + assert ( + msg._type == DIDCommPrefix.OLD.value + "/" + test_message_types.ROTATE + ) + + async def test_receive_rotate(self): + mock_conn_record = MockConnRecord(test_conn_id, True) + + test_to_did = "did:peer:2:testdid" + + record = await self.manager.receive_rotate( + mock_conn_record, Rotate(to_did=test_to_did) + ) + + assert record.RECORD_TYPE == RotateRecord.RECORD_TYPE + assert record.role == record.ROLE_OBSERVING + assert record.state == record.STATE_ROTATE_RECEIVED + assert record.connection_id == mock_conn_record.connection_id + + async def test_receive_rotate_x(self): + mock_conn_record = MockConnRecord(test_conn_id, True) + + test_to_did = "did:badmethod:1:testdid" + test_problem_report = ReportableDIDRotateError( + RotateProblemReport(problem_items=[{"did": test_to_did}]) + ) + + with mock.patch.object( + self.manager, "_ensure_supported_did", side_effect=test_problem_report + ), mock.patch.object(self.responder, "send", mock.CoroutineMock()) as mock_send: + await self.manager.receive_rotate( + mock_conn_record, Rotate(to_did=test_to_did) + ) + mock_send.assert_called_once_with( + test_problem_report.message, + connection_id=mock_conn_record.connection_id, + ) + + @mock.patch.object( + BaseConnectionManager, + "record_keys_for_resolvable_did", + mock.CoroutineMock(), + ) + async def test_commit_rotate(self, *_): + mock_conn_record = MockConnRecord(test_conn_id, True) + mock_conn_record.save = mock.CoroutineMock() + + test_to_did = "did:peer:2:testdid" + + record = await self.manager.receive_rotate( + mock_conn_record, Rotate(to_did=test_to_did) + ) + await self.manager.commit_rotate(mock_conn_record, record) + + assert record.state == RotateRecord.STATE_ACK_SENT + + @mock.patch.object( + BaseConnectionManager, + "record_keys_for_resolvable_did", + mock.CoroutineMock(), + ) + async def test_commit_rotate_x_no_new_did(self, *_): + mock_conn_record = MockConnRecord(test_conn_id, True) + mock_conn_record.save = mock.CoroutineMock() + + test_to_did = "did:peer:2:testdid" + + record = await self.manager.receive_rotate( + mock_conn_record, Rotate(to_did=test_to_did) + ) + record.new_did = None + + with self.assertRaises(ValueError): + await self.manager.commit_rotate(mock_conn_record, record) + + async def test_commit_rotate_x_unrecordable_keys(self): + mock_conn_record = MockConnRecord(test_conn_id, True) + mock_conn_record.save = mock.CoroutineMock() + + test_to_did = "did:peer:2:testdid" + + record = await self.manager.receive_rotate( + mock_conn_record, Rotate(to_did=test_to_did) + ) + + with self.assertRaises(UnrecordableKeysError): + await self.manager.commit_rotate(mock_conn_record, record) + + @mock.patch.object( + BaseConnectionManager, + "clear_connection_targets_cache", + mock.CoroutineMock(), + ) + async def test_receive_ack(self, *_): + mock_conn_record = MockConnRecord(test_conn_id, True) + mock_conn_record.save = mock.CoroutineMock() + + with mock.patch.object( + RotateRecord, + "retrieve_by_thread_id", + return_value=mock.CoroutineMock( + return_value=mock.MagicMock( + new_did="did:peer:2:testdid", delete_record=mock.CoroutineMock() + ) + ), + ) as mock_rotate_record: + await self.manager.receive_ack(mock_conn_record, RotateAck()) + + mock_conn_record.save.assert_called_once() + mock_rotate_record.return_value.delete_record.assert_called_once() + + async def test_receive_ack_x(self): + mock_conn_record = MockConnRecord(test_conn_id, True) + mock_conn_record.save = mock.CoroutineMock() + + with mock.patch.object( + RotateRecord, + "retrieve_by_thread_id", + return_value=mock.CoroutineMock(), + ) as mock_rotate_record: + mock_rotate_record.return_value.new_did = None + with self.assertRaises(ValueError): + await self.manager.receive_ack(mock_conn_record, RotateAck()) + + async def test_receive_problem_report(self): + test_to_did = "did:badmethod:1:testdid" + mock_problem_report = RotateProblemReport( + description={"code": 123}, + problem_items=[{"did": test_to_did}], + ) + + with mock.patch.object( + RotateRecord, + "retrieve_by_thread_id", + return_value=mock.CoroutineMock( + return_value=mock.MagicMock(save=mock.CoroutineMock()) + ), + ) as mock_rotate_record: + await self.manager.receive_problem_report(mock_problem_report) + + mock_rotate_record.return_value.save.assert_called_once() diff --git a/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py index db5b483da1..7cb9658b86 100644 --- a/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/did_rotate/v1_0/tests/test_routes.py @@ -2,28 +2,18 @@ from unittest import IsolatedAsyncioTestCase from .....admin.request_context import AdminRequestContext -from .....messaging.valid import UUID4_EXAMPLE from .....protocols.didcomm_prefix import DIDCommPrefix from .....storage.error import StorageNotFoundError from .....tests import mock from .. import message_types as test_message_types from .. import routes as test_module +from ..tests import MockConnRecord, test_conn_id -test_conn_id = UUID4_EXAMPLE test_valid_rotate_request = { "to_did": "did:example:newdid", } -def generate_mock_rotate_message(): - schema = test_module.RotateMesageSchema() - msg = schema.load(test_valid_rotate_request) - - msg._id = "test-message-id" - msg._type = test_message_types.ROTATE - return msg - - def generate_mock_hangup_message(): schema = test_module.HangupMessageSchema() msg = schema.load({}) @@ -33,10 +23,13 @@ def generate_mock_hangup_message(): return msg -class MockConnRecord: - def __init__(self, connection_id, is_ready) -> None: - self.connection_id = connection_id - self.is_ready = is_ready +def generate_mock_rotate_message(): + schema = test_module.RotateMesageSchema() + msg = schema.load(test_valid_rotate_request) + + msg._id = "test-message-id" + msg._type = test_message_types.ROTATE + return msg class TestDIDRotateRoutes(IsolatedAsyncioTestCase):