diff --git a/aries_cloudagent/connections/models/conn_record.py b/aries_cloudagent/connections/models/conn_record.py index 672d9359e0..eb0342970a 100644 --- a/aries_cloudagent/connections/models/conn_record.py +++ b/aries_cloudagent/connections/models/conn_record.py @@ -174,6 +174,7 @@ def __eq__(self, other: Union[str, "ConnRecord.State"]) -> bool: "invitation_key", "their_public_did", "invitation_msg_id", + "their_role", } RECORD_TYPE = "connection" @@ -373,7 +374,7 @@ async def find_existing_connection( @classmethod async def retrieve_by_request_id( - cls, session: ProfileSession, request_id: str + cls, session: ProfileSession, request_id: str, their_role: str = None ) -> "ConnRecord": """Retrieve a connection record from our previous request ID. @@ -382,6 +383,8 @@ async def retrieve_by_request_id( request_id: The ID of the originating connection request """ tag_filter = {"request_id": request_id} + if their_role: + tag_filter["their_role"] = their_role return await cls.retrieve_by_tag_filter(session, tag_filter) @classmethod diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 282bbb8335..c7c9e51eaa 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -702,13 +702,24 @@ async def accept_response( conn_rec = None if response._thread: # identify the request by the thread ID - try: - async with self.profile.session() as session: + async with self.profile.session() as session: + try: conn_rec = await ConnRecord.retrieve_by_request_id( - session, response._thread_id + session, + response._thread_id, + their_role=ConnRecord.Role.RESPONDER.rfc23, ) - except StorageNotFoundError: - pass + except StorageNotFoundError: + pass + if not conn_rec: + try: + conn_rec = await ConnRecord.retrieve_by_request_id( + session, + response._thread_id, + their_role=ConnRecord.Role.RESPONDER.rfc160, + ) + except StorageNotFoundError: + pass if not conn_rec and receipt.sender_did: # identify connection by the DID they used for us @@ -809,12 +820,27 @@ async def accept_complete( conn_rec = None # identify the request by the thread ID - try: - async with self.profile.session() as session: + async with self.profile.session() as session: + try: conn_rec = await ConnRecord.retrieve_by_request_id( - session, complete._thread_id + session, + complete._thread_id, + their_role=ConnRecord.Role.REQUESTER.rfc23, ) - except StorageNotFoundError: + except StorageNotFoundError: + pass + + if not conn_rec: + try: + conn_rec = await ConnRecord.retrieve_by_request_id( + session, + complete._thread_id, + their_role=ConnRecord.Role.REQUESTER.rfc160, + ) + except StorageNotFoundError: + pass + + if not conn_rec: raise DIDXManagerError( "No corresponding connection request found", error_code=ProblemReportReason.COMPLETE_NOT_ACCEPTED.value,