Skip to content

Commit

Permalink
Merge pull request #1676 from shaangill025/issue#1662
Browse files Browse the repository at this point in the history
Fix DIF PresExch and OOB request_attach delete unused connection
  • Loading branch information
andrewwhitehead authored Mar 29, 2022
2 parents a5594dd + 0dd5aa2 commit bdba6b5
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 132 deletions.
135 changes: 71 additions & 64 deletions aries_cloudagent/protocols/out_of_band/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging

from typing import Mapping, Sequence
from typing import Mapping, Sequence, Optional

from ....connections.base_manager import BaseConnectionManager
from ....connections.models.conn_record import ConnRecord
Expand Down Expand Up @@ -477,69 +477,7 @@ async def receive_invitation(
and num_included_req_attachments == 0
and use_existing_connection
):
await self.create_handshake_reuse_message(
invi_msg=invitation,
conn_record=conn_rec,
)
try:
await asyncio.wait_for(
self.check_reuse_msg_state(
conn_rec=conn_rec,
),
15,
)
async with self.profile.session() as session:
await conn_rec.metadata_delete(
session=session, key="reuse_msg_id"
)

msg_state = await conn_rec.metadata_get(
session, "reuse_msg_state"
)

if msg_state == "not_accepted":
conn_rec = None
else:
async with self.profile.session() as session:
await conn_rec.metadata_delete(
session=session, key="reuse_msg_state"
)
# refetch connection for accurate state after handshake
conn_rec = await ConnRecord.retrieve_by_id(
session=session, record_id=conn_rec.connection_id
)
except asyncio.TimeoutError:
# If no reuse_accepted or problem_report message was received within
# the 15s timeout then a new connection to be created
async with self.profile.session() as session:
sent_reuse_msg_id = await conn_rec.metadata_get(
session=session, key="reuse_msg_id"
)
await conn_rec.metadata_delete(
session=session, key="reuse_msg_id"
)
await conn_rec.metadata_delete(
session=session, key="reuse_msg_state"
)
conn_rec.state = ConnRecord.State.ABANDONED.rfc160
await conn_rec.save(
session, reason="No HandshakeReuseAccept message received"
)
# Emit webhook
await self.profile.notify(
REUSE_ACCEPTED_WEBHOOK_TOPIC,
{
"thread_id": sent_reuse_msg_id,
"connection_id": conn_rec.connection_id,
"state": "rejected",
"comment": (
"No HandshakeReuseAccept message received, "
f"connection {conn_rec.connection_id} ",
f"and invitation {invitation._id}",
),
},
)
conn_rec = None
conn_rec = await self.send_reuse_message(invitation, conn_rec)
# Inverse of the following cases
# Handshake_Protocol not included
# Request_Attachment included
Expand Down Expand Up @@ -609,6 +547,8 @@ async def receive_invitation(
# Request Attach
if len(invitation.requests_attach) >= 1 and conn_rec is not None:
req_attach = invitation.requests_attach[0]
if use_existing_connection:
conn_rec = await self.send_reuse_message(invitation, conn_rec)
if isinstance(req_attach, AttachDecorator):
if req_attach.data is not None:
unq_req_attach_type = DIDCommPrefix.unqualify(
Expand Down Expand Up @@ -768,6 +708,73 @@ async def _process_pres_request_v1(
)
)

async def send_reuse_message(
self, invitation: InvitationMessage, conn_rec: ConnRecord
) -> Optional[ConnRecord]:
"""
Create and wait for handshake reuse message.
Args:
invitation: invitation message
conn_rec: connection record
"""
await self.create_handshake_reuse_message(
invi_msg=invitation,
conn_record=conn_rec,
)
try:
await asyncio.wait_for(
self.check_reuse_msg_state(
conn_rec=conn_rec,
),
15,
)
async with self.profile.session() as session:
await conn_rec.metadata_delete(session=session, key="reuse_msg_id")

msg_state = await conn_rec.metadata_get(session, "reuse_msg_state")

if msg_state == "not_accepted":
conn_rec = None
else:
async with self.profile.session() as session:
await conn_rec.metadata_delete(
session=session, key="reuse_msg_state"
)
# refetch connection for accurate state after handshake
conn_rec = await ConnRecord.retrieve_by_id(
session=session, record_id=conn_rec.connection_id
)
return conn_rec
except asyncio.TimeoutError:
# If no reuse_accepted or problem_report message was received within
# the 15s timeout then a new connection to be created
async with self.profile.session() as session:
sent_reuse_msg_id = await conn_rec.metadata_get(
session=session, key="reuse_msg_id"
)
await conn_rec.metadata_delete(session=session, key="reuse_msg_id")
await conn_rec.metadata_delete(session=session, key="reuse_msg_state")
conn_rec.state = ConnRecord.State.ABANDONED.rfc160
await conn_rec.save(
session, reason="No HandshakeReuseAccept message received"
)
# Emit webhook
await self.profile.notify(
REUSE_ACCEPTED_WEBHOOK_TOPIC,
{
"thread_id": sent_reuse_msg_id,
"connection_id": conn_rec.connection_id,
"state": "rejected",
"comment": (
"No HandshakeReuseAccept message received, "
f"connection {conn_rec.connection_id} ",
f"and invitation {invitation._id}",
),
},
)
return None

async def _process_pres_request_v2(
self,
req_attach: AttachDecorator,
Expand Down
28 changes: 26 additions & 2 deletions aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,19 @@ async def test_receive_invitation_attachment_x(self):
) as didx_mgr_receive_invitation, async_mock.patch(
"aries_cloudagent.protocols.out_of_band.v1_0.manager.InvitationMessage",
autospec=True,
) as inv_message_cls:
) as inv_message_cls, async_mock.patch.object(
OutOfBandManager,
"create_handshake_reuse_message",
autospec=True,
) as oob_mgr_create_reuse_msg, async_mock.patch.object(
OutOfBandManager,
"check_reuse_msg_state",
autospec=True,
) as oob_mgr_check_reuse_state, async_mock.patch.object(
OutOfBandManager,
"send_reuse_message",
autospec=True,
) as oob_mgr_send_reuse_message:

