From bdc7ccf84a165b3827a944afbc517a42435922a4 Mon Sep 17 00:00:00 2001 From: Daniel Bluhm Date: Sun, 21 Jul 2024 15:47:38 -0400 Subject: [PATCH] fix: multiuse invites with did peer 4 This corrects an issue where did:peer:4 connection records were failing to be resolved on inbound message, resulting in the multiuse invitation that created the connection being resolved instead. Fixes #3111. Signed-off-by: Daniel Bluhm --- .pre-commit-config.yaml | 1 - aries_cloudagent/connections/base_manager.py | 40 ++++++++-------- .../connections/models/conn_record.py | 46 +++++++++++++++++-- 3 files changed, 60 insertions(+), 27 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cded6ee9f6..5f019d9070 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,4 +17,3 @@ repos: # Run the formatter - id: ruff-format stages: [commit] - args: [--fix, --exit-non-zero-on-fix, --formatter] diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index 12b826cf5c..19b9439223 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -424,8 +424,6 @@ async def find_did_for_key(self, key: str) -> str: storage: BaseStorage = session.inject(BaseStorage) record = await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key}) ret_did = record.tags["did"] - if ret_did.startswith("did:peer:4"): - ret_did = self.long_did_peer_to_short(ret_did) return ret_did async def remove_keys_for_did(self, did: str): @@ -452,9 +450,7 @@ async def resolve_didcomm_services( doc_dict: dict = await resolver.resolve(self._profile, did, service_accept) doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True) except ResolverError as error: - raise BaseConnectionManagerError( - "Failed to resolve DID services" - ) from error + raise BaseConnectionManagerError("Failed to resolve DID services") from error if not doc.service: raise BaseConnectionManagerError( @@ -523,10 +519,7 @@ async def resolve_invitation( return ( endpoint, - [ - self._extract_key_material_in_base58_format(key) - for key in recipient_keys - ], + [self._extract_key_material_in_base58_format(key) for key in recipient_keys], [self._extract_key_material_in_base58_format(key) for key in routing_keys], ) @@ -800,9 +793,7 @@ async def get_connection_targets( async with cache.acquire(cache_key) as entry: if entry.result: self._logger.debug("Connection targets retrieved from cache") - targets = [ - ConnectionTarget.deserialize(row) for row in entry.result - ] + targets = [ConnectionTarget.deserialize(row) for row in entry.result] else: if not connection: async with self._profile.session() as session: @@ -817,9 +808,7 @@ async def get_connection_targets( # Otherwise, a replica that participated early in exchange # may have bad data set in cache. self._logger.debug("Caching connection targets") - await entry.set_result( - [row.serialize() for row in targets], 3600 - ) + await entry.set_result([row.serialize() for row in targets], 3600) else: self._logger.debug( "Not caching connection targets for connection in " @@ -878,12 +867,8 @@ def diddoc_connection_targets( did=doc.did, endpoint=service.endpoint, label=their_label, - recipient_keys=[ - key.value for key in (service.recip_keys or ()) - ], - routing_keys=[ - key.value for key in (service.routing_keys or ()) - ], + recipient_keys=[key.value for key in (service.recip_keys or ())], + routing_keys=[key.value for key in (service.routing_keys or ())], sender_key=sender_verkey, ) ) @@ -920,7 +905,18 @@ async def find_connection( """ connection = None - if their_did: + if their_did and their_did.startswith("did:peer:4"): + # did:peer:4 always recorded as long + long = their_did + short = self.long_did_peer_to_short(their_did) + try: + async with self._profile.session() as session: + connection = await ConnRecord.retrieve_by_did_peer_4( + session, long, short, my_did + ) + except StorageNotFoundError: + pass + elif their_did: try: async with self._profile.session() as session: connection = await ConnRecord.retrieve_by_did( diff --git a/aries_cloudagent/connections/models/conn_record.py b/aries_cloudagent/connections/models/conn_record.py index 9bf137b7f0..7ec90f3d73 100644 --- a/aries_cloudagent/connections/models/conn_record.py +++ b/aries_cloudagent/connections/models/conn_record.py @@ -217,7 +217,9 @@ def __init__( self.their_role = ( ConnRecord.Role.get(their_role).rfc160 if isinstance(their_role, str) - else None if their_role is None else their_role.rfc160 + else None + if their_role is None + else their_role.rfc160 ) self.invitation_key = invitation_key self.invitation_msg_id = invitation_msg_id @@ -293,6 +295,44 @@ async def retrieve_by_did( return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + @classmethod + async def retrieve_by_did_peer_4( + cls, + session: ProfileSession, + their_did_long: Optional[str] = None, + their_did_short: Optional[str] = None, + my_did: Optional[str] = None, + their_role: Optional[str] = None, + ) -> "ConnRecord": + """Retrieve a connection record by target DID. + + Args: + session: The active profile session + their_did_long: The target DID to filter by, in long form + their_did_short: The target DID to filter by, in short form + my_did: One of our DIDs to filter by + my_role: Filter connections by their role + their_role: Filter connections by their role + """ + tag_filter = {} + if their_did_long and their_did_short: + tag_filter["$or"] = [ + {"their_did": their_did_long}, + {"their_did": their_did_short}, + ] + elif their_did_short: + tag_filter["their_did"] = their_did_short + elif their_did_long: + tag_filter["their_did"] = their_did_long + if my_did: + tag_filter["my_did"] = my_did + + post_filter = {} + if their_role: + post_filter["their_role"] = cls.Role.get(their_role).rfc160 + + return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter) + @classmethod async def retrieve_by_invitation_key( cls, session: ProfileSession, invitation_key: str, their_role: str = None @@ -375,9 +415,7 @@ async def retrieve_by_request_id( return await cls.retrieve_by_tag_filter(session, tag_filter) @classmethod - async def retrieve_by_alias( - cls, session: ProfileSession, alias: str - ) -> "ConnRecord": + async def retrieve_by_alias(cls, session: ProfileSession, alias: str) -> "ConnRecord": """Retrieve a connection record from an alias. Args: