diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py index 30ddd66fa3..b7cb04e35d 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..messages.menu import Menu from ..util import save_connection_menu @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("MenuHandler called with context %s", context) assert isinstance(context.message, Menu) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received action menu: %s", context.message) await save_connection_menu( diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py index 905e039ce8..4aef949495 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/menu_request_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..base_service import BaseMenuService from ..messages.menu_request import MenuRequest @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("MenuRequestHandler called with context %s", context) assert isinstance(context.message, MenuRequest) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received action menu request") service: BaseMenuService = context.inject_or(BaseMenuService) diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py index 5e38bc90f3..4f6a5b1387 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/perform_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..base_service import BaseMenuService from ..messages.perform import Perform @@ -23,6 +23,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("PerformHandler called with context %s", context) assert isinstance(context.message, Perform) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received action menu perform request") service: BaseMenuService = context.inject_or(BaseMenuService) diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py index 4034bf5ea0..392bc3f8cd 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_handler.py @@ -1,9 +1,9 @@ from unittest import IsolatedAsyncioTestCase + from aries_cloudagent.tests import mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from .. import menu_handler as handler @@ -12,6 +12,7 @@ async def test_called(self): request_context = RequestContext.test_context() request_context.connection_record = mock.MagicMock() request_context.connection_record.connection_id = "dummy" + request_context.connection_ready = True handler.save_connection_menu = mock.CoroutineMock() responder = MockResponder() diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py index 30d97e65f4..63214fe409 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_menu_request_handler.py @@ -1,9 +1,9 @@ from unittest import IsolatedAsyncioTestCase + from aries_cloudagent.tests import mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from .. import menu_request_handler as handler @@ -18,6 +18,7 @@ async def test_called(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.MenuRequest() @@ -39,6 +40,7 @@ async def test_called_no_active_menu(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.MenuRequest() diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py index a8decf96ab..7af6672ee8 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/handlers/tests/test_perform_handler.py @@ -1,9 +1,9 @@ from unittest import IsolatedAsyncioTestCase + from aries_cloudagent.tests import mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from .. import perform_handler as handler @@ -18,12 +18,11 @@ async def test_called(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.Perform() - self.menu_service.perform_menu_action = mock.CoroutineMock( - return_value="perform" - ) + self.menu_service.perform_menu_action = mock.CoroutineMock(return_value="perform") handler_inst = handler.PerformHandler() await handler_inst.handle(self.context, responder) @@ -41,6 +40,7 @@ async def test_called_no_active_menu(self): self.context.connection_record = mock.MagicMock() self.context.connection_record.connection_id = "dummy" + self.context.connection_ready = True responder = MockResponder() self.context.message = handler.Perform() diff --git a/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py b/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py index 93bd91760f..286e62a63e 100644 --- a/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py +++ b/aries_cloudagent/protocols/basicmessage/v1_0/handlers/basicmessage_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..messages.basicmessage import BasicMessage @@ -22,6 +22,9 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("BasicMessageHandler called with context %s", context) assert isinstance(context.message, BasicMessage) + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info("Received basic message: %s", context.message.content) body = context.message.content diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py index 7c5a313fc4..9fe504873b 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/disclose_handler.py @@ -3,10 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, - RequestContext, HandlerException, + RequestContext, ) - from ..manager import V10DiscoveryMgr from ..messages.disclose import Disclose @@ -18,10 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("DiscloseHandler called with context %s", context) assert isinstance(context.message, Disclose) + if not context.connection_ready: raise HandlerException( "Received disclosures message from inactive connection" ) + profile = context.profile mgr = V10DiscoveryMgr(profile) await mgr.receive_disclose( diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py index 0336b01351..c0276d5272 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/query_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..manager import V10DiscoveryMgr from ..messages.query import Query @@ -17,6 +17,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("QueryHandler called with context %s", context) assert isinstance(context.message, Query) + + if not context.connection_ready: + raise HandlerException("No connection established") + profile = context.profile mgr = V10DiscoveryMgr(profile) reply = await mgr.receive_query(context.message) diff --git a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py index 10ab3af132..348bce1107 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py +++ b/aries_cloudagent/protocols/discovery/v1_0/handlers/tests/test_query_handler.py @@ -5,7 +5,6 @@ from ......core.protocol_registry import ProtocolRegistry from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder - from ...handlers.query_handler import QueryHandler from ...messages.disclose import Disclose from ...messages.query import Query @@ -30,6 +29,7 @@ async def test_query_all(self, request_context): query_msg = Query(query="*") query_msg.assign_thread_id("test123") request_context.message = query_msg + request_context.connection_ready = True handler = QueryHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -50,6 +50,7 @@ async def test_query_all_disclose_list_settings(self, request_context): query_msg = Query(query="*") query_msg.assign_thread_id("test123") request_context.message = query_msg + request_context.connection_ready = True handler = QueryHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -65,6 +66,7 @@ async def test_receive_query_process_disclosed(self, request_context): query_msg = Query(query="*") query_msg.assign_thread_id("test123") request_context.message = query_msg + request_context.connection_ready = True handler = QueryHandler() responder = MockResponder() with mock.patch.object( diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py index adab14bf3e..6b9f4047ba 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/disclosures_handler.py @@ -3,10 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, - RequestContext, HandlerException, + RequestContext, ) - from ..manager import V20DiscoveryMgr from ..messages.disclosures import Disclosures @@ -18,10 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("DiscloseHandler called with context %s", context) assert isinstance(context.message, Disclosures) + if not context.connection_ready: raise HandlerException( "Received disclosures message from inactive connection" ) + profile = context.profile mgr = V20DiscoveryMgr(profile) await mgr.receive_disclose( diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py index e970ad0c6f..95c26d8866 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/queries_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..manager import V20DiscoveryMgr from ..messages.queries import Queries @@ -17,6 +17,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): """Message handler implementation.""" self._logger.debug("QueryHandler called with context %s", context) assert isinstance(context.message, Queries) + + if not context.connection_ready: + raise HandlerException("No connection established") + profile = context.profile mgr = V20DiscoveryMgr(profile) reply = await mgr.receive_query(context.message) diff --git a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py index 1998a973e0..4fcaba2c4a 100644 --- a/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py +++ b/aries_cloudagent/protocols/discovery/v2_0/handlers/tests/test_queries_handler.py @@ -1,9 +1,11 @@ +from typing import Generator + import pytest from aries_cloudagent.tests import mock -from ......core.protocol_registry import ProtocolRegistry from ......core.goal_code_registry import GoalCodeRegistry +from ......core.protocol_registry import ProtocolRegistry from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder from ......protocols.issue_credential.v1_0.controller import ( @@ -16,7 +18,6 @@ from ......protocols.present_proof.v1_0.message_types import ( CONTROLLERS as pres_proof_v1_controller, ) - from ...handlers.queries_handler import QueriesHandler from ...manager import V20DiscoveryMgr from ...messages.disclosures import Disclosures @@ -27,7 +28,7 @@ @pytest.fixture() -def request_context() -> RequestContext: +def request_context() -> Generator[RequestContext, None, None]: ctx = RequestContext.test_context() protocol_registry = ProtocolRegistry() goal_code_registry = GoalCodeRegistry() @@ -48,6 +49,7 @@ async def test_queries_all(self, request_context): queries = Queries(queries=test_queries) queries.assign_thread_id("test123") request_context.message = queries + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -67,6 +69,7 @@ async def test_queries_protocol_goal_code_all(self, request_context): queries = Queries(queries=test_queries) queries.assign_thread_id("test123") request_context.message = queries + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -105,6 +108,7 @@ async def test_queries_protocol_goal_code_all_disclose_list_settings( queries = Queries(queries=test_queries) queries.assign_thread_id("test123") request_context.message = queries + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() await handler.handle(request_context, responder) @@ -129,6 +133,7 @@ async def test_receive_query_process_disclosed(self, request_context): queries_msg = Queries(queries=test_queries) queries_msg.assign_thread_id("test123") request_context.message = queries_msg + request_context.connection_ready = True handler = QueriesHandler() responder = MockResponder() with mock.patch.object( diff --git a/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py b/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py index 35e1d50bed..f43b1093cc 100644 --- a/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py +++ b/aries_cloudagent/protocols/endorse_transaction/v1_0/handlers/transaction_job_to_send_handler.py @@ -3,9 +3,9 @@ from .....messaging.base_handler import ( BaseHandler, BaseResponder, + HandlerException, RequestContext, ) - from ..manager import TransactionManager, TransactionManagerError from ..messages.transaction_job_to_send import TransactionJobToSend @@ -24,10 +24,11 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug(f"TransactionJobToSendHandler called with context {context}") assert isinstance(context.message, TransactionJobToSend) + if not context.connection_ready: + raise HandlerException("No connection established") + mgr = TransactionManager(context.profile) try: - await mgr.set_transaction_their_job( - context.message, context.message_receipt - ) + await mgr.set_transaction_their_job(context.message, context.message_receipt) except TransactionManagerError: self._logger.exception("Error receiving transaction jobs") diff --git a/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py b/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py index e5f1fedc77..71ba1b29a3 100644 --- a/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py +++ b/aries_cloudagent/protocols/notification/v1_0/handlers/ack_handler.py @@ -1,10 +1,9 @@ """Generic ack message handler.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder -from .....utils.tracing import trace_event, get_timer - +from .....utils.tracing import get_timer, trace_event from ..messages.ack import V10Ack @@ -22,6 +21,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("V20PresAckHandler called with context %s", context) assert isinstance(context.message, V10Ack) + + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.info( "Received v1.0 notification ack message: %s", context.message.serialize(as_string=True), diff --git a/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py b/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py index 2389b5aade..59b9effd43 100644 --- a/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py +++ b/aries_cloudagent/protocols/notification/v1_0/handlers/tests/test_ack_handler.py @@ -1,12 +1,9 @@ -from unittest import mock -from unittest import IsolatedAsyncioTestCase +from unittest import IsolatedAsyncioTestCase, mock from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder from ......transport.inbound.receipt import MessageReceipt - from ...messages.ack import V10Ack - from .. import ack_handler as test_module @@ -15,6 +12,7 @@ async def test_called(self): request_context = RequestContext.test_context() request_context.message_receipt = MessageReceipt() request_context.connection_record = mock.MagicMock() + request_context.connection_ready = True request_context.message = V10Ack(status="OK") handler = test_module.V10AckHandler() diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py index a312a2af1f..d5851df485 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/reuse_handler.py @@ -1,9 +1,8 @@ """Handshake Reuse Message Handler under RFC 0434.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder - from ..manager import OutOfBandManager, OutOfBandManagerError from ..messages.reuse import HandshakeReuse @@ -18,11 +17,12 @@ async def handle(self, context: RequestContext, responder: BaseResponder): context: Request context responder: Responder callback """ - self._logger.debug( - f"HandshakeReuseMessageHandler called with context {context}" - ) + self._logger.debug(f"HandshakeReuseMessageHandler called with context {context}") assert isinstance(context.message, HandshakeReuse) + if not context.connection_ready: + raise HandlerException("No connection established") + profile = context.profile mgr = OutOfBandManager(profile) try: diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py index 2e6982aef2..9a958b11ed 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/handlers/tests/test_reuse_handler.py @@ -1,4 +1,7 @@ """Test Reuse Message Handler.""" + +from typing import AsyncGenerator, Generator + import pytest from aries_cloudagent.tests import mock @@ -8,7 +11,6 @@ from ......messaging.request_context import RequestContext from ......messaging.responder import MockResponder from ......transport.inbound.receipt import MessageReceipt - from ...handlers import reuse_handler as test_module from ...manager import OutOfBandManagerError from ...messages.reuse import HandshakeReuse @@ -16,14 +18,14 @@ @pytest.fixture() -async def request_context() -> RequestContext: +def request_context() -> Generator[RequestContext, None, None]: ctx = RequestContext.test_context() ctx.message_receipt = MessageReceipt() yield ctx @pytest.fixture() -async def session(request_context) -> ProfileSession: +async def session(request_context) -> AsyncGenerator[ProfileSession, None]: yield await request_context.session() @@ -35,6 +37,7 @@ async def test_called(self, mock_oob_mgr, request_context): request_context.message = HandshakeReuse() handler = test_module.HandshakeReuseMessageHandler() request_context.connection_record = ConnRecord() + request_context.connection_ready = True responder = MockResponder() await handler.handle(request_context, responder) mock_oob_mgr.return_value.receive_reuse_message.assert_called_once_with( @@ -52,6 +55,7 @@ async def test_reuse_accepted(self, mock_oob_mgr, request_context): request_context.message = HandshakeReuse() handler = test_module.HandshakeReuseMessageHandler() request_context.connection_record = ConnRecord() + request_context.connection_ready = True responder = MockResponder() await handler.handle(request_context, responder) mock_oob_mgr.return_value.receive_reuse_message.assert_called_once_with( @@ -72,6 +76,7 @@ async def test_exception( request_context.message = HandshakeReuse() handler = test_module.HandshakeReuseMessageHandler() request_context.connection_record = ConnRecord() + request_context.connection_ready = True responder = MockResponder() with caplog.at_level("ERROR"): await handler.handle(request_context, responder) diff --git a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py index 366cbcd6c2..88b804b30a 100644 --- a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/revoke_handler.py @@ -1,9 +1,8 @@ """Handler for revoke message.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder - from ..messages.revoke import Revoke @@ -16,6 +15,10 @@ class RevokeHandler(BaseHandler): async def handle(self, context: RequestContext, responder: BaseResponder): """Handle revoke message.""" assert isinstance(context.message, Revoke) + + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.debug( "Received notification of revocation for cred issued in thread %s " "with comment: %s", diff --git a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py index 29c90c0692..8154ceeb84 100644 --- a/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v1_0/handlers/tests/test_revoke_handler.py @@ -1,12 +1,14 @@ """Test RevokeHandler.""" +from typing import Generator + import pytest from ......core.event_bus import EventBus, MockEventBus from ......core.in_memory import InMemoryProfile from ......core.profile import Profile from ......messaging.request_context import RequestContext -from ......messaging.responder import MockResponder, BaseResponder +from ......messaging.responder import BaseResponder, MockResponder from ...messages.revoke import Revoke from ..revoke_handler import RevokeHandler @@ -32,7 +34,7 @@ def message(): @pytest.fixture -def context(profile: Profile, message: Revoke): +def context(profile: Profile, message: Revoke) -> Generator[RequestContext, None, None]: request_context = RequestContext(profile) request_context.message = message yield request_context @@ -42,6 +44,7 @@ def context(profile: Profile, message: Revoke): async def test_handle( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): + context.connection_ready = True await RevokeHandler().handle(context, responder) assert event_bus.events [(_, received)] = event_bus.events @@ -55,6 +58,7 @@ async def test_handle_monitor( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): context.settings["revocation.monitor_notification"] = True + context.connection_ready = True await RevokeHandler().handle(context, responder) [(_, webhook), (_, received)] = event_bus.events diff --git a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py index f2ffafe7e0..2b6bd23ba1 100644 --- a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/revoke_handler.py @@ -1,9 +1,8 @@ """Handler for revoke message.""" -from .....messaging.base_handler import BaseHandler +from .....messaging.base_handler import BaseHandler, HandlerException from .....messaging.request_context import RequestContext from .....messaging.responder import BaseResponder - from ..messages.revoke import Revoke @@ -16,6 +15,10 @@ class RevokeHandler(BaseHandler): async def handle(self, context: RequestContext, responder: BaseResponder): """Handle revoke message.""" assert isinstance(context.message, Revoke) + + if not context.connection_ready: + raise HandlerException("No connection established") + self._logger.debug( "Received notification of revocation for %s cred %s with comment: %s", context.message.revocation_format, diff --git a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py index 110b47dee0..d6bdc30f57 100644 --- a/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py +++ b/aries_cloudagent/protocols/revocation_notification/v2_0/handlers/tests/test_revoke_handler.py @@ -1,12 +1,14 @@ """Test RevokeHandler.""" +from typing import Generator + import pytest from ......core.event_bus import EventBus, MockEventBus from ......core.in_memory import InMemoryProfile from ......core.profile import Profile from ......messaging.request_context import RequestContext -from ......messaging.responder import MockResponder, BaseResponder +from ......messaging.responder import BaseResponder, MockResponder from ...messages.revoke import Revoke from ..revoke_handler import RevokeHandler @@ -36,7 +38,7 @@ def message(): @pytest.fixture -def context(profile: Profile, message: Revoke): +def context(profile: Profile, message: Revoke) -> Generator[RequestContext, None, None]: request_context = RequestContext(profile) request_context.message = message yield request_context @@ -46,6 +48,7 @@ def context(profile: Profile, message: Revoke): async def test_handle( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): + context.connection_ready = True await RevokeHandler().handle(context, responder) assert event_bus.events [(_, received)] = event_bus.events @@ -60,6 +63,7 @@ async def test_handle_monitor( context: RequestContext, responder: BaseResponder, event_bus: MockEventBus ): context.settings["revocation.monitor_notification"] = True + context.connection_ready = True await RevokeHandler().handle(context, responder) [(_, webhook), (_, received)] = event_bus.events