diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 9fd4542c02..ec157bf1b6 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -4,7 +4,7 @@ import json import logging -from typing import Mapping, Sequence +from typing import Mapping, Sequence, Optional from ....connections.base_manager import BaseConnectionManager from ....connections.models.conn_record import ConnRecord @@ -477,69 +477,7 @@ async def receive_invitation( and num_included_req_attachments == 0 and use_existing_connection ): - await self.create_handshake_reuse_message( - invi_msg=invitation, - conn_record=conn_rec, - ) - try: - await asyncio.wait_for( - self.check_reuse_msg_state( - conn_rec=conn_rec, - ), - 15, - ) - async with self.profile.session() as session: - await conn_rec.metadata_delete( - session=session, key="reuse_msg_id" - ) - - msg_state = await conn_rec.metadata_get( - session, "reuse_msg_state" - ) - - if msg_state == "not_accepted": - conn_rec = None - else: - async with self.profile.session() as session: - await conn_rec.metadata_delete( - session=session, key="reuse_msg_state" - ) - # refetch connection for accurate state after handshake - conn_rec = await ConnRecord.retrieve_by_id( - session=session, record_id=conn_rec.connection_id - ) - except asyncio.TimeoutError: - # If no reuse_accepted or problem_report message was received within - # the 15s timeout then a new connection to be created - async with self.profile.session() as session: - sent_reuse_msg_id = await conn_rec.metadata_get( - session=session, key="reuse_msg_id" - ) - await conn_rec.metadata_delete( - session=session, key="reuse_msg_id" - ) - await conn_rec.metadata_delete( - session=session, key="reuse_msg_state" - ) - conn_rec.state = ConnRecord.State.ABANDONED.rfc160 - await conn_rec.save( - session, reason="No HandshakeReuseAccept message received" - ) - # Emit webhook - await self.profile.notify( - REUSE_ACCEPTED_WEBHOOK_TOPIC, - { - "thread_id": sent_reuse_msg_id, - "connection_id": conn_rec.connection_id, - "state": "rejected", - "comment": ( - "No HandshakeReuseAccept message received, " - f"connection {conn_rec.connection_id} ", - f"and invitation {invitation._id}", - ), - }, - ) - conn_rec = None + conn_rec = await self.send_reuse_message(invitation, conn_rec) # Inverse of the following cases # Handshake_Protocol not included # Request_Attachment included @@ -609,6 +547,8 @@ async def receive_invitation( # Request Attach if len(invitation.requests_attach) >= 1 and conn_rec is not None: req_attach = invitation.requests_attach[0] + if use_existing_connection: + conn_rec = await self.send_reuse_message(invitation, conn_rec) if isinstance(req_attach, AttachDecorator): if req_attach.data is not None: unq_req_attach_type = DIDCommPrefix.unqualify( @@ -768,6 +708,73 @@ async def _process_pres_request_v1( ) ) + async def send_reuse_message( + self, invitation: InvitationMessage, conn_rec: ConnRecord + ) -> Optional[ConnRecord]: + """ + Create and wait for handshake reuse message. + + Args: + invitation: invitation message + conn_rec: connection record + """ + await self.create_handshake_reuse_message( + invi_msg=invitation, + conn_record=conn_rec, + ) + try: + await asyncio.wait_for( + self.check_reuse_msg_state( + conn_rec=conn_rec, + ), + 15, + ) + async with self.profile.session() as session: + await conn_rec.metadata_delete(session=session, key="reuse_msg_id") + + msg_state = await conn_rec.metadata_get(session, "reuse_msg_state") + + if msg_state == "not_accepted": + conn_rec = None + else: + async with self.profile.session() as session: + await conn_rec.metadata_delete( + session=session, key="reuse_msg_state" + ) + # refetch connection for accurate state after handshake + conn_rec = await ConnRecord.retrieve_by_id( + session=session, record_id=conn_rec.connection_id + ) + return conn_rec + except asyncio.TimeoutError: + # If no reuse_accepted or problem_report message was received within + # the 15s timeout then a new connection to be created + async with self.profile.session() as session: + sent_reuse_msg_id = await conn_rec.metadata_get( + session=session, key="reuse_msg_id" + ) + await conn_rec.metadata_delete(session=session, key="reuse_msg_id") + await conn_rec.metadata_delete(session=session, key="reuse_msg_state") + conn_rec.state = ConnRecord.State.ABANDONED.rfc160 + await conn_rec.save( + session, reason="No HandshakeReuseAccept message received" + ) + # Emit webhook + await self.profile.notify( + REUSE_ACCEPTED_WEBHOOK_TOPIC, + { + "thread_id": sent_reuse_msg_id, + "connection_id": conn_rec.connection_id, + "state": "rejected", + "comment": ( + "No HandshakeReuseAccept message received, " + f"connection {conn_rec.connection_id} ", + f"and invitation {invitation._id}", + ), + }, + ) + return None + async def _process_pres_request_v2( self, req_attach: AttachDecorator, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index 297d10c59b..144e9301cb 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -1351,7 +1351,19 @@ async def test_receive_invitation_attachment_x(self): ) as didx_mgr_receive_invitation, async_mock.patch( "aries_cloudagent.protocols.out_of_band.v1_0.manager.InvitationMessage", autospec=True, - ) as inv_message_cls: + ) as inv_message_cls, async_mock.patch.object( + OutOfBandManager, + "create_handshake_reuse_message", + autospec=True, + ) as oob_mgr_create_reuse_msg, async_mock.patch.object( + OutOfBandManager, + "check_reuse_msg_state", + autospec=True, + ) as oob_mgr_check_reuse_state, async_mock.patch.object( + OutOfBandManager, + "send_reuse_message", + autospec=True, + ) as oob_mgr_send_reuse_message: mock_oob_invi = async_mock.MagicMock( services=[TestConfig.test_did], @@ -1373,7 +1385,19 @@ async def test_receive_invitation_req_pres_v1_0_attachment_x(self): ) as didx_mgr_receive_invitation, async_mock.patch( "aries_cloudagent.protocols.out_of_band.v1_0.manager.InvitationMessage", autospec=True, - ) as inv_message_cls: + ) as inv_message_cls, async_mock.patch.object( + OutOfBandManager, + "create_handshake_reuse_message", + autospec=True, + ) as oob_mgr_create_reuse_msg, async_mock.patch.object( + OutOfBandManager, + "check_reuse_msg_state", + autospec=True, + ) as oob_mgr_check_reuse_state, async_mock.patch.object( + OutOfBandManager, + "send_reuse_message", + autospec=True, + ) as oob_mgr_send_reuse_message: mock_oob_invi = async_mock.MagicMock( handshake_protocols=[ pfx.qualify(HSProto.RFC23.name) for pfx in DIDCommPrefix diff --git a/aries_cloudagent/protocols/present_proof/dif/pres_exch.py b/aries_cloudagent/protocols/present_proof/dif/pres_exch.py index 4a2ef289cf..dafe0b8857 100644 --- a/aries_cloudagent/protocols/present_proof/dif/pres_exch.py +++ b/aries_cloudagent/protocols/present_proof/dif/pres_exch.py @@ -646,7 +646,7 @@ class Meta: ), example=( { - "oneOf": [ + "oneof_filter": [ [ {"uri": "https://www.w3.org/Test1#Test1"}, {"uri": "https://www.w3.org/Test2#Test2"}, diff --git a/aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py b/aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py index 86c8372206..638ce0fa3e 100644 --- a/aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py +++ b/aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py @@ -10,6 +10,7 @@ """ import pytz import re +import logging from datetime import datetime from dateutil.parser import parse as dateutil_parser @@ -61,6 +62,7 @@ PRESENTATION_SUBMISSION_JSONLD_TYPE = "PresentationSubmission" PYTZ_TIMEZONE_PATTERN = re.compile(r"(([a-zA-Z]+)(?:\/)([a-zA-Z]+))") LIST_INDEX_PATTERN = re.compile(r"\[(\W+)\]|\[(\d+)\]") +LOGGER = logging.getLogger(__name__) class DIFPresExchError(BaseError): @@ -789,8 +791,11 @@ def exclusive_minimum_check(self, val: any, _filter: Filter) -> bool: given_date = self.string_to_timezone_aware_datetime(str(val)) return given_date > to_compare_date else: - if self.is_numeric(val): + try: + val = self.is_numeric(val) return val > _filter.exclusive_min + except DIFPresExchError as err: + LOGGER.error(err) return False except (TypeError, ValueError, ParserError): return False @@ -817,8 +822,11 @@ def exclusive_maximum_check(self, val: any, _filter: Filter) -> bool: given_date = self.string_to_timezone_aware_datetime(str(val)) return given_date < to_compare_date else: - if self.is_numeric(val): + try: + val = self.is_numeric(val) return val < _filter.exclusive_max + except DIFPresExchError as err: + LOGGER.error(err) return False except (TypeError, ValueError, ParserError): return False @@ -845,8 +853,11 @@ def maximum_check(self, val: any, _filter: Filter) -> bool: given_date = self.string_to_timezone_aware_datetime(str(val)) return given_date <= to_compare_date else: - if self.is_numeric(val): + try: + val = self.is_numeric(val) return val <= _filter.maximum + except DIFPresExchError as err: + LOGGER.error(err) return False except (TypeError, ValueError, ParserError): return False @@ -873,8 +884,11 @@ def minimum_check(self, val: any, _filter: Filter) -> bool: given_date = self.string_to_timezone_aware_datetime(str(val)) return given_date >= to_compare_date else: - if self.is_numeric(val): + try: + val = self.is_numeric(val) return val >= _filter.minimum + except DIFPresExchError as err: + LOGGER.error(err) return False except (TypeError, ValueError, ParserError): return False @@ -1147,19 +1161,31 @@ async def apply_requirements( nested_result=nested_result, exclude=exclude ) - def is_numeric(self, val: any) -> bool: + def is_numeric(self, val: any): """ Check if val is an int or float. Args: val: to check Return: - bool + numeric value + Raises: + DIFPresExchError: Provided value has invalid/incompatible type + """ if isinstance(val, float) or isinstance(val, int): - return True - else: - return False + return val + elif isinstance(val, str): + if val.isdigit(): + return int(val) + else: + try: + return float(val) + except ValueError: + pass + raise DIFPresExchError( + "Invalid type provided for comparision/numeric operation." + ) async def merge_nested_results( self, nested_result: Sequence[dict], exclude: dict diff --git a/aries_cloudagent/protocols/present_proof/dif/tests/test_pres_exch_handler.py b/aries_cloudagent/protocols/present_proof/dif/tests/test_pres_exch_handler.py index 67a56badb0..5c8fc2039b 100644 --- a/aries_cloudagent/protocols/present_proof/dif/tests/test_pres_exch_handler.py +++ b/aries_cloudagent/protocols/present_proof/dif/tests/test_pres_exch_handler.py @@ -1669,9 +1669,14 @@ def test_subject_is_issuer(self, setup_tuple, profile): @pytest.mark.asyncio def test_is_numeric(self, profile): dif_pres_exch_handler = DIFPresExchHandler(profile) - assert dif_pres_exch_handler.is_numeric("test") is False - assert dif_pres_exch_handler.is_numeric(1) is True - assert dif_pres_exch_handler.is_numeric(2 + 3j) is False + with pytest.raises(DIFPresExchError): + dif_pres_exch_handler.is_numeric("test") + assert dif_pres_exch_handler.is_numeric(1) == 1 + assert dif_pres_exch_handler.is_numeric(2.20) == 2.20 + assert dif_pres_exch_handler.is_numeric("2.20") == 2.20 + assert dif_pres_exch_handler.is_numeric("2") == 2 + with pytest.raises(DIFPresExchError): + dif_pres_exch_handler.is_numeric(2 + 3j) @pytest.mark.asyncio def test_filter_no_match(self, profile): diff --git a/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py b/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py index 33f1b653d0..1a196fe270 100644 --- a/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py +++ b/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/handler.py @@ -459,16 +459,11 @@ async def verify_pres(self, pres_ex_record: V20PresExRecord) -> V20PresExRecord: pres_request = pres_ex_record.pres_request.attachment( DIFPresFormatHandler.format ) + challenge = None if "options" in pres_request: - challenge = pres_request["options"].get("challenge") - else: - raise V20PresFormatHandlerError( - "No options [challenge] set for the presentation request" - ) + challenge = pres_request["options"].get("challenge", str(uuid4())) if not challenge: - raise V20PresFormatHandlerError( - "No challenge is set for the presentation request" - ) + challenge = str(uuid4()) pres_ver_result = await verify_presentation( presentation=dif_proof, suites=await self._get_all_suites(wallet=wallet), diff --git a/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/tests/test_handler.py b/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/tests/test_handler.py index 4ac318c506..1b648f2aeb 100644 --- a/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/tests/test_handler.py +++ b/aries_cloudagent/protocols/present_proof/v2_0/formats/dif/tests/test_handler.py @@ -975,50 +975,7 @@ async def test_verify_pres_no_challenge(self): error_msg="error", ) - with self.assertRaises(V20PresFormatHandlerError): - await self.handler.verify_pres(record) - - async def test_verify_pres_invalid_challenge(self): - test_pd = deepcopy(DIF_PRES_REQUEST_B) - del test_pd["options"] - dif_pres_request = V20PresRequest( - formats=[ - V20PresFormat( - attach_id="dif", - format_=ATTACHMENT_FORMAT[PRES_20_REQUEST][ - V20PresFormat.Format.DIF.api - ], - ) - ], - request_presentations_attach=[ - AttachDecorator.data_json(test_pd, ident="dif") - ], - ) - dif_pres = V20Pres( - formats=[ - V20PresFormat( - attach_id="dif", - format_=ATTACHMENT_FORMAT[PRES_20][V20PresFormat.Format.DIF.api], - ) - ], - presentations_attach=[AttachDecorator.data_json(DIF_PRES, ident="dif")], - ) - record = V20PresExRecord( - pres_ex_id="pxid", - thread_id="thid", - connection_id="conn_id", - initiator="init", - role="role", - state="state", - pres_request=dif_pres_request, - pres=dif_pres, - verified="false", - auto_present=True, - error_msg="error", - ) - - with self.assertRaises(V20PresFormatHandlerError): - await self.handler.verify_pres(record) + assert await self.handler.verify_pres(record) async def test_create_pres_cred_limit_disclosure_no_bbs(self): test_pd = deepcopy(DIF_PRES_REQUEST_B) diff --git a/aries_cloudagent/protocols/present_proof/v2_0/manager.py b/aries_cloudagent/protocols/present_proof/v2_0/manager.py index 688f4ca35f..7693942921 100644 --- a/aries_cloudagent/protocols/present_proof/v2_0/manager.py +++ b/aries_cloudagent/protocols/present_proof/v2_0/manager.py @@ -339,7 +339,8 @@ async def receive_pres(self, message: V20Pres, conn_record: ConnRecord): ) pres_ex_record.pres = message pres_ex_record.state = V20PresExRecord.STATE_PRESENTATION_RECEIVED - + if not pres_ex_record.connection_id: + pres_ex_record.connection_id = conn_record.connection_id async with self._profile.session() as session: await pres_ex_record.save(session, reason="receive v2.0 presentation")