diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 1f4bc3a590..4de5712624 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -181,6 +181,7 @@ async def create_request_implicit( alias: str = None, goal_code: str = None, goal: str = None, + auto_accept: bool = False, ) -> ConnRecord: """Create and send a request against a public DID only (no explicit invitation). @@ -192,6 +193,7 @@ async def create_request_implicit( use_public_did: use my public DID for this connection goal_code: Optional self-attested code for sharing intent of connection goal: Optional self-attested string for sharing intent of connection + auto_accept: auto-accept a corresponding connection request Returns: The new `ConnRecord` instance @@ -223,7 +225,13 @@ async def create_request_implicit( ) except StorageNotFoundError: pass - + auto_accept = bool( + auto_accept + or ( + auto_accept is None + and self.profile.settings.get("debug.auto_accept_requests") + ) + ) conn_rec = ConnRecord( my_did=my_public_info.did if my_public_info @@ -233,10 +241,10 @@ async def create_request_implicit( their_role=ConnRecord.Role.RESPONDER.rfc23, invitation_key=None, invitation_msg_id=None, - accept=None, alias=alias, their_public_did=their_public_did, connection_protocol=DIDX_PROTO, + accept=ConnRecord.ACCEPT_AUTO if auto_accept else ConnRecord.ACCEPT_MANUAL, ) request = await self.create_request( # saves and updates conn_rec conn_rec=conn_rec, diff --git a/aries_cloudagent/protocols/didexchange/v1_0/routes.py b/aries_cloudagent/protocols/didexchange/v1_0/routes.py index abfacda882..0c7b90cd7c 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/routes.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/routes.py @@ -62,6 +62,10 @@ class DIDXCreateRequestImplicitQueryStringSchema(OpenAPISchema): required=False, metadata={"description": "Alias for connection", "example": "Barry"}, ) + auto_accept = fields.Boolean( + required=False, + metadata={"description": "Auto-accept connection (defaults to configuration)"}, + ) my_endpoint = fields.Str( required=False, validate=ENDPOINT_VALIDATE, @@ -260,6 +264,7 @@ async def didx_create_request_implicit(request: web.BaseRequest): use_public_did = json.loads(request.query.get("use_public_did", "null")) goal_code = request.query.get("goal_code") or None goal = request.query.get("goal") or None + auto_accept = json.loads(request.query.get("auto_accept", "null")) profile = context.profile didx_mgr = DIDXManager(profile) @@ -273,6 +278,7 @@ async def didx_create_request_implicit(request: web.BaseRequest): alias=alias, goal_code=goal_code, goal=goal, + auto_accept=auto_accept, ) except StorageNotFoundError as err: raise web.HTTPNotFound(reason=err.roll_up) from err diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index d6511ee48a..32b92585dd 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -311,6 +311,7 @@ async def test_create_request_implicit_use_public_did(self): mediation_id=None, use_public_did=True, alias="Tester", + auto_accept=True, ) assert info_public.did == conn_rec.my_did diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_routes.py index 6a3926b5b3..2888c91166 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_routes.py @@ -101,6 +101,7 @@ async def test_didx_create_request_implicit_not_found_x(self): "my_label": "label baby junior", "my_endpoint": "http://endpoint.ca", "mediator_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", + "auto_accept": "true", } with mock.patch.object( @@ -121,6 +122,7 @@ async def test_didx_create_request_implicit_wallet_x(self): "their_public_did": "public-did", "my_label": "label baby junior", "my_endpoint": "http://endpoint.ca", + "auto_accept": "true", } with mock.patch.object( diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 649401a5c5..7ddbbdc847 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -947,6 +947,28 @@ async def delete_stale_connection_by_invitation(self, invi_msg_id: str): for conn_rec in conn_records: await conn_rec.delete_record(session) + async def delete_conn_and_oob_record_invitation(self, invi_msg_id: str): + """Delete conn_record and oob_record associated with an invi_msg_id.""" + async with self.profile.session() as session: + conn_records = await ConnRecord.query( + session, + tag_filter={ + "invitation_msg_id": invi_msg_id, + }, + post_filter_positive={}, + ) + for conn_rec in conn_records: + await conn_rec.delete_record(session) + oob_records = await OobRecord.query( + session, + tag_filter={ + "invi_msg_id": invi_msg_id, + }, + post_filter_positive={}, + ) + for oob_rec in oob_records: + await oob_rec.delete_record(session) + async def receive_reuse_message( self, reuse_msg: HandshakeReuse, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py b/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py index 5a3ff8d84a..3451ab14c9 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/models/oob_record.py @@ -110,6 +110,7 @@ def record_value(self) -> dict: "connection_id", "role", "our_service", + "invi_msg_id", ) }, **{ diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py index 4f11ca51ca..e50732ccab 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py @@ -4,8 +4,13 @@ import logging from aiohttp import web -from aiohttp_apispec import docs, querystring_schema, request_schema, response_schema - +from aiohttp_apispec import ( + docs, + querystring_schema, + request_schema, + match_info_schema, + response_schema, +) from marshmallow import fields, validate from marshmallow.exceptions import ValidationError @@ -175,6 +180,23 @@ class InvitationReceiveQueryStringSchema(OpenAPISchema): ) +class InvitationRecordResponseSchema(OpenAPISchema): + """Response schema for Invitation Record.""" + + +class InvitationRecordMatchInfoSchema(OpenAPISchema): + """Path parameters and validators for request taking invitation record.""" + + invi_msg_id = fields.Str( + required=True, + validate=UUID4_VALIDATE, + metadata={ + "description": "Invitation Message identifier", + "example": UUID4_EXAMPLE, + }, + ) + + @docs( tags=["out-of-band"], summary="Create a new connection invitation", @@ -284,12 +306,37 @@ async def invitation_receive(request: web.BaseRequest): return web.json_response(result.serialize()) +@docs(tags=["out-of-band"], summary="Delete records associated with invitation") +@match_info_schema(InvitationRecordMatchInfoSchema()) +@response_schema(InvitationRecordResponseSchema(), description="") +async def invitation_remove(request: web.BaseRequest): + """Request handler for removing a invitation related conn and oob records. + + Args: + request: aiohttp request object + + """ + context: AdminRequestContext = request["context"] + invi_msg_id = request.match_info["invi_msg_id"] + profile = context.profile + oob_mgr = OutOfBandManager(profile) + try: + await oob_mgr.delete_conn_and_oob_record_invitation(invi_msg_id) + except StorageNotFoundError as err: + raise web.HTTPNotFound(reason=err.roll_up) from err + except StorageError as err: + raise web.HTTPBadRequest(reason=err.roll_up) from err + + return web.json_response({}) + + async def register(app: web.Application): """Register routes.""" app.add_routes( [ web.post("/out-of-band/create-invitation", invitation_create), web.post("/out-of-band/receive-invitation", invitation_receive), + web.delete("/out-of-band/invitations/{invi_msg_id}", invitation_remove), ] ) diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index b8a33c2598..745f51c42f 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -1844,3 +1844,40 @@ async def test_delete_stale_connection_by_invitation(self): mock_connrecord_query.return_value = records await self.manager.delete_stale_connection_by_invitation("test123") mock_connrecord_delete.assert_called_once() + + async def test_delete_conn_and_oob_record_invitation(self): + invitation = InvitationMessage() + oob_records = [ + OobRecord( + invitation=invitation, + invi_msg_id=invitation._id, + role=OobRecord.ROLE_RECEIVER, + connection_id=self.test_conn_rec.connection_id, + state=OobRecord.STATE_INITIAL, + ) + ] + conn_records = [ + ConnRecord( + my_did=self.test_did, + their_did="FBmi5JLf5g58kDnNXMy4QM", + their_role=ConnRecord.Role.RESPONDER.rfc160, + state=ConnRecord.State.INVITATION.rfc160, + invitation_key="dummy2", + invitation_mode="once", + invitation_msg_id="test123", + ) + ] + with mock.patch.object( + ConnRecord, "query", mock.CoroutineMock() + ) as mock_connrecord_query, mock.patch.object( + ConnRecord, "delete_record", mock.CoroutineMock() + ) as mock_connrecord_delete, mock.patch.object( + OobRecord, "query", mock.CoroutineMock() + ) as mock_oobrecord_query, mock.patch.object( + OobRecord, "delete_record", mock.CoroutineMock() + ) as mock_oobrecord_delete: + mock_connrecord_query.return_value = conn_records + mock_oobrecord_query.return_value = oob_records + await self.manager.delete_conn_and_oob_record_invitation("test123") + mock_connrecord_delete.assert_called_once() + mock_oobrecord_delete.assert_called_once() diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py index 22bcd06226..ecf22baa31 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py @@ -3,14 +3,15 @@ from .....admin.request_context import AdminRequestContext from .....connections.models.conn_record import ConnRecord +from .....core.in_memory import InMemoryProfile from .. import routes as test_module class TestOutOfBandRoutes(IsolatedAsyncioTestCase): async def asyncSetUp(self): - self.session_inject = {} - self.context = AdminRequestContext.test_context(self.session_inject) + self.profile = InMemoryProfile.test_profile() + self.context = AdminRequestContext.test_context(profile=self.profile) self.request_dict = { "context": self.context, "outbound_message_router": mock.CoroutineMock(), @@ -64,6 +65,20 @@ async def test_invitation_create(self): ) mock_json_response.assert_called_once_with({"abc": "123"}) + async def test_invitation_remove(self): + self.request.match_info = {"invi_msg_id": "dummy"} + + with mock.patch.object( + test_module, "OutOfBandManager", autospec=True + ) as mock_oob_mgr, mock.patch.object( + test_module.web, "json_response", mock.Mock() + ) as mock_json_response: + mock_oob_mgr.return_value.delete_conn_and_oob_record_invitation = ( + mock.CoroutineMock(return_value=None) + ) + await test_module.invitation_remove(self.request) + mock_json_response.assert_called_once_with({}) + async def test_invitation_create_with_accept(self): self.request.query = { "multi_use": "true", diff --git a/aries_cloudagent/resolver/default/peer3.py b/aries_cloudagent/resolver/default/peer3.py index 5a3100d95d..bca776e449 100644 --- a/aries_cloudagent/resolver/default/peer3.py +++ b/aries_cloudagent/resolver/default/peer3.py @@ -84,8 +84,10 @@ async def create_and_store(self, profile: Profile, peer2: str): async def remove_record_for_deleted_conn(self, profile: Profile, event: Event): """Remove record for deleted connection, if found.""" - their_did = event.payload["their_did"] - my_did = event.payload["my_did"] + their_did = event.payload.get("their_did") + my_did = event.payload.get("my_did") + if not their_did and not my_did: + return dids = [ *(did for did in (their_did, my_did) if PEER3_PATTERN.match(did)), *(peer2to3(did) for did in (their_did, my_did) if PEER2_PATTERN.match(did)),