Skip to content

Commit

Permalink
Merge pull request #269 from andrewwhitehead/refactor/transport
Browse files Browse the repository at this point in the history
Message transport refactoring
  • Loading branch information
swcurran authored Nov 29, 2019
2 parents 456280e + ef49f83 commit b333150
Show file tree
Hide file tree
Showing 80 changed files with 3,320 additions and 1,918 deletions.
178 changes: 82 additions & 96 deletions aries_cloudagent/admin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@

import asyncio
import logging
from typing import Coroutine, Sequence, Set
from typing import Callable, Coroutine, Sequence, Set
import uuid

from aiohttp import web, ClientSession, DummyCookieJar
from aiohttp import web
from aiohttp_apispec import docs, response_schema, setup_aiohttp_apispec
import aiohttp_cors

from marshmallow import fields, Schema

from ..config.injection_context import InjectionContext
from ..messaging.outbound_message import OutboundMessage
from ..messaging.plugin_registry import PluginRegistry
from ..messaging.responder import BaseResponder
from ..messaging.task_queue import TaskQueue
from ..stats import Collector
from ..task_processor import TaskProcessor
from ..transport.outbound.queue.base import BaseOutboundMessageQueue
from ..transport.stats import StatsTracer
from ..transport.queue.basic import BasicMessageQueue
from ..transport.outbound.message import OutboundMessage

from .base_server import BaseAdminServer
from .error import AdminSetupError
Expand All @@ -43,7 +42,9 @@ class AdminStatusSchema(Schema):
class AdminResponder(BaseResponder):
"""Handle outgoing messages from message handlers."""

def __init__(self, send: Coroutine, webhook: Coroutine, **kwargs):
def __init__(
self, context: InjectionContext, send: Coroutine, webhook: Coroutine, **kwargs
):
"""
Initialize an instance of `AdminResponder`.
Expand All @@ -52,6 +53,7 @@ def __init__(self, send: Coroutine, webhook: Coroutine, **kwargs):
"""
super().__init__(**kwargs)
self._context = context
self._send = send
self._webhook = webhook

Expand All @@ -62,7 +64,7 @@ async def send_outbound(self, message: OutboundMessage):
Args:
message: The `OutboundMessage` to be sent
"""
await self._send(message)
await self._send(self._context, message)

async def send_webhook(self, topic: str, payload: dict):
"""
Expand Down Expand Up @@ -111,30 +113,36 @@ def __init__(
port: int,
context: InjectionContext,
outbound_message_router: Coroutine,
webhook_router: Callable,
task_queue: TaskQueue = None,
conductor_stats: Coroutine = None,
):
"""
Initialize an AdminServer instance.
Args:
host: Host to listen on
port: Port to listen on
context: The application context instance
outbound_message_router: Coroutine for delivering outbound messages
webhook_router: Callable for delivering webhooks
task_queue: An optional task queue for handlers
"""
self.app = None
self.host = host
self.port = port
self.conductor_stats = conductor_stats
self.loaded_modules = []
self.webhook_queue = None
self.webhook_retries = 5
self.webhook_session: ClientSession = None
self.task_queue = task_queue
self.webhook_router = webhook_router
self.webhook_targets = {}
self.webhook_task = None
self.webhook_processor: TaskProcessor = None
self.websocket_queues = {}
self.site = None

self.context = context.start_scope("admin")
self.responder = AdminResponder(outbound_message_router, self.send_webhook)
self.responder = AdminResponder(
self.context, outbound_message_router, self.send_webhook
)
self.context.injector.bind_instance(BaseResponder, self.responder)

async def make_application(self) -> web.Application:
Expand Down Expand Up @@ -168,6 +176,15 @@ async def check_token(request, handler):

middlewares.append(check_token)

if self.task_queue:

@web.middleware
async def apply_limiter(request, handler):
task = await self.task_queue.put(handler(request))
return await task

middlewares.append(apply_limiter)

stats: Collector = await self.context.inject(Collector, required=False)
if stats:

Expand All @@ -187,14 +204,16 @@ async def collect_stats(request, handler):
app.add_routes(
[
web.get("/", self.redirect_handler),
web.get("/modules", self.modules_handler),
web.get("/plugins", self.plugins_handler),
web.get("/status", self.status_handler),
web.post("/status/reset", self.status_reset_handler),
web.get("/ws", self.websocket_handler),
]
)

plugin_registry = await self.context.inject(PluginRegistry, required=False)
plugin_registry: PluginRegistry = await self.context.inject(
PluginRegistry, required=False
)
if plugin_registry:
await plugin_registry.register_admin_routes(app)

Expand Down Expand Up @@ -249,21 +268,15 @@ async def stop(self) -> None:
if self.site:
await self.site.stop()
self.site = None
if self.webhook_queue:
self.webhook_queue.stop()
self.webhook_queue = None
if self.webhook_session:
await self.webhook_session.close()
self.webhook_session = None

async def on_startup(self, app: web.Application):
"""Perform webserver startup actions."""

