Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Ensure supported DID before calling Rotate #3380

Merged
merged 11 commits into from
Dec 17, 2024
2 changes: 1 addition & 1 deletion acapy_agent/connections/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ async def resolve_didcomm_services(
try:
doc_dict: dict = await resolver.resolve(self._profile, did, service_accept)
doc: ResolvedDocument = pydid.deserialize_document(doc_dict, strict=True)
except ResolverError as error:
except (ResolverError, ValueError) as error:
raise BaseConnectionManagerError("Failed to resolve DID services") from error

if not doc.service:
Expand Down
2 changes: 1 addition & 1 deletion acapy_agent/ledger/indy_vdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ async def credential_definition_id2schema_id(self, credential_definition_id):
seq_no = tokens[3]
return (await self.get_schema(seq_no))["id"]

async def get_key_for_did(self, did: str) -> str:
async def get_key_for_did(self, did: str) -> Optional[str]:
"""Fetch the verkey for a ledger DID.

Args:
Expand Down
4 changes: 2 additions & 2 deletions acapy_agent/protocols/did_rotate/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def receive_rotate(self, conn: ConnRecord, rotate: Rotate) -> RotateRecord
)

try:
await self._ensure_supported_did(rotate.to_did)
await self.ensure_supported_did(rotate.to_did)
except ReportableDIDRotateError as err:
responder = self.profile.inject(BaseResponder)
err.message.assign_thread_from(rotate)
Expand Down Expand Up @@ -234,7 +234,7 @@ async def receive_hangup(self, conn: ConnRecord):
async with self.profile.session() as session:
await conn.delete_record(session)

async def _ensure_supported_did(self, did: str):
async def ensure_supported_did(self, did: str):
"""Check if the DID is supported."""
resolver = self.profile.inject(DIDResolver)
conn_mgr = BaseConnectionManager(self.profile)
Expand Down
4 changes: 2 additions & 2 deletions acapy_agent/protocols/did_rotate/v1_0/message_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
MESSAGE_TYPES = DIDCommPrefix.qualify_all(
{
ROTATE: f"{PROTOCOL_PACKAGE}.messages.rotate.Rotate",
ACK: f"{PROTOCOL_PACKAGE}.messages.ack.Ack",
ACK: f"{PROTOCOL_PACKAGE}.messages.ack.RotateAck",
HANGUP: f"{PROTOCOL_PACKAGE}.messages.hangup.Hangup",
PROBLEM_REPORT: f"{PROTOCOL_PACKAGE}.messages.problem_report.ProblemReport",
PROBLEM_REPORT: f"{PROTOCOL_PACKAGE}.messages.problem_report.RotateProblemReport",
}
)
17 changes: 16 additions & 1 deletion acapy_agent/protocols/did_rotate/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from ....messaging.models.openapi import OpenAPISchema
from ....messaging.valid import DID_WEB_EXAMPLE, UUID4_EXAMPLE
from ....storage.error import StorageNotFoundError
from .manager import DIDRotateManager
from .manager import (
DIDRotateManager,
UnresolvableDIDCommServicesError,
UnresolvableDIDError,
UnsupportedDIDMethodError,
)
from .message_types import SPEC_URI
from .messages.hangup import HangupSchema as HangupMessageSchema
from .messages.rotate import RotateSchema as RotateMessageSchema
Expand Down Expand Up @@ -63,6 +68,16 @@ async def rotate(request: web.BaseRequest):
body = await request.json()
to_did = body["to_did"]

# Validate DID before proceeding
try:
await did_rotate_mgr.ensure_supported_did(to_did)
except (
UnsupportedDIDMethodError,
UnresolvableDIDError,
UnresolvableDIDCommServicesError,
) as err:
raise web.HTTPBadRequest(reason=str(err)) from err

async with context.session() as session:
try:
conn = await ConnRecord.retrieve_by_id(session, connection_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def test_receive_rotate_x(self):

with (
mock.patch.object(
self.manager, "_ensure_supported_did", side_effect=test_problem_report
self.manager, "ensure_supported_did", side_effect=test_problem_report
),
mock.patch.object(self.responder, "send", mock.CoroutineMock()) as mock_send,
):
Expand Down
35 changes: 33 additions & 2 deletions acapy_agent/protocols/did_rotate/v1_0/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ async def asyncSetUp(self):
"DIDRotateManager",
autospec=True,
return_value=mock.MagicMock(
rotate_my_did=mock.CoroutineMock(return_value=generate_mock_rotate_message())
rotate_my_did=mock.CoroutineMock(return_value=generate_mock_rotate_message()),
ensure_supported_did=mock.CoroutineMock(),
),
)
async def test_rotate(self, *_):
Expand Down Expand Up @@ -102,7 +103,15 @@ async def test_hangup(self, *_):
}
)

async def test_rotate_conn_not_found(self):
@mock.patch.object(
test_module,
"DIDRotateManager",
autospec=True,
return_value=mock.MagicMock(
ensure_supported_did=mock.CoroutineMock(),
),
)
async def test_rotate_conn_not_found(self, *_):
self.request.match_info = {"conn_id": test_conn_id}
self.request.json = mock.CoroutineMock(return_value=test_valid_rotate_request)

Expand All @@ -114,6 +123,28 @@ async def test_rotate_conn_not_found(self):
with self.assertRaises(test_module.web.HTTPNotFound):
await test_module.rotate(self.request)

async def test_rotate_did_validation_errors(self):
self.request.match_info = {"conn_id": test_conn_id}
self.request.json = mock.CoroutineMock(return_value=test_valid_rotate_request)

for error_class in [
test_module.UnsupportedDIDMethodError,
test_module.UnresolvableDIDError,
test_module.UnresolvableDIDCommServicesError,
]:
with mock.patch.object(
test_module,
"DIDRotateManager",
autospec=True,
return_value=mock.MagicMock(
ensure_supported_did=mock.CoroutineMock(
side_effect=error_class("test error")
),
),
):
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.rotate(self.request)


if __name__ == "__main__":
unittest.main()
Loading