Skip to content

Commit

Permalink
Merge branch '292-featapi-assistants-endpoints' into 493-add-authenti…
Browse files Browse the repository at this point in the history
…cation
  • Loading branch information
CollectiveUnicorn authored May 17, 2024
2 parents 5ea4687 + c1ac5f8 commit c9f8c55
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 75 deletions.
9 changes: 4 additions & 5 deletions src/leapfrogai_api/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pydantic import BaseModel
from fastapi import UploadFile, Form, File
from openai.types import FileObject
from openai.types.beta import Assistant
from openai.types.beta import Assistant, AssistantTool
from openai.types.beta.assistant import ToolResources

##########
# GENERIC
Expand Down Expand Up @@ -250,10 +251,8 @@ class CreateAssistantRequest(BaseModel):
name: str | None = "Froggy Assistant"
description: str | None = "A helpful assistant."
instructions: str | None = "You are a helpful assistant."
tools: list[dict[Literal["type"], Literal["file_search"]]] | None = [
{"type": "file_search"}
] # This is all we support right now
tool_resources: object | None = {}
tools: list[AssistantTool] | None = [] # This is all we support right now
tool_resources: ToolResources | None = ToolResources()
metadata: object | None = {}
temperature: float | None = 1.0
top_p: float | None = 1.0
Expand Down
51 changes: 45 additions & 6 deletions src/leapfrogai_api/routers/openai/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import HTTPException, APIRouter, status, Header
from openai.types.beta import Assistant, AssistantDeleted
from openai.types.beta.assistant import ToolResources
from openai.types.beta.assistant import ToolResources, ToolResourcesCodeInterpreter
from leapfrogai_api.backend.types import (
CreateAssistantRequest,
ListAssistantsResponse,
Expand All @@ -11,6 +11,7 @@
from leapfrogai_api.routers.supabase_session import Session, get_user_session
from leapfrogai_api.utils.openai_util import validate_tools_typed_dict
from leapfrogai_api.data.crud_assistant_object import CRUDAssistant
from leapfrogai_api.routers.supabase_session import Session

router = APIRouter(prefix="/openai/v1/assistants", tags=["openai/assistants"])

Expand All @@ -23,6 +24,23 @@ async def create_assistant(
) -> Assistant:
"""Create an assistant."""

if request.tools is not None:
for tool in request.tools:
if tool.type not in ["file_search"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported tool type: {tool.type}",
)

if request.tool_resources is not None:
for tool_resource in request.tool_resources:
if tool_resource is ToolResourcesCodeInterpreter:
if tool_resource["file_ids"] is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Code interpreter tool is not supported",
)

