Skip to content

Commit

Permalink
fix: multiuse invites with did peer 4
Browse files Browse the repository at this point in the history
This corrects an issue where did:peer:4 connection records were failing
to be resolved on inbound message, resulting in the multiuse invitation
that created the connection being resolved instead.

Fixes #3111.

Signed-off-by: Daniel Bluhm <[email protected]>
  • Loading branch information
dbluhm committed Jul 21, 2024
1 parent e1dea37 commit bdc7ccf
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 27 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ repos:
# Run the formatter
- id: ruff-format
stages: [commit]
args: [--fix, --exit-non-zero-on-fix, --formatter]
40 changes: 18 additions & 22 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,6 @@ async def find_did_for_key(self, key: str) -> str:
storage: BaseStorage = session.inject(BaseStorage)
record = await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key})
ret_did = record.tags["did"]
if ret_did.startswith("did:peer:4"):
ret_did = self.long_did_peer_to_short(ret_did)
return ret_did

async def remove_keys_for_did(self, did: str):
Expand All @@ -452,9 +450,7 @@ async def resolve_didcomm_services(
doc_dict: dict = await resolver.resolve(self._profile, did, service_accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
raise BaseConnectionManagerError(
"Failed to resolve DID services"
) from error
raise BaseConnectionManagerError("Failed to resolve DID services") from error

if not doc.service:
raise BaseConnectionManagerError(
Expand Down Expand Up @@ -523,10 +519,7 @@ async def resolve_invitation(

return (
endpoint,
[
self._extract_key_material_in_base58_format(key)
for key in recipient_keys
],
[self._extract_key_material_in_base58_format(key) for key in recipient_keys],
[self._extract_key_material_in_base58_format(key) for key in routing_keys],
)

Expand Down Expand Up @@ -800,9 +793,7 @@ async def get_connection_targets(
async with cache.acquire(cache_key) as entry:
if entry.result:
self._logger.debug("Connection targets retrieved from cache")
targets = [
ConnectionTarget.deserialize(row) for row in entry.result
]
targets = [ConnectionTarget.deserialize(row) for row in entry.result]
else:
if not connection:
async with self._profile.session() as session:
Expand All @@ -817,9 +808,7 @@ async def get_connection_targets(
# Otherwise, a replica that participated early in exchange
# may have bad data set in cache.
self._logger.debug("Caching connection targets")
await entry.set_result(
[row.serialize() for row in targets], 3600
)
await entry.set_result([row.serialize() for row in targets], 3600)
else:
self._logger.debug(
"Not caching connection targets for connection in "
Expand Down Expand Up @@ -878,12 +867,8 @@ def diddoc_connection_targets(
did=doc.did,
endpoint=service.endpoint,
label=their_label,
recipient_keys=[
key.value for key in (service.recip_keys or ())
],
routing_keys=[
key.value for key in (service.routing_keys or ())
],
recipient_keys=[key.value for key in (service.recip_keys or ())],
routing_keys=[key.value for key in (service.routing_keys or ())],
sender_key=sender_verkey,
)
)
Expand Down Expand Up @@ -920,7 +905,18 @@ async def find_connection(
"""
connection = None
if their_did:
if their_did and their_did.startswith("did:peer:4"):
# did:peer:4 always recorded as long
long = their_did
short = self.long_did_peer_to_short(their_did)
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did_peer_4(
session, long, short, my_did
)
except StorageNotFoundError:
pass
elif their_did:
try:
async with self._profile.session() as session:
connection = await ConnRecord.retrieve_by_did(
Expand Down
46 changes: 42 additions & 4 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def __init__(
self.their_role = (
ConnRecord.Role.get(their_role).rfc160
if isinstance(their_role, str)
else None if their_role is None else their_role.rfc160
else None
if their_role is None
else their_role.rfc160
)
self.invitation_key = invitation_key
self.invitation_msg_id = invitation_msg_id
Expand Down Expand Up @@ -293,6 +295,44 @@ async def retrieve_by_did(

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_did_peer_4(
cls,
session: ProfileSession,
their_did_long: Optional[str] = None,
their_did_short: Optional[str] = None,
my_did: Optional[str] = None,
their_role: Optional[str] = None,
) -> "ConnRecord":
"""Retrieve a connection record by target DID.
Args:
session: The active profile session
their_did_long: The target DID to filter by, in long form
their_did_short: The target DID to filter by, in short form
my_did: One of our DIDs to filter by
my_role: Filter connections by their role
their_role: Filter connections by their role
"""
tag_filter = {}
if their_did_long and their_did_short:
tag_filter["$or"] = [
{"their_did": their_did_long},
{"their_did": their_did_short},
]
elif their_did_short:
tag_filter["their_did"] = their_did_short
elif their_did_long:
tag_filter["their_did"] = their_did_long
if my_did:
tag_filter["my_did"] = my_did

post_filter = {}
if their_role:
post_filter["their_role"] = cls.Role.get(their_role).rfc160

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_invitation_key(
cls, session: ProfileSession, invitation_key: str, their_role: str = None
Expand Down Expand Up @@ -375,9 +415,7 @@ async def retrieve_by_request_id(
return await cls.retrieve_by_tag_filter(session, tag_filter)

@classmethod
async def retrieve_by_alias(
cls, session: ProfileSession, alias: str
) -> "ConnRecord":
async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRecord":
"""Retrieve a connection record from an alias.
Args:
Expand Down

0 comments on commit bdc7ccf

Please sign in to comment.