Skip to content

Commit

Permalink
[CHORE] Add support for pydantic v2 (#1174)
Browse files Browse the repository at this point in the history
## Description of changes
Closes #893 

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Adds support for pydantic v2.0 by changing how Collection model inits
- this simple change fixes pydantic v2
	 - Fixes the cross version tests to handle pydantic specifically
- Conditionally imports pydantic-settings based on what is available. In
v2 BaseSettings was moved to a new package.
 - New functionality
	 - N/A

## Test plan
Existing tests were run with the following configs
1. Fastapi < 0.100, Pydantic >= 2.0 - Unsupported as the fastapi
dependencies will not allow it. They likely should, as pydantic.v1
imports would support this, but this is a downstream issue.
2. Fastapi >= 0.100, Pydantic >= 2.0, Supported via normal imports ✅
(Tested with fastapi==0.103.1, pydantic==2.3.0)
3. Fastapi < 0.100 Pydantic < 2.0, Supported via normal imports ✅
(Tested with fastapi==0.95.2, pydantic==1.9.2)
4. Fastapi >= 0.100, Pydantic < 2.0, Supported via normal imports ✅
(Tested with latest fastapi, pydantic==1.9.2)

- [x] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
None required.
  • Loading branch information
HammadB authored Sep 25, 2023
1 parent c7a0414 commit 8a6ad07
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 11 deletions.
3 changes: 2 additions & 1 deletion chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Optional, Tuple, cast, List
from pydantic import BaseModel, PrivateAttr

from uuid import UUID
import chromadb.utils.embedding_functions as ef

Expand Down Expand Up @@ -50,9 +51,9 @@ def __init__(
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
metadata: Optional[CollectionMetadata] = None,
):
super().__init__(name=name, metadata=metadata, id=id)
self._client = client
self._embedding_function = embedding_function
super().__init__(name=name, metadata=metadata, id=id)

def __repr__(self) -> str:
return f"Collection(name={self.name})"
Expand Down
1 change: 0 additions & 1 deletion chromadb/auth/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import requests
from overrides import override
from pydantic import SecretStr

from chromadb.auth import (
ServerAuthCredentialsProvider,
AbstractCredentials,
Expand Down
13 changes: 12 additions & 1 deletion chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,20 @@

from overrides import EnforceOverrides
from overrides import override
from pydantic import BaseSettings, validator
from typing_extensions import Literal


in_pydantic_v2 = False
try:
from pydantic import BaseSettings
except ImportError:
in_pydantic_v2 = True
from pydantic.v1 import BaseSettings
from pydantic.v1 import validator

if not in_pydantic_v2:
from pydantic import validator # type: ignore # noqa

# The thin client will have a flag to control which implementations to use
is_thin_client = False
try:
Expand Down
12 changes: 9 additions & 3 deletions chromadb/test/property/test_cross_version_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
MINIMUM_VERSION = "0.4.1"
version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$")

# Some modules do not work across versions, since we upgrade our support for them, and should be explicitly reimported in the subprocess
VERSIONED_MODULES = ["pydantic"]


def versions() -> List[str]:
"""Returns the pinned minimum version and the latest version of chromadb."""
Expand All @@ -49,7 +52,7 @@ def _patch_boolean_metadata(
# boolean value metadata to int
collection_metadata = collection.metadata
if collection_metadata is not None:
_bool_to_int(collection_metadata)
_bool_to_int(collection_metadata) # type: ignore

if embeddings["metadatas"] is not None:
if isinstance(embeddings["metadatas"], list):
Expand Down Expand Up @@ -162,7 +165,10 @@ def switch_to_version(version: str) -> ModuleType:
old_modules = {
n: m
for n, m in sys.modules.items()
if n == module_name or (n.startswith(module_name + "."))
if n == module_name
or (n.startswith(module_name + "."))
or n in VERSIONED_MODULES
or (any(n.startswith(m + ".") for m in VERSIONED_MODULES))
}
for n in old_modules:
del sys.modules[n]
Expand Down Expand Up @@ -197,7 +203,7 @@ def persist_generated_data_with_old_version(
api.reset()
coll = api.create_collection(
name=collection_strategy.name,
metadata=collection_strategy.metadata,
metadata=collection_strategy.metadata, # type: ignore
# In order to test old versions, we can't rely on the not_implemented function
embedding_function=not_implemented_ef(),
)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ classifiers = [
]
dependencies = [
'requests >= 2.28',
'pydantic>=1.9,<2.0',
'pydantic >= 1.9',
'chroma-hnswlib==0.7.3',
'fastapi>=0.95.2, <0.100.0',
'fastapi >= 0.95.2',
'uvicorn[standard] >= 0.18.3',
'numpy == 1.21.6; python_version < "3.8"',
'numpy >= 1.22.5; python_version >= "3.8"',
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
bcrypt==4.0.1
chroma-hnswlib==0.7.3
fastapi>=0.95.2, <0.100.0
fastapi>=0.95.2
graphlib_backport==1.0.3; python_version < '3.9'
importlib-resources
numpy==1.21.6; python_version < '3.8'
Expand All @@ -9,11 +9,11 @@ onnxruntime==1.14.1
overrides==7.3.1
posthog==2.4.0
pulsar-client==3.1.0
pydantic>=1.9,<2.0
pydantic>=1.9
pypika==0.48.9
requests==2.28.1
tokenizers==0.13.2
tqdm==4.65.0
typer>=0.9.0
typing_extensions==4.5.0
typing_extensions>=4.5.0
uvicorn[standard]==0.18.3

0 comments on commit 8a6ad07

Please sign in to comment.