Skip to content

Commit

Permalink
feat(agents-api): Add YAML support
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 20, 2024
1 parent d92b054 commit 9962356
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 8 deletions.
35 changes: 35 additions & 0 deletions agents-api/agents_api/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import re

import yaml
from fastapi import Request


class YamlMiddleware:
def __init__(self, path_regex: str = r".*"):
self.path_regex = re.compile(path_regex)

async def __call__(self, request: Request, call_next):
content_type = request.headers.get("content-type", "").strip().lower()

# Filter out requests that are not for YAML and not for the specified path
if not self.path_regex.match(request.url.path) or content_type not in [
"application/x-yaml",
"application/yaml",
"text/yaml",
"text/x-yaml",
]:
return await call_next(request)

# Parse the YAML body into a Python object
body = yaml.load(await request.body(), yaml.CSafeLoader)
request._json = body

# Switch headers to JSON
headers = request.headers.mutablecopy()
headers["content-type"] = "application/json"

request._headers = headers

# Continue processing the request
response = await call_next(request)
return response
4 changes: 3 additions & 1 deletion agents-api/agents_api/routers/tasks/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ async def update_execution(
token_data = get_paused_execution_token(
developer_id=x_developer_id, execution_id=execution_id
)
act_handle = temporal_client.get_async_activity_handle(token_data["task_token"])
act_handle = temporal_client.get_async_activity_handle(
token_data["task_token"]
)
await act_handle.complete(data.input)

case _:
Expand Down
17 changes: 12 additions & 5 deletions agents-api/agents_api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
from fastapi import Depends, FastAPI, Request, status
from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from litellm.exceptions import APIError
from pycozo.client import QueryException
from temporalio.service import RPCError

from agents_api.common.exceptions import BaseCommonException
from agents_api.dependencies.auth import get_api_key
from agents_api.env import sentry_dsn
from agents_api.exceptions import PromptTooBigError
from agents_api.routers import (
from .common.exceptions import BaseCommonException
from .dependencies.auth import get_api_key
from .env import sentry_dsn
from .exceptions import PromptTooBigError
from .middleware import YamlMiddleware
from .routers import (
agents,
docs,
jobs,
Expand Down Expand Up @@ -89,6 +91,11 @@ def register_exceptions(app: FastAPI) -> None:
max_age=3600,
)

app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=3)

# Add yaml middleware
app.middleware("http")(YamlMiddleware(path_regex=r"/agents/.+/tasks.*"))

register_exceptions(app)

app.include_router(agents.router)
Expand Down
2 changes: 1 addition & 1 deletion agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pydantic-partial = "^0.5.5"
simpleeval = "^0.9.13"
lz4 = "^4.3.3"

pyyaml = "^6.0.2"
[tool.poetry.group.dev.dependencies]
ipython = "^8.26.0"
ruff = "^0.5.5"
Expand Down
2 changes: 1 addition & 1 deletion agents-api/tests/test_agent_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _(client=client):
response = client.request(
method="POST",
url="/agents",
data=data,
json=data,
)

assert response.status_code == 403
Expand Down
88 changes: 88 additions & 0 deletions agents-api/tests/test_workflow_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Tests for task queries

from uuid import uuid4

from ward import test

from .fixtures import cozo_client, test_agent, test_developer_id
from .utils import patch_http_client_with_temporal


@test("workflow route: evaluate step single")
async def _(
cozo_client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
agent_id = str(agent.id)
task_id = str(uuid4())

async with patch_http_client_with_temporal(
cozo_client=cozo_client, developer_id=developer_id
) as (
make_request,
client,
):
task_data = {
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [{"evaluate": {"hello": '"world"'}}],
}

make_request(
method="POST",
url=f"/agents/{agent_id}/tasks/{task_id}",
json=task_data,
)

execution_data = dict(input={"test": "input"})

make_request(
method="POST",
url=f"/tasks/{task_id}/executions",
json=execution_data,
)


@test("workflow route: evaluate step single with yaml")
async def _(
cozo_client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
agent_id = str(agent.id)
task_id = str(uuid4())

async with patch_http_client_with_temporal(
cozo_client=cozo_client, developer_id=developer_id
) as (
make_request,
client,
):
task_data = """
name: test task
description: test task about
input_schema:
type: object
additionalProperties: true
main:
- evaluate:
hello: '"world"'
"""

make_request(
method="POST",
url=f"/agents/{agent_id}/tasks/{task_id}",
content=task_data,
headers={"Content-Type": "text/yaml"},
)

execution_data = dict(input={"test": "input"})

make_request(
method="POST",
url=f"/tasks/{task_id}/executions",
json=execution_data,
)
23 changes: 23 additions & 0 deletions agents-api/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import asynccontextmanager
from unittest.mock import patch

from fastapi.testclient import TestClient
from temporalio.testing import WorkflowEnvironment

from agents_api.worker.codec import pydantic_data_converter
Expand Down Expand Up @@ -41,3 +42,25 @@ async def patch_testing_temporal():

# Reset log level
logger.setLevel(previous_log_level)


@asynccontextmanager
async def patch_http_client_with_temporal(*, cozo_client, developer_id):
async with patch_testing_temporal():
from agents_api.env import api_key, api_key_header_name
from agents_api.web import app

client = TestClient(app=app)
app.state.cozo_client = cozo_client

def make_request(method, url, **kwargs):
headers = kwargs.pop("headers", {})
headers = {
**headers,
"X-Developer-Id": str(developer_id),
api_key_header_name: api_key,
}

return client.request(method, url, headers=headers, **kwargs)

yield make_request, client

0 comments on commit 9962356

Please sign in to comment.