Skip to content

Commit

Permalink
Merge pull request openwallet-foundation#1670 from ianco/fix/did-exch…
Browse files Browse the repository at this point in the history
…-self-connect

Qualify did exch connection lookup by role
  • Loading branch information
swcurran authored Mar 17, 2022
2 parents d4c60ff + 6f8644e commit 3774e85
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
5 changes: 4 additions & 1 deletion aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __eq__(self, other: Union[str, "ConnRecord.State"]) -> bool:
"invitation_key",
"their_public_did",
"invitation_msg_id",
"their_role",
}

RECORD_TYPE = "connection"
Expand Down Expand Up @@ -373,7 +374,7 @@ async def find_existing_connection(

@classmethod
async def retrieve_by_request_id(
cls, session: ProfileSession, request_id: str
cls, session: ProfileSession, request_id: str, their_role: str = None
) -> "ConnRecord":
"""Retrieve a connection record from our previous request ID.
Expand All @@ -382,6 +383,8 @@ async def retrieve_by_request_id(
request_id: The ID of the originating connection request
"""
tag_filter = {"request_id": request_id}
if their_role:
tag_filter["their_role"] = their_role
return await cls.retrieve_by_tag_filter(session, tag_filter)

@classmethod
Expand Down
44 changes: 35 additions & 9 deletions aries_cloudagent/protocols/didexchange/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,24 @@ async def accept_response(
conn_rec = None
if response._thread:
# identify the request by the thread ID
try:
async with self.profile.session() as session:
async with self.profile.session() as session:
try:
conn_rec = await ConnRecord.retrieve_by_request_id(
session, response._thread_id
session,
response._thread_id,
their_role=ConnRecord.Role.RESPONDER.rfc23,
)
except StorageNotFoundError:
pass
except StorageNotFoundError:
pass
if not conn_rec:
try:
conn_rec = await ConnRecord.retrieve_by_request_id(
session,
response._thread_id,
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
except StorageNotFoundError:
pass

if not conn_rec and receipt.sender_did:
# identify connection by the DID they used for us
Expand Down Expand Up @@ -809,12 +820,27 @@ async def accept_complete(
conn_rec = None

# identify the request by the thread ID
try:
async with self.profile.session() as session:
async with self.profile.session() as session:
try:
conn_rec = await ConnRecord.retrieve_by_request_id(
session, complete._thread_id
session,
complete._thread_id,
their_role=ConnRecord.Role.REQUESTER.rfc23,
)
except StorageNotFoundError:
except StorageNotFoundError:
pass

if not conn_rec:
try:
conn_rec = await ConnRecord.retrieve_by_request_id(
session,
complete._thread_id,
their_role=ConnRecord.Role.REQUESTER.rfc160,
)
except StorageNotFoundError:
pass

if not conn_rec:
raise DIDXManagerError(
"No corresponding connection request found",
error_code=ProblemReportReason.COMPLETE_NOT_ACCEPTED.value,
Expand Down

0 comments on commit 3774e85

Please sign in to comment.