diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index 73f5122a5a..c3163033d3 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -2,23 +2,22 @@ import asyncio import logging -from typing import Coroutine, Sequence, Set +from typing import Callable, Coroutine, Sequence, Set import uuid -from aiohttp import web, ClientSession, DummyCookieJar +from aiohttp import web from aiohttp_apispec import docs, response_schema, setup_aiohttp_apispec import aiohttp_cors from marshmallow import fields, Schema from ..config.injection_context import InjectionContext -from ..messaging.outbound_message import OutboundMessage from ..messaging.plugin_registry import PluginRegistry from ..messaging.responder import BaseResponder +from ..messaging.task_queue import TaskQueue from ..stats import Collector -from ..task_processor import TaskProcessor -from ..transport.outbound.queue.base import BaseOutboundMessageQueue -from ..transport.stats import StatsTracer +from ..transport.queue.basic import BasicMessageQueue +from ..transport.outbound.message import OutboundMessage from .base_server import BaseAdminServer from .error import AdminSetupError @@ -43,7 +42,9 @@ class AdminStatusSchema(Schema): class AdminResponder(BaseResponder): """Handle outgoing messages from message handlers.""" - def __init__(self, send: Coroutine, webhook: Coroutine, **kwargs): + def __init__( + self, context: InjectionContext, send: Coroutine, webhook: Coroutine, **kwargs + ): """ Initialize an instance of `AdminResponder`. @@ -52,6 +53,7 @@ def __init__(self, send: Coroutine, webhook: Coroutine, **kwargs): """ super().__init__(**kwargs) + self._context = context self._send = send self._webhook = webhook @@ -62,7 +64,7 @@ async def send_outbound(self, message: OutboundMessage): Args: message: The `OutboundMessage` to be sent """ - await self._send(message) + await self._send(self._context, message) async def send_webhook(self, topic: str, payload: dict): """ @@ -111,6 +113,9 @@ def __init__( port: int, context: InjectionContext, outbound_message_router: Coroutine, + webhook_router: Callable, + task_queue: TaskQueue = None, + conductor_stats: Coroutine = None, ): """ Initialize an AdminServer instance. @@ -118,23 +123,26 @@ def __init__( Args: host: Host to listen on port: Port to listen on - + context: The application context instance + outbound_message_router: Coroutine for delivering outbound messages + webhook_router: Callable for delivering webhooks + task_queue: An optional task queue for handlers """ self.app = None self.host = host self.port = port + self.conductor_stats = conductor_stats self.loaded_modules = [] - self.webhook_queue = None - self.webhook_retries = 5 - self.webhook_session: ClientSession = None + self.task_queue = task_queue + self.webhook_router = webhook_router self.webhook_targets = {} - self.webhook_task = None - self.webhook_processor: TaskProcessor = None self.websocket_queues = {} self.site = None self.context = context.start_scope("admin") - self.responder = AdminResponder(outbound_message_router, self.send_webhook) + self.responder = AdminResponder( + self.context, outbound_message_router, self.send_webhook + ) self.context.injector.bind_instance(BaseResponder, self.responder) async def make_application(self) -> web.Application: @@ -168,6 +176,15 @@ async def check_token(request, handler): middlewares.append(check_token) + if self.task_queue: + + @web.middleware + async def apply_limiter(request, handler): + task = await self.task_queue.put(handler(request)) + return await task + + middlewares.append(apply_limiter) + stats: Collector = await self.context.inject(Collector, required=False) if stats: @@ -187,14 +204,16 @@ async def collect_stats(request, handler): app.add_routes( [ web.get("/", self.redirect_handler), - web.get("/modules", self.modules_handler), + web.get("/plugins", self.plugins_handler), web.get("/status", self.status_handler), web.post("/status/reset", self.status_reset_handler), web.get("/ws", self.websocket_handler), ] ) - plugin_registry = await self.context.inject(PluginRegistry, required=False) + plugin_registry: PluginRegistry = await self.context.inject( + PluginRegistry, required=False + ) if plugin_registry: await plugin_registry.register_admin_routes(app) @@ -249,21 +268,15 @@ async def stop(self) -> None: if self.site: await self.site.stop() self.site = None - if self.webhook_queue: - self.webhook_queue.stop() - self.webhook_queue = None - if self.webhook_session: - await self.webhook_session.close() - self.webhook_session = None async def on_startup(self, app: web.Application): """Perform webserver startup actions.""" - @docs(tags=["server"], summary="Fetch the list of loaded modules") + @docs(tags=["server"], summary="Fetch the list of loaded plugins") @response_schema(AdminModulesSchema(), 200) - async def modules_handler(self, request: web.BaseRequest): + async def plugins_handler(self, request: web.BaseRequest): """ - Request handler for the loaded modules list. + Request handler for the loaded plugins list. Args: request: aiohttp request object @@ -272,7 +285,12 @@ async def modules_handler(self, request: web.BaseRequest): The module list response """ - return web.json_response({"result": self.loaded_modules}) + registry: PluginRegistry = await self.context.inject( + PluginRegistry, required=False + ) + print(registry) + plugins = registry and sorted(registry.plugin_names) or [] + return web.json_response({"result": plugins}) @docs(tags=["server"], summary="Fetch the server status") @response_schema(AdminStatusSchema(), 200) @@ -291,6 +309,8 @@ async def status_handler(self, request: web.BaseRequest): collector: Collector = await self.context.inject(Collector, required=False) if collector: status["timing"] = collector.results + if self.conductor_stats: + status["conductor"] = await self.conductor_stats() return web.json_response(status) @docs(tags=["server"], summary="Reset statistics") @@ -321,7 +341,8 @@ async def websocket_handler(self, request): ws = web.WebSocketResponse() await ws.prepare(request) socket_id = str(uuid.uuid4()) - queue = await self.context.inject(BaseOutboundMessageQueue) + queue = BasicMessageQueue() + loop = asyncio.get_event_loop() try: self.websocket_queues[socket_id] = queue @@ -340,20 +361,40 @@ async def websocket_handler(self, request): ) closed = False + receive = loop.create_task(ws.receive()) + send = loop.create_task(queue.dequeue(timeout=5.0)) + while not closed: try: - msg = await queue.dequeue(timeout=5.0) - if msg is None: - # we send fake pings because the JS client - # can't detect real ones - msg = {"topic": "ping"} + await asyncio.wait( + (receive, send), return_when=asyncio.FIRST_COMPLETED + ) if ws.closed: closed = True - if msg and not closed: - await ws.send_json(msg) + + if receive.done(): + # ignored + if not closed: + receive = loop.create_task(ws.receive()) + + if send.done(): + msg = send.result() + if msg is None: + # we send fake pings because the JS client + # can't detect real ones + msg = {"topic": "ping"} + if not closed: + if msg: + await ws.send_json(msg) + send = loop.create_task(queue.dequeue(timeout=5.0)) except asyncio.CancelledError: closed = True + if not receive.done(): + receive.cancel() + if not send.done(): + send.cancel() + finally: del self.websocket_queues[socket_id] @@ -374,65 +415,10 @@ def remove_webhook_target(self, target_url: str): async def send_webhook(self, topic: str, payload: dict): """Add a webhook to the queue, to send to all registered targets.""" - if not self.webhook_queue: - self.webhook_queue = await self.context.inject(BaseOutboundMessageQueue) - self.webhook_task = asyncio.get_event_loop().create_task( - self._process_webhooks() - ) - await self.webhook_queue.enqueue((topic, payload)) - - async def _process_webhooks(self): - """Continuously poll webhook queue and dispatch to targets.""" - session_args = {} - collector: Collector = await self.context.inject(Collector, required=False) - if collector: - session_args["trace_configs"] = [StatsTracer(collector, "webhook-http:")] - session_args["cookie_jar"] = DummyCookieJar() - self.webhook_session = ClientSession(**session_args) - self.webhook_processor = TaskProcessor(max_pending=20) - async for topic, payload in self.webhook_queue: - for queue in self.websocket_queues.values(): - await queue.enqueue({"topic": topic, "payload": payload}) - if self.webhook_targets: - targets = self.webhook_targets.copy() - for idx, target in targets.items(): - if not target.topic_filter or topic in target.topic_filter: - retries = ( - self.webhook_retries - if target.retries is None - else target.retries - ) - await self.webhook_processor.run_retry( - lambda pending: self._perform_send_webhook( - target.endpoint, topic, payload, pending.attempts + 1 - ), - ident=(target.endpoint, topic), - retries=retries, - ) - self.webhook_queue.task_done() - - async def _perform_send_webhook( - self, target_url: str, topic: str, payload: dict, attempt: int = None - ): - """Dispatch a webhook to a specific endpoint.""" - full_webhook_url = f"{target_url}/topic/{topic}/" - attempt_str = f" (attempt {attempt})" if attempt else "" - LOGGER.debug("Sending webhook to : %s%s", full_webhook_url, attempt_str) - async with self.webhook_session.post( - full_webhook_url, json=payload - ) as response: - if response.status < 200 or response.status > 299: - # raise Exception(f"Unexpected response status {response.status}") - raise Exception( - f"Unexpected: target {target_url}\n" - f"full {full_webhook_url}\n" - f"response {response}" - ) + if self.webhook_router: + for idx, target in self.webhook_targets.items(): + if not target.topic_filter or topic in target.topic_filter: + self.webhook_router(topic, payload, target.endpoint, target.retries) - async def complete_webhooks(self): - """Wait for all pending webhooks to be dispatched, used in testing.""" - if self.webhook_queue: - await self.webhook_queue.join() - self.webhook_queue.reset() - if self.webhook_processor: - await self.webhook_processor.wait_done() + for queue in self.websocket_queues.values(): + await queue.enqueue({"topic": topic, "payload": payload}) diff --git a/aries_cloudagent/admin/tests/test_admin_server.py b/aries_cloudagent/admin/tests/test_admin_server.py index d0f36e2c57..fc0effb03f 100644 --- a/aries_cloudagent/admin/tests/test_admin_server.py +++ b/aries_cloudagent/admin/tests/test_admin_server.py @@ -8,10 +8,9 @@ from ...config.default_context import DefaultContextBuilder from ...config.injection_context import InjectionContext from ...config.provider import ClassProvider -from ...messaging.outbound_message import OutboundMessage +from ...messaging.plugin_registry import PluginRegistry from ...messaging.protocol_registry import ProtocolRegistry -from ...transport.outbound.queue.base import BaseOutboundMessageQueue -from ...transport.outbound.queue.basic import BasicOutboundMessageQueue +from ...transport.outbound.message import OutboundMessage from ..server import AdminServer @@ -19,24 +18,29 @@ class TestAdminServerBasic(AsyncTestCase): async def setUp(self): self.message_results = [] + self.webhook_results = [] def get_admin_server( self, settings: dict = None, context: InjectionContext = None ) -> AdminServer: if not context: context = InjectionContext() - context.injector.bind_provider( - BaseOutboundMessageQueue, ClassProvider(BasicOutboundMessageQueue) - ) if settings: context.update_settings(settings) return AdminServer( - "0.0.0.0", unused_port(), context, self.outbound_message_router + "0.0.0.0", + unused_port(), + context, + self.outbound_message_router, + self.webhook_router, ) async def outbound_message_router(self, *args): self.message_results.append(args) + def webhook_router(self, *args): + self.webhook_results.append(args) + async def test_start_stop(self): with self.assertRaises(AssertionError): await self.get_admin_server().start() @@ -61,19 +65,23 @@ async def test_start_stop(self): await server.stop() async def test_responder_send(self): - message = OutboundMessage("{}") + message = OutboundMessage(payload="{}") admin_server = self.get_admin_server() await admin_server.responder.send_outbound(message) - assert self.message_results == [(message,)] + assert self.message_results == [(admin_server.context, message)] @unittest_run_loop async def test_responder_webhook(self): - with patch.object(AdminServer, "send_webhook", autospec=True) as sender: - admin_server = self.get_admin_server() - test_topic = "test_topic" - test_payload = {"test": "TEST"} - await admin_server.responder.send_webhook(test_topic, test_payload) - sender.assert_awaited_once_with(admin_server, test_topic, test_payload) + admin_server = self.get_admin_server() + test_url = "target_url" + test_retries = 99 + admin_server.add_webhook_target(test_url, retries=test_retries) + test_topic = "test_topic" + test_payload = {"test": "TEST"} + await admin_server.responder.send_webhook(test_topic, test_payload) + assert self.webhook_results == [ + (test_topic, test_payload, test_url, test_retries) + ] async def test_import_routes(self): # this test just imports all default admin routes @@ -86,8 +94,11 @@ async def test_import_routes(self): class TestAdminServerClient(AioHTTPTestCase): - async def setUpAsync(self): + def setUp(self): + self.admin_server = None self.message_results = [] + self.webhook_results = [] + super().setUp() async def get_application(self): """ @@ -98,16 +109,21 @@ async def get_application(self): async def outbound_message_router(self, *args): self.message_results.append(args) + def webhook_router(self, *args): + self.webhook_results.append(args) + def get_admin_server(self) -> AdminServer: - context = InjectionContext() - context.injector.bind_provider( - BaseOutboundMessageQueue, ClassProvider(BasicOutboundMessageQueue) - ) - context.settings["admin.admin_insecure_mode"] = True - server = AdminServer( - "0.0.0.0", unused_port(), context, self.outbound_message_router - ) - return server + if not self.admin_server: + context = InjectionContext() + context.settings["admin.admin_insecure_mode"] = True + self.admin_server = AdminServer( + "0.0.0.0", + unused_port(), + context, + self.outbound_message_router, + self.webhook_router, + ) + return self.admin_server # the unittest_run_loop decorator can be used in tandem with # the AioHTTPTestCase to simplify running @@ -124,6 +140,16 @@ async def test_swagger(self): text = await resp.text() assert "Swagger UI" in text + @unittest_run_loop + async def test_plugins(self): + test_registry = PluginRegistry() + test_plugin = "aries_cloudagent.protocols.trustping" + test_registry.register_plugin(test_plugin) + self.admin_server.context.injector.bind_instance(PluginRegistry, test_registry) + resp = await self.client.request("GET", "/plugins") + resp_dict = await resp.json() + assert test_plugin in resp_dict["result"] + @unittest_run_loop async def test_status(self): resp = await self.client.request("GET", "/status") @@ -149,16 +175,20 @@ async def get_application(self): return await self.get_admin_server().make_application() async def outbound_message_router(self, *args): - self.message_results.append(args) + raise Exception() + + def webhook_router(self, *args): + raise Exception() def get_admin_server(self) -> AdminServer: context = InjectionContext() - context.injector.bind_provider( - BaseOutboundMessageQueue, ClassProvider(BasicOutboundMessageQueue) - ) context.settings["admin.admin_api_key"] = self.TEST_API_KEY self.server = AdminServer( - "0.0.0.0", unused_port(), context, self.outbound_message_router + "0.0.0.0", + unused_port(), + context, + self.outbound_message_router, + self.webhook_router, ) return self.server @@ -187,16 +217,20 @@ async def receive_hook(self, request): raise web.HTTPOk() async def outbound_message_router(self, *args): - pass + raise Exception() + + def webhook_router(self, *args): + raise Exception() def get_admin_server(self) -> AdminServer: context = InjectionContext() - context.injector.bind_provider( - BaseOutboundMessageQueue, ClassProvider(BasicOutboundMessageQueue) - ) context.settings["admin.admin_insecure_mode"] = True server = AdminServer( - "0.0.0.0", unused_port(), context, self.outbound_message_router + "0.0.0.0", + unused_port(), + context, + self.outbound_message_router, + self.webhook_router, ) return server @@ -207,21 +241,3 @@ async def get_application(self): app = web.Application() app.add_routes([web.post("/topic/{topic}/", self.receive_hook)]) return app - - @unittest_run_loop - async def test_webhook(self): - server_addr = f"http://localhost:{self.server.port}" - admin_server = self.get_admin_server() - await admin_server.start() - - admin_server.add_webhook_target(server_addr) - test_topic = "test_topic" - test_payload = {"test": "TEST"} - await admin_server.send_webhook(test_topic, test_payload) - await asyncio.wait_for(admin_server.complete_webhooks(), 5.0) - assert self.hook_results == [(test_topic, test_payload)] - - admin_server.remove_webhook_target(server_addr) - assert admin_server.webhook_targets == {} - - await admin_server.stop() diff --git a/aries_cloudagent/conductor.py b/aries_cloudagent/conductor.py index d7e3a851bf..af7e32e4c5 100644 --- a/aries_cloudagent/conductor.py +++ b/aries_cloudagent/conductor.py @@ -8,13 +8,9 @@ """ -import asyncio -from collections import OrderedDict import hashlib import logging -from typing import Coroutine, Union -from .delivery_queue import DeliveryQueue from .admin.base_server import BaseAdminServer from .admin.server import AdminServer from .config.default_context import ContextBuilder @@ -24,18 +20,17 @@ from .config.wallet import wallet_config from .dispatcher import Dispatcher from .protocols.connections.manager import ConnectionManager, ConnectionManagerError -from .connections.models.connection_record import ConnectionRecord -from .messaging.error import MessageParseError, MessagePrepareError -from .messaging.outbound_message import OutboundMessage from .messaging.responder import BaseResponder -from .messaging.serializer import MessageSerializer -from .messaging.socket import SocketInfo, SocketRef +from .messaging.task_queue import CompletedTask, TaskQueue from .stats import Collector -from .storage.error import StorageNotFoundError -from .transport.inbound.base import InboundTransportConfiguration from .transport.inbound.manager import InboundTransportManager +from .transport.inbound.message import InboundMessage +from .transport.outbound.base import OutboundDeliveryError from .transport.outbound.manager import OutboundTransportManager -from .transport.outbound.queue.base import BaseOutboundMessageQueue +from .transport.outbound.message import OutboundMessage +from .transport.wire_format import BaseWireFormat + +LOGGER = logging.getLogger(__name__) class Conductor: @@ -60,57 +55,31 @@ def __init__(self, context_builder: ContextBuilder) -> None: self.context: InjectionContext = None self.context_builder = context_builder self.dispatcher: Dispatcher = None - self.logger = logging.getLogger(__name__) - self.message_serializer: MessageSerializer = None self.inbound_transport_manager: InboundTransportManager = None self.outbound_transport_manager: OutboundTransportManager = None - self.sockets = OrderedDict() - self.undelivered_queue: DeliveryQueue = None async def setup(self): """Initialize the global request context.""" context = await self.context_builder.build() - # Populate message serializer - self.message_serializer = await context.inject(MessageSerializer) + self.dispatcher = Dispatcher(context) - # Setup Delivery Queue - if context.settings.get("queue.enable_undelivered_queue"): - self.undelivered_queue = DeliveryQueue() + wire_format = await context.inject(BaseWireFormat, required=False) + if wire_format and hasattr(wire_format, "task_queue"): + wire_format.task_queue = self.dispatcher.task_queue # Register all inbound transports - self.inbound_transport_manager = InboundTransportManager() - inbound_transports = context.settings.get("transport.inbound_configs") or [] - for transport in inbound_transports: - try: - module, host, port = transport - self.inbound_transport_manager.register( - InboundTransportConfiguration(module=module, host=host, port=port), - self.inbound_message_router, - self.register_socket, - ) - except Exception: - self.logger.exception("Unable to register inbound transport") - raise - - # Fetch stats collector, if any - collector = await context.inject(Collector, required=False) + self.inbound_transport_manager = InboundTransportManager( + context, self.inbound_message_router, self.handle_not_returned, + ) + await self.inbound_transport_manager.setup() # Register all outbound transports - outbound_queue = await context.inject(BaseOutboundMessageQueue) - if collector: - collector.wrap(outbound_queue, ("enqueue", "dequeue")) self.outbound_transport_manager = OutboundTransportManager( - outbound_queue, collector + context, self.handle_not_delivered ) - outbound_transports = context.settings.get("transport.outbound_configs") or [] - for outbound_transport in outbound_transports: - try: - self.outbound_transport_manager.register(outbound_transport) - except Exception: - self.logger.exception("Unable to register outbound transport") - raise + await self.outbound_transport_manager.setup() # Admin API if context.settings.get("admin.enabled"): @@ -118,41 +87,50 @@ async def setup(self): admin_host = context.settings.get("admin.host", "0.0.0.0") admin_port = context.settings.get("admin.port", "80") self.admin_server = AdminServer( - admin_host, admin_port, context, self.outbound_message_router + admin_host, + admin_port, + context, + self.outbound_message_router, + self.webhook_router, + self.dispatcher.task_queue, + self.get_stats, ) webhook_urls = context.settings.get("admin.webhook_urls") if webhook_urls: for url in webhook_urls: self.admin_server.add_webhook_target(url) context.injector.bind_instance(BaseAdminServer, self.admin_server) + if "http" not in self.outbound_transport_manager.registered_schemes: + self.outbound_transport_manager.register("http") except Exception: - self.logger.exception("Unable to register admin server") + LOGGER.exception("Unable to register admin server") raise - self.context = context - self.dispatcher = Dispatcher(self.context) - + # Fetch stats collector, if any + collector = await context.inject(Collector, required=False) if collector: # add stats to our own methods collector.wrap( self, ( - "inbound_message_router", + # "inbound_message_router", "outbound_message_router", - "prepare_outbound_message", + # "create_inbound_session", ), ) - collector.wrap(self.dispatcher, "dispatch") + collector.wrap(self.dispatcher, "handle_message") # at the class level (!) should not be performed multiple times collector.wrap( ConnectionManager, ( - "get_connection_target", + "get_connection_targets", "fetch_did_document", - "find_message_connection", + "find_inbound_connection", ), ) + self.context = context + async def start(self) -> None: """Start the agent.""" @@ -168,12 +146,12 @@ async def start(self) -> None: try: await self.inbound_transport_manager.start() except Exception: - self.logger.exception("Unable to start inbound transports") + LOGGER.exception("Unable to start inbound transports") raise try: await self.outbound_transport_manager.start() except Exception: - self.logger.exception("Unable to start outbound transports") + LOGGER.exception("Unable to start outbound transports") raise # Start up Admin server @@ -181,7 +159,7 @@ async def start(self) -> None: try: await self.admin_server.start() except Exception: - self.logger.exception("Unable to start administration API") + LOGGER.exception("Unable to start administration API") # Make admin responder available during message parsing # This allows webhooks to be called when a connection is marked active, # for example @@ -231,245 +209,155 @@ async def start(self) -> None: print("Invitation URL:") print(invite_url) except Exception: - self.logger.exception("Error creating invitation") + LOGGER.exception("Error creating invitation") - async def stop(self, timeout=0.1): + async def stop(self, timeout=1.0): """Stop the agent.""" - tasks = [] + shutdown = TaskQueue() if self.admin_server: - tasks.append(self.admin_server.stop()) + shutdown.run(self.admin_server.stop()) if self.inbound_transport_manager: - tasks.append(self.inbound_transport_manager.stop()) + shutdown.run(self.inbound_transport_manager.stop()) if self.outbound_transport_manager: - tasks.append(self.outbound_transport_manager.stop()) - await asyncio.wait_for(asyncio.gather(*tasks), timeout) - - async def register_socket( - self, *, handler: Coroutine = None, single_response: asyncio.Future = None - ) -> SocketRef: - """Register a new duplex connection.""" - socket = SocketInfo(handler=handler, single_response=single_response) - socket_id = socket.socket_id - self.sockets[socket_id] = socket - - async def close_socket(): - socket.closed = True + shutdown.run(self.outbound_transport_manager.stop()) + await shutdown.complete(timeout) - return SocketRef(socket_id=socket_id, close=close_socket) - - async def inbound_message_router( - self, - message_body: Union[str, bytes], - transport_type: str = None, - socket_id: str = None, - single_response: asyncio.Future = None, - ) -> asyncio.Future: + def inbound_message_router( + self, message: InboundMessage, can_respond: bool = False + ): """ Route inbound messages. Args: - message_body: Body of the incoming message - transport_type: Type of transport this message came from - socket_id: The identifier of the incoming socket connection - single_response: A future to contain the first direct response message + message: The inbound message instance + can_respond: If the session supports return routing """ - try: - parsed_msg, delivery = await self.message_serializer.parse_message( - self.context, message_body, transport_type - ) - except MessageParseError: - self.logger.exception("Error expanding message") - raise - - connection_mgr = ConnectionManager(self.context) - connection = await connection_mgr.find_message_connection(delivery) - if connection: - delivery.connection_id = connection.connection_id - - if single_response and not socket_id: - # if transport wasn't a socket, make a virtual socket used for responses - socket = SocketInfo(single_response=single_response) - socket_id = socket.socket_id - self.sockets[socket_id] = socket - - if socket_id: - if socket_id not in self.sockets: - self.logger.warning( - "Inbound message on unregistered socket ID: %s", socket_id - ) - socket_id = None - elif self.sockets[socket_id].closed: - self.logger.warning( - "Inbound message on closed socket ID: %s", socket_id - ) - socket_id = None - - delivery.socket_id = socket_id - socket: SocketInfo = self.sockets[socket_id] if socket_id else None - - if socket: - socket.process_incoming(parsed_msg, delivery) - elif ( - delivery.direct_response_requested - and delivery.direct_response_requested != SocketInfo.REPLY_MODE_NONE - ): - self.logger.warning( + if message.receipt.direct_response_requested and not can_respond: + LOGGER.warning( "Direct response requested, but not supported by transport: %s", - delivery.transport_type, + message.transport_type, ) - handler_done = await self.dispatcher.dispatch( - parsed_msg, delivery, connection, self.outbound_message_router - ) - return asyncio.ensure_future(self.complete_dispatch(handler_done, socket)) + # Note: at this point we could send the message to a shared queue + # if this pod is too busy to process it - async def complete_dispatch(self, dispatch: asyncio.Future, socket: SocketInfo): - """Wait for the dispatch to complete and perform final actions.""" - await dispatch - await self.queue_processing(socket) - if socket: - socket.dispatch_complete() + self.dispatcher.queue_message( + message, + self.outbound_message_router, + self.admin_server and self.admin_server.send_webhook, + lambda completed: self.dispatch_complete(message, completed), + ) - async def queue_processing(self, socket: SocketInfo): - """ - Interact with undelivered queue to find applicable messages. + def dispatch_complete(self, message: InboundMessage, completed: CompletedTask): + """Handle completion of message dispatch.""" + if completed.exc_info: + LOGGER.exception( + "Exception in message handler:", exc_info=completed.exc_info + ) + self.inbound_transport_manager.dispatch_complete(message, completed) + + async def get_stats(self) -> dict: + """Get the current stats tracked by the conductor.""" + stats = { + "in_sessions": len(self.inbound_transport_manager.sessions), + "out_encode": 0, + "out_deliver": 0, + "task_active": self.dispatcher.task_queue.current_active, + "task_done": self.dispatcher.task_queue.total_done, + "task_failed": self.dispatcher.task_queue.total_failed, + "task_pending": self.dispatcher.task_queue.current_pending, + } + for m in self.outbound_transport_manager.outbound_buffer: + if m.state == m.STATE_ENCODE: + stats["out_encode"] += 1 + if m.state == m.STATE_DELIVER: + stats["out_deliver"] += 1 + return stats - Args: - socket: The incoming socket connection + async def outbound_message_router( + self, + context: InjectionContext, + outbound: OutboundMessage, + inbound: InboundMessage = None, + ) -> None: """ - if ( - socket - and socket.reply_mode - and not socket.closed - and self.undelivered_queue - ): - for key in socket.reply_verkeys: - if not isinstance(key, str): - key = key.value - if self.undelivered_queue.has_message_for_key(key): - for ( - undelivered_message - ) in self.undelivered_queue.inspect_all_messages_for_key(key): - # pending message. Transmit, then kill single_response - if socket.select_outgoing(undelivered_message): - self.logger.debug( - "Sending Queued Message via inbound connection" - ) - self.undelivered_queue.remove_message_for_key( - key, undelivered_message - ) - await socket.send(undelivered_message) - - async def get_connection_target( - self, connection_id: str, context: InjectionContext = None - ): - """Get a `ConnectionTarget` instance representing a connection. + Route an outbound message. Args: - connection_id: The connection record identifier - context: An optional injection context + context: The request context + message: An outbound message to be sent + inbound: The inbound message that produced this response, if available """ + if not outbound.target and outbound.reply_to_verkey: + if not outbound.reply_from_verkey and inbound: + outbound.reply_from_verkey = inbound.receipt.recipient_verkey + # return message to an inbound session + if self.inbound_transport_manager.return_to_session(outbound): + return - context = context or self.context + await self.queue_outbound(context, outbound, inbound) - try: - record = await ConnectionRecord.retrieve_by_id(context, connection_id) - except StorageNotFoundError as e: - raise MessagePrepareError( - "Could not locate connection record: {}".format(connection_id) - ) from e - mgr = ConnectionManager(context) - try: - target = await mgr.get_connection_target(record) - except ConnectionManagerError as e: - raise MessagePrepareError(str(e)) from e - if not target: - raise MessagePrepareError( - "No target found for connection: {}".format(connection_id) - ) - return target + def handle_not_returned(self, context: InjectionContext, outbound: OutboundMessage): + """Handle a message that failed delivery via an inbound session.""" + self.dispatcher.run_task(self.queue_outbound(context, outbound)) - async def prepare_outbound_message( + async def queue_outbound( self, - message: OutboundMessage, - context: InjectionContext = None, - direct_response: bool = False, + context: InjectionContext, + outbound: OutboundMessage, + inbound: InboundMessage = None, ): - """Prepare a response message for transmission. + """ + Queue an outbound message. Args: + context: The request context message: An outbound message to be sent - context: Optional request context - direct_response: Skip wrapping the response in forward messages + inbound: The inbound message that produced this response, if available """ + # populate connection target(s) + if not outbound.target and not outbound.target_list and outbound.connection_id: + # using provided request context + mgr = ConnectionManager(context) + try: + outbound.target_list = await self.dispatcher.run_task( + mgr.get_connection_targets(connection_id=outbound.connection_id) + ) + except ConnectionManagerError: + LOGGER.exception("Error preparing outbound message for transmission") + return - context = context or self.context - - if message.connection_id and not message.target: - message.target = await self.get_connection_target(message.connection_id) + try: + self.outbound_transport_manager.enqueue_message(context, outbound) + except OutboundDeliveryError: + LOGGER.warning("Cannot queue message for delivery, no supported transport") + self.handle_not_delivered(context, outbound) - if not message.encoded and message.target: - target = message.target - message.payload = await self.message_serializer.encode_message( - context, - message.payload, - target.recipient_keys or [], - (not direct_response) and target.routing_keys or [], - target.sender_key, - ) - message.encoded = True + def handle_not_delivered( + self, context: InjectionContext, outbound: OutboundMessage + ): + """Handle a message that failed delivery via outbound transports.""" + self.inbound_transport_manager.return_undelivered(outbound) - async def outbound_message_router( - self, message: OutboundMessage, context: InjectionContext = None - ) -> None: + def webhook_router( + self, topic: str, payload: dict, endpoint: str, retries: int = None + ): """ - Route an outbound message. + Route a webhook through the outbound transport manager. Args: - message: An outbound message to be sent - context: Optional request context + topic: The webhook topic + payload: The webhook payload + endpoint: The endpoint of the webhook target + retries: The number of retries """ - - # try socket connections first, preferring the same socket ID - socket_id = message.reply_socket_id - sel_socket = None - if ( - socket_id - and socket_id in self.sockets - and self.sockets[socket_id].select_outgoing(message) - ): - sel_socket = self.sockets[socket_id] - else: - for socket in self.sockets.values(): - if socket.select_outgoing(message): - sel_socket = socket - break - if sel_socket: - try: - await self.prepare_outbound_message(message, context, True) - except MessagePrepareError: - self.logger.exception( - "Error preparing outbound message for direct response" - ) - return - - await sel_socket.send(message) - self.logger.debug("Returned message to socket %s", sel_socket.socket_id) - return - try: - await self.prepare_outbound_message(message, context) - except MessagePrepareError: - self.logger.exception("Error preparing outbound message for transmission") - return - - # deliver directly to endpoint - if message.endpoint: - await self.outbound_transport_manager.send_message(message) - return - - # Add message to outbound queue, indexed by key - if self.undelivered_queue: - self.undelivered_queue.add_message(message) + self.outbound_transport_manager.enqueue_webhook( + topic, payload, endpoint, retries + ) + except OutboundDeliveryError: + LOGGER.warning( + "Cannot queue message webhook for delivery, no supported transport" + ) diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index c4043669e1..1f676f9dd2 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -7,6 +7,7 @@ from typing import Type from .error import ArgsParseError +from .util import ByteSize CAT_PROVISION = "general" CAT_START = "start" @@ -552,30 +553,6 @@ def get_settings(self, args: Namespace) -> dict: return settings -@group(CAT_START) -class QueueGroup(ArgumentGroup): - """Queue settings.""" - - GROUP_NAME = "Queue" - - def add_arguments(self, parser: ArgumentParser): - """Add queue-specific command line arguments to the parser.""" - parser.add_argument( - "--enable-undelivered-queue", - action="store_true", - help="Enable the outbound undelivered queue that enables this agent to hold messages\ - for delivery to agents without an endpoint. This option will require\ - additional memory to store messages in the queue.", - ) - - def get_settings(self, args: Namespace): - """Extract queue settings.""" - settings = {} - settings["queue.enable_undelivered_queue"] = args.enable_undelivered_queue - - return settings - - @group(CAT_START) class TransportGroup(ArgumentGroup): """Transport settings.""" @@ -598,7 +575,6 @@ def add_arguments(self, parser: ArgumentParser): be specified multiple times to create multiple interfaces.\ Supported inbound transport types are 'http' and 'ws'.", ) - parser.add_argument( "-ot", "--outbound-transport", @@ -612,7 +588,6 @@ def add_arguments(self, parser: ArgumentParser): multiple times to supoort multiple transport types. Supported outbound\ transport types are 'http' and 'ws'.", ) - parser.add_argument( "-e", "--endpoint", @@ -629,7 +604,6 @@ def add_arguments(self, parser: ArgumentParser): The endpoints are used in the formation of a connection \ with another agent.", ) - parser.add_argument( "-l", "--label", @@ -638,18 +612,37 @@ def add_arguments(self, parser: ArgumentParser): help="Specifies the label for this agent. This label is publicized\ (self-attested) to other agents as part of forming a connection.", ) + parser.add_argument( + "--max-message-size", + default=2097152, + type=ByteSize(min_size=1024), + metavar="", + help="Set the maximum size in bytes for inbound agent messages.", + ) + + parser.add_argument( + "--enable-undelivered-queue", + action="store_true", + help="Enable the outbound undelivered queue that enables this agent to hold messages\ + for delivery to agents without an endpoint. This option will require\ + additional memory to store messages in the queue.", + ) def get_settings(self, args: Namespace): """Extract transport settings.""" settings = {} settings["transport.inbound_configs"] = args.inbound_transports settings["transport.outbound_configs"] = args.outbound_transports + settings["transport.enable_undelivered_queue"] = args.enable_undelivered_queue if args.endpoint: settings["default_endpoint"] = args.endpoint[0] settings["additional_endpoints"] = args.endpoint[1:] if args.label: settings["default_label"] = args.label + if args.max_message_size: + settings["transport.max_message_size"] = args.max_message_size + return settings diff --git a/aries_cloudagent/config/default_context.py b/aries_cloudagent/config/default_context.py index e9a3739c1e..e1501762ad 100644 --- a/aries_cloudagent/config/default_context.py +++ b/aries_cloudagent/config/default_context.py @@ -13,7 +13,6 @@ from ..verifier.base import BaseVerifier from ..messaging.plugin_registry import PluginRegistry from ..messaging.protocol_registry import ProtocolRegistry -from ..messaging.serializer import MessageSerializer from ..protocols.actionmenu.base_service import BaseMenuService from ..protocols.actionmenu.driver_service import DriverMenuService from ..protocols.introduction.base_service import BaseIntroductionService @@ -21,7 +20,8 @@ from ..stats import Collector from ..storage.base import BaseStorage from ..storage.provider import StorageProvider -from ..transport.outbound.queue.base import BaseOutboundMessageQueue +from ..transport.pack_format import PackWireFormat +from ..transport.wire_format import BaseWireFormat from ..wallet.base import BaseWallet from ..wallet.provider import WalletProvider @@ -123,26 +123,16 @@ async def bind_providers(self, context: InjectionContext): ), ) - # Register message serializer + # Register default pack format context.injector.bind_provider( - MessageSerializer, + BaseWireFormat, CachedProvider( StatsProvider( - ClassProvider(MessageSerializer), - ("encode_message", "parse_message"), + ClassProvider(PackWireFormat), ("encode_message", "parse_message"), ) ), ) - # Set default outbound message queue - context.injector.bind_provider( - BaseOutboundMessageQueue, - ClassProvider( - "aries_cloudagent.transport.outbound.queue" - + ".basic.BasicOutboundMessageQueue" - ), - ) - # Allow action menu to be provided by driver context.injector.bind_instance(BaseMenuService, DriverMenuService(context)) context.injector.bind_instance( diff --git a/aries_cloudagent/config/logging.py b/aries_cloudagent/config/logging.py index 16110bb4c7..2dbe91352a 100644 --- a/aries_cloudagent/config/logging.py +++ b/aries_cloudagent/config/logging.py @@ -129,7 +129,7 @@ def lr_pad(content: str): ) inbound_transport_strings = [] - for transport in inbound_transports: + for transport in inbound_transports.values(): host_port_string = ( f" - {transport.scheme}://{transport.host}:{transport.port}" ) @@ -142,11 +142,13 @@ def lr_pad(content: str): ) outbound_transport_strings = [] - for schemes in outbound_transports: - for scheme in schemes: - schema_string = f" - {scheme}" - scheme_spacer = " " * (banner_length - len(schema_string)) - outbound_transport_strings.append((schema_string, scheme_spacer)) + schemes = set().union( + *(transport.schemes for transport in outbound_transports.values()) + ) + for scheme in sorted(schemes): + schema_string = f" - {scheme}" + scheme_spacer = " " * (banner_length - len(schema_string)) + outbound_transport_strings.append((schema_string, scheme_spacer)) version_string = f"ver: {__version__}" version_string_spacer = " " * (banner_length - len(version_string)) diff --git a/aries_cloudagent/config/tests/test_argparse.py b/aries_cloudagent/config/tests/test_argparse.py index bb64b915f2..ae6dd85251 100644 --- a/aries_cloudagent/config/tests/test_argparse.py +++ b/aries_cloudagent/config/tests/test_argparse.py @@ -1,9 +1,10 @@ import itertools -from argparse import ArgumentParser +from argparse import ArgumentParser, ArgumentTypeError from asynctest import TestCase as AsyncTestCase, mock as async_mock from .. import argparse +from ..util import ByteSize class TestArgParse(AsyncTestCase): @@ -41,13 +42,16 @@ async def test_transport_settings(self): "http", "-e", "http://default.endpoint/", - "ws://alternate.endpoint/" + "ws://alternate.endpoint/", ] ) assert result.inbound_transports == [["http", "0.0.0.0", "80"]] assert result.outbound_transports == ["http"] - assert result.endpoint == ["http://default.endpoint/", "ws://alternate.endpoint/"] + assert result.endpoint == [ + "http://default.endpoint/", + "ws://alternate.endpoint/", + ] settings = group.get_settings(result) @@ -55,3 +59,34 @@ async def test_transport_settings(self): assert settings.get("transport.outbound_configs") == ["http"] assert settings.get("default_endpoint") == "http://default.endpoint/" assert settings.get("additional_endpoints") == ["ws://alternate.endpoint/"] + + def test_bytesize(self): + bs = ByteSize() + with self.assertRaises(ArgumentTypeError): + bs(None) + with self.assertRaises(ArgumentTypeError): + bs("") + with self.assertRaises(ArgumentTypeError): + bs("a") + with self.assertRaises(ArgumentTypeError): + bs("1.5") + with self.assertRaises(ArgumentTypeError): + bs("-1") + assert bs("101") == 101 + assert bs("101b") == 101 + assert bs("101KB") == 103424 + assert bs("2M") == 2097152 + assert bs("1G") == 1073741824 + assert bs("1t") == 1099511627776 + + bs = ByteSize(min_size=10) + with self.assertRaises(ArgumentTypeError): + bs("5") + assert bs("12") == 12 + + bs = ByteSize(max_size=10) + with self.assertRaises(ArgumentTypeError): + bs("15") + assert bs("10") == 10 + + assert repr(bs) == "ByteSize" diff --git a/aries_cloudagent/config/tests/test_default_context.py b/aries_cloudagent/config/tests/test_default_context.py index 1d69e77d59..2f7932df7d 100644 --- a/aries_cloudagent/config/tests/test_default_context.py +++ b/aries_cloudagent/config/tests/test_default_context.py @@ -1,9 +1,8 @@ from asynctest import TestCase as AsyncTestCase from ...messaging.protocol_registry import ProtocolRegistry -from ...messaging.serializer import MessageSerializer from ...storage.base import BaseStorage -from ...transport.outbound.queue.base import BaseOutboundMessageQueue +from ...transport.wire_format import BaseWireFormat from ...wallet.base import BaseWallet from ..default_context import DefaultContextBuilder @@ -19,8 +18,7 @@ async def test_build_context(self): assert isinstance(result, InjectionContext) for cls in ( - BaseOutboundMessageQueue, - MessageSerializer, + BaseWireFormat, ProtocolRegistry, BaseWallet, BaseStorage, diff --git a/aries_cloudagent/config/tests/test_logging.py b/aries_cloudagent/config/tests/test_logging.py index ac36101ae3..9aaacedf6e 100644 --- a/aries_cloudagent/config/tests/test_logging.py +++ b/aries_cloudagent/config/tests/test_logging.py @@ -39,6 +39,6 @@ def test_banner(self): with contextlib.redirect_stdout(stdout): test_label = "Aries Cloud Agent" test_did = "55GkHamhTU1ZbTbV2ab9DE" - test_module.LoggingConfigurator.print_banner(test_label, [], [], test_did) + test_module.LoggingConfigurator.print_banner(test_label, {}, {}, test_did) output = stdout.getvalue() assert test_did in output diff --git a/aries_cloudagent/config/util.py b/aries_cloudagent/config/util.py index 5fdfaebe86..a8ad9f9933 100644 --- a/aries_cloudagent/config/util.py +++ b/aries_cloudagent/config/util.py @@ -1,6 +1,9 @@ """Entrypoint.""" import os +import re + +from argparse import ArgumentTypeError from typing import Any, Mapping from .logging import LoggingConfigurator @@ -21,3 +24,43 @@ def common_config(settings: Mapping[str, Any]): and settings.get("wallet.storage_type") == "postgres_storage" ): load_postgres_plugin() + + +class ByteSize: + """Argument value parser for byte sizes.""" + + def __init__(self, min_size: int = 0, max_size: int = 0): + """Initialize the ByteSize parser.""" + self.min_size = min_size + self.max_size = max_size + + def __call__(self, arg: str) -> int: + """Interpret the argument value.""" + if not arg: + raise ArgumentTypeError("Expected value") + parts = re.match(r"^(\d+)([kKmMgGtT]?)[bB]?$", arg) + if not parts: + raise ArgumentTypeError("Invalid format") + size = int(parts[1]) + suffix = parts[2].upper() + if suffix == "K": + size = size << 10 + elif suffix == "M": + size = size << 20 + elif suffix == "G": + size = size << 30 + elif suffix == "T": + size = size << 40 + if size < self.min_size: + raise ArgumentTypeError( + f"Size must be greater than or equal to {self.min_size}" + ) + if self.max_size and size > self.max_size: + raise ArgumentTypeError( + f"Size must be less than or equal to {self.max_size}" + ) + return size + + def __repr__(self): + """Format for in error reporting.""" + return self.__class__.__name__ diff --git a/aries_cloudagent/connections/models/diddoc/diddoc.py b/aries_cloudagent/connections/models/diddoc/diddoc.py index 69bcdae3af..b764df5002 100644 --- a/aries_cloudagent/connections/models/diddoc/diddoc.py +++ b/aries_cloudagent/connections/models/diddoc/diddoc.py @@ -321,4 +321,9 @@ def from_json(cls, did_doc_json: str) -> "DIDDoc": def __str__(self) -> str: """Return string representation for abbreviated display.""" - return "DIDDoc({})".format(self.did) + return f"DIDDoc({self.did})" + + def __repr__(self) -> str: + """Format DIDDoc for logging.""" + + return f"" diff --git a/aries_cloudagent/dispatcher.py b/aries_cloudagent/dispatcher.py index 5b6dffea6e..91b7842691 100644 --- a/aries_cloudagent/dispatcher.py +++ b/aries_cloudagent/dispatcher.py @@ -7,23 +7,24 @@ import asyncio import logging -from typing import Coroutine, Union +from typing import Callable, Coroutine, Union -from .admin.base_server import BaseAdminServer from .config.injection_context import InjectionContext from .messaging.agent_message import AgentMessage -from .connections.models.connection_record import ConnectionRecord from .messaging.error import MessageParseError -from .messaging.message_delivery import MessageDelivery from .messaging.models.base import BaseModelError -from .messaging.outbound_message import OutboundMessage +from .protocols.connections.manager import ConnectionManager from .protocols.problem_report.message import ProblemReport from .messaging.protocol_registry import ProtocolRegistry from .messaging.request_context import RequestContext from .messaging.responder import BaseResponder -from .messaging.serializer import MessageSerializer +from .messaging.task_queue import TaskQueue from .messaging.util import datetime_now from .stats import Collector +from .transport.inbound.message import InboundMessage +from .transport.outbound.message import OutboundMessage + +LOGGER = logging.getLogger(__name__) class Dispatcher: @@ -37,59 +38,95 @@ class Dispatcher: def __init__(self, context: InjectionContext): """Initialize an instance of Dispatcher.""" self.context = context - self.logger = logging.getLogger(__name__) + self.task_queue = TaskQueue(max_active=20) + + def put_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Future: + """Run a task in the task queue, potentially blocking other handlers.""" + return self.task_queue.put(coro, complete) - async def dispatch( + def run_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Task: + """Run a task in the task queue, potentially blocking other handlers.""" + return self.task_queue.run(coro, complete) + + def queue_message( self, - parsed_msg: dict, - delivery: MessageDelivery, - connection: ConnectionRecord, - send: Coroutine, + inbound_message: InboundMessage, + send_outbound: Coroutine, + send_webhook: Coroutine = None, + complete: Callable = None, ) -> asyncio.Future: """ - Configure responder and dispatch message context to message handler. + Add a message to the processing queue for handling. + + Args: + inbound_message: The inbound message instance + send_outbound: Async function to send outbound messages + send_webhook: Async function to dispatch a webhook + complete: Function to call when the handler has completed + + Returns: + A future resolving to the handler task + + """ + return self.put_task( + self.handle_message(inbound_message, send_outbound, send_webhook), complete + ) + + async def handle_message( + self, + inbound_message: InboundMessage, + send_outbound: Coroutine, + send_webhook: Coroutine = None, + ): + """ + Configure responder and message context and invoke the message handler. Args: - parsed_msg: The parsed message body - delivery: The incoming message delivery metadata - connection: The related connection record, if any - send: Function to send outbound messages + inbound_message: The inbound message instance + send_outbound: Async function to send outbound messages + send_webhook: Async function to dispatch a webhook Returns: The response from the handler """ + connection_mgr = ConnectionManager(self.context) + connection = await connection_mgr.find_inbound_connection( + inbound_message.receipt + ) + if connection: + inbound_message.connection_id = connection.connection_id + error_result = None try: - message = await self.make_message(parsed_msg) + message = await self.make_message(inbound_message.payload) except MessageParseError as e: - self.logger.error( - f"Message parsing failed: {str(e)}, sending problem report" - ) + LOGGER.error(f"Message parsing failed: {str(e)}, sending problem report") error_result = ProblemReport(explain_ltxt=str(e)) - if delivery.thread_id: - error_result.assign_thread_id(delivery.thread_id) + if inbound_message.receipt.thread_id: + error_result.assign_thread_id(inbound_message.receipt.thread_id) message = None context = RequestContext(base_context=self.context) context.message = message - context.message_delivery = delivery + context.message_receipt = inbound_message.receipt context.connection_ready = connection and connection.is_ready context.connection_record = connection responder = DispatcherResponder( - send, context, + inbound_message, + send_outbound, + send_webhook, connection_id=connection and connection.connection_id, - reply_socket_id=delivery.socket_id, - reply_to_verkey=delivery.sender_verkey, + reply_session_id=inbound_message.session_id, + reply_to_verkey=inbound_message.receipt.sender_verkey, ) if error_result: - return asyncio.get_event_loop().create_task( - responder.send_reply(error_result) - ) + await responder.send_reply(error_result) + return context.injector.bind_instance(BaseResponder, responder) @@ -98,10 +135,7 @@ async def dispatch( collector: Collector = await context.inject(Collector, required=False) if collector: collector.wrap(handler_obj, "handle", ["any-message-handler"]) - handler = asyncio.get_event_loop().create_task( - handler_obj.handle(context, responder) - ) - return handler + await handler_obj.handle(context, responder) async def make_message(self, parsed_msg: dict) -> AgentMessage: """ @@ -124,13 +158,11 @@ async def make_message(self, parsed_msg: dict) -> AgentMessage: """ registry: ProtocolRegistry = await self.context.inject(ProtocolRegistry) - serializer: MessageSerializer = await self.context.inject(MessageSerializer) - - # throws a MessageParseError on failure - message_type = serializer.extract_message_type(parsed_msg) + message_type = parsed_msg.get("@type") + if not message_type: + raise MessageParseError("Message does not contain '@type' parameter") message_cls = registry.resolve_message_class(message_type) - if not message_cls: raise MessageParseError(f"Unrecognized message type {message_type}") @@ -145,30 +177,45 @@ async def make_message(self, parsed_msg: dict) -> AgentMessage: class DispatcherResponder(BaseResponder): """Handle outgoing messages from message handlers.""" - def __init__(self, send: Coroutine, context: RequestContext, **kwargs): + def __init__( + self, + context: RequestContext, + inbound_message: InboundMessage, + send_outbound: Coroutine, + send_webhook: Coroutine = None, + **kwargs, + ): """ Initialize an instance of `DispatcherResponder`. Args: - send: Function to send outbound message context: The request context of the incoming message + inbound_message: The inbound message triggering this handler + send_outbound: Async function to send outbound message + send_webhook: Async function to dispatch a webhook """ super().__init__(**kwargs) self._context = context - self._send = send + self._inbound_message = inbound_message + self._send = send_outbound + self._webhook = send_webhook async def create_outbound( self, message: Union[AgentMessage, str, bytes], **kwargs ) -> OutboundMessage: - """Create an OutboundMessage from a message body.""" + """ + Create an OutboundMessage from a message body. + + Args: + message: The message payload + """ if isinstance(message, AgentMessage) and self._context.settings.get( "timing.enabled" ): # Inject the timing decorator in_time = ( - self._context.message_delivery - and self._context.message_delivery.in_time + self._context.message_receipt and self._context.message_receipt.in_time ) if not message._decorators.get("timing"): message._decorators["timing"] = { @@ -184,7 +231,7 @@ async def send_outbound(self, message: OutboundMessage): Args: message: The `OutboundMessage` to be sent """ - await self._send(message) + await self._send(self._context, message, self._inbound_message) async def send_webhook(self, topic: str, payload: dict): """ @@ -194,10 +241,5 @@ async def send_webhook(self, topic: str, payload: dict): topic: the webhook topic identifier payload: the webhook payload value """ - asyncio.get_event_loop().create_task(self._dispatch_webhook(topic, payload)) - - async def _dispatch_webhook(self, topic: str, payload: dict): - """Perform dispatch of a webhook.""" - server = await self._context.inject(BaseAdminServer, required=False) - if server: - await server.send_webhook(topic, payload) + if self._webhook: + await self._webhook(topic, payload) diff --git a/aries_cloudagent/messaging/base_context.py b/aries_cloudagent/messaging/base_context.py deleted file mode 100644 index c1e4de4d72..0000000000 --- a/aries_cloudagent/messaging/base_context.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Abstract RequestContext base class.""" - -from abc import ABC - - -class BaseRequestContext(ABC): - """ - Abstract RequestContext base class. - - This class is only for resolving recursive import issues. - """ diff --git a/aries_cloudagent/messaging/request_context.py b/aries_cloudagent/messaging/request_context.py index 7edb39c686..61e068525b 100644 --- a/aries_cloudagent/messaging/request_context.py +++ b/aries_cloudagent/messaging/request_context.py @@ -9,9 +9,9 @@ from ..config.injection_context import InjectionContext from ..connections.models.connection_record import ConnectionRecord +from ..transport.inbound.receipt import MessageReceipt from .agent_message import AgentMessage -from .message_delivery import MessageDelivery class RequestContext(InjectionContext): @@ -33,7 +33,7 @@ def __init__( self._connection_ready = False self._connection_record = None self._message = None - self._message_delivery = None + self._message_receipt = None @property def connection_ready(self) -> bool: @@ -137,25 +137,25 @@ def message(self, msg: AgentMessage): self._message = msg @property - def message_delivery(self) -> MessageDelivery: + def message_receipt(self) -> MessageReceipt: """ - Accessor for the message delivery information. + Accessor for the message receipt information. Returns: - This context's message delivery information + This context's message receipt information """ - return self._message_delivery + return self._message_receipt - @message_delivery.setter - def message_delivery(self, delivery: MessageDelivery): + @message_receipt.setter + def message_receipt(self, receipt: MessageReceipt): """ - Setter for the message delivery information. + Setter for the message receipt information. Args: - msg: This context's new message delivery information + msg: This context's new message receipt information """ - self._message_delivery = delivery + self._message_receipt = receipt def __repr__(self) -> str: """ diff --git a/aries_cloudagent/messaging/responder.py b/aries_cloudagent/messaging/responder.py index 0ef2361396..3718c59145 100644 --- a/aries_cloudagent/messaging/responder.py +++ b/aries_cloudagent/messaging/responder.py @@ -6,13 +6,13 @@ """ from abc import ABC, abstractmethod -from typing import Union +from typing import Sequence, Union from ..error import BaseError from ..connections.models.connection_target import ConnectionTarget +from ..transport.outbound.message import OutboundMessage from .agent_message import AgentMessage -from .outbound_message import OutboundMessage class ResponderError(BaseError): @@ -26,12 +26,12 @@ def __init__( self, *, connection_id: str = None, - reply_socket_id: str = None, + reply_session_id: str = None, reply_to_verkey: str = None, ): """Initialize a base responder.""" self.connection_id = connection_id - self.reply_socket_id = reply_socket_id + self.reply_session_id = reply_session_id self.reply_to_verkey = reply_to_verkey async def create_outbound( @@ -39,28 +39,30 @@ async def create_outbound( message: Union[AgentMessage, str, bytes], *, connection_id: str = None, - reply_socket_id: str = None, + reply_session_id: str = None, reply_thread_id: str = None, reply_to_verkey: str = None, target: ConnectionTarget = None, + target_list: Sequence[ConnectionTarget] = None, ) -> OutboundMessage: """Create an OutboundMessage from a message payload.""" if isinstance(message, AgentMessage): payload = message.to_json() - encoded = False + enc_payload = None if not reply_thread_id: reply_thread_id = message._thread_id else: - payload = message - encoded = True + payload = None + enc_payload = message return OutboundMessage( - payload, connection_id=connection_id, - encoded=encoded, - reply_socket_id=reply_socket_id, + enc_payload=enc_payload, + payload=payload, + reply_session_id=reply_session_id, reply_thread_id=reply_thread_id, reply_to_verkey=reply_to_verkey, target=target, + target_list=target_list, ) async def send(self, message: Union[AgentMessage, str, bytes], **kwargs): @@ -74,6 +76,7 @@ async def send_reply( *, connection_id: str = None, target: ConnectionTarget = None, + target_list: Sequence[ConnectionTarget] = None, ): """ Send a reply to an incoming message. @@ -90,9 +93,10 @@ async def send_reply( outbound = await self.create_outbound( message, connection_id=connection_id or self.connection_id, - reply_socket_id=self.reply_socket_id, + reply_session_id=self.reply_session_id, reply_to_verkey=self.reply_to_verkey, target=target, + target_list=target_list, ) await self.send_outbound(outbound) diff --git a/aries_cloudagent/messaging/socket.py b/aries_cloudagent/messaging/socket.py deleted file mode 100644 index 108a95ecd5..0000000000 --- a/aries_cloudagent/messaging/socket.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Duplex connection handling classes.""" - -import asyncio -from typing import Coroutine, Sequence -import uuid - -from .message_delivery import MessageDelivery -from .outbound_message import OutboundMessage - - -class SocketInfo: - """Track an open transport connection for direct routing of outbound messages.""" - - REPLY_MODE_ALL = "all" - REPLY_MODE_NONE = "none" - REPLY_MODE_THREAD = "thread" - - def __init__( - self, - *, - connection_id: str = None, - handler: Coroutine = None, - reply_mode: str = None, - reply_thread_ids: Sequence[str] = None, - reply_verkeys: Sequence[str] = None, - single_response: asyncio.Future = None, - socket_id: str = None, - ): - """Initialize the socket info.""" - self._closed = False - self.connection_id = connection_id - self.handler = handler - self.reply_thread_ids = set(reply_thread_ids) if reply_thread_ids else set() - self.reply_verkeys = set(reply_verkeys) if reply_verkeys else set() - self.single_response = single_response - self.socket_id = socket_id or str(uuid.uuid4()) - # calls setter - self._reply_mode = None - self.reply_mode = reply_mode - - @property - def closed(self) -> bool: - """Accessor for the socket closed state.""" - if self._closed: - return True - if self.single_response and self.single_response.done(): - self._closed = True - return True - return False - - @closed.setter - def closed(self, flag: bool): - """Setter for the socket closed state.""" - self._closed = flag - - @property - def reply_mode(self) -> str: - """Accessor for the socket reply mode.""" - return self._reply_mode - - @reply_mode.setter - def reply_mode(self, mode: str): - """Setter for the socket reply mode.""" - if mode not in (self.REPLY_MODE_ALL, self.REPLY_MODE_THREAD): - mode = None - # reset the tracked thread IDs when the mode is changed to none - self.reply_thread_ids = set() - self._reply_mode = mode - - def add_reply_thread_id(self, thid: str): - """Add a thread ID to the set of potential reply targets.""" - if thid: - self.reply_thread_ids.add(thid) - - def add_reply_verkey(self, verkey: str): - """Add a verkey to the set of potential reply targets.""" - if verkey: - self.reply_verkeys.add(verkey) - - def process_incoming(self, parsed_msg: dict, delivery: MessageDelivery): - """Process an incoming message and update the socket metadata as necessary. - - Args: - parsed_msg: The unserialized message body - delivery: The message delivery metadata - """ - mode = self.reply_mode = delivery.direct_response_requested - self.add_reply_verkey(delivery.sender_verkey) - if mode == self.REPLY_MODE_THREAD: - self.add_reply_thread_id(delivery.thread_id) - delivery.direct_response = bool(mode) - if delivery.connection_id: - self.connection_id = delivery.connection_id - - def dispatch_complete(self): - """Indicate that a message handler has completed.""" - if not self.closed and self.single_response: - self.single_response.cancel() - - def select_outgoing(self, message: OutboundMessage) -> bool: - """Determine if an outbound message should be sent to this socket. - - Args: - message: The outbound message to be checked - """ - mode = self.reply_mode - if not self.closed: - if ( - mode == self.REPLY_MODE_ALL - and message.reply_socket_id == self.socket_id - ): - return True - if ( - mode == self.REPLY_MODE_ALL - and message.reply_to_verkey - and message.reply_to_verkey in self.reply_verkeys - ): - return True - if ( - mode == self.REPLY_MODE_ALL - and message.target - and message.target.recipient_keys - and any(True for k in message.target.recipient_keys - if k in self.reply_verkeys) - ): - return True - if ( - mode == self.REPLY_MODE_THREAD - and message.reply_thread_id - and message.reply_thread_id in self.reply_thread_ids - ): - return True - return False - - async def send(self, message: OutboundMessage): - """.""" - if self.single_response: - self.single_response.set_result(message.payload) - elif self.handler: - await self.handler(message.payload) - - -class SocketRef: - """A reference to a registered duplex connection.""" - - def __init__(self, socket_id: str, close: Coroutine): - """Initialize the socket reference.""" - self.close = close - self.socket_id = socket_id diff --git a/aries_cloudagent/messaging/task_queue.py b/aries_cloudagent/messaging/task_queue.py new file mode 100644 index 0000000000..4a513b182e --- /dev/null +++ b/aries_cloudagent/messaging/task_queue.py @@ -0,0 +1,271 @@ +"""Classes for managing a set of asyncio tasks.""" + +import asyncio +import logging +from typing import Callable, Coroutine, Tuple + +LOGGER = logging.getLogger(__name__) + + +def task_exc_info(task: asyncio.Task): + """Extract exception info from an asyncio task.""" + if not task or not task.done(): + return + try: + exc_val = task.exception() + except asyncio.CancelledError: + exc_val = asyncio.CancelledError("Task was cancelled") + if exc_val: + return type(exc_val), exc_val, exc_val.__traceback__ + + +class CompletedTask: + """Represent the result of a queued task.""" + + # Note: this would be a good place to return timing information + + def __init__(self, task: asyncio.Task, exc_info: Tuple): + """Initialize the completed task.""" + self.exc_info = exc_info + self.task = task + + +class TaskQueue: + """A class for managing a set of asyncio tasks.""" + + def __init__(self, max_active: int = 0): + """ + Initialize the task queue. + + Args: + max_active: The maximum number of tasks to automatically run + """ + self.loop = asyncio.get_event_loop() + self.active_tasks = [] + self.pending_tasks = [] + self.total_done = 0 + self.total_failed = 0 + self._cancelled = False + self._drain_evt = asyncio.Event() + self._drain_task: asyncio.Task = None + self._max_active = max_active + + @property + def cancelled(self) -> bool: + """Accessor for the cancelled property of the queue.""" + return self._cancelled + + @property + def max_active(self) -> int: + """Accessor for the maximum number of active tasks in the queue.""" + return self._max_active + + @property + def ready(self) -> bool: + """Accessor for the ready property of the queue.""" + return ( + not self._cancelled + and not self._max_active + or self.current_size < self._max_active + ) + + @property + def current_active(self) -> int: + """Accessor for the current number of active tasks in the queue.""" + return len(self.active_tasks) + + @property + def current_pending(self) -> int: + """Accessor for the current number of pending tasks in the queue.""" + return len(self.pending_tasks) + + @property + def current_size(self) -> int: + """Accessor for the total number of tasks in the queue.""" + return len(self.active_tasks) + len(self.pending_tasks) + + def __len__(self) -> int: + """Support for the len() builtin.""" + return self.current_size + + def drain(self) -> asyncio.Task: + """Start the process to run queued tasks.""" + if self._drain_task and not self._drain_task.done(): + self._drain_evt.set() + elif self.pending_tasks: + self._drain_task = self.loop.create_task(self._drain_loop()) + self._drain_task.add_done_callback(lambda task: self._drain_done(task)) + return self._drain_task + + def _drain_done(self, task: asyncio.Task): + """Handle completion of the drain process.""" + exc_info = task_exc_info(task) + if exc_info: + LOGGER.exception("Error draining task queue:", exc_info=exc_info) + if self._drain_task and self._drain_task.done(): + self._drain_task = None + + async def _drain_loop(self): + """Run pending tasks while there is room in the queue.""" + # Note: this method should not call async methods apart from + # waiting for the updated event, to avoid yielding to other queue methods + while True: + self._drain_evt.clear() + while self.pending_tasks and ( + not self._max_active or len(self.active_tasks) < self._max_active + ): + coro, task_complete, fut = self.pending_tasks.pop(0) + task = self.run(coro, task_complete) + if fut and not fut.done(): + fut.set_result(task) + if self.pending_tasks: + await self._drain_evt.wait() + else: + break + + def add_pending( + self, + coro: Coroutine, + task_complete: Callable = None, + fut: asyncio.Future = None, + ): + """ + Add a task to the pending queue. + + Args: + coro: The coroutine to run + task_complete: An optional callback when the task has completed + fut: A future that resolves to the task once it is queued + """ + if not asyncio.iscoroutine(coro): + raise ValueError(f"Expected coroutine, got {coro}") + self.pending_tasks.append((coro, task_complete, fut)) + self.drain() + + def add_active( + self, task: asyncio.Task, task_complete: Callable = None + ) -> asyncio.Task: + """ + Register an active async task with an optional completion callback. + + Args: + task: The asyncio task instance + task_complete: An optional callback to run on completion + """ + self.active_tasks.append(task) + task.add_done_callback(lambda fut: self.completed_task(task, task_complete)) + return task + + def run(self, coro: Coroutine, task_complete: Callable = None) -> asyncio.Task: + """ + Start executing a coroutine as an async task, bypassing the pending queue. + + Args: + coro: The coroutine to run + task_complete: A callback to run on completion + + Returns: the new asyncio task instance + + """ + if self._cancelled: + raise RuntimeError("Task queue has been cancelled") + if not asyncio.iscoroutine(coro): + raise ValueError(f"Expected coroutine, got {coro}") + task = self.loop.create_task(coro) + return self.add_active(task, task_complete) + + def put(self, coro: Coroutine, task_complete: Callable = None) -> asyncio.Future: + """ + Add a new task to the queue, delaying execution if busy. + + Args: + coro: The coroutine to run + task_complete: A callback to run on completion + + Returns: a future resolving to the asyncio task instance once queued + + """ + fut = self.loop.create_future() + if self._cancelled: + coro.close() + fut.cancel() + elif self.ready: + task = self.run(coro, task_complete) + fut.set_result(task) + else: + self.add_pending(coro, task_complete, fut) + return fut + + def completed_task(self, task: asyncio.Task, task_complete: Callable): + """Clean up after a task has completed and run callbacks.""" + exc_info = task_exc_info(task) + if exc_info: + self.total_failed += 1 + if not task_complete: + LOGGER.exception("Error running task", exc_info=exc_info) + else: + self.total_done += 1 + if task_complete: + try: + task_complete(CompletedTask(task, exc_info)) + except Exception: + LOGGER.exception("Error finalizing task") + try: + self.active_tasks.remove(task) + except ValueError: + pass + self.drain() + + def cancel_pending(self): + """Cancel any pending tasks in the queue.""" + if self._drain_task: + self._drain_task.cancel() + self._drain_task = None + for coro, task_complete, fut in self.pending_tasks: + coro.close() + fut.cancel() + self.pending_tasks = [] + + def cancel(self): + """Cancel any pending or active tasks in the queue.""" + self._cancelled = True + self.cancel_pending() + for task in self.active_tasks: + if not task.done(): + task.cancel() + + async def complete(self, timeout: float = None, cleanup: bool = True): + """Cancel any pending tasks and wait for, or cancel active tasks.""" + self._cancelled = True + self.cancel_pending() + if timeout or timeout is None: + try: + await self.wait_for(timeout) + except asyncio.TimeoutError: + pass + for task in self.active_tasks: + if not task.done(): + task.cancel() + if cleanup: + while True: + drain = self.drain() + if not drain: + break + await drain + + async def flush(self): + """Wait for any active or pending tasks to be completed.""" + self.drain() + while self.active_tasks or self._drain_task: + if self._drain_task: + await self._drain_task + if self.active_tasks: + await asyncio.wait(self.active_tasks) + + def __await__(self): + """Handle the builtin await operator.""" + yield from self.flush().__await__() + + async def wait_for(self, timeout: float): + """Wait for all queued tasks to complete with a timeout.""" + return await asyncio.wait_for(self.flush(), timeout) diff --git a/aries_cloudagent/messaging/tests/test_task_queue.py b/aries_cloudagent/messaging/tests/test_task_queue.py new file mode 100644 index 0000000000..4b1ae08a71 --- /dev/null +++ b/aries_cloudagent/messaging/tests/test_task_queue.py @@ -0,0 +1,159 @@ +import asyncio +from asynctest import TestCase + +from ..task_queue import CompletedTask, TaskQueue + + +async def retval(val): + return val + + +class TestTaskQueue(TestCase): + async def test_run(self): + queue = TaskQueue() + task = None + completed = [] + + def done(complete: CompletedTask): + assert complete.task is task + assert not complete.exc_info + completed.append(complete.task.result()) + + task = queue.run(retval(1), done) + assert queue.current_active == 1 + assert len(queue) == queue.current_size == 1 + assert not queue.current_pending + await queue.flush() + assert completed == [1] + assert task.result() == 1 + + with self.assertRaises(ValueError): + queue.run(None, done) + + async def test_put_no_limit(self): + queue = TaskQueue() + completed = [] + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.append(complete.task.result()) + + fut = queue.put(retval(1), done) + assert not queue.pending_tasks + await queue.flush() + assert completed == [1] + assert fut.result().result() == 1 + + with self.assertRaises(ValueError): + queue.add_pending(None, done) + + async def test_put_limited(self): + queue = TaskQueue(1) + assert queue.max_active == 1 + assert not queue.cancelled + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + fut1 = queue.put(retval(1), done) + fut2 = queue.put(retval(2), done) + assert queue.pending_tasks + await queue.flush() + assert completed == {1, 2} + assert fut1.result().result() == 1 + assert fut2.result().result() == 2 + + async def test_complete(self): + queue = TaskQueue() + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + queue.run(retval(1), done) + await queue.put(retval(2), done) + queue.put(retval(3), done) + await queue.complete() + assert completed == {1, 2, 3} + + async def test_cancel_pending(self): + queue = TaskQueue(1) + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + queue.run(retval(1), done) + queue.put(retval(2), done) + queue.put(retval(3), done) + queue.cancel_pending() + await queue.flush() + assert completed == {1} + + async def test_cancel_all(self): + queue = TaskQueue(1) + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + queue.run(retval(1), done) + queue.put(retval(2), done) + queue.put(retval(3), done) + queue.cancel() + assert queue.cancelled + await queue.flush() + assert not completed + assert not queue.current_size + + co = retval(1) + with self.assertRaises(RuntimeError): + queue.run(co, done) + co.close() + + co = retval(1) + fut = queue.put(co) + assert fut.cancelled() + + async def test_cancel_long(self): + queue = TaskQueue() + task = queue.run(asyncio.sleep(5)) + queue.cancel() + await queue + + # cancellation may take a second + # assert task.cancelled() + + with self.assertRaises(asyncio.CancelledError): + await task + + async def test_complete_with_timeout(self): + queue = TaskQueue() + task = queue.run(asyncio.sleep(5)) + await queue.complete(0.01) + + # cancellation may take a second + # assert task.cancelled() + + with self.assertRaises(asyncio.CancelledError): + await task + + async def test_repeat_callback(self): + # check that running the callback twice does not throw an exception + + queue = TaskQueue() + completed = [] + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.append(complete.task.result()) + + task = queue.run(retval(1), done) + await task + queue.completed_task(task, done) + assert completed == [1, 1] diff --git a/aries_cloudagent/protocols/connections/handlers/connection_request_handler.py b/aries_cloudagent/protocols/connections/handlers/connection_request_handler.py index 491aa13adc..56f0bdb347 100644 --- a/aries_cloudagent/protocols/connections/handlers/connection_request_handler.py +++ b/aries_cloudagent/protocols/connections/handlers/connection_request_handler.py @@ -24,16 +24,16 @@ async def handle(self, context: RequestContext, responder: BaseResponder): mgr = ConnectionManager(context) try: - await mgr.receive_request(context.message, context.message_delivery) + await mgr.receive_request(context.message, context.message_receipt) except ConnectionManagerError as e: self._logger.exception("Error receiving connection request") if e.error_code: - target = None + targets = None if context.message.connection and context.message.connection.did_doc: try: - target = mgr.diddoc_connection_target( + targets = mgr.diddoc_connection_targets( context.message.connection.did_doc, - context.message_delivery.recipient_verkey, + context.message_receipt.recipient_verkey, ) except ConnectionManagerError: self._logger.exception( @@ -41,5 +41,5 @@ async def handle(self, context: RequestContext, responder: BaseResponder): ) await responder.send_reply( ProblemReport(problem_code=e.error_code, explain=str(e)), - target=target, + target_list=targets, ) diff --git a/aries_cloudagent/protocols/connections/handlers/connection_response_handler.py b/aries_cloudagent/protocols/connections/handlers/connection_response_handler.py index db49f62efc..2450453ed4 100644 --- a/aries_cloudagent/protocols/connections/handlers/connection_response_handler.py +++ b/aries_cloudagent/protocols/connections/handlers/connection_response_handler.py @@ -25,17 +25,17 @@ async def handle(self, context: RequestContext, responder: BaseResponder): mgr = ConnectionManager(context) try: connection = await mgr.accept_response( - context.message, context.message_delivery + context.message, context.message_receipt ) except ConnectionManagerError as e: self._logger.exception("Error receiving connection response") if e.error_code: - target = None + targets = None if context.message.connection and context.message.connection.did_doc: try: - target = mgr.diddoc_connection_target( + targets = mgr.diddoc_connection_targets( context.message.connection.did_doc, - context.message_delivery.recipient_verkey, + context.message_receipt.recipient_verkey, ) except ConnectionManagerError: self._logger.exception( @@ -43,7 +43,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder): ) await responder.send_reply( ProblemReport(problem_code=e.error_code, explain=str(e)), - target=target, + target_list=targets, ) return diff --git a/aries_cloudagent/protocols/connections/handlers/tests/test_invitation_handler.py b/aries_cloudagent/protocols/connections/handlers/tests/test_invitation_handler.py index 9e0757efec..4e1a5fee40 100644 --- a/aries_cloudagent/protocols/connections/handlers/tests/test_invitation_handler.py +++ b/aries_cloudagent/protocols/connections/handlers/tests/test_invitation_handler.py @@ -1,9 +1,9 @@ import pytest from .....messaging.base_handler import HandlerException -from .....messaging.message_delivery import MessageDelivery from .....messaging.request_context import RequestContext from .....messaging.responder import MockResponder +from .....transport.inbound.receipt import MessageReceipt from ...handlers.connection_invitation_handler import ConnectionInvitationHandler from ...messages.connection_invitation import ConnectionInvitation @@ -13,7 +13,7 @@ @pytest.fixture() def request_context() -> RequestContext: ctx = RequestContext() - ctx.message_delivery = MessageDelivery() + ctx.message_receipt = MessageReceipt() yield ctx diff --git a/aries_cloudagent/protocols/connections/handlers/tests/test_request_handler.py b/aries_cloudagent/protocols/connections/handlers/tests/test_request_handler.py index da9666b83b..969df6ad8d 100644 --- a/aries_cloudagent/protocols/connections/handlers/tests/test_request_handler.py +++ b/aries_cloudagent/protocols/connections/handlers/tests/test_request_handler.py @@ -2,9 +2,9 @@ from asynctest import mock as async_mock from .....messaging.base_handler import HandlerException -from .....messaging.message_delivery import MessageDelivery from .....messaging.request_context import RequestContext from .....messaging.responder import MockResponder +from .....transport.inbound.receipt import MessageReceipt from ...handlers import connection_request_handler as handler from ...manager import ConnectionManagerError @@ -15,7 +15,7 @@ @pytest.fixture() def request_context() -> RequestContext: ctx = RequestContext() - ctx.message_delivery = MessageDelivery() + ctx.message_receipt = MessageReceipt() yield ctx @@ -31,7 +31,7 @@ async def test_called(self, mock_conn_mgr, request_context): mock_conn_mgr.assert_called_once_with(request_context) mock_conn_mgr.return_value.receive_request.assert_called_once_with( - request_context.message, request_context.message_delivery + request_context.message, request_context.message_receipt ) assert not responder.messages @@ -53,4 +53,4 @@ async def test_problem_report(self, mock_conn_mgr, request_context): isinstance(result, ProblemReport) and result.problem_code == ProblemReportReason.REQUEST_NOT_ACCEPTED ) - assert target == {"target": None} + assert target == {"target_list": None} diff --git a/aries_cloudagent/protocols/connections/handlers/tests/test_response_handler.py b/aries_cloudagent/protocols/connections/handlers/tests/test_response_handler.py index 1569bc329b..29cfad4a78 100644 --- a/aries_cloudagent/protocols/connections/handlers/tests/test_response_handler.py +++ b/aries_cloudagent/protocols/connections/handlers/tests/test_response_handler.py @@ -2,9 +2,9 @@ from asynctest import mock as async_mock from .....messaging.base_handler import HandlerException -from .....messaging.message_delivery import MessageDelivery from .....messaging.request_context import RequestContext from .....messaging.responder import MockResponder +from .....transport.inbound.receipt import MessageReceipt from ...handlers import connection_response_handler as handler from ...manager import ConnectionManagerError @@ -15,7 +15,7 @@ @pytest.fixture() def request_context() -> RequestContext: ctx = RequestContext() - ctx.message_delivery = MessageDelivery() + ctx.message_receipt = MessageReceipt() yield ctx @@ -31,7 +31,7 @@ async def test_called(self, mock_conn_mgr, request_context): mock_conn_mgr.assert_called_once_with(request_context) mock_conn_mgr.return_value.accept_response.assert_called_once_with( - request_context.message, request_context.message_delivery + request_context.message, request_context.message_receipt ) assert not responder.messages @@ -53,4 +53,4 @@ async def test_problem_report(self, mock_conn_mgr, request_context): isinstance(result, ProblemReport) and result.problem_code == ProblemReportReason.RESPONSE_NOT_ACCEPTED ) - assert target == {"target": None} + assert target == {"target_list": None} diff --git a/aries_cloudagent/protocols/connections/manager.py b/aries_cloudagent/protocols/connections/manager.py index 57045fa7ef..cb4ba5a060 100644 --- a/aries_cloudagent/protocols/connections/manager.py +++ b/aries_cloudagent/protocols/connections/manager.py @@ -2,7 +2,7 @@ import logging -from typing import Tuple, List +from typing import Sequence, Tuple from ...cache.base import BaseCache from ...connections.models.connection_record import ConnectionRecord @@ -12,11 +12,11 @@ from ...config.injection_context import InjectionContext from ...error import BaseError from ...ledger.base import BaseLedger -from ...messaging.message_delivery import MessageDelivery from ...messaging.responder import BaseResponder from ...storage.base import BaseStorage from ...storage.error import StorageError, StorageNotFoundError from ...storage.record import StorageRecord +from ...transport.inbound.receipt import MessageReceipt from ...wallet.base import BaseWallet, DIDInfo from ...wallet.crypto import create_keypair, seed_to_did from ...wallet.error import WalletNotFoundError @@ -292,14 +292,14 @@ async def create_request( return request async def receive_request( - self, request: ConnectionRequest, delivery: MessageDelivery + self, request: ConnectionRequest, receipt: MessageReceipt ) -> ConnectionRecord: """ Receive and store a connection request. Args: request: The `ConnectionRequest` to accept - delivery: The message delivery metadata + receipt: The message receipt Returns: The new or updated `ConnectionRecord` instance @@ -313,12 +313,12 @@ async def receive_request( connection_key = None # Determine what key will need to sign the response - if delivery.recipient_did_public: + if receipt.recipient_did_public: wallet: BaseWallet = await self.context.inject(BaseWallet) - my_info = await wallet.get_local_did(delivery.recipient_did) + my_info = await wallet.get_local_did(receipt.recipient_did) connection_key = my_info.verkey else: - connection_key = delivery.recipient_verkey + connection_key = receipt.recipient_verkey try: connection = await ConnectionRecord.retrieve_by_invitation_key( self.context, connection_key, ConnectionRecord.INITIATOR_SELF @@ -476,7 +476,7 @@ async def create_response( return response async def accept_response( - self, response: ConnectionResponse, delivery: MessageDelivery + self, response: ConnectionResponse, receipt: MessageReceipt ) -> ConnectionRecord: """ Accept a connection response. @@ -486,7 +486,7 @@ async def accept_response( Args: response: The `ConnectionResponse` to accept - delivery: The message delivery metadata + receipt: The message receipt Returns: The updated `ConnectionRecord` representing the connection @@ -509,11 +509,11 @@ async def accept_response( except StorageNotFoundError: pass - if not connection and delivery.sender_did: + if not connection and receipt.sender_did: # identify connection by the DID they used for us try: connection = await ConnectionRecord.retrieve_by_did( - self.context, delivery.sender_did, delivery.recipient_did + self.context, receipt.sender_did, receipt.recipient_did ) except StorageNotFoundError: pass @@ -595,8 +595,6 @@ async def create_static_connection( their_verkey = bytes_to_b58(their_verkey_bin) their_info = DIDInfo(their_did, their_verkey, {}) - print(their_info, my_info) - # Create connection record connection = ConnectionRecord( initiator=ConnectionRecord.INITIATOR_SELF, @@ -670,14 +668,14 @@ async def find_connection( return connection - async def find_message_connection( - self, delivery: MessageDelivery + async def find_inbound_connection( + self, receipt: MessageReceipt ) -> ConnectionRecord: """ Deserialize an incoming message and further populate the request context. Args: - delivery: The message delivery details + receipt: The message receipt Returns: The `ConnectionRecord` associated with the expanded message, if any @@ -688,99 +686,91 @@ async def find_message_connection( connection = None resolved = False - if delivery.sender_verkey and delivery.recipient_verkey: + if receipt.sender_verkey and receipt.recipient_verkey: cache_key = ( - f"connection_by_verkey::{delivery.sender_verkey}" - f"::{delivery.recipient_verkey}" + f"connection_by_verkey::{receipt.sender_verkey}" + f"::{receipt.recipient_verkey}" ) cache: BaseCache = await self.context.inject(BaseCache, required=False) if cache: async with cache.acquire(cache_key) as entry: if entry.result: cached = entry.result - delivery.sender_did = cached["sender_did"] - delivery.recipient_did_public = cached["recipient_did_public"] - delivery.recipient_did = cached["recipient_did"] + receipt.sender_did = cached["sender_did"] + receipt.recipient_did_public = cached["recipient_did_public"] + receipt.recipient_did = cached["recipient_did"] connection = await ConnectionRecord.retrieve_by_id( self.context, cached["id"] ) else: - connection = await self.resolve_message_connection(delivery) + connection = await self.resolve_inbound_connection(receipt) if connection: cache_val = { "id": connection.connection_id, - "sender_did": delivery.sender_did, - "recipient_did": delivery.recipient_did, - "recipient_did_public": delivery.recipient_did_public, + "sender_did": receipt.sender_did, + "recipient_did": receipt.recipient_did, + "recipient_did_public": receipt.recipient_did_public, } await entry.set_result(cache_val, 3600) resolved = True if not connection and not resolved: - connection = await self.resolve_message_connection(delivery) + connection = await self.resolve_inbound_connection(receipt) return connection - async def resolve_message_connection( - self, delivery: MessageDelivery + async def resolve_inbound_connection( + self, receipt: MessageReceipt ) -> ConnectionRecord: """ - Populate the delivery DID information and find the related `ConnectionRecord`. + Populate the receipt DID information and find the related `ConnectionRecord`. Args: - delivery: The message delivery details + receipt: The message receipt Returns: The `ConnectionRecord` associated with the expanded message, if any """ - if delivery.sender_verkey: + if receipt.sender_verkey: try: - delivery.sender_did = await self.find_did_for_key( - delivery.sender_verkey - ) + receipt.sender_did = await self.find_did_for_key(receipt.sender_verkey) except StorageNotFoundError: self._logger.warning( "No corresponding DID found for sender verkey: %s", - delivery.sender_verkey, + receipt.sender_verkey, ) - if delivery.recipient_verkey: + if receipt.recipient_verkey: try: wallet: BaseWallet = await self.context.inject(BaseWallet) my_info = await wallet.get_local_did_for_verkey( - delivery.recipient_verkey + receipt.recipient_verkey ) - delivery.recipient_did = my_info.did - if ( - "public" in my_info.metadata - and my_info.metadata["public"] is True - ): - delivery.recipient_did_public = True + receipt.recipient_did = my_info.did + if "public" in my_info.metadata and my_info.metadata["public"] is True: + receipt.recipient_did_public = True except InjectorError: self._logger.warning( "Cannot resolve recipient verkey, no wallet defined by " "context: %s", - delivery.recipient_verkey, + receipt.recipient_verkey, ) except WalletNotFoundError: self._logger.warning( "No corresponding DID found for recipient verkey: %s", - delivery.recipient_verkey, + receipt.recipient_verkey, ) return await self.find_connection( - delivery.sender_did, - delivery.recipient_did, - delivery.recipient_verkey, - True, + receipt.sender_did, receipt.recipient_did, receipt.recipient_verkey, True, ) async def create_did_document( self, did_info: DIDInfo, inbound_connection_id: str = None, - svc_endpoints: List[str] = [], + svc_endpoints: Sequence[str] = [], ) -> DIDDoc: """Create our DID document for a given DID. @@ -928,33 +918,40 @@ async def remove_keys_for_did(self, did: str): for record in keys: await storage.delete_record(record) - async def get_connection_target( - self, connection: ConnectionRecord - ) -> ConnectionTarget: + async def get_connection_targets( + self, *, connection_id: str = None, connection: ConnectionRecord = None + ): """Create a connection target from a `ConnectionRecord`. Args: - connection: The connection record (with associated `DIDDoc`) - used to generate the connection target + connection_id: The connection ID to search for + connection: The connection record itself, if already available """ - + if not connection_id: + connection_id = connection.connection_id cache: BaseCache = await self.context.inject(BaseCache, required=False) - cache_key = f"connection_target::{connection.connection_id}" + cache_key = f"connection_target::{connection_id}" if cache: async with cache.acquire(cache_key) as entry: if entry.result: - return ConnectionTarget.deserialize(entry.result) + targets = [ + ConnectionTarget.deserialize(row) for row in entry.result + ] else: - target = await self.fetch_connection_target(connection) - await entry.set_result(target.serialize(), 60) + if not connection: + connection = await ConnectionRecord.retrieve_by_id( + self.context, connection_id + ) + targets = await self.fetch_connection_targets(connection) + await entry.set_result([row.serialize() for row in targets], 3600) else: - target = await self.fetch_connection_target(connection) - return target + targets = await self.fetch_connection_targets(connection) + return targets - async def fetch_connection_target( + async def fetch_connection_targets( self, connection: ConnectionRecord - ) -> ConnectionTarget: - """Create a connection target from a `ConnectionRecord`. + ) -> Sequence[ConnectionTarget]: + """Get a list of connection target from a `ConnectionRecord`. Args: connection: The connection record (with associated `DIDDoc`) @@ -967,6 +964,7 @@ async def fetch_connection_target( wallet: BaseWallet = await self.context.inject(BaseWallet) my_info = await wallet.get_local_did(connection.my_did) + results = None if ( connection.state in (connection.STATE_INVITATION, connection.STATE_REQUEST) @@ -991,30 +989,32 @@ async def fetch_connection_target( recipient_keys = invitation.recipient_keys routing_keys = invitation.routing_keys - result = ConnectionTarget( - did=connection.their_did, - endpoint=endpoint, - label=invitation.label, - recipient_keys=recipient_keys, - routing_keys=routing_keys, - sender_key=my_info.verkey, - ) + results = [ + ConnectionTarget( + did=connection.their_did, + endpoint=endpoint, + label=invitation.label, + recipient_keys=recipient_keys, + routing_keys=routing_keys, + sender_key=my_info.verkey, + ) + ] else: if not connection.their_did: self._logger.debug("No target DID associated with connection") return None doc = await self.fetch_did_document(connection.their_did) - result = self.diddoc_connection_target( + results = self.diddoc_connection_targets( doc, my_info.verkey, connection.their_label ) - return result + return results - def diddoc_connection_target( + def diddoc_connection_targets( self, doc: DIDDoc, sender_verkey: str, their_label: str = None - ) -> ConnectionTarget: - """Create a connection target from a DID Document. + ) -> Sequence[ConnectionTarget]: + """Get a list of connection targets from a DID Document. Args: doc: The DID Document to create the target from @@ -1029,18 +1029,24 @@ def diddoc_connection_target( if not doc.service: raise ConnectionManagerError("No services defined by DIDDoc") + targets = [] for service in doc.service.values(): - if not service.recip_keys: - raise ConnectionManagerError("DIDDoc service has no recipient key(s)") - - return ConnectionTarget( - did=doc.did, - endpoint=service.endpoint, - label=their_label, - recipient_keys=[key.value for key in (service.recip_keys or ())], - routing_keys=[key.value for key in (service.routing_keys or ())], - sender_key=sender_verkey, - ) + if service.recip_keys: + targets.append( + ConnectionTarget( + did=doc.did, + endpoint=service.endpoint, + label=their_label, + recipient_keys=[ + key.value for key in (service.recip_keys or ()) + ], + routing_keys=[ + key.value for key in (service.routing_keys or ()) + ], + sender_key=sender_verkey, + ) + ) + return targets async def establish_inbound( self, connection: ConnectionRecord, inbound_connection_id: str, outbound_handler diff --git a/aries_cloudagent/protocols/connections/tests/test_manager.py b/aries_cloudagent/protocols/connections/tests/test_manager.py index dcaefaf288..e2fd4a35c3 100644 --- a/aries_cloudagent/protocols/connections/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/tests/test_manager.py @@ -3,9 +3,9 @@ from ....config.injection_context import InjectionContext from ....connections.models.connection_record import ConnectionRecord from ....connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service -from ....messaging.message_delivery import MessageDelivery from ....storage.base import BaseStorage from ....storage.basic import BasicStorage +from ....transport.inbound.receipt import MessageReceipt from ....wallet.base import BaseWallet from ....wallet.basic import BasicWallet @@ -65,7 +65,7 @@ async def test_non_multi_use_invitation_fails_on_reuse(self): my_endpoint="testendpoint" ) - delivery = MessageDelivery(recipient_verkey=connect_record.invitation_key) + receipt = MessageReceipt(recipient_verkey=connect_record.invitation_key) requestA = ConnectionRequest( connection=ConnectionDetail( @@ -77,7 +77,7 @@ async def test_non_multi_use_invitation_fails_on_reuse(self): label="SameInviteRequestA", ) - await self.manager.receive_request(requestA, delivery) + await self.manager.receive_request(requestA, receipt) requestB = ConnectionRequest( connection=ConnectionDetail( @@ -88,7 +88,7 @@ async def test_non_multi_use_invitation_fails_on_reuse(self): ) # requestB fails because the invitation was not set to multi-use - rr_awaitable = self.manager.receive_request(requestB, delivery) + rr_awaitable = self.manager.receive_request(requestB, receipt) await self.assertAsyncRaises(ConnectionManagerError, rr_awaitable) async def test_multi_use_invitation(self): @@ -96,7 +96,7 @@ async def test_multi_use_invitation(self): my_endpoint="testendpoint", multi_use=True ) - delivery = MessageDelivery(recipient_verkey=connect_record.invitation_key) + receipt = MessageReceipt(recipient_verkey=connect_record.invitation_key) requestA = ConnectionRequest( connection=ConnectionDetail( @@ -108,7 +108,7 @@ async def test_multi_use_invitation(self): label="SameInviteRequestA", ) - await self.manager.receive_request(requestA, delivery) + await self.manager.receive_request(requestA, receipt) requestB = ConnectionRequest( connection=ConnectionDetail( @@ -118,5 +118,5 @@ async def test_multi_use_invitation(self): label="SameInviteRequestB", ) - await self.manager.receive_request(requestB, delivery) + await self.manager.receive_request(requestB, receipt) diff --git a/aries_cloudagent/protocols/introduction/base_service.py b/aries_cloudagent/protocols/introduction/base_service.py index 8a9d72820f..e04a603ade 100644 --- a/aries_cloudagent/protocols/introduction/base_service.py +++ b/aries_cloudagent/protocols/introduction/base_service.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from ...error import BaseError -from ...messaging.base_context import BaseRequestContext +from ...messaging.request_context import RequestContext from .messages.invitation import Invitation @@ -15,7 +15,7 @@ class IntroductionError(BaseError): class BaseIntroductionService(ABC): """Service handler for allowing connections to exchange invitations.""" - def __init__(self, context: BaseRequestContext): + def __init__(self, context: RequestContext): """Init admin service.""" self._context = context @@ -23,7 +23,7 @@ def __init__(self, context: BaseRequestContext): def service_handler(cls): """Quick accessor for conductor to use.""" - async def get_instance(context: BaseRequestContext): + async def get_instance(context: RequestContext): """Return registered server.""" return cls(context) diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_ack_handler.py b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_ack_handler.py index b990c5ba3f..c389c5f083 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_ack_handler.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_ack_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.credential_ack import CredentialAck from .. import credential_ack_handler as handler @@ -15,7 +15,7 @@ class TestCredentialAckHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True @@ -35,7 +35,7 @@ async def test_called(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_issue_handler.py b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_issue_handler.py index f62f123410..8de78f0c8d 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_issue_handler.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_issue_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.credential_issue import CredentialIssue from .. import credential_issue_handler as handler @@ -15,7 +15,7 @@ class TestCredentialIssueHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_store_credential"] = False with async_mock.patch.object( @@ -34,7 +34,7 @@ async def test_called(self): async def test_called_auto_store(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_store_credential"] = True request_context.connection_record = async_mock.MagicMock() @@ -61,14 +61,12 @@ async def test_called_auto_store(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True ) as mock_cred_mgr: - mock_cred_mgr.return_value.receive_credential = ( - async_mock.CoroutineMock() - ) + mock_cred_mgr.return_value.receive_credential = async_mock.CoroutineMock() request_context.message = CredentialIssue() request_context.connection_ready = False handler_inst = handler.CredentialIssueHandler() diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_offer_handler.py b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_offer_handler.py index a8daf31e20..b39fe16d1d 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_offer_handler.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_offer_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.credential_offer import CredentialOffer from .. import credential_offer_handler as handler @@ -15,7 +15,7 @@ class TestCredentialOfferHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_respond_credential_offer"] = False with async_mock.patch.object( @@ -34,7 +34,7 @@ async def test_called(self): async def test_called_auto_request(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_respond_credential_offer"] = True request_context.connection_record = async_mock.MagicMock() request_context.connection_record.my_did = "dummy" @@ -62,14 +62,12 @@ async def test_called_auto_request(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True ) as mock_cred_mgr: - mock_cred_mgr.return_value.receive_offer = ( - async_mock.CoroutineMock() - ) + mock_cred_mgr.return_value.receive_offer = async_mock.CoroutineMock() request_context.message = CredentialOffer() request_context.connection_ready = False handler_inst = handler.CredentialOfferHandler() diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_proposal_handler.py b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_proposal_handler.py index 0a1136c29c..470e532e76 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_proposal_handler.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_proposal_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.credential_proposal import CredentialProposal from .. import credential_proposal_handler as handler @@ -15,7 +15,7 @@ class TestCredentialProposalHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True @@ -36,7 +36,7 @@ async def test_called(self): async def test_called_auto_offer(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.connection_record = async_mock.MagicMock() with async_mock.patch.object( @@ -65,14 +65,12 @@ async def test_called_auto_offer(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True ) as mock_cred_mgr: - mock_cred_mgr.return_value.receive_proposal = ( - async_mock.CoroutineMock() - ) + mock_cred_mgr.return_value.receive_proposal = async_mock.CoroutineMock() request_context.message = CredentialProposal() request_context.connection_ready = False handler_inst = handler.CredentialProposalHandler() diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_request_handler.py b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_request_handler.py index e9e274217b..00370a9ecd 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_request_handler.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/handlers/tests/test_credential_request_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.credential_request import CredentialRequest from .. import credential_request_handler as handler @@ -15,7 +15,7 @@ class TestCredentialRequestHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True @@ -36,7 +36,7 @@ async def test_called(self): async def test_called_auto_issue(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.connection_record = async_mock.MagicMock() with async_mock.patch.object( @@ -55,7 +55,7 @@ async def test_called_auto_issue(self): return_value=mock_cred_proposal ) mock_cred_proposal.credential_proposal = async_mock.MagicMock() - mock_cred_proposal.credential_proposal.attr_dict=async_mock.MagicMock() + mock_cred_proposal.credential_proposal.attr_dict = async_mock.MagicMock() request_context.message = CredentialRequest() request_context.connection_ready = True handler_inst = handler.CredentialRequestHandler() @@ -72,14 +72,12 @@ async def test_called_auto_issue(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "CredentialManager", autospec=True ) as mock_cred_mgr: - mock_cred_mgr.return_value.receive_request = ( - async_mock.CoroutineMock() - ) + mock_cred_mgr.return_value.receive_request = async_mock.CoroutineMock() request_context.message = CredentialRequest() request_context.connection_ready = False handler_inst = handler.CredentialRequestHandler() diff --git a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_ack_handler.py b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_ack_handler.py index 9faa639753..5a916678ee 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_ack_handler.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_ack_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.presentation_ack import PresentationAck from .. import presentation_ack_handler as handler @@ -15,7 +15,7 @@ class TestPresentationAckHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "PresentationManager", autospec=True @@ -35,7 +35,7 @@ async def test_called(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "PresentationManager", autospec=True diff --git a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_handler.py b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_handler.py index 0f641c2de1..37a88afd8a 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_handler.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.presentation import Presentation from .. import presentation_handler as handler @@ -15,7 +15,7 @@ class TestPresentationHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_verify_presentation"] = False with async_mock.patch.object( @@ -34,7 +34,7 @@ async def test_called(self): async def test_called_auto_verify(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_verify_presentation"] = True with async_mock.patch.object( @@ -54,14 +54,12 @@ async def test_called_auto_verify(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "PresentationManager", autospec=True ) as mock_pres_mgr: - mock_pres_mgr.return_value.receive_presentation = ( - async_mock.CoroutineMock() - ) + mock_pres_mgr.return_value.receive_presentation = async_mock.CoroutineMock() request_context.message = Presentation() request_context.connection_ready = False handler_inst = handler.PresentationHandler() diff --git a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_proposal_handler.py b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_proposal_handler.py index 4313f42be6..7aaefa75ac 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_proposal_handler.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_proposal_handler.py @@ -5,8 +5,8 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder +from ......transport.inbound.receipt import MessageReceipt from ...messages.presentation_proposal import PresentationProposal from .. import presentation_proposal_handler as handler @@ -15,7 +15,7 @@ class TestPresentationProposalHandler(AsyncTestCase): async def test_called(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_respond_presentation_proposal"] = False with async_mock.patch.object( @@ -38,7 +38,7 @@ async def test_called_auto_request(self): request_context = RequestContext() request_context.message = async_mock.MagicMock() request_context.message.comment = "hello world" - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.settings["debug.auto_respond_presentation_proposal"] = True with async_mock.patch.object( @@ -50,7 +50,7 @@ async def test_called_auto_request(self): mock_pres_mgr.return_value.create_bound_request = async_mock.CoroutineMock( return_value=( mock_pres_mgr.return_value.receive_proposal.return_value, - "presentation_request_message" + "presentation_request_message", ) ) request_context.message = PresentationProposal() @@ -64,7 +64,7 @@ async def test_called_auto_request(self): presentation_exchange_record=( mock_pres_mgr.return_value.receive_proposal.return_value ), - comment=request_context.message.comment + comment=request_context.message.comment, ) messages = responder.messages assert len(messages) == 1 @@ -74,14 +74,12 @@ async def test_called_auto_request(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "PresentationManager", autospec=True ) as mock_pres_mgr: - mock_pres_mgr.return_value.receive_proposal = ( - async_mock.CoroutineMock() - ) + mock_pres_mgr.return_value.receive_proposal = async_mock.CoroutineMock() request_context.message = PresentationProposal() request_context.connection_ready = False handler_inst = handler.PresentationProposalHandler() diff --git a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_request_handler.py b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_request_handler.py index ba24b9cdfb..07e266851b 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_request_handler.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/handlers/tests/test_presentation_request_handler.py @@ -5,9 +5,9 @@ ) from ......messaging.request_context import RequestContext -from ......messaging.message_delivery import MessageDelivery from ......messaging.responder import MockResponder from ......storage.error import StorageNotFoundError +from ......transport.inbound.receipt import MessageReceipt from ...messages.presentation_request import PresentationRequest from .. import presentation_request_handler as handler @@ -18,7 +18,7 @@ async def test_called(self): request_context = RequestContext() request_context.connection_record = async_mock.MagicMock() request_context.connection_record.connection_id = "dummy" - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.message = PresentationRequest() request_context.message.indy_proof_request = async_mock.MagicMock( return_value=async_mock.MagicMock() @@ -54,7 +54,7 @@ async def test_called_not_found(self): request_context = RequestContext() request_context.connection_record = async_mock.MagicMock() request_context.connection_record.connection_id = "dummy" - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.message = PresentationRequest() request_context.message.indy_proof_request = async_mock.MagicMock( return_value=async_mock.MagicMock() @@ -95,7 +95,7 @@ async def test_called_auto_present(self): request_context.message.indy_proof_request = async_mock.MagicMock( return_value=async_mock.MagicMock() ) - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "PresentationManager", autospec=True @@ -119,7 +119,7 @@ async def test_called_auto_present(self): handler.indy_proof_request2indy_requested_creds = async_mock.CoroutineMock( return_value=async_mock.MagicMock() ) - + mock_pres_mgr.return_value.create_presentation = async_mock.CoroutineMock( return_value=(mock_pres_ex_rec, "presentation_message") ) @@ -147,7 +147,7 @@ async def test_called_auto_present_value_error(self): request_context.message.indy_proof_request = async_mock.MagicMock( return_value=async_mock.MagicMock() ) - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "PresentationManager", autospec=True @@ -171,7 +171,7 @@ async def test_called_auto_present_value_error(self): handler.indy_proof_request2indy_requested_creds = async_mock.CoroutineMock( side_effect=ValueError ) - + mock_pres_mgr.return_value.create_presentation = async_mock.CoroutineMock( return_value=(mock_pres_ex_rec, "presentation_message") ) @@ -189,14 +189,12 @@ async def test_called_auto_present_value_error(self): async def test_called_not_ready(self): request_context = RequestContext() - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() with async_mock.patch.object( handler, "PresentationManager", autospec=True ) as mock_pres_mgr: - mock_pres_mgr.return_value.receive_request = ( - async_mock.CoroutineMock() - ) + mock_pres_mgr.return_value.receive_request = async_mock.CoroutineMock() request_context.message = PresentationRequest() request_context.connection_ready = False handler_inst = handler.PresentationRequestHandler() diff --git a/aries_cloudagent/protocols/problem_report/handler.py b/aries_cloudagent/protocols/problem_report/handler.py index 8a094358f3..2f56452846 100644 --- a/aries_cloudagent/protocols/problem_report/handler.py +++ b/aries_cloudagent/protocols/problem_report/handler.py @@ -22,7 +22,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.info( "Received problem report from: %s, %r", - context.message_delivery.sender_did, + context.message_receipt.sender_did, context.message, ) diff --git a/aries_cloudagent/protocols/problem_report/tests/test_handler.py b/aries_cloudagent/protocols/problem_report/tests/test_handler.py index 9c616b1fe1..7d765c8f49 100644 --- a/aries_cloudagent/protocols/problem_report/tests/test_handler.py +++ b/aries_cloudagent/protocols/problem_report/tests/test_handler.py @@ -1,9 +1,9 @@ import pytest from ....messaging.base_handler import HandlerException -from ....messaging.message_delivery import MessageDelivery from ....messaging.request_context import RequestContext from ....messaging.responder import MockResponder +from ....transport.inbound.receipt import MessageReceipt from ..handler import ProblemReportHandler from ..message import ProblemReport @@ -18,7 +18,7 @@ def request_context() -> RequestContext: class TestPingHandler: @pytest.mark.asyncio async def test_problem_report(self, request_context): - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.message = ProblemReport() request_context.connection_ready = True handler = ProblemReportHandler() diff --git a/aries_cloudagent/protocols/routing/handlers/forward_handler.py b/aries_cloudagent/protocols/routing/handlers/forward_handler.py index 21e308682a..bf7322a776 100644 --- a/aries_cloudagent/protocols/routing/handlers/forward_handler.py +++ b/aries_cloudagent/protocols/routing/handlers/forward_handler.py @@ -19,10 +19,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): self._logger.debug("ForwardHandler called with context %s", context) assert isinstance(context.message, Forward) - if not context.message_delivery.recipient_verkey: + if not context.message_receipt.recipient_verkey: raise HandlerException("Cannot forward message: unknown recipient") self._logger.info( - "Received forward for: %s", context.message_delivery.recipient_verkey + "Received forward for: %s", context.message_receipt.recipient_verkey ) packed = context.message.msg.encode("ascii") diff --git a/aries_cloudagent/protocols/routing/handlers/tests/test_query_update_handlers.py b/aries_cloudagent/protocols/routing/handlers/tests/test_query_update_handlers.py index 047576ba68..6e05e4eda0 100644 --- a/aries_cloudagent/protocols/routing/handlers/tests/test_query_update_handlers.py +++ b/aries_cloudagent/protocols/routing/handlers/tests/test_query_update_handlers.py @@ -2,11 +2,11 @@ from .....connections.models.connection_record import ConnectionRecord from .....messaging.base_handler import HandlerException -from .....messaging.message_delivery import MessageDelivery from .....messaging.request_context import RequestContext from .....messaging.responder import MockResponder from .....storage.base import BaseStorage from .....storage.basic import BasicStorage +from .....transport.inbound.receipt import MessageReceipt from ...handlers.route_query_request_handler import RouteQueryRequestHandler from ...handlers.route_update_request_handler import RouteUpdateRequestHandler @@ -27,7 +27,7 @@ def request_context() -> RequestContext: ctx = RequestContext() ctx.connection_ready = True ctx.connection_record = ConnectionRecord(connection_id="conn-id") - ctx.message_delivery = MessageDelivery(sender_verkey=TEST_VERKEY) + ctx.message_receipt = MessageReceipt(sender_verkey=TEST_VERKEY) ctx.injector.bind_instance(BaseStorage, BasicStorage()) yield ctx diff --git a/aries_cloudagent/protocols/routing/tests/test_routing_manager.py b/aries_cloudagent/protocols/routing/tests/test_routing_manager.py index b8375c6f0d..799ad087dd 100644 --- a/aries_cloudagent/protocols/routing/tests/test_routing_manager.py +++ b/aries_cloudagent/protocols/routing/tests/test_routing_manager.py @@ -1,9 +1,9 @@ import pytest -from aries_cloudagent.messaging.message_delivery import MessageDelivery from aries_cloudagent.messaging.request_context import RequestContext from aries_cloudagent.storage.base import BaseStorage from aries_cloudagent.storage.basic import BasicStorage +from aries_cloudagent.transport.inbound.receipt import MessageReceipt from ..manager import RoutingManager, RoutingManagerError from ..models.route_record import RouteRecord @@ -16,7 +16,7 @@ @pytest.fixture() def request_context() -> RequestContext: ctx = RequestContext() - ctx.message_delivery = MessageDelivery(sender_verkey=TEST_VERKEY) + ctx.message_receipt = MessageReceipt(sender_verkey=TEST_VERKEY) ctx.injector.bind_instance(BaseStorage, BasicStorage()) yield ctx @@ -24,7 +24,7 @@ def request_context() -> RequestContext: @pytest.fixture() def manager() -> RoutingManager: ctx = RequestContext() - ctx.message_delivery = MessageDelivery(sender_verkey=TEST_VERKEY) + ctx.message_receipt = MessageReceipt(sender_verkey=TEST_VERKEY) ctx.injector.bind_instance(BaseStorage, BasicStorage()) return RoutingManager(ctx) diff --git a/aries_cloudagent/protocols/trustping/handlers/ping_handler.py b/aries_cloudagent/protocols/trustping/handlers/ping_handler.py index b66451b8cc..fe00b4dd3d 100644 --- a/aries_cloudagent/protocols/trustping/handlers/ping_handler.py +++ b/aries_cloudagent/protocols/trustping/handlers/ping_handler.py @@ -22,13 +22,13 @@ async def handle(self, context: RequestContext, responder: BaseResponder): assert isinstance(context.message, Ping) self._logger.info( - "Received trust ping from: %s", context.message_delivery.sender_did + "Received trust ping from: %s", context.message_receipt.sender_did ) if not context.connection_ready: self._logger.info( "Connection not active, skipping ping response: %s", - context.message_delivery.sender_did, + context.message_receipt.sender_did, ) return @@ -42,7 +42,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder): "ping", { "comment": context.message.comment, - "connection_id": context.message_delivery.connection_id, + "connection_id": context.message_receipt.connection_id, "responded": context.message.response_requested, "state": "received", "thread_id": context.message._thread_id, diff --git a/aries_cloudagent/protocols/trustping/handlers/ping_response_handler.py b/aries_cloudagent/protocols/trustping/handlers/ping_response_handler.py index 48a3d358f4..0a5d5cf5e9 100644 --- a/aries_cloudagent/protocols/trustping/handlers/ping_response_handler.py +++ b/aries_cloudagent/protocols/trustping/handlers/ping_response_handler.py @@ -22,7 +22,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder): assert isinstance(context.message, PingResponse) self._logger.info( - "Received trust ping response from: %s", context.message_delivery.sender_did + "Received trust ping response from: %s", context.message_receipt.sender_did ) if context.settings.get("debug.monitor_ping"): @@ -30,7 +30,7 @@ async def handle(self, context: RequestContext, responder: BaseResponder): "ping", { "comment": context.message.comment, - "connection_id": context.message_delivery.connection_id, + "connection_id": context.message_receipt.connection_id, "state": "response_received", "thread_id": context.message._thread_id, }, diff --git a/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_handler.py b/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_handler.py index bb8d8a6fbd..5913ff0f65 100644 --- a/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_handler.py +++ b/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_handler.py @@ -1,9 +1,9 @@ import pytest from .....messaging.base_handler import HandlerException -from .....messaging.message_delivery import MessageDelivery from .....messaging.request_context import RequestContext from .....messaging.responder import MockResponder +from .....transport.inbound.receipt import MessageReceipt from ...handlers.ping_handler import PingHandler from ...messages.ping import Ping @@ -19,7 +19,7 @@ def request_context() -> RequestContext: class TestPingHandler: @pytest.mark.asyncio async def test_ping(self, request_context): - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.message = Ping(response_requested=False) request_context.connection_ready = True handler = PingHandler() @@ -30,7 +30,7 @@ async def test_ping(self, request_context): @pytest.mark.asyncio async def test_ping_response(self, request_context): - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.message = Ping(response_requested=True) request_context.connection_ready = True handler = PingHandler() diff --git a/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_response_handler.py b/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_response_handler.py index 72fa522d0a..e0d2ab87fe 100644 --- a/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_response_handler.py +++ b/aries_cloudagent/protocols/trustping/handlers/tests/test_ping_response_handler.py @@ -1,9 +1,9 @@ import pytest from .....messaging.base_handler import HandlerException -from .....messaging.message_delivery import MessageDelivery from .....messaging.request_context import RequestContext from .....messaging.responder import MockResponder +from .....transport.inbound.receipt import MessageReceipt from ...handlers.ping_response_handler import PingResponseHandler from ...messages.ping_response import PingResponse @@ -18,7 +18,7 @@ def request_context() -> RequestContext: class TestPingResponseHandler: @pytest.mark.asyncio async def test_ping_response(self, request_context): - request_context.message_delivery = MessageDelivery() + request_context.message_receipt = MessageReceipt() request_context.message = PingResponse() request_context.connection_ready = True handler = PingResponseHandler() diff --git a/aries_cloudagent/task_processor.py b/aries_cloudagent/task_processor.py deleted file mode 100644 index e577fa61aa..0000000000 --- a/aries_cloudagent/task_processor.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Classes for managing a limited set of concurrent tasks.""" - -import asyncio -import logging -import time -from typing import Awaitable, Callable - -LOGGER = logging.getLogger(__name__) - - -async def delay_task(delay: float, task: Awaitable): - """Wait a given amount of time before executing an awaitable.""" - await asyncio.sleep(delay) - return await task - - -class PendingTask: - """Class for tracking pending tasks.""" - - def __init__( - self, - ident, - fn: Callable[["PendingTask"], Awaitable], - retries: int = None, - retry_delay: float = None, - ): - """Initialize the pending task instance.""" - self.attempts = 0 - self.ident = ident - self.fn = fn - self.future = asyncio.get_event_loop().create_future() - self.retries = retries - self.retry_delay = retry_delay - self.running: asyncio.Future = None - self.start = time.perf_counter() - - def done(self): - """Check if the task is done.""" - return self.future.done() - - def exception(self): - """Get the exception raised by the task, if any.""" - return self.future.exception() - - def result(self): - """Get the result of the task.""" - return self.future.result() - - def cancel(self): - """Cancel the running task.""" - if not self.future.done(): - self.future.cancel() - if self.running and not self.running.done(): - self.running.cancel() - - def __await__(self): - """Await the pending task.""" - return self.future.__await__() - - -class TaskProcessor: - """Class for managing a limited set of concurrent tasks.""" - - def __init__(self, *, max_pending: int = 10): - """Instantiate the dispatcher.""" - self.done_event = asyncio.Event() - self.done_event.set() - self.loop = asyncio.get_event_loop() - self.max_pending = max_pending - self.pending = set() - self.pending_lock = asyncio.Lock() - self.ready_event = asyncio.Event() - self.ready_event.set() - - def ready(self): - """Check if the processor is ready.""" - return self.ready_event.is_set() - - async def wait_ready(self): - """Wait for the processor to be ready for more tasks.""" - await self.ready_event.wait() - - def done(self): - """Check if the processor has any pending tasks.""" - return self.done_event.is_set() - - async def wait_done(self): - """Wait for all pending tasks to complete.""" - await self.done_event.wait() - - def _enqueue_task(self, task: PendingTask): - """Enqueue the given pending task.""" - if not task.done(): - awaitable = task.fn(task) - if awaitable: - if task.attempts and task.retry_delay: - awaitable = delay_task(task.retry_delay, awaitable) - task.attempts += 1 - task.running = asyncio.ensure_future(awaitable) - task.running.add_done_callback( - lambda fut: self.loop.create_task(self._check_task(task)) - ) - else: - task.future.set_result(None) - self.loop.create_task(self._check_task(task)) - - async def _check_task(self, task: PendingTask): - """Complete a task.""" - if task.running and task.running.done(): - future = task.running - task.running = None - exception = future.exception() - if exception: - LOGGER.debug( - "Task raised exception: (%s) %s", task.ident or task, exception - ) - if task.retries and task.attempts < task.retries: - asyncio.get_event_loop().call_soon(self._enqueue_task, task) - else: - LOGGER.warning("Task failed: %s", task.ident or task) - task.future.set_exception(exception) - else: - task.future.set_result(future.result()) - if task.done(): - async with self.pending_lock: - if task in self.pending: - self.pending.remove(task) - else: - LOGGER.warning( - "Task not found in pending list: %s", task.ident or task - ) - if len(self.pending) < self.max_pending: - self.ready_event.set() - if not self.pending: - self.done_event.set() - - async def run_retry( - self, - fn: Callable[[PendingTask], Awaitable], - *, - ident=None, - retries: int = 5, - retry_delay: float = 10.0, - when_ready: bool = True, - ) -> PendingTask: - """Process a task and track the result.""" - if when_ready: - await self.wait_ready() - task = PendingTask(ident, fn, retries=retries, retry_delay=retry_delay) - async with self.pending_lock: - self.pending.add(task) - self.done_event.clear() - if len(self.pending) >= self.max_pending: - self.ready_event.clear() - asyncio.get_event_loop().call_soon(self._enqueue_task, task) - return task - - async def run_task( - self, task: Awaitable, *, ident=None, when_ready: bool = True - ) -> PendingTask: - """Run a single coroutine with no retries.""" - return await self.run_retry( - lambda pending: task, ident=ident, retries=0, when_ready=when_ready - ) diff --git a/aries_cloudagent/tests/test_conductor.py b/aries_cloudagent/tests/test_conductor.py index 3b2f508d45..2c4a6bf3ce 100644 --- a/aries_cloudagent/tests/test_conductor.py +++ b/aries_cloudagent/tests/test_conductor.py @@ -16,25 +16,21 @@ PublicKeyType, Service, ) -from ..messaging.message_delivery import MessageDelivery -from ..messaging.serializer import MessageSerializer -from ..messaging.outbound_message import OutboundMessage from ..messaging.protocol_registry import ProtocolRegistry from ..stats import Collector from ..storage.base import BaseStorage from ..storage.basic import BasicStorage from ..transport.inbound.base import InboundTransportConfiguration -from ..transport.outbound.queue.base import BaseOutboundMessageQueue -from ..transport.outbound.queue.basic import BasicOutboundMessageQueue +from ..transport.inbound.message import InboundMessage +from ..transport.inbound.receipt import MessageReceipt +from ..transport.outbound.base import OutboundDeliveryError +from ..transport.outbound.message import OutboundMessage +from ..transport.wire_format import BaseWireFormat from ..wallet.base import BaseWallet from ..wallet.basic import BasicWallet class Config: - good_inbound_transports = {"transport.inbound_configs": [["http", "host", 80]]} - good_outbound_transports = {"transport.outbound_configs": ["http"]} - bad_inbound_transports = {"transport.inbound_configs": [["bad", "host", 80]]} - bad_outbound_transports = {"transport.outbound_configs": ["bad"]} test_settings = {} test_settings_with_queue = {"queue.enable_undelivered_queue": True} @@ -70,18 +66,15 @@ def make_did_doc(self, did, verkey): class StubContextBuilder(ContextBuilder): def __init__(self, settings): super().__init__(settings) - self.message_serializer = async_mock.create_autospec(MessageSerializer()) + self.wire_format = async_mock.create_autospec(BaseWireFormat()) async def build(self) -> InjectionContext: context = InjectionContext(settings=self.settings) context.injector.enforce_typing = False - context.injector.bind_instance( - BaseOutboundMessageQueue, BasicOutboundMessageQueue() - ) context.injector.bind_instance(BaseStorage, BasicStorage()) context.injector.bind_instance(BaseWallet, BasicWallet()) context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) - context.injector.bind_instance(MessageSerializer, self.message_serializer) + context.injector.bind_instance(BaseWireFormat, self.wire_format) return context @@ -95,8 +88,6 @@ async def build(self) -> InjectionContext: class TestConductor(AsyncTestCase, Config, TestDIDs): async def test_startup(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) - builder.update_settings(self.good_inbound_transports) - builder.update_settings(self.good_outbound_transports) conductor = test_module.Conductor(builder) with async_mock.patch.object( @@ -109,21 +100,19 @@ async def test_startup(self): await conductor.setup() - mock_inbound_mgr.return_value.register.assert_called_once_with( - InboundTransportConfiguration(module="http", host="host", port=80), - conductor.inbound_message_router, - conductor.register_socket, - ) - mock_outbound_mgr.return_value.register.assert_called_once_with("http") + mock_inbound_mgr.return_value.setup.assert_awaited_once() + mock_outbound_mgr.return_value.setup.assert_awaited_once() - mock_inbound_mgr.return_value.registered_transports = [] - mock_outbound_mgr.return_value.registered_transports = [] + mock_inbound_mgr.return_value.registered_transports = {} + mock_outbound_mgr.return_value.registered_transports = {} await conductor.start() mock_inbound_mgr.return_value.start.assert_awaited_once_with() mock_outbound_mgr.return_value.start.assert_awaited_once_with() + mock_logger.print_banner.assert_called_once() + await conductor.stop() mock_inbound_mgr.return_value.stop.assert_awaited_once_with() @@ -136,70 +125,47 @@ async def test_inbound_message_handler(self): await conductor.setup() with async_mock.patch.object( - conductor.dispatcher, "dispatch", new_callable=async_mock.CoroutineMock + conductor.dispatcher, "queue_message", autospec=True ) as mock_dispatch: - delivery = MessageDelivery() - parsed_msg = {} - mock_serializer = builder.message_serializer - mock_serializer.extract_message_type.return_value = "message_type" - mock_serializer.parse_message.return_value = (parsed_msg, delivery) - message_body = "{}" - transport = "http" - await conductor.inbound_message_router(message_body, transport) + receipt = MessageReceipt() + message = InboundMessage(message_body, receipt) - mock_serializer.parse_message.assert_awaited_once_with( - conductor.context, message_body, transport - ) + conductor.inbound_message_router(message) - mock_dispatch.assert_awaited_once_with( - parsed_msg, delivery, None, conductor.outbound_message_router - ) + mock_dispatch.assert_called_once() + assert mock_dispatch.call_args[0][0] is message + assert mock_dispatch.call_args[0][1] == conductor.outbound_message_router + assert mock_dispatch.call_args[0][2] is None # admin webhook router + assert callable(mock_dispatch.call_args[0][3]) - async def test_direct_response(self): + async def test_outbound_message_handler_return_route(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) conductor = test_module.Conductor(builder) + test_to_verkey = "test-to-verkey" + test_from_verkey = "test-from-verkey" await conductor.setup() - single_response = asyncio.Future() - dispatch_result = """{"@type": "..."}""" - - async def mock_dispatch(parsed_msg, delivery, connection, outbound): - socket_id = delivery.socket_id - socket = conductor.sockets[socket_id] - socket.reply_mode = "all" - reply = OutboundMessage( - dispatch_result, - connection_id=None, - encoded=False, - endpoint=None, - reply_socket_id=socket_id, - ) - await outbound(reply) - result = asyncio.Future() - result.set_result(None) - return result - - with async_mock.patch.object(conductor.dispatcher, "dispatch", mock_dispatch): - - delivery = MessageDelivery() - parsed_msg = {} - mock_serializer = builder.message_serializer - mock_serializer.extract_message_type.return_value = "message_type" - mock_serializer.parse_message.return_value = (parsed_msg, delivery) + payload = "{}" + message = OutboundMessage(payload=payload) + message.reply_to_verkey = test_to_verkey + receipt = MessageReceipt() + receipt.recipient_verkey = test_from_verkey + inbound = InboundMessage("[]", receipt) - message_body = "{}" - transport = "http" - complete = await conductor.inbound_message_router( - message_body, transport, None, single_response - ) - await asyncio.wait_for(complete, 1.0) - - assert single_response.result() == dispatch_result - - async def test_outbound_message_handler(self): + with async_mock.patch.object( + conductor.inbound_transport_manager, "return_to_session" + ) as mock_return, async_mock.patch.object( + conductor, "queue_outbound", async_mock.CoroutineMock() + ) as mock_queue: + mock_return.return_value = True + await conductor.outbound_message_router(conductor.context, message) + mock_return.assert_called_once_with(message) + mock_queue.assert_not_awaited() + + async def test_outbound_message_handler_with_target(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) conductor = test_module.Conductor(builder) @@ -215,121 +181,42 @@ async def test_outbound_message_handler(self): ) message = OutboundMessage(payload=payload, target=target) - await conductor.outbound_message_router(message) - - mock_serializer = builder.message_serializer - mock_serializer.encode_message.assert_awaited_once_with( - conductor.context, - payload, - target.recipient_keys, - target.routing_keys, - target.sender_key, - ) + await conductor.outbound_message_router(conductor.context, message) - mock_outbound_mgr.return_value.send_message.assert_awaited_once_with( - message + mock_outbound_mgr.return_value.enqueue_message.assert_called_once_with( + conductor.context, message ) - async def test_outbound_queue_add_with_no_endpoint(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings_with_queue) + async def test_outbound_message_handler_with_connection(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) conductor = test_module.Conductor(builder) - # set up relationship without endpoint + with async_mock.patch.object( - test_module, "DeliveryQueue", autospec=True - ) as mock_delivery_queue: + test_module, "OutboundTransportManager", autospec=True + ) as mock_outbound_mgr, async_mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as conn_mgr: await conductor.setup() - sender_did_doc, sender_pk = self.make_did_doc( - self.test_did, self.test_verkey - ) - target_did_doc, target_pk = self.make_did_doc( - self.test_target_did, self.test_target_verkey - ) - payload = "{}" - target = ConnectionTarget( - recipient_keys=[target_pk], routing_keys=(), sender_key=sender_pk - ) - message = OutboundMessage(payload=payload, target=target) + connection_id = "connection_id" + message = OutboundMessage(payload=payload, connection_id=connection_id) - await conductor.outbound_message_router(message) + await conductor.outbound_message_router(conductor.context, message) - mock_delivery_queue.return_value.add_message.assert_called_once_with( - message + conn_mgr.assert_called_once_with(conductor.context) + conn_mgr.return_value.get_connection_targets.assert_awaited_once_with( + connection_id=connection_id + ) + assert ( + message.target_list + is conn_mgr.return_value.get_connection_targets.return_value ) - async def test_outbound_queue_check_on_inbound(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings_with_queue) - conductor = test_module.Conductor(builder) - - with async_mock.patch.object( - test_module, "DeliveryQueue", autospec=True - ) as mock_delivery_queue: - await conductor.setup() - - async def mock_dispatch(parsed_msg, delivery, connection, outbound): - result = asyncio.Future() - result.set_result(None) - return result - - # set up relationship without endpoint - with async_mock.patch.object( - conductor.dispatcher, "dispatch", mock_dispatch - ) as mock_dispatch_method, async_mock.patch.object( - test_module, "ConnectionManager", autospec=True - ) as mock_connection_manager: - - sender_did_doc, sender_pk = self.make_did_doc( - self.test_did, self.test_verkey - ) - - # we don't need the connection, so avoid looking for one. - mock_connection_manager.find_message_connection.return_value = None - - delivery = MessageDelivery() - delivery.sender_verkey = sender_pk - delivery.direct_response_requested = "all" - parsed_msg = {} - mock_serializer = builder.message_serializer - mock_serializer.extract_message_type.return_value = ( - "message_type" # messaging.trustping.message_types.PING - ) - mock_serializer.parse_message.return_value = (parsed_msg, delivery) - - message_body = "{}" - transport = "http" - delivery_future = asyncio.Future() - r_future = await conductor.inbound_message_router( - message_body, transport, single_response=delivery_future - ) - r_future_result = await r_future # required for test passing. - mock_delivery_queue.return_value.has_message_for_key.assert_called_once_with( - sender_pk.value - ) - - async def test_connection_target(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings) - conductor = test_module.Conductor(builder) - - await conductor.setup() - - test_target = ConnectionTarget( - endpoint="endpoint", recipient_keys=(), routing_keys=(), sender_key="" - ) - test_conn_id = "1" - - with async_mock.patch.object( - ConnectionRecord, "retrieve_by_id", autospec=True - ) as retrieve_by_id, async_mock.patch.object( - ConnectionManager, "get_connection_target", autospec=True - ) as get_target: - - get_target.return_value = test_target - - target = await conductor.get_connection_target(test_conn_id) - - assert target is test_target + mock_outbound_mgr.return_value.enqueue_message.assert_called_once_with( + conductor.context, message + ) async def test_admin(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) @@ -353,8 +240,6 @@ async def test_admin(self): async def test_setup_collector(self): builder: ContextBuilder = StubCollectorContextBuilder(self.test_settings) - builder.update_settings(self.good_inbound_transports) - builder.update_settings(self.good_outbound_transports) conductor = test_module.Conductor(builder) with async_mock.patch.object( @@ -367,6 +252,17 @@ async def test_setup_collector(self): await conductor.setup() + async def test_start_static(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + builder.update_settings({"debug.test_suite_endpoint": True}) + conductor = test_module.Conductor(builder) + + with async_mock.patch.object(test_module, "ConnectionManager") as mock_mgr: + await conductor.setup() + mock_mgr.return_value.create_static_connection = async_mock.CoroutineMock() + await conductor.start() + mock_mgr.return_value.create_static_connection.assert_awaited_once() + async def test_print_invite(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) builder.update_settings( @@ -382,3 +278,40 @@ async def test_print_invite(self): await conductor.stop() assert "http://localhost?c_i=" in captured.getvalue() + + async def test_webhook_router(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + builder.update_settings( + {"debug.print_invitation": True, "invite_base_url": "http://localhost"} + ) + conductor = test_module.Conductor(builder) + + test_topic = "test-topic" + test_payload = {"test": "payload"} + test_endpoint = "http://example" + test_retries = 2 + + await conductor.setup() + with async_mock.patch.object( + conductor.outbound_transport_manager, "enqueue_webhook" + ) as mock_enqueue: + conductor.webhook_router( + test_topic, test_payload, test_endpoint, test_retries + ) + mock_enqueue.assert_called_once_with( + test_topic, test_payload, test_endpoint, test_retries + ) + + # swallow error + with async_mock.patch.object( + conductor.outbound_transport_manager, + "enqueue_webhook", + side_effect=OutboundDeliveryError, + ) as mock_enqueue: + conductor.webhook_router( + test_topic, test_payload, test_endpoint, test_retries + ) + mock_enqueue.assert_called_once_with( + test_topic, test_payload, test_endpoint, test_retries + ) + diff --git a/aries_cloudagent/tests/test_dispatcher.py b/aries_cloudagent/tests/test_dispatcher.py index ad68e9a2dd..1cce81907b 100644 --- a/aries_cloudagent/tests/test_dispatcher.py +++ b/aries_cloudagent/tests/test_dispatcher.py @@ -8,34 +8,38 @@ from ..connections.models.connection_record import ConnectionRecord from ..messaging.agent_message import AgentMessage, AgentMessageSchema from ..messaging.error import MessageParseError -from ..messaging.message_delivery import MessageDelivery -from ..messaging.outbound_message import OutboundMessage -from ..protocols.problem_report.message import ProblemReport from ..messaging.protocol_registry import ProtocolRegistry -from ..messaging.serializer import MessageSerializer +from ..protocols.problem_report.message import ProblemReport +from ..transport.inbound.message import InboundMessage +from ..transport.inbound.receipt import MessageReceipt +from ..transport.outbound.message import OutboundMessage def make_context() -> InjectionContext: context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) - context.injector.bind_instance(MessageSerializer, MessageSerializer()) return context -def make_delivery() -> MessageDelivery: - return MessageDelivery() +def make_inbound(payload) -> InboundMessage: + return InboundMessage(payload, MessageReceipt()) -def make_connection_record() -> ConnectionRecord: - return ConnectionRecord() +# def make_connection_record() -> ConnectionRecord: +# return ConnectionRecord() class Receiver: def __init__(self): self.messages = [] - async def send(self, message: OutboundMessage): - self.messages.append(message) + async def send( + self, + context: InjectionContext, + message: OutboundMessage, + inbound: InboundMessage = None, + ): + self.messages.append((context, message, inbound)) class StubAgentMessage(AgentMessage): @@ -70,10 +74,8 @@ async def test_dispatch(self): with async_mock.patch.object( StubAgentMessageHandler, "handle", autospec=True ) as handler_mock: - await dispatcher.dispatch( - message, make_delivery(), make_connection_record(), rcv.send - ) - await asyncio.sleep(0.1) + await dispatcher.queue_message(make_inbound(message), rcv.send) + await dispatcher.task_queue handler_mock.assert_awaited_once() assert isinstance(handler_mock.call_args[0][1].message, StubAgentMessage) assert isinstance( @@ -84,10 +86,8 @@ async def test_bad_message_dispatch(self): dispatcher = test_module.Dispatcher(make_context()) rcv = Receiver() bad_message = {"bad": "message"} - await dispatcher.dispatch( - bad_message, make_delivery(), make_connection_record(), rcv.send - ) - await asyncio.sleep(0.1) - assert rcv.messages and isinstance(rcv.messages[0], OutboundMessage) - payload = json.loads(rcv.messages[0].payload) + await dispatcher.queue_message(make_inbound(bad_message), rcv.send) + await dispatcher.task_queue + assert rcv.messages and isinstance(rcv.messages[0][1], OutboundMessage) + payload = json.loads(rcv.messages[0][1].payload) assert payload["@type"] == ProblemReport.Meta.message_type diff --git a/aries_cloudagent/tests/test_task_processor.py b/aries_cloudagent/tests/test_task_processor.py deleted file mode 100644 index b4ac96cb62..0000000000 --- a/aries_cloudagent/tests/test_task_processor.py +++ /dev/null @@ -1,60 +0,0 @@ -import asyncio - -from asynctest import TestCase as AsyncTestCase -from asynctest import mock as async_mock - -from ..task_processor import TaskProcessor, PendingTask - - -class RetryTask: - def __init__(self, retries: int, result): - self.attempts = 0 - self.retries = retries - self.result = result - - async def run(self, pending: PendingTask): - self.attempts += 1 - if self.attempts <= self.retries: - raise Exception() - return self.result - - -class TestTaskProcessor(AsyncTestCase): - async def test_coro(self): - collected = [] - - async def test_task(val): - collected.append(val) - return val - - processor = TaskProcessor() - await processor.run_task(test_task(1)) - await processor.run_task(test_task(2)) - future = await processor.run_task(test_task(3)) - result = await asyncio.wait_for(future, timeout=5.0) - assert result == 3 - await asyncio.wait_for(processor.wait_done(), timeout=5.0) - collected.sort() - assert collected == [1, 2, 3] - - async def test_error(self): - async def test_error(): - raise ValueError("test") - - processor = TaskProcessor() - future = await processor.run_task(test_error()) - with self.assertRaises(ValueError): - result = await asyncio.wait_for(future, timeout=5.0) - await asyncio.wait_for(processor.wait_done(), timeout=5.0) - - async def test_retry(self): - test_value = "test_value" - task = RetryTask(1, test_value) - processor = TaskProcessor() - future = await processor.run_retry( - lambda pending: task.run(pending), retries=5, retry_delay=0.01 - ) - result = await asyncio.wait_for(future, timeout=5.0) - assert result == test_value - await asyncio.wait_for(processor.wait_done(), timeout=5.0) - assert task.attempts == 2 diff --git a/aries_cloudagent/transport/error.py b/aries_cloudagent/transport/error.py new file mode 100644 index 0000000000..6615c1bf87 --- /dev/null +++ b/aries_cloudagent/transport/error.py @@ -0,0 +1,23 @@ +"""Transport-related error classes and codes.""" + +from ..error import BaseError + + +class TransportError(BaseError): + """Base class for all transport errors.""" + + +class WireFormatError(TransportError): + """Base class for wire-format errors.""" + + +class MessageParseError(WireFormatError): + """Message parse error.""" + + error_code = "message_parse_error" + + +class MessageEncodeError(WireFormatError): + """Message encoding error.""" + + error_code = "message_encode_error" diff --git a/aries_cloudagent/transport/inbound/base.py b/aries_cloudagent/transport/inbound/base.py index 3bf5951976..b3c87a4196 100644 --- a/aries_cloudagent/transport/inbound/base.py +++ b/aries_cloudagent/transport/inbound/base.py @@ -2,13 +2,73 @@ from abc import ABC, abstractmethod from collections import namedtuple +from typing import Awaitable, Callable -from ...error import BaseError +from ..error import TransportError +from ..wire_format import BaseWireFormat + +from .session import InboundSession class BaseInboundTransport(ABC): """Base inbound transport class.""" + def __init__( + self, + scheme: str, + create_session: Callable, + *, + max_message_size: int = 0, + wire_format: BaseWireFormat = None, + ): + """ + Initialize the inbound transport instance. + + Args: + scheme: The transport scheme identifier + create_session: Method to create a new inbound session + """ + + self._create_session = create_session + self._max_message_size = max_message_size + self._scheme = scheme + self.wire_format: BaseWireFormat = wire_format + + @property + def max_message_size(self): + """Accessor for this transport's max message size.""" + return self._max_message_size + + @property + def scheme(self): + """Accessor for this transport's scheme.""" + return self._scheme + + def create_session( + self, + *, + accept_undelivered: bool = False, + can_respond: bool = False, + client_info: dict = None, + wire_format: BaseWireFormat = None, + ) -> Awaitable[InboundSession]: + """ + Create a new inbound session. + + Args: + accept_undelivered: Flag for accepting undelivered messages + can_respond: Flag indicating that the transport can send responses + client_info: Request-specific client information + wire_format: Optionally override the session wire format + """ + return self._create_session( + accept_undelivered=accept_undelivered, + can_respond=can_respond, + client_info=client_info, + wire_format=wire_format or self.wire_format, + transport_type=self.scheme, + ) + @abstractmethod async def start(self) -> None: """Start listening for on this transport.""" @@ -18,11 +78,15 @@ async def stop(self) -> None: """Stop listening for on this transport.""" -class InboundTransportRegistrationError(BaseError): +class InboundTransportError(TransportError): + """Generic inbound transport error.""" + + +class InboundTransportRegistrationError(InboundTransportError): """Error in loading an inbound transport.""" -class InboundTransportSetupError(BaseError): +class InboundTransportSetupError(InboundTransportError): """Setup error for an inbound transport.""" diff --git a/aries_cloudagent/delivery_queue.py b/aries_cloudagent/transport/inbound/delivery_queue.py similarity index 67% rename from aries_cloudagent/delivery_queue.py rename to aries_cloudagent/transport/inbound/delivery_queue.py index 48cded649d..fff932ff6c 100644 --- a/aries_cloudagent/delivery_queue.py +++ b/aries_cloudagent/transport/inbound/delivery_queue.py @@ -7,7 +7,7 @@ """ import time -from aries_cloudagent.messaging.outbound_message import OutboundMessage +from ..outbound.message import OutboundMessage class QueuedMessage: @@ -26,13 +26,14 @@ def __init__(self, msg: OutboundMessage): self.msg = msg self.timestamp = time.time() - def older_than(self, compare_timestamp): + def older_than(self, compare_timestamp: float) -> bool: """ Age Comparison. Allows you to test age as compared to the provided timestamp. - :param compare_timestamp: - :return: + + Args: + compare_timestamp: The timestamp to compare """ return self.timestamp < compare_timestamp @@ -58,16 +59,16 @@ def expire_messages(self, ttl=None): """ Expire messages that are past the time limit. - :param ttl: Optional. Allows override of configured ttl - :return: None + Args: + ttl: Optional. Allows override of configured ttl """ ttl_seconds = ttl or self.ttl_seconds horizon = time.time() - ttl_seconds for key in self.queue_by_key.keys(): - self.queue_by_key[key] = [wm for - wm in self.queue_by_key[key] - if not wm.older_than(horizon)] + self.queue_by_key[key] = [ + wm for wm in self.queue_by_key[key] if not wm.older_than(horizon) + ] def add_message(self, msg: OutboundMessage): """ @@ -75,11 +76,16 @@ def add_message(self, msg: OutboundMessage): The message is added once per recipient key - Arguments: + Args: msg: The OutboundMessage to add """ + keys = set() + if msg.target: + keys.update(msg.target.recipient_keys) + if msg.reply_to_verkey: + keys.add(msg.reply_to_verkey) wrapped_msg = QueuedMessage(msg) - for recipient_key in msg.target.recipient_keys: + for recipient_key in keys: if recipient_key not in self.queue_by_key: self.queue_by_key[recipient_key] = [] self.queue_by_key[recipient_key].append(wrapped_msg) @@ -88,7 +94,7 @@ def has_message_for_key(self, key: str): """ Check for queued messages by key. - Arguments: + Args: key: The key to use for lookup """ if key in self.queue_by_key and len(self.queue_by_key[key]): @@ -99,7 +105,7 @@ def message_count_for_key(self, key: str): """ Count of queued messages by key. - Arguments: + Args: key: The key to use for lookup """ if key in self.queue_by_key: @@ -111,30 +117,35 @@ def get_one_message_for_key(self, key: str): """ Remove and return a matching message. - Arguments: + Args: key: The key to use for lookup """ - return self.queue_by_key[key].pop(0).msg + if key in self.queue_by_key: + return self.queue_by_key[key].pop(0).msg def inspect_all_messages_for_key(self, key: str): """ Return all messages for key. - Arguments: + Args: key: The key to use for lookup """ - for wrapped_msg in self.queue_by_key[key]: - yield wrapped_msg.msg + if key in self.queue_by_key: + for wrapped_msg in self.queue_by_key[key]: + yield wrapped_msg.msg - def remove_message_for_key(self, key, msg: OutboundMessage): + def remove_message_for_key(self, key: str, msg: OutboundMessage): """ Remove specified message from queue for key. - Arguments: + Args: key: The key to use for lookup msg: The message to remove from the queue """ - for wrapped_msg in self.queue_by_key[key]: - if wrapped_msg.msg == msg: - self.queue_by_key[key].remove(wrapped_msg) - break # exit processing loop + if key in self.queue_by_key: + for wrapped_msg in self.queue_by_key[key]: + if wrapped_msg.msg == msg: + self.queue_by_key[key].remove(wrapped_msg) + if not self.queue_by_key[key]: + del self.queue_by_key[key] + break # exit processing loop diff --git a/aries_cloudagent/transport/inbound/http.py b/aries_cloudagent/transport/inbound/http.py index d268db3f03..823dc70f5c 100644 --- a/aries_cloudagent/transport/inbound/http.py +++ b/aries_cloudagent/transport/inbound/http.py @@ -1,51 +1,40 @@ """Http Transport classes and functions.""" -import asyncio import logging -from typing import Coroutine from aiohttp import web +from ...messaging.error import MessageParseError + from .base import BaseInboundTransport, InboundTransportSetupError +LOGGER = logging.getLogger(__name__) + class HttpTransport(BaseInboundTransport): """Http Transport class.""" - def __init__( - self, - host: str, - port: int, - message_router: Coroutine, - register_socket: Coroutine, - ) -> None: + def __init__(self, host: str, port: int, create_session, **kwargs) -> None: """ - Initialize a Transport instance. + Initialize an inbound HTTP transport instance. Args: host: Host to listen on port: Port to listen on - message_router: Function to pass incoming messages to - register_socket: A coroutine for registering a new socket + create_session: Method to create a new inbound session """ + super().__init__("http", create_session, **kwargs) self.host = host self.port = port - self.message_router = message_router - self.register_socket = register_socket - self.site = None - - self._scheme = "http" - self.logger = logging.getLogger(__name__) - - @property - def scheme(self): - """Accessor for this transport's scheme.""" - return self._scheme + self.site: web.BaseSite = None async def make_application(self) -> web.Application: """Construct the aiohttp application.""" - app = web.Application() + app_args = {} + if self.max_message_size: + app_args["client_max_size"] = self.max_message_size + app = web.Application(**app_args) app.add_routes([web.get("/", self.invite_message_handler)]) app.add_routes([web.post("/", self.inbound_message_handler)]) return app @@ -93,36 +82,38 @@ async def inbound_message_handler(self, request: web.BaseRequest): else: body = await request.read() - try: - response = asyncio.Future() - await self.message_router(body, self._scheme, single_response=response) - except Exception: - self.logger.exception("Error handling message") - return web.Response(status=400) - - try: - await asyncio.wait_for(response, 30) - except asyncio.TimeoutError: - if not response.done(): - response.cancel() - return web.Response(status=504) - except asyncio.CancelledError: - return web.Response(status=200) - - message = response.result() - if message: - if isinstance(message, bytes): - return web.Response( - body=message, - status=200, - headers={"Content-Type": "application/ssi-agent-wire"}, - ) - else: - return web.Response( - text=message, - status=200, - headers={"Content-Type": "application/json"}, - ) + client_info = {"host": request.host, "remote": request.remote} + + session = await self.create_session( + accept_undelivered=True, can_respond=True, client_info=client_info + ) + + async with session: + try: + inbound = await session.receive(body) + except MessageParseError: + raise web.HTTPBadRequest() + + if inbound.receipt.direct_response_requested: + response = await session.wait_response() + + # no more responses + session.can_respond = False + session.clear_response() + + if response: + if isinstance(response, bytes): + return web.Response( + body=response, + status=200, + headers={"Content-Type": "application/ssi-agent-wire"}, + ) + else: + return web.Response( + text=response, + status=200, + headers={"Content-Type": "application/json"}, + ) return web.Response(status=200) async def invite_message_handler(self, request: web.BaseRequest): diff --git a/aries_cloudagent/transport/inbound/manager.py b/aries_cloudagent/transport/inbound/manager.py index c7630207d7..ce32bc8968 100644 --- a/aries_cloudagent/transport/inbound/manager.py +++ b/aries_cloudagent/transport/inbound/manager.py @@ -1,37 +1,79 @@ """Inbound transport manager.""" +import asyncio import logging +import uuid +from collections import OrderedDict +from typing import Callable, Coroutine + +from ...config.injection_context import InjectionContext +from ...classloader import ClassLoader, ModuleLoadError, ClassNotFoundError +from ...messaging.task_queue import CompletedTask, TaskQueue + +from ..outbound.message import OutboundMessage +from ..wire_format import BaseWireFormat from .base import ( BaseInboundTransport, InboundTransportConfiguration, InboundTransportRegistrationError, ) -from ...classloader import ClassLoader, ModuleLoadError, ClassNotFoundError +from .delivery_queue import DeliveryQueue +from .message import InboundMessage +from .session import InboundSession +LOGGER = logging.getLogger(__name__) MODULE_BASE_PATH = "aries_cloudagent.transport.inbound" class InboundTransportManager: """Inbound transport manager class.""" - def __init__(self): + def __init__( + self, + context: InjectionContext, + receive_inbound: Coroutine, + return_inbound: Callable = None, + ): """Initialize an `InboundTransportManager` instance.""" - self.logger = logging.getLogger(__name__) - self.registered_transports = [] + self.context = context + self.max_message_size = 0 + self.receive_inbound = receive_inbound + self.return_inbound = return_inbound + self.registered_transports = {} + self.running_transports = {} + self.sessions = OrderedDict() + self.session_limit: asyncio.Semaphore = None + self.task_queue = TaskQueue() + self.undelivered_queue: DeliveryQueue = None - def register( - self, config: InboundTransportConfiguration, message_handler, register_socket - ): + async def setup(self): + """Perform setup operations.""" + # Load config settings + if self.context.settings.get("transport.max_message_size"): + self.max_message_size = self.context.settings["transport.max_message_size"] + + inbound_transports = ( + self.context.settings.get("transport.inbound_configs") or [] + ) + for transport in inbound_transports: + module, host, port = transport + self.register( + InboundTransportConfiguration(module=module, host=host, port=port) + ) + + # Setup queue for undelivered messages + if self.context.settings.get("transport.enable_undelivered_queue"): + self.undelivered_queue = DeliveryQueue() + + # self.session_limit = asyncio.Semaphore(50) + + def register(self, config: InboundTransportConfiguration) -> str: """ Register transport module. Args: - module_path: Path to module - host: The host to register on - port: The port to register on - message_handler: The message handler for incoming messages - register_socket: A coroutine for registering a new socket + config: The inbound transport configuration """ try: @@ -43,27 +85,163 @@ def register( f"Failed to load inbound transport {config.module}" ) from e - instance = imported_class( - config.host, config.port, message_handler, register_socket + return self.register_transport( + imported_class( + config.host, + config.port, + self.create_session, + max_message_size=self.max_message_size, + ), + imported_class.__qualname__, ) - self.register_instance(instance) - def register_instance(self, transport: BaseInboundTransport): + def register_transport( + self, transport: BaseInboundTransport, transport_id: str + ) -> str: """ - Register a new inbound transport instance. + Register a new inbound transport class. Args: - transport: Inbound transport instance to register + transport: Transport instance to register + transport_id: The transport ID to register + + """ + self.registered_transports[transport_id] = transport + async def start_transport(self, transport_id: str): """ - self.registered_transports.append(transport) + Start a registered inbound transport. + + Args: + transport_id: ID for the inbound transport to start + + """ + transport = self.registered_transports[transport_id] + await transport.start() + self.running_transports[transport_id] = transport + + def get_transport_instance(self, transport_id: str) -> BaseInboundTransport: + """Get an instance of a running transport by ID.""" + return self.running_transports[transport_id] async def start(self): """Start all registered transports.""" - for transport in self.registered_transports: - await transport.start() + for transport_id in self.registered_transports: + self.task_queue.run(self.start_transport(transport_id)) - async def stop(self): + async def stop(self, wait: bool = True): """Stop all registered transports.""" - for transport in self.registered_transports: + await self.task_queue.complete(None if wait else 0) + for transport in self.running_transports.values(): await transport.stop() + + async def create_session( + self, + transport_type: str, + *, + accept_undelivered: bool = False, + can_respond: bool = False, + client_info: dict = None, + wire_format: BaseWireFormat = None, + ): + """ + Create a new inbound session. + + Args: + transport_type: The inbound transport identifier + accept_undelivered: Flag for accepting undelivered messages + can_respond: Flag indicating that the transport can send responses + client_info: An optional dict describing the client + wire_format: Override the wire format for this session + """ + if self.session_limit: + await self.session_limit + if not wire_format: + wire_format = await self.context.inject(BaseWireFormat) + session = InboundSession( + context=self.context, + accept_undelivered=accept_undelivered, + can_respond=can_respond, + client_info=client_info, + close_handler=self.closed_session, + inbound_handler=self.receive_inbound, + session_id=str(uuid.uuid4()), + transport_type=transport_type, + wire_format=wire_format, + ) + self.sessions[session.session_id] = session + return session + + def dispatch_complete(self, message: InboundMessage, completed: CompletedTask): + """Handle completion of message dispatch.""" + session: InboundSession = self.sessions.get(message.session_id) + if session and session.accept_undelivered and not session.response_buffered: + self.process_undelivered(session) + + def closed_session(self, session: InboundSession): + """ + Clean up a closed session. + + Returns an undelivered message to the caller if possible. + """ + if session.session_id in self.sessions: + del self.sessions[session.session_id] + if self.session_limit: + self.session_limit.release() + if session.response_buffer: + if self.return_inbound: + self.return_inbound(session.context, session.response_buffer) + else: + LOGGER.warning("Message failed return delivery, will not be delivered") + + def return_to_session(self, outbound: OutboundMessage) -> bool: + """Return an outbound message via an open session, if possible.""" + accepted = False + + # prefer the same session ID + if outbound.reply_session_id and outbound.reply_session_id in self.sessions: + session = self.sessions[outbound.reply_session_id] + accepted = session.accept_response(outbound) + + if not accepted: + for session in self.sessions.values(): + if session.session_id != outbound.reply_session_id: + accepted = session.accept_response(outbound) + if accepted: + break + + if accepted: + LOGGER.debug("Returned message to socket %s", session.session_id) + return accepted + + def return_undelivered(self, outbound: OutboundMessage) -> bool: + """ + Add an undelivered message to the undelivered queue. + + At this point the message could not be associated with an inbound + session and could not be delivered via an outbound transport. + """ + if self.undelivered_queue: + self.undelivered_queue.add_message(outbound) + return True + return False + + def process_undelivered(self, session: InboundSession): + """ + Interact with undelivered queue to find applicable messages. + + Args: + session: The inbound session + """ + if session and session.can_respond and self.undelivered_queue: + for key in session.reply_verkeys: + for ( + undelivered_message + ) in self.undelivered_queue.inspect_all_messages_for_key(key): + if session.accept_response(undelivered_message): + LOGGER.debug( + "Sending previously undelivered message via inbound session" + ) + self.undelivered_queue.remove_message_for_key( + key, undelivered_message + ) diff --git a/aries_cloudagent/transport/inbound/message.py b/aries_cloudagent/transport/inbound/message.py new file mode 100644 index 0000000000..169b2dc35c --- /dev/null +++ b/aries_cloudagent/transport/inbound/message.py @@ -0,0 +1,25 @@ +"""Classes representing inbound messages.""" + +from typing import Union + +from .receipt import MessageReceipt + + +class InboundMessage: + """Container class linking a message payload with its receipt details.""" + + def __init__( + self, + payload: Union[str, bytes], + receipt: MessageReceipt, + *, + connection_id: str = None, + session_id: str = None, + transport_type: str = None, + ): + """Initialize the inbound message.""" + self.connection_id = connection_id + self.payload = payload + self.receipt = receipt + self.session_id = session_id + self.transport_type = transport_type diff --git a/aries_cloudagent/messaging/message_delivery.py b/aries_cloudagent/transport/inbound/receipt.py similarity index 72% rename from aries_cloudagent/messaging/message_delivery.py rename to aries_cloudagent/transport/inbound/receipt.py index 55cc4e49da..a7d478e6c3 100644 --- a/aries_cloudagent/messaging/message_delivery.py +++ b/aries_cloudagent/transport/inbound/receipt.py @@ -1,19 +1,22 @@ -"""Classes for representing message delivery details.""" +"""Classes for representing message receipt details.""" from datetime import datetime -class MessageDelivery: +class MessageReceipt: """Properties of an agent message's delivery.""" # TODO - add trust context information + REPLY_MODE_ALL = "all" + REPLY_MODE_NONE = "none" + REPLY_MODE_THREAD = "thread" + def __init__( self, *, connection_id: str = None, - direct_response: bool = False, - direct_response_requested: str = None, + direct_response_mode: str = None, in_time: datetime = None, raw_message: str = None, recipient_verkey: str = None, @@ -21,14 +24,11 @@ def __init__( recipient_did_public: str = None, sender_did: str = None, sender_verkey: str = None, - socket_id: str = None, thread_id: str = None, - transport_type: str = None, ): """Initialize the message delivery instance.""" self._connection_id = connection_id - self._direct_response = direct_response - self._direct_response_requested = direct_response_requested + self._direct_response_mode = direct_response_mode self._in_time = in_time self._raw_message = raw_message self._recipient_verkey = recipient_verkey @@ -36,9 +36,7 @@ def __init__( self._recipient_did_public = recipient_did_public self._sender_did = sender_did self._sender_verkey = sender_verkey - self._socket_id = socket_id self._thread_id = thread_id - self._transport_type = transport_type @property def connection_id(self) -> str: @@ -63,48 +61,31 @@ def connection_id(self, connection_id: bool): self._connection_id = connection_id @property - def direct_response(self) -> bool: + def direct_response_mode(self) -> str: """ - Accessor for the flag indicating that direct responses are preferred. + Accessor for the requested direct response mode. Returns: - This context's direct response flag - - """ - return self._direct_response + This context's requested direct response mode - @direct_response.setter - def direct_response(self, direct: bool): """ - Setter for the flag indicating that direct responses are preferred. + return self._direct_response_mode - Args: - direct: This context's new direct response flag - - """ - self._direct_response = direct + @direct_response_mode.setter + def direct_response_mode(self, mode: str) -> str: + """Setter for the requested direct response mode.""" + self._direct_response_mode = mode @property def direct_response_requested(self) -> str: """ - Accessor for the requested direct response mode. + Accessor for the the state of the direct response mode. Returns: This context's requested direct response mode """ - return self._direct_response_requested - - @direct_response_requested.setter - def direct_response_requested(self, direct_mode: str): - """ - Setter for the string indicating the requested direct responses mode. - - Args: - direct_mode: This context's new direct response mode - - """ - self._direct_response_requested = direct_mode + return self._direct_response_mode and self._direct_response_mode != "none" @property def in_time(self) -> str: @@ -263,28 +244,6 @@ def sender_verkey(self, verkey: str): """ self._sender_verkey = verkey - @property - def socket_id(self) -> str: - """ - Accessor for the identifier of the incoming socket connection. - - Returns: - This context's socket identifier - - """ - return self._socket_id - - @socket_id.setter - def socket_id(self, socket: str): - """ - Setter for the incoming socket identifier. - - Args: - socket: This context's socket identifier - - """ - self._socket_id = socket - @property def thread_id(self) -> str: """ @@ -307,28 +266,6 @@ def thread_id(self, thread: str): """ self._thread_id = thread - @property - def transport_type(self) -> str: - """ - Accessor for the transport type used to receive the message. - - Returns: - This context's transport type - - """ - return self._transport_type - - @transport_type.setter - def transport_type(self, transport: str): - """ - Setter for the transport type used to receive the message. - - Args: - transport: This context's new transport - - """ - self._transport_type = transport - def __repr__(self) -> str: """ Provide a human readable representation of this object. diff --git a/aries_cloudagent/transport/inbound/session.py b/aries_cloudagent/transport/inbound/session.py new file mode 100644 index 0000000000..a36194dbd1 --- /dev/null +++ b/aries_cloudagent/transport/inbound/session.py @@ -0,0 +1,276 @@ +"""Inbound connection handling classes.""" + +import asyncio +import logging +from typing import Callable, Sequence, Union + +from ...config.injection_context import InjectionContext + +from ..error import WireFormatError +from ..outbound.message import OutboundMessage +from ..wire_format import BaseWireFormat + +from .message import InboundMessage +from .receipt import MessageReceipt + +LOGGER = logging.getLogger(__name__) + + +class AcceptResult: + """Represent the result of accept_response.""" + + def __init__(self, accepted: bool, retry: bool = False): + """Initialize the `AcceptResult` instance.""" + self.accepted = accepted + self.retry = retry + + def __bool__(self) -> bool: + """Check if the result is true.""" + return self.accepted + + +class InboundSession: + """Track an open transport connection for direct routing of outbound messages.""" + + def __init__( + self, + *, + context: InjectionContext, + inbound_handler: Callable, + session_id: str, + wire_format: BaseWireFormat, + accept_undelivered: bool = False, + can_respond: bool = False, + client_info: dict = None, + close_handler: Callable = None, + reply_mode: str = None, + reply_thread_ids: Sequence[str] = None, + reply_verkeys: Sequence[str] = None, + transport_type: str = None, + ): + """Initialize the inbound session.""" + self.context = context + self.inbound_handler = inbound_handler + self.session_id = session_id + self.wire_format = wire_format + + self.accept_undelivered = accept_undelivered + self.client_info = client_info + self.close_handler = close_handler + self.response_buffer: OutboundMessage = None + self.response_event = asyncio.Event() + self.transport_type = transport_type + + self._can_respond = can_respond + self._closed = False + self._reply_mode = None + self._reply_verkeys = None + self._reply_thread_ids = None + + # call setters + self.reply_thread_ids = reply_thread_ids + self.reply_verkeys = reply_verkeys + self.reply_mode = reply_mode + + @property + def can_respond(self) -> bool: + """Accessor for the session can-respond state.""" + return self._can_respond and not self._closed + + @can_respond.setter + def can_respond(self, can_respond: bool): + """Setter for the session can-respond state.""" + self._can_respond = can_respond + + @property + def closed(self) -> bool: + """Accessor for the session closed state.""" + return self._closed + + def close(self): + """Setter for the session closed state.""" + self._closed = True + self.response_event.set() # end wait_response if blocked + if self.close_handler: + self.close_handler(self) + + @property + def reply_mode(self) -> str: + """Accessor for the session reply mode.""" + return self._reply_mode + + @reply_mode.setter + def reply_mode(self, mode: str): + """Setter for the session reply mode.""" + if mode not in ( + MessageReceipt.REPLY_MODE_ALL, + MessageReceipt.REPLY_MODE_THREAD, + ): + mode = None + self._reply_mode = mode + if not mode: + # reset the tracked thread IDs when the mode is changed to none + self.reply_thread_ids = set() + + @property + def reply_verkeys(self): + """Accessor for the reply verkeys.""" + return self._reply_verkeys.copy() + + @reply_verkeys.setter + def reply_verkeys(self, verkeys: Sequence[str]): + """Setter for the reply verkeys.""" + self._reply_verkeys = set(verkeys) if verkeys else set() + + @property + def reply_thread_ids(self): + """Accessor for the reply thread IDs.""" + return self._reply_thread_ids.copy() + + @reply_thread_ids.setter + def reply_thread_ids(self, thread_ids: Sequence[str]): + """Setter for the reply thread IDs.""" + self._reply_thread_ids = set(thread_ids) if thread_ids else set() + + def add_reply_thread_ids(self, *thids): + """Add a thread ID to the set of potential reply targets.""" + for thid in filter(None, thids): + self._reply_thread_ids.add(thid) + + def add_reply_verkeys(self, *verkeys): + """Add a verkey to the set of potential reply targets.""" + for verkey in filter(None, verkeys): + self._reply_verkeys.add(verkey) + + @property + def response_buffered(self) -> bool: + """Check if a response is currently buffered.""" + return bool(self.response_buffer) + + def process_inbound(self, message: InboundMessage): + """ + Process an incoming message and update the session metadata as necessary. + + Args: + message: The inbound message instance + """ + receipt = message.receipt + mode = self.reply_mode = ( + receipt.direct_response_requested and receipt.direct_response_mode + ) + self.add_reply_verkeys(receipt.sender_verkey) + if mode == MessageReceipt.REPLY_MODE_THREAD: + self.add_reply_thread_ids(receipt.thread_id) + + async def parse_inbound(self, payload_enc: Union[str, bytes]) -> InboundMessage: + """Convert a message payload and to an inbound message.""" + payload, receipt = await self.wire_format.parse_message( + self.context, payload_enc + ) + return InboundMessage( + payload, + receipt, + session_id=self.session_id, + transport_type=self.transport_type, + ) + + async def receive(self, payload_enc: Union[str, bytes]) -> InboundMessage: + """Receive a new message payload and dispatch the message.""" + message = await self.parse_inbound(payload_enc) + self.receive_inbound(message) + return message + + def receive_inbound(self, message: InboundMessage): + """Deliver the inbound message to the conductor.""" + self.process_inbound(message) + self.inbound_handler(message, can_respond=self.can_respond) + + def select_outbound(self, message: OutboundMessage) -> bool: + """Determine if an outbound message should be sent to this session. + + Args: + message: The outbound message to be checked + + """ + if not self.can_respond: + return False + + mode = self.reply_mode + reply_verkey = message.reply_to_verkey + reply_thread_id = message.reply_thread_id + + if reply_verkey and reply_verkey in self.reply_verkeys: + if mode == MessageReceipt.REPLY_MODE_ALL: + return True + elif ( + mode == MessageReceipt.REPLY_MODE_THREAD + and reply_thread_id + and reply_thread_id in self._reply_thread_ids + ): + return True + + return False + + async def encode_outbound(self, outbound: OutboundMessage) -> OutboundMessage: + """Apply wire formatting to an outbound message.""" + if not outbound.payload: + raise WireFormatError("Message has no payload to encode") + if not outbound.reply_to_verkey: + raise WireFormatError("No reply verkey available for encoding message") + + return await self.wire_format.encode_message( + self.context, + outbound.payload, + [outbound.reply_to_verkey], + None, + outbound.reply_from_verkey, + ) + + def accept_response(self, message: OutboundMessage) -> AcceptResult: + """ + Try to queue an outbound message if it applies to this session. + + Returns: a tuple of (message buffered, retry later) + """ + if not self.select_outbound(message): + return AcceptResult(False, False) + if self.response_buffer: + return AcceptResult(False, True) + self.set_response(message) + return AcceptResult(True) + + def set_response(self, message: OutboundMessage): + """Set the contents of the response message buffer.""" + self.response_buffer = message + self.response_event.set() + + def clear_response(self): + """Handle when the buffered response message has been delivered.""" + self.response_buffer = None + self.response_event.set() + + async def wait_response(self) -> Union[str, bytes]: + """Wait for a response to be buffered and pack it.""" + while True: + if self._closed: + return + if self.response_buffer: + response = self.response_buffer.enc_payload + if not response: + try: + response = await self.encode_outbound(self.response_buffer) + except WireFormatError as e: + LOGGER.warning("Error encoding direct response: %s", str(e)) + self.clear_response() + if response: + return response + self.response_event.clear() + await self.response_event.wait() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + """Async context manager entry.""" + self.close() diff --git a/aries_cloudagent/tests/test_delivery_queue.py b/aries_cloudagent/transport/inbound/tests/test_delivery_queue.py similarity index 61% rename from aries_cloudagent/tests/test_delivery_queue.py rename to aries_cloudagent/transport/inbound/tests/test_delivery_queue.py index 8a4ace47b4..6f05bfd865 100644 --- a/aries_cloudagent/tests/test_delivery_queue.py +++ b/aries_cloudagent/transport/inbound/tests/test_delivery_queue.py @@ -1,35 +1,21 @@ import asyncio from unittest import mock, TestCase -from aries_cloudagent.delivery_queue import DeliveryQueue from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock -from aries_cloudagent import messaging -from aries_cloudagent.connections.models.diddoc import DIDDoc, PublicKey, PublicKeyType, Service -from .. import conductor as test_module -from ..admin.base_server import BaseAdminServer -from ..config.base_context import ContextBuilder -from ..config.injection_context import InjectionContext -from ..connections.models.connection_target import ConnectionTarget -from ..messaging.message_delivery import MessageDelivery -from ..messaging.serializer import MessageSerializer -from ..messaging.outbound_message import OutboundMessage -from ..messaging.protocol_registry import ProtocolRegistry -from ..transport.inbound.base import InboundTransportConfiguration -from ..transport.outbound.queue.base import BaseOutboundMessageQueue -from ..transport.outbound.queue.basic import BasicOutboundMessageQueue -from ..wallet.base import BaseWallet -from ..wallet.basic import BasicWallet +from ....connections.models.connection_target import ConnectionTarget +from ....transport.outbound.message import OutboundMessage +from ..delivery_queue import DeliveryQueue -class TestDeliveryQueue(AsyncTestCase): +class TestDeliveryQueue(AsyncTestCase): async def test_message_add_and_check(self): queue = DeliveryQueue() t = ConnectionTarget(recipient_keys=["aaa"]) - msg = OutboundMessage("x", target=t) + msg = OutboundMessage(payload="x", target=t) queue.add_message(msg) assert queue.has_message_for_key("aaa") @@ -37,7 +23,7 @@ async def test_message_add_not_false_check(self): queue = DeliveryQueue() t = ConnectionTarget(recipient_keys=["aaa"]) - msg = OutboundMessage("x", target=t) + msg = OutboundMessage(payload="x", target=t) queue.add_message(msg) assert queue.has_message_for_key("bbb") is False @@ -45,7 +31,7 @@ async def test_message_add_get_by_key(self): queue = DeliveryQueue() t = ConnectionTarget(recipient_keys=["aaa"]) - msg = OutboundMessage("x", target=t) + msg = OutboundMessage(payload="x", target=t) queue.add_message(msg) assert queue.has_message_for_key("aaa") assert queue.get_one_message_for_key("aaa") == msg @@ -55,7 +41,7 @@ async def test_message_add_get_by_list(self): queue = DeliveryQueue() t = ConnectionTarget(recipient_keys=["aaa"]) - msg = OutboundMessage("x", target=t) + msg = OutboundMessage(payload="x", target=t) queue.add_message(msg) assert queue.has_message_for_key("aaa") msg_list = [m for m in queue.inspect_all_messages_for_key("aaa")] @@ -69,7 +55,7 @@ async def test_message_ttl(self): queue = DeliveryQueue() t = ConnectionTarget(recipient_keys=["aaa"]) - msg = OutboundMessage("x", target=t) + msg = OutboundMessage(payload="x", target=t) queue.add_message(msg) assert queue.has_message_for_key("aaa") queue.expire_messages(ttl=-10) diff --git a/aries_cloudagent/transport/inbound/tests/test_http_transport.py b/aries_cloudagent/transport/inbound/tests/test_http_transport.py index c626afff81..5d4a3fc9e3 100644 --- a/aries_cloudagent/transport/inbound/tests/test_http_transport.py +++ b/aries_cloudagent/transport/inbound/tests/test_http_transport.py @@ -3,40 +3,87 @@ from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop, unused_port from aiohttp import web +from asynctest import mock as async_mock + +from ...outbound.message import OutboundMessage +from ...wire_format import JsonWireFormat from ..http import HttpTransport +from ..message import InboundMessage +from ..session import InboundSession class TestHttpTransport(AioHTTPTestCase): def setUp(self): self.message_results = [] self.port = unused_port() - self.transport = None + self.session = None + self.transport = HttpTransport("0.0.0.0", self.port, self.create_session) + self.transport.wire_format = JsonWireFormat() + self.result_event = None + self.response_message = None super(TestHttpTransport, self).setUp() - def get_transport(self): - if not self.transport: - self.transport = HttpTransport( - "0.0.0.0", self.port, self.receive_message, None + def create_session( + self, + transport_type, + *, + client_info, + wire_format, + can_respond: bool = False, + **kwargs + ): + if not self.session: + session = InboundSession( + context=None, + can_respond=can_respond, + inbound_handler=self.receive_message, + session_id=None, + wire_format=wire_format, + client_info=client_info, + transport_type=transport_type, ) - return self.transport + self.session = session + result = asyncio.Future() + result.set_result(self.session) + return result - def get_application(self): - return self.get_transport().make_application() + def receive_message(self, message: InboundMessage, can_respond: bool = False): + self.message_results.append((message.payload, message.receipt, can_respond)) + if self.result_event: + self.result_event.set() + if self.response_message and self.session: + self.session.set_response(self.response_message) - async def receive_message(self, payload, scheme, single_response=None): - self.message_results.append([json.loads(payload), scheme]) - if single_response: - single_response.set_result('{"response": "ok"}') + def get_application(self): + return self.transport.make_application() @unittest_run_loop async def test_send_message(self): await self.transport.start() test_message = {"test": "message"} - resp = await self.client.post("/", json=test_message) + async with self.client.post("/", json=test_message) as resp: + await resp.text() + + assert self.session is not None + assert self.session.transport_type == "http" + assert len(self.message_results) == 1 + assert self.message_results[0][0] == test_message + + await self.transport.stop() + + @unittest_run_loop + async def test_send_receive_message(self): + await self.transport.start() + + test_message = {"~transport": {"return_route": "all"}, "test": "message"} + test_response = {"response": "ok"} + self.response_message = OutboundMessage( + payload=None, enc_payload=json.dumps(test_response) + ) - assert self.message_results == [[test_message, "http"]] - assert await resp.json() == {"response": "ok"} + async with self.client.post("/", json=test_message) as resp: + assert await resp.json() == {"response": "ok"} await self.transport.stop() diff --git a/aries_cloudagent/transport/inbound/tests/test_manager.py b/aries_cloudagent/transport/inbound/tests/test_manager.py index 5cf4666c8e..ec984673c7 100644 --- a/aries_cloudagent/transport/inbound/tests/test_manager.py +++ b/aries_cloudagent/transport/inbound/tests/test_manager.py @@ -2,32 +2,159 @@ from asynctest import TestCase as AsyncTestCase, mock as async_mock +from ....config.injection_context import InjectionContext + +from ...outbound.message import OutboundMessage + from ..base import InboundTransportConfiguration, InboundTransportRegistrationError from ..manager import InboundTransportManager class TestInboundTransportManager(AsyncTestCase): def test_register_path(self): - mgr = InboundTransportManager() + context = InjectionContext() + mgr = InboundTransportManager(context, None) config = InboundTransportConfiguration(module="http", host="0.0.0.0", port=80) - mgr.register(config, None, None) + mgr.register(config) config = InboundTransportConfiguration( module="notransport", host="0.0.0.0", port=80 ) with self.assertRaises(InboundTransportRegistrationError): - mgr.register(config, None, None) + mgr.register(config) + + async def test_setup(self): + context = InjectionContext() + test_module = "http" + test_host = "host" + test_port = 80 + context.update_settings( + { + "transport.inbound_configs": [[test_module, test_host, test_port]], + "transport.enable_undelivered_queue": True, + } + ) + mgr = InboundTransportManager(context, None) + + with async_mock.patch.object(mgr, "register") as mock_register: + await mgr.setup() + mock_register.assert_called_once() + tcfg: InboundTransportConfiguration = mock_register.call_args[0][0] + assert (tcfg.module, tcfg.host, tcfg.port) == ( + test_module, + test_host, + test_port, + ) + + assert mgr.undelivered_queue async def test_start_stop(self): transport = async_mock.MagicMock() transport.start = async_mock.CoroutineMock() transport.stop = async_mock.CoroutineMock() - mgr = InboundTransportManager() - mgr.register_instance(transport) + context = InjectionContext() + mgr = InboundTransportManager(context, None) + mgr.register_transport(transport, "transport_cls") await mgr.start() - transport.start.assert_called_once_with() + await mgr.task_queue + transport.start.assert_awaited_once_with() + assert mgr.get_transport_instance("transport_cls") is transport await mgr.stop() - transport.stop.assert_called_once_with() + transport.stop.assert_awaited_once_with() + + async def test_create_session(self): + context = InjectionContext() + test_inbound_handler = async_mock.CoroutineMock() + mgr = InboundTransportManager(context, test_inbound_handler) + test_transport = "http" + test_accept = True + test_can_respond = True + test_client_info = {"client": "info"} + test_wire_format = async_mock.MagicMock() + session = await mgr.create_session( + test_transport, + accept_undelivered=test_accept, + can_respond=test_can_respond, + client_info=test_client_info, + wire_format=test_wire_format, + ) + + assert session.accept_undelivered == test_accept + assert session.can_respond == test_can_respond + assert session.client_info == test_client_info + assert session.transport_type == test_transport + assert session.wire_format is test_wire_format + assert session.session_id and mgr.sessions[session.session_id] is session + + await session.inbound_handler() + test_inbound_handler.assert_awaited_once_with() + + session.close_handler(session) + assert session.session_id not in mgr.sessions + + async def test_return_to_session(self): + context = InjectionContext() + mgr = InboundTransportManager(context, None) + test_wire_format = async_mock.MagicMock() + + session = await mgr.create_session("http", wire_format=test_wire_format) + + test_outbound = OutboundMessage(payload=None) + test_outbound.reply_session_id = session.session_id + + with async_mock.patch.object( + session, "accept_response", return_value=True + ) as mock_accept: + assert mgr.return_to_session(test_outbound) is True + mock_accept.assert_called_once_with(test_outbound) + + test_outbound = OutboundMessage(payload=None) + test_outbound.reply_session_id = None + + with async_mock.patch.object( + session, "accept_response", return_value=False + ) as mock_accept: + assert mgr.return_to_session(test_outbound) is False + mock_accept.assert_called_once_with(test_outbound) + + async def test_close_return(self): + context = InjectionContext() + test_return = async_mock.MagicMock() + mgr = InboundTransportManager(context, None, return_inbound=test_return) + test_wire_format = async_mock.MagicMock() + + session = await mgr.create_session("http", wire_format=test_wire_format) + + test_outbound = OutboundMessage(payload=None) + session.set_response(test_outbound) + + session.close() + test_return.assert_called_once_with(session.context, test_outbound) + + async def test_process_undelivered(self): + context = InjectionContext() + context.update_settings({"transport.enable_undelivered_queue": True}) + test_verkey = "test-verkey" + test_wire_format = async_mock.MagicMock() + mgr = InboundTransportManager(context, None) + await mgr.setup() + + test_outbound = OutboundMessage(payload=None) + test_outbound.reply_to_verkey = test_verkey + assert mgr.return_undelivered(test_outbound) + assert mgr.undelivered_queue.has_message_for_key(test_verkey) + + session = await mgr.create_session( + "http", can_respond=True, wire_format=test_wire_format + ) + session.add_reply_verkeys(test_verkey) + + with async_mock.patch.object( + session, "accept_response", return_value=True + ) as mock_accept: + mgr.process_undelivered(session) + mock_accept.assert_called_once_with(test_outbound) + assert not mgr.undelivered_queue.has_message_for_key(test_verkey) diff --git a/aries_cloudagent/transport/inbound/tests/test_session.py b/aries_cloudagent/transport/inbound/tests/test_session.py new file mode 100644 index 0000000000..456e98793f --- /dev/null +++ b/aries_cloudagent/transport/inbound/tests/test_session.py @@ -0,0 +1,265 @@ +import asyncio + +from asynctest import TestCase, mock as async_mock + +from ....config.injection_context import InjectionContext + +from ...error import WireFormatError +from ...outbound.message import OutboundMessage + +from ..message import InboundMessage +from ..receipt import MessageReceipt +from ..session import InboundSession + + +class TestInboundSession(TestCase): + def test_init(self): + test_ctx = InjectionContext() + test_inbound = async_mock.MagicMock() + test_session_id = "session-id" + test_wire_format = async_mock.MagicMock() + test_client_info = {"client": "info"} + test_close = async_mock.MagicMock() + test_reply_mode = MessageReceipt.REPLY_MODE_ALL + test_reply_thread_ids = {"1", "2"} + test_reply_verkeys = {"3", "4"} + test_transport_type = "transport-type" + sess = InboundSession( + context=test_ctx, + inbound_handler=test_inbound, + session_id=test_session_id, + wire_format=test_wire_format, + client_info=test_client_info, + close_handler=test_close, + reply_mode=test_reply_mode, + reply_thread_ids=test_reply_thread_ids, + reply_verkeys=test_reply_verkeys, + transport_type=test_transport_type, + ) + + assert sess.context is test_ctx + assert sess.session_id == test_session_id + assert sess.wire_format is test_wire_format + assert sess.client_info == test_client_info + assert sess.reply_mode == test_reply_mode + assert sess.transport_type == test_transport_type + assert "1" in sess.reply_thread_ids + assert "3" in sess.reply_verkeys + + test_msg = async_mock.MagicMock() + with async_mock.patch.object(sess, "process_inbound") as process: + sess.receive_inbound(test_msg) + process.assert_called_once_with(test_msg) + test_inbound.assert_called_once_with(test_msg, can_respond=False) + + sess.close() + test_close.assert_called_once() + assert sess.closed + + def test_setters(self): + test_ctx = InjectionContext() + sess = InboundSession( + context=test_ctx, inbound_handler=None, session_id=None, wire_format=None + ) + + sess.reply_mode = MessageReceipt.REPLY_MODE_ALL + assert sess.reply_mode == MessageReceipt.REPLY_MODE_ALL + + sess.add_reply_thread_ids("1") + assert "1" in sess.reply_thread_ids + sess.add_reply_verkeys("2") + assert "2" in sess.reply_verkeys + + sess.reply_mode = "invalid" + assert not sess.reply_mode + assert not sess.reply_thread_ids # reset by setter method + + async def test_parse_inbound(self): + test_ctx = InjectionContext() + test_session_id = "session-id" + test_transport_type = "transport-type" + test_wire_format = async_mock.MagicMock() + test_wire_format.parse_message = async_mock.CoroutineMock() + test_parsed = "parsed-payload" + test_receipt = async_mock.MagicMock() + test_wire_format.parse_message.return_value = (test_parsed, test_receipt) + sess = InboundSession( + context=test_ctx, + inbound_handler=None, + session_id=test_session_id, + wire_format=test_wire_format, + transport_type=test_transport_type, + ) + + test_payload = "{}" + result = await sess.parse_inbound(test_payload) + test_wire_format.parse_message.assert_awaited_once_with(test_ctx, test_payload) + assert result.payload == test_parsed + assert result.receipt is test_receipt + assert result.session_id == test_session_id + assert result.transport_type == test_transport_type + + async def test_receive(self): + test_ctx = InjectionContext() + sess = InboundSession( + context=test_ctx, inbound_handler=None, session_id=None, wire_format=None, + ) + test_msg = async_mock.MagicMock() + + with async_mock.patch.object( + sess, "parse_inbound", async_mock.CoroutineMock() + ) as encode, async_mock.patch.object( + sess, "receive_inbound", async_mock.MagicMock() + ) as receive: + result = await sess.receive(test_msg) + encode.assert_awaited_once_with(test_msg) + receive.assert_called_once_with(encode.return_value) + assert result is encode.return_value + + def test_process_inbound(self): + test_ctx = InjectionContext() + test_session_id = "session-id" + test_thread_id = "thread-id" + test_verkey = "verkey" + sess = InboundSession( + context=test_ctx, + inbound_handler=None, + session_id=test_session_id, + wire_format=None, + ) + + receipt = MessageReceipt( + direct_response_mode=MessageReceipt.REPLY_MODE_THREAD, + thread_id=test_thread_id, + sender_verkey=test_verkey, + ) + message = InboundMessage(payload=None, receipt=receipt) + sess.process_inbound(message) + assert sess.reply_mode == receipt.direct_response_mode + assert test_verkey in sess.reply_verkeys + assert test_thread_id in sess.reply_thread_ids + + def test_select_outbound(self): + test_ctx = InjectionContext() + test_session_id = "session-id" + test_thread_id = "thread-id" + test_verkey = "verkey" + sess = InboundSession( + context=test_ctx, + inbound_handler=None, + session_id=test_session_id, + wire_format=None, + ) + + sess.reply_mode = MessageReceipt.REPLY_MODE_ALL + test_msg = OutboundMessage(payload=None) + assert not sess.select_outbound(test_msg) # no key + test_msg.reply_session_id = test_session_id + assert not sess.select_outbound(test_msg) # no difference + sess.can_respond = True + assert not sess.select_outbound(test_msg) # no difference + test_msg.reply_to_verkey = test_verkey + sess.add_reply_verkeys(test_verkey) + assert sess.select_outbound(test_msg) + + sess.reply_mode = MessageReceipt.REPLY_MODE_THREAD + sess.reply_verkeys = None + sess.reply_thread_ids = None + test_msg = OutboundMessage(payload=None) + assert not sess.select_outbound(test_msg) + sess.add_reply_thread_ids(test_thread_id) + test_msg.reply_thread_id = test_thread_id + assert not sess.select_outbound(test_msg) + sess.add_reply_verkeys(test_verkey) + test_msg.reply_to_verkey = test_verkey + assert sess.select_outbound(test_msg) + + sess.close() + assert not sess.select_outbound(test_msg) + + async def test_wait_response(self): + test_ctx = InjectionContext() + sess = InboundSession( + context=test_ctx, inbound_handler=None, session_id=None, wire_format=None, + ) + test_msg = OutboundMessage(payload=None) + sess.set_response(test_msg) + assert sess.response_event.is_set() + assert sess.response_buffered + + with async_mock.patch.object( + sess, "encode_outbound", async_mock.CoroutineMock() + ) as encode: + result = await asyncio.wait_for(sess.wait_response(), 0.1) + assert encode.awaited_once_with(test_msg) + assert result is encode.return_value + + sess.clear_response() + assert not sess.response_buffer + + sess.close() + assert await asyncio.wait_for(sess.wait_response(), 0.1) is None + + async def test_encode_response(self): + test_ctx = InjectionContext() + test_wire_format = async_mock.MagicMock() + test_wire_format.encode_message = async_mock.CoroutineMock() + sess = InboundSession( + context=test_ctx, + inbound_handler=None, + session_id=None, + wire_format=test_wire_format, + ) + test_msg = OutboundMessage(payload=None) + test_from_verkey = "from-verkey" + test_to_verkey = "to-verkey" + + with self.assertRaises(WireFormatError): + await sess.encode_outbound(test_msg) + test_msg.payload = "{}" + with self.assertRaises(WireFormatError): + await sess.encode_outbound(test_msg) + test_msg.reply_from_verkey = test_from_verkey + test_msg.reply_to_verkey = test_to_verkey + result = await sess.encode_outbound(test_msg) + assert result is test_wire_format.encode_message.return_value + + test_wire_format.encode_message.assert_awaited_once_with( + test_ctx, + test_msg.payload, + [test_to_verkey], + None, + test_from_verkey, + ) + + async def test_accept_response(self): + test_ctx = InjectionContext() + sess = InboundSession( + context=test_ctx, inbound_handler=None, session_id=None, wire_format=None, + ) + test_msg = OutboundMessage(payload=None) + + with async_mock.patch.object(sess, "select_outbound") as selector: + selector.return_value = False + + accepted = sess.accept_response(test_msg) + assert not accepted and not accepted.retry + + sess.set_response(OutboundMessage(payload=None)) + selector.return_value = True + accepted = sess.accept_response(test_msg) + assert not accepted and accepted.retry + + sess.clear_response() + accepted = sess.accept_response(test_msg) + assert accepted + + async def test_context_mgr(self): + test_ctx = InjectionContext() + sess = InboundSession( + context=test_ctx, inbound_handler=None, session_id=None, wire_format=None, + ) + assert not sess.closed + async with sess: + pass + assert sess.closed diff --git a/aries_cloudagent/transport/inbound/tests/test_ws_transport.py b/aries_cloudagent/transport/inbound/tests/test_ws_transport.py index 8d4fef5b23..50dd5d0840 100644 --- a/aries_cloudagent/transport/inbound/tests/test_ws_transport.py +++ b/aries_cloudagent/transport/inbound/tests/test_ws_transport.py @@ -3,8 +3,13 @@ from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop, unused_port from aiohttp import web +from asynctest import mock as async_mock -from ....messaging.socket import SocketRef +from ...outbound.message import OutboundMessage +from ...wire_format import JsonWireFormat + +from ..message import InboundMessage +from ..session import InboundSession from ..ws import WsTransport @@ -12,46 +17,68 @@ class TestWsTransport(AioHTTPTestCase): def setUp(self): self.message_results = [] self.port = unused_port() - self.response_handler = None - self.socket_id = 99 - self.transport = WsTransport( - "0.0.0.0", self.port, self.receive_message, self.register_socket - ) + self.session = None + self.transport = WsTransport("0.0.0.0", self.port, self.create_session) + self.transport.wire_format = JsonWireFormat() + self.result_event = None super(TestWsTransport, self).setUp() - async def register_socket(self, handler): - assert not self.response_handler - self.response_handler = handler - return SocketRef(self.socket_id, self.close_socket) - - async def close_socket(self): - assert self.response_handler - self.response_handler = None + def create_session( + self, + transport_type, + *, + client_info, + wire_format, + can_respond: bool = False, + **kwargs + ): + if not self.session: + session = InboundSession( + context=None, + can_respond=can_respond, + inbound_handler=self.receive_message, + session_id=None, + wire_format=wire_format, + client_info=client_info, + transport_type=transport_type, + ) + self.session = session + result = asyncio.Future() + result.set_result(self.session) + return result def get_application(self): return self.transport.make_application() - async def receive_message(self, payload, scheme, socket_id): - assert socket_id == self.socket_id - self.message_results.append([json.loads(payload), scheme]) - if self.response_handler: - await self.response_handler('{"response": "ok"}') - else: - self.fail("no response handler") + def receive_message(self, message: InboundMessage, can_respond: bool = False): + self.message_results.append((message.payload, message.receipt, can_respond)) + if self.result_event: + self.result_event.set() @unittest_run_loop - async def test_send_message(self): + async def test_message_and_response(self): await self.transport.start() test_message = {"test": "message"} + test_response = {"response": "ok"} + async with self.client.ws_connect("/") as ws: + self.result_event = asyncio.Event() await ws.send_json(test_message) - assert self.response_handler - result = await asyncio.wait_for(ws.receive_json(), 5.0) - assert json.loads(result) == {"response": "ok"} + await asyncio.wait((self.result_event.wait(),), timeout=0.1) - assert self.message_results == [[test_message, "ws"]] + assert self.session is not None + assert len(self.message_results) == 1 + received, receipt, can_respond = self.message_results[0] + assert received == test_message + assert can_respond - await self.transport.stop() + response = OutboundMessage( + payload=None, enc_payload=json.dumps(test_response) + ) + self.session.set_response(response) + + result = await asyncio.wait_for(ws.receive_json(), 1.0) + assert result == {"response": "ok"} - assert not self.response_handler + await self.transport.stop() diff --git a/aries_cloudagent/transport/inbound/ws.py b/aries_cloudagent/transport/inbound/ws.py index bc4aad09d7..13de94f6c0 100644 --- a/aries_cloudagent/transport/inbound/ws.py +++ b/aries_cloudagent/transport/inbound/ws.py @@ -1,44 +1,36 @@ """Websockets Transport classes and functions.""" +import asyncio import logging -from typing import Coroutine -from aiohttp import web, WSMsgType +from aiohttp import web, WSMessage, WSMsgType -from ...messaging.socket import SocketRef +from ...messaging.error import MessageParseError from .base import BaseInboundTransport, InboundTransportSetupError +LOGGER = logging.getLogger(__name__) + class WsTransport(BaseInboundTransport): """Websockets Transport class.""" - def __init__( - self, - host: str, - port: int, - message_router: Coroutine, - register_socket: Coroutine, - ) -> None: + def __init__(self, host: str, port: int, create_session, **kwargs) -> None: """ - Initialize a Transport instance. + Initialize an inbound WebSocket transport instance. Args: host: Host to listen on port: Port to listen on - message_router: Function to pass incoming messages to - register_socket: A coroutine for registering a new socket + create_session: Method to create a new inbound session """ + super().__init__("ws", create_session, **kwargs) self.host = host self.port = port - self.message_router = message_router - self.register_socket = register_socket - self.site = None + self.site: web.BaseSite = None # TODO: set scheme dynamically based on SSL settings (ws/wss) - self._scheme = "ws" - self.logger = logging.getLogger(__name__) @property def scheme(self): @@ -88,37 +80,59 @@ async def inbound_message_handler(self, request): The web response """ + ws = web.WebSocketResponse() await ws.prepare(request) + loop = asyncio.get_event_loop() - async def reply(result): - if isinstance(result, str): - await ws.send_json(result) - else: - await ws.send_bytes(result) - - socket: SocketRef = await self.register_socket(handler=reply) - - # Listen for incoming messages - async for msg in ws: - self.logger.info(f"Received message: {msg.data}") - if msg.type in (WSMsgType.TEXT, WSMsgType.BINARY): - try: - # Route message and provide connection instance as means to respond - await self.message_router(msg.data, self._scheme, socket.socket_id) - except Exception: - self.logger.exception("Error handling message") - - elif msg.type == WSMsgType.ERROR: - self.logger.error( - f"Websocket connection closed with exception {ws.exception()}" - ) + client_info = {"host": request.host, "remote": request.remote} - else: - self.logger.warning(f"Unexpected websocket message type {msg.type}") + session = await self.create_session( + accept_undelivered=True, can_respond=True, client_info=client_info + ) - self.logger.info("Websocket connection closed") + async with session: + inbound = loop.create_task(ws.receive()) + outbound = loop.create_task(session.wait_response()) + + while not ws.closed: + await asyncio.wait( + (inbound, outbound), return_when=asyncio.FIRST_COMPLETED + ) - await socket.close() + if inbound.done(): + msg: WSMessage = inbound.result() + LOGGER.info("Websocket received message: %s", msg.data) + if msg.type in (WSMsgType.TEXT, WSMsgType.BINARY): + try: + await session.receive(msg.data) + except MessageParseError: + await ws.close(1003) # unsupported data error + elif msg.type == WSMsgType.ERROR: + LOGGER.error( + "Websocket connection closed with exception: %s", + ws.exception(), + ) + if not ws.closed: + inbound = loop.create_task(ws.receive()) + + if outbound.done() and not ws.closed: + # response would be None if session was closed + response = outbound.result() + if isinstance(response, bytes): + await ws.send_bytes(response) + else: + await ws.send_str(response) + session.clear_response() + outbound = loop.create_task(session.wait_response()) + + if inbound and not inbound.done(): + inbound.cancel() + if outbound and not outbound.done(): + outbound.cancel() + + if not ws.closed: + await ws.close() + LOGGER.info("Websocket connection closed") return ws diff --git a/aries_cloudagent/transport/outbound/base.py b/aries_cloudagent/transport/outbound/base.py index c53d7d098f..30258d0cdb 100644 --- a/aries_cloudagent/transport/outbound/base.py +++ b/aries_cloudagent/transport/outbound/base.py @@ -1,19 +1,22 @@ """Base outbound transport.""" -from abc import ABC, abstractmethod import asyncio +from abc import ABC, abstractmethod +from typing import Union -from ...error import BaseError -from ...messaging.outbound_message import OutboundMessage from ...stats import Collector +from ..error import TransportError +from ..wire_format import BaseWireFormat + class BaseOutboundTransport(ABC): """Base outbound transport class.""" - def __init__(self) -> None: + def __init__(self, wire_format: BaseWireFormat = None) -> None: """Initialize a `BaseOutboundTransport` instance.""" self._collector = None + self._wire_format = wire_format @property def collector(self) -> Collector: @@ -43,15 +46,34 @@ async def start(self): async def stop(self): """Shut down the transport.""" + @property + def wire_format(self) -> BaseWireFormat: + """Accessor for a custom wire format for the transport.""" + return self._wire_format + + @wire_format.setter + def wire_format(self, format: BaseWireFormat): + """Setter for a custom wire format for the transport.""" + self._wire_format = format + @abstractmethod - async def handle_message(self, message: OutboundMessage): + async def handle_message(self, payload: Union[str, bytes], endpoint: str): """ Handle message from queue. Args: - message: `OutboundMessage` to send over transport implementation + payload: message payload in string or byte format + endpoint: URI endpoint for delivery """ -class OutboundTransportRegistrationError(BaseError): +class OutboundTransportError(TransportError): + """Generic outbound transport error.""" + + +class OutboundTransportRegistrationError(OutboundTransportError): """Outbound transport registration error.""" + + +class OutboundDeliveryError(OutboundTransportError): + """Base exception when a message cannot be delivered via an outbound transport.""" diff --git a/aries_cloudagent/transport/outbound/http.py b/aries_cloudagent/transport/outbound/http.py index 9f435b6dfd..5078116ba4 100644 --- a/aries_cloudagent/transport/outbound/http.py +++ b/aries_cloudagent/transport/outbound/http.py @@ -1,13 +1,13 @@ """Http outbound transport.""" import logging +from typing import Union -from aiohttp import ClientSession, DummyCookieJar +from aiohttp import ClientSession, DummyCookieJar, TCPConnector -from ...messaging.outbound_message import OutboundMessage from ..stats import StatsTracer -from .base import BaseOutboundTransport +from .base import BaseOutboundTransport, OutboundTransportError class HttpTransport(BaseOutboundTransport): @@ -18,16 +18,20 @@ class HttpTransport(BaseOutboundTransport): def __init__(self) -> None: """Initialize an `HttpTransport` instance.""" super(HttpTransport, self).__init__() + self.client_session: ClientSession = None + self.connector: TCPConnector = None self.logger = logging.getLogger(__name__) async def start(self): """Start the transport.""" session_args = {} + self.connector = TCPConnector(limit=200, limit_per_host=50) if self.collector: session_args["trace_configs"] = [ StatsTracer(self.collector, "outbound-http:") ] session_args["cookie_jar"] = DummyCookieJar() + session_args["connector"] = self.connector self.client_session = ClientSession(**session_args) return self @@ -36,20 +40,22 @@ async def stop(self): await self.client_session.close() self.client_session = None - async def handle_message(self, message: OutboundMessage): + async def handle_message(self, payload: Union[str, bytes], endpoint: str): """ Handle message from queue. Args: message: `OutboundMessage` to send over transport implementation """ + if not endpoint: + raise OutboundTransportError("No endpoint provided") headers = {} - if isinstance(message.payload, bytes): + if isinstance(payload, bytes): headers["Content-Type"] = "application/ssi-agent-wire" else: headers["Content-Type"] = "application/json" async with self.client_session.post( - message.endpoint, data=message.payload, headers=headers + endpoint, data=payload, headers=headers ) as response: if response.status < 200 or response.status > 299: - raise Exception("Unexpected response status") + raise OutboundTransportError("Unexpected response status") diff --git a/aries_cloudagent/transport/outbound/manager.py b/aries_cloudagent/transport/outbound/manager.py index a0159c8f19..cb8aed6d72 100644 --- a/aries_cloudagent/transport/outbound/manager.py +++ b/aries_cloudagent/transport/outbound/manager.py @@ -1,52 +1,98 @@ """Outbound transport manager.""" import asyncio +import json import logging -import uuid +import time -from typing import Type +from typing import Callable, Type, Union from urllib.parse import urlparse -from ...error import BaseError from ...classloader import ClassLoader, ModuleLoadError, ClassNotFoundError -from ...messaging.outbound_message import OutboundMessage +from ...connections.models.connection_target import ConnectionTarget +from ...config.injection_context import InjectionContext +from ...messaging.task_queue import CompletedTask, TaskQueue, task_exc_info from ...stats import Collector -from ...task_processor import TaskProcessor -from .base import BaseOutboundTransport, OutboundTransportRegistrationError -from .queue.base import BaseOutboundMessageQueue +from ..wire_format import BaseWireFormat +from .base import ( + BaseOutboundTransport, + OutboundDeliveryError, + OutboundTransportRegistrationError, +) +from .message import OutboundMessage +LOGGER = logging.getLogger(__name__) MODULE_BASE_PATH = "aries_cloudagent.transport.outbound" -class OutboundDeliveryError(BaseError): - """Base exception when a message cannot be delivered via an outbound transport.""" +class QueuedOutboundMessage: + """Class representing an outbound message pending delivery.""" + + STATE_NEW = "new" + STATE_PENDING = "pending" + STATE_ENCODE = "encode" + STATE_DELIVER = "deliver" + STATE_RETRY = "retry" + STATE_DONE = "done" + + def __init__( + self, + context: InjectionContext, + message: OutboundMessage, + target: ConnectionTarget, + transport_id: str, + ): + """Initialize the queued outbound message.""" + self.context = context + self.endpoint = target and target.endpoint + self.error: Exception = None + self.message = message + self.payload: Union[str, bytes] = None + self.retries = None + self.retry_at: float = None + self.state = self.STATE_NEW + self.target = target + self.task: asyncio.Task = None + self.transport_id: str = transport_id class OutboundTransportManager: """Outbound transport manager class.""" def __init__( - self, queue: BaseOutboundMessageQueue = None, collector: Collector = None + self, context: InjectionContext, handle_not_delivered: Callable = None ): """ Initialize a `OutboundTransportManager` instance. Args: - queue: `BaseOutboundMessageQueue` instance to use + context: The application context + handle_not_delivered: An optional handler for undelivered messages """ - self.logger = logging.getLogger(__name__) - self.polling_task = None - self.processor: TaskProcessor = None - self.queue: BaseOutboundMessageQueue = queue + self.context = context + self.loop = asyncio.get_event_loop() + self.handle_not_delivered = handle_not_delivered + self.outbound_buffer = [] + self.outbound_event = asyncio.Event() + self.outbound_new = [] + self.registered_schemes = {} self.registered_transports = {} self.running_transports = {} - self.startup_tasks = [] - self.collector = collector - - def register(self, module: str): + self.task_queue = TaskQueue(max_active=200) + self._process_task: asyncio.Task = None + + async def setup(self): + """Perform setup operations.""" + outbound_transports = ( + self.context.settings.get("transport.outbound_configs") or [] + ) + for outbound_transport in outbound_transports: + self.register(outbound_transport) + + def register(self, module: str) -> str: """ Register a new outbound transport by module path. @@ -71,9 +117,11 @@ def register(self, module: str): f"Outbound transport module {module} could not be resolved." ) - self.register_class(imported_class) + return self.register_class(imported_class) - def register_class(self, transport_class: Type[BaseOutboundTransport]): + def register_class( + self, transport_class: Type[BaseOutboundTransport], transport_id: str = None + ) -> str: """ Register a new outbound transport class. @@ -94,136 +142,289 @@ def register_class(self, transport_class: Type[BaseOutboundTransport]): f"Imported class {transport_class} does not " + "specify a required 'schemes' attribute" ) + if not transport_id: + transport_id = transport_class.__qualname__ for scheme in schemes: - # A scheme can only be registered once - for scheme_tuple in self.registered_transports.keys(): - if scheme in scheme_tuple: - raise OutboundTransportRegistrationError( - f"Cannot register transport '{transport_class.__qualname__}'" - + f"for '{scheme}' scheme because the scheme" - + "has already been registered" - ) - - self.registered_transports[tuple(schemes)] = transport_class - - async def start_transport(self, schemes, transport_cls): - """Start the transport.""" - transport = transport_cls() - transport.collector = self.collector + if scheme in self.registered_schemes: + # A scheme can only be registered once + raise OutboundTransportRegistrationError( + f"Cannot register transport '{transport_id}'" + f"for '{scheme}' scheme because the scheme" + "has already been registered" + ) + + self.registered_transports[transport_id] = transport_class + + for scheme in schemes: + self.registered_schemes[scheme] = transport_id + + return transport_id + + async def start_transport(self, transport_id: str): + """Start a registered transport.""" + transport = self.registered_transports[transport_id]() + transport.collector = await self.context.inject(Collector, required=False) await transport.start() - self.running_transports[schemes] = transport + self.running_transports[transport_id] = transport async def start(self): """Start all transports and feed messages from the queue.""" - startup = [] - loop = asyncio.get_event_loop() - for schemes, transport_class in self.registered_transports.items(): - # Don't block the loop - startup.append( - loop.create_task(self.start_transport(schemes, transport_class)) - ) - self.startup_tasks = startup - self.polling_task = loop.create_task(self.poll()) + for transport_id in self.registered_transports: + self.task_queue.run(self.start_transport(transport_id)) async def stop(self, wait: bool = True): - """Stop all transports.""" - self.queue.stop() - if wait: - await self.queue.join() - if self.polling_task: - if wait: - await self.polling_task - elif not self.polling_task.done: - self.polling_task.cancel() - self.polling_task = None + """Stop all running transports.""" + if self._process_task and not self._process_task.done(): + self._process_task.cancel() + await self.task_queue.complete(None if wait else 0) for transport in self.running_transports.values(): await transport.stop() - if self.startup_tasks: - for task in self.startup_tasks: - if wait: - await task - elif not task.done(): - task.cancel() - self.startup_tasks = [] self.running_transports = {} - async def poll(self): - """Send messages from the queue to the transports.""" - self.processor = TaskProcessor(max_pending=10) - async for message in self.queue: - unique = str(uuid.uuid4) - self.logger.debug(f"Processing message from queue. id: {unique}") - await self.processor.run_retry( - lambda pending, msg=message: self.dispatch_message( - msg, pending.attempts + 1 - ), - retries=5, - retry_delay=10.0, - ) - self.logger.debug(f"Done processing message from queue id: {unique}") - self.queue.task_done() - - await self.processor.wait_done() - - def get_registered_transport_for_scheme(self, scheme: str): - """Find the registered transport for a given scheme.""" + def get_registered_transport_for_scheme(self, scheme: str) -> str: + """Find the registered transport ID for a given scheme.""" try: return next( - transport - for schemes, transport in self.registered_transports.items() - if scheme in schemes + transport_id + for transport_id, transport in self.registered_transports.items() + if scheme in transport.schemes ) except StopIteration: pass - def get_running_transport_for_scheme(self, scheme: str): - """Find the running transport for a given scheme.""" + def get_running_transport_for_scheme(self, scheme: str) -> str: + """Find the running transport ID for a given scheme.""" try: return next( - transport - for schemes, transport in self.running_transports.items() - if scheme in schemes + transport_id + for transport_id, transport in self.running_transports.items() + if scheme in transport.schemes ) except StopIteration: pass - async def send_message(self, message: OutboundMessage): + def get_running_transport_for_endpoint(self, endpoint: str): + """Find the running transport ID to use for a given endpoint.""" + # Grab the scheme from the uri + scheme = urlparse(endpoint).scheme + if scheme == "": + raise OutboundDeliveryError( + f"The uri '{endpoint}' does not specify a scheme" + ) + + # Look up transport that is registered to handle this scheme + transport_id = self.get_running_transport_for_scheme(scheme) + if not transport_id: + raise OutboundDeliveryError( + f"No transport driver exists to handle scheme '{scheme}'" + ) + return transport_id + + def get_transport_instance(self, transport_id: str) -> BaseOutboundTransport: + """Get an instance of a running transport by ID.""" + return self.running_transports[transport_id] + + def enqueue_message(self, context: InjectionContext, outbound: OutboundMessage): """ - Add a message to the outbound queue. + Add an outbound message to the queue. Args: - message: The outbound message to send - + context: The context of the request + outbound: The outbound message to deliver """ - if self.queue: - await self.queue.enqueue(message) - else: - await self.dispatch_message(message) + targets = [outbound.target] if outbound.target else (outbound.target_list or []) + transport_id = None + for target in targets: + endpoint = target.endpoint + try: + transport_id = self.get_running_transport_for_endpoint(endpoint) + except OutboundDeliveryError: + pass + if transport_id: + break + if not transport_id: + raise OutboundDeliveryError("No supported transport for outbound message") + + queued = QueuedOutboundMessage(context, outbound, target, transport_id) + queued.retries = 5 + self.outbound_new.append(queued) + self.process_queued() + + def enqueue_webhook( + self, topic: str, payload: dict, endpoint: str, retries: int = None + ): + """ + Add a webhook to the queue. - async def dispatch_message(self, message: OutboundMessage, attempt: int = None): - """Dispatch a message to the relevant transport. + Args: + topic: The webhook topic + payload: The webhook payload + endpoint: The webhook endpoint + retries: Override the number of retries - Find a registered transport for the scheme in the uri and - use it to send the message. + Raises: + OutboundDeliveryError: if the associated transport is not running - Args: - message: The outbound message to dispatch + """ + transport_id = self.get_running_transport_for_endpoint(endpoint) + queued = QueuedOutboundMessage(None, None, None, transport_id) + queued.endpoint = f"{endpoint}/topic/{topic}/" + queued.payload = json.dumps(payload) + queued.state = QueuedOutboundMessage.STATE_PENDING + queued.retries = 5 if retries is None else retries + self.outbound_new.append(queued) + self.process_queued() + + def process_queued(self) -> asyncio.Task: + """ + Start the process to deliver queued messages if necessary. + + Returns: the current queue processing task or None """ - # Grab the scheme from the uri - scheme = urlparse(message.endpoint).scheme - if scheme == "": - raise OutboundDeliveryError( - f"The uri '{message.endpoint}' does not specify a scheme" + if self._process_task and not self._process_task.done(): + self.outbound_event.set() + elif self.outbound_new or self.outbound_buffer: + self._process_task = self.loop.create_task(self._process_loop()) + self._process_task.add_done_callback(lambda task: self._process_done(task)) + return self._process_task + + def _process_done(self, task: asyncio.Task): + """Handle completion of the drain process.""" + exc_info = task_exc_info(task) + if exc_info: + LOGGER.exception( + "Exception in outbound queue processing:", exc_info=exc_info ) - - # Look up transport that is registered to handle this scheme - transport = self.get_running_transport_for_scheme(scheme) - if not transport: - raise OutboundDeliveryError( - f"No transport driver exists to handle scheme '{scheme}'" + if self._process_task and self._process_task.done(): + self._process_task = None + + async def _process_loop(self): + """Continually kick off encoding and delivery on outbound messages.""" + # Note: this method should not call async methods apart from + # waiting for the updated event, to avoid yielding to other queue methods + + while True: + self.outbound_event.clear() + loop_time = time.perf_counter() + upd_buffer = [] + + for queued in self.outbound_buffer: + if queued.state == QueuedOutboundMessage.STATE_DONE: + if queued.error: + LOGGER.exception( + "Outbound message could not be delivered to %s", + queued.endpoint, + exc_info=queued.error, + ) + if self.handle_not_delivered: + self.handle_not_delivered(queued.context, queued.message) + continue # remove from buffer + + deliver = False + + if queued.state == QueuedOutboundMessage.STATE_PENDING: + deliver = True + elif queued.state == QueuedOutboundMessage.STATE_RETRY: + if queued.retry_at < loop_time: + queued.retry_at = None + deliver = True + + if deliver: + queued.state = QueuedOutboundMessage.STATE_DELIVER + self.deliver_queued_message(queued) + + upd_buffer.append(queued) + + new_pending = 0 + new_messages = self.outbound_new + self.outbound_new = [] + + for queued in new_messages: + if queued.state == QueuedOutboundMessage.STATE_NEW: + if queued.message and queued.message.enc_payload: + queued.payload = queued.message.enc_payload + queued.state = QueuedOutboundMessage.STATE_PENDING + new_pending += 1 + else: + queued.state = QueuedOutboundMessage.STATE_ENCODE + self.encode_queued_message(queued) + else: + new_pending += 1 + + upd_buffer.append(queued) + + self.outbound_buffer = upd_buffer + if self.outbound_buffer: + if not new_pending: + await self.outbound_event.wait() + else: + break + + def encode_queued_message(self, queued: QueuedOutboundMessage) -> asyncio.Task: + """Kick off encoding of a queued message.""" + queued.task = self.task_queue.run( + self.perform_encode(queued), + lambda completed: self.finished_encode(queued, completed), + ) + return queued.task + + async def perform_encode(self, queued: QueuedOutboundMessage): + """Perform message encoding.""" + transport = self.get_transport_instance(queued.transport_id) + wire_format = transport.wire_format or await queued.context.inject( + BaseWireFormat + ) + queued.payload = await wire_format.encode_message( + queued.context, + queued.message.payload, + queued.target.recipient_keys, + queued.target.routing_keys, + queued.target.sender_key, + ) + + def finished_encode(self, queued: QueuedOutboundMessage, completed: CompletedTask): + """Handle completion of queued message encoding.""" + if completed.exc_info: + queued.error = completed.exc_info + queued.state = QueuedOutboundMessage.STATE_DONE + else: + queued.state = QueuedOutboundMessage.STATE_PENDING + queued.task = None + self.process_queued() + + def deliver_queued_message(self, queued: QueuedOutboundMessage) -> asyncio.Task: + """Kick off delivery of a queued message.""" + transport = self.get_transport_instance(queued.transport_id) + queued.task = self.task_queue.run( + transport.handle_message(queued.payload, queued.endpoint), + lambda completed: self.finished_deliver(queued, completed), + ) + return queued.task + + def finished_deliver(self, queued: QueuedOutboundMessage, completed: CompletedTask): + """Handle completion of queued message delivery.""" + if completed.exc_info: + queued.error = completed.exc_info + LOGGER.exception( + "Outbound message could not be delivered", exc_info=queued.error, ) - # TODO log delivery error on final attempt - await transport.handle_message(message) + if queued.retries: + queued.retries -= 1 + queued.state = QueuedOutboundMessage.STATE_RETRY + queued.retry_at = time.perf_counter() + 10 + else: + queued.state = QueuedOutboundMessage.STATE_DONE + else: + queued.error = None + queued.state = QueuedOutboundMessage.STATE_DONE + queued.task = None + self.process_queued() + + async def flush(self): + """Wait for any queued messages to be delivered.""" + proc_task = self.process_queued() + if proc_task: + await proc_task diff --git a/aries_cloudagent/messaging/outbound_message.py b/aries_cloudagent/transport/outbound/message.py similarity index 61% rename from aries_cloudagent/messaging/outbound_message.py rename to aries_cloudagent/transport/outbound/message.py index 6b2d3ea9b7..c946906db9 100644 --- a/aries_cloudagent/messaging/outbound_message.py +++ b/aries_cloudagent/transport/outbound/message.py @@ -1,8 +1,8 @@ """Outbound message representation.""" -from typing import Union +from typing import Sequence, Union -from ..connections.models.connection_target import ConnectionTarget +from ...connections.models.connection_target import ConnectionTarget class OutboundMessage: @@ -10,38 +10,29 @@ class OutboundMessage: def __init__( self, - payload: Union[str, bytes], *, connection_id: str = None, - encoded: bool = False, + enc_payload: Union[str, bytes] = None, endpoint: str = None, - reply_socket_id: str = None, + payload: Union[str, bytes], + reply_session_id: str = None, reply_thread_id: str = None, reply_to_verkey: str = None, + reply_from_verkey: str = None, target: ConnectionTarget = None, + target_list: Sequence[ConnectionTarget] = None, ): """Initialize an outgoing message.""" self.connection_id = connection_id - self.encoded = encoded + self.enc_payload = enc_payload self._endpoint = endpoint self.payload = payload - self.reply_socket_id = reply_socket_id + self.reply_session_id = reply_session_id self.reply_thread_id = reply_thread_id self.reply_to_verkey = reply_to_verkey + self.reply_from_verkey = reply_from_verkey self.target = target - - @property - def endpoint(self) -> str: - """Return the endpoint of the outbound message. - - Defaults to the endpoint of the connection target. - """ - return self._endpoint or (self.target and self.target.endpoint) - - @endpoint.setter - def endpoint(self, endp: str) -> None: - """Set the endpoint of the outbound message.""" - self._endpoint = endp + self.target_list = list(target_list) if target_list else [] def __repr__(self) -> str: """ diff --git a/aries_cloudagent/transport/outbound/tests/test_http_transport.py b/aries_cloudagent/transport/outbound/tests/test_http_transport.py index 9e9a03ae3f..43d4bb44f9 100644 --- a/aries_cloudagent/transport/outbound/tests/test_http_transport.py +++ b/aries_cloudagent/transport/outbound/tests/test_http_transport.py @@ -3,9 +3,10 @@ from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop from aiohttp import web -from ....messaging.outbound_message import OutboundMessage from ....stats import Collector +from ...outbound.message import OutboundMessage + from ..http import HttpTransport @@ -30,27 +31,27 @@ async def get_application(self): async def test_handle_message(self): server_addr = f"http://localhost:{self.server.port}" - async def send_message(transport, message): + async def send_message(transport, payload, endpoint): async with transport: - await transport.handle_message(message) + await transport.handle_message(payload, endpoint) transport = HttpTransport() - message = OutboundMessage("{}", endpoint=server_addr) - await asyncio.wait_for(send_message(transport, message), 5.0) + await asyncio.wait_for(send_message(transport, "{}", endpoint=server_addr), 5.0) assert self.message_results == [{}] @unittest_run_loop async def test_stats(self): server_addr = f"http://localhost:{self.server.port}" - async def send_message(transport, message): + async def send_message(transport, payload, endpoint): async with transport: - await transport.handle_message(message) + await transport.handle_message(payload, endpoint) transport = HttpTransport() transport.collector = Collector() - message = OutboundMessage(b"{}", endpoint=server_addr) - await asyncio.wait_for(send_message(transport, message), 5.0) + await asyncio.wait_for( + send_message(transport, b"{}", endpoint=server_addr), 5.0 + ) results = transport.collector.extract() assert results["count"] == { diff --git a/aries_cloudagent/transport/outbound/tests/test_manager.py b/aries_cloudagent/transport/outbound/tests/test_manager.py index b3c6119c62..a8b10ef05b 100644 --- a/aries_cloudagent/transport/outbound/tests/test_manager.py +++ b/aries_cloudagent/transport/outbound/tests/test_manager.py @@ -1,51 +1,117 @@ import asyncio +import json from asynctest import TestCase as AsyncTestCase, mock as async_mock -from ....messaging.outbound_message import OutboundMessage +from ....config.injection_context import InjectionContext +from ....connections.models.connection_target import ConnectionTarget -from ..manager import OutboundTransportManager, OutboundTransportRegistrationError -from ..queue.basic import BasicOutboundMessageQueue +from ..manager import ( + OutboundDeliveryError, + OutboundTransportManager, + OutboundTransportRegistrationError, + QueuedOutboundMessage, +) +from ..message import OutboundMessage class TestOutboundTransportManager(AsyncTestCase): def test_register_path(self): - mgr = OutboundTransportManager() + mgr = OutboundTransportManager(InjectionContext()) mgr.register("http") assert mgr.get_registered_transport_for_scheme("http") with self.assertRaises(OutboundTransportRegistrationError): mgr.register("http") + async def test_setup(self): + context = InjectionContext() + context.update_settings({"transport.outbound_configs": ["http"]}) + mgr = OutboundTransportManager(context) + with async_mock.patch.object(mgr, "register") as mock_register: + await mgr.setup() + mock_register.assert_called_once_with("http") + async def test_send_message(self): - mgr = OutboundTransportManager(BasicOutboundMessageQueue()) + context = InjectionContext() + mgr = OutboundTransportManager(context) transport_cls = async_mock.Mock(spec=[]) with self.assertRaises(OutboundTransportRegistrationError): - mgr.register_class(transport_cls) + mgr.register_class(transport_cls, "transport_cls") transport = async_mock.MagicMock() transport.handle_message = async_mock.CoroutineMock() + transport.wire_format.encode_message = async_mock.CoroutineMock() transport.start = async_mock.CoroutineMock() transport.stop = async_mock.CoroutineMock() + transport.schemes = ["http"] transport_cls = async_mock.MagicMock() transport_cls.schemes = ["http"] transport_cls.return_value = transport - mgr.register_class(transport_cls) - assert mgr.get_registered_transport_for_scheme("http") is transport_cls + mgr.register_class(transport_cls, "transport_cls") + assert mgr.get_registered_transport_for_scheme("http") == "transport_cls" await mgr.start() - await asyncio.sleep(0.1) - transport.start.assert_called_once_with() - assert mgr.get_running_transport_for_scheme("http") is transport + await mgr.task_queue + transport.start.assert_awaited_once_with() + assert mgr.get_running_transport_for_scheme("http") == "transport_cls" - message = OutboundMessage("") - message.endpoint = "http://localhost" + message = OutboundMessage(payload="{}") + message.target = ConnectionTarget( + endpoint="http://localhost", + recipient_keys=[1, 2], + routing_keys=[3], + sender_key=4, + ) - await mgr.send_message(message) + mgr.enqueue_message(context, message) + await mgr.flush() + transport.wire_format.encode_message.assert_awaited_once_with( + context, + message.payload, + message.target.recipient_keys, + message.target.routing_keys, + message.target.sender_key, + ) + transport.handle_message.assert_awaited_once_with( + transport.wire_format.encode_message.return_value, message.target.endpoint + ) await mgr.stop() - transport.handle_message.assert_called_once_with(message) assert mgr.get_running_transport_for_scheme("http") is None - transport.stop.assert_called_once_with() + transport.stop.assert_awaited_once_with() + + async def test_enqueue_webhook(self): + context = InjectionContext() + mgr = OutboundTransportManager(context) + test_topic = "test-topic" + test_payload = {"test": "payload"} + test_endpoint = "http://example" + test_retries = 2 + + with self.assertRaises(OutboundDeliveryError): + mgr.enqueue_webhook( + test_topic, test_payload, test_endpoint, retries=test_retries + ) + + transport_cls = async_mock.MagicMock() + transport_cls.schemes = ["http"] + transport_cls.return_value = async_mock.MagicMock() + transport_cls.return_value.schemes = ["http"] + transport_cls.return_value.start = async_mock.CoroutineMock() + tid = mgr.register_class(transport_cls, "transport_cls") + await mgr.start_transport(tid) + + with async_mock.patch.object(mgr, "process_queued") as mock_process: + mgr.enqueue_webhook( + test_topic, test_payload, test_endpoint, retries=test_retries + ) + mock_process.assert_called_once_with() + assert len(mgr.outbound_new) == 1 + queued = mgr.outbound_new[0] + assert queued.endpoint == f"{test_endpoint}/topic/{test_topic}/" + assert json.loads(queued.payload) == test_payload + assert queued.retries == test_retries + assert queued.state == QueuedOutboundMessage.STATE_PENDING diff --git a/aries_cloudagent/transport/outbound/tests/test_ws_transport.py b/aries_cloudagent/transport/outbound/tests/test_ws_transport.py index 220f57b289..bccd320b22 100644 --- a/aries_cloudagent/transport/outbound/tests/test_ws_transport.py +++ b/aries_cloudagent/transport/outbound/tests/test_ws_transport.py @@ -4,8 +4,6 @@ from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop from aiohttp import web, WSMsgType -from ....messaging.outbound_message import OutboundMessage - from ..ws import WsTransport @@ -38,11 +36,10 @@ async def get_application(self): async def test_handle_message(self): server_addr = f"ws://localhost:{self.server.port}" - async def send_message(transport, message): + async def send_message(transport, payload, endpoint: str): async with transport: - await transport.handle_message(message) + await transport.handle_message(payload, endpoint) transport = WsTransport() - message = OutboundMessage("{}", endpoint=server_addr) - await asyncio.wait_for(send_message(transport, message), 5.0) + await asyncio.wait_for(send_message(transport, "{}", endpoint=server_addr), 5.0) assert self.message_results == [{}] diff --git a/aries_cloudagent/transport/outbound/ws.py b/aries_cloudagent/transport/outbound/ws.py index 25f52acd30..d77065615c 100644 --- a/aries_cloudagent/transport/outbound/ws.py +++ b/aries_cloudagent/transport/outbound/ws.py @@ -1,11 +1,10 @@ """Websockets outbound transport.""" import logging +from typing import Union from aiohttp import ClientSession, DummyCookieJar -from ...messaging.outbound_message import OutboundMessage - from .base import BaseOutboundTransport @@ -29,7 +28,7 @@ async def stop(self): await self.client_session.close() self.client_session = None - async def handle_message(self, message: OutboundMessage): + async def handle_message(self, payload: Union[str, bytes], endpoint: str): """ Handle message from queue. @@ -37,8 +36,8 @@ async def handle_message(self, message: OutboundMessage): message: `OutboundMessage` to send over transport implementation """ # aiohttp should automatically handle websocket sessions - async with self.client_session.ws_connect(message.endpoint) as ws: - if isinstance(message.payload, bytes): - await ws.send_bytes(message.payload) + async with self.client_session.ws_connect(endpoint) as ws: + if isinstance(payload, bytes): + await ws.send_bytes(payload) else: - await ws.send_str(message.payload) + await ws.send_str(payload) diff --git a/aries_cloudagent/messaging/serializer.py b/aries_cloudagent/transport/pack_format.py similarity index 50% rename from aries_cloudagent/messaging/serializer.py rename to aries_cloudagent/transport/pack_format.py index 7201a1a1fb..20e79bf16a 100644 --- a/aries_cloudagent/messaging/serializer.py +++ b/aries_cloudagent/transport/pack_format.py @@ -1,4 +1,4 @@ -"""Standard message serializer classes.""" +"""Standard packed message format classes.""" import json import logging @@ -7,38 +7,38 @@ from ..config.base import InjectorError from ..config.injection_context import InjectionContext from ..protocols.routing.messages.forward import Forward +from ..messaging.task_queue import TaskQueue +from ..messaging.util import time_now from ..wallet.base import BaseWallet from ..wallet.error import WalletError -from .error import MessageParseError -from .message_delivery import MessageDelivery -from .util import time_now +from .error import MessageParseError, MessageEncodeError +from .inbound.receipt import MessageReceipt +from .wire_format import BaseWireFormat LOGGER = logging.getLogger(__name__) -class MessageSerializer: +class PackWireFormat(BaseWireFormat): """Standard DIDComm message parser and serializer.""" def __init__(self): - """Initialize the message serializer instance.""" + """Initialize the pack wire format instance.""" + super().__init__() + self.task_queue: TaskQueue = None async def parse_message( - self, - context: InjectionContext, - message_body: Union[str, bytes], - transport_type: str, - ) -> Tuple[dict, MessageDelivery]: + self, context: InjectionContext, message_body: Union[str, bytes], + ) -> Tuple[dict, MessageReceipt]: """ Deserialize an incoming message and further populate the request context. Args: context: The injection context for settings and services message_body: The body of the message - transport_type: The transport the message was received on Returns: - A message delivery object with details on the parsed message + A tuple of the parsed message and a message receipt instance Raises: MessageParseError: If the JSON parsing failed @@ -46,10 +46,9 @@ async def parse_message( """ - delivery = MessageDelivery() - delivery.in_time = time_now() - delivery.raw_message = message_body - delivery.transport_type = transport_type + receipt = MessageReceipt() + receipt.in_time = time_now() + receipt.raw_message = message_body message_dict = None message_json = message_body @@ -66,22 +65,16 @@ async def parse_message( # packed messages are detected by the absence of @type if "@type" not in message_dict: - try: - wallet: BaseWallet = await context.inject(BaseWallet) - except InjectorError: - raise MessageParseError("Wallet not defined in request context") try: - unpacked = await wallet.unpack_message(message_body) - ( - message_json, - delivery.sender_verkey, - delivery.recipient_verkey, - ) = unpacked - except WalletError: + unpack = self.unpack(context, message_body, receipt) + message_json = await ( + self.task_queue and self.task_queue.run(unpack) or unpack + ) + except MessageParseError: LOGGER.debug("Message unpack failed, falling back to JSON") else: - delivery.raw_message = message_json + receipt.raw_message = message_json try: message_dict = json.loads(message_json) except ValueError: @@ -91,32 +84,37 @@ async def parse_message( # parse thread ID thread_dec = message_dict.get("~thread") - delivery.thread_id = ( + receipt.thread_id = ( thread_dec and thread_dec.get("thid") or message_dict.get("@id") ) # handle transport decorator transport_dec = message_dict.get("~transport") if transport_dec: - delivery.direct_response_requested = transport_dec.get("return_route") + receipt.direct_response_mode = transport_dec.get("return_route") LOGGER.debug(f"Expanded message: {message_dict}") - return message_dict, delivery + return message_dict, receipt - def extract_message_type(self, parsed_msg: dict) -> str: - """ - Extract the message type identifier from a parsed message. - - Raises: - MessageParseError: If the message doesn't specify a type - - """ + async def unpack( + self, + context: InjectionContext, + message_body: Union[str, bytes], + receipt: MessageReceipt, + ): + """Look up the wallet instance and perform the message unpack.""" + try: + wallet: BaseWallet = await context.inject(BaseWallet) + except InjectorError: + raise MessageParseError("Wallet not defined in request context") - msg_type = parsed_msg.get("@type") - if not msg_type: - raise MessageParseError("Message does not contain '@type' parameter") - return msg_type + try: + unpacked = await wallet.unpack_message(message_body) + (message_json, receipt.sender_verkey, receipt.recipient_verkey,) = unpacked + return message_json + except WalletError as e: + raise MessageParseError("Message unpack failed") from e async def encode_message( self, @@ -139,21 +137,51 @@ async def encode_message( Returns: The encoded message - """ + Raises: + MessageEncodeError: If the message could not be encoded - wallet: BaseWallet = await context.inject(BaseWallet) + """ if sender_key and recipient_keys: + pack = self.pack( + context, message_json, recipient_keys, routing_keys, sender_key + ) + message = await (self.task_queue and self.task_queue.run(pack) or pack) + else: + message = message_json + return message + + async def pack( + self, + context: InjectionContext, + message_json: Union[str, bytes], + recipient_keys: Sequence[str], + routing_keys: Sequence[str], + sender_key: str, + ): + """Look up the wallet instance and perform the message pack.""" + if not sender_key or not recipient_keys: + raise MessageEncodeError("Cannot pack message without associated keys") + + wallet: BaseWallet = await context.inject(BaseWallet, required=False) + if not wallet: + raise MessageEncodeError("No wallet instance") + + try: message = await wallet.pack_message( message_json, recipient_keys, sender_key ) - if routing_keys: - recip_keys = recipient_keys - for router_key in routing_keys: - fwd_msg = Forward(to=recip_keys[0], msg=message) - # Forwards are anon packed - recip_keys = [router_key] + except WalletError as e: + raise MessageEncodeError("Message pack failed") from e + + if routing_keys: + recip_keys = recipient_keys + for router_key in routing_keys: + fwd_msg = Forward(to=recip_keys[0], msg=message) + # Forwards are anon packed + recip_keys = [router_key] + try: message = await wallet.pack_message(fwd_msg.to_json(), recip_keys) - else: - message = message_json + except WalletError as e: + raise MessageEncodeError("Forward message pack failed") from e return message diff --git a/aries_cloudagent/transport/outbound/queue/__init__.py b/aries_cloudagent/transport/queue/__init__.py similarity index 100% rename from aries_cloudagent/transport/outbound/queue/__init__.py rename to aries_cloudagent/transport/queue/__init__.py diff --git a/aries_cloudagent/transport/outbound/queue/base.py b/aries_cloudagent/transport/queue/base.py similarity index 92% rename from aries_cloudagent/transport/outbound/queue/base.py rename to aries_cloudagent/transport/queue/base.py index 9078bf5023..bd5959ce53 100644 --- a/aries_cloudagent/transport/outbound/queue/base.py +++ b/aries_cloudagent/transport/queue/base.py @@ -1,11 +1,11 @@ -"""Abstract outbound queue.""" +"""Abstract message queue.""" from abc import ABC, abstractmethod import asyncio -class BaseOutboundMessageQueue(ABC): - """Abstract outbound queue class.""" +class BaseMessageQueue(ABC): + """Abstract message queue class.""" @abstractmethod async def enqueue(self, message): diff --git a/aries_cloudagent/transport/outbound/queue/basic.py b/aries_cloudagent/transport/queue/basic.py similarity index 88% rename from aries_cloudagent/transport/outbound/queue/basic.py rename to aries_cloudagent/transport/queue/basic.py index 98f5533138..bc4e251dbc 100644 --- a/aries_cloudagent/transport/outbound/queue/basic.py +++ b/aries_cloudagent/transport/queue/basic.py @@ -2,27 +2,22 @@ import asyncio import logging -import os -from .base import BaseOutboundMessageQueue +from .base import BaseMessageQueue -class BasicOutboundMessageQueue(BaseOutboundMessageQueue): +class BasicMessageQueue(BaseMessageQueue): """Basic in memory queue implementation class.""" def __init__(self): - """Initialize a `BasicOutboundMessageQueue` instance.""" + """Initialize a `BasicMessageQueue` instance.""" self.queue = self.make_queue() self.logger = logging.getLogger(__name__) self.stop_event = asyncio.Event() def make_queue(self): """Create the queue instance.""" - queue_size = os.environ.get("QUEUE_SIZE") - if queue_size: - return asyncio.Queue(maxsize=int(queue_size)) - else: - return asyncio.Queue() + return asyncio.Queue() async def enqueue(self, message): """ diff --git a/aries_cloudagent/transport/outbound/queue/tests/__init__.py b/aries_cloudagent/transport/queue/tests/__init__.py similarity index 100% rename from aries_cloudagent/transport/outbound/queue/tests/__init__.py rename to aries_cloudagent/transport/queue/tests/__init__.py diff --git a/aries_cloudagent/transport/outbound/queue/tests/test_basic_queue.py b/aries_cloudagent/transport/queue/tests/test_basic_queue.py similarity index 89% rename from aries_cloudagent/transport/outbound/queue/tests/test_basic_queue.py rename to aries_cloudagent/transport/queue/tests/test_basic_queue.py index f0b8b43e33..91115e7074 100644 --- a/aries_cloudagent/transport/outbound/queue/tests/test_basic_queue.py +++ b/aries_cloudagent/transport/queue/tests/test_basic_queue.py @@ -2,7 +2,7 @@ from asynctest import TestCase as AsyncTestCase -from ..basic import BasicOutboundMessageQueue +from ..basic import BasicMessageQueue async def collect(queue, count=1): @@ -16,7 +16,7 @@ async def collect(queue, count=1): class TestBasicQueue(AsyncTestCase): async def test_enqueue_dequeue(self): - queue = BasicOutboundMessageQueue() + queue = BasicMessageQueue() with self.assertRaises(asyncio.TimeoutError): await queue.dequeue(timeout=0) @@ -28,7 +28,7 @@ async def test_enqueue_dequeue(self): await queue.dequeue(timeout=0) async def test_async_iter(self): - queue = BasicOutboundMessageQueue() + queue = BasicMessageQueue() results = asyncio.wait_for(collect(queue), timeout=1.0) test_value = "test value" @@ -37,7 +37,7 @@ async def test_async_iter(self): assert found == [test_value] async def test_stopped(self): - queue = BasicOutboundMessageQueue() + queue = BasicMessageQueue() queue.stop() with self.assertRaises(asyncio.CancelledError): diff --git a/aries_cloudagent/messaging/tests/test_serializer.py b/aries_cloudagent/transport/tests/test_pack_format.py similarity index 74% rename from aries_cloudagent/messaging/tests/test_serializer.py rename to aries_cloudagent/transport/tests/test_pack_format.py index 1a893f23e7..c790c36509 100644 --- a/aries_cloudagent/messaging/tests/test_serializer.py +++ b/aries_cloudagent/transport/tests/test_pack_format.py @@ -8,10 +8,10 @@ from ...wallet.basic import BasicWallet from ..error import MessageParseError -from ..serializer import MessageSerializer +from ..pack_format import PackWireFormat -class TestMessageSerializer(AsyncTestCase): +class TestPackWireFormat(AsyncTestCase): test_message_type = "PROTOCOL/MESSAGE" test_message_id = "MESSAGE_ID" test_content = "CONTENT" @@ -23,7 +23,6 @@ class TestMessageSerializer(AsyncTestCase): "~transport": {"return_route": "all"}, "content": test_content, } - test_transport_type = "http" test_seed = "testseed000000000000000000000001" test_routing_seed = "testseed000000000000000000000002" @@ -33,44 +32,43 @@ def setUp(self): self.context.injector.bind_instance(BaseWallet, self.wallet) async def test_errors(self): - serializer = MessageSerializer() + serializer = PackWireFormat() bad_values = [None, "", "1", "[]", "{..."] for message_json in bad_values: with self.assertRaises(MessageParseError): message_dict, delivery = await serializer.parse_message( - self.context, message_json, self.test_transport_type + self.context, message_json ) async def test_unpacked(self): - serializer = MessageSerializer() + serializer = PackWireFormat() message_json = json.dumps(self.test_message) message_dict, delivery = await serializer.parse_message( - self.context, message_json, self.test_transport_type + self.context, message_json ) assert message_dict == self.test_message - assert serializer.extract_message_type(message_dict) == self.test_message_type + assert message_dict["@type"] == self.test_message_type assert delivery.thread_id == self.test_thread_id - assert delivery.direct_response_requested == "all" + assert delivery.direct_response_mode == "all" async def test_fallback(self): - serializer = MessageSerializer() + serializer = PackWireFormat() message = self.test_message.copy() message.pop("@type") message_json = json.dumps(message) message_dict, delivery = await serializer.parse_message( - self.context, message_json, self.test_transport_type + self.context, message_json ) assert delivery.raw_message == message_json - with self.assertRaises(MessageParseError): - serializer.extract_message_type(message_dict) + assert message_dict == message async def test_encode_decode(self): local_did = await self.wallet.create_local_did(self.test_seed) - serializer = MessageSerializer() + serializer = PackWireFormat() recipient_keys = (local_did.verkey,) routing_keys = () sender_key = local_did.verkey @@ -84,17 +82,17 @@ async def test_encode_decode(self): assert isinstance(packed, dict) and "protected" in packed message_dict, delivery = await serializer.parse_message( - self.context, packed_json, self.test_transport_type + self.context, packed_json ) assert message_dict == self.test_message - assert serializer.extract_message_type(message_dict) == self.test_message_type + assert message_dict["@type"] == self.test_message_type assert delivery.thread_id == self.test_thread_id - assert delivery.direct_response_requested == "all" + assert delivery.direct_response_mode == "all" async def test_forward(self): local_did = await self.wallet.create_local_did(self.test_seed) router_did = await self.wallet.create_local_did(self.test_routing_seed) - serializer = MessageSerializer() + serializer = PackWireFormat() recipient_keys = (local_did.verkey,) routing_keys = (router_did.verkey,) sender_key = local_did.verkey @@ -108,9 +106,8 @@ async def test_forward(self): assert isinstance(packed, dict) and "protected" in packed message_dict, delivery = await serializer.parse_message( - self.context, packed_json, self.test_transport_type + self.context, packed_json ) - assert serializer.extract_message_type(message_dict) == FORWARD + assert message_dict["@type"] == FORWARD assert delivery.recipient_verkey == router_did.verkey assert delivery.sender_verkey is None - diff --git a/aries_cloudagent/transport/wire_format.py b/aries_cloudagent/transport/wire_format.py new file mode 100644 index 0000000000..8ad93e3df3 --- /dev/null +++ b/aries_cloudagent/transport/wire_format.py @@ -0,0 +1,150 @@ +"""Abstract wire format classes.""" + +import json +import logging + +from abc import abstractmethod +from typing import Sequence, Tuple, Union + +from ..config.injection_context import InjectionContext +from ..messaging.util import time_now + +from .inbound.receipt import MessageReceipt +from .error import MessageParseError + +LOGGER = logging.getLogger(__name__) + + +class BaseWireFormat: + """Abstract messaging wire format.""" + + def __init__(self): + """Initialize the base wire format instance.""" + + @abstractmethod + async def parse_message( + self, context: InjectionContext, message_body: Union[str, bytes], + ) -> Tuple[dict, MessageReceipt]: + """ + Deserialize an incoming message and further populate the request context. + + Args: + context: The injection context for settings and services + message_body: The body of the message + + Returns: + A tuple of the parsed message and a message receipt instance + + Raises: + MessageParseError: If the message can't be parsed + + """ + + @abstractmethod + async def encode_message( + self, + context: InjectionContext, + message_json: Union[str, bytes], + recipient_keys: Sequence[str], + routing_keys: Sequence[str], + sender_key: str, + ) -> Union[str, bytes]: + """ + Encode an outgoing message for transport. + + Args: + context: The injection context for settings and services + message_json: The message body to serialize + recipient_keys: A sequence of recipient verkeys + routing_keys: A sequence of routing verkeys + sender_key: The verification key of the sending agent + + Returns: + The encoded message + + Raises: + MessageEncodeError: If the message could not be encoded + + """ + + +class JsonWireFormat(BaseWireFormat): + """Unencrypted wire format.""" + + @abstractmethod + async def parse_message( + self, context: InjectionContext, message_body: Union[str, bytes], + ) -> Tuple[dict, MessageReceipt]: + """ + Deserialize an incoming message and further populate the request context. + + Args: + context: The injection context for settings and services + message_body: The body of the message + + Returns: + A tuple of the parsed message and a message receipt instance + + Raises: + MessageParseError: If the JSON parsing failed + + """ + receipt = MessageReceipt() + receipt.in_time = time_now() + receipt.raw_message = message_body + + message_dict = None + message_json = message_body + + if not message_json: + raise MessageParseError("Message body is empty") + + try: + message_dict = json.loads(message_json) + except ValueError: + raise MessageParseError("Message JSON parsing failed") + if not isinstance(message_dict, dict): + raise MessageParseError("Message JSON result is not an object") + + # parse thread ID + thread_dec = message_dict.get("~thread") + receipt.thread_id = ( + thread_dec and thread_dec.get("thid") or message_dict.get("@id") + ) + + # handle transport decorator + transport_dec = message_dict.get("~transport") + if transport_dec: + receipt.direct_response_mode = transport_dec.get("return_route") + + LOGGER.debug(f"Expanded message: {message_dict}") + + return message_dict, receipt + + @abstractmethod + async def encode_message( + self, + context: InjectionContext, + message_json: Union[str, bytes], + recipient_keys: Sequence[str], + routing_keys: Sequence[str], + sender_key: str, + ) -> Union[str, bytes]: + """ + Encode an outgoing message for transport. + + Args: + context: The injection context for settings and services + message_json: The message body to serialize + recipient_keys: A sequence of recipient verkeys + routing_keys: A sequence of routing verkeys + sender_key: The verification key of the sending agent + + Returns: + The encoded message + + Raises: + MessageEncodeError: If the message could not be encoded + + """ + return message_json diff --git a/demo/runners/support/agent.py b/demo/runners/support/agent.py index 3ecbcb3fe8..50b3d70796 100644 --- a/demo/runners/support/agent.py +++ b/demo/runners/support/agent.py @@ -319,14 +319,14 @@ async def _receive_webhook(self, request: ClientRequest): topic = request.match_info["topic"] payload = await request.json() await self.handle_webhook(topic, payload) - return web.Response(text="") + return web.Response(status=200) async def handle_webhook(self, topic: str, payload): if topic != "webhook": # would recurse handler = f"handle_{topic}" method = getattr(self, handler, None) if method: - await method(payload) + asyncio.get_event_loop().create_task(method(payload)) else: log_msg( f"Error: agent {self.ident} " diff --git a/scripts/run_docker b/scripts/run_docker index 40b50d07e9..72fc4e9693 100755 --- a/scripts/run_docker +++ b/scripts/run_docker @@ -11,7 +11,12 @@ done PTVSD_PORT="${PTVSD_PORT-5678}" -if [ ! -z "${ENABLE_PTVSD}" ] || [[ "$@" == *--debug* ]]; then +for arg in $@; do + if [ "$arg" = "--timing" ]; then + ENABLE_PTVSD=1 + fi +done +if [ ! -z "${ENABLE_PTVSD}" ]; then ARGS="${ARGS} -e ENABLE_PTVSD=\"${ENABLE_PTVSD}\" -p $PTVSD_PORT:$PTVSD_PORT" fi