diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index d19bff39c6..a00fffa8b2 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -1039,7 +1039,17 @@ def add_arguments(self, parser: ArgumentParser): action="store_true", env_var="ACAPY_PUBLIC_INVITES", help=( - "Send invitations out, and receive connection requests, " + "Send invitations out using the public DID for the agent, " + "and receive connection requests solicited by invitations " + "which use the public DID. Default: false." + ), + ) + parser.add_argument( + "--requests-through-public-did", + action="store_true", + env_var="ACAPY_REQUESTS_THROUGH_PUBLIC_DID", + help=( + "Allow agent to receive unsolicited connection requests, " "using the public DID for the agent. Default: false." ), ) @@ -1134,6 +1144,13 @@ def get_settings(self, args: Namespace) -> dict: settings["monitor_forward"] = args.monitor_forward if args.public_invites: settings["public_invites"] = True + if args.requests_through_public_did: + if not args.public_invites: + raise ArgsParseError( + "--public-invites is required to use " + "--requests-through-public-did" + ) + settings["requests_through_public_did"] = True if args.timing: settings["timing.enabled"] = True if args.timing_log: diff --git a/aries_cloudagent/connections/models/conn_record.py b/aries_cloudagent/connections/models/conn_record.py index eb0342970a..ca9e21b07f 100644 --- a/aries_cloudagent/connections/models/conn_record.py +++ b/aries_cloudagent/connections/models/conn_record.py @@ -677,11 +677,7 @@ class Meta: required=False, description="Routing state of connection", validate=validate.OneOf( - [ - getattr(ConnRecord, m) - for m in vars(ConnRecord) - if m.startswith("ROUTING_STATE_") - ] + ConnRecord.get_attributes_by_prefix("ROUTING_STATE_", walk_mro=False) ), example=ConnRecord.ROUTING_STATE_ACTIVE, ) @@ -690,11 +686,7 @@ class Meta: description="Connection acceptance: manual or auto", example=ConnRecord.ACCEPT_AUTO, validate=validate.OneOf( - [ - getattr(ConnRecord, a) - for a in vars(ConnRecord) - if a.startswith("ACCEPT_") - ] + ConnRecord.get_attributes_by_prefix("ACCEPT_", walk_mro=False) ), ) error_msg = fields.Str( @@ -707,11 +699,7 @@ class Meta: description="Invitation mode", example=ConnRecord.INVITATION_MODE_ONCE, validate=validate.OneOf( - [ - getattr(ConnRecord, i) - for i in vars(ConnRecord) - if i.startswith("INVITATION_MODE_") - ] + ConnRecord.get_attributes_by_prefix("INVITATION_MODE_", walk_mro=False) ), ) alias = fields.Str( diff --git a/aries_cloudagent/messaging/models/base_record.py b/aries_cloudagent/messaging/models/base_record.py index da1f7914f1..c696cf6771 100644 --- a/aries_cloudagent/messaging/models/base_record.py +++ b/aries_cloudagent/messaging/models/base_record.py @@ -81,6 +81,7 @@ class Meta: EVENT_NAMESPACE: str = "acapy::record" LOG_STATE_FLAG = None TAG_NAMES = {"state"} + STATE_DELETED = "deleted" def __init__( self, @@ -420,7 +421,7 @@ async def delete_record(self, session: ProfileSession): storage = session.inject(BaseStorage) if self.state: self._previous_state = self.state - self.state = "deleted" + self.state = BaseRecord.STATE_DELETED await self.emit_event(session, self.serialize()) await storage.delete_record(self.storage_record) @@ -497,6 +498,24 @@ def __eq__(self, other: Any) -> bool: return self.value == other.value and self.tags == other.tags return False + @classmethod + def get_attributes_by_prefix(cls, prefix: str, walk_mro: bool = True): + """ + List all values for attributes with common prefix. + + Args: + prefix: Common prefix to look for + walk_mro: Walk MRO to find attributes inherited from superclasses + """ + + bases = cls.__mro__ if walk_mro else [cls] + return [ + vars(base)[name] + for base in bases + for name in vars(base) + if name.startswith(prefix) + ] + class BaseExchangeRecord(BaseRecord): """Represents a base record with event tracing capability.""" diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index 23c8057645..28b673a061 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -127,10 +127,42 @@ async def create_invitation( or_default=True, ) image_url = self.profile.context.settings.get("image_url") + invitation = None + connection = None + + invitation_mode = ConnRecord.INVITATION_MODE_ONCE + if multi_use: + invitation_mode = ConnRecord.INVITATION_MODE_MULTI if not my_label: my_label = self.profile.settings.get("default_label") + accept = ( + ConnRecord.ACCEPT_AUTO + if ( + auto_accept + or ( + auto_accept is None + and self.profile.settings.get("debug.auto_accept_requests") + ) + ) + else ConnRecord.ACCEPT_MANUAL + ) + + if recipient_keys: + # TODO: register recipient keys for relay + # TODO: check that recipient keys are in wallet + invitation_key = recipient_keys[0] # TODO first key appropriate? + else: + # Create and store new invitation key + async with self.profile.session() as session: + wallet = session.inject(BaseWallet) + invitation_signing_key = await wallet.create_signing_key( + key_type=ED25519 + ) + invitation_key = invitation_signing_key.verkey + recipient_keys = [invitation_key] + if public: if not self.profile.settings.get("public_invites"): raise ConnectionManagerError("Public invitations are not enabled") @@ -143,89 +175,64 @@ async def create_invitation( "Cannot create public invitation with no public DID" ) - if multi_use: - raise ConnectionManagerError( - "Cannot use public and multi_use at the same time" - ) - - if metadata: - raise ConnectionManagerError( - "Cannot use public and set metadata at the same time" - ) - # FIXME - allow ledger instance to format public DID with prefix? invitation = ConnectionInvitation( label=my_label, did=f"did:sov:{public_did.did}", image_url=image_url ) + connection = ConnRecord( # create connection record + invitation_key=public_did.verkey, + invitation_msg_id=invitation._id, + invitation_mode=invitation_mode, + their_role=ConnRecord.Role.REQUESTER.rfc23, + state=ConnRecord.State.INVITATION.rfc23, + accept=accept, + alias=alias, + connection_protocol=CONN_PROTO, + ) + + async with self.profile.session() as session: + await connection.save(session, reason="Created new invitation") + # Add mapping for multitenant relaying. # Mediation of public keys is not supported yet await self._route_manager.route_public_did(self.profile, public_did.verkey) - return None, invitation - - invitation_mode = ConnRecord.INVITATION_MODE_ONCE - if multi_use: - invitation_mode = ConnRecord.INVITATION_MODE_MULTI - - if recipient_keys: - # TODO: register recipient keys for relay - # TODO: check that recipient keys are in wallet - invitation_key = recipient_keys[0] # TODO first key appropriate? else: - # Create and store new invitation key + # Create connection record + connection = ConnRecord( + invitation_key=invitation_key, # TODO: determine correct key to use + their_role=ConnRecord.Role.REQUESTER.rfc160, + state=ConnRecord.State.INVITATION.rfc160, + accept=accept, + invitation_mode=invitation_mode, + alias=alias, + connection_protocol=CONN_PROTO, + ) async with self.profile.session() as session: - wallet = session.inject(BaseWallet) - invitation_signing_key = await wallet.create_signing_key( - key_type=ED25519 - ) - invitation_key = invitation_signing_key.verkey - recipient_keys = [invitation_key] + await connection.save(session, reason="Created new invitation") - accept = ( - ConnRecord.ACCEPT_AUTO - if ( - auto_accept - or ( - auto_accept is None - and self.profile.settings.get("debug.auto_accept_requests") - ) + await self._route_manager.route_invitation( + self.profile, connection, mediation_record + ) + routing_keys, my_endpoint = await self._route_manager.routing_info( + self.profile, + my_endpoint or cast(str, self.profile.settings.get("default_endpoint")), + mediation_record, ) - else ConnRecord.ACCEPT_MANUAL - ) - - # Create connection record - connection = ConnRecord( - invitation_key=invitation_key, # TODO: determine correct key to use - their_role=ConnRecord.Role.REQUESTER.rfc160, - state=ConnRecord.State.INVITATION.rfc160, - accept=accept, - invitation_mode=invitation_mode, - alias=alias, - connection_protocol=CONN_PROTO, - ) - async with self.profile.session() as session: - await connection.save(session, reason="Created new invitation") - await self._route_manager.route_invitation( - self.profile, connection, mediation_record - ) - routing_keys, my_endpoint = await self._route_manager.routing_info( - self.profile, - my_endpoint or cast(str, self.profile.settings.get("default_endpoint")), - mediation_record, - ) + # Create connection invitation message + # Note: Need to split this into two stages + # to support inbound routing of invites + # Would want to reuse create_did_document and convert the result + invitation = ConnectionInvitation( + label=my_label, + recipient_keys=recipient_keys, + routing_keys=routing_keys, + endpoint=my_endpoint, + image_url=image_url, + ) - # Create connection invitation message - # Note: Need to split this into two stages to support inbound routing of invites - # Would want to reuse create_did_document and convert the result - invitation = ConnectionInvitation( - label=my_label, - recipient_keys=recipient_keys, - routing_keys=routing_keys, - endpoint=my_endpoint, - image_url=image_url, - ) async with self.profile.session() as session: await connection.attach_invitation(session, invitation) @@ -529,6 +536,11 @@ async def receive_request( their_role=ConnRecord.Role.REQUESTER.rfc160, ) if not connection: + if not self.profile.settings.get("requests_through_public_did"): + raise ConnectionManagerError( + "Unsolicited connection requests to " + "public DID is not enabled" + ) connection = ConnRecord() connection.invitation_key = connection_key connection.my_did = my_info.did diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index c9d6105ac7..03c7e64647 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -34,6 +34,7 @@ from ....discovery.v2_0.manager import V20DiscoveryMgr from ..manager import ConnectionManager, ConnectionManagerError +from .. import manager as test_module from ..messages.connection_invitation import ConnectionInvitation from ..messages.connection_request import ConnectionRequest from ..messages.connection_response import ConnectionResponse @@ -112,21 +113,6 @@ async def setUp(self): self.manager = ConnectionManager(self.profile) assert self.manager.profile - async def test_create_invitation_public_and_multi_use_fails(self): - self.context.update_settings({"public_invites": True}) - with async_mock.patch.object( - InMemoryWallet, "get_public_did", autospec=True - ) as mock_wallet_get_public_did: - mock_wallet_get_public_did.return_value = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - with self.assertRaises(ConnectionManagerError): - await self.manager.create_invitation(public=True, multi_use=True) - async def test_create_invitation_non_multi_use_invitation_fails_on_reuse(self): connect_record, connect_invite = await self.manager.create_invitation() @@ -174,7 +160,7 @@ async def test_create_invitation_public(self): public=True, my_endpoint="testendpoint" ) - assert connect_record is None + assert connect_record assert connect_invite.did.endswith(self.test_did) self.route_manager.route_public_did.assert_called_once_with( self.profile, self.test_verkey @@ -266,23 +252,6 @@ async def test_create_invitation_metadata_assigned(self): assert await record.metadata_get_all(session) == {"hello": "world"} - async def test_create_invitation_public_and_metadata_fails(self): - self.context.update_settings({"public_invites": True}) - with async_mock.patch.object( - InMemoryWallet, "get_public_did", autospec=True - ) as mock_wallet_get_public_did: - mock_wallet_get_public_did.return_value = DIDInfo( - self.test_did, - self.test_verkey, - None, - method=SOV, - key_type=ED25519, - ) - with self.assertRaises(ConnectionManagerError): - await self.manager.create_invitation( - public=True, metadata={"hello": "world"} - ) - async def test_create_invitation_multi_use_metadata_transfers_to_connection(self): async with self.profile.session() as session: connect_record, _ = await self.manager.create_invitation( @@ -643,7 +612,83 @@ async def test_receive_request_public_did_oob_invite(self): self.profile, mock_request ) + async def test_receive_request_public_did_unsolicited_fails(self): + async with self.profile.session() as session: + mock_request = async_mock.MagicMock() + mock_request.connection = async_mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = async_mock.MagicMock() + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + self.context.update_settings({"public_invites": True}) + with self.assertRaises(ConnectionManagerError), async_mock.patch.object( + ConnRecord, "connection_id", autospec=True + ), async_mock.patch.object( + ConnRecord, "save", autospec=True + ) as mock_conn_rec_save, async_mock.patch.object( + ConnRecord, "attach_request", autospec=True + ) as mock_conn_attach_request, async_mock.patch.object( + ConnRecord, "retrieve_by_id", autospec=True + ) as mock_conn_retrieve_by_id, async_mock.patch.object( + ConnRecord, "retrieve_request", autospec=True + ), async_mock.patch.object( + ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock() + ) as mock_conn_retrieve_by_invitation_msg_id: + mock_conn_retrieve_by_invitation_msg_id.return_value = None + conn_rec = await self.manager.receive_request(mock_request, receipt) + async def test_receive_request_public_did_conn_invite(self): + async with self.profile.session() as session: + mock_request = async_mock.MagicMock() + mock_request.connection = async_mock.MagicMock() + mock_request.connection.did = self.test_did + mock_request.connection.did_doc = async_mock.MagicMock() + mock_request.connection.did_doc.did = self.test_did + + receipt = MessageReceipt( + recipient_did=self.test_did, recipient_did_public=True + ) + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=self.test_did, + ) + + mock_connection_record = async_mock.MagicMock() + mock_connection_record.save = async_mock.CoroutineMock() + mock_connection_record.attach_request = async_mock.CoroutineMock() + + self.context.update_settings({"public_invites": True}) + with async_mock.patch.object( + ConnRecord, "connection_id", autospec=True + ), async_mock.patch.object( + ConnRecord, "save", autospec=True + ) as mock_conn_rec_save, async_mock.patch.object( + ConnRecord, "attach_request", autospec=True + ) as mock_conn_attach_request, async_mock.patch.object( + ConnRecord, "retrieve_by_id", autospec=True + ) as mock_conn_retrieve_by_id, async_mock.patch.object( + ConnRecord, "retrieve_request", autospec=True + ), async_mock.patch.object( + ConnRecord, + "retrieve_by_invitation_msg_id", + async_mock.CoroutineMock(return_value=mock_connection_record), + ) as mock_conn_retrieve_by_invitation_msg_id: + conn_rec = await self.manager.receive_request(mock_request, receipt) + assert conn_rec + + async def test_receive_request_public_did_unsolicited(self): async with self.profile.session() as session: mock_request = async_mock.MagicMock() mock_request.connection = async_mock.MagicMock() @@ -662,6 +707,7 @@ async def test_receive_request_public_did_conn_invite(self): ) self.context.update_settings({"public_invites": True}) + self.context.update_settings({"requests_through_public_did": True}) with async_mock.patch.object( ConnRecord, "connection_id", autospec=True ), async_mock.patch.object( diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 2509766e9d..b209114b01 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -483,6 +483,10 @@ async def receive_request( ) else: # request is against implicit invitation on public DID + if not self.profile.settings.get("requests_through_public_did"): + raise DIDXManagerError( + "Unsolicited connection requests to " "public DID is not enabled" + ) async with self.profile.session() as session: wallet = session.inject(BaseWallet) my_info = await wallet.create_local_did( diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index 69de49fcfd..30ae59db04 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -965,6 +965,151 @@ async def test_receive_request_public_did_no_auto_accept(self): messages = self.responder.messages assert not messages + async def test_receive_request_implicit_public_did_not_enabled(self): + async with self.profile.session() as session: + mock_request = async_mock.MagicMock( + did=TestConfig.test_did, + did_doc_attach=async_mock.MagicMock( + data=async_mock.MagicMock( + verify=async_mock.CoroutineMock(return_value=True), + signed=async_mock.MagicMock( + decode=async_mock.MagicMock(return_value="dummy-did-doc") + ), + ) + ), + _thread=async_mock.MagicMock(pthid="did:sov:publicdid0000000000000"), + ) + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + await mediation_record.save(session) + + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=TestConfig.test_did, + ) + + self.profile.context.update_settings({"public_invites": True}) + + with async_mock.patch.object( + test_module, "ConnRecord", async_mock.MagicMock() + ) as mock_conn_rec_cls, async_mock.patch.object( + test_module, "DIDDoc", autospec=True + ) as mock_did_doc, async_mock.patch.object( + test_module, "DIDPosture", autospec=True + ) as mock_did_posture, async_mock.patch.object( + self.manager, + "verify_diddoc", + async_mock.CoroutineMock(return_value=DIDDoc(TestConfig.test_did)), + ): + mock_did_posture.get = async_mock.MagicMock( + return_value=test_module.DIDPosture.PUBLIC + ) + mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( + side_effect=StorageNotFoundError() + ) + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=None) + ) + + with self.assertRaises(DIDXManagerError) as context: + await self.manager.receive_request( + request=mock_request, + recipient_did=TestConfig.test_did, + my_endpoint=None, + alias=None, + auto_accept_implicit=None, + ) + assert "Unsolicited connection requests" in str(context.exception) + + async def test_receive_request_implicit_public_did(self): + async with self.profile.session() as session: + mock_request = async_mock.MagicMock( + did=TestConfig.test_did, + did_doc_attach=async_mock.MagicMock( + data=async_mock.MagicMock( + verify=async_mock.CoroutineMock(return_value=True), + signed=async_mock.MagicMock( + decode=async_mock.MagicMock(return_value="dummy-did-doc") + ), + ) + ), + _thread=async_mock.MagicMock(pthid="did:sov:publicdid0000000000000"), + ) + mediation_record = MediationRecord( + role=MediationRecord.ROLE_CLIENT, + state=MediationRecord.STATE_GRANTED, + connection_id=self.test_mediator_conn_id, + routing_keys=self.test_mediator_routing_keys, + endpoint=self.test_mediator_endpoint, + ) + await mediation_record.save(session) + + await session.wallet.create_local_did( + method=SOV, + key_type=ED25519, + seed=None, + did=TestConfig.test_did, + ) + + self.profile.context.update_settings({"public_invites": True}) + self.profile.context.update_settings({"requests_through_public_did": True}) + ACCEPT_AUTO = ConnRecord.ACCEPT_AUTO + STATE_REQUEST = ConnRecord.State.REQUEST + + with async_mock.patch.object( + test_module, "ConnRecord", async_mock.MagicMock() + ) as mock_conn_rec_cls, async_mock.patch.object( + test_module, "DIDDoc", autospec=True + ) as mock_did_doc, async_mock.patch.object( + test_module, "DIDPosture", autospec=True + ) as mock_did_posture, async_mock.patch.object( + self.manager, + "verify_diddoc", + async_mock.CoroutineMock(return_value=DIDDoc(TestConfig.test_did)), + ): + mock_did_posture.get = async_mock.MagicMock( + return_value=test_module.DIDPosture.PUBLIC + ) + mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock( + side_effect=StorageNotFoundError() + ) + mock_conn_rec_cls.retrieve_by_invitation_msg_id = ( + async_mock.CoroutineMock(return_value=None) + ) + + mock_conn_record = async_mock.MagicMock( + accept=ACCEPT_AUTO, + my_did=None, + state=STATE_REQUEST.rfc23, + attach_request=async_mock.CoroutineMock(), + retrieve_request=async_mock.CoroutineMock(), + metadata_get_all=async_mock.CoroutineMock(return_value={}), + metadata_get=async_mock.CoroutineMock(return_value=True), + save=async_mock.CoroutineMock(), + ) + + mock_conn_rec_cls.return_value = mock_conn_record + + conn_rec = await self.manager.receive_request( + request=mock_request, + recipient_did=TestConfig.test_did, + recipient_verkey=None, + my_endpoint=None, + alias=None, + auto_accept_implicit=None, + ) + assert conn_rec + self.oob_mock.clean_finished_oob_record.assert_called_once_with( + self.profile, mock_request + ) + async def test_receive_request_peer_did(self): async with self.profile.session() as session: mock_request = async_mock.MagicMock( diff --git a/aries_cloudagent/protocols/issue_credential/v2_0/models/cred_ex_record.py b/aries_cloudagent/protocols/issue_credential/v2_0/models/cred_ex_record.py index fbbc0adaf5..900310f524 100644 --- a/aries_cloudagent/protocols/issue_credential/v2_0/models/cred_ex_record.py +++ b/aries_cloudagent/protocols/issue_credential/v2_0/models/cred_ex_record.py @@ -296,11 +296,7 @@ class Meta: description="Issue-credential exchange initiator: self or external", example=V20CredExRecord.INITIATOR_SELF, validate=validate.OneOf( - [ - getattr(V20CredExRecord, m) - for m in vars(V20CredExRecord) - if m.startswith("INITIATOR_") - ] + V20CredExRecord.get_attributes_by_prefix("INITIATOR_", walk_mro=False) ), ) role = fields.Str( @@ -308,11 +304,7 @@ class Meta: description="Issue-credential exchange role: holder or issuer", example=V20CredExRecord.ROLE_ISSUER, validate=validate.OneOf( - [ - getattr(V20CredExRecord, m) - for m in vars(V20CredExRecord) - if m.startswith("ROLE_") - ] + V20CredExRecord.get_attributes_by_prefix("ROLE_", walk_mro=False) ), ) state = fields.Str( @@ -320,11 +312,7 @@ class Meta: description="Issue-credential exchange state", example=V20CredExRecord.STATE_DONE, validate=validate.OneOf( - [ - getattr(V20CredExRecord, m) - for m in vars(V20CredExRecord) - if m.startswith("STATE_") - ] + V20CredExRecord.get_attributes_by_prefix("STATE_", walk_mro=True) ), ) cred_preview = fields.Nested( diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 0c9b0d5c05..b3c3f97311 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -140,15 +140,6 @@ async def create_invitation( raise OutOfBandManagerError( "Cannot store metadata without handshake protocols" ) - if public: - if multi_use: - raise OutOfBandManagerError( - "Cannot create public invitation with multi_use" - ) - if metadata: - raise OutOfBandManagerError( - "Cannot store metadata on public invitations" - ) if attachments and multi_use: raise OutOfBandManagerError( @@ -247,9 +238,15 @@ async def create_invitation( # Only create connection record if hanshake_protocols is defined if handshake_protocols: + invitation_mode = ( + ConnRecord.INVITATION_MODE_MULTI + if multi_use + else ConnRecord.INVITATION_MODE_ONCE + ) conn_rec = ConnRecord( # create connection record invitation_key=public_did.verkey, invitation_msg_id=invi_msg._id, + invitation_mode=invitation_mode, their_role=ConnRecord.Role.REQUESTER.rfc23, state=ConnRecord.State.INVITATION.rfc23, accept=ConnRecord.ACCEPT_AUTO @@ -262,6 +259,12 @@ async def create_invitation( async with self.profile.session() as session: await conn_rec.save(session, reason="Created new invitation") await conn_rec.attach_invitation(session, invi_msg) + + await conn_rec.attach_invitation(session, invi_msg) + + if metadata: + for key, value in metadata.items(): + await conn_rec.metadata_set(session, key, value) else: our_service = ServiceDecorator( recipient_keys=[our_recipient_key], diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py b/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py index 69f668335b..e550d9eb54 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py @@ -3,7 +3,7 @@ import json from typing import Any, Mapping, Optional, Union -from marshmallow import fields +from marshmallow import fields, validate from .....connections.models.conn_record import ConnRecord from .....core.profile import ProfileSession @@ -248,6 +248,9 @@ class Meta: required=True, description="Out of band message exchange state", example=OobRecord.STATE_AWAIT_RESPONSE, + validate=validate.OneOf( + OobRecord.get_attributes_by_prefix("STATE_", walk_mro=True) + ), ) invi_msg_id = fields.Str( required=True, @@ -287,4 +290,7 @@ class Meta: description="OOB Role", required=False, example=OobRecord.ROLE_RECEIVER, + validate=validate.OneOf( + OobRecord.get_attributes_by_prefix("ROLE_", walk_mro=False) + ), ) diff --git a/aries_cloudagent/protocols/present_proof/v2_0/models/pres_exchange.py b/aries_cloudagent/protocols/present_proof/v2_0/models/pres_exchange.py index cc314aa9db..4b5e22a10b 100644 --- a/aries_cloudagent/protocols/present_proof/v2_0/models/pres_exchange.py +++ b/aries_cloudagent/protocols/present_proof/v2_0/models/pres_exchange.py @@ -244,11 +244,7 @@ class Meta: description="Present-proof exchange initiator: self or external", example=V20PresExRecord.INITIATOR_SELF, validate=validate.OneOf( - [ - getattr(V20PresExRecord, m) - for m in vars(V20PresExRecord) - if m.startswith("INITIATOR_") - ] + V20PresExRecord.get_attributes_by_prefix("INITIATOR_", walk_mro=False) ), ) role = fields.Str( @@ -256,22 +252,14 @@ class Meta: description="Present-proof exchange role: prover or verifier", example=V20PresExRecord.ROLE_PROVER, validate=validate.OneOf( - [ - getattr(V20PresExRecord, m) - for m in vars(V20PresExRecord) - if m.startswith("ROLE_") - ] + V20PresExRecord.get_attributes_by_prefix("ROLE_", walk_mro=False) ), ) state = fields.Str( required=False, description="Present-proof exchange state", validate=validate.OneOf( - [ - getattr(V20PresExRecord, m) - for m in vars(V20PresExRecord) - if m.startswith("STATE_") - ] + V20PresExRecord.get_attributes_by_prefix("STATE_", walk_mro=True) ), ) pres_proposal = fields.Nested(