try:
assistant = Assistant(
id="", # This is set by the database to prevent conflicts
Expand All @@ -32,8 +50,8 @@ async def create_assistant(
instructions=request.instructions,
model=request.model,
object="assistant",
tools=validate_tools_typed_dict(request.tools),
tool_resources=ToolResources.model_validate(request.tool_resources),
tools=request.tools,
tool_resources=request.tool_resources,
temperature=request.temperature,
top_p=request.top_p,
metadata=request.metadata,
Expand Down Expand Up @@ -124,6 +142,24 @@ async def modify_assistant(
- metadata
- response_format
"""

if request.tools is not None:
for tool in request.tools:
if tool.type not in ["file_search"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported tool type: {tool.type}",
)

if request.tool_resources is not None:
for tool_resource in request.tool_resources:
if tool_resource is ToolResourcesCodeInterpreter:
if tool_resource["file_ids"] is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Code interpreter tool is not supported",
)

crud_assistant = CRUDAssistant(model=Assistant)

old_assistant = await crud_assistant.get(
Expand All @@ -143,10 +179,12 @@ async def modify_assistant(
instructions=request.instructions or old_assistant.instructions,
model=request.model or old_assistant.model,
object="assistant",
tools=validate_tools_typed_dict(request.tools) or old_assistant.tools,
tools=request.tools or old_assistant.tools,
tool_resources=ToolResources.model_validate(request.tool_resources)
or old_assistant.tool_resources,
temperature=request.temperature or old_assistant.temperature,
temperature=float(request.temperature)
if request.temperature is not None
else old_assistant.temperature,
top_p=request.top_p or old_assistant.top_p,
metadata=request.metadata or old_assistant.metadata,
response_format=request.response_format or old_assistant.response_format,
Expand All @@ -165,7 +203,8 @@ async def modify_assistant(
)
except FileNotFoundError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Assistant not found"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update assistant",
) from exc


Expand Down
2 changes: 1 addition & 1 deletion src/leapfrogai_api/routers/openai/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def upload_file(
file_object = FileObject(
id="", # This is set by the database to prevent conflicts
bytes=request.file.size,
created_at=123, # This is set by the database to prevent conflicts
created_at=0, # This is set by the database to prevent conflicts
filename=request.file.filename,
object="file", # Per OpenAI Spec this should always be file
purpose="assistants", # we only support assistants for now
Expand Down
38 changes: 0 additions & 38 deletions src/leapfrogai_api/utils/openai_util.py

This file was deleted.

122 changes: 97 additions & 25 deletions tests/integration/api/test_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,38 @@

client = TestClient(router)

starting_assistant = Assistant(
id="",
created_at=0,
name="test",
description="test",
instructions="test",
model="test",
object="assistant",
tools=[{"type": "file_search"}],
tool_resources={},
temperature=1.0,
top_p=1.0,
metadata={},
response_format="auto",
)

modified_assistant = Assistant(
id="",
created_at=0,
name="test1",
description="test1",
instructions="test1",
model="test1",
object="assistant",
tools=[{"type": "file_search"}],
tool_resources={},
temperature=0,
top_p=0.1,
metadata={},
response_format="auto",
)


@pytest.fixture(scope="session", autouse=True)
def create_assistant():
Expand All @@ -23,21 +55,56 @@ def create_assistant():
global assistant_response # pylint: disable=global-statement

request = CreateAssistantRequest(
model="test",
name="test",
description="test",
instructions="test",
tools=[{"type": "file_search"}],
tool_resources={},
metadata={},
temperature=1.0,
top_p=1.0,
response_format="auto",
model=starting_assistant.model,
name=starting_assistant.name,
description=starting_assistant.description,
instructions=starting_assistant.instructions,
tools=starting_assistant.tools,
tool_resources=starting_assistant.tool_resources,
metadata=starting_assistant.metadata,
temperature=starting_assistant.temperature,
top_p=starting_assistant.top_p,
response_format=starting_assistant.response_format,
)

assistant_response = client.post("/openai/v1/assistants", json=request.model_dump())


@pytest.mark.xfail
def test_code_interpreter_fails():
"""Test creating an assistant with a code interpreter tool. Requires a running Supabase instance."""
request = CreateAssistantRequest(
model=modified_assistant.model,
name=modified_assistant.name,
description=modified_assistant.description,
instructions=modified_assistant.instructions,
tools=[{"type": "code_interpreter"}],
tool_resources=modified_assistant.tool_resources,
metadata=modified_assistant.metadata,
temperature=modified_assistant.temperature,
top_p=modified_assistant.top_p,
response_format=modified_assistant,
)

assistant_fail_response = client.post(
"/openai/v1/assistants", json=request.model_dump()
)

assert assistant_fail_response.status_code is status.HTTP_400_BAD_REQUEST
assert (
assistant_fail_response.json()["detail"]
== "Unsupported tool type: code_interpreter"
)

modify_response = client.post(
f"/openai/v1/assistants/{assistant_response.json()['id']}",
json=request.model_dump(),
)

assert modify_response.status_code is status.HTTP_400_BAD_REQUEST
assert modify_response.json()["detail"] == "Unsupported tool type: code_interpreter"


def test_create():
"""Test creating an assistant. Requires a running Supabase instance."""
assert assistant_response.status_code is status.HTTP_200_OK
Expand Down Expand Up @@ -68,26 +135,27 @@ def test_list():

def test_modify():
"""Test modifying an assistant. Requires a running Supabase instance."""

global modified_assistant # pylint: disable=global-statement

assistant_id = assistant_response.json()["id"]
get_response = client.get(f"/openai/v1/assistants/{assistant_id}")
assert get_response.status_code is status.HTTP_200_OK
assert Assistant.model_validate(
get_response.json()
), f"Get endpoint should return Assistant {assistant_id}."

modified_name = "test1"

request = ModifyAssistantRequest(
model="test1",
name=modified_name,
description="test1",
instructions="test1",
tools=[{"type": "file_search"}],
tool_resources={},
metadata={},
temperature=1.0,
top_p=1.0,
response_format="auto",
model=modified_assistant.model,
name=modified_assistant.name,
description=modified_assistant.description,
instructions=modified_assistant.instructions,
tools=modified_assistant.tools,
tool_resources=modified_assistant.tool_resources,
metadata=modified_assistant.metadata,
temperature=modified_assistant.temperature,
top_p=modified_assistant.top_p,
response_format=modified_assistant.response_format,
)

modify_response = client.post(
Expand All @@ -98,9 +166,13 @@ def test_modify():
assert Assistant.model_validate(
modify_response.json()
), "Should return a Assistant."
assert (
modify_response.json()["name"] == modified_name
), f"Assistant {assistant_id} should be modified via modify endpoint."

modified_assistant.id = modify_response.json()["id"]
modified_assistant.created_at = modify_response.json()["created_at"]

assert modified_assistant == Assistant(
**modify_response.json()
), f"Modify endpoint should return modified Assistant {assistant_id}."

get_modified_response = client.get(f"/openai/v1/assistants/{assistant_id}")
assert get_modified_response.status_code is status.HTTP_200_OK
Expand Down

0 comments on commit c9f8c55

Please sign in to comment.