Skip to content

Commit

Permalink
[BUG]: Pydantic 1.9+ compatibility (#2229)
Browse files Browse the repository at this point in the history
Closes #2503
Closes #2137

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Support for pydantic 1.x in FastAPI server.

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
N/A
  • Loading branch information
tazarov authored Jul 24, 2024
1 parent 1c0fb13 commit d62c13d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
46 changes: 35 additions & 11 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from typing import Any, Callable, cast, Dict, List, Sequence, Optional, Tuple
from typing import (
Any,
Callable,
cast,
Dict,
List,
Sequence,
Optional,
Tuple,
Type,
TypeVar,
)
import fastapi
import orjson
from anyio import (
Expand All @@ -13,6 +24,7 @@
from uuid import UUID

from chromadb.api.configuration import CollectionConfigurationInternal
from pydantic import BaseModel
from chromadb.api.types import GetResult, QueryResult
from chromadb.auth import (
AuthzAction,
Expand Down Expand Up @@ -90,6 +102,17 @@ async def check_http_version_middleware(
return await call_next(request)


D = TypeVar("D", bound=BaseModel, contravariant=True)


def validate_model(model: Type[D], data: Any) -> D:
"""Used for backward compatibility with Pydantic 1.x"""
try:
return model.model_validate(data) # pydantic 2.x
except AttributeError:
return model.parse_obj(data) # pydantic 1.x


class ChromaAPIRouter(fastapi.APIRouter): # type: ignore
# A simple subclass of fastapi's APIRouter which treats URLs with a
# trailing "/" the same as URLs without. Docs will only contain URLs
Expand Down Expand Up @@ -375,7 +398,8 @@ async def create_database(
def process_create_database(
tenant: str, headers: Headers, raw_body: bytes
) -> None:
db = CreateDatabase.model_validate(orjson.loads(raw_body))
db = validate_model(CreateDatabase, orjson.loads(raw_body))

(
maybe_tenant,
maybe_database,
Expand Down Expand Up @@ -438,7 +462,7 @@ async def create_tenant(
self, request: Request, body: CreateTenant = Body(...)
) -> None:
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = CreateTenant.model_validate(orjson.loads(raw_body))
tenant = validate_model(CreateTenant, orjson.loads(raw_body))

maybe_tenant, _ = self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -565,7 +589,7 @@ async def create_collection(
def process_create_collection(
request: Request, tenant: str, database: str, raw_body: bytes
) -> CollectionModel:
create = CreateCollection.model_validate(orjson.loads(raw_body))
create = validate_model(CreateCollection, orjson.loads(raw_body))
configuration = (
CollectionConfigurationInternal()
if not create.configuration
Expand Down Expand Up @@ -652,7 +676,7 @@ async def update_collection(
def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = UpdateCollection.model_validate(orjson.loads(raw_body))
update = validate_model(UpdateCollection, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPDATE_COLLECTION,
Expand Down Expand Up @@ -712,7 +736,7 @@ async def add(
try:

def process_add(request: Request, raw_body: bytes) -> bool:
add = AddEmbedding.model_validate(orjson.loads(raw_body))
add = validate_model(AddEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.ADD,
Expand Down Expand Up @@ -746,7 +770,7 @@ async def update(
self, request: Request, collection_id: str, body: UpdateEmbedding = Body(...)
) -> None:
def process_update(request: Request, raw_body: bytes) -> bool:
update = UpdateEmbedding.model_validate(orjson.loads(raw_body))
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -777,7 +801,7 @@ async def upsert(
self, request: Request, collection_id: str, body: AddEmbedding = Body(...)
) -> None:
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = AddEmbedding.model_validate(orjson.loads(raw_body))
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down Expand Up @@ -808,7 +832,7 @@ async def get(
self, collection_id: str, request: Request, body: GetEmbedding = Body(...)
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = GetEmbedding.model_validate(orjson.loads(raw_body))
get = validate_model(GetEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET,
Expand Down Expand Up @@ -842,7 +866,7 @@ async def delete(
self, collection_id: str, request: Request, body: DeleteEmbedding = Body(...)
) -> List[UUID]:
def process_delete(request: Request, raw_body: bytes) -> List[str]:
delete = DeleteEmbedding.model_validate(orjson.loads(raw_body))
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.DELETE,
Expand Down Expand Up @@ -919,7 +943,7 @@ async def get_nearest_neighbors(
body: QueryEmbedding = Body(...),
) -> QueryResult:
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = QueryEmbedding.model_validate(orjson.loads(raw_body))
query = validate_model(QueryEmbedding, orjson.loads(raw_body))

self.auth_and_get_tenant_and_database_for_request(
request.headers,
Expand Down
17 changes: 13 additions & 4 deletions chromadb/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __getitem__(self, key: str) -> Optional[Any]:
if key == "configuration":
return self.get_configuration()
# For the other model attributes we allow the user to access them directly
if key in self.model_fields:
if key in self.get_model_fields():
return getattr(self, key)
return None

Expand All @@ -106,16 +106,18 @@ def __setitem__(self, key: str, value: Any) -> None:
# For the model attributes we allow the user to access them directly
if key == "configuration":
self.set_configuration(value)
if key in self.model_fields:
if key in self.get_model_fields():
setattr(self, key, value)
else:
raise KeyError(f"No such key: {key}, valid keys are: {self.model_fields}")
raise KeyError(
f"No such key: {key}, valid keys are: {self.get_model_fields()}"
)

def __eq__(self, __value: object) -> bool:
# Check that all the model fields are equal
if not isinstance(__value, Collection):
return False
for field in self.model_fields:
for field in self.get_model_fields():
if getattr(self, field) != getattr(__value, field):
return False
return True
Expand All @@ -128,6 +130,13 @@ def set_configuration(self, configuration: CollectionConfigurationInternal) -> N
"""Sets the configuration of the collection"""
self.configuration_json = configuration.to_json()

def get_model_fields(self) -> Dict[Any, Any]:
"""Used for backward compatibility with Pydantic 1.x"""
try:
return self.model_fields # pydantic 2.x
except AttributeError:
return self.__fields__ # pydantic 1.x

@classmethod
@override
def from_json(cls, json_map: Dict[str, Any]) -> Self:
Expand Down

0 comments on commit d62c13d

Please sign in to comment.