diff --git a/agents-api/agents_api/middleware.py b/agents-api/agents_api/middleware.py new file mode 100644 index 000000000..a76b29fa6 --- /dev/null +++ b/agents-api/agents_api/middleware.py @@ -0,0 +1,35 @@ +import re + +import yaml +from fastapi import Request + + +class YamlMiddleware: + def __init__(self, path_regex: str = r".*"): + self.path_regex = re.compile(path_regex) + + async def __call__(self, request: Request, call_next): + content_type = request.headers.get("content-type", "").strip().lower() + + # Filter out requests that are not for YAML and not for the specified path + if not self.path_regex.match(request.url.path) or content_type not in [ + "application/x-yaml", + "application/yaml", + "text/yaml", + "text/x-yaml", + ]: + return await call_next(request) + + # Parse the YAML body into a Python object + body = yaml.load(await request.body(), yaml.CSafeLoader) + request._json = body + + # Switch headers to JSON + headers = request.headers.mutablecopy() + headers["content-type"] = "application/json" + + request._headers = headers + + # Continue processing the request + response = await call_next(request) + return response diff --git a/agents-api/agents_api/routers/tasks/update_execution.py b/agents-api/agents_api/routers/tasks/update_execution.py index 697846f6c..779a7121b 100644 --- a/agents-api/agents_api/routers/tasks/update_execution.py +++ b/agents-api/agents_api/routers/tasks/update_execution.py @@ -38,7 +38,9 @@ async def update_execution( token_data = get_paused_execution_token( developer_id=x_developer_id, execution_id=execution_id ) - act_handle = temporal_client.get_async_activity_handle(token_data["task_token"]) + act_handle = temporal_client.get_async_activity_handle( + token_data["task_token"] + ) await act_handle.complete(data.input) case _: diff --git a/agents-api/agents_api/web.py b/agents-api/agents_api/web.py index 2a24dcebb..e20803e92 100644 --- a/agents-api/agents_api/web.py +++ b/agents-api/agents_api/web.py @@ -11,16 +11,18 @@ from fastapi import Depends, FastAPI, Request, status from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import JSONResponse from litellm.exceptions import APIError from pycozo.client import QueryException from temporalio.service import RPCError -from agents_api.common.exceptions import BaseCommonException -from agents_api.dependencies.auth import get_api_key -from agents_api.env import sentry_dsn -from agents_api.exceptions import PromptTooBigError -from agents_api.routers import ( +from .common.exceptions import BaseCommonException +from .dependencies.auth import get_api_key +from .env import sentry_dsn +from .exceptions import PromptTooBigError +from .middleware import YamlMiddleware +from .routers import ( agents, docs, jobs, @@ -89,6 +91,11 @@ def register_exceptions(app: FastAPI) -> None: max_age=3600, ) +app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=3) + +# Add yaml middleware +app.middleware("http")(YamlMiddleware(path_regex=r"/agents/.+/tasks.*")) + register_exceptions(app) app.include_router(agents.router) diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index 00d7a4708..8223195dd 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -4484,4 +4484,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "d0445b363a8642838c50a25b5f9e440e1defca471cb3bf547e800a44dc9fe083" +content-hash = "f0057b059c1db08b485252ae5e4139f89126a1f82873ee56604c94a9b9d2142a" diff --git a/agents-api/pyproject.toml b/agents-api/pyproject.toml index 59111206f..5e965a97b 100644 --- a/agents-api/pyproject.toml +++ b/agents-api/pyproject.toml @@ -32,6 +32,7 @@ pydantic-partial = "^0.5.5" simpleeval = "^0.9.13" lz4 = "^4.3.3" +pyyaml = "^6.0.2" [tool.poetry.group.dev.dependencies] ipython = "^8.26.0" ruff = "^0.5.5" diff --git a/agents-api/tests/test_agent_routes.py b/agents-api/tests/test_agent_routes.py index 353d6ab95..0bd6d36df 100644 --- a/agents-api/tests/test_agent_routes.py +++ b/agents-api/tests/test_agent_routes.py @@ -17,7 +17,7 @@ def _(client=client): response = client.request( method="POST", url="/agents", - data=data, + json=data, ) assert response.status_code == 403 diff --git a/agents-api/tests/test_workflow_routes.py b/agents-api/tests/test_workflow_routes.py new file mode 100644 index 000000000..b1fb5abf0 --- /dev/null +++ b/agents-api/tests/test_workflow_routes.py @@ -0,0 +1,88 @@ +# Tests for task queries + +from uuid import uuid4 + +from ward import test + +from .fixtures import cozo_client, test_agent, test_developer_id +from .utils import patch_http_client_with_temporal + + +@test("workflow route: evaluate step single") +async def _( + cozo_client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + agent_id = str(agent.id) + task_id = str(uuid4()) + + async with patch_http_client_with_temporal( + cozo_client=cozo_client, developer_id=developer_id + ) as ( + make_request, + client, + ): + task_data = { + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "main": [{"evaluate": {"hello": '"world"'}}], + } + + make_request( + method="POST", + url=f"/agents/{agent_id}/tasks/{task_id}", + json=task_data, + ) + + execution_data = dict(input={"test": "input"}) + + make_request( + method="POST", + url=f"/tasks/{task_id}/executions", + json=execution_data, + ) + + +@test("workflow route: evaluate step single with yaml") +async def _( + cozo_client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + agent_id = str(agent.id) + task_id = str(uuid4()) + + async with patch_http_client_with_temporal( + cozo_client=cozo_client, developer_id=developer_id + ) as ( + make_request, + client, + ): + task_data = """ + name: test task + description: test task about + input_schema: + type: object + additionalProperties: true + + main: + - evaluate: + hello: '"world"' + """ + + make_request( + method="POST", + url=f"/agents/{agent_id}/tasks/{task_id}", + content=task_data, + headers={"Content-Type": "text/yaml"}, + ) + + execution_data = dict(input={"test": "input"}) + + make_request( + method="POST", + url=f"/tasks/{task_id}/executions", + json=execution_data, + ) diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index b15332300..74127f78d 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -3,6 +3,7 @@ from contextlib import asynccontextmanager from unittest.mock import patch +from fastapi.testclient import TestClient from temporalio.testing import WorkflowEnvironment from agents_api.worker.codec import pydantic_data_converter @@ -41,3 +42,25 @@ async def patch_testing_temporal(): # Reset log level logger.setLevel(previous_log_level) + + +@asynccontextmanager +async def patch_http_client_with_temporal(*, cozo_client, developer_id): + async with patch_testing_temporal(): + from agents_api.env import api_key, api_key_header_name + from agents_api.web import app + + client = TestClient(app=app) + app.state.cozo_client = cozo_client + + def make_request(method, url, **kwargs): + headers = kwargs.pop("headers", {}) + headers = { + **headers, + "X-Developer-Id": str(developer_id), + api_key_header_name: api_key, + } + + return client.request(method, url, headers=headers, **kwargs) + + yield make_request, client