diff --git a/aries_cloudagent/connections/base_manager.py b/aries_cloudagent/connections/base_manager.py index bdbe1dd00e..8e11d3b796 100644 --- a/aries_cloudagent/connections/base_manager.py +++ b/aries_cloudagent/connections/base_manager.py @@ -80,7 +80,6 @@ def __init__(self, profile: Profile): async def create_did_document( self, did_info: DIDInfo, - inbound_connection_id: Optional[str] = None, svc_endpoints: Optional[Sequence[str]] = None, mediation_records: Optional[List[MediationRecord]] = None, ) -> DIDDoc: @@ -88,7 +87,6 @@ async def create_did_document( Args: did_info: The DID information (DID and verkey) used in the connection - inbound_connection_id: The ID of the inbound routing connection to use svc_endpoints: Custom endpoints for the DID Document mediation_record: The record for mediation that contains routing_keys and service endpoint @@ -111,61 +109,18 @@ async def create_did_document( ) did_doc.set(pk) - router_id = inbound_connection_id - routing_keys = [] - router_idx = 1 - while router_id: - # look up routing connection information - async with self._profile.session() as session: - router = await ConnRecord.retrieve_by_id(session, router_id) - if ConnRecord.State.get(router.state) != ConnRecord.State.COMPLETED: - raise BaseConnectionManagerError( - f"Router connection not completed: {router_id}" - ) - routing_doc, _ = await self.fetch_did_document(router.their_did) - assert isinstance(routing_doc, DIDDoc) - if not routing_doc.service: - raise BaseConnectionManagerError( - f"No services defined by routing DIDDoc: {router_id}" - ) - for service in routing_doc.service.values(): - if not service.endpoint: - raise BaseConnectionManagerError( - "Routing DIDDoc service has no service endpoint" - ) - if not service.recip_keys: - raise BaseConnectionManagerError( - "Routing DIDDoc service has no recipient key(s)" - ) - rk = PublicKey( - did_info.did, - f"routing-{router_idx}", - service.recip_keys[0].value, - PublicKeyType.ED25519_SIG_2018, - did_controller, - True, - ) - routing_keys.append(rk) - svc_endpoints = [service.endpoint] - break - router_id = router.inbound_connection_id - + routing_keys: List[str] = [] if mediation_records: for mediation_record in mediation_records: - mediator_routing_keys = [ - PublicKey( - did_info.did, - f"routing-{idx}", - key, - PublicKeyType.ED25519_SIG_2018, - did_controller, # TODO: get correct controller did_info - True, # TODO: should this be true? - ) - for idx, key in enumerate(mediation_record.routing_keys) - ] - - routing_keys = [*routing_keys, *mediator_routing_keys] - svc_endpoints = [mediation_record.endpoint] + ( + mediator_routing_keys, + endpoint, + ) = await self._route_manager.routing_info( + self._profile, mediation_record + ) + routing_keys = [*routing_keys, *(mediator_routing_keys or [])] + if endpoint: + svc_endpoints = [endpoint] for endpoint_index, svc_endpoint in enumerate(svc_endpoints or []): endpoint_ident = "indy" if endpoint_index == 0 else f"indy{endpoint_index}" @@ -938,7 +893,6 @@ async def create_static_connection( # Synthesize their DID doc did_doc = await self.create_did_document( their_info, - None, [their_endpoint or ""], mediation_records=list( filter(None, [base_mediation_record, mediation_record]) diff --git a/aries_cloudagent/connections/models/diddoc/diddoc.py b/aries_cloudagent/connections/models/diddoc/diddoc.py index 0970a2fb0b..e4ced90108 100644 --- a/aries_cloudagent/connections/models/diddoc/diddoc.py +++ b/aries_cloudagent/connections/models/diddoc/diddoc.py @@ -22,6 +22,8 @@ from typing import List, Sequence, Union +from ....did.did_key import DIDKey + from .publickey import PublicKey, PublicKeyType from .service import Service from .util import canon_did, canon_ref, ok_did, resource @@ -116,13 +118,36 @@ def set(self, item: Union[Service, PublicKey]) -> "DIDDoc": "Cannot add item {} to DIDDoc on DID {}".format(item, self.did) ) - def serialize(self) -> dict: + @staticmethod + def _normalize_routing_keys(service: dict) -> dict: + """Normalize routing keys in service. + + Args: + service: service dict + + Returns: service dict with routing keys normalized + """ + routing_keys = service.get("routingKeys") + if routing_keys: + routing_keys = [ + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + for key in routing_keys + ] + service["routingKeys"] = routing_keys + return service + + def serialize(self, normalize_routing_keys: bool = False) -> dict: """Dump current object to a JSON-compatible dictionary. Returns: dict representation of current DIDDoc """ + service = [service.to_dict() for service in self.service.values()] + if normalize_routing_keys: + service = [self._normalize_routing_keys(s) for s in service] return { "@context": DIDDoc.CONTEXT, @@ -136,7 +161,7 @@ def serialize(self) -> dict: for pubkey in self.pubkey.values() if pubkey.authn ], - "service": [service.to_dict() for service in self.service.values()], + "service": service, } def to_json(self) -> str: @@ -285,7 +310,7 @@ def deserialize(cls, did_doc: dict) -> "DIDDoc": ), service["type"], rv.add_service_pubkeys(service, "recipientKeys"), - rv.add_service_pubkeys(service, ["mediatorKeys", "routingKeys"]), + service.get("routingKeys", []), canon_ref(rv.did, endpoint, ";") if ";" in endpoint else endpoint, service.get("priority", None), ) diff --git a/aries_cloudagent/connections/models/diddoc/service.py b/aries_cloudagent/connections/models/diddoc/service.py index 27d9564d5e..c9d2a8f7a0 100644 --- a/aries_cloudagent/connections/models/diddoc/service.py +++ b/aries_cloudagent/connections/models/diddoc/service.py @@ -36,7 +36,7 @@ def __init__( ident: str, typ: str, recip_keys: Union[Sequence, PublicKey], - routing_keys: Union[Sequence, PublicKey], + routing_keys: List[str], endpoint: str, priority: int = 0, ): @@ -69,13 +69,7 @@ def __init__( if recip_keys else None ) - self._routing_keys = ( - [routing_keys] - if isinstance(routing_keys, PublicKey) - else list(routing_keys) - if routing_keys - else None - ) + self._routing_keys = routing_keys or [] self._endpoint = endpoint self._priority = priority @@ -104,7 +98,7 @@ def recip_keys(self) -> List[PublicKey]: return self._recip_keys @property - def routing_keys(self) -> List[PublicKey]: + def routing_keys(self) -> List[str]: """Accessor for the routing keys.""" return self._routing_keys @@ -128,7 +122,7 @@ def to_dict(self) -> dict: if self.recip_keys: rv["recipientKeys"] = [k.value for k in self.recip_keys] if self.routing_keys: - rv["routingKeys"] = [k.value for k in self.routing_keys] + rv["routingKeys"] = self.routing_keys rv["serviceEndpoint"] = self.endpoint return rv diff --git a/aries_cloudagent/connections/tests/test_base_manager.py b/aries_cloudagent/connections/tests/test_base_manager.py index 61ff118cd2..aa136cdfbb 100644 --- a/aries_cloudagent/connections/tests/test_base_manager.py +++ b/aries_cloudagent/connections/tests/test_base_manager.py @@ -31,7 +31,10 @@ from ...protocols.coordinate_mediation.v1_0.models.mediation_record import ( MediationRecord, ) -from ...protocols.coordinate_mediation.v1_0.route_manager import RouteManager +from ...protocols.coordinate_mediation.v1_0.route_manager import ( + RouteManager, + CoordinateMediationV1RouteManager, +) from ...protocols.discovery.v2_0.manager import V20DiscoveryMgr from ...resolver.default.key import KeyDIDResolver from ...resolver.default.legacy_peer import LegacyPeerDIDResolver @@ -82,13 +85,7 @@ async def setUp(self): self.oob_mock = async_mock.MagicMock( clean_finished_oob_record=async_mock.CoroutineMock(return_value=None) ) - self.route_manager = async_mock.MagicMock(RouteManager) - self.route_manager.routing_info = async_mock.CoroutineMock( - return_value=([], self.test_endpoint) - ) - self.route_manager.mediation_record_if_id = async_mock.CoroutineMock( - return_value=None - ) + self.route_manager = CoordinateMediationV1RouteManager() self.resolver = DIDResolver() self.resolver.register_resolver(LegacyPeerDIDResolver()) self.resolver.register_resolver(KeyDIDResolver()) @@ -118,7 +115,7 @@ async def setUp(self): ) self.test_mediator_routing_keys = [ - "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRR" + "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" ] self.test_mediator_conn_id = "mediator-conn-id" self.test_mediator_endpoint = "http://mediator.example.com" @@ -135,176 +132,10 @@ async def test_create_did_document(self): key_type=ED25519, ) - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - did_doc = await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_not_active(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.ABANDONED.rfc23, - ) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_services(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_service_endpoint(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service(self.test_target_did, "dummy", "IndyAgent", [], [], "", 0) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) - - async def test_create_did_document_no_service_recip_keys(self): - did_info = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=self.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=self.test_target_did, verkey=self.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service( - self.test_target_did, - "dummy", - "IndyAgent", - [], - [], - self.test_endpoint, - 0, - ) + did_doc = await self.manager.create_did_document( + did_info=did_info, + svc_endpoints=[self.test_endpoint], ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[self.test_endpoint], - ) async def test_create_did_document_mediation(self): did_info = DIDInfo( @@ -328,8 +159,9 @@ async def test_create_did_document_mediation(self): services = list(doc.service.values()) assert len(services) == 1 (service,) = services - service_public_keys = service.routing_keys[0] - assert service_public_keys.value == mediation_record.routing_keys[0] + assert service.routing_keys + service_routing_key = service.routing_keys[0] + assert service_routing_key == mediation_record.routing_keys[0] assert service.endpoint == mediation_record.endpoint async def test_create_did_document_multiple_mediators(self): @@ -351,7 +183,9 @@ async def test_create_did_document_multiple_mediators(self): role=MediationRecord.ROLE_CLIENT, state=MediationRecord.STATE_GRANTED, connection_id="mediator-conn-id2", - routing_keys=["05e8afd1-b4f0-46b7-a285-7a08c8a37caf"], + routing_keys=[ + "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDz#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDz" + ], endpoint="http://mediatorw.example.com", ) doc = await self.manager.create_did_document( @@ -361,8 +195,8 @@ async def test_create_did_document_multiple_mediators(self): services = list(doc.service.values()) assert len(services) == 1 (service,) = services - assert service.routing_keys[0].value == mediation_record1.routing_keys[0] - assert service.routing_keys[1].value == mediation_record2.routing_keys[0] + assert service.routing_keys[0] == mediation_record1.routing_keys[0] + assert service.routing_keys[1] == mediation_record2.routing_keys[0] assert service.endpoint == mediation_record2.endpoint async def test_create_did_document_mediation_svc_endpoints_overwritten(self): @@ -380,6 +214,9 @@ async def test_create_did_document_mediation_svc_endpoints_overwritten(self): routing_keys=self.test_mediator_routing_keys, endpoint=self.test_mediator_endpoint, ) + self.route_manager.routing_info = async_mock.CoroutineMock( + return_value=(mediation_record.routing_keys, mediation_record.endpoint) + ) doc = await self.manager.create_did_document( did_info, svc_endpoints=[self.test_endpoint], @@ -390,7 +227,7 @@ async def test_create_did_document_mediation_svc_endpoints_overwritten(self): assert len(services) == 1 (service,) = services service_public_keys = service.routing_keys[0] - assert service_public_keys.value == mediation_record.routing_keys[0] + assert service_public_keys == mediation_record.routing_keys[0] assert service.endpoint == mediation_record.endpoint async def test_did_key_storage(self): @@ -436,7 +273,13 @@ async def test_store_did_document_with_routing_keys(self): "controller": "YQwDgq9vdAbB3fk1tkeXmg", "type": "Ed25519VerificationKey2018", "publicKeyBase58": "J81x9zdJa8CGSbTYpoYQaNrV6yv13M1Lgz4tmkNPKwZn", - } + }, + { + "id": "YQwDgq9vdAbB3fk1tkeXmg#1", + "controller": "YQwDgq9vdAbB3fk1tkeXmg", + "type": "Ed25519VerificationKey2018", + "publicKeyBase58": routing_key, + }, ], "service": [ { @@ -447,7 +290,7 @@ async def test_store_did_document_with_routing_keys(self): "recipientKeys": [ "J81x9zdJa8CGSbTYpoYQaNrV6yv13M1Lgz4tmkNPKwZn" ], - "routingKeys": ["cK7fwfjpakMuv8QKVv2y6qouZddVw4TxZNQPUs2fFTd"], + "routingKeys": [routing_key], } ], "authentication": [ @@ -1729,6 +1572,7 @@ async def test_create_static_connection_multitenant(self): ) self.multitenant_mgr.get_default_mediator.return_value = None + self.route_manager.route_static = async_mock.CoroutineMock() with async_mock.patch.object( ConnRecord, "save", autospec=True @@ -1761,6 +1605,7 @@ async def test_create_static_connection_multitenant_auto_disclose_features(self) } ) self.multitenant_mgr.get_default_mediator.return_value = None + self.route_manager.route_static = async_mock.CoroutineMock() with async_mock.patch.object( ConnRecord, "save", autospec=True ), async_mock.patch.object( @@ -1790,6 +1635,7 @@ async def test_create_static_connection_multitenant_mediator(self): ) default_mediator = async_mock.MagicMock() + self.route_manager.route_static = async_mock.CoroutineMock() with async_mock.patch.object( ConnRecord, "save", autospec=True @@ -1839,11 +1685,10 @@ async def test_create_static_connection_multitenant_mediator(self): [ call( their_info, - None, [self.test_endpoint], mediation_records=[default_mediator], ), - call(their_info, None, [self.test_endpoint], mediation_records=[]), + call(their_info, [self.test_endpoint], mediation_records=[]), ] ) diff --git a/aries_cloudagent/messaging/valid.py b/aries_cloudagent/messaging/valid.py index 08bf05f8a7..0838f84af6 100644 --- a/aries_cloudagent/messaging/valid.py +++ b/aries_cloudagent/messaging/valid.py @@ -275,6 +275,37 @@ def __init__(self): ) +class DIDKeyOrRef(Regexp): + """Validate value against DID key specification.""" + + EXAMPLE = "did:key:z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH" + PATTERN = re.compile(rf"^did:key:z[{B58}]+(?:#z[{B58}]+)?$") + + def __init__(self): + """Initialize the instance.""" + + super().__init__( + DIDKeyOrRef.PATTERN, error="Value {input} is not a did:key or did:key ref" + ) + + +class DIDKeyRef(Regexp): + """Validate value as DID key reference.""" + + EXAMPLE = ( + "did:key:z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH" + "#z6MkpTHR8VNsBxYAAWHut2Geadd9jSwuBV8xRoAnwWsdvktH" + ) + PATTERN = re.compile(rf"^did:key:z[{B58}]+#z[{B58}]+$") + + def __init__(self): + """Initialize the instance.""" + + super().__init__( + DIDKeyRef.PATTERN, error="Value {input} is not a did:key reference" + ) + + class DIDWeb(Regexp): """Validate value against did:web specification.""" @@ -854,6 +885,12 @@ def __init__( DID_KEY_VALIDATE = DIDKey() DID_KEY_EXAMPLE = DIDKey.EXAMPLE +DID_KEY_OR_REF_VALIDATE = DIDKeyOrRef() +DID_KEY_OR_REF_EXAMPLE = DIDKeyOrRef.EXAMPLE + +DID_KEY_REF_VALIDATE = DIDKeyRef() +DID_KEY_REF_EXAMPLE = DIDKeyRef.EXAMPLE + DID_POSTURE_VALIDATE = DIDPosture() DID_POSTURE_EXAMPLE = DIDPosture.EXAMPLE diff --git a/aries_cloudagent/multitenant/route_manager.py b/aries_cloudagent/multitenant/route_manager.py index 954b3c98f9..03798f47ce 100644 --- a/aries_cloudagent/multitenant/route_manager.py +++ b/aries_cloudagent/multitenant/route_manager.py @@ -2,7 +2,7 @@ import logging -from typing import List, Optional, Tuple +from typing import List, Optional from ..connections.models.conn_record import ConnRecord from ..core.profile import Profile @@ -11,10 +11,14 @@ from ..protocols.coordinate_mediation.v1_0.models.mediation_record import ( MediationRecord, ) -from ..protocols.coordinate_mediation.v1_0.normalization import normalize_from_did_key +from ..protocols.coordinate_mediation.v1_0.normalization import ( + normalize_from_did_key, + normalize_to_did_key, +) from ..protocols.coordinate_mediation.v1_0.route_manager import ( CoordinateMediationV1RouteManager, RouteManager, + RoutingInfo, ) from ..protocols.routing.v1_0.manager import RoutingManager from ..protocols.routing.v1_0.models.route_record import RouteRecord @@ -98,17 +102,35 @@ async def _route_for_key( return keylist_updates + async def mediation_records_for_connection( + self, + profile: Profile, + conn_record: ConnRecord, + mediation_id: Optional[str] = None, + or_default: bool = False, + ) -> List[MediationRecord]: + """Determine mediation records for a connection.""" + conn_specific = await super().mediation_records_for_connection( + profile, conn_record, mediation_id, or_default + ) + base_mediation_record = await self.get_base_wallet_mediator() + return [ + record + for record in (base_mediation_record, *conn_specific) + if record is not None + ] + async def routing_info( self, profile: Profile, - my_endpoint: str, mediation_record: Optional[MediationRecord] = None, - ) -> Tuple[List[str], str]: + ) -> RoutingInfo: """Return routing info.""" routing_keys = [] base_mediation_record = await self.get_base_wallet_mediator() + my_endpoint = None if base_mediation_record: routing_keys = base_mediation_record.routing_keys my_endpoint = base_mediation_record.endpoint @@ -117,7 +139,9 @@ async def routing_info( routing_keys = [*routing_keys, *mediation_record.routing_keys] my_endpoint = mediation_record.endpoint - return routing_keys, my_endpoint + routing_keys = [normalize_to_did_key(key).key_id for key in routing_keys] + + return RoutingInfo(routing_keys or None, my_endpoint) class BaseWalletRouteManager(CoordinateMediationV1RouteManager): diff --git a/aries_cloudagent/multitenant/tests/test_route_manager.py b/aries_cloudagent/multitenant/tests/test_route_manager.py index e4a537b7d1..2aebc0b2e9 100644 --- a/aries_cloudagent/multitenant/tests/test_route_manager.py +++ b/aries_cloudagent/multitenant/tests/test_route_manager.py @@ -18,6 +18,8 @@ TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ROUTE_RECORD_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF2 = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz" @pytest.fixture @@ -292,12 +294,10 @@ async def test_routing_info_with_mediator( mediation_record = MediationRecord( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://mediator.example.com", ) - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", mediation_record - ) + keys, endpoint = await route_manager.routing_info(sub_profile, mediation_record) assert keys == mediation_record.routing_keys assert endpoint == mediation_record.endpoint @@ -307,11 +307,9 @@ async def test_routing_info_no_mediator( sub_profile: Profile, route_manager: MultitenantRouteManager, ): - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", None - ) - assert keys == [] - assert endpoint == "http://example.com" + keys, endpoint = await route_manager.routing_info(sub_profile, None) + assert keys is None + assert endpoint is None @pytest.mark.asyncio @@ -322,7 +320,7 @@ async def test_routing_info_with_base_mediator( base_mediation_record = MediationRecord( mediation_id="test-base-mediation-id", connection_id="test-base-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://base.mediator.example.com", ) @@ -331,9 +329,7 @@ async def test_routing_info_with_base_mediator( "get_base_wallet_mediator", mock.CoroutineMock(return_value=base_mediation_record), ): - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", None - ) + keys, endpoint = await route_manager.routing_info(sub_profile, None) assert keys == base_mediation_record.routing_keys assert endpoint == base_mediation_record.endpoint @@ -346,13 +342,13 @@ async def test_routing_info_with_base_mediator_and_sub_mediator( mediation_record = MediationRecord( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF2], endpoint="http://mediator.example.com", ) base_mediation_record = MediationRecord( mediation_id="test-base-mediation-id", connection_id="test-base-mediator-conn-id", - routing_keys=["test-base-key-0", "test-base-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://base.mediator.example.com", ) @@ -361,9 +357,7 @@ async def test_routing_info_with_base_mediator_and_sub_mediator( "get_base_wallet_mediator", mock.CoroutineMock(return_value=base_mediation_record), ): - keys, endpoint = await route_manager.routing_info( - sub_profile, "http://example.com", mediation_record - ) + keys, endpoint = await route_manager.routing_info(sub_profile, mediation_record) assert keys == [*base_mediation_record.routing_keys, *mediation_record.routing_keys] assert endpoint == mediation_record.endpoint diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index fe33e1dac3..7a77cd6739 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -11,7 +11,6 @@ from ....core.profile import Profile from ....messaging.responder import BaseResponder from ....messaging.valid import IndyDID -from ....multitenant.base import BaseMultitenantManager from ....storage.error import StorageNotFoundError from ....transport.inbound.receipt import MessageReceipt from ....wallet.base import BaseWallet @@ -55,16 +54,16 @@ def profile(self) -> Profile: async def create_invitation( self, - my_label: str = None, - my_endpoint: str = None, - auto_accept: bool = None, + my_label: Optional[str] = None, + my_endpoint: Optional[str] = None, + auto_accept: Optional[bool] = None, public: bool = False, multi_use: bool = False, - alias: str = None, - routing_keys: Sequence[str] = None, - recipient_keys: Sequence[str] = None, - metadata: dict = None, - mediation_id: str = None, + alias: Optional[str] = None, + routing_keys: Optional[Sequence[str]] = None, + recipient_keys: Optional[Sequence[str]] = None, + metadata: Optional[dict] = None, + mediation_id: Optional[str] = None, ) -> Tuple[ConnRecord, ConnectionInvitation]: """Generate new connection invitation. @@ -208,11 +207,15 @@ async def create_invitation( await self._route_manager.route_invitation( self.profile, connection, mediation_record ) - routing_keys, my_endpoint = await self._route_manager.routing_info( + routing_keys, routing_endpoint = await self._route_manager.routing_info( self.profile, - my_endpoint or cast(str, self.profile.settings.get("default_endpoint")), mediation_record, ) + my_endpoint = ( + routing_endpoint + or my_endpoint + or cast(str, self.profile.settings.get("default_endpoint")) + ) # Create connection invitation message # Note: Need to split this into two stages @@ -336,20 +339,13 @@ async def create_request( """ - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, connection, mediation_id, or_default=True, ) - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - if connection.my_did: async with self.profile.session() as session: wallet = session.inject(BaseWallet) @@ -363,7 +359,7 @@ async def create_request( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_invitee( - self.profile, connection, mediation_record + self.profile, connection, mediation_records ) # Create connection request message @@ -378,11 +374,8 @@ async def create_request( did_doc = await self.create_did_document( my_info, - connection.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) if not my_label: @@ -587,18 +580,10 @@ async def create_response( settings=self.profile.settings, ) - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, connection, mediation_id ) - # Multitenancy setup - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - if ConnRecord.State.get(connection.state) not in ( ConnRecord.State.REQUEST, ConnRecord.State.RESPONSE, @@ -622,7 +607,7 @@ async def create_response( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( - self.profile, connection, mediation_record + self.profile, connection, mediation_records ) # Create connection response message @@ -637,11 +622,8 @@ async def create_response( did_doc = await self.create_did_document( my_info, - connection.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) response = ConnectionResponse( diff --git a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py index 79a402b356..03592102f1 100644 --- a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py +++ b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py @@ -3,8 +3,9 @@ from typing import Sequence from urllib.parse import parse_qs, urljoin, urlparse -from marshmallow import EXCLUDE, ValidationError, fields, validates_schema +from marshmallow import EXCLUDE, ValidationError, fields, pre_load, validates_schema +from .....did.did_key import DIDKey from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.valid import ( GENERIC_DID_EXAMPLE, @@ -58,6 +59,16 @@ def __init__( self.recipient_keys = list(recipient_keys) if recipient_keys else None self.endpoint = endpoint self.routing_keys = list(routing_keys) if routing_keys else None + self.routing_keys = ( + [ + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + for key in self.routing_keys + ] + if self.routing_keys + else None + ) self.image_url = image_url def to_url(self, base_url: str = None) -> str: @@ -157,6 +168,19 @@ class Meta: }, ) + @pre_load + def transform_routing_keys(self, data, **kwargs): + """Transform routingKeys from did:key refs, if necessary.""" + routing_keys = data.get("routingKeys") + if routing_keys: + data["routingKeys"] = [ + DIDKey.from_did(key).public_key_b58 + if key.startswith("did:key:") + else key + for key in routing_keys + ] + return data + @validates_schema def validate_fields(self, data, **kwargs): """Validate schema fields. diff --git a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py index dfe460c464..fd842c9d13 100644 --- a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py +++ b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py @@ -10,7 +10,7 @@ class DIDDocWrapper(fields.Field): """Field that loads and serializes DIDDoc.""" - def _serialize(self, value, attr, obj, **kwargs): + def _serialize(self, value: DIDDoc, attr, obj, **kwargs): """Serialize the DIDDoc. Args: @@ -20,7 +20,7 @@ def _serialize(self, value, attr, obj, **kwargs): The serialized DIDDoc """ - return value.serialize() + return value.serialize(normalize_routing_keys=True) def _deserialize(self, value, attr=None, data=None, **kwargs): """Deserialize a value into a DIDDoc. diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index b881d2fd34..ee8bfd1a66 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -315,7 +315,7 @@ async def test_create_invitation_mediation_using_default(self): assert invite.routing_keys == self.test_mediator_routing_keys assert invite.endpoint == self.test_mediator_endpoint self.route_manager.routing_info.assert_awaited_once_with( - self.profile, self.test_endpoint, mediation_record + self.profile, mediation_record ) async def test_receive_invitation(self): @@ -426,15 +426,11 @@ async def test_create_request_multitenant(self): with async_mock.patch.object( InMemoryWallet, "create_local_did", autospec=True ) as mock_wallet_create_local_did, async_mock.patch.object( - self.multitenant_mgr, - "get_default_mediator", - async_mock.CoroutineMock(return_value=mediation_record), - ), async_mock.patch.object( ConnectionManager, "create_did_document", autospec=True ) as create_did_document, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=None), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): mock_wallet_create_local_did.return_value = DIDInfo( self.test_did, @@ -455,7 +451,6 @@ async def test_create_request_multitenant(self): create_did_document.assert_called_once_with( self.manager, mock_wallet_create_local_did.return_value, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -487,8 +482,8 @@ async def test_create_request_mediation_id(self): InMemoryWallet, "create_local_did" ) as create_local_did, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): did_info = DIDInfo( did=self.test_did, @@ -507,7 +502,6 @@ async def test_create_request_mediation_id(self): create_did_document.assert_called_once_with( self.manager, did_info, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -539,8 +533,8 @@ async def test_create_request_default_mediator(self): InMemoryWallet, "create_local_did" ) as create_local_did, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): did_info = DIDInfo( did=self.test_did, @@ -558,7 +552,6 @@ async def test_create_request_default_mediator(self): create_did_document.assert_called_once_with( self.manager, did_info, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -881,10 +874,6 @@ async def test_create_response_multitenant(self): ConnRecord, "save", autospec=True ), async_mock.patch.object( ConnRecord, "metadata_get", async_mock.CoroutineMock(return_value=False) - ), async_mock.patch.object( - self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), ), async_mock.patch.object( ConnRecord, "retrieve_request", autospec=True ), async_mock.patch.object( @@ -892,15 +881,11 @@ async def test_create_response_multitenant(self): ), async_mock.patch.object( InMemoryWallet, "create_local_did", autospec=True ) as mock_wallet_create_local_did, async_mock.patch.object( - self.multitenant_mgr, - "get_default_mediator", - async_mock.CoroutineMock(return_value=mediation_record), - ), async_mock.patch.object( ConnectionManager, "create_did_document", autospec=True ) as create_did_document, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=None), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ): mock_wallet_create_local_did.return_value = DIDInfo( self.test_did, @@ -918,7 +903,6 @@ async def test_create_response_multitenant(self): create_did_document.assert_called_once_with( self.manager, mock_wallet_create_local_did.return_value, - None, [self.test_endpoint], mediation_records=[mediation_record], ) @@ -970,8 +954,8 @@ async def test_create_response_mediation(self): InMemoryWallet, "create_local_did" ) as create_local_did, async_mock.patch.object( self.route_manager, - "mediation_record_for_connection", - async_mock.CoroutineMock(return_value=mediation_record), + "mediation_records_for_connection", + async_mock.CoroutineMock(return_value=[mediation_record]), ), async_mock.patch.object( record, "retrieve_request", autospec=True ), async_mock.patch.object( @@ -994,7 +978,6 @@ async def test_create_response_mediation(self): create_did_document.assert_called_once_with( self.manager, did_info, - None, [self.test_endpoint], mediation_records=[mediation_record], ) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py index e8924dbb83..d3a39dbebb 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/handlers/tests/test_mediation_grant_handler.py @@ -1,9 +1,10 @@ """Test mediate grant message handler.""" import pytest -from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock +from aries_cloudagent.core.profile import ProfileSession + from ......connections.models.conn_record import ConnRecord from ......messaging.base_handler import HandlerException from ......messaging.request_context import RequestContext @@ -18,69 +19,86 @@ from .. import mediation_grant_handler as test_module TEST_CONN_ID = "conn-id" -TEST_RECORD_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" -TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" +TEST_BASE58_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" +TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ENDPOINT = "https://example.com" -class TestMediationGrantHandler(AsyncTestCase): - """Test mediate grant message handler.""" +@pytest.fixture() +async def context(): + context = RequestContext.test_context() + context.message = MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY]) + context.connection_ready = True + context.connection_record = ConnRecord(connection_id=TEST_CONN_ID) + yield context - async def setUp(self): - """Setup test dependencies.""" - self.context = RequestContext.test_context() - self.session = await self.context.session() - self.context.message = MediationGrant( - endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY] - ) - self.context.connection_ready = True - self.context.connection_record = ConnRecord(connection_id=TEST_CONN_ID) - async def test_handler_no_active_connection(self): +@pytest.fixture() +async def session(context: RequestContext): + yield await context.session() + + +@pytest.mark.asyncio +class TestMediationGrantHandler: + """Test mediate grant message handler.""" + + async def test_handler_no_active_connection(self, context: RequestContext): handler, responder = MediationGrantHandler(), MockResponder() - self.context.connection_ready = False + context.connection_ready = False with pytest.raises(HandlerException) as exc: - await handler.handle(self.context, responder) + await handler.handle(context, responder) assert "no active connection" in str(exc.value) - async def test_handler_no_mediation_record(self): + async def test_handler_no_mediation_record(self, context: RequestContext): handler, responder = MediationGrantHandler(), MockResponder() with pytest.raises(HandlerException) as exc: - await handler.handle(self.context, responder) + await handler.handle(context, responder) assert "has not been requested" in str(exc.value) - async def test_handler(self): + @pytest.mark.parametrize( + "grant", + [ + MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY]), + MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_BASE58_VERKEY]), + ], + ) + async def test_handler( + self, grant: MediationGrant, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), MockResponder() - await MediationRecord(connection_id=TEST_CONN_ID).save(self.session) - await handler.handle(self.context, responder) - record = await MediationRecord.retrieve_by_connection_id( - self.session, TEST_CONN_ID - ) + await MediationRecord(connection_id=TEST_CONN_ID).save(session) + context.message = grant + await handler.handle(context, responder) + record = await MediationRecord.retrieve_by_connection_id(session, TEST_CONN_ID) assert record assert record.state == MediationRecord.STATE_GRANTED assert record.endpoint == TEST_ENDPOINT - assert record.routing_keys == [TEST_RECORD_VERKEY] + assert record.routing_keys == [TEST_VERKEY] - async def test_handler_connection_has_set_to_default_meta(self): + async def test_handler_connection_has_set_to_default_meta( + self, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), MockResponder() record = MediationRecord(connection_id=TEST_CONN_ID) - await record.save(self.session) + await record.save(session) with async_mock.patch.object( - self.context.connection_record, + context.connection_record, "metadata_get", async_mock.CoroutineMock(return_value=True), ), async_mock.patch.object( test_module, "MediationManager", autospec=True ) as mock_mediation_manager: - await handler.handle(self.context, responder) + await handler.handle(context, responder) mock_mediation_manager.return_value.set_default_mediator.assert_called_once_with( record ) - async def test_handler_multitenant_base_mediation(self): + async def test_handler_multitenant_base_mediation( + self, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), async_mock.CoroutineMock() responder.send = async_mock.CoroutineMock() - profile = self.context.profile + profile = context.profile profile.context.update_settings( {"multitenant.enabled": True, "wallet.id": "test_wallet"} @@ -94,28 +112,30 @@ async def test_handler_multitenant_base_mediation(self): multitenant_mgr.get_default_mediator.return_value = default_base_mediator record = MediationRecord(connection_id=TEST_CONN_ID) - await record.save(self.session) + await record.save(session) with async_mock.patch.object(MediationManager, "add_key") as add_key: keylist_updates = async_mock.MagicMock() add_key.return_value = keylist_updates - await handler.handle(self.context, responder) + await handler.handle(context, responder) add_key.assert_called_once_with("key2") responder.send.assert_called_once_with( keylist_updates, connection_id=TEST_CONN_ID ) - async def test_handler_connection_no_set_to_default(self): + async def test_handler_connection_no_set_to_default( + self, session: ProfileSession, context: RequestContext + ): handler, responder = MediationGrantHandler(), MockResponder() record = MediationRecord(connection_id=TEST_CONN_ID) - await record.save(self.session) + await record.save(session) with async_mock.patch.object( - self.context.connection_record, + context.connection_record, "metadata_get", async_mock.CoroutineMock(return_value=False), ), async_mock.patch.object( test_module, "MediationManager", autospec=True ) as mock_mediation_manager: - await handler.handle(self.context, responder) + await handler.handle(context, responder) mock_mediation_manager.return_value.set_default_mediator.assert_not_called() diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py index 0ab45b1434..a97055da63 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py @@ -26,7 +26,10 @@ from .messages.mediate_grant import MediationGrant from .messages.mediate_request import MediationRequest from .models.mediation_record import MediationRecord -from .normalization import normalize_from_did_key +from .normalization import ( + normalize_from_did_key, + normalize_to_did_key, +) LOGGER = logging.getLogger(__name__) @@ -176,7 +179,7 @@ async def grant_request( await mediation_record.save(session, reason="Mediation request granted") grant = MediationGrant( endpoint=session.settings.get("default_endpoint"), - routing_keys=[routing_did.verkey], + routing_keys=[normalize_to_did_key(routing_did.verkey).key_id], ) return mediation_record, grant @@ -458,11 +461,9 @@ async def request_granted(self, record: MediationRecord, grant: MediationGrant): """ record.state = MediationRecord.STATE_GRANTED record.endpoint = grant.endpoint - # record.routing_keys = grant.routing_keys - routing_keys = [] - for key in grant.routing_keys: - routing_keys.append(normalize_from_did_key(key)) - record.routing_keys = routing_keys + record.routing_keys = [ + normalize_to_did_key(key).key_id for key in grant.routing_keys + ] async with self._profile.session() as session: await record.save(session, reason="Mediation request granted.") diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py index d2595ede53..551f795eac 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/messages/mediate_grant.py @@ -9,7 +9,6 @@ from .....messaging.agent_message import AgentMessage, AgentMessageSchema from ..message_types import MEDIATE_GRANT, PROTOCOL_PACKAGE -from ..normalization import normalize_from_public_key HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.mediation_grant_handler.MediationGrantHandler" @@ -41,11 +40,7 @@ def __init__( """ super(MediationGrant, self).__init__(**kwargs) self.endpoint = endpoint - self.routing_keys = ( - [normalize_from_public_key(key) for key in routing_keys] - if routing_keys - else [] - ) + self.routing_keys = routing_keys or [] class MediationGrantSchema(AgentMessageSchema): diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py index d699565367..28fee2ce89 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/normalization.py @@ -1,4 +1,5 @@ """Normalization methods used while transitioning to DID:Key method.""" +from typing import Union from ....did.did_key import DIDKey from ....wallet.key_type import ED25519 @@ -17,3 +18,12 @@ def normalize_from_public_key(key: str): return key return DIDKey.from_public_key_b58(key, ED25519).did + + +def normalize_to_did_key(value: Union[str, DIDKey]) -> DIDKey: + """Normalize a value to a DIDKey.""" + if isinstance(value, DIDKey): + return value + if value.startswith("did:key:"): + return DIDKey.from_did(value) + return DIDKey.from_public_key_b58(value, ED25519) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py index 990252cd5a..c83e7984ec 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/route_manager.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod import logging -from typing import List, Optional, Tuple +from typing import List, NamedTuple, Optional from ....connections.models.conn_record import ConnRecord from ....core.profile import Profile @@ -20,7 +20,10 @@ from .manager import MediationManager from .messages.keylist_update import KeylistUpdate from .models.mediation_record import MediationRecord -from .normalization import normalize_from_did_key +from .normalization import ( + normalize_from_did_key, + normalize_to_did_key, +) LOGGER = logging.getLogger(__name__) @@ -30,6 +33,18 @@ class RouteManagerError(Exception): """Raised on error from route manager.""" +class RoutingInfo(NamedTuple): + """Routing info tuple contiaing routing keys and endpoint.""" + + routing_keys: Optional[List[str]] + endpoint: Optional[str] + + @classmethod + def empty(cls): + """Empty routing info.""" + return cls(routing_keys=None, endpoint=None) + + class RouteManager(ABC): """Base Route Manager.""" @@ -59,14 +74,15 @@ def _validate_mediation_state(self, mediation_record: MediationRecord): f"{mediation_record.mediation_id}" ) - async def mediation_record_for_connection( + async def mediation_records_for_connection( self, profile: Profile, conn_record: ConnRecord, mediation_id: Optional[str] = None, or_default: bool = False, - ): + ) -> List[MediationRecord]: """Return relevant mediator for connection.""" + # TODO Support multiple mediators? if conn_record.connection_id: async with profile.session() as session: mediation_metadata = await conn_record.metadata_get( @@ -83,7 +99,7 @@ async def mediation_record_for_connection( await self.save_mediator_for_connection( profile, conn_record, mediation_record ) - return mediation_record + return [mediation_record] if mediation_record else [] async def mediation_record_if_id( self, @@ -126,11 +142,13 @@ async def route_connection_as_invitee( self, profile: Profile, conn_record: ConnRecord, - mediation_record: Optional[MediationRecord] = None, + mediation_records: List[MediationRecord], ) -> Optional[KeylistUpdate]: """Set up routing for a new connection when we are the invitee.""" LOGGER.debug("Routing connection as invitee") my_info = await self.get_or_create_my_did(profile, conn_record) + # Only most destward mediator receives keylist updates + mediation_record = mediation_records[0] if mediation_records else None return await self._route_for_key( profile, my_info.verkey, mediation_record, skip_if_exists=True ) @@ -139,7 +157,7 @@ async def route_connection_as_inviter( self, profile: Profile, conn_record: ConnRecord, - mediation_record: Optional[MediationRecord] = None, + mediation_records: List[MediationRecord], ) -> Optional[KeylistUpdate]: """Set up routing for a new connection when we are the inviter.""" LOGGER.debug("Routing connection as inviter") @@ -154,6 +172,9 @@ async def route_connection_as_inviter( if public_did and public_did.verkey == conn_record.invitation_key: replace_key = None + # Only most destward mediator receives keylist updates + mediation_record = mediation_records[0] if mediation_records else None + return await self._route_for_key( profile, my_info.verkey, @@ -166,7 +187,7 @@ async def route_connection( self, profile: Profile, conn_record: ConnRecord, - mediation_record: Optional[MediationRecord] = None, + mediation_records: List[MediationRecord], ) -> Optional[KeylistUpdate]: """Set up routing for a connection. @@ -176,14 +197,14 @@ async def route_connection( ConnRecord.Role.RESPONDER ): return await self.route_connection_as_invitee( - profile, conn_record, mediation_record + profile, conn_record, mediation_records ) if conn_record.rfc23_state == ConnRecord.State.REQUEST.rfc23strict( ConnRecord.Role.REQUESTER ): return await self.route_connection_as_inviter( - profile, conn_record, mediation_record + profile, conn_record, mediation_records ) return None @@ -255,9 +276,8 @@ async def save_mediator_for_connection( async def routing_info( self, profile: Profile, - my_endpoint: str, mediation_record: Optional[MediationRecord] = None, - ) -> Tuple[List[str], str]: + ) -> RoutingInfo: """Retrieve routing keys.""" async def connection_from_recipient_key( @@ -321,11 +341,16 @@ async def _route_for_key( async def routing_info( self, profile: Profile, - my_endpoint: str, mediation_record: Optional[MediationRecord] = None, - ) -> Tuple[List[str], str]: + ) -> RoutingInfo: """Return routing info for mediator.""" if mediation_record: - return mediation_record.routing_keys, mediation_record.endpoint + return RoutingInfo( + routing_keys=[ + normalize_to_did_key(key).key_id + for key in mediation_record.routing_keys + ], + endpoint=mediation_record.endpoint, + ) - return [], my_endpoint + return RoutingInfo.empty() diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py index 52b87058fd..11231c36d7 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/routes.py @@ -498,14 +498,14 @@ async def update_keylist_for_connection(request: web.BaseRequest): async with context.session() as session: connection_record = await ConnRecord.retrieve_by_id(session, connection_id) - mediation_record = await route_manager.mediation_record_for_connection( + mediation_records = await route_manager.mediation_records_for_connection( context.profile, connection_record, mediation_id, or_default=True ) # MediationRecord is permitted to be None; route manager will # ensure the correct mediator is notified. keylist_update = await route_manager.route_connection( - context.profile, connection_record, mediation_record + context.profile, connection_record, mediation_records ) results = keylist_update.serialize() if keylist_update else {} diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py index 188f600c99..5848c59063 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_manager.py @@ -29,10 +29,10 @@ TEST_CONN_ID = "conn-id" TEST_THREAD_ID = "thread-id" TEST_ENDPOINT = "https://example.com" -TEST_RECORD_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" -TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" +TEST_BASE58_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" +TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL#z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ROUTE_RECORD_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" -TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" pytestmark = pytest.mark.asyncio @@ -121,7 +121,7 @@ async def test_grant_request(self, session, manager): routing_key = await manager._retrieve_routing_did(session) routing_key = DIDKey.from_public_key_b58( routing_key.verkey, routing_key.key_type - ).did + ).key_id assert grant.routing_keys == [routing_key] async def test_deny_request(self, manager): @@ -134,7 +134,7 @@ async def test_deny_request(self, manager): async def test_update_keylist_delete(self, session, manager, record): """test_update_keylist_delete.""" await RouteRecord( - connection_id=TEST_CONN_ID, recipient_key=TEST_RECORD_VERKEY + connection_id=TEST_CONN_ID, recipient_key=TEST_BASE58_VERKEY ).save(session) response = await manager.update_keylist( record=record, @@ -169,7 +169,7 @@ async def test_update_keylist_create(self, manager, record): async def test_update_keylist_create_existing(self, session, manager, record): """test_update_keylist_create_existing.""" await RouteRecord( - connection_id=TEST_CONN_ID, recipient_key=TEST_RECORD_VERKEY + connection_id=TEST_CONN_ID, recipient_key=TEST_BASE58_VERKEY ).save(session) response = await manager.update_keylist( record=record, @@ -273,14 +273,25 @@ async def test_prepare_request(self, manager): assert record.connection_id == TEST_CONN_ID assert request - async def test_request_granted(self, manager): + async def test_request_granted_base58(self, manager): """test_request_granted.""" record, _ = await manager.prepare_request(TEST_CONN_ID) - grant = MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_ROUTE_VERKEY]) + grant = MediationGrant( + endpoint=TEST_ENDPOINT, routing_keys=[TEST_BASE58_VERKEY] + ) + await manager.request_granted(record, grant) + assert record.state == MediationRecord.STATE_GRANTED + assert record.endpoint == TEST_ENDPOINT + assert record.routing_keys == [TEST_VERKEY] + + async def test_request_granted_did_key(self, manager): + """test_request_granted.""" + record, _ = await manager.prepare_request(TEST_CONN_ID) + grant = MediationGrant(endpoint=TEST_ENDPOINT, routing_keys=[TEST_VERKEY]) await manager.request_granted(record, grant) assert record.state == MediationRecord.STATE_GRANTED assert record.endpoint == TEST_ENDPOINT - assert record.routing_keys == [TEST_ROUTE_RECORD_VERKEY] + assert record.routing_keys == [TEST_VERKEY] async def test_request_denied(self, manager): """test_request_denied.""" diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py index 4d9efceb72..f543efe386 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_route_manager.py @@ -26,6 +26,9 @@ TEST_VERKEY = "did:key:z6Mkgg342Ycpuk263R9d8Aq6MUaxPn1DDeHyGo38EefXmgDL" TEST_ROUTE_RECORD_VERKEY = "9WCgWKUaAJj3VWxxtzvvMQN3AoFxoBtBDo9ntwJnVVCC" TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhya" +TEST_ROUTE_VERKEY_REF2 = "did:key:z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz#z6MknxTj6Zj1VrDWc1ofaZtmCVv2zNXpD58Xup4ijDGoQhyz" class MockRouteManager(RouteManager): @@ -51,7 +54,7 @@ def route_manager(): manager._route_for_key = mock.CoroutineMock( return_value=mock.MagicMock(KeylistUpdate) ) - manager.routing_info = mock.CoroutineMock(return_value=([], "http://example.com")) + manager.routing_info = mock.CoroutineMock(return_value=([], None)) yield manager @@ -113,12 +116,9 @@ async def test_mediation_record_for_connection_mediation_id( ) as mock_mediation_record_if_id, mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): - assert ( - await route_manager.mediation_record_for_connection( - profile, conn_record, mediation_record.mediation_id - ) - == mediation_record - ) + assert await route_manager.mediation_records_for_connection( + profile, conn_record, mediation_record.mediation_id + ) == [mediation_record] mock_mediation_record_if_id.assert_called_once_with( profile, mediation_record.mediation_id, False ) @@ -139,12 +139,9 @@ async def test_mediation_record_for_connection_mediation_metadata( ) as mock_mediation_record_if_id, mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): - assert ( - await route_manager.mediation_record_for_connection( - profile, conn_record, "another-mediation-id" - ) - == mediation_record - ) + assert await route_manager.mediation_records_for_connection( + profile, conn_record, "another-mediation-id" + ) == [mediation_record] mock_mediation_record_if_id.assert_called_once_with( profile, mediation_record.mediation_id, False ) @@ -162,12 +159,9 @@ async def test_mediation_record_for_connection_default( ) as mock_mediation_record_if_id, mock.patch.object( route_manager, "save_mediator_for_connection", mock.CoroutineMock() ): - assert ( - await route_manager.mediation_record_for_connection( - profile, conn_record, None, or_default=True - ) - == mediation_record - ) + assert await route_manager.mediation_records_for_connection( + profile, conn_record, None, or_default=True + ) == [mediation_record] mock_mediation_record_if_id.assert_called_once_with(profile, None, True) @@ -285,7 +279,7 @@ async def test_route_connection_as_invitee( mock.CoroutineMock(return_value=mock_did_info), ): await route_manager.route_connection_as_invitee( - profile, conn_record, mediation_record + profile, conn_record, [mediation_record] ) route_manager._route_for_key.assert_called_once_with( profile, mock_did_info.verkey, mediation_record, skip_if_exists=True @@ -305,7 +299,7 @@ async def test_route_connection_as_inviter( mock.CoroutineMock(return_value=mock_did_info), ): await route_manager.route_connection_as_inviter( - profile, conn_record, mediation_record + profile, conn_record, [mediation_record] ) route_manager._route_for_key.assert_called_once_with( profile, @@ -342,7 +336,7 @@ async def test_route_connection_state_inviter_replace_key_none( ), ): await route_manager.route_connection_as_inviter( - profile, conn_record, mediation_record + profile, conn_record, [mediation_record] ) route_manager._route_for_key.assert_called_once_with( profile, @@ -365,7 +359,7 @@ async def test_route_connection_state_invitee( ) as mock_route_connection_as_invitee, mock.patch.object( route_manager, "route_connection_as_inviter", mock.CoroutineMock() ) as mock_route_connection_as_inviter: - await route_manager.route_connection(profile, conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, [mediation_record]) mock_route_connection_as_invitee.assert_called_once() mock_route_connection_as_inviter.assert_not_called() @@ -382,7 +376,7 @@ async def test_route_connection_state_inviter( ) as mock_route_connection_as_invitee, mock.patch.object( route_manager, "route_connection_as_inviter", mock.CoroutineMock() ) as mock_route_connection_as_inviter: - await route_manager.route_connection(profile, conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, [mediation_record]) mock_route_connection_as_inviter.assert_called_once() mock_route_connection_as_invitee.assert_not_called() @@ -395,7 +389,7 @@ async def test_route_connection_state_other( conn_record.state = "response" conn_record.their_role = "requester" assert ( - await route_manager.route_connection(profile, conn_record, mediation_record) + await route_manager.route_connection(profile, conn_record, [mediation_record]) is None ) @@ -696,11 +690,11 @@ async def test_mediation_routing_info_with_mediator( mediation_record = MediationRecord( mediation_id="test-mediation-id", connection_id="test-mediator-conn-id", - routing_keys=["test-key-0", "test-key-1"], + routing_keys=[TEST_ROUTE_VERKEY_REF], endpoint="http://mediator.example.com", ) keys, endpoint = await mediation_route_manager.routing_info( - profile, "http://example.com", mediation_record + profile, mediation_record ) assert keys == mediation_record.routing_keys assert endpoint == mediation_record.endpoint @@ -711,8 +705,6 @@ async def test_mediation_routing_info_no_mediator( profile: Profile, mediation_route_manager: CoordinateMediationV1RouteManager, ): - keys, endpoint = await mediation_route_manager.routing_info( - profile, "http://example.com", None - ) - assert keys == [] - assert endpoint == "http://example.com" + keys, endpoint = await mediation_route_manager.routing_info(profile, None) + assert keys is None + assert endpoint is None diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 0db2702e51..1b36f4c50d 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -17,7 +17,6 @@ from ....did.did_key import DIDKey from ....messaging.decorators.attach_decorator import AttachDecorator from ....messaging.responder import BaseResponder -from ....multitenant.base import BaseMultitenantManager from ....resolver.base import ResolverError from ....resolver.did_resolver import DIDResolver from ....storage.error import StorageNotFoundError @@ -285,21 +284,13 @@ async def create_request( """ # Mediation Support - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, conn_rec, mediation_id, or_default=True, ) - # Multitenancy setup - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - my_info = None if conn_rec.my_did: @@ -336,11 +327,8 @@ async def create_request( else: did_doc = await self.create_did_document( my_info, - conn_rec.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) attach = AttachDecorator.data_base64(did_doc.serialize()) async with self.profile.session() as session: @@ -377,7 +365,7 @@ async def create_request( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_invitee( - self.profile, conn_rec, mediation_record + self.profile, conn_rec, mediation_records ) return request @@ -599,18 +587,10 @@ async def create_response( settings=self.profile.settings, ) - mediation_record = await self._route_manager.mediation_record_for_connection( + mediation_records = await self._route_manager.mediation_records_for_connection( self.profile, conn_rec, mediation_id ) - # Multitenancy setup - multitenant_mgr = self.profile.inject_or(BaseMultitenantManager) - wallet_id = self.profile.settings.get("wallet.id") - - base_mediation_record = None - if multitenant_mgr and wallet_id: - base_mediation_record = await multitenant_mgr.get_default_mediator() - if ConnRecord.State.get(conn_rec.state) is not ConnRecord.State.REQUEST: raise DIDXManagerError( f"Connection not in state {ConnRecord.State.REQUEST.rfc23}" @@ -645,7 +625,7 @@ async def create_response( # Idempotent; if routing has already been set up, no action taken await self._route_manager.route_connection_as_inviter( - self.profile, conn_rec, mediation_record + self.profile, conn_rec, mediation_records ) # Create connection response message @@ -665,11 +645,8 @@ async def create_response( else: did_doc = await self.create_did_document( my_info, - conn_rec.inbound_connection_id, my_endpoints, - mediation_records=list( - filter(None, [base_mediation_record, mediation_record]) - ), + mediation_records=mediation_records, ) attach = AttachDecorator.data_base64(did_doc.serialize()) async with self.profile.session() as session: diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index 8bcbce0603..c0e1edb913 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -2198,177 +2198,10 @@ async def test_create_did_document(self): key_type=ED25519, ) - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - did_doc = self.make_did_doc( - did=TestConfig.test_target_did, - verkey=TestConfig.test_target_verkey, - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - did_doc = await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_not_completed(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.ABANDONED.rfc23, - ) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_no_services(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=TestConfig.test_target_did, verkey=TestConfig.test_target_verkey + did_doc = await self.manager.create_did_document( + did_info=did_info, + svc_endpoints=[TestConfig.test_endpoint], ) - x_did_doc._service = {} - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_no_service_endpoint(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=TestConfig.test_target_did, verkey=TestConfig.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service(TestConfig.test_target_did, "dummy", "IndyAgent", [], [], "", 0) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) - - async def test_create_did_document_no_service_recip_keys(self): - did_info = DIDInfo( - TestConfig.test_did, - TestConfig.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - - mock_conn = async_mock.MagicMock( - connection_id="dummy", - inbound_connection_id=None, - their_did=TestConfig.test_target_did, - state=ConnRecord.State.COMPLETED.rfc23, - ) - - x_did_doc = self.make_did_doc( - did=TestConfig.test_target_did, verkey=TestConfig.test_target_verkey - ) - x_did_doc._service = {} - x_did_doc.set( - Service( - TestConfig.test_target_did, - "dummy", - "IndyAgent", - [], - [], - TestConfig.test_endpoint, - 0, - ) - ) - for i in range(2): # first cover store-record, then update-value - await self.manager.store_did_document(x_did_doc) - - with async_mock.patch.object( - ConnRecord, "retrieve_by_id", async_mock.CoroutineMock() - ) as mock_conn_rec_retrieve_by_id: - mock_conn_rec_retrieve_by_id.return_value = mock_conn - - with self.assertRaises(BaseConnectionManagerError): - await self.manager.create_did_document( - did_info=did_info, - inbound_connection_id="dummy", - svc_endpoints=[TestConfig.test_endpoint], - ) async def test_did_key_storage(self): did_info = DIDInfo( diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 007b43a675..2dfbe7f421 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -317,9 +317,10 @@ async def create_invitation( async with self.profile.session() as session: await conn_rec.save(session, reason="Created new connection") - routing_keys, my_endpoint = await self._route_manager.routing_info( - self.profile, my_endpoint, mediation_record + routing_keys, routing_endpoint = await self._route_manager.routing_info( + self.profile, mediation_record ) + my_endpoint = routing_endpoint or my_endpoint if not conn_rec: our_service = ServiceDecorator( @@ -335,8 +336,8 @@ async def create_invitation( routing_keys = [ key if len(key.split(":")) == 3 - else DIDKey.from_public_key_b58(key, ED25519).did - for key in routing_keys + else DIDKey.from_public_key_b58(key, ED25519).key_id + for key in routing_keys or [] ] # Create connection invitation message @@ -353,7 +354,9 @@ async def create_invitation( _id="#inline", _type="did-communication", recipient_keys=[ - DIDKey.from_public_key_b58(connection_key.verkey, ED25519).did + DIDKey.from_public_key_b58( + connection_key.verkey, ED25519 + ).key_id ], service_endpoint=my_endpoint, routing_keys=routing_keys, @@ -814,11 +817,11 @@ async def _perform_handshake( "id": "#inline", "type": "did-communication", "recipientKeys": [ - DIDKey.from_public_key_b58(key, ED25519).did + DIDKey.from_public_key_b58(key, ED25519).key_id for key in recipient_keys ], "routingKeys": [ - DIDKey.from_public_key_b58(key, ED25519).did + DIDKey.from_public_key_b58(key, ED25519).key_id for key in routing_keys ], "serviceEndpoint": endpoint, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py index aca81b20f7..92ac2a7b88 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py @@ -1,13 +1,13 @@ """Record used to represent a service block of an out of band invitation.""" -from typing import Sequence +from typing import Optional, Sequence from marshmallow import EXCLUDE, fields, post_dump from .....messaging.models.base import BaseModel, BaseModelSchema from .....messaging.valid import ( - DID_KEY_EXAMPLE, - DID_KEY_VALIDATE, + DID_KEY_OR_REF_EXAMPLE, + DID_KEY_OR_REF_VALIDATE, INDY_DID_EXAMPLE, INDY_DID_VALIDATE, ) @@ -24,12 +24,12 @@ class Meta: def __init__( self, *, - _id: str = None, - _type: str = None, - did: str = None, - recipient_keys: Sequence[str] = None, - routing_keys: Sequence[str] = None, - service_endpoint: str = None, + _id: Optional[str] = None, + _type: Optional[str] = None, + did: Optional[str] = None, + recipient_keys: Optional[Sequence[str]] = None, + routing_keys: Optional[Sequence[str]] = None, + service_endpoint: Optional[str] = None, ): """Initialize a Service instance. @@ -72,10 +72,10 @@ class Meta: recipient_keys = fields.List( fields.Str( - validate=DID_KEY_VALIDATE, + validate=DID_KEY_OR_REF_VALIDATE, metadata={ "description": "Recipient public key", - "example": DID_KEY_EXAMPLE, + "example": DID_KEY_OR_REF_EXAMPLE, }, ), data_key="recipientKeys", @@ -85,8 +85,8 @@ class Meta: routing_keys = fields.List( fields.Str( - validate=DID_KEY_VALIDATE, - metadata={"description": "Routing key", "example": DID_KEY_EXAMPLE}, + validate=DID_KEY_OR_REF_VALIDATE, + metadata={"description": "Routing key", "example": DID_KEY_OR_REF_EXAMPLE}, ), data_key="routingKeys", required=False, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py index 4f33cbbbe9..b37cbd102b 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/tests/test_invitation.py @@ -108,7 +108,10 @@ def test_url_round_trip(self): service = Service( _id="#inline", _type=DID_COMM, - recipient_keys=[DIDKey.from_public_key_b58(TEST_VERKEY, ED25519).did], + recipient_keys=[ + DIDKey.from_public_key_b58(TEST_VERKEY, ED25519).did, + DIDKey.from_public_key_b58(TEST_VERKEY, ED25519).key_id, + ], service_endpoint="http://1.2.3.4:8080/service", ) invi_msg = InvitationMessage( diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index 7bee844ce4..25abf93193 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -808,7 +808,7 @@ async def test_create_invitation_peer_did(self): service["routingKeys"][0] == DIDKey.from_public_key_b58( self.test_mediator_routing_keys[0], ED25519 - ).did + ).key_id ) assert service["serviceEndpoint"] == self.test_mediator_endpoint diff --git a/aries_cloudagent/wallet/routes.py b/aries_cloudagent/wallet/routes.py index ba5e7d4bd4..58903a2963 100644 --- a/aries_cloudagent/wallet/routes.py +++ b/aries_cloudagent/wallet/routes.py @@ -671,7 +671,6 @@ async def wallet_set_public_did(request: web.BaseRequest): routing_keys, mediator_endpoint = await route_manager.routing_info( profile, - None, mediation_record, ) diff --git a/docker/Dockerfile.run b/docker/Dockerfile.run index 8660e45d76..ad4fde5622 100644 --- a/docker/Dockerfile.run +++ b/docker/Dockerfile.run @@ -7,14 +7,12 @@ RUN apt-get update && apt-get install -y curl && apt-get clean RUN pip install --no-cache-dir poetry -ADD . . +RUN mkdir -p aries_cloudagent && touch aries_cloudagent/__init__.py +ADD pyproject.toml poetry.lock README.md ./ +RUN mkdir -p logs && chmod -R ug+rw logs RUN poetry install -E "askar bbs" -RUN mkdir -p aries_cloudagent && touch aries_cloudagent/__init__.py -ADD aries_cloudagent/version.py aries_cloudagent/version.py - -RUN mkdir -p logs && chmod -R ug+rw logs -ADD aries_cloudagent ./aries_cloudagent +ADD . . ENTRYPOINT ["/bin/bash", "-c", "poetry run aca-py \"$@\"", "--"]