Skip to content

Commit

Permalink
Merge pull request #3 from Indicio-tech/feature/sessions
Browse files Browse the repository at this point in the history
Implement opening a websocket session to connection
  • Loading branch information
mepeltier authored Nov 17, 2021
2 parents 008aa56 + a82900c commit 031794f
Show file tree
Hide file tree
Showing 8 changed files with 524 additions and 250 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ COPY --from=base /usr/src/app /usr/src/app
ENV PATH="/usr/src/app/.venv/bin:$PATH"

COPY ./echo_agent/ ./echo_agent/
ENTRYPOINT ["/bin/sh", "-c", "python -m \"$@\"", "--"]
CMD ["uvicorn", "echo_agent:app", "--host", "0.0.0.0", "--port", "80"]
ENTRYPOINT ["/bin/sh", "-c", "python -m uvicorn echo_agent:app \"$@\"", "--"]
CMD ["--host", "0.0.0.0", "--port", "80"]
109 changes: 86 additions & 23 deletions echo_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- send message
"""

import asyncio
import logging
from typing import Dict, List, Optional
from uuid import uuid4
Expand All @@ -25,20 +26,31 @@
crypto,
)
from fastapi import Body, FastAPI, HTTPException, Request
from .models import NewConnection, ConnectionInfo
from pydantic import dataclasses

from .session import Session, SessionMessage
from .models import (
NewConnection,
ConnectionInfo as ConnectionInfoDataclass,
SessionInfo,
)

# Logging
LOGGER = logging.getLogger("uvicorn.error." + __name__)

# Global state
connections: Dict[str, Connection] = {}
sessions: Dict[str, Session] = {}
recip_key_to_connection_id: Dict[str, str] = {}
messages: Dict[str, MsgQueue] = {}


app = FastAPI(title="Echo Agent", version="0.1.0")


ConnectionInfo = dataclasses.dataclass(ConnectionInfoDataclass)


@app.post("/connection", response_model=ConnectionInfo, operation_id="new_connection")
async def new_connection(new_connection: NewConnection):
"""Create a new static connection."""
Expand Down Expand Up @@ -130,16 +142,21 @@ async def new_message(request: Request):
response_model=List[Message],
operation_id="retrieve_messages",
)
async def get_messages(connection_id: str):
async def get_messages(connection_id: str, session_id: Optional[str] = None):
"""Retrieve all received messages for recipient key."""
if connection_id not in messages:
raise HTTPException(
status_code=404, detail=f"No connection id matching {connection_id}"
)

LOGGER.debug("Retrieving messages for connection_id %s", connection_id)
queue = messages[connection_id]
return await queue.flush()
if not session_id:
LOGGER.debug("Retrieving messages for connection_id %s", connection_id)
return queue.get_all()

return queue.get_all(
lambda msg: isinstance(msg, SessionMessage) and msg.session_id == session_id
)


@app.get(
Expand All @@ -150,25 +167,21 @@ async def get_message(
thid: Optional[str] = None,
msg_type: Optional[str] = None,
wait: Optional[bool] = True,
session_id: Optional[str] = None,
timeout: int = 5,
):
"""Wait for a message matching criteria."""

def _thid_match(msg: Message):
return msg.thread["thid"] == thid

def _msg_type_match(msg: Message):
return msg.type == msg_type

def _thid_and_msg_type_match(msg: Message):
return _thid_match(msg) and _msg_type_match(msg)

condition = None
if thid is not None:
condition = _thid_match
if msg_type is not None:
condition = _msg_type_match
if thid is not None and msg_type is not None:
condition = _thid_and_msg_type_match
def _condition(msg: Message):
return all(
[
msg.thread["thid"] == thid if thid else True,
msg.type == msg_type if msg_type else True,
msg.session_id == session_id
if isinstance(msg, SessionMessage) and session_id
else True,
]
)

if connection_id not in messages:
raise HTTPException(
Expand All @@ -177,9 +190,18 @@ def _thid_and_msg_type_match(msg: Message):

queue = messages[connection_id]
if wait:
message = await queue.get(condition=condition)
try:
message = await queue.get(condition=_condition, timeout=timeout)
except asyncio.TimeoutError:
raise HTTPException(
status_code=408,
detail=(
f"No message found for connection id {connection_id} "
"before timeout"
),
)
else:
message = queue.get_nowait(condition=condition)
message = queue.get_nowait(condition=_condition)

if not message:
raise HTTPException(
Expand All @@ -193,7 +215,7 @@ def _thid_and_msg_type_match(msg: Message):

@app.post("/message/{connection_id}", operation_id="send_message")
async def send_message(connection_id: str, message: dict = Body(...)):
"""Send a message to connection identified by did."""
"""Send a message to connection identified by connection ID."""
LOGGER.debug("Sending message to %s: %s", connection_id, message)
if connection_id not in connections:
raise HTTPException(
Expand All @@ -203,4 +225,45 @@ async def send_message(connection_id: str, message: dict = Body(...)):
await conn.send_async(message)


@app.get(
"/session/{connection_id}", operation_id="open_session", response_model=SessionInfo
)
async def open_session(connection_id: str, endpoint: Optional[str] = None):
"""Open a session."""
if connection_id not in connections:
raise HTTPException(
status_code=404, detail=f"No connection matching {connection_id} found"
)
conn = connections[connection_id]

session = Session(conn, endpoint)
sessions[session.id] = session
session.open()
return SessionInfo(session.id, connection_id)


@app.delete("/session/{session_id}")
async def close_session(session_id: str):
"""Close an open session."""
if session_id not in sessions:
raise HTTPException(
status_code=404, detail=f"No session matching {session_id} found"
)
session = sessions[session_id]
await session.close()
sessions.pop(session_id)
return session_id


@app.post("/message/session/{session_id}")
async def send_message_to_session(session_id: str, message: dict = Body(...)):
"""Send a message to a session identified by session ID."""
if session_id not in sessions:
raise HTTPException(
status_code=404, detail=f"No session matching {session_id} found"
)
session = sessions[session_id]
await session.send(message)


__all__ = ["app"]
85 changes: 77 additions & 8 deletions echo_agent/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Client to Echo Agent."""
from contextlib import AbstractAsyncContextManager
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from dataclasses import asdict
from typing import Any, List, Mapping, Optional, Union

