Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add webhook capture #5

Merged
merged 5 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions echo_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
from fastapi import Body, FastAPI, HTTPException, Request
from pydantic import dataclasses

from .webhook_queue import Queue

from .session import Session, SessionMessage
from .models import (
NewConnection,
ConnectionInfo as ConnectionInfoDataclass,
SessionInfo,
Webhook,
)

# Logging
Expand All @@ -43,11 +46,17 @@
sessions: Dict[str, Session] = {}
recip_key_to_connection_id: Dict[str, str] = {}
messages: Dict[str, MsgQueue] = {}
webhooks: Queue[Webhook] = Queue()


app = FastAPI(title="Echo Agent", version="0.1.0")


@app.on_event("startup")
async def setup_webhook_queue():
await webhooks.setup()


ConnectionInfo = dataclasses.dataclass(ConnectionInfoDataclass)


Expand Down Expand Up @@ -266,4 +275,57 @@ async def send_message_to_session(session_id: str, message: dict = Body(...)):
await session.send(message)


@app.post("/webhook/{topic:path}", response_model=Webhook)
async def receive_webhook(topic: str, payload: dict = Body(...)):
"""Receive a webhook."""
LOGGER.debug("Received webhook: topic %s, payload %s", topic, payload)
await webhooks.put(Webhook(topic, payload))


@app.get(
"/webhooks",
response_model=List[Webhook],
operation_id="",
)
async def get_webhooks(topic: Optional[str] = None):
"""Retrieve all received messages for recipient key."""
if not topic:
LOGGER.debug("Retrieving webhooks")
return webhooks.get_all()

return webhooks.get_all(lambda entry: entry.topic == topic)


@app.get("/webhook", response_model=Webhook, operation_id="wait_for_webhook")
async def get_webhook(
topic: Optional[str] = None,
wait: Optional[bool] = True,
timeout: int = 5,
):
"""Wait for a message matching criteria."""

def _condition(entry: Webhook):
return entry.topic == topic if topic else True

if wait:
try:
webhook = await webhooks.get(condition=_condition, timeout=timeout)
except asyncio.TimeoutError:
raise HTTPException(
status_code=408,
detail=("No webhook found before timeout"),
)
else:
webhook = webhooks.get_nowait(condition=_condition)

if not webhook:
raise HTTPException(
status_code=404,
detail="No webhook found",
)

LOGGER.debug("Received webhook, returning to waiting client: %s", webhook)
return webhook


__all__ = ["app"]
65 changes: 63 additions & 2 deletions echo_agent/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Client to Echo Agent."""
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from dataclasses import asdict
from typing import Any, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

from httpx import AsyncClient

from .models import ConnectionInfo, NewConnection, SessionInfo
from .models import ConnectionInfo, NewConnection, SessionInfo, Webhook


class EchoClientError(Exception):
Expand Down Expand Up @@ -228,3 +228,64 @@ async def send_message_to_session(

if response.is_error:
raise EchoClientError(f"Failed to send message: {response.content}")

async def new_webhook(self, topic: str, payload: Dict[str, Any]):
if not self.client:
raise NoOpenClient(
"No client has been opened; use `async with echo_client`"
)

response = await self.client.post(f"/webhook/{topic}", json=payload)

if response.is_error:
raise EchoClientError("Failed to receive webhook")

async def get_webhooks(
self,
*,
topic: Optional[str] = None,
) -> List[Webhook]:
if not self.client:
raise NoOpenClient(
"No client has been opened; use `async with echo_client`"
)

response = await self.client.get(
"/webhooks",
params={"topic": topic} if topic else {},
)

if response.is_error:
raise EchoClientError(f"Failed to retrieve webhooks: {response.content}")

return response.json()

async def get_webhook(
self,
*,
topic: Optional[str] = None,
wait: Optional[bool] = True,
timeout: Optional[int] = 5,
) -> Mapping[str, Any]:
if not self.client:
raise NoOpenClient(
"No client has been opened; use `async with echo_client`"
)

response = await self.client.get(
"/webhook",
params={
k: v
for k, v in {
"topic": topic,
"wait": wait,
"timeout": timeout,
}.items()
if v is not None
},
)

if response.is_error:
raise EchoClientError(f"Failed to wait for webhook: {response.content}")

return response.json()
7 changes: 7 additions & 0 deletions echo_agent/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Dict


@dataclass
Expand All @@ -21,3 +22,9 @@ class ConnectionInfo:
class SessionInfo:
session_id: str
connection_id: str


@dataclass
class Webhook:
topic: str
payload: Dict[str, Any]
112 changes: 112 additions & 0 deletions echo_agent/webhook_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import asyncio
from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar


QueueEntry = TypeVar("QueueEntry")


class Queue(Generic[QueueEntry]):
def __init__(
self,
*,
condition: Optional[Callable[[QueueEntry], bool]] = None,
):
self._queue: List[Any] = []
self._cond: Optional[asyncio.Condition] = None
self.condition = condition

async def setup(self):
self._cond = asyncio.Condition()

def _first_matching_index(self, condition: Callable[[QueueEntry], bool]):
for index, entry in enumerate(self._queue):
if condition(entry):
return index
return None

async def _get(
self, condition: Optional[Callable[[QueueEntry], bool]] = None
) -> QueueEntry:
"""Retrieve a message from the queue."""
while True:
async with self._cond:
# Lock acquired
if not self._queue:
# No items on queue yet so we need to wait for items to show up
await self._cond.wait()

if not self._queue:
# Another task grabbed the value before we got to it
continue

if not condition:
# Just get the first message
return self._queue.pop()

# Return first matching item, if present
match_idx = self._first_matching_index(condition)
if match_idx is not None:
return self._queue.pop(match_idx)

async def get(
self,
condition: Optional[Callable[[QueueEntry], bool]] = None,
*,
timeout: int = 5,
) -> QueueEntry:
"""Retrieve a message from the queue."""
return await asyncio.wait_for(self._get(condition), timeout)

def get_all(
self, condition: Optional[Callable[[QueueEntry], bool]] = None
) -> Sequence[QueueEntry]:
"""Return all messages matching a given condition."""
messages = []
if not self._queue:
return messages

if not condition:
messages = [entry for entry in self._queue]
self._queue.clear()
return messages

# Store messages that didn't match in the order they are seen
filtered: List[QueueEntry] = []
for entry in self._queue:
if condition(entry):
messages.append(entry)
else:
filtered.append(entry)

# Queue contents set to messages that didn't match condition
self._queue[:] = filtered
return messages

def get_nowait(
self, condition: Optional[Callable[[QueueEntry], bool]] = None
) -> Optional[QueueEntry]:
"""Return a message from the queue without waiting."""
if not self._queue:
return None

if not condition:
return self._queue.pop()

match_idx = self._first_matching_index(condition)
if match_idx is not None:
return self._queue.pop(match_idx)

return None

async def put(self, value: QueueEntry):
"""Push a message onto the queue and notify waiting tasks."""
if not self.condition or self.condition(value):
async with self._cond:
self._queue.append(value)
self._cond.notify_all()

def flush(self) -> Sequence[QueueEntry]:
"""Clear queue and return final contents of queue at time of clear."""
final = self._queue.copy()
self._queue.clear()
return final
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
from echo_agent import app, EchoClient
from echo_agent import EchoClient
from echo_agent.app import webhooks


@pytest.fixture
def echo_client():
async def echo_client():
from echo_agent import app

await webhooks.setup()
yield EchoClient(base_url="http://test", app=app)
Loading