diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index c3163033d3..6240943330 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -85,10 +85,9 @@ def __init__( ): """Initialize the webhook target.""" self.endpoint = endpoint - self._topic_filter = None self.retries = retries - # call setter - self.topic_filter = topic_filter + self._topic_filter = None + self.topic_filter = topic_filter # call setter @property def topic_filter(self) -> Set[str]: @@ -176,6 +175,8 @@ async def check_token(request, handler): middlewares.append(check_token) + collector: Collector = await self.context.inject(Collector, required=False) + if self.task_queue: @web.middleware @@ -185,14 +186,11 @@ async def apply_limiter(request, handler): middlewares.append(apply_limiter) - stats: Collector = await self.context.inject(Collector, required=False) - if stats: + elif collector: @web.middleware async def collect_stats(request, handler): - handler = stats.wrap_coro( - handler, [handler.__qualname__, "any-admin-request"] - ) + handler = collector.wrap_coro(handler, [handler.__qualname__]) return await handler(request) middlewares.append(collect_stats) @@ -231,7 +229,7 @@ async def collect_stats(request, handler): for route in app.router.routes(): cors.add(route) # get agent label - agent_label = self.context.settings.get("default_label"), + agent_label = self.context.settings.get("default_label") version_string = f"v{__version__}" setup_aiohttp_apispec( @@ -288,7 +286,6 @@ async def plugins_handler(self, request: web.BaseRequest): registry: PluginRegistry = await self.context.inject( PluginRegistry, required=False ) - print(registry) plugins = registry and sorted(registry.plugin_names) or [] return web.json_response({"result": plugins}) diff --git a/aries_cloudagent/conductor.py b/aries_cloudagent/conductor.py index af7e32e4c5..725d6438c3 100644 --- a/aries_cloudagent/conductor.py +++ b/aries_cloudagent/conductor.py @@ -64,6 +64,7 @@ async def setup(self): context = await self.context_builder.build() self.dispatcher = Dispatcher(context) + await self.dispatcher.setup() wire_format = await context.inject(BaseWireFormat, required=False) if wire_format and hasattr(wire_format, "task_queue"): @@ -118,12 +119,11 @@ async def setup(self): # "create_inbound_session", ), ) - collector.wrap(self.dispatcher, "handle_message") # at the class level (!) should not be performed multiple times collector.wrap( ConnectionManager, ( - "get_connection_targets", + # "get_connection_targets", "fetch_did_document", "find_inbound_connection", ), @@ -214,6 +214,8 @@ async def start(self) -> None: async def stop(self, timeout=1.0): """Stop the agent.""" shutdown = TaskQueue() + if self.dispatcher: + shutdown.run(self.dispatcher.complete()) if self.admin_server: shutdown.run(self.admin_server.stop()) if self.inbound_transport_manager: diff --git a/aries_cloudagent/config/default_context.py b/aries_cloudagent/config/default_context.py index e1501762ad..7ce066251b 100644 --- a/aries_cloudagent/config/default_context.py +++ b/aries_cloudagent/config/default_context.py @@ -20,7 +20,6 @@ from ..stats import Collector from ..storage.base import BaseStorage from ..storage.provider import StorageProvider -from ..transport.pack_format import PackWireFormat from ..transport.wire_format import BaseWireFormat from ..wallet.base import BaseWallet from ..wallet.provider import WalletProvider @@ -67,14 +66,12 @@ async def bind_providers(self, context: InjectionContext): StatsProvider( WalletProvider(), ( - "create", - "open", "sign_message", "verify_message", "encrypt_message", "decrypt_message", - "pack_message", - "unpack_message", + # "pack_message", + # "unpack_message", "get_local_did", ), ) @@ -128,7 +125,12 @@ async def bind_providers(self, context: InjectionContext): BaseWireFormat, CachedProvider( StatsProvider( - ClassProvider(PackWireFormat), ("encode_message", "parse_message"), + ClassProvider( + "aries_cloudagent.transport.pack_format.PackWireFormat" + ), + ( + # "encode_message", "parse_message" + ), ) ), ) diff --git a/aries_cloudagent/config/provider.py b/aries_cloudagent/config/provider.py index e51da07b15..3aba0f39a8 100644 --- a/aries_cloudagent/config/provider.py +++ b/aries_cloudagent/config/provider.py @@ -103,7 +103,10 @@ def __init__( async def provide(self, config: BaseSettings, injector: BaseInjector): """Provide the object instance given a config and injector.""" instance = await self._provider.provide(config, injector) - collector: Collector = await injector.inject(Collector, required=False) - if collector: - collector.wrap(instance, self._methods, ignore_missing=self._ignore_missing) + if self._methods: + collector: Collector = await injector.inject(Collector, required=False) + if collector: + collector.wrap( + instance, self._methods, ignore_missing=self._ignore_missing + ) return instance diff --git a/aries_cloudagent/dispatcher.py b/aries_cloudagent/dispatcher.py index b976d98c4c..10cd2c8443 100644 --- a/aries_cloudagent/dispatcher.py +++ b/aries_cloudagent/dispatcher.py @@ -10,6 +10,8 @@ import os from typing import Callable, Coroutine, Union +from aiohttp.web import HTTPException + from .config.injection_context import InjectionContext from .messaging.agent_message import AgentMessage from .messaging.error import MessageParseError @@ -19,7 +21,7 @@ from .messaging.protocol_registry import ProtocolRegistry from .messaging.request_context import RequestContext from .messaging.responder import BaseResponder -from .messaging.task_queue import TaskQueue +from .messaging.task_queue import CompletedTask, PendingTask, TaskQueue from .messaging.util import datetime_now from .stats import Collector from .transport.inbound.message import InboundMessage @@ -39,16 +41,44 @@ class Dispatcher: def __init__(self, context: InjectionContext): """Initialize an instance of Dispatcher.""" self.context = context - max_active = int(os.getenv("DISPATCHER_MAX_ACTIVE", 100)) - self.task_queue = TaskQueue(max_active=max_active) + self.collector: Collector = None + self.task_queue: TaskQueue = None + + async def setup(self): + """Perform async instance setup.""" + self.collector = await self.context.inject(Collector, required=False) + max_active = int(os.getenv("DISPATCHER_MAX_ACTIVE", 50)) + self.task_queue = TaskQueue( + max_active=max_active, timed=bool(self.collector), trace_fn=self.log_task + ) - def put_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Future: + def put_task( + self, coro: Coroutine, complete: Callable = None, ident: str = None + ) -> PendingTask: """Run a task in the task queue, potentially blocking other handlers.""" - return self.task_queue.put(coro, complete) + return self.task_queue.put(coro, complete, ident) - def run_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Task: + def run_task( + self, coro: Coroutine, complete: Callable = None, ident: str = None + ) -> asyncio.Task: """Run a task in the task queue, potentially blocking other handlers.""" - return self.task_queue.run(coro, complete) + return self.task_queue.run(coro, complete, ident) + + def log_task(self, task: CompletedTask): + """Log a completed task using the stats collector.""" + if task.exc_info and not issubclass(task.exc_info[0], HTTPException): + # skip errors intentionally returned to HTTP clients + LOGGER.exception( + "Handler error: %s", task.ident or "", exc_info=task.exc_info + ) + if self.collector: + timing = task.timing + if "queued" in timing: + self.collector.log( + f"Dispatcher:queued", timing["unqueued"] - timing["queued"] + ) + if task.ident: + self.collector.log(task.ident, timing["ended"] - timing["started"]) def queue_message( self, @@ -56,7 +86,7 @@ def queue_message( send_outbound: Coroutine, send_webhook: Coroutine = None, complete: Callable = None, - ) -> asyncio.Future: + ) -> PendingTask: """ Add a message to the processing queue for handling. @@ -67,7 +97,7 @@ def queue_message( complete: Function to call when the handler has completed Returns: - A future resolving to the handler task + A pending task instance resolving to the handler task """ return self.put_task( @@ -133,11 +163,10 @@ async def handle_message( context.injector.bind_instance(BaseResponder, responder) handler_cls = context.message.Handler - handler_obj = handler_cls() - collector: Collector = await context.inject(Collector, required=False) - if collector: - collector.wrap(handler_obj, "handle", ["any-message-handler"]) - await handler_obj.handle(context, responder) + handler = handler_cls().handle + if self.collector: + handler = self.collector.wrap_coro(handler, [handler.__qualname__]) + await handler(context, responder) async def make_message(self, parsed_msg: dict) -> AgentMessage: """ @@ -175,6 +204,10 @@ async def make_message(self, parsed_msg: dict) -> AgentMessage: return instance + async def complete(self, timeout: float = 0.1): + """Wait for pending tasks to complete.""" + await self.task_queue.complete(timeout=timeout) + class DispatcherResponder(BaseResponder): """Handle outgoing messages from message handlers.""" diff --git a/aries_cloudagent/messaging/task_queue.py b/aries_cloudagent/messaging/task_queue.py index bba539d577..04256ed3f8 100644 --- a/aries_cloudagent/messaging/task_queue.py +++ b/aries_cloudagent/messaging/task_queue.py @@ -13,6 +13,15 @@ def coro_ident(coro: Coroutine): return coro and (hasattr(coro, "__qualname__") and coro.__qualname__ or repr(coro)) +async def coro_timed(coro: Coroutine, timing: dict): + """Capture timing for a coroutine.""" + timing["started"] = time.perf_counter() + try: + return await coro + finally: + timing["ended"] = time.perf_counter() + + def task_exc_info(task: asyncio.Task): """Extract exception info from an asyncio task.""" if not task or not task.done(): @@ -41,6 +50,10 @@ def __init__( self.task = task self.timing = timing + def __repr__(self) -> str: + """Generate string representation for logging.""" + return f"<{self.__class__.__name__} ident={self.ident} timing={self.timing}>" + class PendingTask: """Represent a task in the queue.""" @@ -51,6 +64,7 @@ def __init__( complete_hook: Callable = None, ident: str = None, task_future: asyncio.Future = None, + queued_time: float = None, ): """ Initialize the pending task. @@ -60,13 +74,15 @@ def __init__( complete_hook: A callback to run on completion ident: A string identifier for the task task_future: A future to be resolved to the asyncio Task + queued_time: When the pending task was added to the queue """ if not asyncio.iscoroutine(coro): raise ValueError(f"Expected coroutine, got {coro}") self._cancelled = False self.complete_hook = complete_hook self.coro = coro - self.created_time: float = time.perf_counter() + self.queued_time: float = queued_time + self.unqueued_time: float = None self.ident = ident or coro_ident(coro) self.task_future = task_future or asyncio.get_event_loop().create_future() @@ -100,22 +116,33 @@ def __await__(self): """Wait for the task to be queued.""" return self.task_future.__await__() + def __repr__(self) -> str: + """Generate string representation for logging.""" + return f"<{self.__class__.__name__} ident={self.ident}>" + class TaskQueue: """A class for managing a set of asyncio tasks.""" - def __init__(self, max_active: int = 0): + def __init__( + self, max_active: int = 0, timed: bool = False, trace_fn: Callable = None + ): """ Initialize the task queue. Args: max_active: The maximum number of tasks to automatically run + timed: A flag indicating that timing should be collected for tasks + trace_fn: A callback for all completed tasks """ self.loop = asyncio.get_event_loop() self.active_tasks = [] self.pending_tasks = [] + self.timed = timed self.total_done = 0 self.total_failed = 0 + self.total_started = 0 + self._trace_fn = trace_fn self._cancelled = False self._drain_evt = asyncio.Event() self._drain_task: asyncio.Task = None @@ -155,6 +182,14 @@ def current_size(self) -> int: """Accessor for the total number of tasks in the queue.""" return len(self.active_tasks) + len(self.pending_tasks) + def __bool__(self) -> bool: + """ + Support for the bool() builtin. + + Otherwise, evaluates as false when there are no tasks. + """ + return True + def __len__(self) -> int: """Support for the len() builtin.""" return self.current_size @@ -186,7 +221,14 @@ async def _drain_loop(self): not self._max_active or len(self.active_tasks) < self._max_active ): pending: PendingTask = self.pending_tasks.pop(0) - timing = {"queued": pending.created_time} + if pending.queued_time: + pending.unqueued_time = time.perf_counter() + timing = { + "queued": pending.queued_time, + "unqueued": pending.unqueued_time, + } + else: + timing = None task = self.run( pending.coro, pending.complete_hook, pending.ident, timing ) @@ -206,6 +248,8 @@ def add_pending(self, pending: PendingTask): Args: pending: The `PendingTask` to add to the task queue """ + if self.timed and not pending.queued_time: + pending.queued_time = time.perf_counter() self.pending_tasks.append(pending) self.drain() @@ -229,6 +273,7 @@ def add_active( task.add_done_callback( lambda fut: self.completed_task(task, task_complete, ident, timing) ) + self.total_started += 1 return task def run( @@ -256,9 +301,10 @@ def run( raise ValueError(f"Expected coroutine, got {coro}") if not ident: ident = coro_ident(coro) - if not timing: - timing = dict() - timing["start_time"] = time.perf_counter() + if self.timed: + if not timing: + timing = dict() + coro = coro_timed(coro, timing) task = self.loop.create_task(coro) return self.add_active(task, task_complete, ident, timing) @@ -286,22 +332,31 @@ def put( return pending def completed_task( - self, task: asyncio.Task, task_complete: Callable, ident: str, timing: dict + self, + task: asyncio.Task, + task_complete: Callable, + ident: str, + timing: dict = None, ): """Clean up after a task has completed and run callbacks.""" - timing["end_time"] = time.perf_counter() exc_info = task_exc_info(task) if exc_info: self.total_failed += 1 - if not task_complete: - LOGGER.exception("Error running task", exc_info=exc_info) + if not task_complete and not self._trace_fn: + LOGGER.exception( + "Error running task %s", ident or "", exc_info=exc_info + ) else: self.total_done += 1 - if task_complete: + if task_complete or self._trace_fn: + completed = CompletedTask(task, exc_info, ident, timing) try: - task_complete(CompletedTask(task, exc_info, ident, timing)) + if task_complete: + task_complete(completed) + if self._trace_fn: + self._trace_fn(completed) except Exception: - LOGGER.exception("Error finalizing task") + LOGGER.exception("Error finalizing task %s", completed) try: self.active_tasks.remove(task) except ValueError: diff --git a/aries_cloudagent/messaging/tests/test_task_queue.py b/aries_cloudagent/messaging/tests/test_task_queue.py index 9fce1f315c..fc3fad77a3 100644 --- a/aries_cloudagent/messaging/tests/test_task_queue.py +++ b/aries_cloudagent/messaging/tests/test_task_queue.py @@ -140,7 +140,7 @@ def done(complete: CompletedTask): async def test_cancel_long(self): queue = TaskQueue() - task = queue.run(asyncio.sleep(5)) + task = queue.run(retval(1, delay=5)) queue.cancel() await queue @@ -152,7 +152,7 @@ async def test_cancel_long(self): async def test_complete_with_timeout(self): queue = TaskQueue() - task = queue.run(asyncio.sleep(5)) + task = queue.run(retval(1, delay=5)) await queue.complete(0.01) # cancellation may take a second @@ -175,3 +175,19 @@ def done(complete: CompletedTask): await task queue.completed_task(task, done, None, dict()) assert completed == [1, 1] + + async def test_timed(self): + completed = [] + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.append((complete.task.result(), complete.timing)) + + queue = TaskQueue(max_active=1, timed=True, trace_fn=done) + task1 = queue.run(retval(1)) + task2 = await queue.put(retval(2)) + await queue.complete(0.1) + + assert len(completed) == 2 + assert "queued" not in completed[0][1] + assert "queued" in completed[1][1] diff --git a/aries_cloudagent/stats.py b/aries_cloudagent/stats.py index a41e42f191..e57d208cf7 100644 --- a/aries_cloudagent/stats.py +++ b/aries_cloudagent/stats.py @@ -82,7 +82,7 @@ def stop(self): if self.start_time: dur = self.now() - self.start_time for grp in self.groups: - self.collector.log(grp, dur) + self.collector.log(grp, dur, self.start_time) self.start_time = None def __enter__(self): @@ -124,12 +124,13 @@ def enabled(self, val: bool): """Setter for the collector's enabled property.""" self._enabled = val - def log(self, name: str, duration: float): + def log(self, name: str, duration: float, start: float = None): """Log an entry in the statistics if the collector is enabled.""" if self._enabled: self._stats.log(name, duration) if self._log_file: - start = time.perf_counter() - duration + if start is None: + start = time.perf_counter() - duration self._log_file.write(f"{name} {start:.5f} {duration:.5f}\n") def mark(self, *names): diff --git a/aries_cloudagent/tests/test_dispatcher.py b/aries_cloudagent/tests/test_dispatcher.py index 1cce81907b..66c1b5e7df 100644 --- a/aries_cloudagent/tests/test_dispatcher.py +++ b/aries_cloudagent/tests/test_dispatcher.py @@ -68,6 +68,7 @@ async def test_dispatch(self): {StubAgentMessage.Meta.message_type: StubAgentMessage} ) dispatcher = test_module.Dispatcher(context) + await dispatcher.setup() rcv = Receiver() message = {"@type": StubAgentMessage.Meta.message_type} @@ -84,6 +85,7 @@ async def test_dispatch(self): async def test_bad_message_dispatch(self): dispatcher = test_module.Dispatcher(make_context()) + await dispatcher.setup() rcv = Receiver() bad_message = {"bad": "message"} await dispatcher.queue_message(make_inbound(bad_message), rcv.send)