diff --git a/aries_cloudagent/config/wallet.py b/aries_cloudagent/config/wallet.py index 85e080b0fa..61e34b2a73 100644 --- a/aries_cloudagent/config/wallet.py +++ b/aries_cloudagent/config/wallet.py @@ -2,7 +2,6 @@ import logging from typing import Tuple -import weakref from ..core.error import ProfileNotFoundError from ..core.profile import Profile, ProfileManager, ProfileSession @@ -137,8 +136,6 @@ async def wallet_config( await txn.commit() - context.injector.bind_instance(Profile, weakref.ref(profile)) - return (profile, public_did_info) diff --git a/aries_cloudagent/multitenant/base.py b/aries_cloudagent/multitenant/base.py index 4c1b8e8ba7..cecd2314dd 100644 --- a/aries_cloudagent/multitenant/base.py +++ b/aries_cloudagent/multitenant/base.py @@ -200,7 +200,7 @@ async def create_wallet( if public_did_info: await profile.inject(RouteManager).route_public_did( - public_did_info.verkey + profile, public_did_info.verkey ) except Exception: await wallet_record.delete_record(session) diff --git a/aries_cloudagent/multitenant/route_manager.py b/aries_cloudagent/multitenant/route_manager.py index e20094bf71..a7d6cf6878 100644 --- a/aries_cloudagent/multitenant/route_manager.py +++ b/aries_cloudagent/multitenant/route_manager.py @@ -22,16 +22,12 @@ class MultitenantRouteManager(RouteManager): """Multitenancy route manager.""" - def __init__(self, root_profile: Profile, sub_profile: Profile, wallet_id: str): + def __init__( + self, + root_profile: Profile, + ): """Initialize multitenant route manager.""" self.root_profile = root_profile - self.wallet_id = wallet_id - super().__init__(sub_profile) - - @property - def sub_profile(self) -> Profile: - """Return reference to sub wallet profile.""" - return self.profile async def get_base_wallet_mediator(self) -> Optional[MediationRecord]: """Get base wallet's default mediator.""" @@ -39,14 +35,16 @@ async def get_base_wallet_mediator(self) -> Optional[MediationRecord]: async def _route_for_key( self, + profile: Profile, recipient_key: str, mediation_record: Optional[MediationRecord] = None, *, skip_if_exists: bool = False, replace_key: Optional[str] = None, ): + wallet_id = profile.settings["wallet.id"] LOGGER.info( - f"Add route record for recipient {recipient_key} to wallet {self.wallet_id}" + f"Add route record for recipient {recipient_key} to wallet {wallet_id}" ) routing_mgr = RoutingManager(self.root_profile) mediation_mgr = MediationManager(self.root_profile) @@ -66,7 +64,7 @@ async def _route_for_key( pass await routing_mgr.create_route_record( - recipient_key=recipient_key, internal_wallet_id=self.wallet_id + recipient_key=recipient_key, internal_wallet_id=wallet_id ) # External mediation @@ -86,7 +84,10 @@ async def _route_for_key( return keylist_updates async def routing_info( - self, my_endpoint: str, mediation_record: Optional[MediationRecord] = None + self, + profile: Profile, + my_endpoint: str, + mediation_record: Optional[MediationRecord] = None, ) -> Tuple[List[str], str]: """Return routing info.""" routing_keys = [] diff --git a/aries_cloudagent/multitenant/tests/test_route_manager.py b/aries_cloudagent/multitenant/tests/test_route_manager.py index 856ce7d839..87bfef7129 100644 --- a/aries_cloudagent/multitenant/tests/test_route_manager.py +++ b/aries_cloudagent/multitenant/tests/test_route_manager.py @@ -47,13 +47,7 @@ def sub_profile(mock_responder: MockResponder, wallet_id: str): @pytest.fixture def route_manager(root_profile: Profile, sub_profile: Profile, wallet_id: str): - yield MultitenantRouteManager(root_profile, sub_profile, wallet_id) - - -def test_sub_profile_access( - route_manager: MultitenantRouteManager, sub_profile: Profile -): - assert route_manager.sub_profile == sub_profile + yield MultitenantRouteManager(root_profile) @pytest.mark.asyncio @@ -61,6 +55,7 @@ async def test_route_for_key_sub_mediator_no_base_mediator( route_manager: MultitenantRouteManager, mock_responder: MockResponder, wallet_id: str, + sub_profile: Profile, ): mediation_record = MediationRecord( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id" @@ -72,6 +67,7 @@ async def test_route_for_key_sub_mediator_no_base_mediator( RoutingManager, "create_route_record", mock.CoroutineMock() ) as mock_create_route_record: keylist_update = await route_manager._route_for_key( + sub_profile, "test-recipient-key", mediation_record, skip_if_exists=False, @@ -94,6 +90,7 @@ async def test_route_for_key_sub_mediator_no_base_mediator( @pytest.mark.asyncio async def test_route_for_key_sub_mediator_and_base_mediator( + sub_profile: Profile, route_manager: MultitenantRouteManager, mock_responder: MockResponder, wallet_id: str, @@ -114,6 +111,7 @@ async def test_route_for_key_sub_mediator_and_base_mediator( RoutingManager, "create_route_record", mock.CoroutineMock() ) as mock_create_route_record: keylist_update = await route_manager._route_for_key( + sub_profile, "test-recipient-key", mediation_record, skip_if_exists=False, @@ -136,6 +134,7 @@ async def test_route_for_key_sub_mediator_and_base_mediator( @pytest.mark.asyncio async def test_route_for_key_base_mediator_no_sub_mediator( + sub_profile: Profile, route_manager: MultitenantRouteManager, mock_responder: MockResponder, wallet_id: str, @@ -153,7 +152,11 @@ async def test_route_for_key_base_mediator_no_sub_mediator( RoutingManager, "create_route_record", mock.CoroutineMock() ) as mock_create_route_record: keylist_update = await route_manager._route_for_key( - "test-recipient-key", None, skip_if_exists=False, replace_key=None + sub_profile, + "test-recipient-key", + None, + skip_if_exists=False, + replace_key=None, ) mock_create_route_record.assert_called_once_with( @@ -172,6 +175,7 @@ async def test_route_for_key_base_mediator_no_sub_mediator( @pytest.mark.asyncio async def test_route_for_key_skip_if_exists_and_exists( + sub_profile: Profile, route_manager: MultitenantRouteManager, mock_responder: MockResponder, ): @@ -182,6 +186,7 @@ async def test_route_for_key_skip_if_exists_and_exists( RouteRecord, "retrieve_by_recipient_key", mock.CoroutineMock() ): keylist_update = await route_manager._route_for_key( + sub_profile, "test-recipient-key", mediation_record, skip_if_exists=True, @@ -193,6 +198,7 @@ async def test_route_for_key_skip_if_exists_and_exists( @pytest.mark.asyncio async def test_route_for_key_skip_if_exists_and_absent( + sub_profile: Profile, route_manager: MultitenantRouteManager, mock_responder: MockResponder, ): @@ -205,6 +211,7 @@ async def test_route_for_key_skip_if_exists_and_absent( mock.CoroutineMock(side_effect=StorageNotFoundError), ): keylist_update = await route_manager._route_for_key( + sub_profile, "test-recipient-key", mediation_record, skip_if_exists=True, @@ -223,6 +230,7 @@ async def test_route_for_key_skip_if_exists_and_absent( @pytest.mark.asyncio async def test_route_for_key_replace_key( + sub_profile: Profile, route_manager: MultitenantRouteManager, mock_responder: MockResponder, ): @@ -230,6 +238,7 @@ async def test_route_for_key_replace_key( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id" ) keylist_update = await route_manager._route_for_key( + sub_profile, "test-recipient-key", mediation_record, skip_if_exists=False, @@ -249,10 +258,12 @@ async def test_route_for_key_replace_key( @pytest.mark.asyncio async def test_route_for_key_no_mediator( + sub_profile: Profile, route_manager: MultitenantRouteManager, ): assert ( await route_manager._route_for_key( + sub_profile, "test-recipient-key", None, skip_if_exists=True, @@ -264,6 +275,7 @@ async def test_route_for_key_no_mediator( @pytest.mark.asyncio async def test_routing_info_with_mediator( + sub_profile: Profile, route_manager: MultitenantRouteManager, ): mediation_record = MediationRecord( @@ -273,7 +285,7 @@ async def test_routing_info_with_mediator( endpoint="http://mediator.example.com", ) keys, endpoint = await route_manager.routing_info( - "http://example.com", mediation_record + sub_profile, "http://example.com", mediation_record ) assert keys == mediation_record.routing_keys assert endpoint == mediation_record.endpoint @@ -281,15 +293,19 @@ async def test_routing_info_with_mediator( @pytest.mark.asyncio async def test_routing_info_no_mediator( + sub_profile: Profile, route_manager: MultitenantRouteManager, ): - keys, endpoint = await route_manager.routing_info("http://example.com", None) + keys, endpoint = await route_manager.routing_info( + sub_profile, "http://example.com", None + ) assert keys == [] assert endpoint == "http://example.com" @pytest.mark.asyncio async def test_routing_info_with_base_mediator( + sub_profile: Profile, route_manager: MultitenantRouteManager, ): base_mediation_record = MediationRecord( @@ -304,13 +320,16 @@ async def test_routing_info_with_base_mediator( "get_base_wallet_mediator", mock.CoroutineMock(return_value=base_mediation_record), ): - keys, endpoint = await route_manager.routing_info("http://example.com", None) + keys, endpoint = await route_manager.routing_info( + sub_profile, "http://example.com", None + ) assert keys == base_mediation_record.routing_keys assert endpoint == base_mediation_record.endpoint @pytest.mark.asyncio async def test_routing_info_with_base_mediator_and_sub_mediator( + sub_profile: Profile, route_manager: MultitenantRouteManager, ): mediation_record = MediationRecord( @@ -332,7 +351,7 @@ async def test_routing_info_with_base_mediator_and_sub_mediator( mock.CoroutineMock(return_value=base_mediation_record), ): keys, endpoint = await route_manager.routing_info( - "http://example.com", mediation_record + sub_profile, "http://example.com", mediation_record ) assert keys == [*base_mediation_record.routing_keys, *mediation_record.routing_keys] assert endpoint == mediation_record.endpoint diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index 4cdf1b8d77..d937b7ecb2 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -122,6 +122,7 @@ async def create_invitation( # Mediation Record can still be None after this operation if no # mediation id passed and no default mediation_record = await self._route_manager.mediation_record_if_id( + self.profile, mediation_id, or_default=True, ) @@ -159,7 +160,7 @@ async def create_invitation( # Add mapping for multitenant relaying. # Mediation of public keys is not supported yet - await self._route_manager.route_public_did(public_did.verkey) + await self._route_manager.route_public_did(self.profile, public_did.verkey) return None, invitation @@ -206,8 +207,11 @@ async def create_invitation( async with self.profile.session() as session: await connection.save(session, reason="Created new invitation") - await self._route_manager.route_invitation(connection, mediation_record) + await self._route_manager.route_invitation( + self.profile, connection, mediation_record + ) routing_keys, my_endpoint = await self._route_manager.routing_info( + self.profile, my_endpoint or cast(str, self.profile.settings.get("default_endpoint")), mediation_record, ) @@ -297,7 +301,7 @@ async def receive_invitation( await connection.attach_invitation(session, invitation) await self._route_manager.save_mediator_for_connection( - connection, mediation_id=mediation_id + self.profile, connection, mediation_id=mediation_id ) if connection.accept == ConnRecord.ACCEPT_AUTO: @@ -335,6 +339,7 @@ async def create_request( """ mediation_record = await self._route_manager.mediation_record_for_connection( + self.profile, connection, mediation_id, or_default=True, @@ -360,7 +365,7 @@ async def create_request( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_invitee( - connection, mediation_record + self.profile, connection, mediation_record ) # Create connection request message @@ -579,7 +584,7 @@ async def create_response( ) mediation_record = await self._route_manager.mediation_record_for_connection( - connection, mediation_id + self.profile, connection, mediation_id ) # Multitenancy setup @@ -613,7 +618,7 @@ async def create_response( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( - connection, mediation_record + self.profile, connection, mediation_record ) # Create connection response message @@ -863,7 +868,7 @@ async def create_static_connection( # Routing mediation_record = await self._route_manager.mediation_record_if_id( - mediation_id, or_default=True + self.profile, mediation_id, or_default=True ) multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) @@ -873,7 +878,9 @@ async def create_static_connection( if multitenant_mgr and wallet_id: base_mediation_record = await multitenant_mgr.get_default_mediator() - await self._route_manager.route_static(connection, mediation_record) + await self._route_manager.route_static( + self.profile, connection, mediation_record + ) # Synthesize their DID doc did_doc = await self.create_did_document( diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index 66d284123d..d7c3836236 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -177,7 +177,7 @@ async def test_create_invitation_public(self): assert connect_record is None assert connect_invite.did.endswith(self.test_did) self.route_manager.route_public_did.assert_called_once_with( - self.test_verkey + self.profile, self.test_verkey ) async def test_create_invitation_public_no_public_invites(self): @@ -356,7 +356,7 @@ async def test_create_invitation_mediation_using_default(self): assert invite.routing_keys == self.test_mediator_routing_keys assert invite.endpoint == self.test_mediator_endpoint self.route_manager.routing_info.assert_awaited_once_with( - self.test_endpoint, mediation_record + self.profile, self.test_endpoint, mediation_record ) async def test_receive_invitation(self): diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py index 6197c88fd8..89fbf92081 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py @@ -32,21 +32,19 @@ class RouteManagerError(Exception): class RouteManager(ABC): """Base Route Manager.""" - def __init__(self, profile: Profile): - """Initialize route manager.""" - self.profile = profile - - async def get_or_create_my_did(self, conn_record: ConnRecord) -> DIDInfo: + async def get_or_create_my_did( + self, profile: Profile, conn_record: ConnRecord + ) -> DIDInfo: """Create or retrieve DID info for a conneciton.""" if not conn_record.my_did: - async with self.profile.session() as session: + async with profile.session() as session: wallet = session.inject(BaseWallet) # Create new DID for connection my_info = await wallet.create_local_did(DIDMethod.SOV, KeyType.ED25519) conn_record.my_did = my_info.did await conn_record.save(session, reason="Connection my did created") else: - async with self.profile.session() as session: + async with profile.session() as session: wallet = session.inject(BaseWallet) my_info = await wallet.get_local_did(conn_record.my_did) @@ -62,12 +60,13 @@ def _validate_mediation_state(self, mediation_record: MediationRecord): async def mediation_record_for_connection( self, + profile: Profile, conn_record: ConnRecord, mediation_id: Optional[str] = None, or_default: bool = False, ): """Return relevant mediator for connection.""" - async with self.profile.session() as session: + async with profile.session() as session: mediation_metadata = await conn_record.metadata_get( session, MediationManager.METADATA_KEY, {} ) @@ -75,13 +74,20 @@ async def mediation_record_for_connection( mediation_metadata.get(MediationManager.METADATA_ID) or mediation_id ) - mediation_record = await self.mediation_record_if_id(mediation_id, or_default) + mediation_record = await self.mediation_record_if_id( + profile, mediation_id, or_default + ) if mediation_record: - await self.save_mediator_for_connection(conn_record, mediation_record) + await self.save_mediator_for_connection( + profile, conn_record, mediation_record + ) return mediation_record async def mediation_record_if_id( - self, mediation_id: Optional[str] = None, or_default: bool = False + self, + profile: Profile, + mediation_id: Optional[str] = None, + or_default: bool = False, ): """Validate mediation and return record. @@ -91,14 +97,12 @@ async def mediation_record_if_id( """ mediation_record = None if mediation_id: - async with self.profile.session() as session: + async with profile.session() as session: mediation_record = await MediationRecord.retrieve_by_id( session, mediation_id ) elif or_default: - mediation_record = await MediationManager( - self.profile - ).get_default_mediator() + mediation_record = await MediationManager(profile).get_default_mediator() if mediation_record: self._validate_mediation_state(mediation_record) @@ -107,6 +111,7 @@ async def mediation_record_if_id( @abstractmethod async def _route_for_key( self, + profile: Profile, recipient_key: str, mediation_record: Optional[MediationRecord] = None, *, @@ -117,25 +122,28 @@ async def _route_for_key( async def route_connection_as_invitee( self, + profile: Profile, conn_record: ConnRecord, mediation_record: Optional[MediationRecord] = None, ) -> Optional[KeylistUpdate]: """Set up routing for a new connection when we are the invitee.""" LOGGER.debug("Routing connection as invitee") - my_info = await self.get_or_create_my_did(conn_record) + my_info = await self.get_or_create_my_did(profile, conn_record) return await self._route_for_key( - my_info.verkey, mediation_record, skip_if_exists=True + profile, my_info.verkey, mediation_record, skip_if_exists=True ) async def route_connection_as_inviter( self, + profile: Profile, conn_record: ConnRecord, mediation_record: Optional[MediationRecord] = None, ) -> Optional[KeylistUpdate]: """Set up routing for a new connection when we are the inviter.""" LOGGER.debug("Routing connection as inviter") - my_info = await self.get_or_create_my_did(conn_record) + my_info = await self.get_or_create_my_did(profile, conn_record) return await self._route_for_key( + profile, my_info.verkey, mediation_record, replace_key=conn_record.invitation_key, @@ -144,6 +152,7 @@ async def route_connection_as_inviter( async def route_connection( self, + profile: Profile, conn_record: ConnRecord, mediation_record: Optional[MediationRecord] = None, ) -> Optional[KeylistUpdate]: @@ -154,53 +163,63 @@ async def route_connection( if conn_record.rfc23_state == ConnRecord.State.INVITATION.rfc23strict( ConnRecord.Role.RESPONDER ): - return await self.route_connection_as_invitee(conn_record, mediation_record) + return await self.route_connection_as_invitee( + profile, conn_record, mediation_record + ) if conn_record.rfc23_state == ConnRecord.State.REQUEST.rfc23strict( ConnRecord.Role.REQUESTER ): - return await self.route_connection_as_inviter(conn_record, mediation_record) + return await self.route_connection_as_inviter( + profile, conn_record, mediation_record + ) return None async def route_invitation( self, + profile: Profile, conn_record: ConnRecord, mediation_record: Optional[MediationRecord] = None, ) -> Optional[KeylistUpdate]: """Set up routing for receiving a response to an invitation.""" - await self.save_mediator_for_connection(conn_record, mediation_record) + await self.save_mediator_for_connection(profile, conn_record, mediation_record) if conn_record.invitation_key: return await self._route_for_key( - conn_record.invitation_key, mediation_record, skip_if_exists=True + profile, + conn_record.invitation_key, + mediation_record, + skip_if_exists=True, ) raise ValueError("Expected connection to have invitation_key") - async def route_public_did(self, verkey: str): + async def route_public_did(self, profile: Profile, verkey: str): """Establish routing for a public DID.""" - return await self._route_for_key(verkey, skip_if_exists=True) + return await self._route_for_key(profile, verkey, skip_if_exists=True) async def route_static( self, + profile: Profile, conn_record: ConnRecord, mediation_record: Optional[MediationRecord] = None, ) -> Optional[KeylistUpdate]: """Establish routing for a static connection.""" - my_info = await self.get_or_create_my_did(conn_record) + my_info = await self.get_or_create_my_did(profile, conn_record) return await self._route_for_key( - my_info.verkey, mediation_record, skip_if_exists=True + profile, my_info.verkey, mediation_record, skip_if_exists=True ) async def save_mediator_for_connection( self, + profile: Profile, conn_record: ConnRecord, mediation_record: Optional[MediationRecord] = None, mediation_id: Optional[str] = None, ): """Save mediator info to connection metadata.""" - async with self.profile.session() as session: + async with profile.session() as session: if mediation_id: mediation_record = await MediationRecord.retrieve_by_id( session, mediation_id @@ -216,6 +235,7 @@ async def save_mediator_for_connection( @abstractmethod async def routing_info( self, + profile: Profile, my_endpoint: str, mediation_record: Optional[MediationRecord] = None, ) -> Tuple[List[str], str]: @@ -227,6 +247,7 @@ class CoordinateMediationV1RouteManager(RouteManager): async def _route_for_key( self, + profile: Profile, recipient_key: str, mediation_record: Optional[MediationRecord] = None, *, @@ -238,7 +259,7 @@ async def _route_for_key( if skip_if_exists: try: - async with self.profile.session() as session: + async with profile.session() as session: await RouteRecord.retrieve_by_recipient_key(session, recipient_key) return None @@ -246,19 +267,22 @@ async def _route_for_key( pass # Keylist update is idempotent, skip_if_exists ignored - mediation_mgr = MediationManager(self.profile) + mediation_mgr = MediationManager(profile) keylist_update = await mediation_mgr.add_key(recipient_key) if replace_key: keylist_update = await mediation_mgr.remove_key(replace_key, keylist_update) - responder = self.profile.inject(BaseResponder) + responder = profile.inject(BaseResponder) await responder.send( keylist_update, connection_id=mediation_record.connection_id ) return keylist_update async def routing_info( - self, my_endpoint: str, mediation_record: Optional[MediationRecord] = None + self, + profile: Profile, + my_endpoint: str, + mediation_record: Optional[MediationRecord] = None, ) -> Tuple[List[str], str]: """Return routing info for mediator.""" if mediation_record: diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager_provider.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager_provider.py index d053094a52..693766c922 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager_provider.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager_provider.py @@ -20,8 +20,7 @@ def provide(self, settings: BaseSettings, injector: BaseInjector): """Create the appropriate route manager instance.""" wallet_id = settings.get("wallet.id") multitenant_mgr = injector.inject_or(BaseMultitenantManager) - profile = injector.inject(Profile) if multitenant_mgr and wallet_id: - return MultitenantRouteManager(self.root_profile, profile, wallet_id) + return MultitenantRouteManager(self.root_profile) - return CoordinateMediationV1RouteManager(profile) + return CoordinateMediationV1RouteManager() diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py index 61b72bb975..dcb315337b 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py @@ -37,8 +37,8 @@ def profile(mock_responder: MockResponder): @pytest.fixture -def route_manager(profile: Profile): - manager = MockRouteManager(profile) +def route_manager(): + manager = MockRouteManager() manager._route_for_key = mock.CoroutineMock( return_value=mock.MagicMock(KeylistUpdate) ) @@ -47,8 +47,8 @@ def route_manager(profile: Profile): @pytest.fixture -def mediation_route_manager(profile: Profile): - yield CoordinateMediationV1RouteManager(profile) +def mediation_route_manager(): + yield CoordinateMediationV1RouteManager() @pytest.fixture @@ -61,7 +61,7 @@ def conn_record(): @pytest.mark.asyncio async def test_get_or_create_my_did_no_did( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): conn_record.my_did = None mock_did_info = mock.MagicMock() @@ -72,7 +72,7 @@ async def test_get_or_create_my_did_no_did( ) as mock_create_local_did, mock.patch.object( conn_record, "save", mock.CoroutineMock() ) as mock_save: - info = await route_manager.get_or_create_my_did(conn_record) + info = await route_manager.get_or_create_my_did(profile, conn_record) assert mock_did_info == info mock_create_local_did.assert_called_once() mock_save.assert_called_once() @@ -80,21 +80,21 @@ async def test_get_or_create_my_did_no_did( @pytest.mark.asyncio async def test_get_or_create_my_did_existing_did( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): conn_record.my_did = "test-did" mock_did_info = mock.MagicMock(DIDInfo) with mock.patch.object( InMemoryWallet, "get_local_did", mock.CoroutineMock(return_value=mock_did_info) ) as mock_get_local_did: - info = await route_manager.get_or_create_my_did(conn_record) + info = await route_manager.get_or_create_my_did(profile, conn_record) assert mock_did_info == info mock_get_local_did.assert_called_once() @pytest.mark.asyncio async def test_mediation_record_for_connection_mediation_id( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") with mock.patch.object( @@ -106,18 +106,18 @@ async def test_mediation_record_for_connection_mediation_id( ): assert ( await route_manager.mediation_record_for_connection( - conn_record, mediation_record.mediation_id + profile, conn_record, mediation_record.mediation_id ) == mediation_record ) mock_mediation_record_if_id.assert_called_once_with( - mediation_record.mediation_id, False + profile, mediation_record.mediation_id, False ) @pytest.mark.asyncio async def test_mediation_record_for_connection_mediation_metadata( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") conn_record.metadata_get.return_value = { @@ -132,18 +132,18 @@ async def test_mediation_record_for_connection_mediation_metadata( ): assert ( await route_manager.mediation_record_for_connection( - conn_record, "another-mediation-id" + profile, conn_record, "another-mediation-id" ) == mediation_record ) mock_mediation_record_if_id.assert_called_once_with( - mediation_record.mediation_id, False + profile, mediation_record.mediation_id, False ) @pytest.mark.asyncio async def test_mediation_record_for_connection_default( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") with mock.patch.object( @@ -155,15 +155,17 @@ async def test_mediation_record_for_connection_default( ): assert ( await route_manager.mediation_record_for_connection( - conn_record, None, or_default=True + profile, conn_record, None, or_default=True ) == mediation_record ) - mock_mediation_record_if_id.assert_called_once_with(None, True) + mock_mediation_record_if_id.assert_called_once_with(profile, None, True) @pytest.mark.asyncio -async def test_mediation_record_if_id_with_id(route_manager: RouteManager): +async def test_mediation_record_if_id_with_id( + profile: Profile, route_manager: RouteManager +): mediation_record = MediationRecord( mediation_id="test-mediation-id", state=MediationRecord.STATE_GRANTED ) @@ -173,14 +175,16 @@ async def test_mediation_record_if_id_with_id(route_manager: RouteManager): mock.CoroutineMock(return_value=mediation_record), ) as mock_retrieve_by_id: actual = await route_manager.mediation_record_if_id( - mediation_id=mediation_record.mediation_id + profile, mediation_id=mediation_record.mediation_id ) assert mediation_record == actual mock_retrieve_by_id.assert_called_once() @pytest.mark.asyncio -async def test_mediation_record_if_id_with_id_bad_state(route_manager: RouteManager): +async def test_mediation_record_if_id_with_id_bad_state( + profile: Profile, route_manager: RouteManager +): mediation_record = MediationRecord( mediation_id="test-mediation-id", state=MediationRecord.STATE_DENIED ) @@ -191,12 +195,14 @@ async def test_mediation_record_if_id_with_id_bad_state(route_manager: RouteMana ): with pytest.raises(RouteManagerError): await route_manager.mediation_record_if_id( - mediation_id=mediation_record.mediation_id + profile, mediation_id=mediation_record.mediation_id ) @pytest.mark.asyncio -async def test_mediation_record_if_id_with_id_and_default(route_manager: RouteManager): +async def test_mediation_record_if_id_with_id_and_default( + profile: Profile, route_manager: RouteManager +): mediation_record = MediationRecord( mediation_id="test-mediation-id", state=MediationRecord.STATE_GRANTED ) @@ -208,7 +214,7 @@ async def test_mediation_record_if_id_with_id_and_default(route_manager: RouteMa MediationManager, "get_default_mediator", mock.CoroutineMock() ) as mock_get_default_mediator: actual = await route_manager.mediation_record_if_id( - mediation_id=mediation_record.mediation_id, or_default=True + profile, mediation_id=mediation_record.mediation_id, or_default=True ) assert mediation_record == actual mock_retrieve_by_id.assert_called_once() @@ -217,6 +223,7 @@ async def test_mediation_record_if_id_with_id_and_default(route_manager: RouteMa @pytest.mark.asyncio async def test_mediation_record_if_id_without_id_and_default( + profile: Profile, route_manager: RouteManager, ): mediation_record = MediationRecord( @@ -230,7 +237,7 @@ async def test_mediation_record_if_id_without_id_and_default( mock.CoroutineMock(return_value=mediation_record), ) as mock_get_default_mediator: actual = await route_manager.mediation_record_if_id( - mediation_id=None, or_default=True + profile, mediation_id=None, or_default=True ) assert mediation_record == actual mock_retrieve_by_id.assert_not_called() @@ -239,6 +246,7 @@ async def test_mediation_record_if_id_without_id_and_default( @pytest.mark.asyncio async def test_mediation_record_if_id_without_id_and_no_default( + profile: Profile, route_manager: RouteManager, ): with mock.patch.object( @@ -248,7 +256,7 @@ async def test_mediation_record_if_id_without_id_and_no_default( ) as mock_get_default_mediator: assert ( await route_manager.mediation_record_if_id( - mediation_id=None, or_default=True + profile, mediation_id=None, or_default=True ) is None ) @@ -258,7 +266,7 @@ async def test_mediation_record_if_id_without_id_and_no_default( @pytest.mark.asyncio async def test_route_connection_as_invitee( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") mock_did_info = mock.MagicMock(DIDInfo) @@ -267,15 +275,17 @@ async def test_route_connection_as_invitee( "get_or_create_my_did", mock.CoroutineMock(return_value=mock_did_info), ): - await route_manager.route_connection_as_invitee(conn_record, mediation_record) + await route_manager.route_connection_as_invitee( + profile, conn_record, mediation_record + ) route_manager._route_for_key.assert_called_once_with( - mock_did_info.verkey, mediation_record, skip_if_exists=True + profile, mock_did_info.verkey, mediation_record, skip_if_exists=True ) @pytest.mark.asyncio async def test_route_connection_as_inviter( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") mock_did_info = mock.MagicMock(DIDInfo) @@ -285,8 +295,11 @@ async def test_route_connection_as_inviter( "get_or_create_my_did", mock.CoroutineMock(return_value=mock_did_info), ): - await route_manager.route_connection_as_inviter(conn_record, mediation_record) + await route_manager.route_connection_as_inviter( + profile, conn_record, mediation_record + ) route_manager._route_for_key.assert_called_once_with( + profile, mock_did_info.verkey, mediation_record, replace_key="test-invitation-key", @@ -296,7 +309,7 @@ async def test_route_connection_as_inviter( @pytest.mark.asyncio async def test_route_connection_state_invitee( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") conn_record.state = "invitation" @@ -306,14 +319,14 @@ async def test_route_connection_state_invitee( ) as mock_route_connection_as_invitee, mock.patch.object( route_manager, "route_connection_as_inviter", mock.CoroutineMock() ) as mock_route_connection_as_inviter: - await route_manager.route_connection(conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, mediation_record) mock_route_connection_as_invitee.assert_called_once() mock_route_connection_as_inviter.assert_not_called() @pytest.mark.asyncio async def test_route_connection_state_inviter( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") conn_record.state = "request" @@ -323,57 +336,62 @@ async def test_route_connection_state_inviter( ) as mock_route_connection_as_invitee, mock.patch.object( route_manager, "route_connection_as_inviter", mock.CoroutineMock() ) as mock_route_connection_as_inviter: - await route_manager.route_connection(conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, mediation_record) mock_route_connection_as_inviter.assert_called_once() mock_route_connection_as_invitee.assert_not_called() @pytest.mark.asyncio async def test_route_connection_state_other( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") conn_record.state = "response" conn_record.their_role = "requester" - assert await route_manager.route_connection(conn_record, mediation_record) is None + assert ( + await route_manager.route_connection(profile, conn_record, mediation_record) + is None + ) @pytest.mark.asyncio async def test_route_invitation_with_key( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") conn_record.invitation_key = "test-invitation-key" with mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): - await route_manager.route_invitation(conn_record, mediation_record) + await route_manager.route_invitation(profile, conn_record, mediation_record) route_manager._route_for_key.assert_called_once() @pytest.mark.asyncio async def test_route_invitation_without_key( - route_manager: RouteManager, conn_record: ConnRecord + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord ): mediation_record = MediationRecord(mediation_id="test-mediation-id") with mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): with pytest.raises(ValueError): - await route_manager.route_invitation(conn_record, mediation_record) + await route_manager.route_invitation(profile, conn_record, mediation_record) route_manager._route_for_key.assert_not_called() @pytest.mark.asyncio -async def test_route_public_did(route_manager: RouteManager): - await route_manager.route_public_did("test-verkey") +async def test_route_public_did(profile: Profile, route_manager: RouteManager): + await route_manager.route_public_did(profile, "test-verkey") route_manager._route_for_key.assert_called_once_with( - "test-verkey", skip_if_exists=True + profile, "test-verkey", skip_if_exists=True ) @pytest.mark.asyncio -async def test_route_static(route_manager: RouteManager, conn_record: ConnRecord): +async def test_route_static( + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord +): mediation_record = MediationRecord(mediation_id="test-mediation-id") mock_did_info = mock.MagicMock(DIDInfo) conn_record.invitation_key = "test-invitation-key" @@ -382,8 +400,9 @@ async def test_route_static(route_manager: RouteManager, conn_record: ConnRecord "get_or_create_my_did", mock.CoroutineMock(return_value=mock_did_info), ): - await route_manager.route_static(conn_record, mediation_record) + await route_manager.route_static(profile, conn_record, mediation_record) route_manager._route_for_key.assert_called_once_with( + profile, mock_did_info.verkey, mediation_record, skip_if_exists=True, @@ -392,7 +411,9 @@ async def test_route_static(route_manager: RouteManager, conn_record: ConnRecord @pytest.mark.asyncio async def test_save_mediator_for_connection_record( - route_manager: RouteManager, conn_record: ConnRecord, profile: Profile + profile: Profile, + route_manager: RouteManager, + conn_record: ConnRecord, ): mediation_record = MediationRecord(mediation_id="test-mediation-id") session = mock.MagicMock() @@ -402,7 +423,9 @@ async def test_save_mediator_for_connection_record( with mock.patch.object( MediationRecord, "retrieve_by_id", mock.CoroutineMock() ) as mock_retrieve_by_id: - await route_manager.save_mediator_for_connection(conn_record, mediation_record) + await route_manager.save_mediator_for_connection( + profile, conn_record, mediation_record + ) mock_retrieve_by_id.assert_not_called() conn_record.metadata_set.assert_called_once_with( session, @@ -413,7 +436,9 @@ async def test_save_mediator_for_connection_record( @pytest.mark.asyncio async def test_save_mediator_for_connection_id( - route_manager: RouteManager, conn_record: ConnRecord, profile: Profile + profile: Profile, + route_manager: RouteManager, + conn_record: ConnRecord, ): mediation_record = MediationRecord(mediation_id="test-mediation-id") session = mock.MagicMock() @@ -426,7 +451,7 @@ async def test_save_mediator_for_connection_id( mock.CoroutineMock(return_value=mediation_record), ) as mock_retrieve_by_id: await route_manager.save_mediator_for_connection( - conn_record, mediation_id=mediation_record.mediation_id + profile, conn_record, mediation_id=mediation_record.mediation_id ) mock_retrieve_by_id.assert_called_once() conn_record.metadata_set.assert_called_once_with( @@ -438,18 +463,21 @@ async def test_save_mediator_for_connection_id( @pytest.mark.asyncio async def test_save_mediator_for_connection_no_mediator( - route_manager: RouteManager, conn_record: ConnRecord, profile: Profile + profile: Profile, + route_manager: RouteManager, + conn_record: ConnRecord, ): with mock.patch.object( MediationRecord, "retrieve_by_id", mock.CoroutineMock() ) as mock_retrieve_by_id: - await route_manager.save_mediator_for_connection(conn_record) + await route_manager.save_mediator_for_connection(profile, conn_record) mock_retrieve_by_id.assert_not_called() conn_record.metadata_set.assert_not_called() @pytest.mark.asyncio async def test_mediation_route_for_key( + profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, mock_responder: MockResponder, ): @@ -457,7 +485,11 @@ async def test_mediation_route_for_key( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id" ) keylist_update = await mediation_route_manager._route_for_key( - "test-recipient-key", mediation_record, skip_if_exists=False, replace_key=None + profile, + "test-recipient-key", + mediation_record, + skip_if_exists=False, + replace_key=None, ) assert keylist_update assert keylist_update.serialize()["updates"] == [ @@ -472,6 +504,7 @@ async def test_mediation_route_for_key( @pytest.mark.asyncio async def test_mediation_route_for_key_skip_if_exists_and_exists( + profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, mock_responder: MockResponder, ): @@ -482,6 +515,7 @@ async def test_mediation_route_for_key_skip_if_exists_and_exists( RouteRecord, "retrieve_by_recipient_key", mock.CoroutineMock() ): keylist_update = await mediation_route_manager._route_for_key( + profile, "test-recipient-key", mediation_record, skip_if_exists=True, @@ -493,6 +527,7 @@ async def test_mediation_route_for_key_skip_if_exists_and_exists( @pytest.mark.asyncio async def test_mediation_route_for_key_skip_if_exists_and_absent( + profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, mock_responder: MockResponder, ): @@ -505,6 +540,7 @@ async def test_mediation_route_for_key_skip_if_exists_and_absent( mock.CoroutineMock(side_effect=StorageNotFoundError), ): keylist_update = await mediation_route_manager._route_for_key( + profile, "test-recipient-key", mediation_record, skip_if_exists=True, @@ -523,6 +559,7 @@ async def test_mediation_route_for_key_skip_if_exists_and_absent( @pytest.mark.asyncio async def test_mediation_route_for_key_replace_key( + profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, mock_responder: MockResponder, ): @@ -530,6 +567,7 @@ async def test_mediation_route_for_key_replace_key( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id" ) keylist_update = await mediation_route_manager._route_for_key( + profile, "test-recipient-key", mediation_record, skip_if_exists=False, @@ -549,10 +587,12 @@ async def test_mediation_route_for_key_replace_key( @pytest.mark.asyncio async def test_mediation_route_for_key_no_mediator( + profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, ): assert ( await mediation_route_manager._route_for_key( + profile, "test-recipient-key", None, skip_if_exists=True, @@ -564,6 +604,7 @@ async def test_mediation_route_for_key_no_mediator( @pytest.mark.asyncio async def test_mediation_routing_info_with_mediator( + profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, ): mediation_record = MediationRecord( @@ -573,7 +614,7 @@ async def test_mediation_routing_info_with_mediator( endpoint="http://mediator.example.com", ) keys, endpoint = await mediation_route_manager.routing_info( - "http://example.com", mediation_record + profile, "http://example.com", mediation_record ) assert keys == mediation_record.routing_keys assert endpoint == mediation_record.endpoint @@ -581,10 +622,11 @@ async def test_mediation_routing_info_with_mediator( @pytest.mark.asyncio async def test_mediation_routing_info_no_mediator( + profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, ): keys, endpoint = await mediation_route_manager.routing_info( - "http://example.com", None + profile, "http://example.com", None ) assert keys == [] assert endpoint == "http://example.com" diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 5ad6c83642..ec80923e9a 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -155,7 +155,7 @@ async def receive_invitation( ].public_key_base58 await self._route_manager.save_mediator_for_connection( - conn_rec, mediation_id=mediation_id + self.profile, conn_rec, mediation_id=mediation_id ) if conn_rec.accept == ConnRecord.ACCEPT_AUTO: @@ -259,6 +259,7 @@ async def create_request( """ # Mediation Support mediation_record = await self._route_manager.mediation_record_for_connection( + self.profile, conn_rec, mediation_id, or_default=True, @@ -290,7 +291,7 @@ async def create_request( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_invitee( - conn_rec, mediation_record + self.profile, conn_rec, mediation_record ) # Create connection request message @@ -550,7 +551,7 @@ async def create_response( ) mediation_record = await self._route_manager.mediation_record_for_connection( - conn_rec, mediation_id + self.profile, conn_rec, mediation_id ) # Multitenancy setup @@ -583,7 +584,7 @@ async def create_response( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( - conn_rec, mediation_record + self.profile, conn_rec, mediation_record ) # Create connection response message diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index e3f4548d47..7b02b328f7 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -110,6 +110,7 @@ async def create_invitation( """ mediation_record = await self._route_manager.mediation_record_if_id( + self.profile, mediation_id, or_default=True, ) @@ -296,7 +297,7 @@ async def create_invitation( await conn_rec.save(session, reason="Created new connection") routing_keys, my_endpoint = await self._route_manager.routing_info( - my_endpoint, mediation_record + self.profile, my_endpoint, mediation_record ) if not conn_rec: @@ -358,7 +359,9 @@ async def create_invitation( async with self.profile.session() as session: await oob_record.save(session, reason="Created new oob invitation") - await self._route_manager.route_invitation(conn_rec, mediation_record) + await self._route_manager.route_invitation( + self.profile, conn_rec, mediation_record + ) return InvitationRecord( # for return via admin API, not storage oob_id=oob_record.oob_id, @@ -392,7 +395,9 @@ async def receive_invitation( """ if mediation_id: try: - await self._route_manager.mediation_record_if_id(mediation_id) + await self._route_manager.mediation_record_if_id( + self.profile, mediation_id + ) except StorageNotFoundError: mediation_id = None diff --git a/aries_cloudagent/wallet/routes.py b/aries_cloudagent/wallet/routes.py index ff8da41b7a..fd26fe8c45 100644 --- a/aries_cloudagent/wallet/routes.py +++ b/aries_cloudagent/wallet/routes.py @@ -570,7 +570,7 @@ async def promote_wallet_public_did( # Route the public DID route_manager = profile.inject(RouteManager) - await route_manager.route_public_did(info.verkey) + await route_manager.route_public_did(profile, info.verkey) return info, attrib_def