diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index 7330699f7c..c6eda4a5a6 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -53,6 +53,7 @@ "acapy::actionmenu::received": "actionmenu", "acapy::actionmenu::get-active-menu": "get-active-menu", "acapy::actionmenu::perform-menu-action": "perform-menu-action", + "acapy::keylist::updated": "keylist", } diff --git a/aries_cloudagent/core/event_bus.py b/aries_cloudagent/core/event_bus.py index 180f4130dc..22d7c8f922 100644 --- a/aries_cloudagent/core/event_bus.py +++ b/aries_cloudagent/core/event_bus.py @@ -15,6 +15,7 @@ Optional, Pattern, TYPE_CHECKING, + Tuple, ) from functools import partial @@ -193,7 +194,7 @@ class MockEventBus(EventBus): def __init__(self): """Initialize MockEventBus.""" super().__init__() - self.events = [] + self.events: List[Tuple[Profile, Event]] = [] async def notify(self, profile: "Profile", event: Event): """Append the event to MockEventBus.events.""" diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py index 8e8f8922a5..20c63a5e15 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py @@ -32,6 +32,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder): session, context.connection_record.connection_id ) response = await mgr.update_keylist(record, updates=context.message.updates) + response.assign_thread_from(context.message) await responder.send_reply(response) except (StorageNotFoundError, MediationNotGrantedError): reply = CMProblemReport( diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py index bcca040072..79e2b56aec 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py @@ -25,3 +25,6 @@ async def handle(self, context: RequestContext, responder: BaseResponder): await mgr.store_update_results( context.connection_record.connection_id, context.message.updated ) + await mgr.notify_keylist_updated( + context.connection_record.connection_id, context.message + ) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py index c9e3bb868f..75218935d0 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py @@ -46,6 +46,9 @@ async def test_handler(self): handler, responder = KeylistUpdateResponseHandler(), MockResponder() with async_mock.patch.object( MediationManager, "store_update_results" - ) as mock_method: + ) as mock_store, async_mock.patch.object( + MediationManager, "notify_keylist_updated" + ) as mock_notify: await handler.handle(self.context, responder) - mock_method.assert_called_once_with(TEST_CONN_ID, self.updated) + mock_store.assert_called_once_with(TEST_CONN_ID, self.updated) + mock_notify.assert_called_once_with(TEST_CONN_ID, self.context.message) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py index 8fb2d2a449..195dd5fd5a 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py @@ -60,6 +60,7 @@ class MediationManager: SET_TO_DEFAULT_ON_GRANTED = "set_to_default_on_granted" METADATA_KEY = "mediation" METADATA_ID = "id" + KEYLIST_UPDATED_EVENT = "acapy::keylist::updated" def __init__(self, profile: Profile): """Initialize Mediation Manager. @@ -534,63 +535,79 @@ async def store_update_results( session: An active profile session """ - session = await self._profile.session() to_save: Sequence[RouteRecord] = [] to_remove: Sequence[RouteRecord] = [] - for updated in results: - if updated.result != KeylistUpdated.RESULT_SUCCESS: - # TODO better handle different results? - LOGGER.warning( - "Keylist update failure: %s(%s): %s", - updated.action, - updated.recipient_key, - updated.result, - ) - continue - if updated.action == KeylistUpdateRule.RULE_ADD: - # Multi-tenancy uses route record for internal relaying of wallets - # So the record could already exist. We update in that case - try: - record = await RouteRecord.retrieve_by_recipient_key( - session, updated.recipient_key - ) - record.connection_id = connection_id - record.role = RouteRecord.ROLE_CLIENT - except StorageNotFoundError: - record = RouteRecord( - role=RouteRecord.ROLE_CLIENT, - recipient_key=updated.recipient_key, - connection_id=connection_id, - ) - to_save.append(record) - elif updated.action == KeylistUpdateRule.RULE_REMOVE: - try: - records = await RouteRecord.query( - session, - { - "role": RouteRecord.ROLE_CLIENT, - "connection_id": connection_id, - "recipient_key": updated.recipient_key, - }, - ) - except StorageNotFoundError as err: - LOGGER.error( - "No route found while processing keylist update response: %s", - err, + + async with self._profile.session() as session: + for updated in results: + if updated.result != KeylistUpdated.RESULT_SUCCESS: + # TODO better handle different results? + LOGGER.warning( + "Keylist update failure: %s(%s): %s", + updated.action, + updated.recipient_key, + updated.result, ) - else: - if len(records) > 1: + continue + if updated.action == KeylistUpdateRule.RULE_ADD: + # Multi-tenancy uses route record for internal relaying of wallets + # So the record could already exist. We update in that case + try: + record = await RouteRecord.retrieve_by_recipient_key( + session, updated.recipient_key + ) + record.connection_id = connection_id + record.role = RouteRecord.ROLE_CLIENT + except StorageNotFoundError: + record = RouteRecord( + role=RouteRecord.ROLE_CLIENT, + recipient_key=updated.recipient_key, + connection_id=connection_id, + ) + to_save.append(record) + elif updated.action == KeylistUpdateRule.RULE_REMOVE: + try: + records = await RouteRecord.query( + session, + { + "role": RouteRecord.ROLE_CLIENT, + "connection_id": connection_id, + "recipient_key": updated.recipient_key, + }, + ) + except StorageNotFoundError as err: LOGGER.error( - f"Too many ({len(records)}) routes found " - "while processing keylist update response" + "No route found while processing keylist update response: %s", + err, ) - record = records[0] - to_remove.append(record) + else: + if len(records) > 1: + LOGGER.error( + f"Too many ({len(records)}) routes found " + "while processing keylist update response" + ) + record = records[0] + to_remove.append(record) + + for record_for_saving in to_save: + await record_for_saving.save( + session, reason="Route successfully added." + ) + for record_for_removal in to_remove: + await record_for_removal.delete_record(session) - for record_for_saving in to_save: - await record_for_saving.save(session, reason="Route successfully added.") - for record_for_removal in to_remove: - await record_for_removal.delete_record(session) + async def notify_keylist_updated( + self, connection_id: str, response: KeylistUpdateResponse + ): + """Notify of keylist update response received.""" + await self._profile.notify( + self.KEYLIST_UPDATED_EVENT, + { + "connection_id": connection_id, + "thread_id": response._thread_id, + "updated": [update.serialize() for update in response.updated], + }, + ) async def get_my_keylist( self, connection_id: Optional[str] = None diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py index d2505ccfb9..f3a03c0a30 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py @@ -1,19 +1,16 @@ """Test MediationManager.""" import logging - -import pytest +from typing import AsyncIterable, Iterable from asynctest import mock as async_mock +import pytest +from .. import manager as test_module +from .....core.event_bus import EventBus, MockEventBus +from .....core.in_memory import InMemoryProfile from .....core.profile import Profile, ProfileSession -from .....connections.models.conn_record import ConnRecord -from .....messaging.request_context import RequestContext from .....storage.error import StorageNotFoundError -from .....transport.inbound.receipt import MessageReceipt - from ....routing.v1_0.models.route_record import RouteRecord - -from .. import manager as test_module from ..manager import ( MediationAlreadyExists, MediationManager, @@ -22,12 +19,14 @@ ) from ..messages.inner.keylist_update_rule import KeylistUpdateRule from ..messages.inner.keylist_updated import KeylistUpdated +from ..messages.keylist_update_response import KeylistUpdateResponse from ..messages.mediate_deny import MediationDeny from ..messages.mediate_grant import MediationGrant from ..messages.mediate_request import MediationRequest from ..models.mediation_record import MediationRecord TEST_CONN_ID = "conn-id" +TEST_THREAD_ID = "thread-id" TEST_ENDPOINT = "https://example.com" TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" TEST_ROUTE_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" @@ -36,29 +35,32 @@ @pytest.fixture -async def profile() -> Profile: +def profile() -> Iterable[Profile]: """Fixture for profile used in tests.""" # pylint: disable=W0621 - context = RequestContext.test_context() - context.message_receipt = MessageReceipt(sender_verkey=TEST_VERKEY) - context.connection_record = ConnRecord(connection_id=TEST_CONN_ID) - yield context.profile + yield InMemoryProfile.test_profile(bind={EventBus: MockEventBus()}) + + +@pytest.fixture +def mock_event_bus(profile: Profile): + yield profile.inject(EventBus) @pytest.fixture -async def session(profile) -> ProfileSession: # pylint: disable=W0621 +async def session(profile) -> AsyncIterable[ProfileSession]: # pylint: disable=W0621 """Fixture for profile session used in tests.""" - yield await profile.session() + async with profile.session() as session: + yield session @pytest.fixture -async def manager(profile) -> MediationManager: # pylint: disable=W0621 +def manager(profile) -> Iterable[MediationManager]: # pylint: disable=W0621 """Fixture for manager used in tests.""" yield MediationManager(profile) @pytest.fixture -def record() -> MediationRecord: +def record() -> Iterable[MediationRecord]: """Fixture for record used in tests.""" yield MediationRecord( state=MediationRecord.STATE_GRANTED, connection_id=TEST_CONN_ID @@ -71,7 +73,7 @@ class TestMediationManager: # pylint: disable=R0904,W0621 async def test_create_manager_no_profile(self): """test_create_manager_no_profile.""" with pytest.raises(MediationManagerError): - await MediationManager(None) + MediationManager(None) async def test_create_did(self, manager, session): """test_create_did.""" @@ -363,7 +365,11 @@ async def test_add_remove_key_mix(self, manager): assert update.updates[0].recipient_key == TEST_VERKEY assert update.updates[1].recipient_key == TEST_ROUTE_VERKEY - async def test_store_update_results(self, session, manager): + async def test_store_update_results( + self, + session: ProfileSession, + manager: MediationManager, + ): """test_store_update_results.""" await RouteRecord( role=RouteRecord.ROLE_CLIENT, @@ -467,6 +473,34 @@ async def test_store_update_results_errors(self, caplog, manager): assert "server_error" in caplog.text print(caplog.text) + async def test_notify_keylist_updated( + self, manager: MediationManager, mock_event_bus: MockEventBus + ): + """test notify_keylist_updated.""" + response = KeylistUpdateResponse( + updated=[ + KeylistUpdated( + recipient_key=TEST_ROUTE_VERKEY, + action=KeylistUpdateRule.RULE_ADD, + result=KeylistUpdated.RESULT_SUCCESS, + ), + KeylistUpdated( + recipient_key=TEST_VERKEY, + action=KeylistUpdateRule.RULE_REMOVE, + result=KeylistUpdated.RESULT_SUCCESS, + ), + ], + ) + response.assign_thread_id(TEST_THREAD_ID) + await manager.notify_keylist_updated(TEST_CONN_ID, response) + assert mock_event_bus.events + assert mock_event_bus.events[0][1].topic == manager.KEYLIST_UPDATED_EVENT + assert mock_event_bus.events[0][1].payload == { + "connection_id": TEST_CONN_ID, + "thread_id": TEST_THREAD_ID, + "updated": [result.serialize() for result in response.updated], + } + async def test_get_my_keylist(self, session, manager): """test_get_my_keylist.""" await RouteRecord(