Skip to content

Commit

Permalink
Merge pull request #1525 from shaangill025/issue_1524
Browse files Browse the repository at this point in the history
OOB: Fixes issues with multiple public explicit invitation and unused 0160 connection
  • Loading branch information
ianco authored Dec 1, 2021
2 parents 04a0264 + ff91cab commit 67539b9
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 40 deletions.
19 changes: 19 additions & 0 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ async def retrieve_by_invitation_key(

return await cls.retrieve_by_tag_filter(session, tag_filter, post_filter)

@classmethod
async def retrieve_by_invitation_msg_id(
cls, session: ProfileSession, invitation_msg_id: str, their_role: str = None
) -> "ConnRecord":
"""Retrieve a connection record by invitation_msg_id.
Args:
session: The active profile session
invitation_msg_id: Invitation message identifier
initiator: Filter by the initiator value
"""
post_filter = {
"state": cls.State.INVITATION.rfc160,
"invitation_msg_id": invitation_msg_id,
}
if their_role:
post_filter["their_role"] = cls.Role.get(their_role).rfc160
return await cls.query(session, post_filter_positive=post_filter)

@classmethod
async def retrieve_by_request_id(
cls, session: ProfileSession, request_id: str
Expand Down
23 changes: 23 additions & 0 deletions aries_cloudagent/connections/models/tests/test_conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,29 @@ async def test_retrieve_by_invitation_key(self):
their_role=ConnRecord.Role.REQUESTER.rfc23,
)

async def test_retrieve_by_invitation_msg_id(self):
record = ConnRecord(
my_did=self.test_did,
their_did=self.test_target_did,
their_role=ConnRecord.Role.RESPONDER.rfc160,
state=ConnRecord.State.INVITATION.rfc160,
invitation_msg_id="test123",
)
await record.save(self.session)
results = await ConnRecord.retrieve_by_invitation_msg_id(
session=self.session,
invitation_msg_id="test123",
their_role=ConnRecord.Role.RESPONDER.rfc160,
)
assert len(results) == 1
assert results[0] == record
results = await ConnRecord.retrieve_by_invitation_msg_id(
session=self.session,
invitation_msg_id="test123",
their_role=ConnRecord.Role.REQUESTER.rfc160,
)
assert len(results) == 0

async def test_retrieve_by_request_id(self):
record = ConnRecord(
my_did=self.test_did,
Expand Down
36 changes: 22 additions & 14 deletions aries_cloudagent/protocols/connections/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ async def create_request(
connection=ConnectionDetail(did=connection.my_did, did_doc=did_doc),
image_url=self.profile.settings.get("image_url"),
)
request.assign_thread_id(thid=request._id, pthid=connection.invitation_msg_id)

# Update connection state
connection.request_id = request._id
Expand Down Expand Up @@ -601,21 +602,28 @@ async def receive_request(
# Add mapping for multitenant relay
if multitenant_mgr and wallet_id:
await multitenant_mgr.add_key(wallet_id, my_info.verkey)

connection = ConnRecord(
invitation_key=connection_key,
my_did=my_info.did,
their_role=ConnRecord.Role.RESPONDER.rfc160,
their_did=request.connection.did,
their_label=request.label,
accept=(
ConnRecord.ACCEPT_AUTO
if self.profile.settings.get("debug.auto_accept_requests")
else ConnRecord.ACCEPT_MANUAL
),
state=ConnRecord.State.REQUEST.rfc160,
connection_protocol=CONN_PROTO,
async with self.profile.session() as session:
conn_records = await ConnRecord.retrieve_by_invitation_msg_id(
session=session,
invitation_msg_id=request._thread.pthid,
their_role=ConnRecord.Role.REQUESTER.rfc160,
)
if len(conn_records) == 1:
connection = conn_records[0]
else:
connection = ConnRecord()
connection.invitation_key = connection_key
connection.my_did = my_info.did
connection.their_role = ConnRecord.Role.RESPONDER.rfc160
connection.their_did = request.connection.did
connection.their_label = request.label
connection.accept = (
ConnRecord.ACCEPT_AUTO
if self.profile.settings.get("debug.auto_accept_requests")
else ConnRecord.ACCEPT_MANUAL
)
connection.state = ConnRecord.State.REQUEST.rfc160
connection.connection_protocol = CONN_PROTO
async with self.profile.session() as session:
await connection.save(
session, reason="Received connection request from public DID"
Expand Down
59 changes: 55 additions & 4 deletions aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ async def test_create_request_mediation_not_granted(self):
record, mediation_id=mediation_record.mediation_id
)

async def test_receive_request_public_did(self):
async def test_receive_request_public_did_oob_invite(self):
async with self.profile.session() as session:
mock_request = async_mock.MagicMock()
mock_request.connection = async_mock.MagicMock()
Expand Down Expand Up @@ -706,7 +706,52 @@ async def test_receive_request_public_did(self):
ConnRecord, "retrieve_by_id", autospec=True
) as mock_conn_retrieve_by_id, async_mock.patch.object(
ConnRecord, "retrieve_request", autospec=True
):
), async_mock.patch.object(
ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock()
) as mock_conn_retrieve_by_invitation_msg_id:
mock_conn_retrieve_by_invitation_msg_id.return_value = [ConnRecord()]
conn_rec = await self.manager.receive_request(mock_request, receipt)
assert conn_rec

messages = self.responder.messages
assert len(messages) == 1
(result, target) = messages[0]
assert type(result) == ConnectionResponse
assert "connection_id" in target

async def test_receive_request_public_did_conn_invite(self):
async with self.profile.session() as session:
mock_request = async_mock.MagicMock()
mock_request.connection = async_mock.MagicMock()
mock_request.connection.did = self.test_did
mock_request.connection.did_doc = async_mock.MagicMock()
mock_request.connection.did_doc.did = self.test_did

receipt = MessageReceipt(
recipient_did=self.test_did, recipient_did_public=True
)
await session.wallet.create_local_did(
method=DIDMethod.SOV,
key_type=KeyType.ED25519,
seed=None,
did=self.test_did,
)

self.context.update_settings({"public_invites": True})
with async_mock.patch.object(
ConnRecord, "connection_id", autospec=True
), async_mock.patch.object(
ConnRecord, "save", autospec=True
) as mock_conn_rec_save, async_mock.patch.object(
ConnRecord, "attach_request", autospec=True
) as mock_conn_attach_request, async_mock.patch.object(
ConnRecord, "retrieve_by_id", autospec=True
) as mock_conn_retrieve_by_id, async_mock.patch.object(
ConnRecord, "retrieve_request", autospec=True
), async_mock.patch.object(
ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock()
) as mock_conn_retrieve_by_invitation_msg_id:
mock_conn_retrieve_by_invitation_msg_id.return_value = []
conn_rec = await self.manager.receive_request(mock_request, receipt)
assert conn_rec

Expand Down Expand Up @@ -800,7 +845,10 @@ async def test_receive_request_public_multitenant(self):
InMemoryWallet, "create_local_did", autospec=True
) as mock_wallet_create_local_did, async_mock.patch.object(
InMemoryWallet, "get_local_did", autospec=True
) as mock_wallet_get_local_did:
) as mock_wallet_get_local_did, async_mock.patch.object(
ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock()
) as mock_conn_retrieve_by_invitation_msg_id:
mock_conn_retrieve_by_invitation_msg_id.return_value = [ConnRecord()]
mock_wallet_create_local_did.return_value = DIDInfo(
new_info.did,
new_info.verkey,
Expand Down Expand Up @@ -940,7 +988,10 @@ async def test_receive_request_public_did_no_auto_accept(self):
ConnRecord, "retrieve_by_id", autospec=True
) as mock_conn_retrieve_by_id, async_mock.patch.object(
ConnRecord, "retrieve_request", autospec=True
):
), async_mock.patch.object(
ConnRecord, "retrieve_by_invitation_msg_id", async_mock.CoroutineMock()
) as mock_conn_retrieve_by_invitation_msg_id:
mock_conn_retrieve_by_invitation_msg_id.return_value = [ConnRecord()]
conn_rec = await self.manager.receive_request(mock_request, receipt)
assert conn_rec

