diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index 263a44b005..e31eacdd61 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -115,7 +115,7 @@ async def setup(self): # Register all outbound transports self.outbound_transport_manager = OutboundTransportManager( - context, self.handle_not_delivered + self.root_profile, self.handle_not_delivered ) await self.outbound_transport_manager.setup() diff --git a/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/handler.py b/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/handler.py index c0fffb53ee..1c42e97e5f 100644 --- a/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/handler.py +++ b/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/handler.py @@ -1,5 +1,7 @@ """V2.0 issue-credential linked data proof credential format handler.""" +from ......vc.ld_proofs.error import LinkedDataProofException +from ......vc.ld_proofs.check import get_properties_without_context import logging from typing import Mapping @@ -399,6 +401,17 @@ async def create_offer( detail = LDProofVCDetail.deserialize(offer_data) detail = await self._prepare_detail(detail) + document_loader = self.profile.inject(DocumentLoader) + missing_properties = get_properties_without_context( + detail.credential.serialize(), document_loader + ) + + if len(missing_properties) > 0: + raise LinkedDataProofException( + f"{len(missing_properties)} attributes dropped. " + f"Provide definitions in context to correct. {missing_properties}" + ) + # Make sure we can issue with the did and proof type await self._assert_can_issue_with_id_and_proof_type( detail.credential.issuer_id, detail.options.proof_type diff --git a/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/tests/test_handler.py b/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/tests/test_handler.py index 730bff1f0d..c4666a679e 100644 --- a/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/tests/test_handler.py +++ b/aries_cloudagent/protocols/issue_credential/v2_0/formats/ld_proof/tests/test_handler.py @@ -1,6 +1,8 @@ from copy import deepcopy +from .......vc.ld_proofs.error import LinkedDataProofException from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock +from unittest.mock import patch from marshmallow import ValidationError from .. import handler as test_module @@ -406,7 +408,9 @@ async def test_create_offer(self): LDProofCredFormatHandler, "_assert_can_issue_with_id_and_proof_type", async_mock.CoroutineMock(), - ) as mock_can_issue: + ) as mock_can_issue, patch.object( + test_module, "get_properties_without_context", return_value=[] + ): (cred_format, attachment) = await self.handler.create_offer( self.cred_proposal ) @@ -444,7 +448,7 @@ async def test_create_offer_adds_bbs_context(self): LDProofCredFormatHandler, "_assert_can_issue_with_id_and_proof_type", async_mock.CoroutineMock(), - ): + ), patch.object(test_module, "get_properties_without_context", return_value=[]): (cred_format, attachment) = await self.handler.create_offer(cred_proposal) # assert BBS url added to context @@ -457,6 +461,27 @@ async def test_create_offer_x_no_proposal(self): context.exception ) + async def test_create_offer_x_wrong_attributes(self): + missing_properties = ["foo"] + with async_mock.patch.object( + LDProofCredFormatHandler, + "_assert_can_issue_with_id_and_proof_type", + async_mock.CoroutineMock(), + ), patch.object( + test_module, + "get_properties_without_context", + return_value=missing_properties, + ), self.assertRaises( + LinkedDataProofException + ) as context: + await self.handler.create_offer(self.cred_proposal) + + assert ( + f"{len(missing_properties)} attributes dropped. " + f"Provide definitions in context to correct. {missing_properties}" + in str(context.exception) + ) + async def test_receive_offer(self): cred_ex_record = async_mock.MagicMock() cred_offer_message = async_mock.MagicMock() diff --git a/aries_cloudagent/protocols/issue_credential/v2_0/routes.py b/aries_cloudagent/protocols/issue_credential/v2_0/routes.py index 3eb05e8732..cd27db11f4 100644 --- a/aries_cloudagent/protocols/issue_credential/v2_0/routes.py +++ b/aries_cloudagent/protocols/issue_credential/v2_0/routes.py @@ -1,5 +1,6 @@ """Credential exchange admin routes.""" +from ....vc.ld_proofs.error import LinkedDataProofException from json.decoder import JSONDecodeError from typing import Mapping @@ -1060,6 +1061,8 @@ async def credential_exchange_send_bound_offer(request: web.BaseRequest): cred_ex_record, outbound_handler, ) + except LinkedDataProofException as err: + raise web.HTTPBadRequest(reason=err) from err await outbound_handler(cred_offer_message, connection_id=connection_id) diff --git a/aries_cloudagent/protocols/issue_credential/v2_0/tests/test_routes.py b/aries_cloudagent/protocols/issue_credential/v2_0/tests/test_routes.py index 40cc486c79..23c00f4b3f 100644 --- a/aries_cloudagent/protocols/issue_credential/v2_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/issue_credential/v2_0/tests/test_routes.py @@ -1,3 +1,4 @@ +from .....vc.ld_proofs.error import LinkedDataProofException from asynctest import mock as async_mock, TestCase as AsyncTestCase from .....admin.request_context import AdminRequestContext @@ -802,6 +803,37 @@ async def test_credential_exchange_send_bound_offer(self): mock_response.assert_called_once_with(mock_cx_rec.serialize.return_value) + async def test_credential_exchange_send_bound_offer_linked_data_error(self): + self.request.json = async_mock.CoroutineMock(return_value={}) + self.request.match_info = {"cred_ex_id": "dummy"} + + with async_mock.patch.object( + test_module, "ConnRecord", autospec=True + ) as mock_conn_rec, async_mock.patch.object( + test_module, "V20CredManager", autospec=True + ) as mock_cred_mgr, async_mock.patch.object( + test_module, "V20CredExRecord", autospec=True + ) as mock_cx_rec_cls, async_mock.patch.object( + test_module.web, "json_response" + ) as mock_response: + mock_cx_rec_cls.retrieve_by_id = async_mock.CoroutineMock() + mock_cx_rec_cls.retrieve_by_id.return_value.state = ( + test_module.V20CredExRecord.STATE_PROPOSAL_RECEIVED + ) + + mock_cred_mgr.return_value.create_offer = async_mock.CoroutineMock() + + mock_cx_rec = async_mock.MagicMock() + + exception_message = "ex" + mock_cred_mgr.return_value.create_offer.side_effect = ( + LinkedDataProofException(exception_message) + ) + with self.assertRaises(test_module.web.HTTPBadRequest) as error: + await test_module.credential_exchange_send_bound_offer(self.request) + + assert exception_message in str(error.exception) + async def test_credential_exchange_send_bound_offer_bad_cred_ex_id(self): self.request.json = async_mock.CoroutineMock(return_value={}) self.request.match_info = {"cred_ex_id": "dummy"} diff --git a/aries_cloudagent/transport/outbound/base.py b/aries_cloudagent/transport/outbound/base.py index 1a7bff9dbc..467ef69d55 100644 --- a/aries_cloudagent/transport/outbound/base.py +++ b/aries_cloudagent/transport/outbound/base.py @@ -14,10 +14,13 @@ class BaseOutboundTransport(ABC): """Base outbound transport class.""" - def __init__(self, wire_format: BaseWireFormat = None) -> None: + def __init__( + self, wire_format: BaseWireFormat = None, root_profile: Profile = None + ) -> None: """Initialize a `BaseOutboundTransport` instance.""" self._collector = None self._wire_format = wire_format + self.root_profile = root_profile @property def collector(self) -> Collector: diff --git a/aries_cloudagent/transport/outbound/http.py b/aries_cloudagent/transport/outbound/http.py index 5bd2069c28..363e59ee2a 100644 --- a/aries_cloudagent/transport/outbound/http.py +++ b/aries_cloudagent/transport/outbound/http.py @@ -18,9 +18,9 @@ class HttpTransport(BaseOutboundTransport): schemes = ("http", "https") - def __init__(self) -> None: + def __init__(self, **kwargs) -> None: """Initialize an `HttpTransport` instance.""" - super().__init__() + super().__init__(**kwargs) self.client_session: ClientSession = None self.connector: TCPConnector = None self.logger = logging.getLogger(__name__) diff --git a/aries_cloudagent/transport/outbound/manager.py b/aries_cloudagent/transport/outbound/manager.py index abdea326aa..18fdf4b182 100644 --- a/aries_cloudagent/transport/outbound/manager.py +++ b/aries_cloudagent/transport/outbound/manager.py @@ -9,7 +9,6 @@ from urllib.parse import urlparse from ...connections.models.connection_target import ConnectionTarget -from ...config.injection_context import InjectionContext from ...core.profile import Profile from ...utils.classloader import ClassLoader, ModuleLoadError, ClassNotFoundError from ...utils.stats import Collector @@ -68,18 +67,16 @@ class OutboundTransportManager: MAX_RETRY_COUNT = 4 - def __init__( - self, context: InjectionContext, handle_not_delivered: Callable = None - ): + def __init__(self, profile: Profile, handle_not_delivered: Callable = None): """ Initialize a `OutboundTransportManager` instance. Args: - context: The application context + root_profile: The application root profile handle_not_delivered: An optional handler for undelivered messages """ - self.context = context + self.root_profile = profile self.loop = asyncio.get_event_loop() self.handle_not_delivered = handle_not_delivered self.outbound_buffer = [] @@ -90,13 +87,15 @@ def __init__( self.running_transports = {} self.task_queue = TaskQueue(max_active=200) self._process_task: asyncio.Task = None - if self.context.settings.get("transport.max_outbound_retry"): - self.MAX_RETRY_COUNT = self.context.settings["transport.max_outbound_retry"] + if self.root_profile.settings.get("transport.max_outbound_retry"): + self.MAX_RETRY_COUNT = self.root_profile.settings[ + "transport.max_outbound_retry" + ] async def setup(self): """Perform setup operations.""" outbound_transports = ( - self.context.settings.get("transport.outbound_configs") or [] + self.root_profile.settings.get("transport.outbound_configs") or [] ) for outbound_transport in outbound_transports: self.register(outbound_transport) @@ -172,8 +171,10 @@ def register_class( async def start_transport(self, transport_id: str): """Start a registered transport.""" - transport = self.registered_transports[transport_id]() - transport.collector = self.context.inject_or(Collector) + transport = self.registered_transports[transport_id]( + root_profile=self.root_profile + ) + transport.collector = self.root_profile.inject_or(Collector) await transport.start() self.running_transports[transport_id] = transport @@ -379,14 +380,14 @@ async def _process_loop(self): if deliver: queued.state = QueuedOutboundMessage.STATE_DELIVER p_time = trace_event( - self.context.settings, + self.root_profile.settings, queued.message if queued.message else queued.payload, outcome="OutboundTransportManager.DELIVER.START." + queued.endpoint, ) self.deliver_queued_message(queued) trace_event( - self.context.settings, + self.root_profile.settings, queued.message if queued.message else queued.payload, outcome="OutboundTransportManager.DELIVER.END." + queued.endpoint, @@ -408,13 +409,13 @@ async def _process_loop(self): else: queued.state = QueuedOutboundMessage.STATE_ENCODE p_time = trace_event( - self.context.settings, + self.root_profile.settings, queued.message if queued.message else queued.payload, outcome="OutboundTransportManager.ENCODE.START", ) self.encode_queued_message(queued) trace_event( - self.context.settings, + self.root_profile.settings, queued.message if queued.message else queued.payload, outcome="OutboundTransportManager.ENCODE.END", perf_counter=p_time, @@ -449,7 +450,7 @@ async def perform_encode( self, queued: QueuedOutboundMessage, wire_format: BaseWireFormat = None ): """Perform message encoding.""" - wire_format = wire_format or self.context.inject(BaseWireFormat) + wire_format = wire_format or self.root_profile.inject(BaseWireFormat) session = await queued.profile.session() queued.payload = await wire_format.encode_message( diff --git a/aries_cloudagent/transport/outbound/tests/test_manager.py b/aries_cloudagent/transport/outbound/tests/test_manager.py index 9ab95b2b8a..834184b208 100644 --- a/aries_cloudagent/transport/outbound/tests/test_manager.py +++ b/aries_cloudagent/transport/outbound/tests/test_manager.py @@ -2,7 +2,7 @@ from asynctest import TestCase as AsyncTestCase, mock as async_mock -from ....config.injection_context import InjectionContext +from ....core.in_memory import InMemoryProfile from ....connections.models.connection_target import ConnectionTarget from ....core.in_memory import InMemoryProfile from ...wire_format import BaseWireFormat @@ -19,7 +19,7 @@ class TestOutboundTransportManager(AsyncTestCase): def test_register_path(self): - mgr = OutboundTransportManager(InjectionContext()) + mgr = OutboundTransportManager(InMemoryProfile.test_profile()) mgr.register("http") assert mgr.get_registered_transport_for_scheme("http") assert mgr.MAX_RETRY_COUNT == 4 @@ -33,23 +33,21 @@ def test_register_path(self): mgr.register("no.such.module.path") def test_maximum_retry_count(self): - context = InjectionContext() - context.update_settings({"transport.max_outbound_retry": 5}) - mgr = OutboundTransportManager(context) + profile = InMemoryProfile.test_profile({"transport.max_outbound_retry": 5}) + mgr = OutboundTransportManager(profile) mgr.register("http") assert mgr.MAX_RETRY_COUNT == 5 async def test_setup(self): - context = InjectionContext() - context.update_settings({"transport.outbound_configs": ["http"]}) - mgr = OutboundTransportManager(context) + profile = InMemoryProfile.test_profile({"transport.outbound_configs": ["http"]}) + mgr = OutboundTransportManager(profile) with async_mock.patch.object(mgr, "register") as mock_register: await mgr.setup() mock_register.assert_called_once_with("http") async def test_send_message(self): - context = InjectionContext() - mgr = OutboundTransportManager(context) + profile = InMemoryProfile.test_profile() + mgr = OutboundTransportManager(profile) transport_cls = async_mock.Mock(spec=[]) with self.assertRaises(OutboundTransportRegistrationError): @@ -124,9 +122,8 @@ async def test_send_message(self): transport.stop.assert_awaited_once_with() async def test_stop_cancel(self): - context = InjectionContext() - context.update_settings({"transport.outbound_configs": ["http"]}) - mgr = OutboundTransportManager(context) + profile = InMemoryProfile.test_profile({"transport.outbound_configs": ["http"]}) + mgr = OutboundTransportManager(profile) mgr._process_task = async_mock.MagicMock( done=async_mock.MagicMock(return_value=False), cancel=async_mock.MagicMock() ) @@ -135,8 +132,8 @@ async def test_stop_cancel(self): mgr._process_task.cancel.assert_called_once() async def test_enqueue_webhook(self): - context = InjectionContext() - mgr = OutboundTransportManager(context) + profile = InMemoryProfile.test_profile() + mgr = OutboundTransportManager(profile) test_topic = "test-topic" test_payload = {"test": "payload"} test_endpoint_host = "http://example" @@ -173,8 +170,8 @@ async def test_process_done_x(self): done=async_mock.MagicMock(return_value=True), exception=async_mock.MagicMock(return_value=KeyError("No such key")), ) - context = InjectionContext() - mgr = OutboundTransportManager(context) + profile = InMemoryProfile.test_profile() + mgr = OutboundTransportManager(profile) with async_mock.patch.object( mgr, "_process_task", async_mock.MagicMock() @@ -187,8 +184,8 @@ async def test_process_finished_x(self): mock_task = async_mock.MagicMock( exc_info=(KeyError, KeyError("nope"), None), ) - context = InjectionContext() - mgr = OutboundTransportManager(context) + profile = InMemoryProfile.test_profile() + mgr = OutboundTransportManager(profile) with async_mock.patch.object( mgr, "process_queued", async_mock.MagicMock() @@ -203,9 +200,9 @@ async def test_process_loop_retry_now(self): retry_at=test_module.get_timer() - 1, ) - context = InjectionContext() + profile = InMemoryProfile.test_profile() mock_handle_not_delivered = async_mock.MagicMock() - mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr = OutboundTransportManager(profile, mock_handle_not_delivered) mgr.outbound_buffer.append(mock_queued) with async_mock.patch.object( @@ -222,9 +219,9 @@ async def test_process_loop_retry_later(self): retry_at=test_module.get_timer() + 3600, ) - context = InjectionContext() + profile = InMemoryProfile.test_profile() mock_handle_not_delivered = async_mock.MagicMock() - mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr = OutboundTransportManager(profile, mock_handle_not_delivered) mgr.outbound_buffer.append(mock_queued) with async_mock.patch.object( @@ -236,9 +233,9 @@ async def test_process_loop_retry_later(self): assert mock_queued.retry_at is not None async def test_process_loop_new(self): - context = InjectionContext() + profile = InMemoryProfile.test_profile() mock_handle_not_delivered = async_mock.MagicMock() - mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr = OutboundTransportManager(profile, mock_handle_not_delivered) mgr.outbound_new = [ async_mock.MagicMock( @@ -259,9 +256,9 @@ async def test_process_loop_new(self): await mgr._process_loop() async def test_process_loop_new_deliver(self): - context = InjectionContext() + profile = InMemoryProfile.test_profile() mock_handle_not_delivered = async_mock.MagicMock() - mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr = OutboundTransportManager(profile, mock_handle_not_delivered) mgr.outbound_new = [ async_mock.MagicMock( @@ -289,9 +286,9 @@ async def test_process_loop_x(self): payload="Hello world", ) - context = InjectionContext() + profile = InMemoryProfile.test_profile() mock_handle_not_delivered = async_mock.MagicMock() - mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr = OutboundTransportManager(profile, mock_handle_not_delivered) mgr.outbound_buffer.append(mock_queued) await mgr._process_loop() @@ -302,9 +299,9 @@ async def test_finished_deliver_x_log_debug(self): ) mock_completed_x = async_mock.MagicMock(exc_info=KeyError("an error occurred")) - context = InjectionContext() + profile = InMemoryProfile.test_profile() mock_handle_not_delivered = async_mock.MagicMock() - mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr = OutboundTransportManager(profile, mock_handle_not_delivered) mgr.outbound_buffer.append(mock_queued) with async_mock.patch.object( test_module.LOGGER, "exception", async_mock.MagicMock() @@ -319,19 +316,17 @@ async def test_finished_deliver_x_log_debug(self): mgr.finished_deliver(mock_queued, mock_completed_x) async def test_should_encode_outbound_message(self): - context = InjectionContext() base_wire_format = BaseWireFormat() encoded_msg = "encoded_message" base_wire_format.encode_message = async_mock.CoroutineMock( return_value=encoded_msg ) - context.injector.bind_instance(BaseWireFormat, base_wire_format) - profile = InMemoryProfile.test_session().profile + profile = InMemoryProfile.test_profile(bind={BaseWireFormat: base_wire_format}) profile.session = async_mock.CoroutineMock(return_value=async_mock.MagicMock()) outbound = async_mock.MagicMock(payload="payload", enc_payload=None) target = async_mock.MagicMock() - mgr = OutboundTransportManager(context) + mgr = OutboundTransportManager(profile) result = await mgr.encode_outbound_message(profile, outbound, target) assert result.payload == encoded_msg @@ -344,13 +339,12 @@ async def test_should_encode_outbound_message(self): ) async def test_should_not_encode_already_packed_message(self): - context = InjectionContext() profile = InMemoryProfile.test_session().profile enc_payload = "enc_payload" outbound = async_mock.MagicMock(enc_payload=enc_payload) target = async_mock.MagicMock() - mgr = OutboundTransportManager(context) + mgr = OutboundTransportManager(profile) result = await mgr.encode_outbound_message(profile, outbound, target) assert result.payload == enc_payload diff --git a/aries_cloudagent/transport/outbound/ws.py b/aries_cloudagent/transport/outbound/ws.py index 06421a2f3f..4ffcfdac08 100644 --- a/aries_cloudagent/transport/outbound/ws.py +++ b/aries_cloudagent/transport/outbound/ws.py @@ -15,9 +15,9 @@ class WsTransport(BaseOutboundTransport): schemes = ("ws", "wss") - def __init__(self) -> None: + def __init__(self, **kwargs) -> None: """Initialize an `WsTransport` instance.""" - super().__init__() + super().__init__(**kwargs) self.logger = logging.getLogger(__name__) async def start(self):