diff --git a/echo_agent/app.py b/echo_agent/app.py index b05d099..10d5cce 100644 --- a/echo_agent/app.py +++ b/echo_agent/app.py @@ -28,11 +28,14 @@ from fastapi import Body, FastAPI, HTTPException, Request from pydantic import dataclasses +from .webhook_queue import Queue + from .session import Session, SessionMessage from .models import ( NewConnection, ConnectionInfo as ConnectionInfoDataclass, SessionInfo, + Webhook, ) # Logging @@ -43,11 +46,17 @@ sessions: Dict[str, Session] = {} recip_key_to_connection_id: Dict[str, str] = {} messages: Dict[str, MsgQueue] = {} +webhooks: Queue[Webhook] = Queue() app = FastAPI(title="Echo Agent", version="0.1.0") +@app.on_event("startup") +async def setup_webhook_queue(): + await webhooks.setup() + + ConnectionInfo = dataclasses.dataclass(ConnectionInfoDataclass) @@ -266,4 +275,57 @@ async def send_message_to_session(session_id: str, message: dict = Body(...)): await session.send(message) +@app.post("/webhook/{topic:path}", response_model=Webhook) +async def receive_webhook(topic: str, payload: dict = Body(...)): + """Receive a webhook.""" + LOGGER.debug("Received webhook: topic %s, payload %s", topic, payload) + await webhooks.put(Webhook(topic, payload)) + + +@app.get( + "/webhooks", + response_model=List[Webhook], + operation_id="", +) +async def get_webhooks(topic: Optional[str] = None): + """Retrieve all received messages for recipient key.""" + if not topic: + LOGGER.debug("Retrieving webhooks") + return webhooks.get_all() + + return webhooks.get_all(lambda entry: entry.topic == topic) + + +@app.get("/webhook", response_model=Webhook, operation_id="wait_for_webhook") +async def get_webhook( + topic: Optional[str] = None, + wait: Optional[bool] = True, + timeout: int = 5, +): + """Wait for a message matching criteria.""" + + def _condition(entry: Webhook): + return entry.topic == topic if topic else True + + if wait: + try: + webhook = await webhooks.get(condition=_condition, timeout=timeout) + except asyncio.TimeoutError: + raise HTTPException( + status_code=408, + detail=("No webhook found before timeout"), + ) + else: + webhook = webhooks.get_nowait(condition=_condition) + + if not webhook: + raise HTTPException( + status_code=404, + detail="No webhook found", + ) + + LOGGER.debug("Received webhook, returning to waiting client: %s", webhook) + return webhook + + __all__ = ["app"] diff --git a/echo_agent/client.py b/echo_agent/client.py index 2785508..4eb615b 100644 --- a/echo_agent/client.py +++ b/echo_agent/client.py @@ -1,11 +1,11 @@ """Client to Echo Agent.""" from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import asdict -from typing import Any, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union from httpx import AsyncClient -from .models import ConnectionInfo, NewConnection, SessionInfo +from .models import ConnectionInfo, NewConnection, SessionInfo, Webhook class EchoClientError(Exception): @@ -228,3 +228,64 @@ async def send_message_to_session( if response.is_error: raise EchoClientError(f"Failed to send message: {response.content}") + + async def new_webhook(self, topic: str, payload: Dict[str, Any]): + if not self.client: + raise NoOpenClient( + "No client has been opened; use `async with echo_client`" + ) + + response = await self.client.post(f"/webhook/{topic}", json=payload) + + if response.is_error: + raise EchoClientError("Failed to receive webhook") + + async def get_webhooks( + self, + *, + topic: Optional[str] = None, + ) -> List[Webhook]: + if not self.client: + raise NoOpenClient( + "No client has been opened; use `async with echo_client`" + ) + + response = await self.client.get( + "/webhooks", + params={"topic": topic} if topic else {}, + ) + + if response.is_error: + raise EchoClientError(f"Failed to retrieve webhooks: {response.content}") + + return response.json() + + async def get_webhook( + self, + *, + topic: Optional[str] = None, + wait: Optional[bool] = True, + timeout: Optional[int] = 5, + ) -> Mapping[str, Any]: + if not self.client: + raise NoOpenClient( + "No client has been opened; use `async with echo_client`" + ) + + response = await self.client.get( + "/webhook", + params={ + k: v + for k, v in { + "topic": topic, + "wait": wait, + "timeout": timeout, + }.items() + if v is not None + }, + ) + + if response.is_error: + raise EchoClientError(f"Failed to wait for webhook: {response.content}") + + return response.json() diff --git a/echo_agent/models.py b/echo_agent/models.py index 8672d88..aae2e2e 100644 --- a/echo_agent/models.py +++ b/echo_agent/models.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any, Dict @dataclass @@ -21,3 +22,9 @@ class ConnectionInfo: class SessionInfo: session_id: str connection_id: str + + +@dataclass +class Webhook: + topic: str + payload: Dict[str, Any] diff --git a/echo_agent/webhook_queue.py b/echo_agent/webhook_queue.py new file mode 100644 index 0000000..5cff3ad --- /dev/null +++ b/echo_agent/webhook_queue.py @@ -0,0 +1,112 @@ +import asyncio +from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar + + +QueueEntry = TypeVar("QueueEntry") + + +class Queue(Generic[QueueEntry]): + def __init__( + self, + *, + condition: Optional[Callable[[QueueEntry], bool]] = None, + ): + self._queue: List[Any] = [] + self._cond: Optional[asyncio.Condition] = None + self.condition = condition + + async def setup(self): + self._cond = asyncio.Condition() + + def _first_matching_index(self, condition: Callable[[QueueEntry], bool]): + for index, entry in enumerate(self._queue): + if condition(entry): + return index + return None + + async def _get( + self, condition: Optional[Callable[[QueueEntry], bool]] = None + ) -> QueueEntry: + """Retrieve a message from the queue.""" + while True: + async with self._cond: + # Lock acquired + if not self._queue: + # No items on queue yet so we need to wait for items to show up + await self._cond.wait() + + if not self._queue: + # Another task grabbed the value before we got to it + continue + + if not condition: + # Just get the first message + return self._queue.pop() + + # Return first matching item, if present + match_idx = self._first_matching_index(condition) + if match_idx is not None: + return self._queue.pop(match_idx) + + async def get( + self, + condition: Optional[Callable[[QueueEntry], bool]] = None, + *, + timeout: int = 5, + ) -> QueueEntry: + """Retrieve a message from the queue.""" + return await asyncio.wait_for(self._get(condition), timeout) + + def get_all( + self, condition: Optional[Callable[[QueueEntry], bool]] = None + ) -> Sequence[QueueEntry]: + """Return all messages matching a given condition.""" + messages = [] + if not self._queue: + return messages + + if not condition: + messages = [entry for entry in self._queue] + self._queue.clear() + return messages + + # Store messages that didn't match in the order they are seen + filtered: List[QueueEntry] = [] + for entry in self._queue: + if condition(entry): + messages.append(entry) + else: + filtered.append(entry) + + # Queue contents set to messages that didn't match condition + self._queue[:] = filtered + return messages + + def get_nowait( + self, condition: Optional[Callable[[QueueEntry], bool]] = None + ) -> Optional[QueueEntry]: + """Return a message from the queue without waiting.""" + if not self._queue: + return None + + if not condition: + return self._queue.pop() + + match_idx = self._first_matching_index(condition) + if match_idx is not None: + return self._queue.pop(match_idx) + + return None + + async def put(self, value: QueueEntry): + """Push a message onto the queue and notify waiting tasks.""" + if not self.condition or self.condition(value): + async with self._cond: + self._queue.append(value) + self._cond.notify_all() + + def flush(self) -> Sequence[QueueEntry]: + """Clear queue and return final contents of queue at time of clear.""" + final = self._queue.copy() + self._queue.clear() + return final diff --git a/tests/conftest.py b/tests/conftest.py index c7e089b..986dc46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,11 @@ import pytest -from echo_agent import app, EchoClient +from echo_agent import EchoClient +from echo_agent.app import webhooks @pytest.fixture -def echo_client(): +async def echo_client(): + from echo_agent import app + + await webhooks.setup() yield EchoClient(base_url="http://test", app=app) diff --git a/tests/test_client.py b/tests/test_client.py index 8a86649..3963ebd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,7 +6,7 @@ from aries_staticagent.message import Message import pytest -from echo_agent.app import connections, messages, recip_key_to_connection_id +from echo_agent.app import connections, messages, recip_key_to_connection_id, webhooks from echo_agent.client import EchoClient, NoOpenClient from echo_agent.models import ConnectionInfo from echo_agent.session import SessionMessage @@ -199,3 +199,90 @@ async def test_get_message_no_wait( await echo_client.new_message(recip.pack(msg)) message = await echo_client.get_message(connection_id, wait=False) assert message + + +@pytest.mark.asyncio +async def test_receive_webhook( + echo_client: EchoClient, recip: Connection, conn: Connection, connection_id: str +): + """Test reception of a webhook.""" + async with echo_client: + await echo_client.new_webhook("test", {"test": "test"}) + assert webhooks._queue + + +@pytest.mark.asyncio +async def test_get_webhooks( + echo_client: EchoClient, recip: Connection, conn: Connection, connection_id: str +): + """Test reception of a webhook.""" + async with echo_client: + await echo_client.new_webhook("test", {"test": "test"}) + webhooks = await echo_client.get_webhooks() + assert webhooks + + +@pytest.mark.asyncio +async def test_get_webhooks_condition( + echo_client: EchoClient, recip: Connection, conn: Connection, connection_id: str +): + """Test reception of a webhook.""" + async with echo_client: + await echo_client.new_webhook("test", {"test": "test"}) + webhooks = await echo_client.get_webhooks(topic="test") + assert webhooks + + +@pytest.mark.asyncio +async def test_get_webhook_post( + echo_client: EchoClient, recip: Connection, conn: Connection, connection_id: str +): + """Test reception of a webhook.""" + async with echo_client: + await echo_client.new_webhook("test", {"test": "test"}) + webhook = await echo_client.get_webhook() + assert webhook + + +@pytest.mark.asyncio +async def test_get_webhook_post_condition( + echo_client: EchoClient, recip: Connection, conn: Connection, connection_id: str +): + """Test reception of a webhook.""" + async with echo_client: + await echo_client.new_webhook("test", {"test": "test"}) + webhook = await echo_client.get_webhook(topic="test") + assert webhook + + +@pytest.mark.asyncio +async def test_get_webhook_pre( + echo_client: EchoClient, recip: Connection, conn: Connection, connection_id: str +): + """Test reception of a webhook.""" + + async def _produce(echo_client): + await asyncio.sleep(0.5) + await echo_client.new_webhook("test", {"test": "test"}) + + async def _consume(echo_client): + return await echo_client.get_webhook(topic="test") + + async with echo_client: + loop = asyncio.get_event_loop() + _, webhook = await asyncio.gather( + loop.create_task(_produce(echo_client)), + loop.create_task(_consume(echo_client)), + ) + assert webhook + + +@pytest.mark.asyncio +async def test_get_webhook_no_wait( + echo_client: EchoClient, recip: Connection, conn: Connection, connection_id: str +): + """Test reception of a webhook.""" + async with echo_client: + await echo_client.new_webhook("test", {"test": "test"}) + webhook = await echo_client.get_webhook(topic="test", wait=False) + assert webhook