diff --git a/.dockerignore b/.dockerignore
index 9f25234e15f..4fcf5c716cd 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,3 +1,9 @@
venv
.git
-examples
\ No newline at end of file
+examples
+clients
+.hypothesis
+__pycache__
+.vscode
+*.egg-info
+.pytest_cache
\ No newline at end of file
diff --git a/.github/workflows/chroma-integration-test.yml b/.github/workflows/chroma-integration-test.yml
new file mode 100644
index 00000000000..7883371fd52
--- /dev/null
+++ b/.github/workflows/chroma-integration-test.yml
@@ -0,0 +1,37 @@
+name: Chroma Integration Tests
+
+on:
+ push:
+ branches:
+ - main
+ - team/hypothesis-tests
+ pull_request:
+ branches:
+ - main
+ - team/hypothesis-tests
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ python: ['3.7']
+ platform: [ubuntu-latest]
+ testfile: ["--ignore-glob 'chromadb/test/property/*'",
+ "chromadb/test/property/test_add.py",
+ "chromadb/test/property/test_collections.py",
+ "chromadb/test/property/test_cross_version_persist.py",
+ "chromadb/test/property/test_embeddings.py",
+ "chromadb/test/property/test_filtering.py",
+ "chromadb/test/property/test_persist.py"]
+ runs-on: ${{ matrix.platform }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ - name: Set up Python ${{ matrix.python }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python }}
+ - name: Install test dependencies
+ run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
+ - name: Integration Test
+ run: bin/integration-test ${{ matrix.testfile }}
\ No newline at end of file
diff --git a/.github/workflows/chroma-test.yml b/.github/workflows/chroma-test.yml
index ec1c7e37c3c..142d0971f4e 100644
--- a/.github/workflows/chroma-test.yml
+++ b/.github/workflows/chroma-test.yml
@@ -4,16 +4,26 @@ on:
push:
branches:
- main
+ - team/hypothesis-tests
pull_request:
branches:
- main
+ - team/hypothesis-tests
jobs:
test:
+ timeout-minutes: 90
strategy:
matrix:
- python: ['3.10']
+ python: ['3.7', '3.8', '3.9', '3.10']
platform: [ubuntu-latest]
+ testfile: ["--ignore-glob 'chromadb/test/property/*'",
+ "chromadb/test/property/test_add.py",
+ "chromadb/test/property/test_collections.py",
+ "chromadb/test/property/test_cross_version_persist.py",
+ "chromadb/test/property/test_embeddings.py",
+ "chromadb/test/property/test_filtering.py",
+ "chromadb/test/property/test_persist.py"]
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout
@@ -25,6 +35,4 @@ jobs:
- name: Install test dependencies
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
- name: Test
- run: python -m pytest
- - name: Integration Test
- run: bin/integration-test
\ No newline at end of file
+ run: python -m pytest ${{ matrix.testfile }}
diff --git a/.github/workflows/pr-review-checklist.yml b/.github/workflows/pr-review-checklist.yml
new file mode 100644
index 00000000000..6b7c9d38122
--- /dev/null
+++ b/.github/workflows/pr-review-checklist.yml
@@ -0,0 +1,37 @@
+name: PR Review Checklist
+
+on:
+ pull_request_target:
+ types:
+ - opened
+
+jobs:
+ PR-Comment:
+ runs-on: ubuntu-latest
+ steps:
+ - name: PR Comment
+ uses: actions/github-script@v2
+ with:
+ github-token: ${{secrets.GITHUB_TOKEN}}
+ script: |
+ github.issues.createComment({
+ issue_number: ${{ github.event.number }},
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ body: `# Reviewer Checklist
+ Please leverage this checklist to ensure your code review is thorough before approving
+ ## Testing, Bugs, Errors, Logs, Documentation
+ - [ ] Can you think of any use case in which the code does not behave as intended? Have they been tested?
+ - [ ] Can you think of any inputs or external events that could break the code? Is user input validated and safe? Have they been tested?
+ - [ ] If appropriate, are there adequate property based tests?
+ - [ ] If appropriate, are there adequate unit tests?
+ - [ ] Should any logging, debugging, tracing information be added or removed?
+ - [ ] Are error messages user-friendly?
+ - [ ] Have all documentation changes needed been made?
+ - [ ] Have all non-obvious changes been commented?
+ ## System Compatibility
+ - [ ] Are there any potential impacts on other parts of the system or backward compatibility?
+ - [ ] Does this change intersect with any items on our roadmap, and if so, is there a plan for fitting them together?
+ ## Quality
+ - [ ] Is this code of a unexpectedly high quality (Readbility, Modularity, Intuitiveness)`
+ })
diff --git a/.gitignore b/.gitignore
index e4a4bed1330..de36093c7f6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,4 +21,4 @@ dist
.terraform.lock.hcl
terraform.tfstate
.hypothesis/
-.idea
\ No newline at end of file
+.idea
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 78903060c50..4ec74c4e3d7 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,23 +1,28 @@
{
- "git.ignoreLimitWarning": true,
- "editor.rulers": [
- 120
- ],
- "editor.formatOnSave": true,
- "python.formatting.provider": "black",
- "files.exclude": {
- "**/__pycache__": true,
- "**/.ipynb_checkpoints": true,
- "**/.pytest_cache": true,
- "**/chroma.egg-info": true
- },
- "python.analysis.typeCheckingMode": "basic",
- "python.linting.flake8Enabled": true,
- "python.linting.enabled": true,
- "python.linting.flake8Args": [
- "--extend-ignore=E203",
- "--extend-ignore=E501",
- "--extend-ignore=E503",
- "--max-line-length=88",
- ],
-}
+ "git.ignoreLimitWarning": true,
+ "editor.rulers": [
+ 120
+ ],
+ "editor.formatOnSave": true,
+ "python.formatting.provider": "black",
+ "files.exclude": {
+ "**/__pycache__": true,
+ "**/.ipynb_checkpoints": true,
+ "**/.pytest_cache": true,
+ "**/chroma.egg-info": true
+ },
+ "python.analysis.typeCheckingMode": "basic",
+ "python.linting.flake8Enabled": true,
+ "python.linting.enabled": true,
+ "python.linting.flake8Args": [
+ "--extend-ignore=E203",
+ "--extend-ignore=E501",
+ "--extend-ignore=E503",
+ "--max-line-length=88"
+ ],
+ "python.testing.pytestArgs": [
+ "."
+ ],
+ "python.testing.unittestEnabled": false,
+ "python.testing.pytestEnabled": true
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index 252ae58b579..8e7dc8339dd 100644
--- a/README.md
+++ b/README.md
@@ -13,10 +13,10 @@
|
- |
+ |
Docs
- |
+ |
Homepage
@@ -30,19 +30,19 @@ pip install chromadb # python client
The core API is only 4 functions (run our [💡 Google Colab](https://colab.research.google.com/drive/1QEzFyqnoFxq7LUGyP1vzR4iLt9PpCDXv?usp=sharing) or [Replit template](https://replit.com/@swyx/BasicChromaStarter?v=1)):
-```python
+```python
import chromadb
# setup Chroma in-memory, for easy prototyping. Can add persistence easily!
client = chromadb.Client()
# Create collection. get_collection, get_or_create_collection, delete_collection also available!
-collection = client.create_collection("all-my-documents")
+collection = client.create_collection("all-my-documents")
# Add docs to the collection. Can also update and delete. Row-based API coming soon!
collection.add(
documents=["This is document1", "This is document2"], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on these!
- ids=["doc1", "doc2"], # unique for each doc
+ ids=["doc1", "doc2"], # unique for each doc
)
# Query/search 2 most similar results. You can also .get by id
@@ -66,15 +66,15 @@ results = collection.query(
For example, the `"Chat your data"` use case:
1. Add documents to your database. You can pass in your own embeddings, embedding function, or let Chroma embed them for you.
2. Query relevant documents with natural language.
-3. Compose documents into the context window of an LLM like `GPT3` for additional summarization or analysis.
+3. Compose documents into the context window of an LLM like `GPT3` for additional summarization or analysis.
## Embeddings?
What are embeddings?
- [Read the guide from OpenAI](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
-- __Literal__: Embedding something turns it from image/text/audio into a list of numbers. 🖼️ or 📄 => `[1.2, 2.1, ....]`. This process makes documents "understandable" to a machine learning model.
-- __By analogy__: An embedding represents the essence of a document. This enables documents and queries with the same essence to be "near" each other and therefore easy to find.
+- __Literal__: Embedding something turns it from image/text/audio into a list of numbers. 🖼️ or 📄 => `[1.2, 2.1, ....]`. This process makes documents "understandable" to a machine learning model.
+- __By analogy__: An embedding represents the essence of a document. This enables documents and queries with the same essence to be "near" each other and therefore easy to find.
- __Technical__: An embedding is the latent-space position of a document at a layer of a deep neural network. For models trained specifically to embed data, this is the last layer.
- __A small example__: If you search your photos for "famous bridge in San Francisco". By embedding this query and comparing it to the embeddings of your photos and their metadata - it should return photos of the Golden Gate Bridge.
@@ -82,7 +82,7 @@ Embeddings databases (also known as **vector databases**) store embeddings and a
## Get involved
-Chroma is a rapidly developing project. We welcome PR contributors and ideas for how to improve the project.
+Chroma is a rapidly developing project. We welcome PR contributors and ideas for how to improve the project.
- [Join the conversation on Discord](https://discord.gg/MMeYNTmh3x)
- [Review the roadmap and contribute your ideas](https://docs.trychroma.com/roadmap)
- [Grab an issue and open a PR](https://github.com/chroma-core/chroma/issues)
diff --git a/bin/integration-test b/bin/integration-test
index ea3c4cc2d80..29c49cdb91b 100755
--- a/bin/integration-test
+++ b/bin/integration-test
@@ -12,15 +12,15 @@ trap cleanup EXIT
docker compose -f docker-compose.test.yml up --build -d
-export CHROMA_INTEGRATION_TEST=1
+export CHROMA_INTEGRATION_TEST_ONLY=1
export CHROMA_API_IMPL=rest
export CHROMA_SERVER_HOST=localhost
export CHROMA_SERVER_HTTP_PORT=8000
-python -m pytest
+echo testing: python -m pytest "$@"
+python -m pytest "$@"
cd clients/js
yarn
yarn test:run
-cd ../..
-
+cd ../..
\ No newline at end of file
diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py
index 90c3446e88f..e3b6ff35e3b 100644
--- a/chromadb/api/__init__.py
+++ b/chromadb/api/__init__.py
@@ -61,7 +61,8 @@ def create_collection(
Args:
name (str): The name of the collection to create. The name must be unique.
metadata (Optional[Dict], optional): A dictionary of metadata to associate with the collection. Defaults to None.
- get_or_create (bool, optional): If True, will return the collection if it already exists. Defaults to False.
+ get_or_create (bool, optional): If True, will return the collection if it already exists,
+ and update the metadata (if applicable). Defaults to False.
embedding_function (Optional[Callable], optional): A function that takes documents and returns an embedding. Defaults to None.
Returns:
@@ -82,8 +83,11 @@ def delete_collection(
"""
@abstractmethod
- def get_or_create_collection(self, name: str, metadata: Optional[Dict] = None) -> Collection:
- """Calls create_collection with get_or_create=True
+ def get_or_create_collection(
+ self, name: str, metadata: Optional[Dict] = None
+ ) -> Collection:
+ """Calls create_collection with get_or_create=True.
+ If the collection exists, but with different metadata, the metadata will be replaced.
Args:
name (str): The name of the collection to create. The name must be unique.
@@ -141,7 +145,7 @@ def _add(
⚠️ It is recommended to use the more specific methods below when possible.
Args:
- collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to
+ collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to
embedding (Sequence[Sequence[float]]): The sequence of embeddings to add
metadata (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None.
documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None.
@@ -162,17 +166,40 @@ def _update(
⚠️ It is recommended to use the more specific methods below when possible.
Args:
- collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to
+ collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to
embedding (Sequence[Sequence[float]]): The sequence of embeddings to add
"""
pass
+ @abstractmethod
+ def _upsert(
+ self,
+ collection_name: str,
+ ids: IDs,
+ embeddings: Optional[Embeddings] = None,
+ metadatas: Optional[Metadatas] = None,
+ documents: Optional[Documents] = None,
+ increment_index: bool = True,
+ ):
+ """Add or update entries in the embedding store.
+ If an entry with the same id already exists, it will be updated, otherwise it will be added.
+
+ Args:
+ collection_name (str): The collection to add the embeddings to
+ ids (Optional[Union[str, Sequence[str]]], optional): The ids to associate with the embeddings. Defaults to None.
+ embeddings (Sequence[Sequence[float]]): The sequence of embeddings to add
+ metadatas (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None.
+ documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None.
+ increment_index (bool, optional): If True, will incrementally add to the ANN index of the collection. Defaults to True.
+ """
+ pass
+
@abstractmethod
def _count(self, collection_name: str) -> int:
"""Returns the number of embeddings in the database
Args:
- collection_name (str): The model space to count the embeddings in.
+ collection_name (str): The collection to count the embeddings in.
Returns:
int: The number of embeddings in the collection
@@ -278,14 +305,19 @@ def raw_sql(self, sql: str) -> pd.DataFrame:
@abstractmethod
def create_index(self, collection_name: Optional[str] = None) -> bool:
- """Creates an index for the given model space
+ """Creates an index for the given collection
⚠️ This method should not be used directly.
Args:
- collection_name (Optional[str], optional): The model space to create the index for. Uses the client's model space if None. Defaults to None.
+ collection_name (Optional[str], optional): The collection to create the index for. Uses the client's collection if None. Defaults to None.
Returns:
bool: True if the index was created successfully
"""
pass
+
+ @abstractmethod
+ def persist(self):
+ """Persist the database to disk"""
+ pass
diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py
index 0be1a087fa7..c5eac52014c 100644
--- a/chromadb/api/fastapi.py
+++ b/chromadb/api/fastapi.py
@@ -15,6 +15,7 @@
from typing import Sequence
from chromadb.api.models.Collection import Collection
from chromadb.telemetry import Telemetry
+import chromadb.errors as errors
class FastAPI(API):
@@ -26,13 +27,13 @@ def __init__(self, settings, telemetry_client: Telemetry):
def heartbeat(self):
"""Returns the current server time in nanoseconds to check if the server is alive"""
resp = requests.get(self._api_url)
- resp.raise_for_status()
+ raise_chroma_error(resp)
return int(resp.json()["nanosecond heartbeat"])
def list_collections(self) -> Sequence[Collection]:
"""Returns a list of all collections"""
resp = requests.get(self._api_url + "/collections")
- resp.raise_for_status()
+ raise_chroma_error(resp)
json_collections = resp.json()
collections = []
for json_collection in json_collections:
@@ -52,7 +53,7 @@ def create_collection(
self._api_url + "/collections",
data=json.dumps({"name": name, "metadata": metadata, "get_or_create": get_or_create}),
)
- resp.raise_for_status()
+ raise_chroma_error(resp)
resp_json = resp.json()
return Collection(
client=self,
@@ -68,7 +69,7 @@ def get_collection(
) -> Collection:
"""Returns a collection"""
resp = requests.get(self._api_url + "/collections/" + name)
- resp.raise_for_status()
+ raise_chroma_error(resp)
resp_json = resp.json()
return Collection(
client=self,
@@ -93,18 +94,18 @@ def _modify(self, current_name: str, new_name: str, new_metadata: Optional[Dict]
self._api_url + "/collections/" + current_name,
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
)
- resp.raise_for_status()
+ raise_chroma_error(resp)
return resp.json()
def delete_collection(self, name: str):
"""Deletes a collection"""
resp = requests.delete(self._api_url + "/collections/" + name)
- resp.raise_for_status()
+ raise_chroma_error(resp)
def _count(self, collection_name: str):
"""Returns the number of embeddings in the database"""
resp = requests.get(self._api_url + "/collections/" + collection_name + "/count")
- resp.raise_for_status()
+ raise_chroma_error(resp)
return resp.json()
def _peek(self, collection_name, limit=10):
@@ -147,7 +148,7 @@ def _get(
),
)
- resp.raise_for_status()
+ raise_chroma_error(resp)
return resp.json()
def _delete(self, collection_name, ids=None, where={}, where_document={}):
@@ -158,7 +159,7 @@ def _delete(self, collection_name, ids=None, where={}, where_document={}):
data=json.dumps({"where": where, "ids": ids, "where_document": where_document}),
)
- resp.raise_for_status()
+ raise_chroma_error(resp)
return resp.json()
def _add(
@@ -180,20 +181,16 @@ def _add(
self._api_url + "/collections/" + collection_name + "/add",
data=json.dumps(
{
+ "ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
- "ids": ids,
"increment_index": increment_index,
}
),
)
- try:
- resp.raise_for_status()
- except requests.HTTPError:
- raise (Exception(resp.text))
-
+ raise_chroma_error(resp)
return True
def _update(
@@ -224,6 +221,36 @@ def _update(
resp.raise_for_status()
return True
+ def _upsert(
+ self,
+ collection_name: str,
+ ids: IDs,
+ embeddings: Embeddings,
+ metadatas: Optional[Metadatas] = None,
+ documents: Optional[Documents] = None,
+ increment_index: bool = True,
+ ):
+ """
+ Updates a batch of embeddings in the database
+ - pass in column oriented data lists
+ """
+
+ resp = requests.post(
+ self._api_url + "/collections/" + collection_name + "/upsert",
+ data=json.dumps(
+ {
+ "ids": ids,
+ "embeddings": embeddings,
+ "metadatas": metadatas,
+ "documents": documents,
+ "increment_index": increment_index,
+ }
+ ),
+ )
+
+ resp.raise_for_status()
+ return True
+
def _query(
self,
collection_name,
@@ -248,43 +275,60 @@ def _query(
),
)
- try:
- resp.raise_for_status()
- except requests.HTTPError:
- raise (Exception(resp.text))
-
+ raise_chroma_error(resp)
body = resp.json()
return body
def reset(self):
"""Resets the database"""
resp = requests.post(self._api_url + "/reset")
- resp.raise_for_status()
+ raise_chroma_error(resp)
return resp.json
def persist(self):
"""Persists the database"""
resp = requests.post(self._api_url + "/persist")
- resp.raise_for_status()
+ raise_chroma_error(resp)
return resp.json
def raw_sql(self, sql):
"""Runs a raw SQL query against the database"""
resp = requests.post(self._api_url + "/raw_sql", data=json.dumps({"raw_sql": sql}))
- resp.raise_for_status()
+ raise_chroma_error(resp)
return pd.DataFrame.from_dict(resp.json())
def create_index(self, collection_name: str):
"""Creates an index for the given space key"""
resp = requests.post(self._api_url + "/collections/" + collection_name + "/create_index")
- try:
- resp.raise_for_status()
- except requests.HTTPError:
- raise (Exception(resp.text))
+ raise_chroma_error(resp)
return resp.json()
def get_version(self):
"""Returns the version of the server"""
resp = requests.get(self._api_url + "/version")
- resp.raise_for_status()
+ raise_chroma_error(resp)
return resp.json()
+
+
+def raise_chroma_error(resp):
+ """Raises an error if the response is not ok, using a ChromaError if possible"""
+ if resp.ok:
+ return
+
+ chroma_error = None
+ try:
+ body = resp.json()
+ if "error" in body:
+ if body["error"] in errors.error_types:
+ chroma_error = errors.error_types[body["error"]](body["message"])
+
+ except BaseException:
+ pass
+
+ if chroma_error:
+ raise chroma_error
+
+ try:
+ resp.raise_for_status()
+ except requests.HTTPError:
+ raise (Exception(resp.text))
diff --git a/chromadb/api/local.py b/chromadb/api/local.py
index 732dc678cc6..1a30d70401e 100644
--- a/chromadb/api/local.py
+++ b/chromadb/api/local.py
@@ -3,7 +3,7 @@
from typing import Dict, List, Optional, Sequence, Callable, cast
from chromadb import __version__
-
+import chromadb.errors as errors
from chromadb.api import API
from chromadb.db import DB
from chromadb.api.types import (
@@ -38,11 +38,11 @@ def check_index_name(index_name):
)
if len(index_name) < 3 or len(index_name) > 63:
raise ValueError(msg)
- if not re.match("^[a-z0-9][a-z0-9._-]*[a-z0-9]$", index_name):
+ if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name):
raise ValueError(msg)
if ".." in index_name:
raise ValueError(msg)
- if re.match("^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$", index_name):
+ if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name):
raise ValueError(msg)
@@ -90,7 +90,10 @@ def create_collection(
res = self._db.create_collection(name, metadata, get_or_create)
return Collection(
- client=self, name=name, embedding_function=embedding_function, metadata=res[0][2]
+ client=self,
+ name=name,
+ embedding_function=embedding_function,
+ metadata=res[0][2],
)
def get_or_create_collection(
@@ -112,7 +115,9 @@ def get_or_create_collection(
>>> client.get_or_create_collection("my_collection")
collection(name="my_collection", metadata={})
"""
- return self.create_collection(name, metadata, embedding_function, get_or_create=True)
+ return self.create_collection(
+ name, metadata, embedding_function, get_or_create=True
+ )
def get_collection(
self,
@@ -138,7 +143,10 @@ def get_collection(
if len(res) == 0:
raise ValueError(f"Collection {name} does not exist")
return Collection(
- client=self, name=name, embedding_function=embedding_function, metadata=res[0][2]
+ client=self,
+ name=name,
+ embedding_function=embedding_function,
+ metadata=res[0][2],
)
def list_collections(self) -> Sequence[Collection]:
@@ -154,7 +162,9 @@ def list_collections(self) -> Sequence[Collection]:
db_collections = self._db.list_collections()
for db_collection in db_collections:
collections.append(
- Collection(client=self, name=db_collection[1], metadata=db_collection[2])
+ Collection(
+ client=self, name=db_collection[1], metadata=db_collection[2]
+ )
)
return collections
@@ -194,6 +204,11 @@ def _add(
documents: Optional[Documents] = None,
increment_index: bool = True,
):
+ existing_ids = self._get(collection_name, ids=ids, include=[])["ids"]
+ if len(existing_ids) > 0:
+ raise errors.IDAlreadyExistsError(
+ f"IDs {existing_ids} already exist in collection {collection_name}"
+ )
collection_uuid = self._db.get_collection_uuid_from_name(collection_name)
added_uuids = self._db.add(
@@ -220,6 +235,65 @@ def _update(
):
collection_uuid = self._db.get_collection_uuid_from_name(collection_name)
self._db.update(collection_uuid, ids, embeddings, metadatas, documents)
+ return True
+
+ def _upsert(
+ self,
+ collection_name: str,
+ ids: IDs,
+ embeddings: Embeddings,
+ metadatas: Optional[Metadatas] = None,
+ documents: Optional[Documents] = None,
+ increment_index: bool = True,
+ ):
+ # Determine which ids need to be added and which need to be updated based on the ids already in the collection
+ existing_ids = set(self._get(collection_name, ids=ids, include=[])["ids"])
+
+ ids_to_add = []
+ ids_to_update = []
+ embeddings_to_add: Embeddings = []
+ embeddings_to_update: Embeddings = []
+ metadatas_to_add: Optional[Metadatas] = [] if metadatas else None
+ metadatas_to_update: Optional[Metadatas] = [] if metadatas else None
+ documents_to_add: Optional[Documents] = [] if documents else None
+ documents_to_update: Optional[Documents] = [] if documents else None
+
+ for i, id in enumerate(ids):
+ if id in existing_ids:
+ ids_to_update.append(id)
+ if embeddings is not None:
+ embeddings_to_update.append(embeddings[i])
+ if metadatas is not None:
+ metadatas_to_update.append(metadatas[i])
+ if documents is not None:
+ documents_to_update.append(documents[i])
+ else:
+ ids_to_add.append(id)
+ if embeddings is not None:
+ embeddings_to_add.append(embeddings[i])
+ if metadatas is not None:
+ metadatas_to_add.append(metadatas[i])
+ if documents is not None:
+ documents_to_add.append(documents[i])
+
+ if len(ids_to_add) > 0:
+ self._add(
+ ids_to_add,
+ collection_name,
+ embeddings_to_add,
+ metadatas_to_add,
+ documents_to_add,
+ increment_index=increment_index,
+ )
+
+ if len(ids_to_update) > 0:
+ self._update(
+ collection_name,
+ ids_to_update,
+ embeddings_to_update,
+ metadatas_to_update,
+ documents_to_update,
+ )
return True
@@ -252,7 +326,9 @@ def _get(
# Remove plural from include since db columns are singular
db_columns = [column[:-1] for column in include] + ["id"]
- column_index = {column_name: index for index, column_name in enumerate(db_columns)}
+ column_index = {
+ column_name: index for index, column_name in enumerate(db_columns)
+ }
db_result = self._db.get(
collection_name=collection_name,
@@ -274,11 +350,17 @@ def _get(
for entry in db_result:
if include_embeddings:
- cast(List, get_result["embeddings"]).append(entry[column_index["embedding"]])
+ cast(List, get_result["embeddings"]).append(
+ entry[column_index["embedding"]]
+ )
if include_documents:
- cast(List, get_result["documents"]).append(entry[column_index["document"]])
+ cast(List, get_result["documents"]).append(
+ entry[column_index["document"]]
+ )
if include_metadatas:
- cast(List, get_result["metadatas"]).append(entry[column_index["metadata"]])
+ cast(List, get_result["metadatas"]).append(
+ entry[column_index["metadata"]]
+ )
get_result["ids"].append(entry[column_index["id"]])
return get_result
@@ -291,9 +373,14 @@ def _delete(self, collection_name, ids=None, where=None, where_document=None):
collection_uuid = self._db.get_collection_uuid_from_name(collection_name)
deleted_uuids = self._db.delete(
- collection_uuid=collection_uuid, where=where, ids=ids, where_document=where_document
+ collection_uuid=collection_uuid,
+ where=where,
+ ids=ids,
+ where_document=where_document,
+ )
+ self._telemetry_client.capture(
+ CollectionDeleteEvent(collection_uuid, len(deleted_uuids))
)
- self._telemetry_client.capture(CollectionDeleteEvent(collection_uuid, len(deleted_uuids)))
return deleted_uuids
def _count(self, collection_name):
@@ -344,8 +431,12 @@ def _query(
ids = []
metadatas = []
# Remove plural from include since db columns are singular
- db_columns = [column[:-1] for column in include if column != "distances"] + ["id"]
- column_index = {column_name: index for index, column_name in enumerate(db_columns)}
+ db_columns = [
+ column[:-1] for column in include if column != "distances"
+ ] + ["id"]
+ column_index = {
+ column_name: index for index, column_name in enumerate(db_columns)
+ }
db_result = self._db.get_by_ids(uuids[i], columns=db_columns)
for entry in db_result:
diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py
index 9e8b2924c19..6aa9958d6df 100644
--- a/chromadb/api/models/Collection.py
+++ b/chromadb/api/models/Collection.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Optional, cast, List, Dict
+from typing import TYPE_CHECKING, Optional, cast, List, Dict, Tuple
from pydantic import BaseModel, PrivateAttr
from chromadb.api.types import (
@@ -42,7 +42,6 @@ def __init__(
embedding_function: Optional[EmbeddingFunction] = None,
metadata: Optional[Dict] = None,
):
-
self._client = client
if embedding_function is not None:
self._embedding_function = embedding_function
@@ -95,36 +94,13 @@ def add(
"""
- ids = validate_ids(maybe_cast_one_to_many(ids))
- embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None
- metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None
- documents = maybe_cast_one_to_many(documents) if documents else None
-
- # Check that one of embeddings or documents is provided
- if embeddings is None and documents is None:
- raise ValueError("You must provide either embeddings or documents, or both")
-
- # Check that, if they're provided, the lengths of the arrays match the length of ids
- if embeddings is not None and len(embeddings) != len(ids):
- raise ValueError(
- f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}"
- )
- if metadatas is not None and len(metadatas) != len(ids):
- raise ValueError(
- f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}"
- )
- if documents is not None and len(documents) != len(ids):
- raise ValueError(
- f"Number of documents {len(documents)} must match number of ids {len(ids)}"
- )
-
- # If document embeddings are not provided, we need to compute them
- if embeddings is None and documents is not None:
- if self._embedding_function is None:
- raise ValueError("You must provide embeddings or a function to compute them")
- embeddings = self._embedding_function(documents)
+ ids, embeddings, metadatas, documents = self._validate_embedding_set(
+ ids, embeddings, metadatas, documents
+ )
- self._client._add(ids, self.name, embeddings, metadatas, documents, increment_index)
+ self._client._add(
+ ids, self.name, embeddings, metadatas, documents, increment_index
+ )
def get(
self,
@@ -151,7 +127,9 @@ def get(
"""
where = validate_where(where) if where else None
- where_document = validate_where_document(where_document) if where_document else None
+ where_document = (
+ validate_where_document(where_document) if where_document else None
+ )
ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None
include = validate_include(include, allow_distances=False)
return self._client._get(
@@ -204,8 +182,12 @@ def query(
"""
where = validate_where(where) if where else None
- where_document = validate_where_document(where_document) if where_document else None
- query_embeddings = maybe_cast_one_to_many(query_embeddings) if query_embeddings else None
+ where_document = (
+ validate_where_document(where_document) if where_document else None
+ )
+ query_embeddings = (
+ maybe_cast_one_to_many(query_embeddings) if query_embeddings else None
+ )
query_texts = maybe_cast_one_to_many(query_texts) if query_texts else None
include = validate_include(include, allow_distances=True)
@@ -220,9 +202,13 @@ def query(
# If query_embeddings are not provided, we need to compute them from the query_texts
if query_embeddings is None:
if self._embedding_function is None:
- raise ValueError("You must provide embeddings or a function to compute them")
+ raise ValueError(
+ "You must provide embeddings or a function to compute them"
+ )
# We know query texts is not None at this point, cast for the typechecker
- query_embeddings = self._embedding_function(cast(List[Document], query_texts))
+ query_embeddings = self._embedding_function(
+ cast(List[Document], query_texts)
+ )
if where is None:
where = {}
@@ -249,7 +235,9 @@ def modify(self, name: Optional[str] = None, metadata=None):
Returns:
None
"""
- self._client._modify(current_name=self.name, new_name=name, new_metadata=metadata)
+ self._client._modify(
+ current_name=self.name, new_name=name, new_metadata=metadata
+ )
if name:
self.name = name
if metadata:
@@ -274,40 +262,41 @@ def update(
None
"""
- ids = validate_ids(maybe_cast_one_to_many(ids))
- embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None
- metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None
- documents = maybe_cast_one_to_many(documents) if documents else None
+ ids, embeddings, metadatas, documents = self._validate_embedding_set(
+ ids, embeddings, metadatas, documents, require_embeddings_or_documents=False
+ )
- # Must update one of embeddings, metadatas, or documents
- if embeddings is None and documents is None and metadatas is None:
- raise ValueError("You must update at least one of embeddings, documents or metadatas.")
+ self._client._update(self.name, ids, embeddings, metadatas, documents)
- # Check that one of embeddings or documents is provided
- if embeddings is not None and documents is None:
- raise ValueError("You must provide updated documents with updated embeddings")
+ def upsert(
+ self,
+ ids: OneOrMany[ID],
+ embeddings: Optional[OneOrMany[Embedding]] = None,
+ metadatas: Optional[OneOrMany[Metadata]] = None,
+ documents: Optional[OneOrMany[Document]] = None,
+ increment_index: bool = True,
+ ):
+ """Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
- # Check that, if they're provided, the lengths of the arrays match the length of ids
- if embeddings is not None and len(embeddings) != len(ids):
- raise ValueError(
- f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}"
- )
- if metadatas is not None and len(metadatas) != len(ids):
- raise ValueError(
- f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}"
- )
- if documents is not None and len(documents) != len(ids):
- raise ValueError(
- f"Number of documents {len(documents)} must match number of ids {len(ids)}"
- )
+ Args:
+ ids: The ids of the embeddings to update
+ embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
+ metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
+ documents: The documents to associate with the embeddings. Optional.
+ """
- # If document embeddings are not provided, we need to compute them
- if embeddings is None and documents is not None:
- if self._embedding_function is None:
- raise ValueError("You must provide embeddings or a function to compute them")
- embeddings = self._embedding_function(documents)
+ ids, embeddings, metadatas, documents = self._validate_embedding_set(
+ ids, embeddings, metadatas, documents
+ )
- self._client._update(self.name, ids, embeddings, metadatas, documents)
+ self._client._upsert(
+ collection_name=self.name,
+ ids=ids,
+ embeddings=embeddings,
+ metadatas=metadatas,
+ documents=documents,
+ increment_index=increment_index,
+ )
def delete(
self,
@@ -327,8 +316,65 @@ def delete(
"""
ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None
where = validate_where(where) if where else None
- where_document = validate_where_document(where_document) if where_document else None
+ where_document = (
+ validate_where_document(where_document) if where_document else None
+ )
return self._client._delete(self.name, ids, where, where_document)
def create_index(self):
self._client.create_index(self.name)
+
+ def _validate_embedding_set(
+ self,
+ ids,
+ embeddings,
+ metadatas,
+ documents,
+ require_embeddings_or_documents=True,
+ ) -> Tuple[
+ IDs,
+ Optional[List[Embedding]],
+ Optional[List[Metadata]],
+ Optional[List[Document]],
+ ]:
+ ids = validate_ids(maybe_cast_one_to_many(ids))
+ embeddings = (
+ maybe_cast_one_to_many(embeddings) if embeddings is not None else None
+ )
+ metadatas = (
+ validate_metadatas(maybe_cast_one_to_many(metadatas))
+ if metadatas is not None
+ else None
+ )
+ documents = maybe_cast_one_to_many(documents) if documents is not None else None
+
+ # Check that one of embeddings or documents is provided
+ if require_embeddings_or_documents:
+ if embeddings is None and documents is None:
+ raise ValueError(
+ "You must provide either embeddings or documents, or both"
+ )
+
+ # Check that, if they're provided, the lengths of the arrays match the length of ids
+ if embeddings is not None and len(embeddings) != len(ids):
+ raise ValueError(
+ f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}"
+ )
+ if metadatas is not None and len(metadatas) != len(ids):
+ raise ValueError(
+ f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}"
+ )
+ if documents is not None and len(documents) != len(ids):
+ raise ValueError(
+ f"Number of documents {len(documents)} must match number of ids {len(ids)}"
+ )
+
+ # If document embeddings are not provided, we need to compute them
+ if embeddings is None and documents is not None:
+ if self._embedding_function is None:
+ raise ValueError(
+ "You must provide embeddings or a function to compute them"
+ )
+ embeddings = self._embedding_function(documents)
+
+ return ids, embeddings, metadatas, documents
diff --git a/chromadb/api/types.py b/chromadb/api/types.py
index 39d50d723d4..edd303c186c 100644
--- a/chromadb/api/types.py
+++ b/chromadb/api/types.py
@@ -1,4 +1,6 @@
-from typing import Literal, Optional, Union, Dict, Sequence, TypedDict, Protocol, TypeVar, List
+from typing import Optional, Union, Dict, Sequence, TypeVar, List
+from typing_extensions import Literal, TypedDict, Protocol
+import chromadb.errors as errors
ID = str
IDs = List[ID]
@@ -26,7 +28,9 @@
WhereOperator = Literal["$gt", "$gte", "$lt", "$lte", "$ne", "$eq"]
OperatorExpression = Dict[Union[WhereOperator, LogicalOperator], LiteralValue]
-Where = Dict[Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]]
+Where = Dict[
+ Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]
+]
WhereDocumentOperator = Literal["$contains", LogicalOperator]
WhereDocument = Dict[WhereDocumentOperator, Union[str, List["WhereDocument"]]]
@@ -84,6 +88,11 @@ def validate_ids(ids: IDs) -> IDs:
for id in ids:
if not isinstance(id, str):
raise ValueError(f"Expected ID to be a str, got {id}")
+ if len(ids) != len(set(ids)):
+ dups = set([x for x in ids if ids.count(x) > 1])
+ raise errors.DuplicateIDError(
+ f"Expected IDs to be unique, found duplicates for: {dups}"
+ )
return ids
@@ -95,7 +104,9 @@ def validate_metadata(metadata: Metadata) -> Metadata:
if not isinstance(key, str):
raise ValueError(f"Expected metadata key to be a str, got {key}")
if not isinstance(value, (str, int, float)):
- raise ValueError(f"Expected metadata value to be a str, int, or float, got {value}")
+ raise ValueError(
+ f"Expected metadata value to be a str, int, or float, got {value}"
+ )
return metadata
@@ -118,7 +129,11 @@ def validate_where(where: Where) -> Where:
for key, value in where.items():
if not isinstance(key, str):
raise ValueError(f"Expected where key to be a str, got {key}")
- if key != "$and" and key != "$or" and not isinstance(value, (str, int, float, dict)):
+ if (
+ key != "$and"
+ and key != "$or"
+ and not isinstance(value, (str, int, float, dict))
+ ):
raise ValueError(
f"Expected where value to be a str, int, float, or operator expression, got {value}"
)
@@ -167,7 +182,9 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument:
a list of where_document expressions
"""
if not isinstance(where_document, dict):
- raise ValueError(f"Expected where document to be a dictionary, got {where_document}")
+ raise ValueError(
+ f"Expected where document to be a dictionary, got {where_document}"
+ )
if len(where_document) != 1:
raise ValueError(
f"Expected where document to have exactly one operator, got {where_document}"
diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py
index b1b6874cdf2..7f1b2969fef 100644
--- a/chromadb/db/clickhouse.py
+++ b/chromadb/db/clickhouse.py
@@ -1,4 +1,11 @@
-from chromadb.api.types import Documents, Embeddings, IDs, Metadatas, Where, WhereDocument
+from chromadb.api.types import (
+ Documents,
+ Embeddings,
+ IDs,
+ Metadatas,
+ Where,
+ WhereDocument,
+)
from chromadb.db import DB
from chromadb.db.index.hnswlib import Hnswlib, delete_all_indexes
from chromadb.errors import (
@@ -53,7 +60,8 @@ def __init__(self, settings):
def _init_conn(self):
common.set_setting("autogenerate_session_id", False)
self._conn = clickhouse_connect.get_client(
- host=self._settings.clickhouse_host, port=int(self._settings.clickhouse_port)
+ host=self._settings.clickhouse_host,
+ port=int(self._settings.clickhouse_port),
)
self._create_table_collections(self._conn)
self._create_table_embeddings(self._conn)
@@ -100,7 +108,9 @@ def _delete_index(self, collection_id):
# UTILITY METHODS
#
def persist(self):
- raise NotImplementedError("Clickhouse is a persistent database, this method is not needed")
+ raise NotImplementedError(
+ "Clickhouse is a persistent database, this method is not needed"
+ )
def get_collection_uuid_from_name(self, name: str) -> str:
res = self._get_conn().query(
@@ -143,6 +153,9 @@ def create_collection(
if len(dupe_check) > 0:
if get_or_create:
+ if dupe_check[0][2] != metadata:
+ self.update_collection(name, new_name=name, new_metadata=metadata)
+ dupe_check = self.get_collection(name)
logger.info(
f"collection with name {name} already exists, returning existing collection"
)
@@ -189,7 +202,10 @@ def list_collections(self) -> Sequence:
return [[x[0], x[1], json.loads(x[2])] for x in res]
def update_collection(
- self, current_name: str, new_name: Optional[str] = None, new_metadata: Optional[Dict] = None
+ self,
+ current_name: str,
+ new_name: Optional[str] = None,
+ new_metadata: Optional[Dict] = None,
):
if new_name is None:
new_name = current_name
@@ -202,11 +218,11 @@ def update_collection(
ALTER TABLE
collections
UPDATE
- metadata = '{json.dumps(new_metadata)}',
- name = '{new_name}'
+ metadata = %s,
+ name = %s
WHERE
- name = '{current_name}'
- """
+ name = %s
+ """, [json.dumps(new_metadata), new_name, current_name]
)
def delete_collection(self, name: str):
@@ -241,7 +257,14 @@ def add(self, collection_uuid, embeddings, metadatas, documents, ids):
]
for i, embedding in enumerate(embeddings)
]
- column_names = ["collection_uuid", "uuid", "embedding", "metadata", "document", "id"]
+ column_names = [
+ "collection_uuid",
+ "uuid",
+ "embedding",
+ "metadata",
+ "document",
+ "id",
+ ]
self._get_conn().insert("embeddings", data_to_insert, column_names=column_names)
return [x[1] for x in data_to_insert] # return uuids
@@ -260,26 +283,28 @@ def _update(
update_fields = []
parameters[f"i{i}"] = ids[i]
if embeddings is not None:
- update_fields.append(f"embedding = {{e{i}:Array(Float64)}}")
+ update_fields.append(f"embedding = %(e{i})s")
parameters[f"e{i}"] = embeddings[i]
if metadatas is not None:
- update_fields.append(f"metadata = {{m{i}:String}}")
+ update_fields.append(f"metadata = %(m{i})s")
parameters[f"m{i}"] = json.dumps(metadatas[i])
if documents is not None:
- update_fields.append(f"document = {{d{i}:String}}")
+ update_fields.append(f"document = %(d{i})s")
parameters[f"d{i}"] = documents[i]
update_statement = f"""
UPDATE
{",".join(update_fields)}
WHERE
- id = {{i{i}:String}} AND
+ id = %(i{i})s AND
collection_uuid = '{collection_uuid}'{"" if i == len(ids) - 1 else ","}
"""
updates.append(update_statement)
update_clauses = ("").join(updates)
- self._get_conn().command(f"ALTER TABLE embeddings {update_clauses}", parameters=parameters)
+ self._get_conn().command(
+ f"ALTER TABLE embeddings {update_clauses}", parameters=parameters
+ )
def update(
self,
@@ -292,14 +317,19 @@ def update(
# Verify all IDs exist
existing_items = self.get(collection_uuid=collection_uuid, ids=ids)
if len(existing_items) != len(ids):
- raise ValueError(f"Could not find {len(ids) - len(existing_items)} items for update")
+ raise ValueError(
+ f"Could not find {len(ids) - len(existing_items)} items for update"
+ )
# Update the db
self._update(collection_uuid, ids, embeddings, metadatas, documents)
# Update the index
if embeddings is not None:
- update_uuids = [x[1] for x in existing_items]
+ # `get` current returns items in arbitrary order.
+ # TODO if we fix `get`, we can remove this explicit mapping.
+ uuid_mapping = {r[4]: r[1] for r in existing_items}
+ update_uuids = [uuid_mapping[id] for id in ids]
index = self._index(collection_uuid)
index.add(update_uuids, embeddings, update=True)
@@ -318,37 +348,59 @@ def _get(self, where={}, columns: Optional[List] = None):
if "metadata" in select_columns:
metadata_column_index = select_columns.index("metadata")
db_metadata = val[i][metadata_column_index]
- val[i][metadata_column_index] = json.loads(db_metadata) if db_metadata else None
+ val[i][metadata_column_index] = (
+ json.loads(db_metadata) if db_metadata else None
+ )
return val
def _format_where(self, where, result):
for key, value in where.items():
+
+ def has_key_and(clause):
+ return f"(JSONHas(metadata,'{key}') = 1 AND {clause})"
+
# Shortcut for $eq
if type(value) == str:
- result.append(f" JSONExtractString(metadata,'{key}') = '{value}'")
+ result.append(has_key_and(f" JSONExtractString(metadata,'{key}') = '{value}'"))
elif type(value) == int:
- result.append(f" JSONExtractInt(metadata,'{key}') = {value}")
+ result.append(has_key_and(f" JSONExtractInt(metadata,'{key}') = {value}"))
elif type(value) == float:
- result.append(f" JSONExtractFloat(metadata,'{key}') = {value}")
+ result.append(has_key_and(f" JSONExtractFloat(metadata,'{key}') = {value}"))
# Operator expression
elif type(value) == dict:
operator, operand = list(value.items())[0]
if operator == "$gt":
- return result.append(f" JSONExtractFloat(metadata,'{key}') > {operand}")
+ return result.append(
+ has_key_and(f" JSONExtractFloat(metadata,'{key}') > {operand}")
+ )
elif operator == "$lt":
- return result.append(f" JSONExtractFloat(metadata,'{key}') < {operand}")
+ return result.append(
+ has_key_and(f" JSONExtractFloat(metadata,'{key}') < {operand}")
+ )
elif operator == "$gte":
- return result.append(f" JSONExtractFloat(metadata,'{key}') >= {operand}")
+ return result.append(
+ has_key_and(f" JSONExtractFloat(metadata,'{key}') >= {operand}")
+ )
elif operator == "$lte":
- return result.append(f" JSONExtractFloat(metadata,'{key}') <= {operand}")
+ return result.append(
+ has_key_and(f" JSONExtractFloat(metadata,'{key}') <= {operand}")
+ )
elif operator == "$ne":
if type(operand) == str:
- return result.append(f" JSONExtractString(metadata,'{key}') != '{operand}'")
- return result.append(f" JSONExtractFloat(metadata,'{key}') != {operand}")
+ return result.append(
+ has_key_and(f" JSONExtractString(metadata,'{key}') != '{operand}'")
+ )
+ return result.append(
+ has_key_and(f" JSONExtractFloat(metadata,'{key}') != {operand}")
+ )
elif operator == "$eq":
if type(operand) == str:
- return result.append(f" JSONExtractString(metadata,'{key}') = '{operand}'")
- return result.append(f" JSONExtractFloat(metadata,'{key}') = {operand}")
+ return result.append(
+ has_key_and(f" JSONExtractString(metadata,'{key}') = '{operand}'")
+ )
+ return result.append(
+ has_key_and(f" JSONExtractFloat(metadata,'{key}') = {operand}")
+ )
else:
raise ValueError(
f"Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got {operator}"
@@ -396,7 +448,9 @@ def get(
columns: Optional[List[str]] = None,
) -> Sequence:
if collection_name is None and collection_uuid is None:
- raise TypeError("Arguments collection_name and collection_uuid cannot both be None")
+ raise TypeError(
+ "Arguments collection_name and collection_uuid cannot both be None"
+ )
if collection_name is not None:
collection_uuid = self.get_collection_uuid_from_name(collection_name)
@@ -426,7 +480,11 @@ def get(
def _count(self, collection_uuid: str):
where_string = f"WHERE collection_uuid = '{collection_uuid}'"
- return self._get_conn().query(f"SELECT COUNT() FROM embeddings {where_string}").result_rows
+ return (
+ self._get_conn()
+ .query(f"SELECT COUNT() FROM embeddings {where_string}")
+ .result_rows
+ )
def count(self, collection_name: str):
collection_uuid = self.get_collection_uuid_from_name(collection_name)
@@ -434,7 +492,9 @@ def count(self, collection_name: str):
def _delete(self, where_str: Optional[str] = None) -> List:
deleted_uuids = (
- self._get_conn().query(f"""SELECT uuid FROM embeddings {where_str}""").result_rows
+ self._get_conn()
+ .query(f"""SELECT uuid FROM embeddings {where_str}""")
+ .result_rows
)
self._get_conn().command(
f"""
@@ -494,17 +554,20 @@ def get_nearest_neighbors(
collection_name=None,
collection_uuid=None,
) -> Tuple[List[List[uuid.UUID]], npt.NDArray]:
-
# Either the collection name or the collection uuid must be provided
if collection_name is None and collection_uuid is None:
- raise TypeError("Arguments collection_name and collection_uuid cannot both be None")
+ raise TypeError(
+ "Arguments collection_name and collection_uuid cannot both be None"
+ )
if collection_name is not None:
collection_uuid = self.get_collection_uuid_from_name(collection_name)
if len(where) != 0 or len(where_document) != 0:
results = self.get(
- collection_uuid=collection_uuid, where=where, where_document=where_document
+ collection_uuid=collection_uuid,
+ where=where,
+ where_document=where_document,
)
if len(results) > 0:
diff --git a/chromadb/db/duckdb.py b/chromadb/db/duckdb.py
index 44e35655906..5fd218c8146 100644
--- a/chromadb/db/duckdb.py
+++ b/chromadb/db/duckdb.py
@@ -13,6 +13,7 @@
import uuid
import os
import logging
+import atexit
logger = logging.getLogger(__name__)
@@ -40,7 +41,6 @@ def clickhouse_to_duckdb_schema(table_schema):
class DuckDB(Clickhouse):
# duckdb has a different way of connecting to the database
def __init__(self, settings):
-
self._conn = duckdb.connect()
self._create_table_collections()
self._create_table_embeddings()
@@ -68,9 +68,9 @@ def _create_table_embeddings(self):
# UTILITY METHODS
#
def get_collection_uuid_from_name(self, name):
- return self._conn.execute("SELECT uuid FROM collections WHERE name = ?", [name]).fetchall()[
- 0
- ][0]
+ return self._conn.execute(
+ "SELECT uuid FROM collections WHERE name = ?", [name]
+ ).fetchall()[0][0]
#
# COLLECTION METHODS
@@ -82,6 +82,10 @@ def create_collection(
dupe_check = self.get_collection(name)
if len(dupe_check) > 0:
if get_or_create is True:
+ if dupe_check[0][2] != metadata:
+ self.update_collection(name, new_name=name, new_metadata=metadata)
+ dupe_check = self.get_collection(name)
+
logger.info(
f"collection with name {name} already exists, returning existing collection"
)
@@ -97,12 +101,16 @@ def create_collection(
return [[str(collection_uuid), name, metadata]]
def get_collection(self, name: str) -> Sequence:
- res = self._conn.execute("""SELECT * FROM collections WHERE name = ?""", [name]).fetchall()
+ res = self._conn.execute(
+ """SELECT * FROM collections WHERE name = ?""", [name]
+ ).fetchall()
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res]
def get_collection_by_id(self, uuid: str) -> Sequence:
- res = self._conn.execute("""SELECT * FROM collections WHERE uuid = ?""", [uuid]).fetchone()
+ res = self._conn.execute(
+ """SELECT * FROM collections WHERE uuid = ?""", [uuid]
+ ).fetchone()
return [res[0], res[1], json.loads(res[2])]
def list_collections(self) -> Sequence:
@@ -172,20 +180,32 @@ def _format_where(self, where, result):
if type(value) == str:
result.append(f" json_extract_string(metadata,'$.{key}') = '{value}'")
if type(value) == int:
- result.append(f" CAST(json_extract(metadata,'$.{key}') AS INT) = {value}")
+ result.append(
+ f" CAST(json_extract(metadata,'$.{key}') AS INT) = {value}"
+ )
if type(value) == float:
- result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) = {value}")
+ result.append(
+ f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) = {value}"
+ )
# Operator expression
elif type(value) == dict:
operator, operand = list(value.items())[0]
if operator == "$gt":
- result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) > {operand}")
+ result.append(
+ f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) > {operand}"
+ )
elif operator == "$lt":
- result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) < {operand}")
+ result.append(
+ f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) < {operand}"
+ )
elif operator == "$gte":
- result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) >= {operand}")
+ result.append(
+ f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) >= {operand}"
+ )
elif operator == "$lte":
- result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) <= {operand}")
+ result.append(
+ f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) <= {operand}"
+ )
elif operator == "$ne":
if type(operand) == str:
return result.append(
@@ -215,7 +235,9 @@ def _format_where(self, where, result):
elif key == "$and":
result.append(f"({' AND '.join(all_subresults)})")
else:
- raise ValueError(f"Operator {key} not supported with a list of where clauses")
+ raise ValueError(
+ f"Operator {key} not supported with a list of where clauses"
+ )
def _format_where_document(self, where_document, results):
operator = list(where_document.keys())[0]
@@ -335,7 +357,9 @@ def get_by_ids(self, ids: List, columns: Optional[List] = None):
).fetchall()
# sort db results by the order of the uuids
- response = sorted(response, key=lambda obj: ids.index(uuid.UUID(obj[len(columns) - 1])))
+ response = sorted(
+ response, key=lambda obj: ids.index(uuid.UUID(obj[len(columns) - 1]))
+ )
return response
@@ -374,6 +398,8 @@ def __init__(self, settings):
self._save_folder = settings.persist_directory
self.load()
+ # https://docs.python.org/3/library/atexit.html
+ atexit.register(self.persist)
def set_save_folder(self, path):
self._save_folder = path
@@ -385,7 +411,9 @@ def persist(self):
"""
Persist the database to disk
"""
- logger.info(f"Persisting DB to disk, putting it in the save folder: {self._save_folder}")
+ logger.info(
+ f"Persisting DB to disk, putting it in the save folder: {self._save_folder}"
+ )
if self._conn is None:
return
@@ -426,7 +454,9 @@ def load(self):
logger.info(f"No existing DB found in {self._save_folder}, skipping load")
else:
path = self._save_folder + "/chroma-embeddings.parquet"
- self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
+ self._conn.execute(
+ f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');"
+ )
logger.info(
f"""loaded in {self._conn.query(f"SELECT COUNT() FROM embeddings").fetchall()[0][0]} embeddings"""
)
@@ -436,14 +466,16 @@ def load(self):
logger.info(f"No existing DB found in {self._save_folder}, skipping load")
else:
path = self._save_folder + "/chroma-collections.parquet"
- self._conn.execute(f"INSERT INTO collections SELECT * FROM read_parquet('{path}');")
+ self._conn.execute(
+ f"INSERT INTO collections SELECT * FROM read_parquet('{path}');"
+ )
logger.info(
f"""loaded in {self._conn.query(f"SELECT COUNT() FROM collections").fetchall()[0][0]} collections"""
)
def __del__(self):
- logger.info("PersistentDuckDB del, about to run persist")
- self.persist()
+ # No-op for duckdb with persistence since the base class will delete the indexes
+ pass
def reset(self):
super().reset()
diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py
index 0cfa760d6bb..f00aadd5fbc 100644
--- a/chromadb/db/index/hnswlib.py
+++ b/chromadb/db/index/hnswlib.py
@@ -2,10 +2,15 @@
import pickle
import time
from typing import Dict
+
from chromadb.api.types import IndexMetadata
import hnswlib
from chromadb.db.index import Index
-from chromadb.errors import NoIndexException, InvalidDimensionException, NotEnoughElementsException
+from chromadb.errors import (
+ NoIndexException,
+ InvalidDimensionException,
+ NotEnoughElementsException,
+)
import logging
import re
from uuid import UUID
@@ -24,7 +29,6 @@
class HnswParams:
-
space: str
construction_ef: int
search_ef: int
@@ -33,7 +37,6 @@ class HnswParams:
resize_factor: float
def __init__(self, metadata):
-
metadata = metadata or {}
# Convert all values to strings for future compatibility.
@@ -44,7 +47,9 @@ def __init__(self, metadata):
if param not in valid_params:
raise ValueError(f"Unknown HNSW parameter: {param}")
if not re.match(valid_params[param], value):
- raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}")
+ raise ValueError(
+ f"Invalid value for HNSW parameter: {param} = {value}"
+ )
self.space = metadata.get("hnsw:space", "l2")
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100))
@@ -71,7 +76,7 @@ class Hnswlib(Index):
_index_metadata: IndexMetadata
_params: HnswParams
_id_to_label: Dict[str, int]
- _label_to_id: Dict[int, str]
+ _label_to_id: Dict[int, UUID]
def __init__(self, id, settings, metadata):
self._save_folder = settings.persist_directory + "/index"
@@ -128,7 +133,7 @@ def add(self, ids, embeddings, update=False):
labels = []
for id in ids:
- if id in self._id_to_label:
+ if hexid(id) in self._id_to_label:
if update:
labels.append(self._id_to_label[hexid(id)])
else:
@@ -141,7 +146,9 @@ def add(self, ids, embeddings, update=False):
labels.append(next_label)
if self._index_metadata["elements"] > self._index.get_max_elements():
- new_size = max(self._index_metadata["elements"] * self._params.resize_factor, 1000)
+ new_size = max(
+ self._index_metadata["elements"] * self._params.resize_factor, 1000
+ )
self._index.resize_index(int(new_size))
self._index.add_items(embeddings, labels)
@@ -196,7 +203,6 @@ def _exists(self):
return
def _load(self):
-
if not os.path.exists(f"{self._save_folder}/index_{self._id}.bin"):
return
@@ -208,7 +214,9 @@ def _load(self):
with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "rb") as f:
self._index_metadata = pickle.load(f)
- p = hnswlib.Index(space=self._params.space, dim=self._index_metadata["dimensionality"])
+ p = hnswlib.Index(
+ space=self._params.space, dim=self._index_metadata["dimensionality"]
+ )
self._index = p
self._index.load_index(
f"{self._save_folder}/index_{self._id}.bin",
@@ -218,9 +226,10 @@ def _load(self):
self._index.set_num_threads(self._params.num_threads)
def get_nearest_neighbors(self, query, k, ids=None):
-
if self._index is None:
- raise NoIndexException("Index not found, please create an instance before querying")
+ raise NoIndexException(
+ "Index not found, please create an instance before querying"
+ )
# Check dimensionality
self._check_dimensionality(query)
@@ -245,8 +254,12 @@ def get_nearest_neighbors(self, query, k, ids=None):
logger.debug(f"time to pre process our knn query: {time.time() - s2}")
s3 = time.time()
- database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function)
+ database_labels, distances = self._index.knn_query(
+ query, k=k, filter=filter_function
+ )
logger.debug(f"time to run knn query: {time.time() - s3}")
- ids = [[self._label_to_id[label] for label in labels] for labels in database_labels]
+ ids = [
+ [self._label_to_id[label] for label in labels] for labels in database_labels
+ ]
return ids, distances
diff --git a/chromadb/errors.py b/chromadb/errors.py
index 0a75c43f5b2..ee58ee975f1 100644
--- a/chromadb/errors.py
+++ b/chromadb/errors.py
@@ -1,14 +1,66 @@
-class NoDatapointsException(Exception):
- pass
+from abc import ABCMeta, abstractmethod
-class NoIndexException(Exception):
- pass
+class ChromaError(Exception):
+ def code(self):
+ """Return an appropriate HTTP response code for this error"""
+ return 400 # Bad Request
-class InvalidDimensionException(Exception):
- pass
+ def message(self):
+ return ", ".join(self.args)
+ @classmethod
+ @abstractmethod
+ def name(self):
+ """Return the error name"""
+ pass
-class NotEnoughElementsException(Exception):
- pass
+
+class NoDatapointsException(ChromaError):
+ @classmethod
+ def name(cls):
+ return "NoDatapoints"
+
+
+class NoIndexException(ChromaError):
+ @classmethod
+ def name(cls):
+ return "NoIndex"
+
+
+class InvalidDimensionException(ChromaError):
+ @classmethod
+ def name(cls):
+ return "InvalidDimension"
+
+
+class NotEnoughElementsException(ChromaError):
+ @classmethod
+ def name(cls):
+ return "NotEnoughElements"
+
+
+class IDAlreadyExistsError(ChromaError):
+
+ def code(self):
+ return 409 # Conflict
+
+ @classmethod
+ def name(cls):
+ return "IDAlreadyExists"
+
+
+class DuplicateIDError(ChromaError):
+ @classmethod
+ def name(cls):
+ return "DuplicateID"
+
+error_types = {
+ "NoDatapoints": NoDatapointsException,
+ "NoIndex": NoIndexException,
+ "InvalidDimension": InvalidDimensionException,
+ "NotEnoughElements": NotEnoughElementsException,
+ "IDAlreadyExists": IDAlreadyExistsError,
+ "DuplicateID": DuplicateIDError,
+}
diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py
index d24515ef8bb..09cf19a58a7 100644
--- a/chromadb/server/fastapi/__init__.py
+++ b/chromadb/server/fastapi/__init__.py
@@ -9,6 +9,7 @@
import chromadb
import chromadb.server
from chromadb.errors import (
+ ChromaError,
NoDatapointsException,
InvalidDimensionException,
NotEnoughElementsException,
@@ -45,6 +46,10 @@ def use_route_names_as_operation_ids(app: _FastAPI) -> None:
async def catch_exceptions_middleware(request: Request, call_next):
try:
return await call_next(request)
+ except ChromaError as e:
+ return JSONResponse(content={"error": e.name(),
+ "message": e.message()},
+ status_code=e.code())
except Exception as e:
logger.exception(e)
return JSONResponse(content={"error": repr(e)}, status_code=500)
@@ -86,6 +91,9 @@ def __init__(self, settings):
self.router.add_api_route(
"/api/v1/collections/{collection_name}/update", self.update, methods=["POST"]
)
+ self.router.add_api_route(
+ "/api/v1/collections/{collection_name}/upsert", self.upsert, methods=["POST"]
+ )
self.router.add_api_route(
"/api/v1/collections/{collection_name}/get", self.get, methods=["POST"]
)
@@ -180,6 +188,16 @@ def update(self, collection_name: str, add: UpdateEmbedding):
metadatas=add.metadatas,
)
+ def upsert(self, collection_name: str, upsert: AddEmbedding):
+ return self._api._upsert(
+ collection_name=collection_name,
+ ids=upsert.ids,
+ embeddings=upsert.embeddings,
+ documents=upsert.documents,
+ metadatas=upsert.metadatas,
+ increment_index=upsert.increment_index,
+ )
+
def get(self, collection_name, get: GetEmbedding):
return self._api._get(
collection_name=collection_name,
@@ -207,22 +225,15 @@ def reset(self):
return self._api.reset()
def get_nearest_neighbors(self, collection_name, query: QueryEmbedding):
- try:
- nnresult = self._api._query(
- collection_name=collection_name,
- where=query.where,
- where_document=query.where_document,
- query_embeddings=query.query_embeddings,
- n_results=query.n_results,
- include=query.include,
- )
- return nnresult
- except NoDatapointsException as e:
- raise HTTPException(status_code=500, detail=str(e))
- except InvalidDimensionException as e:
- raise HTTPException(status_code=500, detail=str(e))
- except NotEnoughElementsException as e:
- raise HTTPException(status_code=500, detail=str(e))
+ nnresult = self._api._query(
+ collection_name=collection_name,
+ where=query.where,
+ where_document=query.where_document,
+ query_embeddings=query.query_embeddings,
+ n_results=query.n_results,
+ include=query.include,
+ )
+ return nnresult
def raw_sql(self, raw_sql: RawSql):
return self._api.raw_sql(raw_sql.raw_sql)
diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py
new file mode 100644
index 00000000000..88729bb6ffa
--- /dev/null
+++ b/chromadb/test/conftest.py
@@ -0,0 +1,105 @@
+from chromadb.config import Settings
+from chromadb import Client
+from chromadb.api import API
+import chromadb.server.fastapi
+from requests.exceptions import ConnectionError
+import hypothesis
+import tempfile
+import os
+import uvicorn
+import time
+from multiprocessing import Process
+import pytest
+from typing import Generator, List, Tuple, Callable
+import shutil
+
+hypothesis.settings.register_profile(
+ "dev", deadline=10000, suppress_health_check=[
+ hypothesis.HealthCheck.data_too_large,
+ hypothesis.HealthCheck.large_base_example
+ ]
+)
+hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev"))
+
+
+def _run_server():
+ """Run a Chroma server locally"""
+ settings = Settings(
+ chroma_api_impl="local",
+ chroma_db_impl="duckdb",
+ persist_directory=tempfile.gettempdir() + "/test_server",
+ )
+ server = chromadb.server.fastapi.FastAPI(settings)
+ uvicorn.run(server.app(), host="0.0.0.0", port=6666, log_level="error")
+
+
+def _await_server(api, attempts=0):
+ try:
+ api.heartbeat()
+ except ConnectionError as e:
+ if attempts > 10:
+ raise e
+ else:
+ time.sleep(2)
+ _await_server(api, attempts + 1)
+
+
+def fastapi() -> Generator[API, None, None]:
+ """Fixture generator that launches a server in a separate process, and yields a
+ fastapi client connect to it"""
+ proc = Process(target=_run_server, args=(), daemon=True)
+ proc.start()
+ api = chromadb.Client(
+ Settings(
+ chroma_api_impl="rest", chroma_server_host="localhost", chroma_server_http_port="6666"
+ )
+ )
+ _await_server(api)
+ yield api
+ proc.kill()
+
+
+def duckdb() -> Generator[API, None, None]:
+ """Fixture generator for duckdb"""
+ yield Client(
+ Settings(
+ chroma_api_impl="local",
+ chroma_db_impl="duckdb",
+ persist_directory=tempfile.gettempdir(),
+ )
+ )
+
+
+def duckdb_parquet() -> Generator[API, None, None]:
+ """Fixture generator for duckdb+parquet"""
+
+ save_path = tempfile.gettempdir() + "/tests"
+ yield Client(
+ Settings(
+ chroma_api_impl="local",
+ chroma_db_impl="duckdb+parquet",
+ persist_directory=save_path,
+ )
+ )
+ if os.path.exists(save_path):
+ shutil.rmtree(save_path)
+
+
+def integration_api() -> Generator[API, None, None]:
+ """Fixture generator for returning a client configured via environmenet
+ variables, intended for externally configured integration tests
+ """
+ yield chromadb.Client()
+
+
+def fixtures() -> List[Callable[[], Generator[API, None, None]]]:
+ api_fixtures = [duckdb, duckdb_parquet, fastapi]
+ if "CHROMA_INTEGRATION_TEST" in os.environ:
+ api_fixtures.append(integration_api)
+ if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ:
+ api_fixtures = [integration_api]
+ return api_fixtures
+
+@pytest.fixture(scope="module", params=fixtures())
+def api(request) -> Generator[API, None, None]:
+ yield next(request.param())
diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py
new file mode 100644
index 00000000000..0ba2f2e74a0
--- /dev/null
+++ b/chromadb/test/property/invariants.py
@@ -0,0 +1,251 @@
+import math
+from chromadb.test.property.strategies import RecordSet
+from typing import Callable, Optional, Union, List, TypeVar
+from typing_extensions import Literal
+import numpy as np
+from chromadb.api import types
+from chromadb.api.models.Collection import Collection
+from hypothesis import note
+from hypothesis.errors import InvalidArgument
+
+T = TypeVar("T")
+
+
+def maybe_wrap(value: Union[T, List[T]]) -> Union[None, List[T]]:
+ """Wrap a value in a list if it is not a list"""
+ if value is None:
+ return None
+ elif isinstance(value, List):
+ return value
+ else:
+ return [value]
+
+
+def wrap_all(embeddings: RecordSet) -> RecordSet:
+ """Ensure that an embedding set has lists for all its values"""
+
+ if embeddings["embeddings"] is None:
+ embedding_list = None
+ elif isinstance(embeddings["embeddings"], list):
+ if len(embeddings["embeddings"]) > 0:
+ if isinstance(embeddings["embeddings"][0], list):
+ embedding_list = embeddings["embeddings"]
+ else:
+ embedding_list = [embeddings["embeddings"]]
+ else:
+ embedding_list = []
+ else:
+ raise InvalidArgument("embeddings must be a list, list of lists, or None")
+
+ return {
+ "ids": maybe_wrap(embeddings["ids"]), # type: ignore
+ "documents": maybe_wrap(embeddings["documents"]), # type: ignore
+ "metadatas": maybe_wrap(embeddings["metadatas"]), # type: ignore
+ "embeddings": embedding_list,
+ }
+
+
+def count(collection: Collection, embeddings: RecordSet):
+ """The given collection count is equal to the number of embeddings"""
+ count = collection.count()
+ embeddings = wrap_all(embeddings)
+ assert count == len(embeddings["ids"])
+
+
+def _field_matches(
+ collection: Collection,
+ embeddings: RecordSet,
+ field_name: Union[Literal["documents"], Literal["metadatas"]],
+):
+ """
+ The actual embedding field is equal to the expected field
+ field_name: one of [documents, metadatas]
+ """
+ result = collection.get(ids=embeddings["ids"], include=[field_name])
+ # The test_out_of_order_ids test fails because of this in test_add.py
+ # Here we sort by the ids to match the input order
+ embedding_id_to_index = {id: i for i, id in enumerate(embeddings["ids"])}
+ actual_field = result[field_name]
+ # This assert should never happen, if we include metadatas/documents it will be
+ # [None, None..] if there is no metadata. It will not be just None.
+ assert actual_field is not None
+ actual_field = sorted(
+ enumerate(actual_field),
+ key=lambda index_and_field_value: embedding_id_to_index[
+ result["ids"][index_and_field_value[0]]
+ ],
+ )
+ actual_field = [field_value for _, field_value in actual_field]
+
+ expected_field = embeddings[field_name]
+ if expected_field is None:
+ # Since an RecordSet is the user input, we need to convert the documents to
+ # a List since thats what the API returns -> none per entry
+ expected_field = [None] * len(embeddings["ids"])
+ assert actual_field == expected_field
+
+
+def ids_match(collection: Collection, embeddings: RecordSet):
+ """The actual embedding ids is equal to the expected ids"""
+ embeddings = wrap_all(embeddings)
+ actual_ids = collection.get(ids=embeddings["ids"], include=[])["ids"]
+ # The test_out_of_order_ids test fails because of this in test_add.py
+ # Here we sort the ids to match the input order
+ embedding_id_to_index = {id: i for i, id in enumerate(embeddings["ids"])}
+ actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id])
+ assert actual_ids == embeddings["ids"]
+
+
+def metadatas_match(collection: Collection, embeddings: RecordSet):
+ """The actual embedding metadata is equal to the expected metadata"""
+ embeddings = wrap_all(embeddings)
+ _field_matches(collection, embeddings, "metadatas")
+
+
+def documents_match(collection: Collection, embeddings: RecordSet):
+ """The actual embedding documents is equal to the expected documents"""
+ embeddings = wrap_all(embeddings)
+ _field_matches(collection, embeddings, "documents")
+
+
+def no_duplicates(collection: Collection):
+ ids = collection.get()["ids"]
+ assert len(ids) == len(set(ids))
+
+
+# These match what the spec of hnswlib is
+# This epsilon is used to prevent division by zero and the value is the same
+# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238
+NORM_EPS = 1e-30
+distance_functions = {
+ "l2": lambda x, y: np.linalg.norm(x - y) ** 2,
+ "cosine": lambda x, y: 1
+ - np.dot(x, y) / ((np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS)),
+ "ip": lambda x, y: 1 - np.dot(x, y),
+}
+
+
+def _exact_distances(
+ query: types.Embeddings,
+ targets: types.Embeddings,
+ distance_fn: Callable = lambda x, y: np.linalg.norm(x - y) ** 2,
+):
+ """Return the ordered indices and distances from each query to each target"""
+ np_query = np.array(query)
+ np_targets = np.array(targets)
+
+ # Compute the distance between each query and each target, using the distance function
+ distances = np.apply_along_axis(
+ lambda query: np.apply_along_axis(distance_fn, 1, np_targets, query),
+ 1,
+ np_query,
+ )
+ # Sort the distances and return the indices
+ return np.argsort(distances), distances
+
+
+def ann_accuracy(
+ collection: Collection,
+ record_set: RecordSet,
+ n_results: int = 1,
+ min_recall: float = 0.99,
+ embedding_function: Optional[types.EmbeddingFunction] = None,
+):
+ """Validate that the API performs nearest_neighbor searches correctly"""
+ record_set = wrap_all(record_set)
+
+ if len(record_set["ids"]) == 0:
+ return # nothing to test here
+
+ embeddings = record_set["embeddings"]
+ have_embeddings = embeddings is not None and len(embeddings) > 0
+ if not have_embeddings:
+ assert embedding_function is not None
+ assert record_set["documents"] is not None
+ # Compute the embeddings for the documents
+ embeddings = embedding_function(record_set["documents"])
+
+ # l2 is the default distance function
+ distance_function = distance_functions["l2"]
+ accuracy_threshold = 1e-6
+ if "hnsw:space" in collection.metadata:
+ space = collection.metadata["hnsw:space"]
+ # TODO: ip and cosine are numerically unstable in HNSW.
+ # The higher the dimensionality, the more noise is introduced, since each float element
+ # of the vector has noise added, which is then subsequently included in all normalization calculations.
+ # This means that higher dimensions will have more noise, and thus more error.
+ dim = len(embeddings[0])
+ accuracy_threshold = accuracy_threshold * math.pow(10, int(math.log10(dim)))
+
+ if space == "cosine":
+ distance_function = distance_functions["cosine"]
+
+ if space == "ip":
+ distance_function = distance_functions["ip"]
+
+ # Perform exact distance computation
+ indices, distances = _exact_distances(
+ embeddings, embeddings, distance_fn=distance_function
+ )
+
+ query_results = collection.query(
+ query_embeddings=record_set["embeddings"],
+ query_texts=record_set["documents"] if not have_embeddings else None,
+ n_results=n_results,
+ include=["embeddings", "documents", "metadatas", "distances"],
+ )
+
+ # Dict of ids to indices
+ id_to_index = {id: i for i, id in enumerate(record_set["ids"])}
+ missing = 0
+ for i, (indices_i, distances_i) in enumerate(zip(indices, distances)):
+ expected_ids = np.array(record_set["ids"])[indices_i[:n_results]]
+ missing += len(set(expected_ids) - set(query_results["ids"][i]))
+
+ # For each id in the query results, find the index in the embeddings set
+ # and assert that the embeddings are the same
+ for j, id in enumerate(query_results["ids"][i]):
+ # This may be because the true nth nearest neighbor didn't get returned by the ANN query
+ unexpected_id = id not in expected_ids
+ index = id_to_index[id]
+
+ correct_distance = np.allclose(
+ distances_i[index],
+ query_results["distances"][i][j],
+ atol=accuracy_threshold,
+ )
+ if unexpected_id:
+ # If the ID is unexpcted, but the distance is correct, then we
+ # have a duplicate in the data. In this case, we should not reduce recall.
+ if correct_distance:
+ missing -= 1
+ else:
+ continue
+ else:
+ assert correct_distance
+
+ assert np.allclose(embeddings[index], query_results["embeddings"][i][j])
+ if record_set["documents"] is not None:
+ assert (
+ record_set["documents"][index] == query_results["documents"][i][j]
+ )
+ if record_set["metadatas"] is not None:
+ assert (
+ record_set["metadatas"][index] == query_results["metadatas"][i][j]
+ )
+
+ size = len(record_set["ids"])
+ recall = (size - missing) / size
+
+ try:
+ note(
+ f"recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}"
+ )
+ except InvalidArgument:
+ pass # it's ok if we're running outside hypothesis
+
+ assert recall >= min_recall
+
+ # Ensure that the query results are sorted by distance
+ for distance_result in query_results["distances"]:
+ assert np.allclose(np.sort(distance_result), distance_result)
diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py
new file mode 100644
index 00000000000..f862731cfdf
--- /dev/null
+++ b/chromadb/test/property/strategies.py
@@ -0,0 +1,461 @@
+import hashlib
+import hypothesis
+import hypothesis.strategies as st
+from typing import Optional, Callable, List, Dict, Union
+from typing_extensions import TypedDict
+import hypothesis.extra.numpy as npst
+import numpy as np
+import chromadb.api.types as types
+import re
+from hypothesis.strategies._internal.strategies import SearchStrategy
+from hypothesis.errors import InvalidDefinition
+
+from dataclasses import dataclass
+
+# Set the random seed for reproducibility
+np.random.seed(0) # unnecessary, hypothesis does this for us
+
+# See Hypothesis documentation for creating strategies at
+# https://hypothesis.readthedocs.io/en/latest/data.html
+
+# NOTE: Because these strategies are used in state machines, we need to
+# work around an issue with state machines, in which strategies that frequently
+# are marked as invalid (i.e. through the use of `assume` or `.filter`) can cause the
+# state machine tests to fail with an hypothesis.errors.Unsatisfiable.
+
+# Ultimately this is because the entire state machine is run as a single Hypothesis
+# example, which ends up drawing from the same strategies an enormous number of times.
+# Whenever a strategy marks itself as invalid, Hypothesis tries to start the entire
+# state machine run over. See https://github.com/HypothesisWorks/hypothesis/issues/3618
+
+# Because strategy generation is all interrelated, seemingly small changes (especially
+# ones called early in a test) can have an outside effect. Generating lists with
+# unique=True, or dictionaries with a min size seems especially bad.
+
+# Please make changes to these strategies incrementally, testing to make sure they don't
+# start generating unsatisfiable examples.
+
+test_hnsw_config = {
+ "hnsw:construction_ef": 128,
+ "hnsw:search_ef": 128,
+ "hnsw:M": 128,
+}
+
+
+class RecordSet(TypedDict):
+ """
+ A generated set of embeddings, ids, metadatas, and documents that
+ represent what a user would pass to the API.
+ """
+
+ ids: Union[types.ID, List[types.ID]]
+ embeddings: Optional[Union[types.Embeddings, types.Embedding]]
+ metadatas: Optional[Union[List[types.Metadata], types.Metadata]]
+ documents: Optional[Union[List[types.Document], types.Document]]
+
+
+# TODO: support arbitrary text everywhere so we don't SQL-inject ourselves.
+# TODO: support empty strings everywhere
+sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"
+safe_text = st.text(alphabet=sql_alphabet, min_size=1)
+
+# Workaround for FastAPI json encoding peculiarities
+# https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104
+safe_text = safe_text.filter(lambda s: not s.startswith("_sa"))
+
+safe_integers = st.integers(
+ min_value=-(2**31), max_value=2**31 - 1
+) # TODO: handle longs
+safe_floats = st.floats(
+ allow_infinity=False, allow_nan=False, allow_subnormal=False
+) # TODO: handle infinity and NAN
+
+safe_values = [safe_text, safe_integers, safe_floats]
+
+
+def one_or_both(strategy_a, strategy_b):
+ return st.one_of(
+ st.tuples(strategy_a, strategy_b),
+ st.tuples(strategy_a, st.none()),
+ st.tuples(st.none(), strategy_b),
+ )
+
+
+# Temporarily generate only these to avoid SQL formatting issues.
+legal_id_characters = (
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./+"
+)
+
+float_types = [np.float16, np.float32, np.float64]
+int_types = [np.int16, np.int32, np.int64] # TODO: handle int types
+
+
+@st.composite
+def collection_name(draw) -> str:
+ _collection_name_re = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,60}[a-zA-Z0-9]$")
+ _ipv4_address_re = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}$")
+ _two_periods_re = re.compile(r"\.\.")
+
+ name = draw(st.from_regex(_collection_name_re))
+ hypothesis.assume(not _ipv4_address_re.match(name))
+ hypothesis.assume(not _two_periods_re.search(name))
+
+ return name
+
+
+collection_metadata = st.one_of(
+ st.none(), st.dictionaries(safe_text, st.one_of(*safe_values))
+)
+
+
+# TODO: Use a hypothesis strategy while maintaining embedding uniqueness
+# Or handle duplicate embeddings within a known epsilon
+def create_embeddings(dim: int, count: int, dtype: np.dtype) -> types.Embeddings:
+ return (
+ np.random.uniform(
+ low=-1.0,
+ high=1.0,
+ size=(count, dim),
+ )
+ .astype(dtype)
+ .tolist()
+ )
+
+
+class hashing_embedding_function(types.EmbeddingFunction):
+ def __init__(self, dim: int, dtype: np.dtype) -> None:
+ self.dim = dim
+ self.dtype = dtype
+
+ def __call__(self, texts: types.Documents) -> types.Embeddings:
+ # Hash the texts and convert to hex strings
+ hashed_texts = [
+ list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in texts
+ ]
+ # Pad with repetition, or truncate the hex strings to the desired dimension
+ padded_texts = [
+ text * (self.dim // len(text)) + text[: self.dim % len(text)]
+ for text in hashed_texts
+ ]
+
+ # Convert the hex strings to dtype
+ return np.array(
+ [[int(char, 16) / 15.0 for char in text] for text in padded_texts],
+ dtype=self.dtype,
+ ).tolist()
+
+
+def embedding_function_strategy(
+ dim: int, dtype: np.dtype
+) -> st.SearchStrategy[types.EmbeddingFunction]:
+ return st.just(hashing_embedding_function(dim, dtype))
+
+
+@dataclass
+class Collection:
+ name: str
+ metadata: Optional[types.Metadata]
+ dimension: int
+ dtype: np.dtype
+ known_metadata_keys: Dict[str, st.SearchStrategy]
+ known_document_keywords: List[str]
+ has_documents: bool = False
+ has_embeddings: bool = False
+ embedding_function: Optional[types.EmbeddingFunction] = None
+
+
+@st.composite
+def collections(
+ draw,
+ add_filterable_data=False,
+ with_hnsw_params=False,
+ has_embeddings: Optional[bool] = None,
+ has_documents: Optional[bool] = None,
+) -> Collection:
+ """Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data."""
+
+ assert not ((has_embeddings is False) and (has_documents is False))
+
+ name = draw(collection_name())
+ metadata = draw(collection_metadata)
+ dimension = draw(st.integers(min_value=2, max_value=2048))
+ dtype = draw(st.sampled_from(float_types))
+
+ if with_hnsw_params:
+ if metadata is None:
+ metadata = {}
+ metadata.update(test_hnsw_config)
+ # Sometimes, select a space at random
+ if draw(st.booleans()):
+ # TODO: pull the distance functions from a source of truth that lives not
+ # in tests once https://github.com/chroma-core/issues/issues/61 lands
+ metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
+
+ known_metadata_keys = {}
+ if add_filterable_data:
+ while len(known_metadata_keys) < 5:
+ key = draw(safe_text)
+ known_metadata_keys[key] = draw(st.sampled_from(safe_values))
+
+ if has_documents is None:
+ has_documents = draw(st.booleans())
+ if has_documents and add_filterable_data:
+ known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5))
+ else:
+ known_document_keywords = []
+
+ if not has_documents:
+ has_embeddings = True
+ else:
+ if has_embeddings is None:
+ has_embeddings = draw(st.booleans())
+
+ embedding_function = draw(embedding_function_strategy(dimension, dtype))
+
+ return Collection(
+ name=name,
+ metadata=metadata,
+ dimension=dimension,
+ dtype=dtype,
+ known_metadata_keys=known_metadata_keys,
+ has_documents=has_documents,
+ known_document_keywords=known_document_keywords,
+ has_embeddings=has_embeddings,
+ embedding_function=embedding_function,
+ )
+
+
+@st.composite
+def metadata(draw, collection: Collection):
+ """Strategy for generating metadata that could be a part of the given collection"""
+ # First draw a random dictionary.
+ md = draw(st.dictionaries(safe_text, st.one_of(*safe_values)))
+ # Then, remove keys that overlap with the known keys for the coll
+ # to avoid type errors when comparing.
+ if collection.known_metadata_keys:
+ for key in collection.known_metadata_keys.keys():
+ if key in md:
+ del md[key]
+ # Finally, add in some of the known keys for the collection
+ md.update(
+ draw(st.fixed_dictionaries({}, optional=collection.known_metadata_keys))
+ )
+ return md
+
+
+@st.composite
+def document(draw, collection: Collection):
+ """Strategy for generating documents that could be a part of the given collection"""
+
+ if collection.known_document_keywords:
+ known_words_st = st.sampled_from(collection.known_document_keywords)
+ else:
+ known_words_st = st.text(min_size=1)
+
+ random_words_st = st.text(min_size=1)
+ words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1))
+ return " ".join(words)
+
+
+@st.composite
+def record(draw, collection: Collection, id_strategy=safe_text):
+ md = draw(metadata(collection))
+
+ if collection.has_embeddings:
+ embedding = create_embeddings(collection.dimension, 1, collection.dtype)[0]
+ else:
+ embedding = None
+ if collection.has_documents:
+ doc = draw(document(collection))
+ else:
+ doc = None
+
+ return {
+ "id": draw(id_strategy),
+ "embedding": embedding,
+ "metadata": md,
+ "document": doc,
+ }
+
+
+@st.composite
+def recordsets(
+ draw,
+ collection_strategy=collections(),
+ id_strategy=safe_text,
+ min_size=1,
+ max_size=50,
+) -> RecordSet:
+ collection = draw(collection_strategy)
+
+ records = draw(
+ st.lists(record(collection, id_strategy), min_size=min_size, max_size=max_size)
+ )
+
+ records = {r["id"]: r for r in records}.values() # Remove duplicates
+
+ ids = [r["id"] for r in records]
+ embeddings = (
+ [r["embedding"] for r in records] if collection.has_embeddings else None
+ )
+ metadatas = [r["metadata"] for r in records]
+ documents = [r["document"] for r in records] if collection.has_documents else None
+
+ # in the case where we have a single record, sometimes exercise
+ # the code that handles individual values rather than lists
+ if len(records) == 1:
+ if draw(st.booleans()):
+ ids = ids[0]
+ if collection.has_embeddings and draw(st.booleans()):
+ embeddings = embeddings[0]
+ if draw(st.booleans()):
+ metadatas = metadatas[0]
+ if collection.has_documents and draw(st.booleans()):
+ documents = documents[0]
+
+ return {
+ "ids": ids,
+ "embeddings": embeddings,
+ "metadatas": metadatas,
+ "documents": documents,
+ }
+
+
+# This class is mostly cloned from from hypothesis.stateful.RuleStrategy,
+# but always runs all the rules, instead of using a FeatureStrategy to
+# enable/disable rules. Disabled rules cause the entire test to be marked invalida and,
+# combined with the complexity of our other strategies, leads to an
+# unacceptably increased incidence of hypothesis.errors.Unsatisfiable.
+class DeterministicRuleStrategy(SearchStrategy):
+ def __init__(self, machine):
+ super().__init__()
+ self.machine = machine
+ self.rules = list(machine.rules())
+
+ # The order is a bit arbitrary. Primarily we're trying to group rules
+ # that write to the same location together, and to put rules with no
+ # target first as they have less effect on the structure. We order from
+ # fewer to more arguments on grounds that it will plausibly need less
+ # data. This probably won't work especially well and we could be
+ # smarter about it, but it's better than just doing it in definition
+ # order.
+ self.rules.sort(
+ key=lambda rule: (
+ sorted(rule.targets),
+ len(rule.arguments),
+ rule.function.__name__,
+ )
+ )
+
+ def __repr__(self):
+ return "{}(machine={}({{...}}))".format(
+ self.__class__.__name__,
+ self.machine.__class__.__name__,
+ )
+
+ def do_draw(self, data):
+ if not any(self.is_valid(rule) for rule in self.rules):
+ msg = f"No progress can be made from state {self.machine!r}"
+ raise InvalidDefinition(msg) from None
+
+ rule = data.draw(st.sampled_from([r for r in self.rules if self.is_valid(r)]))
+ argdata = data.draw(rule.arguments_strategy)
+ return (rule, argdata)
+
+ def is_valid(self, rule):
+ if not all(precond(self.machine) for precond in rule.preconditions):
+ return False
+
+ for b in rule.bundles:
+ bundle = self.machine.bundle(b.name)
+ if not bundle:
+ return False
+ return True
+
+
+@st.composite
+def where_clause(draw, collection):
+ """Generate a filter that could be used in a query against the given collection"""
+
+ known_keys = sorted(collection.known_metadata_keys.keys())
+
+ key = draw(st.sampled_from(known_keys))
+ value = draw(collection.known_metadata_keys[key])
+
+ legal_ops = [None, "$eq", "$ne"]
+ if not isinstance(value, str):
+ legal_ops = ["$gt", "$lt", "$lte", "$gte"] + legal_ops
+
+ op = draw(st.sampled_from(legal_ops))
+
+ if op is None:
+ return {key: value}
+ else:
+ return {key: {op: value}}
+
+
+@st.composite
+def where_doc_clause(draw, collection):
+ """Generate a where_document filter that could be used against the given collection"""
+ if collection.known_document_keywords:
+ word = draw(st.sampled_from(collection.known_document_keywords))
+ else:
+ word = draw(safe_text)
+ return {"$contains": word}
+
+
+@st.composite
+def binary_operator_clause(draw, base_st):
+ op = draw(st.sampled_from(["$and", "$or"]))
+ return {op: [draw(base_st), draw(base_st)]}
+
+
+@st.composite
+def recursive_where_clause(draw, collection):
+ base_st = where_clause(collection)
+ return draw(st.recursive(base_st, binary_operator_clause))
+
+
+@st.composite
+def recursive_where_doc_clause(draw, collection):
+ base_st = where_doc_clause(collection)
+ return draw(st.recursive(base_st, binary_operator_clause))
+
+
+class Filter(TypedDict):
+ where: Optional[Dict[str, Union[str, int, float]]]
+ ids: Optional[Union[str, List[str]]]
+ where_document: Optional[types.WhereDocument]
+
+
+@st.composite
+def filters(
+ draw,
+ collection_st: st.SearchStrategy[Collection],
+ recordset_st: st.SearchStrategy[RecordSet],
+ include_all_ids=False,
+) -> Filter:
+ collection = draw(collection_st)
+ recordset = draw(recordset_st)
+
+ where_clause = draw(st.one_of(st.none(), recursive_where_clause(collection)))
+ where_document_clause = draw(
+ st.one_of(st.none(), recursive_where_doc_clause(collection))
+ )
+
+ ids = recordset["ids"]
+ # Record sets can be a value instead of a list of values if there is only one record
+ if isinstance(ids, str):
+ ids = [ids]
+
+ if not include_all_ids:
+ ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids))))
+ if ids is not None:
+ # Remove duplicates since hypothesis samples with replacement
+ ids = list(set(ids))
+
+ # Test both the single value list and the unwrapped single value case
+ if ids is not None and len(ids) == 1 and draw(st.booleans()):
+ ids = ids[0]
+
+ return {"where": where_clause, "where_document": where_document_clause, "ids": ids}
diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py
new file mode 100644
index 00000000000..be31e1a63a7
--- /dev/null
+++ b/chromadb/test/property/test_add.py
@@ -0,0 +1,74 @@
+import pytest
+import hypothesis.strategies as st
+from hypothesis import given, settings
+from chromadb.api import API
+import chromadb.test.property.strategies as strategies
+import chromadb.test.property.invariants as invariants
+
+collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")
+
+
+@given(collection=collection_st, embeddings=strategies.recordsets(collection_st))
+@settings(deadline=None)
+def test_add(
+ api: API, collection: strategies.Collection, embeddings: strategies.RecordSet
+):
+ api.reset()
+
+ # TODO: Generative embedding functions
+ coll = api.create_collection(
+ name=collection.name,
+ metadata=collection.metadata,
+ embedding_function=collection.embedding_function,
+ )
+ coll.add(**embeddings)
+
+ embeddings = invariants.wrap_all(embeddings)
+ invariants.count(coll, embeddings)
+ n_results = max(1, (len(embeddings["ids"]) // 10))
+ invariants.ann_accuracy(
+ coll,
+ embeddings,
+ n_results=n_results,
+ embedding_function=collection.embedding_function,
+ )
+
+
+# TODO: This test fails right now because the ids are not sorted by the input order
+@pytest.mark.xfail(
+ reason="This is expected to fail right now. We should change the API to sort the \
+ ids by input order."
+)
+def test_out_of_order_ids(api: API):
+ api.reset()
+ ooo_ids = [
+ "40",
+ "05",
+ "8",
+ "6",
+ "10",
+ "01",
+ "00",
+ "3",
+ "04",
+ "20",
+ "02",
+ "9",
+ "30",
+ "11",
+ "13",
+ "2",
+ "0",
+ "7",
+ "06",
+ "5",
+ "50",
+ "12",
+ "03",
+ "4",
+ "1",
+ ]
+ coll = api.create_collection("test", embedding_function=lambda x: [1, 2, 3])
+ coll.add(ids=ooo_ids, embeddings=[[1, 2, 3] for _ in range(len(ooo_ids))])
+ get_ids = coll.get(ids=ooo_ids)["ids"]
+ assert get_ids == ooo_ids
diff --git a/chromadb/test/property/test_collections.py b/chromadb/test/property/test_collections.py
new file mode 100644
index 00000000000..de1cde067ff
--- /dev/null
+++ b/chromadb/test/property/test_collections.py
@@ -0,0 +1,122 @@
+import pytest
+import logging
+import hypothesis.strategies as st
+from typing import List
+import chromadb
+from chromadb.api import API
+from chromadb.api.models.Collection import Collection
+import chromadb.test.property.strategies as strategies
+from hypothesis.stateful import (
+ Bundle,
+ RuleBasedStateMachine,
+ rule,
+ initialize,
+ multiple,
+ consumes,
+ run_state_machine_as_test,
+)
+
+
+class CollectionStateMachine(RuleBasedStateMachine):
+ def __init__(self, api):
+ super().__init__()
+ self.existing = set()
+ self.model = {}
+ self.api = api
+
+ collections = Bundle("collections")
+
+ @initialize()
+ def initialize(self):
+ self.api.reset()
+ self.existing = set()
+
+ @rule(target=collections, coll=strategies.collections())
+ def create_coll(self, coll):
+ if coll.name in self.existing:
+ with pytest.raises(Exception):
+ c = self.api.create_collection(name=coll.name,
+ metadata=coll.metadata,
+ embedding_function=coll.embedding_function)
+ return multiple()
+
+ c = self.api.create_collection(name=coll.name,
+ metadata=coll.metadata,
+ embedding_function=coll.embedding_function)
+ self.existing.add(coll.name)
+
+ assert c.name == coll.name
+ assert c.metadata == coll.metadata
+ return coll
+
+ @rule(coll=collections)
+ def get_coll(self, coll):
+ if coll.name in self.existing:
+ c = self.api.get_collection(name=coll.name)
+ assert c.name == coll.name
+ assert c.metadata == coll.metadata
+ else:
+ with pytest.raises(Exception):
+ self.api.get_collection(name=coll.name)
+
+ @rule(coll=consumes(collections))
+ def delete_coll(self, coll):
+ if coll.name in self.existing:
+ self.api.delete_collection(name=coll.name)
+ self.existing.remove(coll.name)
+ else:
+ with pytest.raises(Exception):
+ self.api.delete_collection(name=coll.name)
+
+ with pytest.raises(Exception):
+ self.api.get_collection(name=coll.name)
+
+ @rule()
+ def list_collections(self):
+ colls = self.api.list_collections()
+ assert len(colls) == len(self.existing)
+ for c in colls:
+ assert c.name in self.existing
+
+ @rule(
+ target=collections,
+ coll=st.one_of(consumes(collections), strategies.collections()),
+ )
+ def get_or_create_coll(self, coll):
+ c = self.api.get_or_create_collection(name=coll.name,
+ metadata=coll.metadata,
+ embedding_function=coll.embedding_function)
+ assert c.name == coll.name
+ if coll.metadata is not None:
+ assert c.metadata == coll.metadata
+ self.existing.add(coll.name)
+ return coll
+
+ @rule(
+ target=collections,
+ coll=consumes(collections),
+ new_metadata=strategies.collection_metadata,
+ new_name=st.one_of(st.none(), strategies.collection_name()),
+ )
+ def modify_coll(self, coll, new_metadata, new_name):
+ c = self.api.get_collection(name=coll.name)
+
+ if new_metadata is not None:
+ coll.metadata = new_metadata
+
+ if new_name is not None:
+ self.existing.remove(coll.name)
+ self.existing.add(new_name)
+ coll.name = new_name
+
+ c.modify(metadata=new_metadata, name=new_name)
+ c = self.api.get_collection(name=coll.name)
+
+ assert c.name == coll.name
+ assert c.metadata == coll.metadata
+ return coll
+
+
+def test_collections(caplog, api):
+ caplog.set_level(logging.ERROR)
+ run_state_machine_as_test(lambda: CollectionStateMachine(api))
\ No newline at end of file
diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py
new file mode 100644
index 00000000000..af6d0487aa8
--- /dev/null
+++ b/chromadb/test/property/test_cross_version_persist.py
@@ -0,0 +1,252 @@
+import sys
+import os
+import shutil
+import subprocess
+import tempfile
+from typing import Generator, Tuple
+from hypothesis import given, settings
+import hypothesis.strategies as st
+import pytest
+import json
+from urllib import request
+from chromadb.api import API
+import chromadb.test.property.strategies as strategies
+import chromadb.test.property.invariants as invariants
+from importlib.util import spec_from_file_location, module_from_spec
+from packaging import version as packaging_version
+import re
+import multiprocessing
+from chromadb import Client
+from chromadb.config import Settings
+import sys
+
+MINIMUM_VERSION = "0.3.20"
+COLLECTION_NAME_LOWERCASE_VERSION = "0.3.21"
+version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$")
+
+
+def _patch_uppercase_coll_name(
+ collection: strategies.Collection, embeddings: strategies.RecordSet
+):
+ """Old versions didn't handle uppercase characters in collection names"""
+ collection.name = collection.name.lower()
+
+
+def _patch_empty_dict_metadata(
+ collection: strategies.Collection, embeddings: strategies.RecordSet
+):
+ """Old versions do the wrong thing when metadata is a single empty dict"""
+ if embeddings["metadatas"] == {}:
+ embeddings["metadatas"] = None
+
+
+version_patches = [
+ ("0.3.21", _patch_uppercase_coll_name),
+ ("0.3.21", _patch_empty_dict_metadata),
+]
+
+
+def patch_for_version(
+ version, collection: strategies.Collection, embeddings: strategies.RecordSet
+):
+ """Override aspects of the collection and embeddings, before testing, to account for
+ breaking changes in old versions."""
+
+ for patch_version, patch in version_patches:
+ if packaging_version.Version(version) <= packaging_version.Version(
+ patch_version
+ ):
+ patch(collection, embeddings)
+
+
+def versions():
+ """Returns the pinned minimum version and the latest version of chromadb."""
+ url = "https://pypi.org/pypi/chromadb/json"
+ data = json.load(request.urlopen(request.Request(url)))
+ versions = list(data["releases"].keys())
+ # Older versions on pypi contain "devXYZ" suffixes
+ versions = [v for v in versions if version_re.match(v)]
+ versions.sort(key=packaging_version.Version)
+ return [MINIMUM_VERSION, versions[-1]]
+
+
+def configurations(versions):
+ return [
+ (
+ version,
+ Settings(
+ chroma_api_impl="local",
+ chroma_db_impl="duckdb+parquet",
+ persist_directory=tempfile.gettempdir() + "/tests/" + version + "/",
+ ),
+ )
+ for version in versions
+ ]
+
+
+test_old_versions = versions()
+base_install_dir = tempfile.gettempdir() + "/persistence_test_chromadb_versions"
+
+
+# This fixture is not shared with the rest of the tests because it is unique in how it
+# installs the versions of chromadb
+@pytest.fixture(scope="module", params=configurations(test_old_versions))
+def version_settings(request) -> Generator[Tuple[str, Settings], None, None]:
+ configuration = request.param
+ version = configuration[0]
+ install_version(version)
+ yield configuration
+ # Cleanup the installed version
+ path = get_path_to_version_install(version)
+ shutil.rmtree(path)
+ # Cleanup the persisted data
+ data_path = configuration[1].persist_directory
+ if os.path.exists(data_path):
+ shutil.rmtree(data_path)
+
+
+def get_path_to_version_install(version):
+ return base_install_dir + "/" + version
+
+
+def get_path_to_version_library(version):
+ return get_path_to_version_install(version) + "/chromadb/__init__.py"
+
+
+def install_version(version):
+ # Check if already installed
+ version_library = get_path_to_version_library(version)
+ if os.path.exists(version_library):
+ return
+ path = get_path_to_version_install(version)
+ install(f"chromadb=={version}", path)
+
+
+def install(pkg, path):
+ # -q -q to suppress pip output to ERROR level
+ # https://pip.pypa.io/en/stable/cli/pip/#quiet
+ print(f"Installing chromadb version {pkg} to {path}")
+ return subprocess.check_call(
+ [
+ sys.executable,
+ "-m",
+ "pip",
+ "-q",
+ "-q",
+ "install",
+ pkg,
+ "--target={}".format(path),
+ ]
+ )
+
+
+def switch_to_version(version):
+ module_name = "chromadb"
+ # Remove old version from sys.modules, except test modules
+ old_modules = {
+ n: m
+ for n, m in sys.modules.items()
+ if n == module_name or (n.startswith(module_name + "."))
+ }
+ for n in old_modules:
+ del sys.modules[n]
+
+ # Load the target version and override the path to the installed version
+ # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
+ path = get_path_to_version_library(version)
+ sys.path.insert(0, get_path_to_version_install(version))
+ spec = spec_from_file_location(module_name, path)
+ assert spec is not None and spec.loader is not None
+ module = module_from_spec(spec)
+ spec.loader.exec_module(module)
+ assert module.__version__ == version
+ sys.modules[module_name] = module
+ return module
+
+
+def persist_generated_data_with_old_version(
+ version,
+ settings,
+ collection_strategy: strategies.Collection,
+ embeddings_strategy: strategies.RecordSet,
+):
+ old_module = switch_to_version(version)
+ api: API = old_module.Client(settings)
+ api.reset()
+ coll = api.create_collection(
+ name=collection_strategy.name,
+ metadata=collection_strategy.metadata,
+ embedding_function=lambda x: None,
+ )
+ coll.add(**embeddings_strategy)
+ # We can't use the invariants module here because it uses the current version
+ # Just use some basic checks for sanity and manual testing where you break the new
+ # version
+
+ check_embeddings = invariants.wrap_all(embeddings_strategy)
+ # Check count
+ assert coll.count() == len(check_embeddings["embeddings"] or [])
+ # Check ids
+ result = coll.get()
+ actual_ids = result["ids"]
+ embedding_id_to_index = {id: i for i, id in enumerate(check_embeddings["ids"])}
+ actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id])
+ assert actual_ids == check_embeddings["ids"]
+ api.persist()
+ del api
+
+
+# Since we can't pickle the embedding function, we always generate record sets with embeddings
+collection_st = st.shared(
+ strategies.collections(with_hnsw_params=True, has_embeddings=True), key="coll"
+)
+
+
+@given(
+ collection_strategy=collection_st,
+ embeddings_strategy=strategies.recordsets(collection_st),
+)
+@pytest.mark.skipif(
+ sys.version_info.major < 3
+ or (sys.version_info.major == 3 and sys.version_info.minor <= 7),
+ reason="The mininum supported versions of chroma do not work with python <= 3.7",
+)
+@settings(deadline=None)
+def test_cycle_versions(
+ version_settings: Tuple[str, Settings],
+ collection_strategy: strategies.Collection,
+ embeddings_strategy: strategies.RecordSet,
+):
+ # # Test backwards compatibility
+ # # For the current version, ensure that we can load a collection from
+ # # the previous versions
+ version, settings = version_settings
+
+ patch_for_version(version, collection_strategy, embeddings_strategy)
+
+ # Can't pickle a function, and we won't need them
+ collection_strategy.embedding_function = None
+ collection_strategy.known_metadata_keys = {}
+
+ # Run the task in a separate process to avoid polluting the current process
+ # with the old version. Using spawn instead of fork to avoid sharing the
+ # current process memory which would cause the old version to be loaded
+ ctx = multiprocessing.get_context("spawn")
+ p = ctx.Process(
+ target=persist_generated_data_with_old_version,
+ args=(version, settings, collection_strategy, embeddings_strategy),
+ )
+ p.start()
+ p.join()
+
+ # Switch to the current version (local working directory) and check the invariants
+ # are preserved for the collection
+ api = Client(settings)
+ coll = api.get_collection(
+ name=collection_strategy.name, embedding_function=lambda x: None
+ )
+ invariants.count(coll, embeddings_strategy)
+ invariants.metadatas_match(coll, embeddings_strategy)
+ invariants.documents_match(coll, embeddings_strategy)
+ invariants.ids_match(coll, embeddings_strategy)
+ invariants.ann_accuracy(coll, embeddings_strategy)
diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py
new file mode 100644
index 00000000000..0a9a682e5d8
--- /dev/null
+++ b/chromadb/test/property/test_embeddings.py
@@ -0,0 +1,264 @@
+import numpy as np
+import pytest
+import logging
+from hypothesis import given
+import hypothesis.strategies as st
+from typing import Set, List, Optional, cast
+from dataclasses import dataclass
+import chromadb.errors as errors
+import chromadb
+from chromadb.api import API
+from chromadb.api.models.Collection import Collection
+import chromadb.test.property.strategies as strategies
+from hypothesis.stateful import (
+ Bundle,
+ RuleBasedStateMachine,
+ rule,
+ initialize,
+ precondition,
+ consumes,
+ run_state_machine_as_test,
+ multiple,
+ invariant,
+)
+from collections import defaultdict
+import chromadb.test.property.invariants as invariants
+import hypothesis
+
+
+traces = defaultdict(lambda: 0)
+
+
+def trace(key):
+ global traces
+ traces[key] += 1
+
+
+def print_traces():
+ global traces
+ for key, value in traces.items():
+ print(f"{key}: {value}")
+
+
+dtype_shared_st = st.shared(st.sampled_from(strategies.float_types), key="dtype")
+dimension_shared_st = st.shared(
+ st.integers(min_value=2, max_value=2048), key="dimension"
+)
+
+
+@dataclass
+class EmbeddingStateMachineStates:
+ initialize = "initialize"
+ add_embeddings = "add_embeddings"
+ delete_by_ids = "delete_by_ids"
+ update_embeddings = "update_embeddings"
+ upsert_embeddings = "upsert_embeddings"
+
+
+collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")
+
+
+class EmbeddingStateMachine(RuleBasedStateMachine):
+ collection: Collection
+ embedding_ids: Bundle = Bundle("embedding_ids")
+
+ def __init__(self, api=None):
+ super().__init__()
+ # For debug only, to run as class-based test
+ if not api:
+ api = chromadb.Client(configurations()[0])
+ self.api = api
+ self._rules_strategy = strategies.DeterministicRuleStrategy(self)
+
+ @initialize(collection=collection_st)
+ def initialize(self, collection: strategies.Collection):
+ self.api.reset()
+ self.collection = self.api.create_collection(
+ name=collection.name,
+ metadata=collection.metadata,
+ embedding_function=collection.embedding_function,
+ )
+ self.embedding_function = collection.embedding_function
+ trace("init")
+ self.on_state_change(EmbeddingStateMachineStates.initialize)
+ self.embeddings = {
+ "ids": [],
+ "embeddings": [],
+ "metadatas": [],
+ "documents": [],
+ }
+
+ @rule(target=embedding_ids, embedding_set=strategies.recordsets(collection_st))
+ def add_embeddings(self, embedding_set):
+ trace("add_embeddings")
+ self.on_state_change(EmbeddingStateMachineStates.add_embeddings)
+
+ normalized_embedding_set = invariants.wrap_all(embedding_set)
+
+ if len(normalized_embedding_set["ids"]) > 0:
+ trace("add_more_embeddings")
+
+ if set(normalized_embedding_set["ids"]).intersection(
+ set(self.embeddings["ids"])
+ ):
+ with pytest.raises(errors.IDAlreadyExistsError):
+ self.collection.add(**embedding_set)
+ return multiple()
+ else:
+ self.collection.add(**embedding_set)
+ self._upsert_embeddings(embedding_set)
+ return multiple(*normalized_embedding_set["ids"])
+
+ @precondition(lambda self: len(self.embeddings["ids"]) > 20)
+ @rule(ids=st.lists(consumes(embedding_ids), min_size=1, max_size=20))
+ def delete_by_ids(self, ids):
+ trace("remove embeddings")
+ self.on_state_change(EmbeddingStateMachineStates.delete_by_ids)
+ indices_to_remove = [self.embeddings["ids"].index(id) for id in ids]
+
+ self.collection.delete(ids=ids)
+ self._remove_embeddings(set(indices_to_remove))
+
+ # Removing the precondition causes the tests to frequently fail as "unsatisfiable"
+ # Using a value < 5 causes retries and lowers the number of valid samples
+ @precondition(lambda self: len(self.embeddings["ids"]) >= 5)
+ @rule(
+ embedding_set=strategies.recordsets(
+ collection_strategy=collection_st,
+ id_strategy=embedding_ids,
+ min_size=1,
+ max_size=5,
+ )
+ )
+ def update_embeddings(self, embedding_set):
+ trace("update embeddings")
+ self.on_state_change(EmbeddingStateMachineStates.update_embeddings)
+ self.collection.update(**embedding_set)
+ self._upsert_embeddings(embedding_set)
+
+ # Using a value < 3 causes more retries and lowers the number of valid samples
+ @precondition(lambda self: len(self.embeddings["ids"]) >= 3)
+ @rule(
+ embedding_set=strategies.recordsets(
+ collection_strategy=collection_st,
+ id_strategy=st.one_of(embedding_ids, strategies.safe_text),
+ min_size=1,
+ max_size=5,
+ )
+ )
+ def upsert_embeddings(self, embedding_set):
+ trace("upsert embeddings")
+ self.on_state_change(EmbeddingStateMachineStates.upsert_embeddings)
+ self.collection.upsert(**embedding_set)
+ self._upsert_embeddings(embedding_set)
+
+ @invariant()
+ def count(self):
+ invariants.count(self.collection, self.embeddings) # type: ignore
+
+ @invariant()
+ def no_duplicates(self):
+ invariants.no_duplicates(self.collection)
+
+ @invariant()
+ def ann_accuracy(self):
+ invariants.ann_accuracy(
+ collection=self.collection, record_set=self.embeddings, min_recall=0.95, embedding_function=self.embedding_function # type: ignore
+ )
+
+ def _upsert_embeddings(self, embeddings: strategies.RecordSet):
+ embeddings = invariants.wrap_all(embeddings)
+ for idx, id in enumerate(embeddings["ids"]):
+ if id in self.embeddings["ids"]:
+ target_idx = self.embeddings["ids"].index(id)
+ if "embeddings" in embeddings and embeddings["embeddings"] is not None:
+ self.embeddings["embeddings"][target_idx] = embeddings[
+ "embeddings"
+ ][idx]
+ else:
+ self.embeddings["embeddings"][target_idx] = self.embedding_function(
+ [embeddings["documents"][idx]]
+ )[0]
+ if "metadatas" in embeddings and embeddings["metadatas"] is not None:
+ self.embeddings["metadatas"][target_idx] = embeddings["metadatas"][
+ idx
+ ]
+ if "documents" in embeddings and embeddings["documents"] is not None:
+ self.embeddings["documents"][target_idx] = embeddings["documents"][
+ idx
+ ]
+ else:
+ # Add path
+ self.embeddings["ids"].append(id)
+ if "embeddings" in embeddings and embeddings["embeddings"] is not None:
+ self.embeddings["embeddings"].append(embeddings["embeddings"][idx])
+ else:
+ self.embeddings["embeddings"].append(
+ self.embedding_function([embeddings["documents"][idx]])[0]
+ )
+ if "metadatas" in embeddings and embeddings["metadatas"] is not None:
+ self.embeddings["metadatas"].append(embeddings["metadatas"][idx])
+ else:
+ self.embeddings["metadatas"].append(None)
+ if "documents" in embeddings and embeddings["documents"] is not None:
+ self.embeddings["documents"].append(embeddings["documents"][idx])
+ else:
+ self.embeddings["documents"].append(None)
+
+ def _remove_embeddings(self, indices_to_remove: Set[int]):
+ indices_list = list(indices_to_remove)
+ indices_list.sort(reverse=True)
+
+ for i in indices_list:
+ del self.embeddings["ids"][i]
+ del self.embeddings["embeddings"][i]
+ del self.embeddings["metadatas"][i]
+ del self.embeddings["documents"][i]
+
+ def on_state_change(self, new_state):
+ pass
+
+
+def test_embeddings_state(caplog, api):
+ caplog.set_level(logging.ERROR)
+ run_state_machine_as_test(lambda: EmbeddingStateMachine(api))
+ print_traces()
+
+
+def test_multi_add(api: API):
+ api.reset()
+ coll = api.create_collection(name="foo")
+ coll.add(ids=["a"], embeddings=[[0.0]])
+ assert coll.count() == 1
+
+ with pytest.raises(errors.IDAlreadyExistsError):
+ coll.add(ids=["a"], embeddings=[[0.0]])
+
+ assert coll.count() == 1
+
+ results = coll.get()
+ assert results["ids"] == ["a"]
+
+ coll.delete(ids=["a"])
+ assert coll.count() == 0
+
+
+def test_dup_add(api: API):
+ api.reset()
+ coll = api.create_collection(name="foo")
+ with pytest.raises(errors.DuplicateIDError):
+ coll.add(ids=["a", "a"], embeddings=[[0.0], [1.1]])
+ with pytest.raises(errors.DuplicateIDError):
+ coll.upsert(ids=["a", "a"], embeddings=[[0.0], [1.1]])
+
+
+# TODO: Use SQL escaping correctly internally
+@pytest.mark.xfail(reason="We don't properly escape SQL internally, causing problems")
+def test_escape_chars_in_ids(api: API):
+ api.reset()
+ id = "\x1f"
+ coll = api.create_collection(name="foo")
+ coll.add(ids=[id], embeddings=[[0.0]])
+ assert coll.count() == 1
+ coll.delete(ids=[id])
+ assert coll.count() == 0
diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py
new file mode 100644
index 00000000000..1008f895fd4
--- /dev/null
+++ b/chromadb/test/property/test_filtering.py
@@ -0,0 +1,177 @@
+from hypothesis import given, settings, HealthCheck
+from chromadb.api import API
+from chromadb.errors import NoDatapointsException
+from chromadb.test.property import invariants
+import chromadb.test.property.strategies as strategies
+import hypothesis.strategies as st
+import logging
+import random
+
+
+def _filter_where_clause(clause, mm):
+ """Return true if the where clause is true for the given metadata map"""
+
+ key, expr = list(clause.items())[0]
+
+ # Handle the shorthand for equal: {key: val} where val is a simple value
+ if isinstance(expr, str) or isinstance(expr, int) or isinstance(expr, float):
+ return _filter_where_clause({key: {"$eq": expr}}, mm)
+
+ if key == "$and":
+ return all(_filter_where_clause(clause, mm) for clause in expr)
+ if key == "$or":
+ return any(_filter_where_clause(clause, mm) for clause in expr)
+
+ op, val = list(expr.items())[0]
+
+ if op == "$eq":
+ return key in mm and mm[key] == val
+ elif op == "$ne":
+ return key in mm and mm[key] != val
+ elif op == "$gt":
+ return key in mm and mm[key] > val
+ elif op == "$gte":
+ return key in mm and mm[key] >= val
+ elif op == "$lt":
+ return key in mm and mm[key] < val
+ elif op == "$lte":
+ return key in mm and mm[key] <= val
+ else:
+ raise ValueError("Unknown operator: {}".format(key))
+
+
+def _filter_where_doc_clause(clause, doc):
+ key, expr = list(clause.items())[0]
+ if key == "$and":
+ return all(_filter_where_doc_clause(clause, doc) for clause in expr)
+ elif key == "$or":
+ return any(_filter_where_doc_clause(clause, doc) for clause in expr)
+ elif key == "$contains":
+ return expr in doc
+ else:
+ raise ValueError("Unknown operator: {}".format(key))
+
+
+EMPTY_DICT = {}
+EMPTY_STRING = ""
+
+
+def _filter_embedding_set(recordset: strategies.RecordSet, filter: strategies.Filter):
+ """Return IDs from the embedding set that match the given filter object"""
+
+ recordset = invariants.wrap_all(recordset)
+
+ ids = set(recordset["ids"])
+
+ filter_ids = filter["ids"]
+ if filter_ids is not None:
+ filter_ids = invariants.maybe_wrap(filter_ids)
+ assert filter_ids is not None
+ # If the filter ids is an empty list then we treat that as get all
+ if len(filter_ids) != 0:
+ ids = ids.intersection(filter_ids)
+
+ for i in range(len(recordset["ids"])):
+ if filter["where"]:
+ metadatas = recordset["metadatas"] or [EMPTY_DICT] * len(recordset["ids"])
+ if not _filter_where_clause(filter["where"], metadatas[i]):
+ ids.discard(recordset["ids"][i])
+
+ if filter["where_document"]:
+ documents = recordset["documents"] or [EMPTY_STRING] * len(recordset["ids"])
+ if not _filter_where_doc_clause(filter["where_document"], documents[i]):
+ ids.discard(recordset["ids"][i])
+
+ return list(ids)
+
+
+collection_st = st.shared(
+ strategies.collections(add_filterable_data=True, with_hnsw_params=True),
+ key="coll",
+)
+recordset_st = st.shared(
+ strategies.recordsets(collection_st, max_size=1000), key="recordset"
+)
+
+
+@settings(
+ suppress_health_check=[
+ HealthCheck.function_scoped_fixture,
+ HealthCheck.large_base_example,
+ ]
+)
+@given(
+ collection=collection_st,
+ recordset=recordset_st,
+ filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1),
+)
+def test_filterable_metadata_get(caplog, api: API, collection, recordset, filters):
+ caplog.set_level(logging.ERROR)
+
+ api.reset()
+ coll = api.create_collection(
+ name=collection.name,
+ metadata=collection.metadata,
+ embedding_function=collection.embedding_function,
+ )
+ coll.add(**recordset)
+
+ for filter in filters:
+ result_ids = coll.get(**filter)["ids"]
+ expected_ids = _filter_embedding_set(recordset, filter)
+ assert sorted(result_ids) == sorted(expected_ids)
+
+
+@settings(
+ suppress_health_check=[
+ HealthCheck.function_scoped_fixture,
+ HealthCheck.large_base_example,
+ ]
+)
+@given(
+ collection=collection_st,
+ recordset=recordset_st,
+ filters=st.lists(
+ strategies.filters(collection_st, recordset_st, include_all_ids=True),
+ min_size=1,
+ ),
+)
+def test_filterable_metadata_query(
+ caplog,
+ api: API,
+ collection: strategies.Collection,
+ recordset: strategies.RecordSet,
+ filters,
+):
+ caplog.set_level(logging.ERROR)
+
+ api.reset()
+ coll = api.create_collection(
+ name=collection.name,
+ metadata=collection.metadata,
+ embedding_function=collection.embedding_function,
+ )
+ coll.add(**recordset)
+ recordset = invariants.wrap_all(recordset)
+ total_count = len(recordset["ids"])
+ # Pick a random vector
+ if collection.has_embeddings:
+ random_query = recordset["embeddings"][random.randint(0, total_count - 1)]
+ else:
+ random_query = collection.embedding_function(
+ recordset["documents"][random.randint(0, total_count - 1)]
+ )
+ for filter in filters:
+ try:
+ result_ids = set(
+ coll.query(
+ query_embeddings=random_query,
+ n_results=total_count,
+ where=filter["where"],
+ where_document=filter["where_document"],
+ )["ids"][0]
+ )
+ except NoDatapointsException:
+ result_ids = set()
+ expected_ids = set(_filter_embedding_set(recordset, filter))
+ assert len(result_ids.intersection(expected_ids)) == len(result_ids)
diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py
new file mode 100644
index 00000000000..d4632de2e30
--- /dev/null
+++ b/chromadb/test/property/test_persist.py
@@ -0,0 +1,156 @@
+import logging
+import multiprocessing
+from typing import Generator, Callable
+from hypothesis import given
+import hypothesis.strategies as st
+import pytest
+import chromadb
+from chromadb.api import API
+from chromadb.config import Settings
+import chromadb.test.property.strategies as strategies
+import chromadb.test.property.invariants as invariants
+from chromadb.test.property.test_embeddings import (
+ EmbeddingStateMachine,
+ EmbeddingStateMachineStates,
+)
+from hypothesis.stateful import run_state_machine_as_test, rule, precondition
+import os
+import shutil
+import pytest
+import tempfile
+
+CreatePersistAPI = Callable[[], API]
+
+configurations = [
+ Settings(
+ chroma_api_impl="local",
+ chroma_db_impl="duckdb+parquet",
+ persist_directory=tempfile.gettempdir() + "/tests",
+ )
+]
+
+
+@pytest.fixture(scope="module", params=configurations)
+def settings(request) -> Generator[Settings, None, None]:
+ configuration = request.param
+ yield configuration
+ save_path = configuration.persist_directory
+ # Remove if it exists
+ if os.path.exists(save_path):
+ shutil.rmtree(save_path)
+
+
+collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")
+
+
+@given(
+ collection_strategy=collection_st,
+ embeddings_strategy=strategies.recordsets(collection_st),
+)
+def test_persist(
+ settings: Settings,
+ collection_strategy: strategies.Collection,
+ embeddings_strategy: strategies.RecordSet,
+):
+ api_1 = chromadb.Client(settings)
+ api_1.reset()
+ coll = api_1.create_collection(
+ name=collection_strategy.name,
+ metadata=collection_strategy.metadata,
+ embedding_function=collection_strategy.embedding_function,
+ )
+
+ coll.add(**embeddings_strategy)
+
+ invariants.count(coll, embeddings_strategy)
+ invariants.metadatas_match(coll, embeddings_strategy)
+ invariants.documents_match(coll, embeddings_strategy)
+ invariants.ids_match(coll, embeddings_strategy)
+ invariants.ann_accuracy(
+ coll,
+ embeddings_strategy,
+ embedding_function=collection_strategy.embedding_function,
+ )
+
+ api_1.persist()
+ del api_1
+
+ api_2 = chromadb.Client(settings)
+ coll = api_2.get_collection(
+ name=collection_strategy.name,
+ embedding_function=collection_strategy.embedding_function,
+ )
+ invariants.count(coll, embeddings_strategy)
+ invariants.metadatas_match(coll, embeddings_strategy)
+ invariants.documents_match(coll, embeddings_strategy)
+ invariants.ids_match(coll, embeddings_strategy)
+ invariants.ann_accuracy(
+ coll,
+ embeddings_strategy,
+ embedding_function=collection_strategy.embedding_function,
+ )
+
+
+def load_and_check(settings: Settings, collection_name: str, embeddings_set, conn):
+ try:
+ api = chromadb.Client(settings)
+ coll = api.get_collection(
+ name=collection_name, embedding_function=lambda x: None
+ )
+ invariants.count(coll, embeddings_set)
+ invariants.metadatas_match(coll, embeddings_set)
+ invariants.documents_match(coll, embeddings_set)
+ invariants.ids_match(coll, embeddings_set)
+ invariants.ann_accuracy(coll, embeddings_set)
+ except Exception as e:
+ conn.send(e)
+ raise e
+
+
+class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates):
+ persist = "persist"
+
+
+class PersistEmbeddingsStateMachine(EmbeddingStateMachine):
+ def __init__(self, api: API, settings: Settings):
+ self.api = api
+ self.settings = settings
+ self.last_persist_delay = 10
+ self.api.reset()
+ super().__init__(self.api)
+
+ @precondition(lambda self: len(self.embeddings["ids"]) >= 1)
+ @precondition(lambda self: self.last_persist_delay <= 0)
+ @rule()
+ def persist(self):
+ self.on_state_change(PersistEmbeddingsStateMachineStates.persist)
+ self.api.persist()
+ collection_name = self.collection.name
+ # Create a new process and then inside the process run the invariants
+ # TODO: Once we switch off of duckdb and onto sqlite we can remove this
+ ctx = multiprocessing.get_context("spawn")
+ conn1, conn2 = multiprocessing.Pipe()
+ p = ctx.Process(
+ target=load_and_check,
+ args=(self.settings, collection_name, self.embeddings, conn2),
+ )
+ p.start()
+ p.join()
+
+ if conn1.poll():
+ e = conn1.recv()
+ raise e
+
+ def on_state_change(self, new_state):
+ if new_state == PersistEmbeddingsStateMachineStates.persist:
+ self.last_persist_delay = 10
+ else:
+ self.last_persist_delay -= 1
+
+
+def test_persist_embeddings_state(caplog, settings: Settings):
+ caplog.set_level(logging.ERROR)
+ api = chromadb.Client(settings)
+ run_state_machine_as_test(
+ lambda: PersistEmbeddingsStateMachine(settings=settings, api=api)
+ )
diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py
index 8c010545711..277e8530504 100644
--- a/chromadb/test/test_api.py
+++ b/chromadb/test/test_api.py
@@ -13,17 +13,6 @@
import numpy as np
-@pytest.fixture
-def local_api():
- return chromadb.Client(
- Settings(
- chroma_api_impl="local",
- chroma_db_impl="duckdb",
- persist_directory=tempfile.gettempdir(),
- )
- )
-
-
@pytest.fixture
def local_persist_api():
return chromadb.Client(
@@ -34,7 +23,6 @@ def local_persist_api():
)
)
-
# https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached
@pytest.fixture
def local_persist_api_cache_bust():
@@ -47,67 +35,6 @@ def local_persist_api_cache_bust():
)
-@pytest.fixture
-def fastapi_integration_api():
- return chromadb.Client() # configured by environment variables
-
-
-def _build_fastapi_api():
- return chromadb.Client(
- Settings(
- chroma_api_impl="rest", chroma_server_host="localhost", chroma_server_http_port="6666"
- )
- )
-
-
-@pytest.fixture
-def fastapi_api():
- return _build_fastapi_api()
-
-
-def run_server():
- settings = Settings(
- chroma_api_impl="local",
- chroma_db_impl="duckdb",
- persist_directory=tempfile.gettempdir() + "/test_server",
- )
- server = chromadb.server.fastapi.FastAPI(settings)
- uvicorn.run(server.app(), host="0.0.0.0", port=6666, log_level="info")
-
-
-def await_server(attempts=0):
- api = _build_fastapi_api()
-
- try:
- api.heartbeat()
- except ConnectionError as e:
- if attempts > 10:
- raise e
- else:
- time.sleep(2)
- await_server(attempts + 1)
-
-
-@pytest.fixture(scope="module", autouse=True)
-def fastapi_server():
- proc = Process(target=run_server, args=(), daemon=True)
- proc.start()
- await_server()
- yield
- proc.kill()
-
-
-test_apis = [local_api, fastapi_api]
-
-if "CHROMA_INTEGRATION_TEST" in os.environ:
- print("Including integration tests")
- test_apis.append(fastapi_integration_api)
-
-if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ:
- print("Including integration tests only")
- test_apis = [fastapi_integration_api]
-
-
@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist_index_loading(api_fixture, request):
api = request.getfixturevalue("local_persist_api")
@@ -203,22 +130,18 @@ def test_persist(api_fixture, request):
assert api.list_collections() == []
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_heartbeat(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_heartbeat(api):
assert isinstance(api.heartbeat(), int)
batch_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
- "ids": ["https://example.com", "https://example.com"],
+ "ids": ["https://example.com/1", "https://example.com/2"],
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_add(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_add(api):
api.reset()
@@ -229,9 +152,7 @@ def test_add(api_fixture, request):
assert collection.count() == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_get_or_create(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_get_or_create(api):
api.reset()
@@ -251,14 +172,11 @@ def test_get_or_create(api_fixture, request):
minimal_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
- "ids": ["https://example.com", "https://example.com"],
+ "ids": ["https://example.com/1", "https://example.com/2"],
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_add_minimal(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
-
+def test_add_minimal(api):
api.reset()
collection = api.create_collection("testspace")
@@ -268,9 +186,7 @@ def test_add_minimal(api_fixture, request):
assert collection.count() == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_get_from_db(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_get_from_db(api):
api.reset()
collection = api.create_collection("testspace")
@@ -280,9 +196,7 @@ def test_get_from_db(api_fixture, request):
assert len(records[key]) == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_reset_db(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_reset_db(api):
api.reset()
@@ -294,9 +208,7 @@ def test_reset_db(api_fixture, request):
assert len(api.list_collections()) == 0
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_get_nearest_neighbors(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_get_nearest_neighbors(api):
api.reset()
collection = api.create_collection("testspace")
@@ -331,10 +243,7 @@ def test_get_nearest_neighbors(api_fixture, request):
assert len(nn[key]) == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_get_nearest_neighbors_filter(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
-
+def test_get_nearest_neighbors_filter(api, request):
api.reset()
collection = api.create_collection("testspace")
collection.add(**batch_records)
@@ -349,9 +258,7 @@ def test_get_nearest_neighbors_filter(api_fixture, request):
assert str(e.value).__contains__("found")
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_delete(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_delete(api):
api.reset()
collection = api.create_collection("testspace")
@@ -362,9 +269,7 @@ def test_delete(api_fixture, request):
assert collection.count() == 0
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_delete_with_index(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_delete_with_index(api):
api.reset()
collection = api.create_collection("testspace")
@@ -373,9 +278,7 @@ def test_delete_with_index(api_fixture, request):
collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1)
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_count(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_count(api):
api.reset()
collection = api.create_collection("testspace")
@@ -384,9 +287,7 @@ def test_count(api_fixture, request):
assert collection.count() == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_modify(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_modify(api):
api.reset()
collection = api.create_collection("testspace")
@@ -396,9 +297,7 @@ def test_modify(api_fixture, request):
assert collection.name == "testspace2"
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_cru(api_fixture, request):
- api: API = request.getfixturevalue(api_fixture.__name__)
+def test_metadata_cru(api):
api.reset()
metadata_a = {"a": 1, "b": 2}
@@ -448,9 +347,9 @@ def test_metadata_cru(api_fixture, request):
assert collection.metadata is None
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_increment_index_on(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_increment_index_on(api):
+
api.reset()
collection = api.create_collection("testspace")
@@ -468,9 +367,9 @@ def test_increment_index_on(api_fixture, request):
assert len(nn[key]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_increment_index_off(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_increment_index_off(api):
+
api.reset()
collection = api.create_collection("testspace")
@@ -488,9 +387,9 @@ def test_increment_index_off(api_fixture, request):
assert len(nn[key]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def skipping_indexing_will_fail(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def skipping_indexing_will_fail(api):
+
api.reset()
collection = api.create_collection("testspace")
@@ -503,9 +402,9 @@ def skipping_indexing_will_fail(api_fixture, request):
assert str(e.value).__contains__("index not found")
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_add_a_collection(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_add_a_collection(api):
+
api.reset()
api.create_collection("testspace")
@@ -519,9 +418,9 @@ def test_add_a_collection(api_fixture, request):
collection = api.get_collection("testspace2")
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_list_collections(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_list_collections(api):
+
api.reset()
api.create_collection("testspace")
@@ -532,9 +431,9 @@ def test_list_collections(api_fixture, request):
assert len(collections) == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_reset(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_reset(api):
+
api.reset()
api.create_collection("testspace")
@@ -549,9 +448,9 @@ def test_reset(api_fixture, request):
assert len(collections) == 0
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_peek(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_peek(api):
+
api.reset()
collection = api.create_collection("testspace")
@@ -574,9 +473,9 @@ def test_peek(api_fixture, request):
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_add_get_int_float(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_add_get_int_float(api):
+
api.reset()
collection = api.create_collection("test_int")
@@ -590,9 +489,9 @@ def test_metadata_add_get_int_float(api_fixture, request):
assert type(items["metadatas"][0]["float_value"]) == float
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_add_query_int_float(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_add_query_int_float(api):
+
api.reset()
collection = api.create_collection("test_int")
@@ -606,9 +505,9 @@ def test_metadata_add_query_int_float(api_fixture, request):
assert type(items["metadatas"][0][0]["float_value"]) == float
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_get_where_string(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_get_where_string(api):
+
api.reset()
collection = api.create_collection("test_int")
@@ -619,9 +518,9 @@ def test_metadata_get_where_string(api_fixture, request):
assert items["metadatas"][0]["string_value"] == "one"
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_get_where_int(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_get_where_int(api):
+
api.reset()
collection = api.create_collection("test_int")
@@ -632,9 +531,9 @@ def test_metadata_get_where_int(api_fixture, request):
assert items["metadatas"][0]["string_value"] == "one"
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_get_where_float(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_get_where_float(api):
+
api.reset()
collection = api.create_collection("test_int")
@@ -646,9 +545,9 @@ def test_metadata_get_where_float(api_fixture, request):
assert items["metadatas"][0]["float_value"] == 1.001
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_update_get_int_float(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_update_get_int_float(api):
+
api.reset()
collection = api.create_collection("test_int")
@@ -670,9 +569,9 @@ def test_metadata_update_get_int_float(api_fixture, request):
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_validation_add(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_validation_add(api):
+
api.reset()
collection = api.create_collection("test_metadata_validation")
@@ -680,9 +579,9 @@ def test_metadata_validation_add(api_fixture, request):
collection.add(**bad_metadata_records)
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_metadata_validation_update(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_metadata_validation_update(api):
+
api.reset()
collection = api.create_collection("test_metadata_validation")
@@ -691,9 +590,9 @@ def test_metadata_validation_update(api_fixture, request):
collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}})
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_validation_get(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_validation_get(api):
+
api.reset()
collection = api.create_collection("test_where_validation")
@@ -701,9 +600,9 @@ def test_where_validation_get(api_fixture, request):
collection.get(where={"value": {"nested": "5"}})
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_validation_query(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_validation_query(api):
+
api.reset()
collection = api.create_collection("test_where_validation")
@@ -721,9 +620,9 @@ def test_where_validation_query(api_fixture, request):
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_lt(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_lt(api):
+
api.reset()
collection = api.create_collection("test_where_lt")
@@ -732,9 +631,9 @@ def test_where_lt(api_fixture, request):
assert len(items["metadatas"]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_lte(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_lte(api):
+
api.reset()
collection = api.create_collection("test_where_lte")
@@ -743,9 +642,9 @@ def test_where_lte(api_fixture, request):
assert len(items["metadatas"]) == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_gt(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_gt(api):
+
api.reset()
collection = api.create_collection("test_where_lte")
@@ -754,9 +653,9 @@ def test_where_gt(api_fixture, request):
assert len(items["metadatas"]) == 2
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_gte(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_gte(api):
+
api.reset()
collection = api.create_collection("test_where_lte")
@@ -765,9 +664,9 @@ def test_where_gte(api_fixture, request):
assert len(items["metadatas"]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_ne_string(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_ne_string(api):
+
api.reset()
collection = api.create_collection("test_where_lte")
@@ -776,9 +675,9 @@ def test_where_ne_string(api_fixture, request):
assert len(items["metadatas"]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_ne_eq_number(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_ne_eq_number(api):
+
api.reset()
collection = api.create_collection("test_where_lte")
@@ -789,9 +688,9 @@ def test_where_ne_eq_number(api_fixture, request):
assert len(items["metadatas"]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_valid_operators(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_valid_operators(api):
+
api.reset()
collection = api.create_collection("test_where_valid_operators")
@@ -851,9 +750,9 @@ def test_where_valid_operators(api_fixture, request):
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_dimensionality_validation_add(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_dimensionality_validation_add(api):
+
api.reset()
collection = api.create_collection("test_dimensionality_validation")
@@ -864,9 +763,9 @@ def test_dimensionality_validation_add(api_fixture, request):
assert "dimensionality" in str(e.value)
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_dimensionality_validation_query(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_dimensionality_validation_query(api):
+
api.reset()
collection = api.create_collection("test_dimensionality_validation_query")
@@ -877,9 +776,9 @@ def test_dimensionality_validation_query(api_fixture, request):
assert "dimensionality" in str(e.value)
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_number_of_elements_validation_query(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_number_of_elements_validation_query(api):
+
api.reset()
collection = api.create_collection("test_number_of_elements_validation")
@@ -890,9 +789,9 @@ def test_number_of_elements_validation_query(api_fixture, request):
assert "number of elements" in str(e.value)
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_query_document_valid_operators(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_query_document_valid_operators(api):
+
api.reset()
collection = api.create_collection("test_where_valid_operators")
@@ -936,9 +835,9 @@ def test_query_document_valid_operators(api_fixture, request):
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_get_where_document(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_get_where_document(api):
+
api.reset()
collection = api.create_collection("test_get_where_document")
@@ -954,9 +853,9 @@ def test_get_where_document(api_fixture, request):
assert len(items["metadatas"]) == 0
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_query_where_document(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_query_where_document(api):
+
api.reset()
collection = api.create_collection("test_query_where_document")
@@ -979,9 +878,9 @@ def test_query_where_document(api_fixture, request):
assert "datapoints" in str(e.value)
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_delete_where_document(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_delete_where_document(api):
+
api.reset()
collection = api.create_collection("test_delete_where_document")
@@ -1015,9 +914,9 @@ def test_delete_where_document(api_fixture, request):
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_logical_operators(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_logical_operators(api):
+
api.reset()
collection = api.create_collection("test_logical_operators")
@@ -1055,9 +954,9 @@ def test_where_logical_operators(api_fixture, request):
assert len(items["metadatas"]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_where_document_logical_operators(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_where_document_logical_operators(api):
+
api.reset()
collection = api.create_collection("test_document_logical_operators")
@@ -1107,9 +1006,9 @@ def test_where_document_logical_operators(api_fixture, request):
}
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_query_include(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_query_include(api):
+
api.reset()
collection = api.create_collection("test_query_include")
@@ -1141,9 +1040,9 @@ def test_query_include(api_fixture, request):
assert items["ids"][0][1] == "id2"
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_get_include(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_get_include(api):
+
api.reset()
collection = api.create_collection("test_get_include")
@@ -1174,9 +1073,9 @@ def test_get_include(api_fixture, request):
# make sure query results are returned in the right order
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_query_order(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_query_order(api):
+
api.reset()
collection = api.create_collection("test_query_order")
@@ -1193,9 +1092,9 @@ def test_query_order(api_fixture, request):
# test to make sure add, get, delete error on invalid id input
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_invalid_id(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_invalid_id(api):
+
api.reset()
collection = api.create_collection("test_invalid_id")
@@ -1215,9 +1114,9 @@ def test_invalid_id(api_fixture, request):
assert "ID" in str(e.value)
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_index_params(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_index_params(api):
+
# first standard add
api.reset()
@@ -1254,10 +1153,10 @@ def test_index_params(api_fixture, request):
assert items["distances"][0][0] < -5
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_invalid_index_params(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_invalid_index_params(api):
+
+
api.reset()
with pytest.raises(Exception):
@@ -1273,8 +1172,7 @@ def test_invalid_index_params(api_fixture, request):
collection.add(**records)
-@pytest.mark.parametrize("api_fixture", [local_persist_api])
-def test_persist_index_loading_params(api_fixture, request):
+def test_persist_index_loading_params(api, request):
api = request.getfixturevalue("local_persist_api")
api.reset()
collection = api.create_collection("test", metadata={"hnsw:space": "ip"})
@@ -1297,9 +1195,9 @@ def test_persist_index_loading_params(api_fixture, request):
assert len(nn[key]) == 1
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_add_large(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_add_large(api):
+
api.reset()
@@ -1317,9 +1215,9 @@ def test_add_large(api_fixture, request):
# test get_version
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_get_version(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_get_version(api):
+
api.reset()
version = api.get_version()
@@ -1330,9 +1228,9 @@ def test_get_version(api_fixture, request):
# test delete_collection
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_delete_collection(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+
+def test_delete_collection(api):
+
api.reset()
collection = api.create_collection("test_delete_collection")
collection.add(**records)
@@ -1342,15 +1240,15 @@ def test_delete_collection(api_fixture, request):
assert len(api.list_collections()) == 0
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_multiple_collections(api_fixture, request):
+
+def test_multiple_collections(api):
embeddings1 = np.random.rand(10, 512).astype(np.float32).tolist()
embeddings2 = np.random.rand(10, 512).astype(np.float32).tolist()
ids1 = [f"http://example.com/1/{i}" for i in range(len(embeddings1))]
ids2 = [f"http://example.com/2/{i}" for i in range(len(embeddings2))]
- api = request.getfixturevalue(api_fixture.__name__)
+
api.reset()
coll1 = api.create_collection("coll1")
coll1.add(embeddings=embeddings1, ids=ids1)
@@ -1369,10 +1267,10 @@ def test_multiple_collections(api_fixture, request):
assert results2["ids"][0][0] == ids2[0]
-@pytest.mark.parametrize("api_fixture", test_apis)
-def test_update_query(api_fixture, request):
- api = request.getfixturevalue(api_fixture.__name__)
+def test_update_query(api):
+
+
api.reset()
collection = api.create_collection("test_update_query")
collection.add(**records)
@@ -1397,3 +1295,47 @@ def test_update_query(api_fixture, request):
assert results["documents"][0][0] == updated_records["documents"][0]
assert results["metadatas"][0][0]["foo"] == "bar"
assert results["embeddings"][0][0] == updated_records["embeddings"][0]
+
+
+initial_records = {
+ "embeddings": [[0, 0, 0], [1.2, 2.24, 3.2], [2.2, 3.24, 4.2]],
+ "ids": ["id1", "id2", "id3"],
+ "metadatas": [{"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2}, {"string_value": "three"}],
+ "documents": ["this document is first", "this document is second", "this document is third"],
+}
+
+new_records = {
+ "embeddings": [[3.0, 3.0, 1.1], [3.2, 4.24, 5.2]],
+ "ids": ["id1", "id4"],
+ "metadatas": [{"int_value": 1, "string_value": "one_of_one", "float_value": 1.001}, {"int_value": 4}],
+ "documents": ["this document is even more first", "this document is new and fourth"],
+}
+
+def test_upsert(api):
+ api.reset()
+ collection = api.create_collection("test")
+
+ collection.add(**initial_records)
+ assert collection.count() == 3
+
+ collection.upsert(**new_records)
+ assert collection.count() == 4
+
+ get_result = collection.get(include=['embeddings', 'metadatas', 'documents'], ids=new_records['ids'][0])
+ assert get_result['embeddings'][0] == new_records['embeddings'][0]
+ assert get_result['metadatas'][0] == new_records['metadatas'][0]
+ assert get_result['documents'][0] == new_records['documents'][0]
+
+ query_result = collection.query(query_embeddings=get_result['embeddings'], n_results=1, include=['embeddings', 'metadatas', 'documents'])
+ assert query_result['embeddings'][0][0] == new_records['embeddings'][0]
+ assert query_result['metadatas'][0][0] == new_records['metadatas'][0]
+ assert query_result['documents'][0][0] == new_records['documents'][0]
+
+ collection.delete(ids=initial_records['ids'][2])
+ collection.upsert(ids=initial_records['ids'][2], embeddings=[[1.1, 0.99, 2.21]], metadatas=[{"string_value": "a new string value"}])
+ assert collection.count() == 4
+
+ get_result = collection.get(include=['embeddings', 'metadatas', 'documents'], ids=['id3'])
+ assert get_result['embeddings'][0] == [1.1, 0.99, 2.21]
+ assert get_result['metadatas'][0] == {"string_value": "a new string value"}
+ assert get_result['documents'][0] == None
\ No newline at end of file
diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py
index 28e1103edd9..057ecac9fae 100644
--- a/chromadb/utils/embedding_functions.py
+++ b/chromadb/utils/embedding_functions.py
@@ -3,16 +3,21 @@
class SentenceTransformerEmbeddingFunction(EmbeddingFunction):
+
+ models = {}
+
# If you have a beefier machine, try "gtr-t5-large".
# for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
- try:
- from sentence_transformers import SentenceTransformer
- except ImportError:
- raise ValueError(
- "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
- )
- self._model = SentenceTransformer(model_name)
+ if model_name not in self.models:
+ try:
+ from sentence_transformers import SentenceTransformer
+ except ImportError:
+ raise ValueError(
+ "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
+ )
+ self.models[model_name] = SentenceTransformer(model_name)
+ self._model = self.models[model_name]
def __call__(self, texts: Documents) -> Embeddings:
return self._model.encode(list(texts), convert_to_numpy=True).tolist()
diff --git a/clients/js/src/generated/api.ts b/clients/js/src/generated/api.ts
index 20d03625c5b..f4c5bd3638d 100644
--- a/clients/js/src/generated/api.ts
+++ b/clients/js/src/generated/api.ts
@@ -614,6 +614,49 @@ export const ApiApiFetchParamCreator = function (configuration?: Configuration)
options: localVarRequestOptions,
};
},
+ /**
+ * @summary Upsert
+ * @param {string} collectionName
+ * @param {Api.AddEmbedding} request
+ * @param {RequestInit} [options] Override http request option.
+ * @throws {RequiredError}
+ */
+ upsert(collectionName: string, request: Api.AddEmbedding, options: RequestInit = {}): FetchArgs {
+ // verify required parameter 'collectionName' is not null or undefined
+ if (collectionName === null || collectionName === undefined) {
+ throw new RequiredError('collectionName', 'Required parameter collectionName was null or undefined when calling upsert.');
+ }
+ // verify required parameter 'request' is not null or undefined
+ if (request === null || request === undefined) {
+ throw new RequiredError('request', 'Required parameter request was null or undefined when calling upsert.');
+ }
+ let localVarPath = `/api/v1/collections/{collection_name}/upsert`
+ .replace('{collection_name}', encodeURIComponent(String(collectionName)));
+ const localVarPathQueryStart = localVarPath.indexOf("?");
+ const localVarRequestOptions: RequestInit = Object.assign({ method: 'POST' }, options);
+ const localVarHeaderParameter: Headers = options.headers ? new Headers(options.headers) : new Headers();
+ const localVarQueryParameter = new URLSearchParams(localVarPathQueryStart !== -1 ? localVarPath.substring(localVarPathQueryStart + 1) : "");
+ if (localVarPathQueryStart !== -1) {
+ localVarPath = localVarPath.substring(0, localVarPathQueryStart);
+ }
+
+ localVarHeaderParameter.set('Content-Type', 'application/json');
+
+ localVarRequestOptions.headers = localVarHeaderParameter;
+
+ if (request !== undefined) {
+ localVarRequestOptions.body = JSON.stringify(request || {});
+ }
+
+ const localVarQueryParameterString = localVarQueryParameter.toString();
+ if (localVarQueryParameterString) {
+ localVarPath += "?" + localVarQueryParameterString;
+ }
+ return {
+ url: localVarPath,
+ options: localVarRequestOptions,
+ };
+ },
/**
* @summary Version
* @param {RequestInit} [options] Override http request option.
@@ -1113,6 +1156,36 @@ export const ApiApiFp = function(configuration?: Configuration) {
});
};
},
+ /**
+ * @summary Upsert
+ * @param {string} collectionName
+ * @param {Api.AddEmbedding} request
+ * @param {RequestInit} [options] Override http request option.
+ * @throws {RequiredError}
+ */
+ upsert(collectionName: string, request: Api.AddEmbedding, options?: RequestInit): (fetch?: FetchAPI, basePath?: string) => Promise {
+ const localVarFetchArgs = ApiApiFetchParamCreator(configuration).upsert(collectionName, request, options);
+ return (fetch: FetchAPI = defaultFetch, basePath: string = BASE_PATH) => {
+ return fetch(basePath + localVarFetchArgs.url, localVarFetchArgs.options).then((response) => {
+ const contentType = response.headers.get('Content-Type');
+ const mimeType = contentType ? contentType.replace(/;.*/, '') : undefined;
+
+ if (response.status === 200) {
+ if (mimeType === 'application/json') {
+ return response.json() as any;
+ }
+ throw response;
+ }
+ if (response.status === 422) {
+ if (mimeType === 'application/json') {
+ throw response;
+ }
+ throw response;
+ }
+ throw response;
+ });
+ };
+ },
/**
* @summary Version
* @param {RequestInit} [options] Override http request option.
@@ -1324,6 +1397,17 @@ export class ApiApi extends BaseAPI {
return ApiApiFp(this.configuration).updateCollection(collectionName, request, options)(this.fetch, this.basePath);
}
+ /**
+ * @summary Upsert
+ * @param {string} collectionName
+ * @param {Api.AddEmbedding} request
+ * @param {RequestInit} [options] Override http request option.
+ * @throws {RequiredError}
+ */
+ public upsert(collectionName: string, request: Api.AddEmbedding, options?: RequestInit) {
+ return ApiApiFp(this.configuration).upsert(collectionName, request, options)(this.fetch, this.basePath);
+ }
+
/**
* @summary Version
* @param {RequestInit} [options] Override http request option.
diff --git a/clients/js/src/generated/models.ts b/clients/js/src/generated/models.ts
index 0ca3d7b2c9b..b17f701b622 100644
--- a/clients/js/src/generated/models.ts
+++ b/clients/js/src/generated/models.ts
@@ -315,6 +315,9 @@ export namespace Api {
}
+ export interface Upsert200Response {
+ }
+
export interface ValidationError {
loc: (string | number)[];
msg: string;
diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts
index 4b53a35a58b..48015dc714e 100644
--- a/clients/js/src/index.ts
+++ b/clients/js/src/index.ts
@@ -170,16 +170,23 @@ export class Collection {
this.metadata = metadata;
}
- public async add(
+ private async validate(
+ require_embeddings_or_documents: boolean, // set to false in the case of Update
ids: string | string[],
embeddings: number[] | number[][] | undefined,
metadatas?: object | object[],
documents?: string | string[],
- increment_index: boolean = true
) {
- if (embeddings === undefined && documents === undefined) {
- throw new Error("embeddings and documents cannot both be undefined");
- } else if (embeddings === undefined && documents !== undefined) {
+
+ if (require_embeddings_or_documents) {
+ if ((embeddings === undefined) && (documents === undefined)) {
+ throw new Error(
+ "embeddings and documents cannot both be undefined",
+ );
+ }
+ }
+
+ if ((embeddings === undefined) && (documents !== undefined)) {
const documentsArray = toArray(documents);
if (this.embeddingFunction !== undefined) {
embeddings = await this.embeddingFunction.generate(documentsArray);
@@ -222,21 +229,84 @@ export class Collection {
);
}
- return await this.api
- .add(this.name, {
+ const uniqueIds = new Set(idsArray);
+ if (uniqueIds.size !== idsArray.length) {
+ const duplicateIds = idsArray.filter((item, index) => idsArray.indexOf(item) !== index);
+ throw new Error(
+ `Expected IDs to be unique, found duplicates for: ${duplicateIds}`,
+ );
+ }
+
+ return [idsArray, embeddingsArray, metadatasArray, documentsArray]
+ }
+
+ public async add(
+ ids: string | string[],
+ embeddings: number[] | number[][] | undefined,
+ metadatas?: object | object[],
+ documents?: string | string[],
+ increment_index: boolean = true,
+ ) {
+
+ const [idsArray, embeddingsArray, metadatasArray, documentsArray] = await this.validate(
+ true,
+ ids,
+ embeddings,
+ metadatas,
+ documents
+ )
+
+ const response = await this.api.add(this.name,
+ {
+ // @ts-ignore
ids: idsArray,
- embeddings: embeddingsArray,
- //@ts-ignore
+ embeddings: embeddingsArray as number[][], // We know this is defined because of the validate function
+ // @ts-ignore
documents: documentsArray,
metadatas: metadatasArray,
incrementIndex: increment_index,
})
- .then(function (response: any) {
- return JSON.parse(response);
- })
+ .then(handleSuccess)
+ .catch(handleError);
+
+ return response
+ }
+
+ public async upsert(
+ ids: string | string[],
+ embeddings: number[] | number[][] | undefined,
+ metadatas?: object | object[],
+ documents?: string | string[],
+ increment_index: boolean = true,
+ ) {
+
+ const [idsArray, embeddingsArray, metadatasArray, documentsArray] = await this.validate(
+ true,
+ ids,
+ embeddings,
+ metadatas,
+ documents
+ )
+
+ const response = await this.api.upsert(this.name,
+ {
+ //@ts-ignore
+ ids: idsArray,
+ embeddings: embeddingsArray as number[][], // We know this is defined because of the validate function
+ //@ts-ignore
+ documents: documentsArray,
+ metadatas: metadatasArray,
+ increment_index: increment_index,
+ },
+ )
+ .then(handleSuccess)
.catch(handleError);
+
+ return response
+
}
+
public async count() {
const response = await this.api.count(this.name);
return handleSuccess(response);
diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts
index 956a3fcafa9..6ac078b5121 100644
--- a/clients/js/test/add.collections.test.ts
+++ b/clients/js/test/add.collections.test.ts
@@ -1,6 +1,7 @@
-import { expect, test } from "@jest/globals";
-import chroma from "./initClient";
-import { DOCUMENTS, EMBEDDINGS, IDS } from "./data";
+import { expect, test } from '@jest/globals';
+import chroma from './initClient'
+import { DOCUMENTS, EMBEDDINGS, IDS } from './data';
+import { METADATAS } from './data';
import { IncludeEnum } from "../src/types";
test("it should add single embeddings to a collection", async () => {
@@ -37,3 +38,25 @@ test("add documents", async () => {
const results = await collection.get(["test1"]);
expect(results.documents[0]).toBe("This is a test");
});
+
+test('it should return an error when inserting an ID that alreay exists in the Collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ await collection.add(IDS, EMBEDDINGS, METADATAS)
+ const results = await collection.add(IDS, EMBEDDINGS, METADATAS);
+ expect(results.error).toBeDefined()
+ expect(results.error).toContain("IDAlreadyExists")
+})
+
+test('It should return an error when inserting duplicate IDs in the same batch', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = IDS.concat(["test1"])
+ const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
+ const metadatas = METADATAS.concat([{ test: 'test1', 'float_value': 0.1 }])
+ try {
+ await collection.add(ids, embeddings, metadatas);
+ } catch (e: any) {
+ expect(e.message).toMatch('duplicates')
+ }
+})
\ No newline at end of file
diff --git a/clients/js/test/client.test.ts b/clients/js/test/client.test.ts
index 1a39d6159a6..06595c0149b 100644
--- a/clients/js/test/client.test.ts
+++ b/clients/js/test/client.test.ts
@@ -3,38 +3,188 @@ import { ChromaClient } from "../src/index";
import chroma from "./initClient";
test("it should create the client connection", async () => {
- expect(chroma).toBeDefined();
- expect(chroma).toBeInstanceOf(ChromaClient);
+ expect(chroma).toBeDefined();
+ expect(chroma).toBeInstanceOf(ChromaClient);
});
test("it should get the version", async () => {
- const version = await chroma.version();
- expect(version).toBeDefined();
- expect(version).toMatch(/^[0-9]+\.[0-9]+\.[0-9]+$/);
+ const version = await chroma.version();
+ expect(version).toBeDefined();
+ expect(version).toMatch(/^[0-9]+\.[0-9]+\.[0-9]+$/);
});
test("it should get the heartbeat", async () => {
- const heartbeat = await chroma.heartbeat();
- expect(heartbeat).toBeDefined();
- expect(heartbeat).toBeGreaterThan(0);
+ const heartbeat = await chroma.heartbeat();
+ expect(heartbeat).toBeDefined();
+ expect(heartbeat).toBeGreaterThan(0);
});
test("it should reset the database", async () => {
- await chroma.reset();
- const collections = await chroma.listCollections();
- expect(collections).toBeDefined();
- expect(collections).toBeInstanceOf(Array);
- expect(collections.length).toBe(0);
-
- const collection = await chroma.createCollection("test");
- const collections2 = await chroma.listCollections();
- expect(collections2).toBeDefined();
- expect(collections2).toBeInstanceOf(Array);
- expect(collections2.length).toBe(1);
-
- await chroma.reset();
- const collections3 = await chroma.listCollections();
- expect(collections3).toBeDefined();
- expect(collections3).toBeInstanceOf(Array);
- expect(collections3.length).toBe(0);
+ await chroma.reset();
+ const collections = await chroma.listCollections();
+ expect(collections).toBeDefined();
+ expect(collections).toBeInstanceOf(Array);
+ expect(collections.length).toBe(0);
+
+ const collection = await chroma.createCollection("test");
+ const collections2 = await chroma.listCollections();
+ expect(collections2).toBeDefined();
+ expect(collections2).toBeInstanceOf(Array);
+ expect(collections2.length).toBe(1);
+
+ await chroma.reset();
+ const collections3 = await chroma.listCollections();
+ expect(collections3).toBeDefined();
+ expect(collections3).toBeInstanceOf(Array);
+ expect(collections3.length).toBe(0);
});
+
+test('it should list collections', async () => {
+ await chroma.reset()
+ let collections = await chroma.listCollections()
+ expect(collections).toBeDefined()
+ expect(collections).toBeInstanceOf(Array)
+ expect(collections.length).toBe(0)
+ const collection = await chroma.createCollection('test')
+ collections = await chroma.listCollections()
+ expect(collections.length).toBe(1)
+})
+
+test('it should get a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const collection2 = await chroma.getCollection('test')
+ expect(collection).toBeDefined()
+ expect(collection2).toBeDefined()
+ expect(collection).toHaveProperty('name')
+ expect(collection2).toHaveProperty('name')
+ expect(collection.name).toBe(collection2.name)
+})
+
+test('it should delete a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ let collections = await chroma.listCollections()
+ expect(collections.length).toBe(1)
+ var resp = await chroma.deleteCollection('test')
+ collections = await chroma.listCollections()
+ expect(collections.length).toBe(0)
+})
+
+test('it should add single embeddings to a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const id = 'test1'
+ const embedding = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ const metadata = { test: 'test' }
+ await collection.add(id, embedding, metadata)
+ const count = await collection.count()
+ expect(count).toBe(1)
+})
+
+test('it should add batch embeddings to a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = ['test1', 'test2', 'test3']
+ const embeddings = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ ]
+ await collection.add(ids, embeddings)
+ const count = await collection.count()
+ expect(count).toBe(3)
+})
+
+test('it should query a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = ['test1', 'test2', 'test3']
+ const embeddings = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ ]
+ await collection.add(ids, embeddings)
+ const results = await collection.query([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 2)
+ expect(results).toBeDefined()
+ expect(results).toBeInstanceOf(Object)
+ // expect(results.embeddings[0].length).toBe(2)
+ expect(['test1', 'test2']).toEqual(expect.arrayContaining(results.ids[0]));
+ expect(['test3']).not.toEqual(expect.arrayContaining(results.ids[0]));
+})
+
+test('it should peek a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = ['test1', 'test2', 'test3']
+ const embeddings = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ ]
+ await collection.add(ids, embeddings)
+ const results = await collection.peek(2)
+ expect(results).toBeDefined()
+ expect(results).toBeInstanceOf(Object)
+ expect(results.ids.length).toBe(2)
+ expect(['test1', 'test2']).toEqual(expect.arrayContaining(results.ids));
+})
+
+test('it should get a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = ['test1', 'test2', 'test3']
+ const embeddings = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ ]
+ const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }]
+ await collection.add(ids, embeddings, metadatas)
+ const results = await collection.get(['test1'])
+ expect(results).toBeDefined()
+ expect(results).toBeInstanceOf(Object)
+ expect(results.ids.length).toBe(1)
+ expect(['test1']).toEqual(expect.arrayContaining(results.ids));
+ expect(['test2']).not.toEqual(expect.arrayContaining(results.ids));
+
+ const results2 = await collection.get(undefined, { 'test': 'test1' })
+ expect(results2).toBeDefined()
+ expect(results2).toBeInstanceOf(Object)
+ expect(results2.ids.length).toBe(1)
+})
+
+test('it should delete a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = ['test1', 'test2', 'test3']
+ const embeddings = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ ]
+ const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }]
+ await collection.add(ids, embeddings, metadatas)
+ let count = await collection.count()
+ expect(count).toBe(3)
+ var resp = await collection.delete(undefined, { 'test': 'test1' })
+ count = await collection.count()
+ expect(count).toBe(2)
+})
+
+test('wrong code returns an error', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = ['test1', 'test2', 'test3']
+ const embeddings = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ ]
+ const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }]
+ await collection.add(ids, embeddings, metadatas)
+ const results = await collection.get(undefined, { "test": { "$contains": "hello" } });
+ expect(results.error).toBeDefined()
+ expect(results.error).toBe("ValueError('Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got $contains')")
+})
diff --git a/clients/js/test/upsert.collections.test.ts b/clients/js/test/upsert.collections.test.ts
new file mode 100644
index 00000000000..4bbf3cb0981
--- /dev/null
+++ b/clients/js/test/upsert.collections.test.ts
@@ -0,0 +1,29 @@
+import { expect, test } from '@jest/globals';
+import chroma from './initClient'
+import { DOCUMENTS, EMBEDDINGS, IDS } from './data';
+import { METADATAS } from './data';
+
+
+test('it should upsert embeddings to a collection', async () => {
+ await chroma.reset()
+ const collection = await chroma.createCollection('test')
+ const ids = ['test1', 'test2']
+ const embeddings = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ ]
+ await collection.add(ids, embeddings)
+ const count = await collection.count()
+ expect(count).toBe(2)
+
+ const ids2 = ["test2", "test3"]
+ const embeddings2 = [
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 15],
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ ]
+
+ await collection.upsert(ids2, embeddings2)
+
+ const count2 = await collection.count()
+ expect(count2).toBe(3)
+})
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index e3f33840277..55ca469031f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,12 +25,13 @@ dependencies = [
'fastapi >= 0.85.1',
'uvicorn[standard] >= 0.18.3',
'numpy >= 1.21.6',
- 'posthog >= 2.4.0'
+ 'posthog >= 2.4.0',
+ 'typing_extensions >= 4.5.0'
]
[tool.black]
-line-length = 100
-required-version = "22.10.0" # Black will refuse to run if it's not this version.
+line-length = 88
+required-version = "23.3.0" # Black will refuse to run if it's not this version.
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']
[tool.pytest.ini_options]
diff --git a/requirements.txt b/requirements.txt
index 267522d04c8..5bb66359137 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,5 @@ hnswlib==0.7.0
clickhouse-connect==0.5.7
pydantic==1.9.0
sentence-transformers==2.2.2
-posthog==2.4.0
\ No newline at end of file
+posthog==2.4.0
+typing_extensions==4.5.0
\ No newline at end of file
diff --git a/requirements_dev.txt b/requirements_dev.txt
index df8913b3c04..78456ffcafa 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -2,4 +2,6 @@ build
pytest
setuptools_scm
httpx
-black==22.10.0 # match what's in pyproject.toml
+black==23.3.0 # match what's in pyproject.toml
+hypothesis
+hypothesis[numpy]
\ No newline at end of file