Skip to content

Commit

Permalink
fix: FastAPI Server pydantic 1.x compatibility
Browse files Browse the repository at this point in the history
Refs: #2137
  • Loading branch information
tazarov committed Jul 23, 2024
1 parent 7f52a7f commit f243a54
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from typing import Any, Callable, cast, Dict, List, Sequence, Optional, Tuple
from typing import (
Any,
Callable,
cast,
Dict,
List,
Sequence,
Optional,
Tuple,
Type,
TypeVar,
)
import fastapi
import orjson
from anyio import (
Expand All @@ -13,6 +24,7 @@
from uuid import UUID

from chromadb.api.configuration import CollectionConfigurationInternal
from pydantic import BaseModel
from chromadb.api.types import GetResult, QueryResult
from chromadb.auth import (
AuthzAction,
Expand Down Expand Up @@ -90,6 +102,16 @@ async def check_http_version_middleware(
return await call_next(request)


D = TypeVar("D", bound=BaseModel, contravariant=True)


def validate_model(model: Type[D], data: Any) -> D:
try:
return model.model_validate(data)
except AttributeError:
return model.parse_obj(data)


class ChromaAPIRouter(fastapi.APIRouter): # type: ignore
# A simple subclass of fastapi's APIRouter which treats URLs with a
# trailing "/" the same as URLs without. Docs will only contain URLs
Expand Down Expand Up @@ -375,7 +397,8 @@ async def create_database(
def process_create_database(
tenant: str, headers: Headers, raw_body: bytes
) -> None:
db = CreateDatabase.model_validate(orjson.loads(raw_body))
db = validate_model(CreateDatabase, orjson.loads(raw_body))

(
maybe_tenant,
maybe_database,
Expand Down Expand Up @@ -438,7 +461,7 @@ async def create_tenant(
self, request: Request, body: CreateTenant = Body(...)
) -> None:
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = CreateTenant.model_validate(orjson.loads(raw_body))
tenant = validate_model(CreateTenant, orjson.loads(raw_body))

maybe_tenant, _ = self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -565,7 +588,7 @@ async def create_collection(
def process_create_collection(
request: Request, tenant: str, database: str, raw_body: bytes
) -> CollectionModel:
create = CreateCollection.model_validate(orjson.loads(raw_body))
create = validate_model(CreateCollection, orjson.loads(raw_body))
configuration = (
CollectionConfigurationInternal()
if not create.configuration
Expand Down Expand Up @@ -652,7 +675,7 @@ async def update_collection(
def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = UpdateCollection.model_validate(orjson.loads(raw_body))
update = validate_model(UpdateCollection, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPDATE_COLLECTION,
Expand Down Expand Up @@ -712,7 +735,7 @@ async def add(
try:

def process_add(request: Request, raw_body: bytes) -> bool:
add = AddEmbedding.model_validate(orjson.loads(raw_body))
add = validate_model(AddEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.ADD,
Expand Down Expand Up @@ -746,7 +769,7 @@ async def update(
self, request: Request, collection_id: str, body: UpdateEmbedding = Body(...)
) -> None:
def process_update(request: Request, raw_body: bytes) -> bool:
update = UpdateEmbedding.model_validate(orjson.loads(raw_body))
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -777,7 +800,7 @@ async def upsert(
self, request: Request, collection_id: str, body: AddEmbedding = Body(...)
) -> None:
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = AddEmbedding.model_validate(orjson.loads(raw_body))
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -808,7 +831,7 @@ async def get(
self, collection_id: str, request: Request, body: GetEmbedding = Body(...)
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = GetEmbedding.model_validate(orjson.loads(raw_body))
get = validate_model(GetEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET,
Expand Down Expand Up @@ -842,7 +865,7 @@ async def delete(
self, collection_id: str, request: Request, body: DeleteEmbedding = Body(...)
) -> List[UUID]:
def process_delete(request: Request, raw_body: bytes) -> List[str]:
delete = DeleteEmbedding.model_validate(orjson.loads(raw_body))
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.DELETE,
Expand Down Expand Up @@ -919,7 +942,7 @@ async def get_nearest_neighbors(
body: QueryEmbedding = Body(...),
) -> QueryResult:
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = QueryEmbedding.model_validate(orjson.loads(raw_body))
query = validate_model(QueryEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down

0 comments on commit f243a54

Please sign in to comment.