Skip to content

Commit

Permalink
Merge pull request #1914 from sicpa-dlab/fix/conn-id-in-keylist-webhook
Browse files Browse the repository at this point in the history
feat: include connection ids in keylist update webhook
  • Loading branch information
ianco authored Oct 27, 2022
2 parents d0c4742 + 1a3774c commit 4696a26
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 56 deletions.
13 changes: 13 additions & 0 deletions aries_cloudagent/multitenant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,19 @@ async def _get_wallet_by_key(self, recipient_key: str) -> Optional[WalletRecord]
except (RouteNotFoundError):
pass

async def get_profile_for_key(
self, context: InjectionContext, recipient_key: str
) -> Optional[Profile]:
"""Retrieve a wallet profile by recipient key."""
wallet = await self._get_wallet_by_key(recipient_key)
if not wallet:
return None

if wallet.requires_external_key:
raise WalletKeyMissingError()

return await self.get_wallet_profile(context, wallet)

async def get_wallets_by_message(
self, message_body, wire_format: BaseWireFormat = None
) -> List[WalletRecord]:
Expand Down
36 changes: 35 additions & 1 deletion aries_cloudagent/multitenant/route_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
import logging
from typing import List, Optional, Tuple

from ..connections.models.conn_record import ConnRecord
from ..core.profile import Profile
from ..messaging.responder import BaseResponder
from ..protocols.coordinate_mediation.v1_0.manager import MediationManager
from ..protocols.coordinate_mediation.v1_0.models.mediation_record import (
MediationRecord,
)
from ..protocols.coordinate_mediation.v1_0.route_manager import RouteManager
from ..protocols.coordinate_mediation.v1_0.normalization import normalize_from_did_key
from ..protocols.coordinate_mediation.v1_0.route_manager import (
CoordinateMediationV1RouteManager,
RouteManager,
)
from ..protocols.routing.v1_0.manager import RoutingManager
from ..protocols.routing.v1_0.models.route_record import RouteRecord
from ..storage.error import StorageNotFoundError
from .base import BaseMultitenantManager


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,3 +109,31 @@ async def routing_info(
my_endpoint = mediation_record.endpoint

return routing_keys, my_endpoint


class BaseWalletRouteManager(CoordinateMediationV1RouteManager):
"""Route manager for operations specific to the base wallet."""

async def connection_from_recipient_key(
self, profile: Profile, recipient_key: str
) -> ConnRecord:
"""Retrieve a connection by recipient key.
The recipient key is expected to be a local key owned by this agent.
Since the multi-tenant base wallet can receive and send keylist updates
for sub wallets, we check the sub wallet's connections before the base
wallet.
"""
LOGGER.debug("Retrieving connection for recipient key for multitenant wallet")
manager = profile.inject(BaseMultitenantManager)
profile_to_search = (
await manager.get_profile_for_key(
profile.context, normalize_from_did_key(recipient_key)
)
or profile
)

return await super().connection_from_recipient_key(
profile_to_search, recipient_key
)
15 changes: 15 additions & 0 deletions aries_cloudagent/multitenant/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,18 @@ async def test_get_wallets_by_message(self):
assert wallets[0] == return_wallets[0]
assert wallets[1] == return_wallets[3]
assert get_wallet_by_key.call_count == 4

async def test_get_profile_for_key(self):
mock_wallet = async_mock.MagicMock()
mock_wallet.requires_external_key = False
with async_mock.patch.object(
self.manager,
"_get_wallet_by_key",
async_mock.CoroutineMock(return_value=mock_wallet),
), async_mock.patch.object(
self.manager, "get_wallet_profile", async_mock.CoroutineMock()
) as mock_get_wallet_profile:
profile = await self.manager.get_profile_for_key(
self.context, "test-verkey"
)
assert profile == mock_get_wallet_profile.return_value
25 changes: 24 additions & 1 deletion aries_cloudagent/multitenant/tests/test_route_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from ...protocols.coordinate_mediation.v1_0.models.mediation_record import (
MediationRecord,
)
from ...protocols.coordinate_mediation.v1_0.route_manager import RouteManager
from ...protocols.routing.v1_0.manager import RoutingManager
from ...protocols.routing.v1_0.models.route_record import RouteRecord
from ...storage.error import StorageNotFoundError
from ..route_manager import MultitenantRouteManager
from ..base import BaseMultitenantManager
from ..route_manager import BaseWalletRouteManager, MultitenantRouteManager

TEST_RECORD_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx"
TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL"
Expand Down Expand Up @@ -55,6 +57,11 @@ def route_manager(root_profile: Profile, sub_profile: Profile, wallet_id: str):
yield MultitenantRouteManager(root_profile)


@pytest.fixture
def base_route_manager():
yield BaseWalletRouteManager()


@pytest.mark.asyncio
async def test_route_for_key_sub_mediator_no_base_mediator(
route_manager: MultitenantRouteManager,
Expand Down Expand Up @@ -360,3 +367,19 @@ async def test_routing_info_with_base_mediator_and_sub_mediator(
)
assert keys == [*base_mediation_record.routing_keys, *mediation_record.routing_keys]
assert endpoint == mediation_record.endpoint


@pytest.mark.asyncio
async def test_connection_from_recipient_key(
sub_profile: Profile, base_route_manager: BaseWalletRouteManager
):
manager = mock.MagicMock()
manager.get_profile_for_key = mock.CoroutineMock(return_value=sub_profile)
sub_profile.context.injector.bind_instance(BaseMultitenantManager, manager)
with mock.patch.object(
RouteManager, "connection_from_recipient_key", mock.CoroutineMock()
) as mock_conn_for_recip:
result = await base_route_manager.connection_from_recipient_key(
sub_profile, TEST_VERKEY
)
assert result == mock_conn_for_recip.return_value
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Handler for keylist-update-response message."""

from .....core.profile import Profile
from .....messaging.base_handler import BaseHandler, HandlerException
from .....messaging.request_context import RequestContext
from .....messaging.responder import BaseResponder

from ..messages.keylist_update_response import KeylistUpdateResponse
from .....storage.error import StorageNotFoundError
from .....wallet.error import WalletNotFoundError
from ..manager import MediationManager
from ..messages.keylist_update_response import KeylistUpdateResponse
from ..route_manager import RouteManager


class KeylistUpdateResponseHandler(BaseHandler):
Expand All @@ -25,6 +28,39 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
await mgr.store_update_results(
context.connection_record.connection_id, context.message.updated
)
await mgr.notify_keylist_updated(
context.connection_record.connection_id, context.message
await self.notify_keylist_updated(
context.profile, context.connection_record.connection_id, context.message
)

async def notify_keylist_updated(
self, profile: Profile, connection_id: str, response: KeylistUpdateResponse
):
"""Notify of keylist update response received."""
route_manager = profile.inject(RouteManager)
self._logger.debug(
"Retrieving connection ID from route manager of type %s",
type(route_manager).__name__,
)
try:
key_to_connection = {
updated.recipient_key: await route_manager.connection_from_recipient_key(
profile, updated.recipient_key
)
for updated in response.updated
}
except (StorageNotFoundError, WalletNotFoundError) as err:
raise HandlerException(
"Unknown recipient key received in keylist update response"
) from err

await profile.notify(
MediationManager.KEYLIST_UPDATED_EVENT,
{
"connection_id": connection_id,
"thread_id": response._thread_id,
"updated": [update.serialize() for update in response.updated],
"mediated_connections": {
key: conn.connection_id for key, conn in key_to_connection.items()
},
},
)
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
"""Test handler for keylist-update-response message."""

from functools import partial
from typing import AsyncGenerator
import pytest
from asynctest import TestCase as AsyncTestCase
from asynctest import mock as async_mock


from ......connections.models.conn_record import ConnRecord
from ......core.event_bus import EventBus, MockEventBus
from ......messaging.base_handler import HandlerException
from ......messaging.request_context import RequestContext
from ......messaging.responder import MockResponder
from ...messages.inner.keylist_update_rule import KeylistUpdateRule
from ...messages.inner.keylist_updated import KeylistUpdated
from ...messages.keylist_update_response import KeylistUpdateResponse
from ...manager import MediationManager
from ...route_manager import RouteManager
from ...tests.test_route_manager import MockRouteManager
from ..keylist_update_response_handler import KeylistUpdateResponseHandler

TEST_CONN_ID = "conn-id"
TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx"
TEST_THREAD_ID = "thread-id"
TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL"
TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya"


class TestKeylistUpdateResponseHandler(AsyncTestCase):
Expand All @@ -34,6 +42,14 @@ async def setUp(self):
self.context.message = KeylistUpdateResponse(updated=self.updated)
self.context.connection_ready = True
self.context.connection_record = ConnRecord(connection_id=TEST_CONN_ID)
self.mock_event_bus = MockEventBus()
self.context.profile.context.injector.bind_instance(
EventBus, self.mock_event_bus
)
self.route_manager = MockRouteManager()
self.context.profile.context.injector.bind_instance(
RouteManager, self.route_manager
)

async def test_handler_no_active_connection(self):
handler, responder = KeylistUpdateResponseHandler(), MockResponder()
Expand All @@ -47,8 +63,86 @@ async def test_handler(self):
with async_mock.patch.object(
MediationManager, "store_update_results"
) as mock_store, async_mock.patch.object(
MediationManager, "notify_keylist_updated"
handler, "notify_keylist_updated"
) as mock_notify:
await handler.handle(self.context, responder)
mock_store.assert_called_once_with(TEST_CONN_ID, self.updated)
mock_notify.assert_called_once_with(TEST_CONN_ID, self.context.message)
mock_notify.assert_called_once_with(
self.context.profile, TEST_CONN_ID, self.context.message
)

async def test_notify_keylist_updated(self):
"""test notify_keylist_updated."""
handler = KeylistUpdateResponseHandler()

async def _result_generator():
yield ConnRecord(connection_id="conn_id_1")
yield ConnRecord(connection_id="conn_id_2")

async def _retrieve_by_invitation_key(
generator: AsyncGenerator, *args, **kwargs
):
return await generator.__anext__()

with async_mock.patch.object(
self.route_manager,
"connection_from_recipient_key",
partial(_retrieve_by_invitation_key, _result_generator()),
):
response = KeylistUpdateResponse(
updated=[
KeylistUpdated(
recipient_key=TEST_ROUTE_VERKEY,
action=KeylistUpdateRule.RULE_ADD,
result=KeylistUpdated.RESULT_SUCCESS,
),
KeylistUpdated(
recipient_key=TEST_VERKEY,
action=KeylistUpdateRule.RULE_REMOVE,
result=KeylistUpdated.RESULT_SUCCESS,
),
],
)

response.assign_thread_id(TEST_THREAD_ID)
await handler.notify_keylist_updated(
self.context.profile, TEST_CONN_ID, response
)
assert self.mock_event_bus.events
assert (
self.mock_event_bus.events[0][1].topic
== MediationManager.KEYLIST_UPDATED_EVENT
)
assert self.mock_event_bus.events[0][1].payload == {
"connection_id": TEST_CONN_ID,
"thread_id": TEST_THREAD_ID,
"updated": [result.serialize() for result in response.updated],
"mediated_connections": {
TEST_ROUTE_VERKEY: "conn_id_1",
TEST_VERKEY: "conn_id_2",
},
}

async def test_notify_keylist_updated_x_unknown_recip_key(self):
"""test notify_keylist_updated."""
handler = KeylistUpdateResponseHandler()
response = KeylistUpdateResponse(
updated=[
KeylistUpdated(
recipient_key=TEST_ROUTE_VERKEY,
action=KeylistUpdateRule.RULE_ADD,
result=KeylistUpdated.RESULT_SUCCESS,
),
KeylistUpdated(
recipient_key=TEST_VERKEY,
action=KeylistUpdateRule.RULE_REMOVE,
result=KeylistUpdated.RESULT_SUCCESS,
),
],
)

response.assign_thread_id(TEST_THREAD_ID)
with pytest.raises(HandlerException):
await handler.notify_keylist_updated(
self.context.profile, TEST_CONN_ID, response
)
16 changes: 2 additions & 14 deletions aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
from typing import Optional, Sequence, Tuple


from ....core.error import BaseError
from ....core.profile import Profile, ProfileSession
from ....storage.base import BaseStorage
Expand Down Expand Up @@ -539,6 +538,8 @@ async def store_update_results(
session: An active profile session
"""
# TODO The stored recipient keys are did:key!

to_save: Sequence[RouteRecord] = []
to_remove: Sequence[RouteRecord] = []

Expand Down Expand Up @@ -600,19 +601,6 @@ async def store_update_results(
for record_for_removal in to_remove:
await record_for_removal.delete_record(session)

async def notify_keylist_updated(
self, connection_id: str, response: KeylistUpdateResponse
):
"""Notify of keylist update response received."""
await self._profile.notify(
self.KEYLIST_UPDATED_EVENT,
{
"connection_id": connection_id,
"thread_id": response._thread_id,
"updated": [update.serialize() for update in response.updated],
},
)

async def get_my_keylist(
self, connection_id: Optional[str] = None
) -> Sequence[RouteRecord]:
Expand Down
Loading

0 comments on commit 4696a26

Please sign in to comment.