Skip to content

Commit

Permalink
Merge branch 'main' into m1-build-issue
Browse files Browse the repository at this point in the history
  • Loading branch information
swcurran authored Mar 19, 2024
2 parents 915b60e + 9da989c commit 997a20f
Show file tree
Hide file tree
Showing 17 changed files with 504 additions and 103 deletions.
52 changes: 47 additions & 5 deletions aries_cloudagent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import json
import logging
from typing import List, Optional, Sequence, Text, Tuple, Union
from typing import Dict, List, Optional, Sequence, Text, Tuple, Union

import pydid
from base58 import b58decode
Expand Down Expand Up @@ -52,7 +52,7 @@
from ..utils.multiformats import multibase, multicodec
from ..wallet.base import BaseWallet
from ..wallet.crypto import create_keypair, seed_to_did
from ..wallet.did_info import DIDInfo, KeyInfo
from ..wallet.did_info import DIDInfo, KeyInfo, INVITATION_REUSE_KEY
from ..wallet.did_method import PEER2, PEER4, SOV
from ..wallet.error import WalletNotFoundError
from ..wallet.key_type import ED25519
Expand Down Expand Up @@ -89,6 +89,12 @@ def _key_info_to_multikey(key_info: KeyInfo) -> str:
multicodec.wrap("ed25519-pub", b58decode(key_info.verkey)), "base58btc"
)

def long_did_peer_to_short(self, long_did: str) -> DIDInfo:
"""Convert did:peer:4 long format to short format and return."""

short_did_peer = long_to_short(long_did)
return short_did_peer

async def long_did_peer_4_to_short(self, long_dp4: str) -> DIDInfo:
"""Convert did:peer:4 long format to short format and store in wallet."""

Expand All @@ -113,6 +119,7 @@ async def create_did_peer_4(
self,
svc_endpoints: Optional[Sequence[str]] = None,
mediation_records: Optional[List[MediationRecord]] = None,
metadata: Optional[Dict] = None,
) -> DIDInfo:
"""Create a did:peer:4 DID for a connection.
Expand Down Expand Up @@ -159,8 +166,13 @@ async def create_did_peer_4(
)
did = encode(input_doc)

did_metadata = metadata if metadata else {}
did_info = DIDInfo(
did=did, method=PEER4, verkey=key.verkey, metadata={}, key_type=ED25519
did=did,
method=PEER4,
verkey=key.verkey,
metadata=did_metadata,
key_type=ED25519,
)
await wallet.store_did(did_info)

Expand All @@ -170,6 +182,7 @@ async def create_did_peer_2(
self,
svc_endpoints: Optional[Sequence[str]] = None,
mediation_records: Optional[List[MediationRecord]] = None,
metadata: Optional[Dict] = None,
) -> DIDInfo:
"""Create a did:peer:2 DID for a connection.
Expand Down Expand Up @@ -215,13 +228,39 @@ async def create_did_peer_2(
[KeySpec.verification(self._key_info_to_multikey(key))], services
)

did_metadata = metadata if metadata else {}
did_info = DIDInfo(
did=did, method=PEER2, verkey=key.verkey, metadata={}, key_type=ED25519
did=did,
method=PEER2,
verkey=key.verkey,
metadata=did_metadata,
key_type=ED25519,
)
await wallet.store_did(did_info)

return did_info

async def fetch_invitation_reuse_did(
self,
did_method: str,
) -> DIDDoc:
"""Fetch a DID from the wallet to use across multiple invitations.
Args:
did_method: The DID method used (e.g. PEER2 or PEER4)
Returns:
The `DIDDoc` instance, or "None" if no DID is found
"""
did_info = None
async with self._profile.session() as session:
wallet = session.inject(BaseWallet)
did_list = await wallet.get_local_dids()
for did in did_list:
if did.method == did_method and INVITATION_REUSE_KEY in did.metadata:
return did
return did_info