@docs(tags=["server"], summary="Fetch the list of loaded modules")
@docs(tags=["server"], summary="Fetch the list of loaded plugins")
@response_schema(AdminModulesSchema(), 200)
async def modules_handler(self, request: web.BaseRequest):
async def plugins_handler(self, request: web.BaseRequest):
"""
Request handler for the loaded modules list.
Request handler for the loaded plugins list.
Args:
request: aiohttp request object
Expand All @@ -272,7 +285,12 @@ async def modules_handler(self, request: web.BaseRequest):
The module list response
"""
return web.json_response({"result": self.loaded_modules})
registry: PluginRegistry = await self.context.inject(
PluginRegistry, required=False
)
print(registry)
plugins = registry and sorted(registry.plugin_names) or []
return web.json_response({"result": plugins})

@docs(tags=["server"], summary="Fetch the server status")
@response_schema(AdminStatusSchema(), 200)
Expand All @@ -291,6 +309,8 @@ async def status_handler(self, request: web.BaseRequest):
collector: Collector = await self.context.inject(Collector, required=False)
if collector:
status["timing"] = collector.results
if self.conductor_stats:
status["conductor"] = await self.conductor_stats()
return web.json_response(status)

@docs(tags=["server"], summary="Reset statistics")
Expand Down Expand Up @@ -321,7 +341,8 @@ async def websocket_handler(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
socket_id = str(uuid.uuid4())
queue = await self.context.inject(BaseOutboundMessageQueue)
queue = BasicMessageQueue()
loop = asyncio.get_event_loop()

try:
self.websocket_queues[socket_id] = queue
Expand All @@ -340,20 +361,40 @@ async def websocket_handler(self, request):
)

closed = False
receive = loop.create_task(ws.receive())
send = loop.create_task(queue.dequeue(timeout=5.0))

while not closed:
try:
msg = await queue.dequeue(timeout=5.0)
if msg is None:
# we send fake pings because the JS client
# can't detect real ones
msg = {"topic": "ping"}
await asyncio.wait(
(receive, send), return_when=asyncio.FIRST_COMPLETED
)
if ws.closed:
closed = True
if msg and not closed:
await ws.send_json(msg)

if receive.done():
# ignored
if not closed:
receive = loop.create_task(ws.receive())

if send.done():
msg = send.result()
if msg is None:
# we send fake pings because the JS client
# can't detect real ones
msg = {"topic": "ping"}
if not closed:
if msg:
await ws.send_json(msg)
send = loop.create_task(queue.dequeue(timeout=5.0))
except asyncio.CancelledError:
closed = True

if not receive.done():
receive.cancel()
if not send.done():
send.cancel()

finally:
del self.websocket_queues[socket_id]

Expand All @@ -374,65 +415,10 @@ def remove_webhook_target(self, target_url: str):

async def send_webhook(self, topic: str, payload: dict):
"""Add a webhook to the queue, to send to all registered targets."""
if not self.webhook_queue:
self.webhook_queue = await self.context.inject(BaseOutboundMessageQueue)
self.webhook_task = asyncio.get_event_loop().create_task(
self._process_webhooks()
)
await self.webhook_queue.enqueue((topic, payload))

async def _process_webhooks(self):
"""Continuously poll webhook queue and dispatch to targets."""
session_args = {}
collector: Collector = await self.context.inject(Collector, required=False)
if collector:
session_args["trace_configs"] = [StatsTracer(collector, "webhook-http:")]
session_args["cookie_jar"] = DummyCookieJar()
self.webhook_session = ClientSession(**session_args)
self.webhook_processor = TaskProcessor(max_pending=20)
async for topic, payload in self.webhook_queue:
for queue in self.websocket_queues.values():
await queue.enqueue({"topic": topic, "payload": payload})
if self.webhook_targets:
targets = self.webhook_targets.copy()
for idx, target in targets.items():
if not target.topic_filter or topic in target.topic_filter:
retries = (
self.webhook_retries
if target.retries is None
else target.retries
)
await self.webhook_processor.run_retry(
lambda pending: self._perform_send_webhook(
target.endpoint, topic, payload, pending.attempts + 1
),
ident=(target.endpoint, topic),
retries=retries,
)
self.webhook_queue.task_done()

async def _perform_send_webhook(
self, target_url: str, topic: str, payload: dict, attempt: int = None
):
"""Dispatch a webhook to a specific endpoint."""
full_webhook_url = f"{target_url}/topic/{topic}/"
attempt_str = f" (attempt {attempt})" if attempt else ""
LOGGER.debug("Sending webhook to : %s%s", full_webhook_url, attempt_str)
async with self.webhook_session.post(
full_webhook_url, json=payload
) as response:
if response.status < 200 or response.status > 299:
# raise Exception(f"Unexpected response status {response.status}")
raise Exception(
f"Unexpected: target {target_url}\n"
f"full {full_webhook_url}\n"
f"response {response}"
)
if self.webhook_router:
for idx, target in self.webhook_targets.items():
if not target.topic_filter or topic in target.topic_filter:
self.webhook_router(topic, payload, target.endpoint, target.retries)

async def complete_webhooks(self):
"""Wait for all pending webhooks to be dispatched, used in testing."""
if self.webhook_queue:
await self.webhook_queue.join()
self.webhook_queue.reset()
if self.webhook_processor:
await self.webhook_processor.wait_done()
for queue in self.websocket_queues.values():
await queue.enqueue({"topic": topic, "payload": payload})
Loading

0 comments on commit b333150

Please sign in to comment.