diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index dc59d54c49..200f6c1c00 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -46,7 +46,7 @@ from ..resolver.base import ResolverError from ..resolver.did_resolver import DIDResolver from ..storage.base import BaseStorage -from ..storage.error import StorageDuplicateError, StorageError, StorageNotFoundError +from ..storage.error import StorageDuplicateError, StorageNotFoundError from ..storage.record import StorageRecord from ..transport.inbound.receipt import MessageReceipt from ..utils.multiformats import multibase, multicodec @@ -854,9 +854,9 @@ async def fetch_did_document(self, did: str) -> Tuple[dict, StorageRecord]: async def find_connection( self, - their_did: str, + their_did: Optional[str], my_did: Optional[str] = None, - my_verkey: Optional[str] = None, + parent_thread_id: Optional[str] = None, auto_complete=False, ) -> Optional[ConnRecord]: """Look up existing connection information for a sender verkey. @@ -864,7 +864,7 @@ async def find_connection( Args: their_did: Their DID my_did: My DID - my_verkey: My verkey + parent_thread_id: Parent thread ID auto_complete: Should this connection automatically be promoted to active Returns: @@ -895,16 +895,13 @@ async def find_connection( connection_id=connection.connection_id ) - if not connection and my_verkey: - try: - async with self._profile.session() as session: - connection = await ConnRecord.retrieve_by_invitation_key( - session, - my_verkey, - their_role=ConnRecord.Role.REQUESTER.rfc160, - ) - except StorageError: - pass + if not connection and parent_thread_id: + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_invitation_msg_id( + session, + parent_thread_id, + their_role=ConnRecord.Role.REQUESTER.rfc160, + ) return connection @@ -1001,7 +998,7 @@ async def resolve_inbound_connection( ) return await self.find_connection( - receipt.sender_did, receipt.recipient_did, receipt.recipient_verkey, True + receipt.sender_did, receipt.recipient_did, receipt.parent_thread_id, True ) async def get_endpoints(self, conn_id: str) -> Tuple[Optional[str], Optional[str]]: diff --git a/aries_cloudagent/connections/tests/test_base_manager.py b/aries_cloudagent/connections/tests/test_base_manager.py index 38c5477c27..21f5615b69 100644 --- a/aries_cloudagent/connections/tests/test_base_manager.py +++ b/aries_cloudagent/connections/tests/test_base_manager.py @@ -79,6 +79,8 @@ async def asyncSetUp(self): self.test_target_did = "GbuDUYXaUZRfHD2jeDuQuP" self.test_target_verkey = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" + self.test_pthid = "test-pthid" + self.responder = MockResponder() self.oob_mock = mock.MagicMock( @@ -1645,7 +1647,7 @@ async def test_find_connection_retrieve_by_did(self): conn_rec = await self.manager.find_connection( their_did=self.test_target_did, my_did=self.test_did, - my_verkey=self.test_verkey, + parent_thread_id=self.test_pthid, auto_complete=True, ) assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED @@ -1665,7 +1667,7 @@ async def test_find_connection_retrieve_by_did_auto_disclose_features(self): conn_rec = await self.manager.find_connection( their_did=self.test_target_did, my_did=self.test_did, - my_verkey=self.test_verkey, + parent_thread_id=self.test_pthid, auto_complete=True, ) assert ConnRecord.State.get(conn_rec.state) is ConnRecord.State.COMPLETED @@ -1675,10 +1677,10 @@ async def test_find_connection_retrieve_by_invitation_key(self): with mock.patch.object( ConnRecord, "retrieve_by_did", mock.CoroutineMock() ) as mock_conn_retrieve_by_did, mock.patch.object( - ConnRecord, "retrieve_by_invitation_key", mock.CoroutineMock() - ) as mock_conn_retrieve_by_invitation_key: + ConnRecord, "retrieve_by_invitation_msg_id", mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id: mock_conn_retrieve_by_did.side_effect = StorageNotFoundError() - mock_conn_retrieve_by_invitation_key.return_value = mock.MagicMock( + mock_conn_retrieve_by_invitation_msg_id.return_value = mock.MagicMock( state=ConnRecord.State.RESPONSE, save=mock.CoroutineMock(), ) @@ -1686,7 +1688,7 @@ async def test_find_connection_retrieve_by_invitation_key(self): conn_rec = await self.manager.find_connection( their_did=self.test_target_did, my_did=self.test_did, - my_verkey=self.test_verkey, + parent_thread_id=self.test_pthid, ) assert conn_rec @@ -1695,14 +1697,14 @@ async def test_find_connection_retrieve_none_by_invitation_key(self): ConnRecord, "retrieve_by_did", mock.CoroutineMock() ) as mock_conn_retrieve_by_did, mock.patch.object( ConnRecord, "retrieve_by_invitation_key", mock.CoroutineMock() - ) as mock_conn_retrieve_by_invitation_key: + ) as mock_conn_retrieve_by_invitation_msg_id: mock_conn_retrieve_by_did.side_effect = StorageNotFoundError() - mock_conn_retrieve_by_invitation_key.side_effect = StorageNotFoundError() + mock_conn_retrieve_by_invitation_msg_id.return_value = None conn_rec = await self.manager.find_connection( their_did=self.test_target_did, my_did=self.test_did, - my_verkey=self.test_verkey, + parent_thread_id=self.test_pthid, ) assert conn_rec is None diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 0b12931e12..6ae8fd5da1 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -522,9 +522,7 @@ async def receive_request( ) if recipient_verkey: - conn_rec = await self._receive_request_pairwise_did( - request, recipient_verkey, alias - ) + conn_rec = await self._receive_request_pairwise_did(request, alias) else: conn_rec = await self._receive_request_public_did( request, recipient_did, alias, auto_accept_implicit @@ -539,22 +537,23 @@ async def receive_request( async def _receive_request_pairwise_did( self, request: DIDXRequest, - recipient_verkey: str, alias: Optional[str] = None, ) -> ConnRecord: """Receive a DID Exchange request against a pairwise (not public) DID.""" - try: - async with self.profile.session() as session: - conn_rec = await ConnRecord.retrieve_by_invitation_key( - session=session, - invitation_key=recipient_verkey, - their_role=ConnRecord.Role.REQUESTER.rfc23, - ) - except StorageNotFoundError: + if not request._thread.pthid: + raise DIDXManagerError("DID Exchange request missing parent thread ID") + + async with self.profile.session() as session: + conn_rec = await ConnRecord.retrieve_by_invitation_msg_id( + session=session, + invitation_msg_id=request._thread.pthid, + their_role=ConnRecord.Role.REQUESTER.rfc23, + ) + + if not conn_rec: raise DIDXManagerError( - "No explicit invitation found for pairwise connection " - f"in state {ConnRecord.State.INVITATION.rfc23}: " - "a prior connection request may have updated the connection state" + "Pairwise requests must be against explicit invitations that have not " + "been previously consumed" ) if conn_rec.is_multiuse_invitation: diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index 5066aaf9dd..f3c5d6d969 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -721,8 +721,8 @@ async def test_receive_request_invi_not_found(self): with mock.patch.object( test_module, "ConnRecord", mock.MagicMock() ) as mock_conn_rec_cls: - mock_conn_rec_cls.retrieve_by_invitation_key = mock.CoroutineMock( - side_effect=StorageNotFoundError() + mock_conn_rec_cls.retrieve_by_invitation_msg_id = mock.CoroutineMock( + return_value=None ) with self.assertRaises(DIDXManagerError) as context: await self.manager.receive_request( @@ -732,7 +732,7 @@ async def test_receive_request_invi_not_found(self): alias=None, auto_accept_implicit=None, ) - assert "No explicit invitation found" in str(context.exception) + assert "explicit invitations" in str(context.exception) async def test_receive_request_public_did_no_did_doc_attachment(self): async with self.profile.session() as session: @@ -1376,7 +1376,7 @@ async def test_receive_request_peer_did(self): ), mock.patch.object( self.manager, "store_did_document", mock.CoroutineMock() ): - mock_conn_rec_cls.retrieve_by_invitation_key = mock.CoroutineMock( + mock_conn_rec_cls.retrieve_by_invitation_msg_id = mock.CoroutineMock( return_value=mock_conn ) mock_conn_rec_cls.return_value = mock.MagicMock( @@ -1435,8 +1435,8 @@ async def test_receive_request_peer_did_not_found_x(self): with mock.patch.object( test_module, "ConnRecord", mock.MagicMock() ) as mock_conn_rec_cls: - mock_conn_rec_cls.retrieve_by_invitation_key = mock.CoroutineMock( - side_effect=StorageNotFoundError() + mock_conn_rec_cls.retrieve_by_invitation_msg_id = mock.CoroutineMock( + return_value=None ) with self.assertRaises(DIDXManagerError): await self.manager.receive_request(