mock_oob_invi = async_mock.MagicMock(
services=[TestConfig.test_did],
Expand All @@ -1373,7 +1385,19 @@ async def test_receive_invitation_req_pres_v1_0_attachment_x(self):
) as didx_mgr_receive_invitation, async_mock.patch(
"aries_cloudagent.protocols.out_of_band.v1_0.manager.InvitationMessage",
autospec=True,
) as inv_message_cls:
) as inv_message_cls, async_mock.patch.object(
OutOfBandManager,
"create_handshake_reuse_message",
autospec=True,
) as oob_mgr_create_reuse_msg, async_mock.patch.object(
OutOfBandManager,
"check_reuse_msg_state",
autospec=True,
) as oob_mgr_check_reuse_state, async_mock.patch.object(
OutOfBandManager,
"send_reuse_message",
autospec=True,
) as oob_mgr_send_reuse_message:
mock_oob_invi = async_mock.MagicMock(
handshake_protocols=[
pfx.qualify(HSProto.RFC23.name) for pfx in DIDCommPrefix
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/protocols/present_proof/dif/pres_exch.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ class Meta:
),
example=(
{
"oneOf": [
"oneof_filter": [
[
{"uri": "https://www.w3.org/Test1#Test1"},
{"uri": "https://www.w3.org/Test2#Test2"},
Expand Down
44 changes: 35 additions & 9 deletions aries_cloudagent/protocols/present_proof/dif/pres_exch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""
import pytz
import re
import logging

from datetime import datetime
from dateutil.parser import parse as dateutil_parser
Expand Down Expand Up @@ -61,6 +62,7 @@
PRESENTATION_SUBMISSION_JSONLD_TYPE = "PresentationSubmission"
PYTZ_TIMEZONE_PATTERN = re.compile(r"(([a-zA-Z]+)(?:\/)([a-zA-Z]+))")
LIST_INDEX_PATTERN = re.compile(r"\[(\W+)\]|\[(\d+)\]")
LOGGER = logging.getLogger(__name__)


class DIFPresExchError(BaseError):
Expand Down Expand Up @@ -789,8 +791,11 @@ def exclusive_minimum_check(self, val: any, _filter: Filter) -> bool:
given_date = self.string_to_timezone_aware_datetime(str(val))
return given_date > to_compare_date
else:
if self.is_numeric(val):
try:
val = self.is_numeric(val)
return val > _filter.exclusive_min
except DIFPresExchError as err:
LOGGER.error(err)
return False
except (TypeError, ValueError, ParserError):
return False
Expand All @@ -817,8 +822,11 @@ def exclusive_maximum_check(self, val: any, _filter: Filter) -> bool:
given_date = self.string_to_timezone_aware_datetime(str(val))
return given_date < to_compare_date
else:
if self.is_numeric(val):
try:
val = self.is_numeric(val)
return val < _filter.exclusive_max
except DIFPresExchError as err:
LOGGER.error(err)
return False
except (TypeError, ValueError, ParserError):
return False
Expand All @@ -845,8 +853,11 @@ def maximum_check(self, val: any, _filter: Filter) -> bool:
given_date = self.string_to_timezone_aware_datetime(str(val))
return given_date <= to_compare_date
else:
if self.is_numeric(val):
try:
val = self.is_numeric(val)
return val <= _filter.maximum
except DIFPresExchError as err:
LOGGER.error(err)
return False
except (TypeError, ValueError, ParserError):
return False
Expand All @@ -873,8 +884,11 @@ def minimum_check(self, val: any, _filter: Filter) -> bool:
given_date = self.string_to_timezone_aware_datetime(str(val))
return given_date >= to_compare_date
else:
if self.is_numeric(val):
try:
val = self.is_numeric(val)
return val >= _filter.minimum
except DIFPresExchError as err:
LOGGER.error(err)
return False
except (TypeError, ValueError, ParserError):
return False
Expand Down Expand Up @@ -1147,19 +1161,31 @@ async def apply_requirements(
nested_result=nested_result, exclude=exclude
)

def is_numeric(self, val: any) -> bool:
def is_numeric(self, val: any):
"""
Check if val is an int or float.
Args:
val: to check
Return:
bool
numeric value
Raises:
DIFPresExchError: Provided value has invalid/incompatible type
"""
if isinstance(val, float) or isinstance(val, int):
return True
else:
return False
return val
elif isinstance(val, str):
if val.isdigit():
return int(val)
else:
try:
return float(val)
except ValueError:
pass
raise DIFPresExchError(
"Invalid type provided for comparision/numeric operation."
)

async def merge_nested_results(
self, nested_result: Sequence[dict], exclude: dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1669,9 +1669,14 @@ def test_subject_is_issuer(self, setup_tuple, profile):
@pytest.mark.asyncio
def test_is_numeric(self, profile):
dif_pres_exch_handler = DIFPresExchHandler(profile)
assert dif_pres_exch_handler.is_numeric("test") is False
assert dif_pres_exch_handler.is_numeric(1) is True
assert dif_pres_exch_handler.is_numeric(2 + 3j) is False
with pytest.raises(DIFPresExchError):
dif_pres_exch_handler.is_numeric("test")
assert dif_pres_exch_handler.is_numeric(1) == 1
assert dif_pres_exch_handler.is_numeric(2.20) == 2.20
assert dif_pres_exch_handler.is_numeric("2.20") == 2.20
assert dif_pres_exch_handler.is_numeric("2") == 2
with pytest.raises(DIFPresExchError):
dif_pres_exch_handler.is_numeric(2 + 3j)

@pytest.mark.asyncio
def test_filter_no_match(self, profile):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,11 @@ async def verify_pres(self, pres_ex_record: V20PresExRecord) -> V20PresExRecord:
pres_request = pres_ex_record.pres_request.attachment(
DIFPresFormatHandler.format
)
challenge = None
if "options" in pres_request:
challenge = pres_request["options"].get("challenge")
else:
raise V20PresFormatHandlerError(
"No options [challenge] set for the presentation request"
)
challenge = pres_request["options"].get("challenge", str(uuid4()))
if not challenge:
raise V20PresFormatHandlerError(
"No challenge is set for the presentation request"
)
challenge = str(uuid4())
pres_ver_result = await verify_presentation(
presentation=dif_proof,
suites=await self._get_all_suites(wallet=wallet),
Expand Down
Loading

0 comments on commit bdba6b5

Please sign in to comment.