async def create_did_document(
self,
did_info: DIDInfo,
Expand Down Expand Up @@ -346,7 +385,10 @@ async def find_did_for_key(self, key: str) -> str:
async with self._profile.session() as session:
storage: BaseStorage = session.inject(BaseStorage)
record = await storage.find_record(self.RECORD_TYPE_DID_KEY, {"key": key})
return record.tags["did"]
ret_did = record.tags["did"]
if ret_did.startswith("did:peer:4"):
ret_did = self.long_did_peer_to_short(ret_did)
return ret_did

async def remove_keys_for_did(self, did: str):
"""Remove all keys associated with a DID.
Expand Down
9 changes: 6 additions & 3 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,16 @@ async def retrieve_by_invitation_msg_id(
async def find_existing_connection(
cls, session: ProfileSession, their_public_did: str
) -> Optional["ConnRecord"]:
"""Retrieve existing active connection records (public did).
"""Retrieve existing active connection records (public did or did:peer).
Args:
session: The active profile session
their_public_did: Inviter public DID
their_public_did: Inviter public DID (or did:peer)
"""
tag_filter = {"their_public_did": their_public_did}
if their_public_did.startswith("did:peer"):
tag_filter = {"their_did": their_public_did}
else:
tag_filter = {"their_public_did": their_public_did}
conn_records = await cls.query(
session,
tag_filter=tag_filter,
Expand Down
117 changes: 112 additions & 5 deletions aries_cloudagent/protocols/out_of_band/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from ....storage.error import StorageNotFoundError
from ....transport.inbound.receipt import MessageReceipt
from ....wallet.base import BaseWallet
from ....wallet.did_info import INVITATION_REUSE_KEY
from ....wallet.did_method import PEER2, PEER4
from ....wallet.key_type import ED25519
from ...connections.v1_0.manager import ConnectionManager
from ...connections.v1_0.messages.connection_invitation import ConnectionInvitation
Expand Down Expand Up @@ -81,8 +83,11 @@ async def create_invitation(
my_endpoint: str = None,
auto_accept: bool = None,
public: bool = False,
did_peer_2: bool = False,
did_peer_4: bool = False,
hs_protos: Sequence[HSProto] = None,
multi_use: bool = False,
create_unique_did: bool = False,
alias: str = None,
attachments: Sequence[Mapping] = None,
metadata: dict = None,
Expand Down Expand Up @@ -279,6 +284,99 @@ async def create_invitation(
routing_keys=[],
).serialize()

elif did_peer_4 or did_peer_2:
mediation_records = [mediation_record] if mediation_record else []

if my_endpoint:
my_endpoints = [my_endpoint]
else:
my_endpoints = []
default_endpoint = self.profile.settings.get("default_endpoint")
if default_endpoint:
my_endpoints.append(default_endpoint)
my_endpoints.extend(
self.profile.settings.get("additional_endpoints", [])
)

my_info = None
my_did = None
if not create_unique_did:
# check wallet to see if there is an existing "invitation" DID available
did_method = PEER4 if did_peer_4 else PEER2
my_info = await self.fetch_invitation_reuse_did(did_method)
if my_info:
my_did = my_info.did
else:
LOGGER.warn("No invitation DID found, creating new DID")

if not my_did:
did_metadata = (
{INVITATION_REUSE_KEY: "true"} if not create_unique_did else {}
)
if did_peer_4:
my_info = await self.create_did_peer_4(
my_endpoints, mediation_records, did_metadata
)
my_did = my_info.did
else:
my_info = await self.create_did_peer_2(
my_endpoints, mediation_records, did_metadata
)
my_did = my_info.did

invi_msg = InvitationMessage( # create invitation message
_id=invitation_message_id,
label=my_label or self.profile.settings.get("default_label"),
handshake_protocols=handshake_protocols,
requests_attach=message_attachments,
services=[my_did],
accept=service_accept if protocol_version != "1.0" else None,
version=protocol_version or DEFAULT_VERSION,
image_url=image_url,
)
invi_url = invi_msg.to_url()

our_recipient_key = my_info.verkey

# Only create connection record if hanshake_protocols is defined
if handshake_protocols:
invitation_mode = (
ConnRecord.INVITATION_MODE_MULTI
if multi_use
else ConnRecord.INVITATION_MODE_ONCE
)
conn_rec = ConnRecord( # create connection record
invitation_key=our_recipient_key,
invitation_msg_id=invi_msg._id,
invitation_mode=invitation_mode,
their_role=ConnRecord.Role.REQUESTER.rfc23,
state=ConnRecord.State.INVITATION.rfc23,
accept=(
ConnRecord.ACCEPT_AUTO
if auto_accept
else ConnRecord.ACCEPT_MANUAL
),
alias=alias,
connection_protocol=connection_protocol,
my_did=my_did,
)

async with self.profile.session() as session:
await conn_rec.save(session, reason="Created new invitation")
await conn_rec.attach_invitation(session, invi_msg)

await conn_rec.attach_invitation(session, invi_msg)

if metadata:
for key, value in metadata.items():
await conn_rec.metadata_set(session, key, value)
else:
our_service = ServiceDecorator(
recipient_keys=[our_recipient_key],
endpoint=endpoint,
routing_keys=[],
).serialize()

else:
if not my_endpoint:
my_endpoint = self.profile.settings.get("default_endpoint")
Expand Down Expand Up @@ -454,7 +552,7 @@ async def receive_invitation(
# service_accept
service_accept = invitation.accept

# Get the DID public did, if any
# Get the DID public did, if any (might also be a did:peer)
public_did = None
if isinstance(oob_service_item, str):
if bool(IndyDID.PATTERN.match(oob_service_item)):
Expand All @@ -465,17 +563,21 @@ async def receive_invitation(
conn_rec = None

# Find existing connection - only if started by an invitation with Public DID
# and use_existing_connection is true
# (or did:peer) and use_existing_connection is true
if (
public_did is not None and use_existing_connection
): # invite has public DID: seek existing connection
LOGGER.debug(
"Trying to find existing connection for oob invitation with "
f"did {public_did}"
)
if public_did.startswith("did:peer:4"):
search_public_did = self.long_did_peer_to_short(public_did)
else:
search_public_did = public_did
async with self._profile.session() as session:
conn_rec = await ConnRecord.find_existing_connection(
session=session, their_public_did=public_did
session=session, their_public_did=search_public_did
)

oob_record = OobRecord(
Expand Down Expand Up @@ -809,8 +911,10 @@ async def _perform_handshake(
# If it's in the did format, we need to convert to a full service block
# An existing connection can only be reused based on a public DID
# in an out-of-band message (RFC 0434).
# OR did:peer:2 or did:peer:4.

public_did = service.split(":")[-1]
if not service.startswith("did:peer"):
public_did = service.split(":")[-1]

# TODO: resolve_invitation should resolve key_info objects
# or something else that includes the key type. We now assume
Expand All @@ -835,7 +939,10 @@ async def _perform_handshake(
}
)

LOGGER.debug(f"Creating connection with public did {public_did}")
if public_did:
LOGGER.debug(f"Creating connection with public did {public_did}")
else:
LOGGER.debug(f"Creating connection with service {service}")

conn_record = None
for protocol in supported_handshake_protocols:
Expand Down
24 changes: 24 additions & 0 deletions aries_cloudagent/protocols/out_of_band/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class InvitationCreateQueryStringSchema(OpenAPISchema):
required=False,
metadata={"description": "Create invitation for multiple use (default false)"},
)
create_unique_did = fields.Boolean(
required=False,
metadata={
"description": "Create unique DID for this invitation (default false)"
},
)


class InvitationCreateRequestSchema(OpenAPISchema):
Expand Down Expand Up @@ -231,18 +237,36 @@ async def invitation_create(request: web.BaseRequest):

multi_use = json.loads(request.query.get("multi_use", "false"))
auto_accept = json.loads(request.query.get("auto_accept", "null"))
create_unique_did = json.loads(request.query.get("create_unique_did", "false"))

if create_unique_did and use_public_did:
raise web.HTTPBadRequest(
reason="create_unique_did cannot be used with use_public_did"
)

profile = context.profile

emit_did_peer_4 = profile.settings.get("emit_did_peer_4", False)
emit_did_peer_2 = profile.settings.get("emit_did_peer_2", False)
if emit_did_peer_2 and emit_did_peer_4:
LOGGER.warning(
"emit_did_peer_2 and emit_did_peer_4 both set, \
using did:peer:4"
)

oob_mgr = OutOfBandManager(profile)
try:
invi_rec = await oob_mgr.create_invitation(
my_label=my_label,
auto_accept=auto_accept,
public=use_public_did,
did_peer_2=emit_did_peer_2,
did_peer_4=emit_did_peer_4,
hs_protos=[
h for h in [HSProto.get(hsp) for hsp in handshake_protocols] if h
],
multi_use=multi_use,
create_unique_did=create_unique_did,
attachments=attachments,
metadata=metadata,
alias=alias,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ async def test_invitation_create(self):
my_label=None,
auto_accept=True,
public=True,
did_peer_2=False,
did_peer_4=False,
multi_use=True,
create_unique_did=False,
hs_protos=[test_module.HSProto.RFC23],
attachments=body["attachments"],
metadata=body["metadata"],
Expand Down Expand Up @@ -109,7 +112,10 @@ async def test_invitation_create_with_accept(self):
my_label=None,
auto_accept=True,
public=True,
did_peer_2=False,
did_peer_4=False,
multi_use=True,
create_unique_did=False,
hs_protos=[test_module.HSProto.RFC23],
attachments=body["attachments"],
metadata=body["metadata"],
Expand Down
8 changes: 6 additions & 2 deletions aries_cloudagent/resolver/default/peer3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ async def remove_record_for_deleted_conn(self, profile: Profile, event: Event):
if not their_did and not my_did:
return
dids = [
*(did for did in (their_did, my_did) if PEER3_PATTERN.match(did)),
*(peer2to3(did) for did in (their_did, my_did) if PEER2_PATTERN.match(did)),
*(did for did in (their_did, my_did) if did and PEER3_PATTERN.match(did)),
*(
peer2to3(did)
for did in (their_did, my_did)
if did and PEER2_PATTERN.match(did)
),
]
if dids:
LOGGER.debug(
Expand Down
Loading

0 comments on commit 997a20f

Please sign in to comment.