from httpx import AsyncClient

from .models import ConnectionInfo, NewConnection
from .models import ConnectionInfo, NewConnection, SessionInfo


class EchoClientError(Exception):
Expand Down Expand Up @@ -37,7 +37,7 @@ async def __aexit__(self, exc_type, exc_value, traceback):
await self.client.__aexit__(exc_type, exc_value, traceback)

async def new_connection(
self, seed: Union[str, bytes], endpoint: str, their_vk: str
self, seed: str, endpoint: str, their_vk: str
) -> ConnectionInfo:
if not self.client:
raise NoOpenClient(
Expand Down Expand Up @@ -94,7 +94,9 @@ async def new_message(self, packed_message: bytes):
raise EchoClientError("Failed to receive message")

async def get_messages(
self, connection: Union[str, ConnectionInfo]
self,
connection: Union[str, ConnectionInfo],
session: Union[str, SessionInfo, None] = None,
) -> List[Mapping[str, Any]]:
if not self.client:
raise NoOpenClient(
Expand All @@ -104,7 +106,17 @@ async def get_messages(
connection_id = (
connection if isinstance(connection, str) else connection.connection_id
)
response = await self.client.get(f"/messages/{connection_id}")
session_id = (
session
if isinstance(session, str)
else session.session_id
if isinstance(session, SessionInfo)
else None
)
response = await self.client.get(
f"/messages/{connection_id}",
params={"session_id": session_id} if session_id else {},
)

if response.is_error:
raise EchoClientError(f"Failed to retrieve messages: {response.content}")
Expand All @@ -114,8 +126,10 @@ async def get_messages(
async def get_message(
self,
connection: Union[str, ConnectionInfo],
*,
thid: Optional[str] = None,
msg_type: Optional[str] = None,
session: Optional[Union[str, SessionInfo]] = None,
wait: Optional[bool] = True,
timeout: Optional[int] = 5,
) -> Mapping[str, Any]:
Expand All @@ -127,14 +141,26 @@ async def get_message(
connection_id = (
connection if isinstance(connection, str) else connection.connection_id
)
session_id = (
session
if isinstance(session, str)
else session.session_id
if isinstance(session, SessionInfo)
else None
)
response = await self.client.get(
f"/message/{connection_id}",
params={
k: v
for k, v in {"thid": thid, "msg_type": msg_type, "wait": wait}.items()
for k, v in {
"thid": thid,
"msg_type": msg_type,
"session_id": session_id,
"wait": wait,
"timeout": timeout,
}.items()
if v is not None
},
timeout=timeout,
)

if response.is_error:
Expand All @@ -158,4 +184,47 @@ async def send_message(
response = await self.client.post(f"/message/{connection_id}", json=message)

if response.is_error:
raise EchoClientError("Failed to send message")
raise EchoClientError(f"Failed to send message: {response.content}")

@asynccontextmanager
async def session(
self, connection: Union[str, ConnectionInfo], endpoint: Optional[str] = None
):
"""Open a session."""
if not self.client:
raise NoOpenClient(
"No client has been opened; use `async with echo_client`"
)

connection_id = (
connection if isinstance(connection, str) else connection.connection_id
)
session_info: Optional[SessionInfo] = None
try:
response = await self.client.get(
f"/session/{connection_id}",
params={"endpoint": endpoint} if endpoint else {},
)
if response.is_error:
raise EchoClientError(f"Failed to open session: {response.content}")
session_info = SessionInfo(**response.json())
yield session_info
finally:
if session_info:
await self.client.delete(f"/session/{session_info.session_id}")

async def send_message_to_session(
self, session: Union[str, SessionInfo], message: Mapping[str, Any]
):
if not self.client:
raise NoOpenClient(
"No client has been opened; use `async with echo_client`"
)

session_id = session if isinstance(session, str) else session.session_id
response = await self.client.post(
f"/message/session/{session_id}", json=message
)

if response.is_error:
raise EchoClientError(f"Failed to send message: {response.content}")
6 changes: 6 additions & 0 deletions echo_agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ class ConnectionInfo:
verkey: str
their_vk: str
endpoint: str


@dataclass
class SessionInfo:
session_id: str
connection_id: str
Loading

0 comments on commit 031794f

Please sign in to comment.