diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py index 645ca45adf..990252cd5a 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py @@ -144,11 +144,21 @@ async def route_connection_as_inviter( """Set up routing for a new connection when we are the inviter.""" LOGGER.debug("Routing connection as inviter") my_info = await self.get_or_create_my_did(profile, conn_record) + + replace_key = conn_record.invitation_key + async with profile.session() as session: + wallet = session.inject(BaseWallet) + public_did = await wallet.get_public_did() + + # Do not replace key, if it is public + if public_did and public_did.verkey == conn_record.invitation_key: + replace_key = None + return await self._route_for_key( profile, my_info.verkey, mediation_record, - replace_key=conn_record.invitation_key, + replace_key=replace_key, skip_if_exists=True, ) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py index f64ecd0cfe..4d9efceb72 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py @@ -8,7 +8,9 @@ from .....messaging.responder import BaseResponder, MockResponder from .....storage.error import StorageNotFoundError from .....wallet.did_info import DIDInfo +from .....wallet.did_method import SOV from .....wallet.in_memory import InMemoryWallet +from .....wallet.key_type import ED25519 from ....routing.v1_0.models.route_record import RouteRecord from ..manager import MediationManager from ..messages.keylist_update import KeylistUpdate @@ -19,6 +21,7 @@ RouteManagerError, ) +TEST_RECORD_DID = "55GkHamhTU1ZbTbV2ab9DE" TEST_RECORD_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ROUTE_RECORD_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" @@ -313,6 +316,43 @@ async def test_route_connection_as_inviter( ) +@pytest.mark.asyncio +async def test_route_connection_state_inviter_replace_key_none( + profile: Profile, route_manager: RouteManager, conn_record: ConnRecord +): + mediation_record = MediationRecord(mediation_id="test-mediation-id") + mock_did_info = mock.MagicMock(DIDInfo) + conn_record.invitation_key = TEST_RECORD_VERKEY + + with mock.patch.object( + route_manager, + "get_or_create_my_did", + mock.CoroutineMock(return_value=mock_did_info), + ), mock.patch.object( + InMemoryWallet, + "get_public_did", + mock.CoroutineMock( + return_value=DIDInfo( + TEST_RECORD_DID, + TEST_RECORD_VERKEY, + None, + method=SOV, + key_type=ED25519, + ) + ), + ): + await route_manager.route_connection_as_inviter( + profile, conn_record, mediation_record + ) + route_manager._route_for_key.assert_called_once_with( + profile, + mock_did_info.verkey, + mediation_record, + replace_key=None, + skip_if_exists=True, + ) + + @pytest.mark.asyncio async def test_route_connection_state_invitee( profile: Profile, route_manager: RouteManager, conn_record: ConnRecord