Expand Down
28 changes: 18 additions & 10 deletions aries_cloudagent/protocols/didexchange/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ async def receive_request(
# Determine what key will need to sign the response
if recipient_verkey: # peer DID
connection_key = recipient_verkey
try:
async with self.profile.session() as session:
conn_rec = await ConnRecord.retrieve_by_invitation_key(
session=session,
invitation_key=connection_key,
their_role=ConnRecord.Role.REQUESTER.rfc23,
)
except StorageNotFoundError:
if recipient_verkey:
raise DIDXManagerError(
"No explicit invitation found for pairwise connection "
f"in state {ConnRecord.State.INVITATION.rfc23}: "
"a prior connection request may have updated the connection state"
)
else:
if not self.profile.settings.get("public_invites"):
raise DIDXManagerError(
Expand All @@ -391,20 +405,14 @@ async def receive_request(
raise DIDXManagerError(f"Request DID {recipient_did} is not public")
connection_key = my_info.verkey

try:
async with self.profile.session() as session:
conn_rec = await ConnRecord.retrieve_by_invitation_key(
conn_records = await ConnRecord.retrieve_by_invitation_msg_id(
session=session,
invitation_key=connection_key,
invitation_msg_id=request._thread.pthid,
their_role=ConnRecord.Role.REQUESTER.rfc23,
)
except StorageNotFoundError:
if recipient_verkey:
raise DIDXManagerError(
"No explicit invitation found for pairwise connection "
f"in state {ConnRecord.State.INVITATION.rfc23}: "
"a prior connection request may have updated the connection state"
)
if len(conn_records) == 1:
conn_rec = conn_records[0]

if conn_rec: # invitation was explicit
connection_key = conn_rec.invitation_key
Expand Down
24 changes: 12 additions & 12 deletions aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ async def test_receive_request_explicit_public_did(self):
mock_conn_rec_cls.retrieve_by_id = async_mock.CoroutineMock(
return_value=async_mock.MagicMock(save=async_mock.CoroutineMock())
)
mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock(
return_value=mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_msg_id = (
async_mock.CoroutineMock(return_value=[mock_conn_record])
)
mock_conn_rec_cls.return_value = mock_conn_record

Expand Down Expand Up @@ -783,8 +783,8 @@ async def test_receive_request_public_did_no_did_doc_attachment(self):
save=async_mock.CoroutineMock(),
)
mock_conn_rec_cls.return_value = mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock(
return_value=mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_msg_id = (
async_mock.CoroutineMock(return_value=[mock_conn_record])
)

mock_did_posture.get = async_mock.MagicMock(
Expand Down Expand Up @@ -890,8 +890,8 @@ async def test_receive_request_public_did_x_wrong_did(self):
save=async_mock.CoroutineMock(),
)
mock_conn_rec_cls.return_value = mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock(
return_value=mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_msg_id = (
async_mock.CoroutineMock(return_value=[mock_conn_record])
)
mock_did_doc_from_json.return_value = async_mock.MagicMock(
did="wrong-did"
Expand Down Expand Up @@ -950,8 +950,8 @@ async def test_receive_request_public_did_x_did_doc_attach_bad_sig(self):
save=async_mock.CoroutineMock(),
)
mock_conn_rec_cls.return_value = mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock(
return_value=mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_msg_id = (
async_mock.CoroutineMock(return_value=[mock_conn_record])
)

mock_did_posture.get = async_mock.MagicMock(
Expand Down Expand Up @@ -1068,8 +1068,8 @@ async def test_receive_request_public_did_no_auto_accept(self):
save=async_mock.CoroutineMock(),
)
mock_conn_rec_cls.return_value = mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock(
return_value=mock_conn_record
mock_conn_rec_cls.retrieve_by_invitation_msg_id = (
async_mock.CoroutineMock(return_value=[mock_conn_record])
)

mock_did_posture.get = async_mock.MagicMock(
Expand Down Expand Up @@ -1306,8 +1306,8 @@ async def test_receive_request_implicit_multitenant(self):
retrieve_request=async_mock.CoroutineMock(),
)
mock_conn_rec_cls.return_value = mock_conn_rec
mock_conn_rec_cls.retrieve_by_invitation_key = async_mock.CoroutineMock(
side_effect=StorageNotFoundError()
mock_conn_rec_cls.retrieve_by_invitation_msg_id = (
async_mock.CoroutineMock(return_value=[])
)

mock_did_posture.get = async_mock.MagicMock(
Expand Down

0 comments on commit 67539b9

Please sign in to comment.