diff --git a/.dockerignore b/.dockerignore index 9f25234e15f..4fcf5c716cd 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,9 @@ venv .git -examples \ No newline at end of file +examples +clients +.hypothesis +__pycache__ +.vscode +*.egg-info +.pytest_cache \ No newline at end of file diff --git a/.github/workflows/chroma-integration-test.yml b/.github/workflows/chroma-integration-test.yml new file mode 100644 index 00000000000..7883371fd52 --- /dev/null +++ b/.github/workflows/chroma-integration-test.yml @@ -0,0 +1,37 @@ +name: Chroma Integration Tests + +on: + push: + branches: + - main + - team/hypothesis-tests + pull_request: + branches: + - main + - team/hypothesis-tests + +jobs: + test: + strategy: + matrix: + python: ['3.7'] + platform: [ubuntu-latest] + testfile: ["--ignore-glob 'chromadb/test/property/*'", + "chromadb/test/property/test_add.py", + "chromadb/test/property/test_collections.py", + "chromadb/test/property/test_cross_version_persist.py", + "chromadb/test/property/test_embeddings.py", + "chromadb/test/property/test_filtering.py", + "chromadb/test/property/test_persist.py"] + runs-on: ${{ matrix.platform }} + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + - name: Install test dependencies + run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt + - name: Integration Test + run: bin/integration-test ${{ matrix.testfile }} \ No newline at end of file diff --git a/.github/workflows/chroma-test.yml b/.github/workflows/chroma-test.yml index ec1c7e37c3c..142d0971f4e 100644 --- a/.github/workflows/chroma-test.yml +++ b/.github/workflows/chroma-test.yml @@ -4,16 +4,26 @@ on: push: branches: - main + - team/hypothesis-tests pull_request: branches: - main + - team/hypothesis-tests jobs: test: + timeout-minutes: 90 strategy: matrix: - python: ['3.10'] + python: ['3.7', '3.8', '3.9', '3.10'] platform: [ubuntu-latest] + testfile: ["--ignore-glob 'chromadb/test/property/*'", + "chromadb/test/property/test_add.py", + "chromadb/test/property/test_collections.py", + "chromadb/test/property/test_cross_version_persist.py", + "chromadb/test/property/test_embeddings.py", + "chromadb/test/property/test_filtering.py", + "chromadb/test/property/test_persist.py"] runs-on: ${{ matrix.platform }} steps: - name: Checkout @@ -25,6 +35,4 @@ jobs: - name: Install test dependencies run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt - name: Test - run: python -m pytest - - name: Integration Test - run: bin/integration-test \ No newline at end of file + run: python -m pytest ${{ matrix.testfile }} diff --git a/.github/workflows/pr-review-checklist.yml b/.github/workflows/pr-review-checklist.yml new file mode 100644 index 00000000000..6b7c9d38122 --- /dev/null +++ b/.github/workflows/pr-review-checklist.yml @@ -0,0 +1,37 @@ +name: PR Review Checklist + +on: + pull_request_target: + types: + - opened + +jobs: + PR-Comment: + runs-on: ubuntu-latest + steps: + - name: PR Comment + uses: actions/github-script@v2 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + github.issues.createComment({ + issue_number: ${{ github.event.number }}, + owner: context.repo.owner, + repo: context.repo.repo, + body: `# Reviewer Checklist + Please leverage this checklist to ensure your code review is thorough before approving + ## Testing, Bugs, Errors, Logs, Documentation + - [ ] Can you think of any use case in which the code does not behave as intended? Have they been tested? + - [ ] Can you think of any inputs or external events that could break the code? Is user input validated and safe? Have they been tested? + - [ ] If appropriate, are there adequate property based tests? + - [ ] If appropriate, are there adequate unit tests? + - [ ] Should any logging, debugging, tracing information be added or removed? + - [ ] Are error messages user-friendly? + - [ ] Have all documentation changes needed been made? + - [ ] Have all non-obvious changes been commented? + ## System Compatibility + - [ ] Are there any potential impacts on other parts of the system or backward compatibility? + - [ ] Does this change intersect with any items on our roadmap, and if so, is there a plan for fitting them together? + ## Quality + - [ ] Is this code of a unexpectedly high quality (Readbility, Modularity, Intuitiveness)` + }) diff --git a/.gitignore b/.gitignore index e4a4bed1330..de36093c7f6 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,4 @@ dist .terraform.lock.hcl terraform.tfstate .hypothesis/ -.idea \ No newline at end of file +.idea diff --git a/.vscode/settings.json b/.vscode/settings.json index 78903060c50..4ec74c4e3d7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,23 +1,28 @@ { - "git.ignoreLimitWarning": true, - "editor.rulers": [ - 120 - ], - "editor.formatOnSave": true, - "python.formatting.provider": "black", - "files.exclude": { - "**/__pycache__": true, - "**/.ipynb_checkpoints": true, - "**/.pytest_cache": true, - "**/chroma.egg-info": true - }, - "python.analysis.typeCheckingMode": "basic", - "python.linting.flake8Enabled": true, - "python.linting.enabled": true, - "python.linting.flake8Args": [ - "--extend-ignore=E203", - "--extend-ignore=E501", - "--extend-ignore=E503", - "--max-line-length=88", - ], -} + "git.ignoreLimitWarning": true, + "editor.rulers": [ + 120 + ], + "editor.formatOnSave": true, + "python.formatting.provider": "black", + "files.exclude": { + "**/__pycache__": true, + "**/.ipynb_checkpoints": true, + "**/.pytest_cache": true, + "**/chroma.egg-info": true + }, + "python.analysis.typeCheckingMode": "basic", + "python.linting.flake8Enabled": true, + "python.linting.enabled": true, + "python.linting.flake8Args": [ + "--extend-ignore=E203", + "--extend-ignore=E501", + "--extend-ignore=E503", + "--max-line-length=88" + ], + "python.testing.pytestArgs": [ + "." + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/README.md b/README.md index 252ae58b579..8e7dc8339dd 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,10 @@ | License - | + | Docs - | + | Homepage @@ -30,19 +30,19 @@ pip install chromadb # python client The core API is only 4 functions (run our [💡 Google Colab](https://colab.research.google.com/drive/1QEzFyqnoFxq7LUGyP1vzR4iLt9PpCDXv?usp=sharing) or [Replit template](https://replit.com/@swyx/BasicChromaStarter?v=1)): -```python +```python import chromadb # setup Chroma in-memory, for easy prototyping. Can add persistence easily! client = chromadb.Client() # Create collection. get_collection, get_or_create_collection, delete_collection also available! -collection = client.create_collection("all-my-documents") +collection = client.create_collection("all-my-documents") # Add docs to the collection. Can also update and delete. Row-based API coming soon! collection.add( documents=["This is document1", "This is document2"], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on these! - ids=["doc1", "doc2"], # unique for each doc + ids=["doc1", "doc2"], # unique for each doc ) # Query/search 2 most similar results. You can also .get by id @@ -66,15 +66,15 @@ results = collection.query( For example, the `"Chat your data"` use case: 1. Add documents to your database. You can pass in your own embeddings, embedding function, or let Chroma embed them for you. 2. Query relevant documents with natural language. -3. Compose documents into the context window of an LLM like `GPT3` for additional summarization or analysis. +3. Compose documents into the context window of an LLM like `GPT3` for additional summarization or analysis. ## Embeddings? What are embeddings? - [Read the guide from OpenAI](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings) -- __Literal__: Embedding something turns it from image/text/audio into a list of numbers. 🖼️ or 📄 => `[1.2, 2.1, ....]`. This process makes documents "understandable" to a machine learning model. -- __By analogy__: An embedding represents the essence of a document. This enables documents and queries with the same essence to be "near" each other and therefore easy to find. +- __Literal__: Embedding something turns it from image/text/audio into a list of numbers. 🖼️ or 📄 => `[1.2, 2.1, ....]`. This process makes documents "understandable" to a machine learning model. +- __By analogy__: An embedding represents the essence of a document. This enables documents and queries with the same essence to be "near" each other and therefore easy to find. - __Technical__: An embedding is the latent-space position of a document at a layer of a deep neural network. For models trained specifically to embed data, this is the last layer. - __A small example__: If you search your photos for "famous bridge in San Francisco". By embedding this query and comparing it to the embeddings of your photos and their metadata - it should return photos of the Golden Gate Bridge. @@ -82,7 +82,7 @@ Embeddings databases (also known as **vector databases**) store embeddings and a ## Get involved -Chroma is a rapidly developing project. We welcome PR contributors and ideas for how to improve the project. +Chroma is a rapidly developing project. We welcome PR contributors and ideas for how to improve the project. - [Join the conversation on Discord](https://discord.gg/MMeYNTmh3x) - [Review the roadmap and contribute your ideas](https://docs.trychroma.com/roadmap) - [Grab an issue and open a PR](https://github.com/chroma-core/chroma/issues) diff --git a/bin/integration-test b/bin/integration-test index ea3c4cc2d80..29c49cdb91b 100755 --- a/bin/integration-test +++ b/bin/integration-test @@ -12,15 +12,15 @@ trap cleanup EXIT docker compose -f docker-compose.test.yml up --build -d -export CHROMA_INTEGRATION_TEST=1 +export CHROMA_INTEGRATION_TEST_ONLY=1 export CHROMA_API_IMPL=rest export CHROMA_SERVER_HOST=localhost export CHROMA_SERVER_HTTP_PORT=8000 -python -m pytest +echo testing: python -m pytest "$@" +python -m pytest "$@" cd clients/js yarn yarn test:run -cd ../.. - +cd ../.. \ No newline at end of file diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 90c3446e88f..e3b6ff35e3b 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -61,7 +61,8 @@ def create_collection( Args: name (str): The name of the collection to create. The name must be unique. metadata (Optional[Dict], optional): A dictionary of metadata to associate with the collection. Defaults to None. - get_or_create (bool, optional): If True, will return the collection if it already exists. Defaults to False. + get_or_create (bool, optional): If True, will return the collection if it already exists, + and update the metadata (if applicable). Defaults to False. embedding_function (Optional[Callable], optional): A function that takes documents and returns an embedding. Defaults to None. Returns: @@ -82,8 +83,11 @@ def delete_collection( """ @abstractmethod - def get_or_create_collection(self, name: str, metadata: Optional[Dict] = None) -> Collection: - """Calls create_collection with get_or_create=True + def get_or_create_collection( + self, name: str, metadata: Optional[Dict] = None + ) -> Collection: + """Calls create_collection with get_or_create=True. + If the collection exists, but with different metadata, the metadata will be replaced. Args: name (str): The name of the collection to create. The name must be unique. @@ -141,7 +145,7 @@ def _add( ⚠️ It is recommended to use the more specific methods below when possible. Args: - collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to + collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to embedding (Sequence[Sequence[float]]): The sequence of embeddings to add metadata (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None. documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None. @@ -162,17 +166,40 @@ def _update( ⚠️ It is recommended to use the more specific methods below when possible. Args: - collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to + collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to embedding (Sequence[Sequence[float]]): The sequence of embeddings to add """ pass + @abstractmethod + def _upsert( + self, + collection_name: str, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + increment_index: bool = True, + ): + """Add or update entries in the embedding store. + If an entry with the same id already exists, it will be updated, otherwise it will be added. + + Args: + collection_name (str): The collection to add the embeddings to + ids (Optional[Union[str, Sequence[str]]], optional): The ids to associate with the embeddings. Defaults to None. + embeddings (Sequence[Sequence[float]]): The sequence of embeddings to add + metadatas (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None. + documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None. + increment_index (bool, optional): If True, will incrementally add to the ANN index of the collection. Defaults to True. + """ + pass + @abstractmethod def _count(self, collection_name: str) -> int: """Returns the number of embeddings in the database Args: - collection_name (str): The model space to count the embeddings in. + collection_name (str): The collection to count the embeddings in. Returns: int: The number of embeddings in the collection @@ -278,14 +305,19 @@ def raw_sql(self, sql: str) -> pd.DataFrame: @abstractmethod def create_index(self, collection_name: Optional[str] = None) -> bool: - """Creates an index for the given model space + """Creates an index for the given collection ⚠️ This method should not be used directly. Args: - collection_name (Optional[str], optional): The model space to create the index for. Uses the client's model space if None. Defaults to None. + collection_name (Optional[str], optional): The collection to create the index for. Uses the client's collection if None. Defaults to None. Returns: bool: True if the index was created successfully """ pass + + @abstractmethod + def persist(self): + """Persist the database to disk""" + pass diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 0be1a087fa7..c5eac52014c 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -15,6 +15,7 @@ from typing import Sequence from chromadb.api.models.Collection import Collection from chromadb.telemetry import Telemetry +import chromadb.errors as errors class FastAPI(API): @@ -26,13 +27,13 @@ def __init__(self, settings, telemetry_client: Telemetry): def heartbeat(self): """Returns the current server time in nanoseconds to check if the server is alive""" resp = requests.get(self._api_url) - resp.raise_for_status() + raise_chroma_error(resp) return int(resp.json()["nanosecond heartbeat"]) def list_collections(self) -> Sequence[Collection]: """Returns a list of all collections""" resp = requests.get(self._api_url + "/collections") - resp.raise_for_status() + raise_chroma_error(resp) json_collections = resp.json() collections = [] for json_collection in json_collections: @@ -52,7 +53,7 @@ def create_collection( self._api_url + "/collections", data=json.dumps({"name": name, "metadata": metadata, "get_or_create": get_or_create}), ) - resp.raise_for_status() + raise_chroma_error(resp) resp_json = resp.json() return Collection( client=self, @@ -68,7 +69,7 @@ def get_collection( ) -> Collection: """Returns a collection""" resp = requests.get(self._api_url + "/collections/" + name) - resp.raise_for_status() + raise_chroma_error(resp) resp_json = resp.json() return Collection( client=self, @@ -93,18 +94,18 @@ def _modify(self, current_name: str, new_name: str, new_metadata: Optional[Dict] self._api_url + "/collections/" + current_name, data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}), ) - resp.raise_for_status() + raise_chroma_error(resp) return resp.json() def delete_collection(self, name: str): """Deletes a collection""" resp = requests.delete(self._api_url + "/collections/" + name) - resp.raise_for_status() + raise_chroma_error(resp) def _count(self, collection_name: str): """Returns the number of embeddings in the database""" resp = requests.get(self._api_url + "/collections/" + collection_name + "/count") - resp.raise_for_status() + raise_chroma_error(resp) return resp.json() def _peek(self, collection_name, limit=10): @@ -147,7 +148,7 @@ def _get( ), ) - resp.raise_for_status() + raise_chroma_error(resp) return resp.json() def _delete(self, collection_name, ids=None, where={}, where_document={}): @@ -158,7 +159,7 @@ def _delete(self, collection_name, ids=None, where={}, where_document={}): data=json.dumps({"where": where, "ids": ids, "where_document": where_document}), ) - resp.raise_for_status() + raise_chroma_error(resp) return resp.json() def _add( @@ -180,20 +181,16 @@ def _add( self._api_url + "/collections/" + collection_name + "/add", data=json.dumps( { + "ids": ids, "embeddings": embeddings, "metadatas": metadatas, "documents": documents, - "ids": ids, "increment_index": increment_index, } ), ) - try: - resp.raise_for_status() - except requests.HTTPError: - raise (Exception(resp.text)) - + raise_chroma_error(resp) return True def _update( @@ -224,6 +221,36 @@ def _update( resp.raise_for_status() return True + def _upsert( + self, + collection_name: str, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + increment_index: bool = True, + ): + """ + Updates a batch of embeddings in the database + - pass in column oriented data lists + """ + + resp = requests.post( + self._api_url + "/collections/" + collection_name + "/upsert", + data=json.dumps( + { + "ids": ids, + "embeddings": embeddings, + "metadatas": metadatas, + "documents": documents, + "increment_index": increment_index, + } + ), + ) + + resp.raise_for_status() + return True + def _query( self, collection_name, @@ -248,43 +275,60 @@ def _query( ), ) - try: - resp.raise_for_status() - except requests.HTTPError: - raise (Exception(resp.text)) - + raise_chroma_error(resp) body = resp.json() return body def reset(self): """Resets the database""" resp = requests.post(self._api_url + "/reset") - resp.raise_for_status() + raise_chroma_error(resp) return resp.json def persist(self): """Persists the database""" resp = requests.post(self._api_url + "/persist") - resp.raise_for_status() + raise_chroma_error(resp) return resp.json def raw_sql(self, sql): """Runs a raw SQL query against the database""" resp = requests.post(self._api_url + "/raw_sql", data=json.dumps({"raw_sql": sql})) - resp.raise_for_status() + raise_chroma_error(resp) return pd.DataFrame.from_dict(resp.json()) def create_index(self, collection_name: str): """Creates an index for the given space key""" resp = requests.post(self._api_url + "/collections/" + collection_name + "/create_index") - try: - resp.raise_for_status() - except requests.HTTPError: - raise (Exception(resp.text)) + raise_chroma_error(resp) return resp.json() def get_version(self): """Returns the version of the server""" resp = requests.get(self._api_url + "/version") - resp.raise_for_status() + raise_chroma_error(resp) return resp.json() + + +def raise_chroma_error(resp): + """Raises an error if the response is not ok, using a ChromaError if possible""" + if resp.ok: + return + + chroma_error = None + try: + body = resp.json() + if "error" in body: + if body["error"] in errors.error_types: + chroma_error = errors.error_types[body["error"]](body["message"]) + + except BaseException: + pass + + if chroma_error: + raise chroma_error + + try: + resp.raise_for_status() + except requests.HTTPError: + raise (Exception(resp.text)) diff --git a/chromadb/api/local.py b/chromadb/api/local.py index 732dc678cc6..1a30d70401e 100644 --- a/chromadb/api/local.py +++ b/chromadb/api/local.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence, Callable, cast from chromadb import __version__ - +import chromadb.errors as errors from chromadb.api import API from chromadb.db import DB from chromadb.api.types import ( @@ -38,11 +38,11 @@ def check_index_name(index_name): ) if len(index_name) < 3 or len(index_name) > 63: raise ValueError(msg) - if not re.match("^[a-z0-9][a-z0-9._-]*[a-z0-9]$", index_name): + if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): raise ValueError(msg) if ".." in index_name: raise ValueError(msg) - if re.match("^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$", index_name): + if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): raise ValueError(msg) @@ -90,7 +90,10 @@ def create_collection( res = self._db.create_collection(name, metadata, get_or_create) return Collection( - client=self, name=name, embedding_function=embedding_function, metadata=res[0][2] + client=self, + name=name, + embedding_function=embedding_function, + metadata=res[0][2], ) def get_or_create_collection( @@ -112,7 +115,9 @@ def get_or_create_collection( >>> client.get_or_create_collection("my_collection") collection(name="my_collection", metadata={}) """ - return self.create_collection(name, metadata, embedding_function, get_or_create=True) + return self.create_collection( + name, metadata, embedding_function, get_or_create=True + ) def get_collection( self, @@ -138,7 +143,10 @@ def get_collection( if len(res) == 0: raise ValueError(f"Collection {name} does not exist") return Collection( - client=self, name=name, embedding_function=embedding_function, metadata=res[0][2] + client=self, + name=name, + embedding_function=embedding_function, + metadata=res[0][2], ) def list_collections(self) -> Sequence[Collection]: @@ -154,7 +162,9 @@ def list_collections(self) -> Sequence[Collection]: db_collections = self._db.list_collections() for db_collection in db_collections: collections.append( - Collection(client=self, name=db_collection[1], metadata=db_collection[2]) + Collection( + client=self, name=db_collection[1], metadata=db_collection[2] + ) ) return collections @@ -194,6 +204,11 @@ def _add( documents: Optional[Documents] = None, increment_index: bool = True, ): + existing_ids = self._get(collection_name, ids=ids, include=[])["ids"] + if len(existing_ids) > 0: + raise errors.IDAlreadyExistsError( + f"IDs {existing_ids} already exist in collection {collection_name}" + ) collection_uuid = self._db.get_collection_uuid_from_name(collection_name) added_uuids = self._db.add( @@ -220,6 +235,65 @@ def _update( ): collection_uuid = self._db.get_collection_uuid_from_name(collection_name) self._db.update(collection_uuid, ids, embeddings, metadatas, documents) + return True + + def _upsert( + self, + collection_name: str, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + increment_index: bool = True, + ): + # Determine which ids need to be added and which need to be updated based on the ids already in the collection + existing_ids = set(self._get(collection_name, ids=ids, include=[])["ids"]) + + ids_to_add = [] + ids_to_update = [] + embeddings_to_add: Embeddings = [] + embeddings_to_update: Embeddings = [] + metadatas_to_add: Optional[Metadatas] = [] if metadatas else None + metadatas_to_update: Optional[Metadatas] = [] if metadatas else None + documents_to_add: Optional[Documents] = [] if documents else None + documents_to_update: Optional[Documents] = [] if documents else None + + for i, id in enumerate(ids): + if id in existing_ids: + ids_to_update.append(id) + if embeddings is not None: + embeddings_to_update.append(embeddings[i]) + if metadatas is not None: + metadatas_to_update.append(metadatas[i]) + if documents is not None: + documents_to_update.append(documents[i]) + else: + ids_to_add.append(id) + if embeddings is not None: + embeddings_to_add.append(embeddings[i]) + if metadatas is not None: + metadatas_to_add.append(metadatas[i]) + if documents is not None: + documents_to_add.append(documents[i]) + + if len(ids_to_add) > 0: + self._add( + ids_to_add, + collection_name, + embeddings_to_add, + metadatas_to_add, + documents_to_add, + increment_index=increment_index, + ) + + if len(ids_to_update) > 0: + self._update( + collection_name, + ids_to_update, + embeddings_to_update, + metadatas_to_update, + documents_to_update, + ) return True @@ -252,7 +326,9 @@ def _get( # Remove plural from include since db columns are singular db_columns = [column[:-1] for column in include] + ["id"] - column_index = {column_name: index for index, column_name in enumerate(db_columns)} + column_index = { + column_name: index for index, column_name in enumerate(db_columns) + } db_result = self._db.get( collection_name=collection_name, @@ -274,11 +350,17 @@ def _get( for entry in db_result: if include_embeddings: - cast(List, get_result["embeddings"]).append(entry[column_index["embedding"]]) + cast(List, get_result["embeddings"]).append( + entry[column_index["embedding"]] + ) if include_documents: - cast(List, get_result["documents"]).append(entry[column_index["document"]]) + cast(List, get_result["documents"]).append( + entry[column_index["document"]] + ) if include_metadatas: - cast(List, get_result["metadatas"]).append(entry[column_index["metadata"]]) + cast(List, get_result["metadatas"]).append( + entry[column_index["metadata"]] + ) get_result["ids"].append(entry[column_index["id"]]) return get_result @@ -291,9 +373,14 @@ def _delete(self, collection_name, ids=None, where=None, where_document=None): collection_uuid = self._db.get_collection_uuid_from_name(collection_name) deleted_uuids = self._db.delete( - collection_uuid=collection_uuid, where=where, ids=ids, where_document=where_document + collection_uuid=collection_uuid, + where=where, + ids=ids, + where_document=where_document, + ) + self._telemetry_client.capture( + CollectionDeleteEvent(collection_uuid, len(deleted_uuids)) ) - self._telemetry_client.capture(CollectionDeleteEvent(collection_uuid, len(deleted_uuids))) return deleted_uuids def _count(self, collection_name): @@ -344,8 +431,12 @@ def _query( ids = [] metadatas = [] # Remove plural from include since db columns are singular - db_columns = [column[:-1] for column in include if column != "distances"] + ["id"] - column_index = {column_name: index for index, column_name in enumerate(db_columns)} + db_columns = [ + column[:-1] for column in include if column != "distances" + ] + ["id"] + column_index = { + column_name: index for index, column_name in enumerate(db_columns) + } db_result = self._db.get_by_ids(uuids[i], columns=db_columns) for entry in db_result: diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 9e8b2924c19..6aa9958d6df 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, cast, List, Dict +from typing import TYPE_CHECKING, Optional, cast, List, Dict, Tuple from pydantic import BaseModel, PrivateAttr from chromadb.api.types import ( @@ -42,7 +42,6 @@ def __init__( embedding_function: Optional[EmbeddingFunction] = None, metadata: Optional[Dict] = None, ): - self._client = client if embedding_function is not None: self._embedding_function = embedding_function @@ -95,36 +94,13 @@ def add( """ - ids = validate_ids(maybe_cast_one_to_many(ids)) - embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None - metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None - documents = maybe_cast_one_to_many(documents) if documents else None - - # Check that one of embeddings or documents is provided - if embeddings is None and documents is None: - raise ValueError("You must provide either embeddings or documents, or both") - - # Check that, if they're provided, the lengths of the arrays match the length of ids - if embeddings is not None and len(embeddings) != len(ids): - raise ValueError( - f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}" - ) - if metadatas is not None and len(metadatas) != len(ids): - raise ValueError( - f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}" - ) - if documents is not None and len(documents) != len(ids): - raise ValueError( - f"Number of documents {len(documents)} must match number of ids {len(ids)}" - ) - - # If document embeddings are not provided, we need to compute them - if embeddings is None and documents is not None: - if self._embedding_function is None: - raise ValueError("You must provide embeddings or a function to compute them") - embeddings = self._embedding_function(documents) + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents + ) - self._client._add(ids, self.name, embeddings, metadatas, documents, increment_index) + self._client._add( + ids, self.name, embeddings, metadatas, documents, increment_index + ) def get( self, @@ -151,7 +127,9 @@ def get( """ where = validate_where(where) if where else None - where_document = validate_where_document(where_document) if where_document else None + where_document = ( + validate_where_document(where_document) if where_document else None + ) ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None include = validate_include(include, allow_distances=False) return self._client._get( @@ -204,8 +182,12 @@ def query( """ where = validate_where(where) if where else None - where_document = validate_where_document(where_document) if where_document else None - query_embeddings = maybe_cast_one_to_many(query_embeddings) if query_embeddings else None + where_document = ( + validate_where_document(where_document) if where_document else None + ) + query_embeddings = ( + maybe_cast_one_to_many(query_embeddings) if query_embeddings else None + ) query_texts = maybe_cast_one_to_many(query_texts) if query_texts else None include = validate_include(include, allow_distances=True) @@ -220,9 +202,13 @@ def query( # If query_embeddings are not provided, we need to compute them from the query_texts if query_embeddings is None: if self._embedding_function is None: - raise ValueError("You must provide embeddings or a function to compute them") + raise ValueError( + "You must provide embeddings or a function to compute them" + ) # We know query texts is not None at this point, cast for the typechecker - query_embeddings = self._embedding_function(cast(List[Document], query_texts)) + query_embeddings = self._embedding_function( + cast(List[Document], query_texts) + ) if where is None: where = {} @@ -249,7 +235,9 @@ def modify(self, name: Optional[str] = None, metadata=None): Returns: None """ - self._client._modify(current_name=self.name, new_name=name, new_metadata=metadata) + self._client._modify( + current_name=self.name, new_name=name, new_metadata=metadata + ) if name: self.name = name if metadata: @@ -274,40 +262,41 @@ def update( None """ - ids = validate_ids(maybe_cast_one_to_many(ids)) - embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None - metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None - documents = maybe_cast_one_to_many(documents) if documents else None + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents, require_embeddings_or_documents=False + ) - # Must update one of embeddings, metadatas, or documents - if embeddings is None and documents is None and metadatas is None: - raise ValueError("You must update at least one of embeddings, documents or metadatas.") + self._client._update(self.name, ids, embeddings, metadatas, documents) - # Check that one of embeddings or documents is provided - if embeddings is not None and documents is None: - raise ValueError("You must provide updated documents with updated embeddings") + def upsert( + self, + ids: OneOrMany[ID], + embeddings: Optional[OneOrMany[Embedding]] = None, + metadatas: Optional[OneOrMany[Metadata]] = None, + documents: Optional[OneOrMany[Document]] = None, + increment_index: bool = True, + ): + """Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist. - # Check that, if they're provided, the lengths of the arrays match the length of ids - if embeddings is not None and len(embeddings) != len(ids): - raise ValueError( - f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}" - ) - if metadatas is not None and len(metadatas) != len(ids): - raise ValueError( - f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}" - ) - if documents is not None and len(documents) != len(ids): - raise ValueError( - f"Number of documents {len(documents)} must match number of ids {len(ids)}" - ) + Args: + ids: The ids of the embeddings to update + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional. + metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. + documents: The documents to associate with the embeddings. Optional. + """ - # If document embeddings are not provided, we need to compute them - if embeddings is None and documents is not None: - if self._embedding_function is None: - raise ValueError("You must provide embeddings or a function to compute them") - embeddings = self._embedding_function(documents) + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents + ) - self._client._update(self.name, ids, embeddings, metadatas, documents) + self._client._upsert( + collection_name=self.name, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + increment_index=increment_index, + ) def delete( self, @@ -327,8 +316,65 @@ def delete( """ ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None where = validate_where(where) if where else None - where_document = validate_where_document(where_document) if where_document else None + where_document = ( + validate_where_document(where_document) if where_document else None + ) return self._client._delete(self.name, ids, where, where_document) def create_index(self): self._client.create_index(self.name) + + def _validate_embedding_set( + self, + ids, + embeddings, + metadatas, + documents, + require_embeddings_or_documents=True, + ) -> Tuple[ + IDs, + Optional[List[Embedding]], + Optional[List[Metadata]], + Optional[List[Document]], + ]: + ids = validate_ids(maybe_cast_one_to_many(ids)) + embeddings = ( + maybe_cast_one_to_many(embeddings) if embeddings is not None else None + ) + metadatas = ( + validate_metadatas(maybe_cast_one_to_many(metadatas)) + if metadatas is not None + else None + ) + documents = maybe_cast_one_to_many(documents) if documents is not None else None + + # Check that one of embeddings or documents is provided + if require_embeddings_or_documents: + if embeddings is None and documents is None: + raise ValueError( + "You must provide either embeddings or documents, or both" + ) + + # Check that, if they're provided, the lengths of the arrays match the length of ids + if embeddings is not None and len(embeddings) != len(ids): + raise ValueError( + f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}" + ) + if metadatas is not None and len(metadatas) != len(ids): + raise ValueError( + f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}" + ) + if documents is not None and len(documents) != len(ids): + raise ValueError( + f"Number of documents {len(documents)} must match number of ids {len(ids)}" + ) + + # If document embeddings are not provided, we need to compute them + if embeddings is None and documents is not None: + if self._embedding_function is None: + raise ValueError( + "You must provide embeddings or a function to compute them" + ) + embeddings = self._embedding_function(documents) + + return ids, embeddings, metadatas, documents diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 39d50d723d4..edd303c186c 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,4 +1,6 @@ -from typing import Literal, Optional, Union, Dict, Sequence, TypedDict, Protocol, TypeVar, List +from typing import Optional, Union, Dict, Sequence, TypeVar, List +from typing_extensions import Literal, TypedDict, Protocol +import chromadb.errors as errors ID = str IDs = List[ID] @@ -26,7 +28,9 @@ WhereOperator = Literal["$gt", "$gte", "$lt", "$lte", "$ne", "$eq"] OperatorExpression = Dict[Union[WhereOperator, LogicalOperator], LiteralValue] -Where = Dict[Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]] +Where = Dict[ + Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]] +] WhereDocumentOperator = Literal["$contains", LogicalOperator] WhereDocument = Dict[WhereDocumentOperator, Union[str, List["WhereDocument"]]] @@ -84,6 +88,11 @@ def validate_ids(ids: IDs) -> IDs: for id in ids: if not isinstance(id, str): raise ValueError(f"Expected ID to be a str, got {id}") + if len(ids) != len(set(ids)): + dups = set([x for x in ids if ids.count(x) > 1]) + raise errors.DuplicateIDError( + f"Expected IDs to be unique, found duplicates for: {dups}" + ) return ids @@ -95,7 +104,9 @@ def validate_metadata(metadata: Metadata) -> Metadata: if not isinstance(key, str): raise ValueError(f"Expected metadata key to be a str, got {key}") if not isinstance(value, (str, int, float)): - raise ValueError(f"Expected metadata value to be a str, int, or float, got {value}") + raise ValueError( + f"Expected metadata value to be a str, int, or float, got {value}" + ) return metadata @@ -118,7 +129,11 @@ def validate_where(where: Where) -> Where: for key, value in where.items(): if not isinstance(key, str): raise ValueError(f"Expected where key to be a str, got {key}") - if key != "$and" and key != "$or" and not isinstance(value, (str, int, float, dict)): + if ( + key != "$and" + and key != "$or" + and not isinstance(value, (str, int, float, dict)) + ): raise ValueError( f"Expected where value to be a str, int, float, or operator expression, got {value}" ) @@ -167,7 +182,9 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument: a list of where_document expressions """ if not isinstance(where_document, dict): - raise ValueError(f"Expected where document to be a dictionary, got {where_document}") + raise ValueError( + f"Expected where document to be a dictionary, got {where_document}" + ) if len(where_document) != 1: raise ValueError( f"Expected where document to have exactly one operator, got {where_document}" diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py index b1b6874cdf2..7f1b2969fef 100644 --- a/chromadb/db/clickhouse.py +++ b/chromadb/db/clickhouse.py @@ -1,4 +1,11 @@ -from chromadb.api.types import Documents, Embeddings, IDs, Metadatas, Where, WhereDocument +from chromadb.api.types import ( + Documents, + Embeddings, + IDs, + Metadatas, + Where, + WhereDocument, +) from chromadb.db import DB from chromadb.db.index.hnswlib import Hnswlib, delete_all_indexes from chromadb.errors import ( @@ -53,7 +60,8 @@ def __init__(self, settings): def _init_conn(self): common.set_setting("autogenerate_session_id", False) self._conn = clickhouse_connect.get_client( - host=self._settings.clickhouse_host, port=int(self._settings.clickhouse_port) + host=self._settings.clickhouse_host, + port=int(self._settings.clickhouse_port), ) self._create_table_collections(self._conn) self._create_table_embeddings(self._conn) @@ -100,7 +108,9 @@ def _delete_index(self, collection_id): # UTILITY METHODS # def persist(self): - raise NotImplementedError("Clickhouse is a persistent database, this method is not needed") + raise NotImplementedError( + "Clickhouse is a persistent database, this method is not needed" + ) def get_collection_uuid_from_name(self, name: str) -> str: res = self._get_conn().query( @@ -143,6 +153,9 @@ def create_collection( if len(dupe_check) > 0: if get_or_create: + if dupe_check[0][2] != metadata: + self.update_collection(name, new_name=name, new_metadata=metadata) + dupe_check = self.get_collection(name) logger.info( f"collection with name {name} already exists, returning existing collection" ) @@ -189,7 +202,10 @@ def list_collections(self) -> Sequence: return [[x[0], x[1], json.loads(x[2])] for x in res] def update_collection( - self, current_name: str, new_name: Optional[str] = None, new_metadata: Optional[Dict] = None + self, + current_name: str, + new_name: Optional[str] = None, + new_metadata: Optional[Dict] = None, ): if new_name is None: new_name = current_name @@ -202,11 +218,11 @@ def update_collection( ALTER TABLE collections UPDATE - metadata = '{json.dumps(new_metadata)}', - name = '{new_name}' + metadata = %s, + name = %s WHERE - name = '{current_name}' - """ + name = %s + """, [json.dumps(new_metadata), new_name, current_name] ) def delete_collection(self, name: str): @@ -241,7 +257,14 @@ def add(self, collection_uuid, embeddings, metadatas, documents, ids): ] for i, embedding in enumerate(embeddings) ] - column_names = ["collection_uuid", "uuid", "embedding", "metadata", "document", "id"] + column_names = [ + "collection_uuid", + "uuid", + "embedding", + "metadata", + "document", + "id", + ] self._get_conn().insert("embeddings", data_to_insert, column_names=column_names) return [x[1] for x in data_to_insert] # return uuids @@ -260,26 +283,28 @@ def _update( update_fields = [] parameters[f"i{i}"] = ids[i] if embeddings is not None: - update_fields.append(f"embedding = {{e{i}:Array(Float64)}}") + update_fields.append(f"embedding = %(e{i})s") parameters[f"e{i}"] = embeddings[i] if metadatas is not None: - update_fields.append(f"metadata = {{m{i}:String}}") + update_fields.append(f"metadata = %(m{i})s") parameters[f"m{i}"] = json.dumps(metadatas[i]) if documents is not None: - update_fields.append(f"document = {{d{i}:String}}") + update_fields.append(f"document = %(d{i})s") parameters[f"d{i}"] = documents[i] update_statement = f""" UPDATE {",".join(update_fields)} WHERE - id = {{i{i}:String}} AND + id = %(i{i})s AND collection_uuid = '{collection_uuid}'{"" if i == len(ids) - 1 else ","} """ updates.append(update_statement) update_clauses = ("").join(updates) - self._get_conn().command(f"ALTER TABLE embeddings {update_clauses}", parameters=parameters) + self._get_conn().command( + f"ALTER TABLE embeddings {update_clauses}", parameters=parameters + ) def update( self, @@ -292,14 +317,19 @@ def update( # Verify all IDs exist existing_items = self.get(collection_uuid=collection_uuid, ids=ids) if len(existing_items) != len(ids): - raise ValueError(f"Could not find {len(ids) - len(existing_items)} items for update") + raise ValueError( + f"Could not find {len(ids) - len(existing_items)} items for update" + ) # Update the db self._update(collection_uuid, ids, embeddings, metadatas, documents) # Update the index if embeddings is not None: - update_uuids = [x[1] for x in existing_items] + # `get` current returns items in arbitrary order. + # TODO if we fix `get`, we can remove this explicit mapping. + uuid_mapping = {r[4]: r[1] for r in existing_items} + update_uuids = [uuid_mapping[id] for id in ids] index = self._index(collection_uuid) index.add(update_uuids, embeddings, update=True) @@ -318,37 +348,59 @@ def _get(self, where={}, columns: Optional[List] = None): if "metadata" in select_columns: metadata_column_index = select_columns.index("metadata") db_metadata = val[i][metadata_column_index] - val[i][metadata_column_index] = json.loads(db_metadata) if db_metadata else None + val[i][metadata_column_index] = ( + json.loads(db_metadata) if db_metadata else None + ) return val def _format_where(self, where, result): for key, value in where.items(): + + def has_key_and(clause): + return f"(JSONHas(metadata,'{key}') = 1 AND {clause})" + # Shortcut for $eq if type(value) == str: - result.append(f" JSONExtractString(metadata,'{key}') = '{value}'") + result.append(has_key_and(f" JSONExtractString(metadata,'{key}') = '{value}'")) elif type(value) == int: - result.append(f" JSONExtractInt(metadata,'{key}') = {value}") + result.append(has_key_and(f" JSONExtractInt(metadata,'{key}') = {value}")) elif type(value) == float: - result.append(f" JSONExtractFloat(metadata,'{key}') = {value}") + result.append(has_key_and(f" JSONExtractFloat(metadata,'{key}') = {value}")) # Operator expression elif type(value) == dict: operator, operand = list(value.items())[0] if operator == "$gt": - return result.append(f" JSONExtractFloat(metadata,'{key}') > {operand}") + return result.append( + has_key_and(f" JSONExtractFloat(metadata,'{key}') > {operand}") + ) elif operator == "$lt": - return result.append(f" JSONExtractFloat(metadata,'{key}') < {operand}") + return result.append( + has_key_and(f" JSONExtractFloat(metadata,'{key}') < {operand}") + ) elif operator == "$gte": - return result.append(f" JSONExtractFloat(metadata,'{key}') >= {operand}") + return result.append( + has_key_and(f" JSONExtractFloat(metadata,'{key}') >= {operand}") + ) elif operator == "$lte": - return result.append(f" JSONExtractFloat(metadata,'{key}') <= {operand}") + return result.append( + has_key_and(f" JSONExtractFloat(metadata,'{key}') <= {operand}") + ) elif operator == "$ne": if type(operand) == str: - return result.append(f" JSONExtractString(metadata,'{key}') != '{operand}'") - return result.append(f" JSONExtractFloat(metadata,'{key}') != {operand}") + return result.append( + has_key_and(f" JSONExtractString(metadata,'{key}') != '{operand}'") + ) + return result.append( + has_key_and(f" JSONExtractFloat(metadata,'{key}') != {operand}") + ) elif operator == "$eq": if type(operand) == str: - return result.append(f" JSONExtractString(metadata,'{key}') = '{operand}'") - return result.append(f" JSONExtractFloat(metadata,'{key}') = {operand}") + return result.append( + has_key_and(f" JSONExtractString(metadata,'{key}') = '{operand}'") + ) + return result.append( + has_key_and(f" JSONExtractFloat(metadata,'{key}') = {operand}") + ) else: raise ValueError( f"Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got {operator}" @@ -396,7 +448,9 @@ def get( columns: Optional[List[str]] = None, ) -> Sequence: if collection_name is None and collection_uuid is None: - raise TypeError("Arguments collection_name and collection_uuid cannot both be None") + raise TypeError( + "Arguments collection_name and collection_uuid cannot both be None" + ) if collection_name is not None: collection_uuid = self.get_collection_uuid_from_name(collection_name) @@ -426,7 +480,11 @@ def get( def _count(self, collection_uuid: str): where_string = f"WHERE collection_uuid = '{collection_uuid}'" - return self._get_conn().query(f"SELECT COUNT() FROM embeddings {where_string}").result_rows + return ( + self._get_conn() + .query(f"SELECT COUNT() FROM embeddings {where_string}") + .result_rows + ) def count(self, collection_name: str): collection_uuid = self.get_collection_uuid_from_name(collection_name) @@ -434,7 +492,9 @@ def count(self, collection_name: str): def _delete(self, where_str: Optional[str] = None) -> List: deleted_uuids = ( - self._get_conn().query(f"""SELECT uuid FROM embeddings {where_str}""").result_rows + self._get_conn() + .query(f"""SELECT uuid FROM embeddings {where_str}""") + .result_rows ) self._get_conn().command( f""" @@ -494,17 +554,20 @@ def get_nearest_neighbors( collection_name=None, collection_uuid=None, ) -> Tuple[List[List[uuid.UUID]], npt.NDArray]: - # Either the collection name or the collection uuid must be provided if collection_name is None and collection_uuid is None: - raise TypeError("Arguments collection_name and collection_uuid cannot both be None") + raise TypeError( + "Arguments collection_name and collection_uuid cannot both be None" + ) if collection_name is not None: collection_uuid = self.get_collection_uuid_from_name(collection_name) if len(where) != 0 or len(where_document) != 0: results = self.get( - collection_uuid=collection_uuid, where=where, where_document=where_document + collection_uuid=collection_uuid, + where=where, + where_document=where_document, ) if len(results) > 0: diff --git a/chromadb/db/duckdb.py b/chromadb/db/duckdb.py index 44e35655906..5fd218c8146 100644 --- a/chromadb/db/duckdb.py +++ b/chromadb/db/duckdb.py @@ -13,6 +13,7 @@ import uuid import os import logging +import atexit logger = logging.getLogger(__name__) @@ -40,7 +41,6 @@ def clickhouse_to_duckdb_schema(table_schema): class DuckDB(Clickhouse): # duckdb has a different way of connecting to the database def __init__(self, settings): - self._conn = duckdb.connect() self._create_table_collections() self._create_table_embeddings() @@ -68,9 +68,9 @@ def _create_table_embeddings(self): # UTILITY METHODS # def get_collection_uuid_from_name(self, name): - return self._conn.execute("SELECT uuid FROM collections WHERE name = ?", [name]).fetchall()[ - 0 - ][0] + return self._conn.execute( + "SELECT uuid FROM collections WHERE name = ?", [name] + ).fetchall()[0][0] # # COLLECTION METHODS @@ -82,6 +82,10 @@ def create_collection( dupe_check = self.get_collection(name) if len(dupe_check) > 0: if get_or_create is True: + if dupe_check[0][2] != metadata: + self.update_collection(name, new_name=name, new_metadata=metadata) + dupe_check = self.get_collection(name) + logger.info( f"collection with name {name} already exists, returning existing collection" ) @@ -97,12 +101,16 @@ def create_collection( return [[str(collection_uuid), name, metadata]] def get_collection(self, name: str) -> Sequence: - res = self._conn.execute("""SELECT * FROM collections WHERE name = ?""", [name]).fetchall() + res = self._conn.execute( + """SELECT * FROM collections WHERE name = ?""", [name] + ).fetchall() # json.loads the metadata return [[x[0], x[1], json.loads(x[2])] for x in res] def get_collection_by_id(self, uuid: str) -> Sequence: - res = self._conn.execute("""SELECT * FROM collections WHERE uuid = ?""", [uuid]).fetchone() + res = self._conn.execute( + """SELECT * FROM collections WHERE uuid = ?""", [uuid] + ).fetchone() return [res[0], res[1], json.loads(res[2])] def list_collections(self) -> Sequence: @@ -172,20 +180,32 @@ def _format_where(self, where, result): if type(value) == str: result.append(f" json_extract_string(metadata,'$.{key}') = '{value}'") if type(value) == int: - result.append(f" CAST(json_extract(metadata,'$.{key}') AS INT) = {value}") + result.append( + f" CAST(json_extract(metadata,'$.{key}') AS INT) = {value}" + ) if type(value) == float: - result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) = {value}") + result.append( + f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) = {value}" + ) # Operator expression elif type(value) == dict: operator, operand = list(value.items())[0] if operator == "$gt": - result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) > {operand}") + result.append( + f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) > {operand}" + ) elif operator == "$lt": - result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) < {operand}") + result.append( + f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) < {operand}" + ) elif operator == "$gte": - result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) >= {operand}") + result.append( + f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) >= {operand}" + ) elif operator == "$lte": - result.append(f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) <= {operand}") + result.append( + f" CAST(json_extract(metadata,'$.{key}') AS DOUBLE) <= {operand}" + ) elif operator == "$ne": if type(operand) == str: return result.append( @@ -215,7 +235,9 @@ def _format_where(self, where, result): elif key == "$and": result.append(f"({' AND '.join(all_subresults)})") else: - raise ValueError(f"Operator {key} not supported with a list of where clauses") + raise ValueError( + f"Operator {key} not supported with a list of where clauses" + ) def _format_where_document(self, where_document, results): operator = list(where_document.keys())[0] @@ -335,7 +357,9 @@ def get_by_ids(self, ids: List, columns: Optional[List] = None): ).fetchall() # sort db results by the order of the uuids - response = sorted(response, key=lambda obj: ids.index(uuid.UUID(obj[len(columns) - 1]))) + response = sorted( + response, key=lambda obj: ids.index(uuid.UUID(obj[len(columns) - 1])) + ) return response @@ -374,6 +398,8 @@ def __init__(self, settings): self._save_folder = settings.persist_directory self.load() + # https://docs.python.org/3/library/atexit.html + atexit.register(self.persist) def set_save_folder(self, path): self._save_folder = path @@ -385,7 +411,9 @@ def persist(self): """ Persist the database to disk """ - logger.info(f"Persisting DB to disk, putting it in the save folder: {self._save_folder}") + logger.info( + f"Persisting DB to disk, putting it in the save folder: {self._save_folder}" + ) if self._conn is None: return @@ -426,7 +454,9 @@ def load(self): logger.info(f"No existing DB found in {self._save_folder}, skipping load") else: path = self._save_folder + "/chroma-embeddings.parquet" - self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');") + self._conn.execute( + f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');" + ) logger.info( f"""loaded in {self._conn.query(f"SELECT COUNT() FROM embeddings").fetchall()[0][0]} embeddings""" ) @@ -436,14 +466,16 @@ def load(self): logger.info(f"No existing DB found in {self._save_folder}, skipping load") else: path = self._save_folder + "/chroma-collections.parquet" - self._conn.execute(f"INSERT INTO collections SELECT * FROM read_parquet('{path}');") + self._conn.execute( + f"INSERT INTO collections SELECT * FROM read_parquet('{path}');" + ) logger.info( f"""loaded in {self._conn.query(f"SELECT COUNT() FROM collections").fetchall()[0][0]} collections""" ) def __del__(self): - logger.info("PersistentDuckDB del, about to run persist") - self.persist() + # No-op for duckdb with persistence since the base class will delete the indexes + pass def reset(self): super().reset() diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index 0cfa760d6bb..f00aadd5fbc 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -2,10 +2,15 @@ import pickle import time from typing import Dict + from chromadb.api.types import IndexMetadata import hnswlib from chromadb.db.index import Index -from chromadb.errors import NoIndexException, InvalidDimensionException, NotEnoughElementsException +from chromadb.errors import ( + NoIndexException, + InvalidDimensionException, + NotEnoughElementsException, +) import logging import re from uuid import UUID @@ -24,7 +29,6 @@ class HnswParams: - space: str construction_ef: int search_ef: int @@ -33,7 +37,6 @@ class HnswParams: resize_factor: float def __init__(self, metadata): - metadata = metadata or {} # Convert all values to strings for future compatibility. @@ -44,7 +47,9 @@ def __init__(self, metadata): if param not in valid_params: raise ValueError(f"Unknown HNSW parameter: {param}") if not re.match(valid_params[param], value): - raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") + raise ValueError( + f"Invalid value for HNSW parameter: {param} = {value}" + ) self.space = metadata.get("hnsw:space", "l2") self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) @@ -71,7 +76,7 @@ class Hnswlib(Index): _index_metadata: IndexMetadata _params: HnswParams _id_to_label: Dict[str, int] - _label_to_id: Dict[int, str] + _label_to_id: Dict[int, UUID] def __init__(self, id, settings, metadata): self._save_folder = settings.persist_directory + "/index" @@ -128,7 +133,7 @@ def add(self, ids, embeddings, update=False): labels = [] for id in ids: - if id in self._id_to_label: + if hexid(id) in self._id_to_label: if update: labels.append(self._id_to_label[hexid(id)]) else: @@ -141,7 +146,9 @@ def add(self, ids, embeddings, update=False): labels.append(next_label) if self._index_metadata["elements"] > self._index.get_max_elements(): - new_size = max(self._index_metadata["elements"] * self._params.resize_factor, 1000) + new_size = max( + self._index_metadata["elements"] * self._params.resize_factor, 1000 + ) self._index.resize_index(int(new_size)) self._index.add_items(embeddings, labels) @@ -196,7 +203,6 @@ def _exists(self): return def _load(self): - if not os.path.exists(f"{self._save_folder}/index_{self._id}.bin"): return @@ -208,7 +214,9 @@ def _load(self): with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "rb") as f: self._index_metadata = pickle.load(f) - p = hnswlib.Index(space=self._params.space, dim=self._index_metadata["dimensionality"]) + p = hnswlib.Index( + space=self._params.space, dim=self._index_metadata["dimensionality"] + ) self._index = p self._index.load_index( f"{self._save_folder}/index_{self._id}.bin", @@ -218,9 +226,10 @@ def _load(self): self._index.set_num_threads(self._params.num_threads) def get_nearest_neighbors(self, query, k, ids=None): - if self._index is None: - raise NoIndexException("Index not found, please create an instance before querying") + raise NoIndexException( + "Index not found, please create an instance before querying" + ) # Check dimensionality self._check_dimensionality(query) @@ -245,8 +254,12 @@ def get_nearest_neighbors(self, query, k, ids=None): logger.debug(f"time to pre process our knn query: {time.time() - s2}") s3 = time.time() - database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function) + database_labels, distances = self._index.knn_query( + query, k=k, filter=filter_function + ) logger.debug(f"time to run knn query: {time.time() - s3}") - ids = [[self._label_to_id[label] for label in labels] for labels in database_labels] + ids = [ + [self._label_to_id[label] for label in labels] for labels in database_labels + ] return ids, distances diff --git a/chromadb/errors.py b/chromadb/errors.py index 0a75c43f5b2..ee58ee975f1 100644 --- a/chromadb/errors.py +++ b/chromadb/errors.py @@ -1,14 +1,66 @@ -class NoDatapointsException(Exception): - pass +from abc import ABCMeta, abstractmethod -class NoIndexException(Exception): - pass +class ChromaError(Exception): + def code(self): + """Return an appropriate HTTP response code for this error""" + return 400 # Bad Request -class InvalidDimensionException(Exception): - pass + def message(self): + return ", ".join(self.args) + @classmethod + @abstractmethod + def name(self): + """Return the error name""" + pass -class NotEnoughElementsException(Exception): - pass + +class NoDatapointsException(ChromaError): + @classmethod + def name(cls): + return "NoDatapoints" + + +class NoIndexException(ChromaError): + @classmethod + def name(cls): + return "NoIndex" + + +class InvalidDimensionException(ChromaError): + @classmethod + def name(cls): + return "InvalidDimension" + + +class NotEnoughElementsException(ChromaError): + @classmethod + def name(cls): + return "NotEnoughElements" + + +class IDAlreadyExistsError(ChromaError): + + def code(self): + return 409 # Conflict + + @classmethod + def name(cls): + return "IDAlreadyExists" + + +class DuplicateIDError(ChromaError): + @classmethod + def name(cls): + return "DuplicateID" + +error_types = { + "NoDatapoints": NoDatapointsException, + "NoIndex": NoIndexException, + "InvalidDimension": InvalidDimensionException, + "NotEnoughElements": NotEnoughElementsException, + "IDAlreadyExists": IDAlreadyExistsError, + "DuplicateID": DuplicateIDError, +} diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index d24515ef8bb..09cf19a58a7 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -9,6 +9,7 @@ import chromadb import chromadb.server from chromadb.errors import ( + ChromaError, NoDatapointsException, InvalidDimensionException, NotEnoughElementsException, @@ -45,6 +46,10 @@ def use_route_names_as_operation_ids(app: _FastAPI) -> None: async def catch_exceptions_middleware(request: Request, call_next): try: return await call_next(request) + except ChromaError as e: + return JSONResponse(content={"error": e.name(), + "message": e.message()}, + status_code=e.code()) except Exception as e: logger.exception(e) return JSONResponse(content={"error": repr(e)}, status_code=500) @@ -86,6 +91,9 @@ def __init__(self, settings): self.router.add_api_route( "/api/v1/collections/{collection_name}/update", self.update, methods=["POST"] ) + self.router.add_api_route( + "/api/v1/collections/{collection_name}/upsert", self.upsert, methods=["POST"] + ) self.router.add_api_route( "/api/v1/collections/{collection_name}/get", self.get, methods=["POST"] ) @@ -180,6 +188,16 @@ def update(self, collection_name: str, add: UpdateEmbedding): metadatas=add.metadatas, ) + def upsert(self, collection_name: str, upsert: AddEmbedding): + return self._api._upsert( + collection_name=collection_name, + ids=upsert.ids, + embeddings=upsert.embeddings, + documents=upsert.documents, + metadatas=upsert.metadatas, + increment_index=upsert.increment_index, + ) + def get(self, collection_name, get: GetEmbedding): return self._api._get( collection_name=collection_name, @@ -207,22 +225,15 @@ def reset(self): return self._api.reset() def get_nearest_neighbors(self, collection_name, query: QueryEmbedding): - try: - nnresult = self._api._query( - collection_name=collection_name, - where=query.where, - where_document=query.where_document, - query_embeddings=query.query_embeddings, - n_results=query.n_results, - include=query.include, - ) - return nnresult - except NoDatapointsException as e: - raise HTTPException(status_code=500, detail=str(e)) - except InvalidDimensionException as e: - raise HTTPException(status_code=500, detail=str(e)) - except NotEnoughElementsException as e: - raise HTTPException(status_code=500, detail=str(e)) + nnresult = self._api._query( + collection_name=collection_name, + where=query.where, + where_document=query.where_document, + query_embeddings=query.query_embeddings, + n_results=query.n_results, + include=query.include, + ) + return nnresult def raw_sql(self, raw_sql: RawSql): return self._api.raw_sql(raw_sql.raw_sql) diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py new file mode 100644 index 00000000000..88729bb6ffa --- /dev/null +++ b/chromadb/test/conftest.py @@ -0,0 +1,105 @@ +from chromadb.config import Settings +from chromadb import Client +from chromadb.api import API +import chromadb.server.fastapi +from requests.exceptions import ConnectionError +import hypothesis +import tempfile +import os +import uvicorn +import time +from multiprocessing import Process +import pytest +from typing import Generator, List, Tuple, Callable +import shutil + +hypothesis.settings.register_profile( + "dev", deadline=10000, suppress_health_check=[ + hypothesis.HealthCheck.data_too_large, + hypothesis.HealthCheck.large_base_example + ] +) +hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev")) + + +def _run_server(): + """Run a Chroma server locally""" + settings = Settings( + chroma_api_impl="local", + chroma_db_impl="duckdb", + persist_directory=tempfile.gettempdir() + "/test_server", + ) + server = chromadb.server.fastapi.FastAPI(settings) + uvicorn.run(server.app(), host="0.0.0.0", port=6666, log_level="error") + + +def _await_server(api, attempts=0): + try: + api.heartbeat() + except ConnectionError as e: + if attempts > 10: + raise e + else: + time.sleep(2) + _await_server(api, attempts + 1) + + +def fastapi() -> Generator[API, None, None]: + """Fixture generator that launches a server in a separate process, and yields a + fastapi client connect to it""" + proc = Process(target=_run_server, args=(), daemon=True) + proc.start() + api = chromadb.Client( + Settings( + chroma_api_impl="rest", chroma_server_host="localhost", chroma_server_http_port="6666" + ) + ) + _await_server(api) + yield api + proc.kill() + + +def duckdb() -> Generator[API, None, None]: + """Fixture generator for duckdb""" + yield Client( + Settings( + chroma_api_impl="local", + chroma_db_impl="duckdb", + persist_directory=tempfile.gettempdir(), + ) + ) + + +def duckdb_parquet() -> Generator[API, None, None]: + """Fixture generator for duckdb+parquet""" + + save_path = tempfile.gettempdir() + "/tests" + yield Client( + Settings( + chroma_api_impl="local", + chroma_db_impl="duckdb+parquet", + persist_directory=save_path, + ) + ) + if os.path.exists(save_path): + shutil.rmtree(save_path) + + +def integration_api() -> Generator[API, None, None]: + """Fixture generator for returning a client configured via environmenet + variables, intended for externally configured integration tests + """ + yield chromadb.Client() + + +def fixtures() -> List[Callable[[], Generator[API, None, None]]]: + api_fixtures = [duckdb, duckdb_parquet, fastapi] + if "CHROMA_INTEGRATION_TEST" in os.environ: + api_fixtures.append(integration_api) + if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: + api_fixtures = [integration_api] + return api_fixtures + +@pytest.fixture(scope="module", params=fixtures()) +def api(request) -> Generator[API, None, None]: + yield next(request.param()) diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py new file mode 100644 index 00000000000..0ba2f2e74a0 --- /dev/null +++ b/chromadb/test/property/invariants.py @@ -0,0 +1,251 @@ +import math +from chromadb.test.property.strategies import RecordSet +from typing import Callable, Optional, Union, List, TypeVar +from typing_extensions import Literal +import numpy as np +from chromadb.api import types +from chromadb.api.models.Collection import Collection +from hypothesis import note +from hypothesis.errors import InvalidArgument + +T = TypeVar("T") + + +def maybe_wrap(value: Union[T, List[T]]) -> Union[None, List[T]]: + """Wrap a value in a list if it is not a list""" + if value is None: + return None + elif isinstance(value, List): + return value + else: + return [value] + + +def wrap_all(embeddings: RecordSet) -> RecordSet: + """Ensure that an embedding set has lists for all its values""" + + if embeddings["embeddings"] is None: + embedding_list = None + elif isinstance(embeddings["embeddings"], list): + if len(embeddings["embeddings"]) > 0: + if isinstance(embeddings["embeddings"][0], list): + embedding_list = embeddings["embeddings"] + else: + embedding_list = [embeddings["embeddings"]] + else: + embedding_list = [] + else: + raise InvalidArgument("embeddings must be a list, list of lists, or None") + + return { + "ids": maybe_wrap(embeddings["ids"]), # type: ignore + "documents": maybe_wrap(embeddings["documents"]), # type: ignore + "metadatas": maybe_wrap(embeddings["metadatas"]), # type: ignore + "embeddings": embedding_list, + } + + +def count(collection: Collection, embeddings: RecordSet): + """The given collection count is equal to the number of embeddings""" + count = collection.count() + embeddings = wrap_all(embeddings) + assert count == len(embeddings["ids"]) + + +def _field_matches( + collection: Collection, + embeddings: RecordSet, + field_name: Union[Literal["documents"], Literal["metadatas"]], +): + """ + The actual embedding field is equal to the expected field + field_name: one of [documents, metadatas] + """ + result = collection.get(ids=embeddings["ids"], include=[field_name]) + # The test_out_of_order_ids test fails because of this in test_add.py + # Here we sort by the ids to match the input order + embedding_id_to_index = {id: i for i, id in enumerate(embeddings["ids"])} + actual_field = result[field_name] + # This assert should never happen, if we include metadatas/documents it will be + # [None, None..] if there is no metadata. It will not be just None. + assert actual_field is not None + actual_field = sorted( + enumerate(actual_field), + key=lambda index_and_field_value: embedding_id_to_index[ + result["ids"][index_and_field_value[0]] + ], + ) + actual_field = [field_value for _, field_value in actual_field] + + expected_field = embeddings[field_name] + if expected_field is None: + # Since an RecordSet is the user input, we need to convert the documents to + # a List since thats what the API returns -> none per entry + expected_field = [None] * len(embeddings["ids"]) + assert actual_field == expected_field + + +def ids_match(collection: Collection, embeddings: RecordSet): + """The actual embedding ids is equal to the expected ids""" + embeddings = wrap_all(embeddings) + actual_ids = collection.get(ids=embeddings["ids"], include=[])["ids"] + # The test_out_of_order_ids test fails because of this in test_add.py + # Here we sort the ids to match the input order + embedding_id_to_index = {id: i for i, id in enumerate(embeddings["ids"])} + actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id]) + assert actual_ids == embeddings["ids"] + + +def metadatas_match(collection: Collection, embeddings: RecordSet): + """The actual embedding metadata is equal to the expected metadata""" + embeddings = wrap_all(embeddings) + _field_matches(collection, embeddings, "metadatas") + + +def documents_match(collection: Collection, embeddings: RecordSet): + """The actual embedding documents is equal to the expected documents""" + embeddings = wrap_all(embeddings) + _field_matches(collection, embeddings, "documents") + + +def no_duplicates(collection: Collection): + ids = collection.get()["ids"] + assert len(ids) == len(set(ids)) + + +# These match what the spec of hnswlib is +# This epsilon is used to prevent division by zero and the value is the same +# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238 +NORM_EPS = 1e-30 +distance_functions = { + "l2": lambda x, y: np.linalg.norm(x - y) ** 2, + "cosine": lambda x, y: 1 + - np.dot(x, y) / ((np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS)), + "ip": lambda x, y: 1 - np.dot(x, y), +} + + +def _exact_distances( + query: types.Embeddings, + targets: types.Embeddings, + distance_fn: Callable = lambda x, y: np.linalg.norm(x - y) ** 2, +): + """Return the ordered indices and distances from each query to each target""" + np_query = np.array(query) + np_targets = np.array(targets) + + # Compute the distance between each query and each target, using the distance function + distances = np.apply_along_axis( + lambda query: np.apply_along_axis(distance_fn, 1, np_targets, query), + 1, + np_query, + ) + # Sort the distances and return the indices + return np.argsort(distances), distances + + +def ann_accuracy( + collection: Collection, + record_set: RecordSet, + n_results: int = 1, + min_recall: float = 0.99, + embedding_function: Optional[types.EmbeddingFunction] = None, +): + """Validate that the API performs nearest_neighbor searches correctly""" + record_set = wrap_all(record_set) + + if len(record_set["ids"]) == 0: + return # nothing to test here + + embeddings = record_set["embeddings"] + have_embeddings = embeddings is not None and len(embeddings) > 0 + if not have_embeddings: + assert embedding_function is not None + assert record_set["documents"] is not None + # Compute the embeddings for the documents + embeddings = embedding_function(record_set["documents"]) + + # l2 is the default distance function + distance_function = distance_functions["l2"] + accuracy_threshold = 1e-6 + if "hnsw:space" in collection.metadata: + space = collection.metadata["hnsw:space"] + # TODO: ip and cosine are numerically unstable in HNSW. + # The higher the dimensionality, the more noise is introduced, since each float element + # of the vector has noise added, which is then subsequently included in all normalization calculations. + # This means that higher dimensions will have more noise, and thus more error. + dim = len(embeddings[0]) + accuracy_threshold = accuracy_threshold * math.pow(10, int(math.log10(dim))) + + if space == "cosine": + distance_function = distance_functions["cosine"] + + if space == "ip": + distance_function = distance_functions["ip"] + + # Perform exact distance computation + indices, distances = _exact_distances( + embeddings, embeddings, distance_fn=distance_function + ) + + query_results = collection.query( + query_embeddings=record_set["embeddings"], + query_texts=record_set["documents"] if not have_embeddings else None, + n_results=n_results, + include=["embeddings", "documents", "metadatas", "distances"], + ) + + # Dict of ids to indices + id_to_index = {id: i for i, id in enumerate(record_set["ids"])} + missing = 0 + for i, (indices_i, distances_i) in enumerate(zip(indices, distances)): + expected_ids = np.array(record_set["ids"])[indices_i[:n_results]] + missing += len(set(expected_ids) - set(query_results["ids"][i])) + + # For each id in the query results, find the index in the embeddings set + # and assert that the embeddings are the same + for j, id in enumerate(query_results["ids"][i]): + # This may be because the true nth nearest neighbor didn't get returned by the ANN query + unexpected_id = id not in expected_ids + index = id_to_index[id] + + correct_distance = np.allclose( + distances_i[index], + query_results["distances"][i][j], + atol=accuracy_threshold, + ) + if unexpected_id: + # If the ID is unexpcted, but the distance is correct, then we + # have a duplicate in the data. In this case, we should not reduce recall. + if correct_distance: + missing -= 1 + else: + continue + else: + assert correct_distance + + assert np.allclose(embeddings[index], query_results["embeddings"][i][j]) + if record_set["documents"] is not None: + assert ( + record_set["documents"][index] == query_results["documents"][i][j] + ) + if record_set["metadatas"] is not None: + assert ( + record_set["metadatas"][index] == query_results["metadatas"][i][j] + ) + + size = len(record_set["ids"]) + recall = (size - missing) / size + + try: + note( + f"recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}" + ) + except InvalidArgument: + pass # it's ok if we're running outside hypothesis + + assert recall >= min_recall + + # Ensure that the query results are sorted by distance + for distance_result in query_results["distances"]: + assert np.allclose(np.sort(distance_result), distance_result) diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py new file mode 100644 index 00000000000..f862731cfdf --- /dev/null +++ b/chromadb/test/property/strategies.py @@ -0,0 +1,461 @@ +import hashlib +import hypothesis +import hypothesis.strategies as st +from typing import Optional, Callable, List, Dict, Union +from typing_extensions import TypedDict +import hypothesis.extra.numpy as npst +import numpy as np +import chromadb.api.types as types +import re +from hypothesis.strategies._internal.strategies import SearchStrategy +from hypothesis.errors import InvalidDefinition + +from dataclasses import dataclass + +# Set the random seed for reproducibility +np.random.seed(0) # unnecessary, hypothesis does this for us + +# See Hypothesis documentation for creating strategies at +# https://hypothesis.readthedocs.io/en/latest/data.html + +# NOTE: Because these strategies are used in state machines, we need to +# work around an issue with state machines, in which strategies that frequently +# are marked as invalid (i.e. through the use of `assume` or `.filter`) can cause the +# state machine tests to fail with an hypothesis.errors.Unsatisfiable. + +# Ultimately this is because the entire state machine is run as a single Hypothesis +# example, which ends up drawing from the same strategies an enormous number of times. +# Whenever a strategy marks itself as invalid, Hypothesis tries to start the entire +# state machine run over. See https://github.com/HypothesisWorks/hypothesis/issues/3618 + +# Because strategy generation is all interrelated, seemingly small changes (especially +# ones called early in a test) can have an outside effect. Generating lists with +# unique=True, or dictionaries with a min size seems especially bad. + +# Please make changes to these strategies incrementally, testing to make sure they don't +# start generating unsatisfiable examples. + +test_hnsw_config = { + "hnsw:construction_ef": 128, + "hnsw:search_ef": 128, + "hnsw:M": 128, +} + + +class RecordSet(TypedDict): + """ + A generated set of embeddings, ids, metadatas, and documents that + represent what a user would pass to the API. + """ + + ids: Union[types.ID, List[types.ID]] + embeddings: Optional[Union[types.Embeddings, types.Embedding]] + metadatas: Optional[Union[List[types.Metadata], types.Metadata]] + documents: Optional[Union[List[types.Document], types.Document]] + + +# TODO: support arbitrary text everywhere so we don't SQL-inject ourselves. +# TODO: support empty strings everywhere +sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_" +safe_text = st.text(alphabet=sql_alphabet, min_size=1) + +# Workaround for FastAPI json encoding peculiarities +# https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104 +safe_text = safe_text.filter(lambda s: not s.startswith("_sa")) + +safe_integers = st.integers( + min_value=-(2**31), max_value=2**31 - 1 +) # TODO: handle longs +safe_floats = st.floats( + allow_infinity=False, allow_nan=False, allow_subnormal=False +) # TODO: handle infinity and NAN + +safe_values = [safe_text, safe_integers, safe_floats] + + +def one_or_both(strategy_a, strategy_b): + return st.one_of( + st.tuples(strategy_a, strategy_b), + st.tuples(strategy_a, st.none()), + st.tuples(st.none(), strategy_b), + ) + + +# Temporarily generate only these to avoid SQL formatting issues. +legal_id_characters = ( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./+" +) + +float_types = [np.float16, np.float32, np.float64] +int_types = [np.int16, np.int32, np.int64] # TODO: handle int types + + +@st.composite +def collection_name(draw) -> str: + _collection_name_re = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,60}[a-zA-Z0-9]$") + _ipv4_address_re = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}$") + _two_periods_re = re.compile(r"\.\.") + + name = draw(st.from_regex(_collection_name_re)) + hypothesis.assume(not _ipv4_address_re.match(name)) + hypothesis.assume(not _two_periods_re.search(name)) + + return name + + +collection_metadata = st.one_of( + st.none(), st.dictionaries(safe_text, st.one_of(*safe_values)) +) + + +# TODO: Use a hypothesis strategy while maintaining embedding uniqueness +# Or handle duplicate embeddings within a known epsilon +def create_embeddings(dim: int, count: int, dtype: np.dtype) -> types.Embeddings: + return ( + np.random.uniform( + low=-1.0, + high=1.0, + size=(count, dim), + ) + .astype(dtype) + .tolist() + ) + + +class hashing_embedding_function(types.EmbeddingFunction): + def __init__(self, dim: int, dtype: np.dtype) -> None: + self.dim = dim + self.dtype = dtype + + def __call__(self, texts: types.Documents) -> types.Embeddings: + # Hash the texts and convert to hex strings + hashed_texts = [ + list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in texts + ] + # Pad with repetition, or truncate the hex strings to the desired dimension + padded_texts = [ + text * (self.dim // len(text)) + text[: self.dim % len(text)] + for text in hashed_texts + ] + + # Convert the hex strings to dtype + return np.array( + [[int(char, 16) / 15.0 for char in text] for text in padded_texts], + dtype=self.dtype, + ).tolist() + + +def embedding_function_strategy( + dim: int, dtype: np.dtype +) -> st.SearchStrategy[types.EmbeddingFunction]: + return st.just(hashing_embedding_function(dim, dtype)) + + +@dataclass +class Collection: + name: str + metadata: Optional[types.Metadata] + dimension: int + dtype: np.dtype + known_metadata_keys: Dict[str, st.SearchStrategy] + known_document_keywords: List[str] + has_documents: bool = False + has_embeddings: bool = False + embedding_function: Optional[types.EmbeddingFunction] = None + + +@st.composite +def collections( + draw, + add_filterable_data=False, + with_hnsw_params=False, + has_embeddings: Optional[bool] = None, + has_documents: Optional[bool] = None, +) -> Collection: + """Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data.""" + + assert not ((has_embeddings is False) and (has_documents is False)) + + name = draw(collection_name()) + metadata = draw(collection_metadata) + dimension = draw(st.integers(min_value=2, max_value=2048)) + dtype = draw(st.sampled_from(float_types)) + + if with_hnsw_params: + if metadata is None: + metadata = {} + metadata.update(test_hnsw_config) + # Sometimes, select a space at random + if draw(st.booleans()): + # TODO: pull the distance functions from a source of truth that lives not + # in tests once https://github.com/chroma-core/issues/issues/61 lands + metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"])) + + known_metadata_keys = {} + if add_filterable_data: + while len(known_metadata_keys) < 5: + key = draw(safe_text) + known_metadata_keys[key] = draw(st.sampled_from(safe_values)) + + if has_documents is None: + has_documents = draw(st.booleans()) + if has_documents and add_filterable_data: + known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5)) + else: + known_document_keywords = [] + + if not has_documents: + has_embeddings = True + else: + if has_embeddings is None: + has_embeddings = draw(st.booleans()) + + embedding_function = draw(embedding_function_strategy(dimension, dtype)) + + return Collection( + name=name, + metadata=metadata, + dimension=dimension, + dtype=dtype, + known_metadata_keys=known_metadata_keys, + has_documents=has_documents, + known_document_keywords=known_document_keywords, + has_embeddings=has_embeddings, + embedding_function=embedding_function, + ) + + +@st.composite +def metadata(draw, collection: Collection): + """Strategy for generating metadata that could be a part of the given collection""" + # First draw a random dictionary. + md = draw(st.dictionaries(safe_text, st.one_of(*safe_values))) + # Then, remove keys that overlap with the known keys for the coll + # to avoid type errors when comparing. + if collection.known_metadata_keys: + for key in collection.known_metadata_keys.keys(): + if key in md: + del md[key] + # Finally, add in some of the known keys for the collection + md.update( + draw(st.fixed_dictionaries({}, optional=collection.known_metadata_keys)) + ) + return md + + +@st.composite +def document(draw, collection: Collection): + """Strategy for generating documents that could be a part of the given collection""" + + if collection.known_document_keywords: + known_words_st = st.sampled_from(collection.known_document_keywords) + else: + known_words_st = st.text(min_size=1) + + random_words_st = st.text(min_size=1) + words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1)) + return " ".join(words) + + +@st.composite +def record(draw, collection: Collection, id_strategy=safe_text): + md = draw(metadata(collection)) + + if collection.has_embeddings: + embedding = create_embeddings(collection.dimension, 1, collection.dtype)[0] + else: + embedding = None + if collection.has_documents: + doc = draw(document(collection)) + else: + doc = None + + return { + "id": draw(id_strategy), + "embedding": embedding, + "metadata": md, + "document": doc, + } + + +@st.composite +def recordsets( + draw, + collection_strategy=collections(), + id_strategy=safe_text, + min_size=1, + max_size=50, +) -> RecordSet: + collection = draw(collection_strategy) + + records = draw( + st.lists(record(collection, id_strategy), min_size=min_size, max_size=max_size) + ) + + records = {r["id"]: r for r in records}.values() # Remove duplicates + + ids = [r["id"] for r in records] + embeddings = ( + [r["embedding"] for r in records] if collection.has_embeddings else None + ) + metadatas = [r["metadata"] for r in records] + documents = [r["document"] for r in records] if collection.has_documents else None + + # in the case where we have a single record, sometimes exercise + # the code that handles individual values rather than lists + if len(records) == 1: + if draw(st.booleans()): + ids = ids[0] + if collection.has_embeddings and draw(st.booleans()): + embeddings = embeddings[0] + if draw(st.booleans()): + metadatas = metadatas[0] + if collection.has_documents and draw(st.booleans()): + documents = documents[0] + + return { + "ids": ids, + "embeddings": embeddings, + "metadatas": metadatas, + "documents": documents, + } + + +# This class is mostly cloned from from hypothesis.stateful.RuleStrategy, +# but always runs all the rules, instead of using a FeatureStrategy to +# enable/disable rules. Disabled rules cause the entire test to be marked invalida and, +# combined with the complexity of our other strategies, leads to an +# unacceptably increased incidence of hypothesis.errors.Unsatisfiable. +class DeterministicRuleStrategy(SearchStrategy): + def __init__(self, machine): + super().__init__() + self.machine = machine + self.rules = list(machine.rules()) + + # The order is a bit arbitrary. Primarily we're trying to group rules + # that write to the same location together, and to put rules with no + # target first as they have less effect on the structure. We order from + # fewer to more arguments on grounds that it will plausibly need less + # data. This probably won't work especially well and we could be + # smarter about it, but it's better than just doing it in definition + # order. + self.rules.sort( + key=lambda rule: ( + sorted(rule.targets), + len(rule.arguments), + rule.function.__name__, + ) + ) + + def __repr__(self): + return "{}(machine={}({{...}}))".format( + self.__class__.__name__, + self.machine.__class__.__name__, + ) + + def do_draw(self, data): + if not any(self.is_valid(rule) for rule in self.rules): + msg = f"No progress can be made from state {self.machine!r}" + raise InvalidDefinition(msg) from None + + rule = data.draw(st.sampled_from([r for r in self.rules if self.is_valid(r)])) + argdata = data.draw(rule.arguments_strategy) + return (rule, argdata) + + def is_valid(self, rule): + if not all(precond(self.machine) for precond in rule.preconditions): + return False + + for b in rule.bundles: + bundle = self.machine.bundle(b.name) + if not bundle: + return False + return True + + +@st.composite +def where_clause(draw, collection): + """Generate a filter that could be used in a query against the given collection""" + + known_keys = sorted(collection.known_metadata_keys.keys()) + + key = draw(st.sampled_from(known_keys)) + value = draw(collection.known_metadata_keys[key]) + + legal_ops = [None, "$eq", "$ne"] + if not isinstance(value, str): + legal_ops = ["$gt", "$lt", "$lte", "$gte"] + legal_ops + + op = draw(st.sampled_from(legal_ops)) + + if op is None: + return {key: value} + else: + return {key: {op: value}} + + +@st.composite +def where_doc_clause(draw, collection): + """Generate a where_document filter that could be used against the given collection""" + if collection.known_document_keywords: + word = draw(st.sampled_from(collection.known_document_keywords)) + else: + word = draw(safe_text) + return {"$contains": word} + + +@st.composite +def binary_operator_clause(draw, base_st): + op = draw(st.sampled_from(["$and", "$or"])) + return {op: [draw(base_st), draw(base_st)]} + + +@st.composite +def recursive_where_clause(draw, collection): + base_st = where_clause(collection) + return draw(st.recursive(base_st, binary_operator_clause)) + + +@st.composite +def recursive_where_doc_clause(draw, collection): + base_st = where_doc_clause(collection) + return draw(st.recursive(base_st, binary_operator_clause)) + + +class Filter(TypedDict): + where: Optional[Dict[str, Union[str, int, float]]] + ids: Optional[Union[str, List[str]]] + where_document: Optional[types.WhereDocument] + + +@st.composite +def filters( + draw, + collection_st: st.SearchStrategy[Collection], + recordset_st: st.SearchStrategy[RecordSet], + include_all_ids=False, +) -> Filter: + collection = draw(collection_st) + recordset = draw(recordset_st) + + where_clause = draw(st.one_of(st.none(), recursive_where_clause(collection))) + where_document_clause = draw( + st.one_of(st.none(), recursive_where_doc_clause(collection)) + ) + + ids = recordset["ids"] + # Record sets can be a value instead of a list of values if there is only one record + if isinstance(ids, str): + ids = [ids] + + if not include_all_ids: + ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) + if ids is not None: + # Remove duplicates since hypothesis samples with replacement + ids = list(set(ids)) + + # Test both the single value list and the unwrapped single value case + if ids is not None and len(ids) == 1 and draw(st.booleans()): + ids = ids[0] + + return {"where": where_clause, "where_document": where_document_clause, "ids": ids} diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py new file mode 100644 index 00000000000..be31e1a63a7 --- /dev/null +++ b/chromadb/test/property/test_add.py @@ -0,0 +1,74 @@ +import pytest +import hypothesis.strategies as st +from hypothesis import given, settings +from chromadb.api import API +import chromadb.test.property.strategies as strategies +import chromadb.test.property.invariants as invariants + +collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") + + +@given(collection=collection_st, embeddings=strategies.recordsets(collection_st)) +@settings(deadline=None) +def test_add( + api: API, collection: strategies.Collection, embeddings: strategies.RecordSet +): + api.reset() + + # TODO: Generative embedding functions + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + coll.add(**embeddings) + + embeddings = invariants.wrap_all(embeddings) + invariants.count(coll, embeddings) + n_results = max(1, (len(embeddings["ids"]) // 10)) + invariants.ann_accuracy( + coll, + embeddings, + n_results=n_results, + embedding_function=collection.embedding_function, + ) + + +# TODO: This test fails right now because the ids are not sorted by the input order +@pytest.mark.xfail( + reason="This is expected to fail right now. We should change the API to sort the \ + ids by input order." +) +def test_out_of_order_ids(api: API): + api.reset() + ooo_ids = [ + "40", + "05", + "8", + "6", + "10", + "01", + "00", + "3", + "04", + "20", + "02", + "9", + "30", + "11", + "13", + "2", + "0", + "7", + "06", + "5", + "50", + "12", + "03", + "4", + "1", + ] + coll = api.create_collection("test", embedding_function=lambda x: [1, 2, 3]) + coll.add(ids=ooo_ids, embeddings=[[1, 2, 3] for _ in range(len(ooo_ids))]) + get_ids = coll.get(ids=ooo_ids)["ids"] + assert get_ids == ooo_ids diff --git a/chromadb/test/property/test_collections.py b/chromadb/test/property/test_collections.py new file mode 100644 index 00000000000..de1cde067ff --- /dev/null +++ b/chromadb/test/property/test_collections.py @@ -0,0 +1,122 @@ +import pytest +import logging +import hypothesis.strategies as st +from typing import List +import chromadb +from chromadb.api import API +from chromadb.api.models.Collection import Collection +import chromadb.test.property.strategies as strategies +from hypothesis.stateful import ( + Bundle, + RuleBasedStateMachine, + rule, + initialize, + multiple, + consumes, + run_state_machine_as_test, +) + + +class CollectionStateMachine(RuleBasedStateMachine): + def __init__(self, api): + super().__init__() + self.existing = set() + self.model = {} + self.api = api + + collections = Bundle("collections") + + @initialize() + def initialize(self): + self.api.reset() + self.existing = set() + + @rule(target=collections, coll=strategies.collections()) + def create_coll(self, coll): + if coll.name in self.existing: + with pytest.raises(Exception): + c = self.api.create_collection(name=coll.name, + metadata=coll.metadata, + embedding_function=coll.embedding_function) + return multiple() + + c = self.api.create_collection(name=coll.name, + metadata=coll.metadata, + embedding_function=coll.embedding_function) + self.existing.add(coll.name) + + assert c.name == coll.name + assert c.metadata == coll.metadata + return coll + + @rule(coll=collections) + def get_coll(self, coll): + if coll.name in self.existing: + c = self.api.get_collection(name=coll.name) + assert c.name == coll.name + assert c.metadata == coll.metadata + else: + with pytest.raises(Exception): + self.api.get_collection(name=coll.name) + + @rule(coll=consumes(collections)) + def delete_coll(self, coll): + if coll.name in self.existing: + self.api.delete_collection(name=coll.name) + self.existing.remove(coll.name) + else: + with pytest.raises(Exception): + self.api.delete_collection(name=coll.name) + + with pytest.raises(Exception): + self.api.get_collection(name=coll.name) + + @rule() + def list_collections(self): + colls = self.api.list_collections() + assert len(colls) == len(self.existing) + for c in colls: + assert c.name in self.existing + + @rule( + target=collections, + coll=st.one_of(consumes(collections), strategies.collections()), + ) + def get_or_create_coll(self, coll): + c = self.api.get_or_create_collection(name=coll.name, + metadata=coll.metadata, + embedding_function=coll.embedding_function) + assert c.name == coll.name + if coll.metadata is not None: + assert c.metadata == coll.metadata + self.existing.add(coll.name) + return coll + + @rule( + target=collections, + coll=consumes(collections), + new_metadata=strategies.collection_metadata, + new_name=st.one_of(st.none(), strategies.collection_name()), + ) + def modify_coll(self, coll, new_metadata, new_name): + c = self.api.get_collection(name=coll.name) + + if new_metadata is not None: + coll.metadata = new_metadata + + if new_name is not None: + self.existing.remove(coll.name) + self.existing.add(new_name) + coll.name = new_name + + c.modify(metadata=new_metadata, name=new_name) + c = self.api.get_collection(name=coll.name) + + assert c.name == coll.name + assert c.metadata == coll.metadata + return coll + + +def test_collections(caplog, api): + caplog.set_level(logging.ERROR) + run_state_machine_as_test(lambda: CollectionStateMachine(api)) \ No newline at end of file diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py new file mode 100644 index 00000000000..af6d0487aa8 --- /dev/null +++ b/chromadb/test/property/test_cross_version_persist.py @@ -0,0 +1,252 @@ +import sys +import os +import shutil +import subprocess +import tempfile +from typing import Generator, Tuple +from hypothesis import given, settings +import hypothesis.strategies as st +import pytest +import json +from urllib import request +from chromadb.api import API +import chromadb.test.property.strategies as strategies +import chromadb.test.property.invariants as invariants +from importlib.util import spec_from_file_location, module_from_spec +from packaging import version as packaging_version +import re +import multiprocessing +from chromadb import Client +from chromadb.config import Settings +import sys + +MINIMUM_VERSION = "0.3.20" +COLLECTION_NAME_LOWERCASE_VERSION = "0.3.21" +version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$") + + +def _patch_uppercase_coll_name( + collection: strategies.Collection, embeddings: strategies.RecordSet +): + """Old versions didn't handle uppercase characters in collection names""" + collection.name = collection.name.lower() + + +def _patch_empty_dict_metadata( + collection: strategies.Collection, embeddings: strategies.RecordSet +): + """Old versions do the wrong thing when metadata is a single empty dict""" + if embeddings["metadatas"] == {}: + embeddings["metadatas"] = None + + +version_patches = [ + ("0.3.21", _patch_uppercase_coll_name), + ("0.3.21", _patch_empty_dict_metadata), +] + + +def patch_for_version( + version, collection: strategies.Collection, embeddings: strategies.RecordSet +): + """Override aspects of the collection and embeddings, before testing, to account for + breaking changes in old versions.""" + + for patch_version, patch in version_patches: + if packaging_version.Version(version) <= packaging_version.Version( + patch_version + ): + patch(collection, embeddings) + + +def versions(): + """Returns the pinned minimum version and the latest version of chromadb.""" + url = "https://pypi.org/pypi/chromadb/json" + data = json.load(request.urlopen(request.Request(url))) + versions = list(data["releases"].keys()) + # Older versions on pypi contain "devXYZ" suffixes + versions = [v for v in versions if version_re.match(v)] + versions.sort(key=packaging_version.Version) + return [MINIMUM_VERSION, versions[-1]] + + +def configurations(versions): + return [ + ( + version, + Settings( + chroma_api_impl="local", + chroma_db_impl="duckdb+parquet", + persist_directory=tempfile.gettempdir() + "/tests/" + version + "/", + ), + ) + for version in versions + ] + + +test_old_versions = versions() +base_install_dir = tempfile.gettempdir() + "/persistence_test_chromadb_versions" + + +# This fixture is not shared with the rest of the tests because it is unique in how it +# installs the versions of chromadb +@pytest.fixture(scope="module", params=configurations(test_old_versions)) +def version_settings(request) -> Generator[Tuple[str, Settings], None, None]: + configuration = request.param + version = configuration[0] + install_version(version) + yield configuration + # Cleanup the installed version + path = get_path_to_version_install(version) + shutil.rmtree(path) + # Cleanup the persisted data + data_path = configuration[1].persist_directory + if os.path.exists(data_path): + shutil.rmtree(data_path) + + +def get_path_to_version_install(version): + return base_install_dir + "/" + version + + +def get_path_to_version_library(version): + return get_path_to_version_install(version) + "/chromadb/__init__.py" + + +def install_version(version): + # Check if already installed + version_library = get_path_to_version_library(version) + if os.path.exists(version_library): + return + path = get_path_to_version_install(version) + install(f"chromadb=={version}", path) + + +def install(pkg, path): + # -q -q to suppress pip output to ERROR level + # https://pip.pypa.io/en/stable/cli/pip/#quiet + print(f"Installing chromadb version {pkg} to {path}") + return subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "-q", + "-q", + "install", + pkg, + "--target={}".format(path), + ] + ) + + +def switch_to_version(version): + module_name = "chromadb" + # Remove old version from sys.modules, except test modules + old_modules = { + n: m + for n, m in sys.modules.items() + if n == module_name or (n.startswith(module_name + ".")) + } + for n in old_modules: + del sys.modules[n] + + # Load the target version and override the path to the installed version + # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + path = get_path_to_version_library(version) + sys.path.insert(0, get_path_to_version_install(version)) + spec = spec_from_file_location(module_name, path) + assert spec is not None and spec.loader is not None + module = module_from_spec(spec) + spec.loader.exec_module(module) + assert module.__version__ == version + sys.modules[module_name] = module + return module + + +def persist_generated_data_with_old_version( + version, + settings, + collection_strategy: strategies.Collection, + embeddings_strategy: strategies.RecordSet, +): + old_module = switch_to_version(version) + api: API = old_module.Client(settings) + api.reset() + coll = api.create_collection( + name=collection_strategy.name, + metadata=collection_strategy.metadata, + embedding_function=lambda x: None, + ) + coll.add(**embeddings_strategy) + # We can't use the invariants module here because it uses the current version + # Just use some basic checks for sanity and manual testing where you break the new + # version + + check_embeddings = invariants.wrap_all(embeddings_strategy) + # Check count + assert coll.count() == len(check_embeddings["embeddings"] or []) + # Check ids + result = coll.get() + actual_ids = result["ids"] + embedding_id_to_index = {id: i for i, id in enumerate(check_embeddings["ids"])} + actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id]) + assert actual_ids == check_embeddings["ids"] + api.persist() + del api + + +# Since we can't pickle the embedding function, we always generate record sets with embeddings +collection_st = st.shared( + strategies.collections(with_hnsw_params=True, has_embeddings=True), key="coll" +) + + +@given( + collection_strategy=collection_st, + embeddings_strategy=strategies.recordsets(collection_st), +) +@pytest.mark.skipif( + sys.version_info.major < 3 + or (sys.version_info.major == 3 and sys.version_info.minor <= 7), + reason="The mininum supported versions of chroma do not work with python <= 3.7", +) +@settings(deadline=None) +def test_cycle_versions( + version_settings: Tuple[str, Settings], + collection_strategy: strategies.Collection, + embeddings_strategy: strategies.RecordSet, +): + # # Test backwards compatibility + # # For the current version, ensure that we can load a collection from + # # the previous versions + version, settings = version_settings + + patch_for_version(version, collection_strategy, embeddings_strategy) + + # Can't pickle a function, and we won't need them + collection_strategy.embedding_function = None + collection_strategy.known_metadata_keys = {} + + # Run the task in a separate process to avoid polluting the current process + # with the old version. Using spawn instead of fork to avoid sharing the + # current process memory which would cause the old version to be loaded + ctx = multiprocessing.get_context("spawn") + p = ctx.Process( + target=persist_generated_data_with_old_version, + args=(version, settings, collection_strategy, embeddings_strategy), + ) + p.start() + p.join() + + # Switch to the current version (local working directory) and check the invariants + # are preserved for the collection + api = Client(settings) + coll = api.get_collection( + name=collection_strategy.name, embedding_function=lambda x: None + ) + invariants.count(coll, embeddings_strategy) + invariants.metadatas_match(coll, embeddings_strategy) + invariants.documents_match(coll, embeddings_strategy) + invariants.ids_match(coll, embeddings_strategy) + invariants.ann_accuracy(coll, embeddings_strategy) diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py new file mode 100644 index 00000000000..0a9a682e5d8 --- /dev/null +++ b/chromadb/test/property/test_embeddings.py @@ -0,0 +1,264 @@ +import numpy as np +import pytest +import logging +from hypothesis import given +import hypothesis.strategies as st +from typing import Set, List, Optional, cast +from dataclasses import dataclass +import chromadb.errors as errors +import chromadb +from chromadb.api import API +from chromadb.api.models.Collection import Collection +import chromadb.test.property.strategies as strategies +from hypothesis.stateful import ( + Bundle, + RuleBasedStateMachine, + rule, + initialize, + precondition, + consumes, + run_state_machine_as_test, + multiple, + invariant, +) +from collections import defaultdict +import chromadb.test.property.invariants as invariants +import hypothesis + + +traces = defaultdict(lambda: 0) + + +def trace(key): + global traces + traces[key] += 1 + + +def print_traces(): + global traces + for key, value in traces.items(): + print(f"{key}: {value}") + + +dtype_shared_st = st.shared(st.sampled_from(strategies.float_types), key="dtype") +dimension_shared_st = st.shared( + st.integers(min_value=2, max_value=2048), key="dimension" +) + + +@dataclass +class EmbeddingStateMachineStates: + initialize = "initialize" + add_embeddings = "add_embeddings" + delete_by_ids = "delete_by_ids" + update_embeddings = "update_embeddings" + upsert_embeddings = "upsert_embeddings" + + +collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") + + +class EmbeddingStateMachine(RuleBasedStateMachine): + collection: Collection + embedding_ids: Bundle = Bundle("embedding_ids") + + def __init__(self, api=None): + super().__init__() + # For debug only, to run as class-based test + if not api: + api = chromadb.Client(configurations()[0]) + self.api = api + self._rules_strategy = strategies.DeterministicRuleStrategy(self) + + @initialize(collection=collection_st) + def initialize(self, collection: strategies.Collection): + self.api.reset() + self.collection = self.api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + self.embedding_function = collection.embedding_function + trace("init") + self.on_state_change(EmbeddingStateMachineStates.initialize) + self.embeddings = { + "ids": [], + "embeddings": [], + "metadatas": [], + "documents": [], + } + + @rule(target=embedding_ids, embedding_set=strategies.recordsets(collection_st)) + def add_embeddings(self, embedding_set): + trace("add_embeddings") + self.on_state_change(EmbeddingStateMachineStates.add_embeddings) + + normalized_embedding_set = invariants.wrap_all(embedding_set) + + if len(normalized_embedding_set["ids"]) > 0: + trace("add_more_embeddings") + + if set(normalized_embedding_set["ids"]).intersection( + set(self.embeddings["ids"]) + ): + with pytest.raises(errors.IDAlreadyExistsError): + self.collection.add(**embedding_set) + return multiple() + else: + self.collection.add(**embedding_set) + self._upsert_embeddings(embedding_set) + return multiple(*normalized_embedding_set["ids"]) + + @precondition(lambda self: len(self.embeddings["ids"]) > 20) + @rule(ids=st.lists(consumes(embedding_ids), min_size=1, max_size=20)) + def delete_by_ids(self, ids): + trace("remove embeddings") + self.on_state_change(EmbeddingStateMachineStates.delete_by_ids) + indices_to_remove = [self.embeddings["ids"].index(id) for id in ids] + + self.collection.delete(ids=ids) + self._remove_embeddings(set(indices_to_remove)) + + # Removing the precondition causes the tests to frequently fail as "unsatisfiable" + # Using a value < 5 causes retries and lowers the number of valid samples + @precondition(lambda self: len(self.embeddings["ids"]) >= 5) + @rule( + embedding_set=strategies.recordsets( + collection_strategy=collection_st, + id_strategy=embedding_ids, + min_size=1, + max_size=5, + ) + ) + def update_embeddings(self, embedding_set): + trace("update embeddings") + self.on_state_change(EmbeddingStateMachineStates.update_embeddings) + self.collection.update(**embedding_set) + self._upsert_embeddings(embedding_set) + + # Using a value < 3 causes more retries and lowers the number of valid samples + @precondition(lambda self: len(self.embeddings["ids"]) >= 3) + @rule( + embedding_set=strategies.recordsets( + collection_strategy=collection_st, + id_strategy=st.one_of(embedding_ids, strategies.safe_text), + min_size=1, + max_size=5, + ) + ) + def upsert_embeddings(self, embedding_set): + trace("upsert embeddings") + self.on_state_change(EmbeddingStateMachineStates.upsert_embeddings) + self.collection.upsert(**embedding_set) + self._upsert_embeddings(embedding_set) + + @invariant() + def count(self): + invariants.count(self.collection, self.embeddings) # type: ignore + + @invariant() + def no_duplicates(self): + invariants.no_duplicates(self.collection) + + @invariant() + def ann_accuracy(self): + invariants.ann_accuracy( + collection=self.collection, record_set=self.embeddings, min_recall=0.95, embedding_function=self.embedding_function # type: ignore + ) + + def _upsert_embeddings(self, embeddings: strategies.RecordSet): + embeddings = invariants.wrap_all(embeddings) + for idx, id in enumerate(embeddings["ids"]): + if id in self.embeddings["ids"]: + target_idx = self.embeddings["ids"].index(id) + if "embeddings" in embeddings and embeddings["embeddings"] is not None: + self.embeddings["embeddings"][target_idx] = embeddings[ + "embeddings" + ][idx] + else: + self.embeddings["embeddings"][target_idx] = self.embedding_function( + [embeddings["documents"][idx]] + )[0] + if "metadatas" in embeddings and embeddings["metadatas"] is not None: + self.embeddings["metadatas"][target_idx] = embeddings["metadatas"][ + idx + ] + if "documents" in embeddings and embeddings["documents"] is not None: + self.embeddings["documents"][target_idx] = embeddings["documents"][ + idx + ] + else: + # Add path + self.embeddings["ids"].append(id) + if "embeddings" in embeddings and embeddings["embeddings"] is not None: + self.embeddings["embeddings"].append(embeddings["embeddings"][idx]) + else: + self.embeddings["embeddings"].append( + self.embedding_function([embeddings["documents"][idx]])[0] + ) + if "metadatas" in embeddings and embeddings["metadatas"] is not None: + self.embeddings["metadatas"].append(embeddings["metadatas"][idx]) + else: + self.embeddings["metadatas"].append(None) + if "documents" in embeddings and embeddings["documents"] is not None: + self.embeddings["documents"].append(embeddings["documents"][idx]) + else: + self.embeddings["documents"].append(None) + + def _remove_embeddings(self, indices_to_remove: Set[int]): + indices_list = list(indices_to_remove) + indices_list.sort(reverse=True) + + for i in indices_list: + del self.embeddings["ids"][i] + del self.embeddings["embeddings"][i] + del self.embeddings["metadatas"][i] + del self.embeddings["documents"][i] + + def on_state_change(self, new_state): + pass + + +def test_embeddings_state(caplog, api): + caplog.set_level(logging.ERROR) + run_state_machine_as_test(lambda: EmbeddingStateMachine(api)) + print_traces() + + +def test_multi_add(api: API): + api.reset() + coll = api.create_collection(name="foo") + coll.add(ids=["a"], embeddings=[[0.0]]) + assert coll.count() == 1 + + with pytest.raises(errors.IDAlreadyExistsError): + coll.add(ids=["a"], embeddings=[[0.0]]) + + assert coll.count() == 1 + + results = coll.get() + assert results["ids"] == ["a"] + + coll.delete(ids=["a"]) + assert coll.count() == 0 + + +def test_dup_add(api: API): + api.reset() + coll = api.create_collection(name="foo") + with pytest.raises(errors.DuplicateIDError): + coll.add(ids=["a", "a"], embeddings=[[0.0], [1.1]]) + with pytest.raises(errors.DuplicateIDError): + coll.upsert(ids=["a", "a"], embeddings=[[0.0], [1.1]]) + + +# TODO: Use SQL escaping correctly internally +@pytest.mark.xfail(reason="We don't properly escape SQL internally, causing problems") +def test_escape_chars_in_ids(api: API): + api.reset() + id = "\x1f" + coll = api.create_collection(name="foo") + coll.add(ids=[id], embeddings=[[0.0]]) + assert coll.count() == 1 + coll.delete(ids=[id]) + assert coll.count() == 0 diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py new file mode 100644 index 00000000000..1008f895fd4 --- /dev/null +++ b/chromadb/test/property/test_filtering.py @@ -0,0 +1,177 @@ +from hypothesis import given, settings, HealthCheck +from chromadb.api import API +from chromadb.errors import NoDatapointsException +from chromadb.test.property import invariants +import chromadb.test.property.strategies as strategies +import hypothesis.strategies as st +import logging +import random + + +def _filter_where_clause(clause, mm): + """Return true if the where clause is true for the given metadata map""" + + key, expr = list(clause.items())[0] + + # Handle the shorthand for equal: {key: val} where val is a simple value + if isinstance(expr, str) or isinstance(expr, int) or isinstance(expr, float): + return _filter_where_clause({key: {"$eq": expr}}, mm) + + if key == "$and": + return all(_filter_where_clause(clause, mm) for clause in expr) + if key == "$or": + return any(_filter_where_clause(clause, mm) for clause in expr) + + op, val = list(expr.items())[0] + + if op == "$eq": + return key in mm and mm[key] == val + elif op == "$ne": + return key in mm and mm[key] != val + elif op == "$gt": + return key in mm and mm[key] > val + elif op == "$gte": + return key in mm and mm[key] >= val + elif op == "$lt": + return key in mm and mm[key] < val + elif op == "$lte": + return key in mm and mm[key] <= val + else: + raise ValueError("Unknown operator: {}".format(key)) + + +def _filter_where_doc_clause(clause, doc): + key, expr = list(clause.items())[0] + if key == "$and": + return all(_filter_where_doc_clause(clause, doc) for clause in expr) + elif key == "$or": + return any(_filter_where_doc_clause(clause, doc) for clause in expr) + elif key == "$contains": + return expr in doc + else: + raise ValueError("Unknown operator: {}".format(key)) + + +EMPTY_DICT = {} +EMPTY_STRING = "" + + +def _filter_embedding_set(recordset: strategies.RecordSet, filter: strategies.Filter): + """Return IDs from the embedding set that match the given filter object""" + + recordset = invariants.wrap_all(recordset) + + ids = set(recordset["ids"]) + + filter_ids = filter["ids"] + if filter_ids is not None: + filter_ids = invariants.maybe_wrap(filter_ids) + assert filter_ids is not None + # If the filter ids is an empty list then we treat that as get all + if len(filter_ids) != 0: + ids = ids.intersection(filter_ids) + + for i in range(len(recordset["ids"])): + if filter["where"]: + metadatas = recordset["metadatas"] or [EMPTY_DICT] * len(recordset["ids"]) + if not _filter_where_clause(filter["where"], metadatas[i]): + ids.discard(recordset["ids"][i]) + + if filter["where_document"]: + documents = recordset["documents"] or [EMPTY_STRING] * len(recordset["ids"]) + if not _filter_where_doc_clause(filter["where_document"], documents[i]): + ids.discard(recordset["ids"][i]) + + return list(ids) + + +collection_st = st.shared( + strategies.collections(add_filterable_data=True, with_hnsw_params=True), + key="coll", +) +recordset_st = st.shared( + strategies.recordsets(collection_st, max_size=1000), key="recordset" +) + + +@settings( + suppress_health_check=[ + HealthCheck.function_scoped_fixture, + HealthCheck.large_base_example, + ] +) +@given( + collection=collection_st, + recordset=recordset_st, + filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1), +) +def test_filterable_metadata_get(caplog, api: API, collection, recordset, filters): + caplog.set_level(logging.ERROR) + + api.reset() + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + coll.add(**recordset) + + for filter in filters: + result_ids = coll.get(**filter)["ids"] + expected_ids = _filter_embedding_set(recordset, filter) + assert sorted(result_ids) == sorted(expected_ids) + + +@settings( + suppress_health_check=[ + HealthCheck.function_scoped_fixture, + HealthCheck.large_base_example, + ] +) +@given( + collection=collection_st, + recordset=recordset_st, + filters=st.lists( + strategies.filters(collection_st, recordset_st, include_all_ids=True), + min_size=1, + ), +) +def test_filterable_metadata_query( + caplog, + api: API, + collection: strategies.Collection, + recordset: strategies.RecordSet, + filters, +): + caplog.set_level(logging.ERROR) + + api.reset() + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, + embedding_function=collection.embedding_function, + ) + coll.add(**recordset) + recordset = invariants.wrap_all(recordset) + total_count = len(recordset["ids"]) + # Pick a random vector + if collection.has_embeddings: + random_query = recordset["embeddings"][random.randint(0, total_count - 1)] + else: + random_query = collection.embedding_function( + recordset["documents"][random.randint(0, total_count - 1)] + ) + for filter in filters: + try: + result_ids = set( + coll.query( + query_embeddings=random_query, + n_results=total_count, + where=filter["where"], + where_document=filter["where_document"], + )["ids"][0] + ) + except NoDatapointsException: + result_ids = set() + expected_ids = set(_filter_embedding_set(recordset, filter)) + assert len(result_ids.intersection(expected_ids)) == len(result_ids) diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py new file mode 100644 index 00000000000..d4632de2e30 --- /dev/null +++ b/chromadb/test/property/test_persist.py @@ -0,0 +1,156 @@ +import logging +import multiprocessing +from typing import Generator, Callable +from hypothesis import given +import hypothesis.strategies as st +import pytest +import chromadb +from chromadb.api import API +from chromadb.config import Settings +import chromadb.test.property.strategies as strategies +import chromadb.test.property.invariants as invariants +from chromadb.test.property.test_embeddings import ( + EmbeddingStateMachine, + EmbeddingStateMachineStates, +) +from hypothesis.stateful import run_state_machine_as_test, rule, precondition +import os +import shutil +import pytest +import tempfile + +CreatePersistAPI = Callable[[], API] + +configurations = [ + Settings( + chroma_api_impl="local", + chroma_db_impl="duckdb+parquet", + persist_directory=tempfile.gettempdir() + "/tests", + ) +] + + +@pytest.fixture(scope="module", params=configurations) +def settings(request) -> Generator[Settings, None, None]: + configuration = request.param + yield configuration + save_path = configuration.persist_directory + # Remove if it exists + if os.path.exists(save_path): + shutil.rmtree(save_path) + + +collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") + + +@given( + collection_strategy=collection_st, + embeddings_strategy=strategies.recordsets(collection_st), +) +def test_persist( + settings: Settings, + collection_strategy: strategies.Collection, + embeddings_strategy: strategies.RecordSet, +): + api_1 = chromadb.Client(settings) + api_1.reset() + coll = api_1.create_collection( + name=collection_strategy.name, + metadata=collection_strategy.metadata, + embedding_function=collection_strategy.embedding_function, + ) + + coll.add(**embeddings_strategy) + + invariants.count(coll, embeddings_strategy) + invariants.metadatas_match(coll, embeddings_strategy) + invariants.documents_match(coll, embeddings_strategy) + invariants.ids_match(coll, embeddings_strategy) + invariants.ann_accuracy( + coll, + embeddings_strategy, + embedding_function=collection_strategy.embedding_function, + ) + + api_1.persist() + del api_1 + + api_2 = chromadb.Client(settings) + coll = api_2.get_collection( + name=collection_strategy.name, + embedding_function=collection_strategy.embedding_function, + ) + invariants.count(coll, embeddings_strategy) + invariants.metadatas_match(coll, embeddings_strategy) + invariants.documents_match(coll, embeddings_strategy) + invariants.ids_match(coll, embeddings_strategy) + invariants.ann_accuracy( + coll, + embeddings_strategy, + embedding_function=collection_strategy.embedding_function, + ) + + +def load_and_check(settings: Settings, collection_name: str, embeddings_set, conn): + try: + api = chromadb.Client(settings) + coll = api.get_collection( + name=collection_name, embedding_function=lambda x: None + ) + invariants.count(coll, embeddings_set) + invariants.metadatas_match(coll, embeddings_set) + invariants.documents_match(coll, embeddings_set) + invariants.ids_match(coll, embeddings_set) + invariants.ann_accuracy(coll, embeddings_set) + except Exception as e: + conn.send(e) + raise e + + +class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): + persist = "persist" + + +class PersistEmbeddingsStateMachine(EmbeddingStateMachine): + def __init__(self, api: API, settings: Settings): + self.api = api + self.settings = settings + self.last_persist_delay = 10 + self.api.reset() + super().__init__(self.api) + + @precondition(lambda self: len(self.embeddings["ids"]) >= 1) + @precondition(lambda self: self.last_persist_delay <= 0) + @rule() + def persist(self): + self.on_state_change(PersistEmbeddingsStateMachineStates.persist) + self.api.persist() + collection_name = self.collection.name + # Create a new process and then inside the process run the invariants + # TODO: Once we switch off of duckdb and onto sqlite we can remove this + ctx = multiprocessing.get_context("spawn") + conn1, conn2 = multiprocessing.Pipe() + p = ctx.Process( + target=load_and_check, + args=(self.settings, collection_name, self.embeddings, conn2), + ) + p.start() + p.join() + + if conn1.poll(): + e = conn1.recv() + raise e + + def on_state_change(self, new_state): + if new_state == PersistEmbeddingsStateMachineStates.persist: + self.last_persist_delay = 10 + else: + self.last_persist_delay -= 1 + + +def test_persist_embeddings_state(caplog, settings: Settings): + caplog.set_level(logging.ERROR) + api = chromadb.Client(settings) + run_state_machine_as_test( + lambda: PersistEmbeddingsStateMachine(settings=settings, api=api) + ) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 8c010545711..277e8530504 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -13,17 +13,6 @@ import numpy as np -@pytest.fixture -def local_api(): - return chromadb.Client( - Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb", - persist_directory=tempfile.gettempdir(), - ) - ) - - @pytest.fixture def local_persist_api(): return chromadb.Client( @@ -34,7 +23,6 @@ def local_persist_api(): ) ) - # https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached @pytest.fixture def local_persist_api_cache_bust(): @@ -47,67 +35,6 @@ def local_persist_api_cache_bust(): ) -@pytest.fixture -def fastapi_integration_api(): - return chromadb.Client() # configured by environment variables - - -def _build_fastapi_api(): - return chromadb.Client( - Settings( - chroma_api_impl="rest", chroma_server_host="localhost", chroma_server_http_port="6666" - ) - ) - - -@pytest.fixture -def fastapi_api(): - return _build_fastapi_api() - - -def run_server(): - settings = Settings( - chroma_api_impl="local", - chroma_db_impl="duckdb", - persist_directory=tempfile.gettempdir() + "/test_server", - ) - server = chromadb.server.fastapi.FastAPI(settings) - uvicorn.run(server.app(), host="0.0.0.0", port=6666, log_level="info") - - -def await_server(attempts=0): - api = _build_fastapi_api() - - try: - api.heartbeat() - except ConnectionError as e: - if attempts > 10: - raise e - else: - time.sleep(2) - await_server(attempts + 1) - - -@pytest.fixture(scope="module", autouse=True) -def fastapi_server(): - proc = Process(target=run_server, args=(), daemon=True) - proc.start() - await_server() - yield - proc.kill() - - -test_apis = [local_api, fastapi_api] - -if "CHROMA_INTEGRATION_TEST" in os.environ: - print("Including integration tests") - test_apis.append(fastapi_integration_api) - -if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: - print("Including integration tests only") - test_apis = [fastapi_integration_api] - - @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_loading(api_fixture, request): api = request.getfixturevalue("local_persist_api") @@ -203,22 +130,18 @@ def test_persist(api_fixture, request): assert api.list_collections() == [] -@pytest.mark.parametrize("api_fixture", test_apis) -def test_heartbeat(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_heartbeat(api): assert isinstance(api.heartbeat(), int) batch_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], - "ids": ["https://example.com", "https://example.com"], + "ids": ["https://example.com/1", "https://example.com/2"], } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_add(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_add(api): api.reset() @@ -229,9 +152,7 @@ def test_add(api_fixture, request): assert collection.count() == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_get_or_create(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_get_or_create(api): api.reset() @@ -251,14 +172,11 @@ def test_get_or_create(api_fixture, request): minimal_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], - "ids": ["https://example.com", "https://example.com"], + "ids": ["https://example.com/1", "https://example.com/2"], } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_add_minimal(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) - +def test_add_minimal(api): api.reset() collection = api.create_collection("testspace") @@ -268,9 +186,7 @@ def test_add_minimal(api_fixture, request): assert collection.count() == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_get_from_db(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_get_from_db(api): api.reset() collection = api.create_collection("testspace") @@ -280,9 +196,7 @@ def test_get_from_db(api_fixture, request): assert len(records[key]) == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_reset_db(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_reset_db(api): api.reset() @@ -294,9 +208,7 @@ def test_reset_db(api_fixture, request): assert len(api.list_collections()) == 0 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_get_nearest_neighbors(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_get_nearest_neighbors(api): api.reset() collection = api.create_collection("testspace") @@ -331,10 +243,7 @@ def test_get_nearest_neighbors(api_fixture, request): assert len(nn[key]) == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_get_nearest_neighbors_filter(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) - +def test_get_nearest_neighbors_filter(api, request): api.reset() collection = api.create_collection("testspace") collection.add(**batch_records) @@ -349,9 +258,7 @@ def test_get_nearest_neighbors_filter(api_fixture, request): assert str(e.value).__contains__("found") -@pytest.mark.parametrize("api_fixture", test_apis) -def test_delete(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_delete(api): api.reset() collection = api.create_collection("testspace") @@ -362,9 +269,7 @@ def test_delete(api_fixture, request): assert collection.count() == 0 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_delete_with_index(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_delete_with_index(api): api.reset() collection = api.create_collection("testspace") @@ -373,9 +278,7 @@ def test_delete_with_index(api_fixture, request): collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_count(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_count(api): api.reset() collection = api.create_collection("testspace") @@ -384,9 +287,7 @@ def test_count(api_fixture, request): assert collection.count() == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_modify(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_modify(api): api.reset() collection = api.create_collection("testspace") @@ -396,9 +297,7 @@ def test_modify(api_fixture, request): assert collection.name == "testspace2" -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_cru(api_fixture, request): - api: API = request.getfixturevalue(api_fixture.__name__) +def test_metadata_cru(api): api.reset() metadata_a = {"a": 1, "b": 2} @@ -448,9 +347,9 @@ def test_metadata_cru(api_fixture, request): assert collection.metadata is None -@pytest.mark.parametrize("api_fixture", test_apis) -def test_increment_index_on(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_increment_index_on(api): + api.reset() collection = api.create_collection("testspace") @@ -468,9 +367,9 @@ def test_increment_index_on(api_fixture, request): assert len(nn[key]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_increment_index_off(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_increment_index_off(api): + api.reset() collection = api.create_collection("testspace") @@ -488,9 +387,9 @@ def test_increment_index_off(api_fixture, request): assert len(nn[key]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def skipping_indexing_will_fail(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def skipping_indexing_will_fail(api): + api.reset() collection = api.create_collection("testspace") @@ -503,9 +402,9 @@ def skipping_indexing_will_fail(api_fixture, request): assert str(e.value).__contains__("index not found") -@pytest.mark.parametrize("api_fixture", test_apis) -def test_add_a_collection(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_add_a_collection(api): + api.reset() api.create_collection("testspace") @@ -519,9 +418,9 @@ def test_add_a_collection(api_fixture, request): collection = api.get_collection("testspace2") -@pytest.mark.parametrize("api_fixture", test_apis) -def test_list_collections(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_list_collections(api): + api.reset() api.create_collection("testspace") @@ -532,9 +431,9 @@ def test_list_collections(api_fixture, request): assert len(collections) == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_reset(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_reset(api): + api.reset() api.create_collection("testspace") @@ -549,9 +448,9 @@ def test_reset(api_fixture, request): assert len(collections) == 0 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_peek(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_peek(api): + api.reset() collection = api.create_collection("testspace") @@ -574,9 +473,9 @@ def test_peek(api_fixture, request): } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_add_get_int_float(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_add_get_int_float(api): + api.reset() collection = api.create_collection("test_int") @@ -590,9 +489,9 @@ def test_metadata_add_get_int_float(api_fixture, request): assert type(items["metadatas"][0]["float_value"]) == float -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_add_query_int_float(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_add_query_int_float(api): + api.reset() collection = api.create_collection("test_int") @@ -606,9 +505,9 @@ def test_metadata_add_query_int_float(api_fixture, request): assert type(items["metadatas"][0][0]["float_value"]) == float -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_get_where_string(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_get_where_string(api): + api.reset() collection = api.create_collection("test_int") @@ -619,9 +518,9 @@ def test_metadata_get_where_string(api_fixture, request): assert items["metadatas"][0]["string_value"] == "one" -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_get_where_int(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_get_where_int(api): + api.reset() collection = api.create_collection("test_int") @@ -632,9 +531,9 @@ def test_metadata_get_where_int(api_fixture, request): assert items["metadatas"][0]["string_value"] == "one" -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_get_where_float(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_get_where_float(api): + api.reset() collection = api.create_collection("test_int") @@ -646,9 +545,9 @@ def test_metadata_get_where_float(api_fixture, request): assert items["metadatas"][0]["float_value"] == 1.001 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_update_get_int_float(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_update_get_int_float(api): + api.reset() collection = api.create_collection("test_int") @@ -670,9 +569,9 @@ def test_metadata_update_get_int_float(api_fixture, request): } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_validation_add(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_validation_add(api): + api.reset() collection = api.create_collection("test_metadata_validation") @@ -680,9 +579,9 @@ def test_metadata_validation_add(api_fixture, request): collection.add(**bad_metadata_records) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_metadata_validation_update(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_metadata_validation_update(api): + api.reset() collection = api.create_collection("test_metadata_validation") @@ -691,9 +590,9 @@ def test_metadata_validation_update(api_fixture, request): collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}}) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_validation_get(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_validation_get(api): + api.reset() collection = api.create_collection("test_where_validation") @@ -701,9 +600,9 @@ def test_where_validation_get(api_fixture, request): collection.get(where={"value": {"nested": "5"}}) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_validation_query(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_validation_query(api): + api.reset() collection = api.create_collection("test_where_validation") @@ -721,9 +620,9 @@ def test_where_validation_query(api_fixture, request): } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_lt(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_lt(api): + api.reset() collection = api.create_collection("test_where_lt") @@ -732,9 +631,9 @@ def test_where_lt(api_fixture, request): assert len(items["metadatas"]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_lte(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_lte(api): + api.reset() collection = api.create_collection("test_where_lte") @@ -743,9 +642,9 @@ def test_where_lte(api_fixture, request): assert len(items["metadatas"]) == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_gt(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_gt(api): + api.reset() collection = api.create_collection("test_where_lte") @@ -754,9 +653,9 @@ def test_where_gt(api_fixture, request): assert len(items["metadatas"]) == 2 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_gte(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_gte(api): + api.reset() collection = api.create_collection("test_where_lte") @@ -765,9 +664,9 @@ def test_where_gte(api_fixture, request): assert len(items["metadatas"]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_ne_string(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_ne_string(api): + api.reset() collection = api.create_collection("test_where_lte") @@ -776,9 +675,9 @@ def test_where_ne_string(api_fixture, request): assert len(items["metadatas"]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_ne_eq_number(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_ne_eq_number(api): + api.reset() collection = api.create_collection("test_where_lte") @@ -789,9 +688,9 @@ def test_where_ne_eq_number(api_fixture, request): assert len(items["metadatas"]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_valid_operators(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_valid_operators(api): + api.reset() collection = api.create_collection("test_where_valid_operators") @@ -851,9 +750,9 @@ def test_where_valid_operators(api_fixture, request): } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_dimensionality_validation_add(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_dimensionality_validation_add(api): + api.reset() collection = api.create_collection("test_dimensionality_validation") @@ -864,9 +763,9 @@ def test_dimensionality_validation_add(api_fixture, request): assert "dimensionality" in str(e.value) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_dimensionality_validation_query(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_dimensionality_validation_query(api): + api.reset() collection = api.create_collection("test_dimensionality_validation_query") @@ -877,9 +776,9 @@ def test_dimensionality_validation_query(api_fixture, request): assert "dimensionality" in str(e.value) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_number_of_elements_validation_query(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_number_of_elements_validation_query(api): + api.reset() collection = api.create_collection("test_number_of_elements_validation") @@ -890,9 +789,9 @@ def test_number_of_elements_validation_query(api_fixture, request): assert "number of elements" in str(e.value) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_query_document_valid_operators(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_query_document_valid_operators(api): + api.reset() collection = api.create_collection("test_where_valid_operators") @@ -936,9 +835,9 @@ def test_query_document_valid_operators(api_fixture, request): } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_get_where_document(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_get_where_document(api): + api.reset() collection = api.create_collection("test_get_where_document") @@ -954,9 +853,9 @@ def test_get_where_document(api_fixture, request): assert len(items["metadatas"]) == 0 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_query_where_document(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_query_where_document(api): + api.reset() collection = api.create_collection("test_query_where_document") @@ -979,9 +878,9 @@ def test_query_where_document(api_fixture, request): assert "datapoints" in str(e.value) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_delete_where_document(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_delete_where_document(api): + api.reset() collection = api.create_collection("test_delete_where_document") @@ -1015,9 +914,9 @@ def test_delete_where_document(api_fixture, request): } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_logical_operators(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_logical_operators(api): + api.reset() collection = api.create_collection("test_logical_operators") @@ -1055,9 +954,9 @@ def test_where_logical_operators(api_fixture, request): assert len(items["metadatas"]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_where_document_logical_operators(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_where_document_logical_operators(api): + api.reset() collection = api.create_collection("test_document_logical_operators") @@ -1107,9 +1006,9 @@ def test_where_document_logical_operators(api_fixture, request): } -@pytest.mark.parametrize("api_fixture", test_apis) -def test_query_include(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_query_include(api): + api.reset() collection = api.create_collection("test_query_include") @@ -1141,9 +1040,9 @@ def test_query_include(api_fixture, request): assert items["ids"][0][1] == "id2" -@pytest.mark.parametrize("api_fixture", test_apis) -def test_get_include(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_get_include(api): + api.reset() collection = api.create_collection("test_get_include") @@ -1174,9 +1073,9 @@ def test_get_include(api_fixture, request): # make sure query results are returned in the right order -@pytest.mark.parametrize("api_fixture", test_apis) -def test_query_order(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_query_order(api): + api.reset() collection = api.create_collection("test_query_order") @@ -1193,9 +1092,9 @@ def test_query_order(api_fixture, request): # test to make sure add, get, delete error on invalid id input -@pytest.mark.parametrize("api_fixture", test_apis) -def test_invalid_id(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_invalid_id(api): + api.reset() collection = api.create_collection("test_invalid_id") @@ -1215,9 +1114,9 @@ def test_invalid_id(api_fixture, request): assert "ID" in str(e.value) -@pytest.mark.parametrize("api_fixture", test_apis) -def test_index_params(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_index_params(api): + # first standard add api.reset() @@ -1254,10 +1153,10 @@ def test_index_params(api_fixture, request): assert items["distances"][0][0] < -5 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_invalid_index_params(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_invalid_index_params(api): + + api.reset() with pytest.raises(Exception): @@ -1273,8 +1172,7 @@ def test_invalid_index_params(api_fixture, request): collection.add(**records) -@pytest.mark.parametrize("api_fixture", [local_persist_api]) -def test_persist_index_loading_params(api_fixture, request): +def test_persist_index_loading_params(api, request): api = request.getfixturevalue("local_persist_api") api.reset() collection = api.create_collection("test", metadata={"hnsw:space": "ip"}) @@ -1297,9 +1195,9 @@ def test_persist_index_loading_params(api_fixture, request): assert len(nn[key]) == 1 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_add_large(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_add_large(api): + api.reset() @@ -1317,9 +1215,9 @@ def test_add_large(api_fixture, request): # test get_version -@pytest.mark.parametrize("api_fixture", test_apis) -def test_get_version(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_get_version(api): + api.reset() version = api.get_version() @@ -1330,9 +1228,9 @@ def test_get_version(api_fixture, request): # test delete_collection -@pytest.mark.parametrize("api_fixture", test_apis) -def test_delete_collection(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) + +def test_delete_collection(api): + api.reset() collection = api.create_collection("test_delete_collection") collection.add(**records) @@ -1342,15 +1240,15 @@ def test_delete_collection(api_fixture, request): assert len(api.list_collections()) == 0 -@pytest.mark.parametrize("api_fixture", test_apis) -def test_multiple_collections(api_fixture, request): + +def test_multiple_collections(api): embeddings1 = np.random.rand(10, 512).astype(np.float32).tolist() embeddings2 = np.random.rand(10, 512).astype(np.float32).tolist() ids1 = [f"http://example.com/1/{i}" for i in range(len(embeddings1))] ids2 = [f"http://example.com/2/{i}" for i in range(len(embeddings2))] - api = request.getfixturevalue(api_fixture.__name__) + api.reset() coll1 = api.create_collection("coll1") coll1.add(embeddings=embeddings1, ids=ids1) @@ -1369,10 +1267,10 @@ def test_multiple_collections(api_fixture, request): assert results2["ids"][0][0] == ids2[0] -@pytest.mark.parametrize("api_fixture", test_apis) -def test_update_query(api_fixture, request): - api = request.getfixturevalue(api_fixture.__name__) +def test_update_query(api): + + api.reset() collection = api.create_collection("test_update_query") collection.add(**records) @@ -1397,3 +1295,47 @@ def test_update_query(api_fixture, request): assert results["documents"][0][0] == updated_records["documents"][0] assert results["metadatas"][0][0]["foo"] == "bar" assert results["embeddings"][0][0] == updated_records["embeddings"][0] + + +initial_records = { + "embeddings": [[0, 0, 0], [1.2, 2.24, 3.2], [2.2, 3.24, 4.2]], + "ids": ["id1", "id2", "id3"], + "metadatas": [{"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2}, {"string_value": "three"}], + "documents": ["this document is first", "this document is second", "this document is third"], +} + +new_records = { + "embeddings": [[3.0, 3.0, 1.1], [3.2, 4.24, 5.2]], + "ids": ["id1", "id4"], + "metadatas": [{"int_value": 1, "string_value": "one_of_one", "float_value": 1.001}, {"int_value": 4}], + "documents": ["this document is even more first", "this document is new and fourth"], +} + +def test_upsert(api): + api.reset() + collection = api.create_collection("test") + + collection.add(**initial_records) + assert collection.count() == 3 + + collection.upsert(**new_records) + assert collection.count() == 4 + + get_result = collection.get(include=['embeddings', 'metadatas', 'documents'], ids=new_records['ids'][0]) + assert get_result['embeddings'][0] == new_records['embeddings'][0] + assert get_result['metadatas'][0] == new_records['metadatas'][0] + assert get_result['documents'][0] == new_records['documents'][0] + + query_result = collection.query(query_embeddings=get_result['embeddings'], n_results=1, include=['embeddings', 'metadatas', 'documents']) + assert query_result['embeddings'][0][0] == new_records['embeddings'][0] + assert query_result['metadatas'][0][0] == new_records['metadatas'][0] + assert query_result['documents'][0][0] == new_records['documents'][0] + + collection.delete(ids=initial_records['ids'][2]) + collection.upsert(ids=initial_records['ids'][2], embeddings=[[1.1, 0.99, 2.21]], metadatas=[{"string_value": "a new string value"}]) + assert collection.count() == 4 + + get_result = collection.get(include=['embeddings', 'metadatas', 'documents'], ids=['id3']) + assert get_result['embeddings'][0] == [1.1, 0.99, 2.21] + assert get_result['metadatas'][0] == {"string_value": "a new string value"} + assert get_result['documents'][0] == None \ No newline at end of file diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 28e1103edd9..057ecac9fae 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -3,16 +3,21 @@ class SentenceTransformerEmbeddingFunction(EmbeddingFunction): + + models = {} + # If you have a beefier machine, try "gtr-t5-large". # for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html def __init__(self, model_name: str = "all-MiniLM-L6-v2"): - try: - from sentence_transformers import SentenceTransformer - except ImportError: - raise ValueError( - "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" - ) - self._model = SentenceTransformer(model_name) + if model_name not in self.models: + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ValueError( + "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" + ) + self.models[model_name] = SentenceTransformer(model_name) + self._model = self.models[model_name] def __call__(self, texts: Documents) -> Embeddings: return self._model.encode(list(texts), convert_to_numpy=True).tolist() diff --git a/clients/js/src/generated/api.ts b/clients/js/src/generated/api.ts index 20d03625c5b..f4c5bd3638d 100644 --- a/clients/js/src/generated/api.ts +++ b/clients/js/src/generated/api.ts @@ -614,6 +614,49 @@ export const ApiApiFetchParamCreator = function (configuration?: Configuration) options: localVarRequestOptions, }; }, + /** + * @summary Upsert + * @param {string} collectionName + * @param {Api.AddEmbedding} request + * @param {RequestInit} [options] Override http request option. + * @throws {RequiredError} + */ + upsert(collectionName: string, request: Api.AddEmbedding, options: RequestInit = {}): FetchArgs { + // verify required parameter 'collectionName' is not null or undefined + if (collectionName === null || collectionName === undefined) { + throw new RequiredError('collectionName', 'Required parameter collectionName was null or undefined when calling upsert.'); + } + // verify required parameter 'request' is not null or undefined + if (request === null || request === undefined) { + throw new RequiredError('request', 'Required parameter request was null or undefined when calling upsert.'); + } + let localVarPath = `/api/v1/collections/{collection_name}/upsert` + .replace('{collection_name}', encodeURIComponent(String(collectionName))); + const localVarPathQueryStart = localVarPath.indexOf("?"); + const localVarRequestOptions: RequestInit = Object.assign({ method: 'POST' }, options); + const localVarHeaderParameter: Headers = options.headers ? new Headers(options.headers) : new Headers(); + const localVarQueryParameter = new URLSearchParams(localVarPathQueryStart !== -1 ? localVarPath.substring(localVarPathQueryStart + 1) : ""); + if (localVarPathQueryStart !== -1) { + localVarPath = localVarPath.substring(0, localVarPathQueryStart); + } + + localVarHeaderParameter.set('Content-Type', 'application/json'); + + localVarRequestOptions.headers = localVarHeaderParameter; + + if (request !== undefined) { + localVarRequestOptions.body = JSON.stringify(request || {}); + } + + const localVarQueryParameterString = localVarQueryParameter.toString(); + if (localVarQueryParameterString) { + localVarPath += "?" + localVarQueryParameterString; + } + return { + url: localVarPath, + options: localVarRequestOptions, + }; + }, /** * @summary Version * @param {RequestInit} [options] Override http request option. @@ -1113,6 +1156,36 @@ export const ApiApiFp = function(configuration?: Configuration) { }); }; }, + /** + * @summary Upsert + * @param {string} collectionName + * @param {Api.AddEmbedding} request + * @param {RequestInit} [options] Override http request option. + * @throws {RequiredError} + */ + upsert(collectionName: string, request: Api.AddEmbedding, options?: RequestInit): (fetch?: FetchAPI, basePath?: string) => Promise { + const localVarFetchArgs = ApiApiFetchParamCreator(configuration).upsert(collectionName, request, options); + return (fetch: FetchAPI = defaultFetch, basePath: string = BASE_PATH) => { + return fetch(basePath + localVarFetchArgs.url, localVarFetchArgs.options).then((response) => { + const contentType = response.headers.get('Content-Type'); + const mimeType = contentType ? contentType.replace(/;.*/, '') : undefined; + + if (response.status === 200) { + if (mimeType === 'application/json') { + return response.json() as any; + } + throw response; + } + if (response.status === 422) { + if (mimeType === 'application/json') { + throw response; + } + throw response; + } + throw response; + }); + }; + }, /** * @summary Version * @param {RequestInit} [options] Override http request option. @@ -1324,6 +1397,17 @@ export class ApiApi extends BaseAPI { return ApiApiFp(this.configuration).updateCollection(collectionName, request, options)(this.fetch, this.basePath); } + /** + * @summary Upsert + * @param {string} collectionName + * @param {Api.AddEmbedding} request + * @param {RequestInit} [options] Override http request option. + * @throws {RequiredError} + */ + public upsert(collectionName: string, request: Api.AddEmbedding, options?: RequestInit) { + return ApiApiFp(this.configuration).upsert(collectionName, request, options)(this.fetch, this.basePath); + } + /** * @summary Version * @param {RequestInit} [options] Override http request option. diff --git a/clients/js/src/generated/models.ts b/clients/js/src/generated/models.ts index 0ca3d7b2c9b..b17f701b622 100644 --- a/clients/js/src/generated/models.ts +++ b/clients/js/src/generated/models.ts @@ -315,6 +315,9 @@ export namespace Api { } + export interface Upsert200Response { + } + export interface ValidationError { loc: (string | number)[]; msg: string; diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index 4b53a35a58b..48015dc714e 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -170,16 +170,23 @@ export class Collection { this.metadata = metadata; } - public async add( + private async validate( + require_embeddings_or_documents: boolean, // set to false in the case of Update ids: string | string[], embeddings: number[] | number[][] | undefined, metadatas?: object | object[], documents?: string | string[], - increment_index: boolean = true ) { - if (embeddings === undefined && documents === undefined) { - throw new Error("embeddings and documents cannot both be undefined"); - } else if (embeddings === undefined && documents !== undefined) { + + if (require_embeddings_or_documents) { + if ((embeddings === undefined) && (documents === undefined)) { + throw new Error( + "embeddings and documents cannot both be undefined", + ); + } + } + + if ((embeddings === undefined) && (documents !== undefined)) { const documentsArray = toArray(documents); if (this.embeddingFunction !== undefined) { embeddings = await this.embeddingFunction.generate(documentsArray); @@ -222,21 +229,84 @@ export class Collection { ); } - return await this.api - .add(this.name, { + const uniqueIds = new Set(idsArray); + if (uniqueIds.size !== idsArray.length) { + const duplicateIds = idsArray.filter((item, index) => idsArray.indexOf(item) !== index); + throw new Error( + `Expected IDs to be unique, found duplicates for: ${duplicateIds}`, + ); + } + + return [idsArray, embeddingsArray, metadatasArray, documentsArray] + } + + public async add( + ids: string | string[], + embeddings: number[] | number[][] | undefined, + metadatas?: object | object[], + documents?: string | string[], + increment_index: boolean = true, + ) { + + const [idsArray, embeddingsArray, metadatasArray, documentsArray] = await this.validate( + true, + ids, + embeddings, + metadatas, + documents + ) + + const response = await this.api.add(this.name, + { + // @ts-ignore ids: idsArray, - embeddings: embeddingsArray, - //@ts-ignore + embeddings: embeddingsArray as number[][], // We know this is defined because of the validate function + // @ts-ignore documents: documentsArray, metadatas: metadatasArray, incrementIndex: increment_index, }) - .then(function (response: any) { - return JSON.parse(response); - }) + .then(handleSuccess) + .catch(handleError); + + return response + } + + public async upsert( + ids: string | string[], + embeddings: number[] | number[][] | undefined, + metadatas?: object | object[], + documents?: string | string[], + increment_index: boolean = true, + ) { + + const [idsArray, embeddingsArray, metadatasArray, documentsArray] = await this.validate( + true, + ids, + embeddings, + metadatas, + documents + ) + + const response = await this.api.upsert(this.name, + { + //@ts-ignore + ids: idsArray, + embeddings: embeddingsArray as number[][], // We know this is defined because of the validate function + //@ts-ignore + documents: documentsArray, + metadatas: metadatasArray, + increment_index: increment_index, + }, + ) + .then(handleSuccess) .catch(handleError); + + return response + } + public async count() { const response = await this.api.count(this.name); return handleSuccess(response); diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts index 956a3fcafa9..6ac078b5121 100644 --- a/clients/js/test/add.collections.test.ts +++ b/clients/js/test/add.collections.test.ts @@ -1,6 +1,7 @@ -import { expect, test } from "@jest/globals"; -import chroma from "./initClient"; -import { DOCUMENTS, EMBEDDINGS, IDS } from "./data"; +import { expect, test } from '@jest/globals'; +import chroma from './initClient' +import { DOCUMENTS, EMBEDDINGS, IDS } from './data'; +import { METADATAS } from './data'; import { IncludeEnum } from "../src/types"; test("it should add single embeddings to a collection", async () => { @@ -37,3 +38,25 @@ test("add documents", async () => { const results = await collection.get(["test1"]); expect(results.documents[0]).toBe("This is a test"); }); + +test('it should return an error when inserting an ID that alreay exists in the Collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + await collection.add(IDS, EMBEDDINGS, METADATAS) + const results = await collection.add(IDS, EMBEDDINGS, METADATAS); + expect(results.error).toBeDefined() + expect(results.error).toContain("IDAlreadyExists") +}) + +test('It should return an error when inserting duplicate IDs in the same batch', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = IDS.concat(["test1"]) + const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + const metadatas = METADATAS.concat([{ test: 'test1', 'float_value': 0.1 }]) + try { + await collection.add(ids, embeddings, metadatas); + } catch (e: any) { + expect(e.message).toMatch('duplicates') + } +}) \ No newline at end of file diff --git a/clients/js/test/client.test.ts b/clients/js/test/client.test.ts index 1a39d6159a6..06595c0149b 100644 --- a/clients/js/test/client.test.ts +++ b/clients/js/test/client.test.ts @@ -3,38 +3,188 @@ import { ChromaClient } from "../src/index"; import chroma from "./initClient"; test("it should create the client connection", async () => { - expect(chroma).toBeDefined(); - expect(chroma).toBeInstanceOf(ChromaClient); + expect(chroma).toBeDefined(); + expect(chroma).toBeInstanceOf(ChromaClient); }); test("it should get the version", async () => { - const version = await chroma.version(); - expect(version).toBeDefined(); - expect(version).toMatch(/^[0-9]+\.[0-9]+\.[0-9]+$/); + const version = await chroma.version(); + expect(version).toBeDefined(); + expect(version).toMatch(/^[0-9]+\.[0-9]+\.[0-9]+$/); }); test("it should get the heartbeat", async () => { - const heartbeat = await chroma.heartbeat(); - expect(heartbeat).toBeDefined(); - expect(heartbeat).toBeGreaterThan(0); + const heartbeat = await chroma.heartbeat(); + expect(heartbeat).toBeDefined(); + expect(heartbeat).toBeGreaterThan(0); }); test("it should reset the database", async () => { - await chroma.reset(); - const collections = await chroma.listCollections(); - expect(collections).toBeDefined(); - expect(collections).toBeInstanceOf(Array); - expect(collections.length).toBe(0); - - const collection = await chroma.createCollection("test"); - const collections2 = await chroma.listCollections(); - expect(collections2).toBeDefined(); - expect(collections2).toBeInstanceOf(Array); - expect(collections2.length).toBe(1); - - await chroma.reset(); - const collections3 = await chroma.listCollections(); - expect(collections3).toBeDefined(); - expect(collections3).toBeInstanceOf(Array); - expect(collections3.length).toBe(0); + await chroma.reset(); + const collections = await chroma.listCollections(); + expect(collections).toBeDefined(); + expect(collections).toBeInstanceOf(Array); + expect(collections.length).toBe(0); + + const collection = await chroma.createCollection("test"); + const collections2 = await chroma.listCollections(); + expect(collections2).toBeDefined(); + expect(collections2).toBeInstanceOf(Array); + expect(collections2.length).toBe(1); + + await chroma.reset(); + const collections3 = await chroma.listCollections(); + expect(collections3).toBeDefined(); + expect(collections3).toBeInstanceOf(Array); + expect(collections3.length).toBe(0); }); + +test('it should list collections', async () => { + await chroma.reset() + let collections = await chroma.listCollections() + expect(collections).toBeDefined() + expect(collections).toBeInstanceOf(Array) + expect(collections.length).toBe(0) + const collection = await chroma.createCollection('test') + collections = await chroma.listCollections() + expect(collections.length).toBe(1) +}) + +test('it should get a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const collection2 = await chroma.getCollection('test') + expect(collection).toBeDefined() + expect(collection2).toBeDefined() + expect(collection).toHaveProperty('name') + expect(collection2).toHaveProperty('name') + expect(collection.name).toBe(collection2.name) +}) + +test('it should delete a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + let collections = await chroma.listCollections() + expect(collections.length).toBe(1) + var resp = await chroma.deleteCollection('test') + collections = await chroma.listCollections() + expect(collections.length).toBe(0) +}) + +test('it should add single embeddings to a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const id = 'test1' + const embedding = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + const metadata = { test: 'test' } + await collection.add(id, embedding, metadata) + const count = await collection.count() + expect(count).toBe(1) +}) + +test('it should add batch embeddings to a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = ['test1', 'test2', 'test3'] + const embeddings = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + ] + await collection.add(ids, embeddings) + const count = await collection.count() + expect(count).toBe(3) +}) + +test('it should query a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = ['test1', 'test2', 'test3'] + const embeddings = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + ] + await collection.add(ids, embeddings) + const results = await collection.query([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 2) + expect(results).toBeDefined() + expect(results).toBeInstanceOf(Object) + // expect(results.embeddings[0].length).toBe(2) + expect(['test1', 'test2']).toEqual(expect.arrayContaining(results.ids[0])); + expect(['test3']).not.toEqual(expect.arrayContaining(results.ids[0])); +}) + +test('it should peek a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = ['test1', 'test2', 'test3'] + const embeddings = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + ] + await collection.add(ids, embeddings) + const results = await collection.peek(2) + expect(results).toBeDefined() + expect(results).toBeInstanceOf(Object) + expect(results.ids.length).toBe(2) + expect(['test1', 'test2']).toEqual(expect.arrayContaining(results.ids)); +}) + +test('it should get a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = ['test1', 'test2', 'test3'] + const embeddings = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + ] + const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }] + await collection.add(ids, embeddings, metadatas) + const results = await collection.get(['test1']) + expect(results).toBeDefined() + expect(results).toBeInstanceOf(Object) + expect(results.ids.length).toBe(1) + expect(['test1']).toEqual(expect.arrayContaining(results.ids)); + expect(['test2']).not.toEqual(expect.arrayContaining(results.ids)); + + const results2 = await collection.get(undefined, { 'test': 'test1' }) + expect(results2).toBeDefined() + expect(results2).toBeInstanceOf(Object) + expect(results2.ids.length).toBe(1) +}) + +test('it should delete a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = ['test1', 'test2', 'test3'] + const embeddings = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + ] + const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }] + await collection.add(ids, embeddings, metadatas) + let count = await collection.count() + expect(count).toBe(3) + var resp = await collection.delete(undefined, { 'test': 'test1' }) + count = await collection.count() + expect(count).toBe(2) +}) + +test('wrong code returns an error', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = ['test1', 'test2', 'test3'] + const embeddings = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + ] + const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }] + await collection.add(ids, embeddings, metadatas) + const results = await collection.get(undefined, { "test": { "$contains": "hello" } }); + expect(results.error).toBeDefined() + expect(results.error).toBe("ValueError('Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got $contains')") +}) diff --git a/clients/js/test/upsert.collections.test.ts b/clients/js/test/upsert.collections.test.ts new file mode 100644 index 00000000000..4bbf3cb0981 --- /dev/null +++ b/clients/js/test/upsert.collections.test.ts @@ -0,0 +1,29 @@ +import { expect, test } from '@jest/globals'; +import chroma from './initClient' +import { DOCUMENTS, EMBEDDINGS, IDS } from './data'; +import { METADATAS } from './data'; + + +test('it should upsert embeddings to a collection', async () => { + await chroma.reset() + const collection = await chroma.createCollection('test') + const ids = ['test1', 'test2'] + const embeddings = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + ] + await collection.add(ids, embeddings) + const count = await collection.count() + expect(count).toBe(2) + + const ids2 = ["test2", "test3"] + const embeddings2 = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 15], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ] + + await collection.upsert(ids2, embeddings2) + + const count2 = await collection.count() + expect(count2).toBe(3) +}) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e3f33840277..55ca469031f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,12 +25,13 @@ dependencies = [ 'fastapi >= 0.85.1', 'uvicorn[standard] >= 0.18.3', 'numpy >= 1.21.6', - 'posthog >= 2.4.0' + 'posthog >= 2.4.0', + 'typing_extensions >= 4.5.0' ] [tool.black] -line-length = 100 -required-version = "22.10.0" # Black will refuse to run if it's not this version. +line-length = 88 +required-version = "23.3.0" # Black will refuse to run if it's not this version. target-version = ['py36', 'py37', 'py38', 'py39', 'py310'] [tool.pytest.ini_options] diff --git a/requirements.txt b/requirements.txt index 267522d04c8..5bb66359137 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ hnswlib==0.7.0 clickhouse-connect==0.5.7 pydantic==1.9.0 sentence-transformers==2.2.2 -posthog==2.4.0 \ No newline at end of file +posthog==2.4.0 +typing_extensions==4.5.0 \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index df8913b3c04..78456ffcafa 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -2,4 +2,6 @@ build pytest setuptools_scm httpx -black==22.10.0 # match what's in pyproject.toml +black==23.3.0 # match what's in pyproject.toml +hypothesis +hypothesis[numpy] \ No newline at end of file