From ff91cabb4886fa9dc206c8ad3f877a7257cf8978 Mon Sep 17 00:00:00 2001 From: Shaanjot Gill Date: Tue, 30 Nov 2021 11:59:45 -0800 Subject: [PATCH] fix Signed-off-by: Shaanjot Gill --- .../connections/models/conn_record.py | 19 ++++++ .../models/tests/test_conn_record.py | 23 ++++++++ .../protocols/connections/v1_0/manager.py | 36 ++++++----- .../connections/v1_0/tests/test_manager.py | 59 +++++++++++++++++-- .../protocols/didexchange/v1_0/manager.py | 28 +++++---- .../didexchange/v1_0/tests/test_manager.py | 24 ++++---- 6 files changed, 149 insertions(+), 40 deletions(-) diff --git a/aries_cloudagent/connections/models/conn_record.py b/aries_cloudagent/connections/models/conn_record.py index c456b2d050..4e5981e777 100644 --- a/aries_cloudagent/connections/models/conn_record.py +++ b/aries_cloudagent/connections/models/conn_record.py @@ -322,6 +322,25 @@ async def retrieve_by_invitation_key( return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + @classmethod + async def retrieve_by_invitation_msg_id( + cls, session: ProfileSession, invitation_msg_id: str, their_role: str = None + ) -> "ConnRecord": + """Retrieve a connection record by invitation_msg_id. + + Args: + session: The active profile session + invitation_msg_id: Invitation message identifier + initiator: Filter by the initiator value + """ + post_filter = { + "state": cls.State.INVITATION.rfc160, + "invitation_msg_id": invitation_msg_id, + } + if their_role: + post_filter["their_role"] = cls.Role.get(their_role).rfc160 + return await cls.query(session, post_filter_positive=post_filter) + @classmethod async def retrieve_by_request_id( cls, session: ProfileSession, request_id: str diff --git a/aries_cloudagent/connections/models/tests/test_conn_record.py b/aries_cloudagent/connections/models/tests/test_conn_record.py index 2923065cfa..d11caf9804 100644 --- a/aries_cloudagent/connections/models/tests/test_conn_record.py +++ b/aries_cloudagent/connections/models/tests/test_conn_record.py @@ -178,6 +178,29 @@ async def test_retrieve_by_invitation_key(self): their_role=ConnRecord.Role.REQUESTER.rfc23, ) + async def test_retrieve_by_invitation_msg_id(self): + record = ConnRecord( + my_did=self.test_did, + their_did=self.test_target_did, + their_role=ConnRecord.Role.RESPONDER.rfc160, + state=ConnRecord.State.INVITATION.rfc160, + invitation_msg_id="test123", + ) + await record.save(self.session) + results = await ConnRecord.retrieve_by_invitation_msg_id( + session=self.session, + invitation_msg_id="test123", + their_role=ConnRecord.Role.RESPONDER.rfc160, + ) + assert len(results) == 1 + assert results[0] == record + results = await ConnRecord.retrieve_by_invitation_msg_id( + session=self.session, + invitation_msg_id="test123", + their_role=ConnRecord.Role.REQUESTER.rfc160, + ) + assert len(results) == 0 + async def test_retrieve_by_request_id(self): record = ConnRecord( my_did=self.test_did, diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index 12a44315fb..a2441132ad 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -439,6 +439,7 @@ async def create_request( connection=ConnectionDetail(did=connection.my_did, did_doc=did_doc), image_url=self.profile.settings.get("image_url"), ) + request.assign_thread_id(thid=request._id, pthid=connection.invitation_msg_id) # Update connection state connection.request_id = request._id @@ -601,21 +602,28 @@ async def receive_request( # Add mapping for multitenant relay if multitenant_mgr and wallet_id: await multitenant_mgr.add_key(wallet_id, my_info.verkey) - - connection = ConnRecord( - invitation_key=connection_key, - my_did=my_info.did, - their_role=ConnRecord.Role.RESPONDER.rfc160, - their_did=request.connection.did, - their_label=request.label, - accept=( - ConnRecord.ACCEPT_AUTO - if self.profile.settings.get("debug.auto_accept_requests") - else ConnRecord.ACCEPT_MANUAL - ), - state=ConnRecord.State.REQUEST.rfc160, - connection_protocol=CONN_PROTO, + async with self.profile.session() as session: + conn_records = await ConnRecord.retrieve_by_invitation_msg_id( + session=session, + invitation_msg_id=request._thread.pthid, + their_role=ConnRecord.Role.REQUESTER.rfc160, + ) + if len(conn_records) == 1: + connection = conn_records[0] + else: + connection = ConnRecord() + connection.invitation_key = connection_key + connection.my_did = my_info.did + connection.their_role = ConnRecord.Role.RESPONDER.rfc160 + connection.their_did = request.connection.did + connection.their_label = request.label + connection.accept = ( + ConnRecord.ACCEPT_AUTO + if self.profile.settings.get("debug.auto_accept_requests") + else ConnRecord.ACCEPT_MANUAL ) + connection.state = ConnRecord.State.REQUEST.rfc160 + connection.connection_protocol = CONN_PROTO async with self.profile.session() as session: await connection.save( session, reason="Received connection request from public DID" diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index 8201927cac..33743caa1f 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -677,7 +677,7 @@ async def test_create_request_mediation_not_granted(self): record, mediation_id=mediation_record.mediation_id ) - async def test_receive_request_public_did(self): + async def test_receive_request_public_did_oob_invite(self): async with self.profile.session() as session: mock_request = async_mock.MagicMock() mock_request.connection = async_mock.MagicMock() @@ -706,7 +706,52 @@ async def test_receive_request_public_did(self): ConnRecord, "retrieve_by_id", autospec=True ) as mock_conn_retrieve_by_id, async_mock.patch.object( ConnRecord, "retrieve_request", autospec=True - ): + ), async_mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id: + mock_conn_retrieve_by_invitation_msg_id.return_value = [ConnRecord()] + conn_rec = await self.manager.receive_request(mock_request, receipt) + assert conn_rec + + messages = self.responder.messages + assert len(messages) == 1 + (result, target) = messages[0] + assert type(result) == ConnectionResponse + assert "connection_id" in target + + async def test_receive_request_public_did_conn_invite(self): + async with self.profile.session() as session: + mock_request = async_mock.MagicMock() + mock_request.connection = async_mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = async_mock.MagicMock() + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + await session.wallet.create_local_did( + method=DIDMethod.SOV, + key_type=KeyType.ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": True}) + with async_mock.patch.object( + ConnRecord, "connection_id", autospec=True + ), async_mock.patch.object( + ConnRecord, "save", autospec=True + ) as mock_conn_rec_save, async_mock.patch.object( + ConnRecord, "attach_request", autospec=True + ) as mock_conn_attach_request, async_mock.patch.object( + ConnRecord, "retrieve_by_id", autospec=True + ) as mock_conn_retrieve_by_id, async_mock.patch.object( + ConnRecord, "retrieve_request", autospec=True + ), async_mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id: + mock_conn_retrieve_by_invitation_msg_id.return_value = [] conn_rec = await self.manager.receive_request(mock_request, receipt) assert conn_rec @@ -800,7 +845,10 @@ async def test_receive_request_public_multitenant(self): InMemoryWallet, "create_local_did", autospec=True ) as mock_wallet_create_local_did, async_mock.patch.object( InMemoryWallet, "get_local_did", autospec=True - ) as mock_wallet_get_local_did: + ) as mock_wallet_get_local_did, async_mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id: + mock_conn_retrieve_by_invitation_msg_id.return_value = [ConnRecord()] mock_wallet_create_local_did.return_value = DIDInfo( new_info.did, new_info.verkey, @@ -940,7 +988,10 @@ async def test_receive_request_public_did_no_auto_accept(self): ConnRecord, "retrieve_by_id", autospec=True ) as mock_conn_retrieve_by_id, async_mock.patch.object( ConnRecord, "retrieve_request", autospec=True - ): + ), async_mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id: + mock_conn_retrieve_by_invitation_msg_id.return_value = [ConnRecord()] conn_rec = await self.manager.receive_request(mock_request, receipt) assert conn_rec diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 6754b2d345..0e123b3961 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -375,6 +375,20 @@ async def receive_request( # Determine what key will need to sign the response if recipient_verkey: # peer DID connection_key = recipient_verkey + try: + async with self.profile.session() as session: + conn_rec = await ConnRecord.retrieve_by_invitation_key( + session=session, + invitation_key=connection_key, + their_role=ConnRecord.Role.REQUESTER.rfc23, + ) + except StorageNotFoundError: + if recipient_verkey: + raise DIDXManagerError( + "No explicit invitation found for pairwise connection " + f"in state {ConnRecord.State.INVITATION.rfc23}: " + "a prior connection request may have updated the connection state" + ) else: if not self.profile.settings.get("public_invites"): raise DIDXManagerError( @@ -391,20 +405,14 @@ async def receive_request( raise DIDXManagerError(f"Request DID {recipient_did} is not public") connection_key = my_info.verkey - try: async with self.profile.session() as session: - conn_rec = await ConnRecord.retrieve_by_invitation_key( + conn_records = await ConnRecord.retrieve_by_invitation_msg_id( session=session, - invitation_key=connection_key, + invitation_msg_id=request._thread.pthid, their_role=ConnRecord.Role.REQUESTER.rfc23, ) - except StorageNotFoundError: - if recipient_verkey: - raise DIDXManagerError( - "No explicit invitation found for pairwise connection " - f"in state {ConnRecord.State.INVITATION.rfc23}: " - "a prior connection request may have updated the connection state" - ) + if len(conn_records) == 1: + conn_rec = conn_records[0] if conn_rec: # invitation was explicit connection_key = conn_rec.invitation_key diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index 08344c22dd..9a8049b1d0 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -507,8 +507,8 @@ async def test_receive_request_explicit_public_did(self): mock_conn_rec_cls.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock(save=async_mock.CoroutineMock()) ) - mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( - return_value=mock_conn_record + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=[mock_conn_record]) ) mock_conn_rec_cls.return_value = mock_conn_record @@ -783,8 +783,8 @@ async def test_receive_request_public_did_no_did_doc_attachment(self): save=async_mock.CoroutineMock(), ) mock_conn_rec_cls.return_value = mock_conn_record - mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( - return_value=mock_conn_record + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=[mock_conn_record]) ) mock_did_posture.get = async_mock.MagicMock( @@ -890,8 +890,8 @@ async def test_receive_request_public_did_x_wrong_did(self): save=async_mock.CoroutineMock(), ) mock_conn_rec_cls.return_value = mock_conn_record - mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( - return_value=mock_conn_record + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=[mock_conn_record]) ) mock_did_doc_from_json.return_value = async_mock.MagicMock( did="wrong-did" @@ -950,8 +950,8 @@ async def test_receive_request_public_did_x_did_doc_attach_bad_sig(self): save=async_mock.CoroutineMock(), ) mock_conn_rec_cls.return_value = mock_conn_record - mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( - return_value=mock_conn_record + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=[mock_conn_record]) ) mock_did_posture.get = async_mock.MagicMock( @@ -1068,8 +1068,8 @@ async def test_receive_request_public_did_no_auto_accept(self): save=async_mock.CoroutineMock(), ) mock_conn_rec_cls.return_value = mock_conn_record - mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( - return_value=mock_conn_record + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=[mock_conn_record]) ) mock_did_posture.get = async_mock.MagicMock( @@ -1306,8 +1306,8 @@ async def test_receive_request_implicit_multitenant(self): retrieve_request=async_mock.CoroutineMock(), ) mock_conn_rec_cls.return_value = mock_conn_rec - mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( - side_effect=StorageNotFoundError() + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=[]) ) mock_did_posture.get = async_mock.MagicMock(