From 83a3681cc4f5e8dc83584d9927fe7a82de811e10 Mon Sep 17 00:00:00 2001 From: Daniel Bluhm Date: Wed, 11 May 2022 09:35:45 -0400 Subject: [PATCH 1/2] feat: event and webhook on keylist update stored Signed-off-by: Daniel Bluhm --- aries_cloudagent/admin/server.py | 1 + aries_cloudagent/core/event_bus.py | 3 +- .../coordinate_mediation/v1_0/manager.py | 116 ++++++++++-------- .../v1_0/tests/test_mediation_manager.py | 39 ++++-- 4 files changed, 95 insertions(+), 64 deletions(-) diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index 0bbada1da6..70ec98c308 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -52,6 +52,7 @@ "acapy::actionmenu::received": "actionmenu", "acapy::actionmenu::get-active-menu": "get-active-menu", "acapy::actionmenu::perform-menu-action": "perform-menu-action", + "acapy::keylist::updated": "keylist", } diff --git a/aries_cloudagent/core/event_bus.py b/aries_cloudagent/core/event_bus.py index 180f4130dc..22d7c8f922 100644 --- a/aries_cloudagent/core/event_bus.py +++ b/aries_cloudagent/core/event_bus.py @@ -15,6 +15,7 @@ Optional, Pattern, TYPE_CHECKING, + Tuple, ) from functools import partial @@ -193,7 +194,7 @@ class MockEventBus(EventBus): def __init__(self): """Initialize MockEventBus.""" super().__init__() - self.events = [] + self.events: List[Tuple[Profile, Event]] = [] async def notify(self, profile: "Profile", event: Event): """Append the event to MockEventBus.events.""" diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py index 8fb2d2a449..6da7845a94 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py @@ -60,6 +60,7 @@ class MediationManager: SET_TO_DEFAULT_ON_GRANTED = "set_to_default_on_granted" METADATA_KEY = "mediation" METADATA_ID = "id" + KEYLIST_UPDATED_EVENT = "acapy::keylist::updated" def __init__(self, profile: Profile): """Initialize Mediation Manager. @@ -534,63 +535,74 @@ async def store_update_results( session: An active profile session """ - session = await self._profile.session() to_save: Sequence[RouteRecord] = [] to_remove: Sequence[RouteRecord] = [] - for updated in results: - if updated.result != KeylistUpdated.RESULT_SUCCESS: - # TODO better handle different results? - LOGGER.warning( - "Keylist update failure: %s(%s): %s", - updated.action, - updated.recipient_key, - updated.result, - ) - continue - if updated.action == KeylistUpdateRule.RULE_ADD: - # Multi-tenancy uses route record for internal relaying of wallets - # So the record could already exist. We update in that case - try: - record = await RouteRecord.retrieve_by_recipient_key( - session, updated.recipient_key - ) - record.connection_id = connection_id - record.role = RouteRecord.ROLE_CLIENT - except StorageNotFoundError: - record = RouteRecord( - role=RouteRecord.ROLE_CLIENT, - recipient_key=updated.recipient_key, - connection_id=connection_id, - ) - to_save.append(record) - elif updated.action == KeylistUpdateRule.RULE_REMOVE: - try: - records = await RouteRecord.query( - session, - { - "role": RouteRecord.ROLE_CLIENT, - "connection_id": connection_id, - "recipient_key": updated.recipient_key, - }, - ) - except StorageNotFoundError as err: - LOGGER.error( - "No route found while processing keylist update response: %s", - err, + + async with self._profile.session() as session: + for updated in results: + if updated.result != KeylistUpdated.RESULT_SUCCESS: + # TODO better handle different results? + LOGGER.warning( + "Keylist update failure: %s(%s): %s", + updated.action, + updated.recipient_key, + updated.result, ) - else: - if len(records) > 1: + continue + if updated.action == KeylistUpdateRule.RULE_ADD: + # Multi-tenancy uses route record for internal relaying of wallets + # So the record could already exist. We update in that case + try: + record = await RouteRecord.retrieve_by_recipient_key( + session, updated.recipient_key + ) + record.connection_id = connection_id + record.role = RouteRecord.ROLE_CLIENT + except StorageNotFoundError: + record = RouteRecord( + role=RouteRecord.ROLE_CLIENT, + recipient_key=updated.recipient_key, + connection_id=connection_id, + ) + to_save.append(record) + elif updated.action == KeylistUpdateRule.RULE_REMOVE: + try: + records = await RouteRecord.query( + session, + { + "role": RouteRecord.ROLE_CLIENT, + "connection_id": connection_id, + "recipient_key": updated.recipient_key, + }, + ) + except StorageNotFoundError as err: LOGGER.error( - f"Too many ({len(records)}) routes found " - "while processing keylist update response" + "No route found while processing keylist update response: %s", + err, ) - record = records[0] - to_remove.append(record) - - for record_for_saving in to_save: - await record_for_saving.save(session, reason="Route successfully added.") - for record_for_removal in to_remove: - await record_for_removal.delete_record(session) + else: + if len(records) > 1: + LOGGER.error( + f"Too many ({len(records)}) routes found " + "while processing keylist update response" + ) + record = records[0] + to_remove.append(record) + + for record_for_saving in to_save: + await record_for_saving.save( + session, reason="Route successfully added." + ) + for record_for_removal in to_remove: + await record_for_removal.delete_record(session) + + await self._profile.notify( + self.KEYLIST_UPDATED_EVENT, + { + "connection_id": connection_id, + "updated": [update.serialize() for update in results], + }, + ) async def get_my_keylist( self, connection_id: Optional[str] = None diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py index d2505ccfb9..3f834f589d 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py @@ -1,11 +1,14 @@ """Test MediationManager.""" import logging +from typing import AsyncIterable, Iterable import pytest from asynctest import mock as async_mock from .....core.profile import Profile, ProfileSession +from .....core.in_memory import InMemoryProfile +from .....core.event_bus import EventBus, MockEventBus from .....connections.models.conn_record import ConnRecord from .....messaging.request_context import RequestContext from .....storage.error import StorageNotFoundError @@ -36,29 +39,32 @@ @pytest.fixture -async def profile() -> Profile: +def profile() -> Iterable[Profile]: """Fixture for profile used in tests.""" # pylint: disable=W0621 - context = RequestContext.test_context() - context.message_receipt = MessageReceipt(sender_verkey=TEST_VERKEY) - context.connection_record = ConnRecord(connection_id=TEST_CONN_ID) - yield context.profile + yield InMemoryProfile.test_profile(bind={EventBus: MockEventBus()}) @pytest.fixture -async def session(profile) -> ProfileSession: # pylint: disable=W0621 +def mock_event_bus(profile: Profile): + yield profile.inject(EventBus) + + +@pytest.fixture +async def session(profile) -> AsyncIterable[ProfileSession]: # pylint: disable=W0621 """Fixture for profile session used in tests.""" - yield await profile.session() + async with profile.session() as session: + yield session @pytest.fixture -async def manager(profile) -> MediationManager: # pylint: disable=W0621 +def manager(profile) -> Iterable[MediationManager]: # pylint: disable=W0621 """Fixture for manager used in tests.""" yield MediationManager(profile) @pytest.fixture -def record() -> MediationRecord: +def record() -> Iterable[MediationRecord]: """Fixture for record used in tests.""" yield MediationRecord( state=MediationRecord.STATE_GRANTED, connection_id=TEST_CONN_ID @@ -71,7 +77,7 @@ class TestMediationManager: # pylint: disable=R0904,W0621 async def test_create_manager_no_profile(self): """test_create_manager_no_profile.""" with pytest.raises(MediationManagerError): - await MediationManager(None) + MediationManager(None) async def test_create_did(self, manager, session): """test_create_did.""" @@ -363,7 +369,12 @@ async def test_add_remove_key_mix(self, manager): assert update.updates[0].recipient_key == TEST_VERKEY assert update.updates[1].recipient_key == TEST_ROUTE_VERKEY - async def test_store_update_results(self, session, manager): + async def test_store_update_results( + self, + session: ProfileSession, + manager: MediationManager, + mock_event_bus: MockEventBus, + ): """test_store_update_results.""" await RouteRecord( role=RouteRecord.ROLE_CLIENT, @@ -383,6 +394,12 @@ async def test_store_update_results(self, session, manager): ), ] await manager.store_update_results(TEST_CONN_ID, results) + assert mock_event_bus.events + assert mock_event_bus.events[0][1].topic == manager.KEYLIST_UPDATED_EVENT + assert mock_event_bus.events[0][1].payload == { + "connection_id": TEST_CONN_ID, + "updated": [result.serialize() for result in results], + } routes = await RouteRecord.query(session) assert len(routes) == 1 From e180aa0431f1418511add34256d30ceed36fad82 Mon Sep 17 00:00:00 2001 From: Daniel Bluhm Date: Wed, 22 Jun 2022 11:44:09 -0400 Subject: [PATCH 2/2] feat: include thread id in keylist update recv webhooks Signed-off-by: Daniel Bluhm --- .../v1_0/handlers/keylist_update_handler.py | 1 + .../keylist_update_response_handler.py | 3 ++ .../test_keylist_update_response_handler.py | 7 ++- .../coordinate_mediation/v1_0/manager.py | 7 ++- .../v1_0/tests/test_mediation_manager.py | 51 ++++++++++++------- 5 files changed, 49 insertions(+), 20 deletions(-) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py index 8e8f8922a5..20c63a5e15 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_handler.py @@ -32,6 +32,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder): session, context.connection_record.connection_id ) response = await mgr.update_keylist(record, updates=context.message.updates) + response.assign_thread_from(context.message) await responder.send_reply(response) except (StorageNotFoundError, MediationNotGrantedError): reply = CMProblemReport( diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py index bcca040072..79e2b56aec 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/keylist_update_response_handler.py @@ -25,3 +25,6 @@ async def handle(self, context: RequestContext, responder: BaseResponder): await mgr.store_update_results( context.connection_record.connection_id, context.message.updated ) + await mgr.notify_keylist_updated( + context.connection_record.connection_id, context.message + ) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py index c9e3bb868f..75218935d0 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_keylist_update_response_handler.py @@ -46,6 +46,9 @@ async def test_handler(self): handler, responder = KeylistUpdateResponseHandler(), MockResponder() with async_mock.patch.object( MediationManager, "store_update_results" - ) as mock_method: + ) as mock_store, async_mock.patch.object( + MediationManager, "notify_keylist_updated" + ) as mock_notify: await handler.handle(self.context, responder) - mock_method.assert_called_once_with(TEST_CONN_ID, self.updated) + mock_store.assert_called_once_with(TEST_CONN_ID, self.updated) + mock_notify.assert_called_once_with(TEST_CONN_ID, self.context.message) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py index 6da7845a94..195dd5fd5a 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py @@ -596,11 +596,16 @@ async def store_update_results( for record_for_removal in to_remove: await record_for_removal.delete_record(session) + async def notify_keylist_updated( + self, connection_id: str, response: KeylistUpdateResponse + ): + """Notify of keylist update response received.""" await self._profile.notify( self.KEYLIST_UPDATED_EVENT, { "connection_id": connection_id, - "updated": [update.serialize() for update in results], + "thread_id": response._thread_id, + "updated": [update.serialize() for update in response.updated], }, ) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py index 3f834f589d..f3a03c0a30 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py @@ -2,21 +2,15 @@ import logging from typing import AsyncIterable, Iterable -import pytest - from asynctest import mock as async_mock +import pytest -from .....core.profile import Profile, ProfileSession -from .....core.in_memory import InMemoryProfile +from .. import manager as test_module from .....core.event_bus import EventBus, MockEventBus -from .....connections.models.conn_record import ConnRecord -from .....messaging.request_context import RequestContext +from .....core.in_memory import InMemoryProfile +from .....core.profile import Profile, ProfileSession from .....storage.error import StorageNotFoundError -from .....transport.inbound.receipt import MessageReceipt - from ....routing.v1_0.models.route_record import RouteRecord - -from .. import manager as test_module from ..manager import ( MediationAlreadyExists, MediationManager, @@ -25,12 +19,14 @@ ) from ..messages.inner.keylist_update_rule import KeylistUpdateRule from ..messages.inner.keylist_updated import KeylistUpdated +from ..messages.keylist_update_response import KeylistUpdateResponse from ..messages.mediate_deny import MediationDeny from ..messages.mediate_grant import MediationGrant from ..messages.mediate_request import MediationRequest from ..models.mediation_record import MediationRecord TEST_CONN_ID = "conn-id" +TEST_THREAD_ID = "thread-id" TEST_ENDPOINT = "https://example.com" TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" TEST_ROUTE_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" @@ -373,7 +369,6 @@ async def test_store_update_results( self, session: ProfileSession, manager: MediationManager, - mock_event_bus: MockEventBus, ): """test_store_update_results.""" await RouteRecord( @@ -394,12 +389,6 @@ async def test_store_update_results( ), ] await manager.store_update_results(TEST_CONN_ID, results) - assert mock_event_bus.events - assert mock_event_bus.events[0][1].topic == manager.KEYLIST_UPDATED_EVENT - assert mock_event_bus.events[0][1].payload == { - "connection_id": TEST_CONN_ID, - "updated": [result.serialize() for result in results], - } routes = await RouteRecord.query(session) assert len(routes) == 1 @@ -484,6 +473,34 @@ async def test_store_update_results_errors(self, caplog, manager): assert "server_error" in caplog.text print(caplog.text) + async def test_notify_keylist_updated( + self, manager: MediationManager, mock_event_bus: MockEventBus + ): + """test notify_keylist_updated.""" + response = KeylistUpdateResponse( + updated=[ + KeylistUpdated( + recipient_key=TEST_ROUTE_VERKEY, + action=KeylistUpdateRule.RULE_ADD, + result=KeylistUpdated.RESULT_SUCCESS, + ), + KeylistUpdated( + recipient_key=TEST_VERKEY, + action=KeylistUpdateRule.RULE_REMOVE, + result=KeylistUpdated.RESULT_SUCCESS, + ), + ], + ) + response.assign_thread_id(TEST_THREAD_ID) + await manager.notify_keylist_updated(TEST_CONN_ID, response) + assert mock_event_bus.events + assert mock_event_bus.events[0][1].topic == manager.KEYLIST_UPDATED_EVENT + assert mock_event_bus.events[0][1].payload == { + "connection_id": TEST_CONN_ID, + "thread_id": TEST_THREAD_ID, + "updated": [result.serialize() for result in response.updated], + } + async def test_get_my_keylist(self, session, manager): """test_get_my_keylist.""" await RouteRecord(