diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 526ebbba296..e3b6ff35e3b 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -145,7 +145,7 @@ def _add( ⚠️ It is recommended to use the more specific methods below when possible. Args: - collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to + collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to embedding (Sequence[Sequence[float]]): The sequence of embeddings to add metadata (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None. documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None. @@ -166,17 +166,40 @@ def _update( ⚠️ It is recommended to use the more specific methods below when possible. Args: - collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to + collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to embedding (Sequence[Sequence[float]]): The sequence of embeddings to add """ pass + @abstractmethod + def _upsert( + self, + collection_name: str, + ids: IDs, + embeddings: Optional[Embeddings] = None, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + increment_index: bool = True, + ): + """Add or update entries in the embedding store. + If an entry with the same id already exists, it will be updated, otherwise it will be added. + + Args: + collection_name (str): The collection to add the embeddings to + ids (Optional[Union[str, Sequence[str]]], optional): The ids to associate with the embeddings. Defaults to None. + embeddings (Sequence[Sequence[float]]): The sequence of embeddings to add + metadatas (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None. + documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None. + increment_index (bool, optional): If True, will incrementally add to the ANN index of the collection. Defaults to True. + """ + pass + @abstractmethod def _count(self, collection_name: str) -> int: """Returns the number of embeddings in the database Args: - collection_name (str): The model space to count the embeddings in. + collection_name (str): The collection to count the embeddings in. Returns: int: The number of embeddings in the collection @@ -282,11 +305,11 @@ def raw_sql(self, sql: str) -> pd.DataFrame: @abstractmethod def create_index(self, collection_name: Optional[str] = None) -> bool: - """Creates an index for the given model space + """Creates an index for the given collection ⚠️ This method should not be used directly. Args: - collection_name (Optional[str], optional): The model space to create the index for. Uses the client's model space if None. Defaults to None. + collection_name (Optional[str], optional): The collection to create the index for. Uses the client's collection if None. Defaults to None. Returns: bool: True if the index was created successfully diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 0be1a087fa7..cc5644f02ee 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -180,10 +180,10 @@ def _add( self._api_url + "/collections/" + collection_name + "/add", data=json.dumps( { + "ids": ids, "embeddings": embeddings, "metadatas": metadatas, "documents": documents, - "ids": ids, "increment_index": increment_index, } ), @@ -224,6 +224,36 @@ def _update( resp.raise_for_status() return True + def _upsert( + self, + collection_name: str, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + increment_index: bool = True, + ): + """ + Updates a batch of embeddings in the database + - pass in column oriented data lists + """ + + resp = requests.post( + self._api_url + "/collections/" + collection_name + "/upsert", + data=json.dumps( + { + "ids": ids, + "embeddings": embeddings, + "metadatas": metadatas, + "documents": documents, + "increment_index": increment_index, + } + ), + ) + + resp.raise_for_status() + return True + def _query( self, collection_name, diff --git a/chromadb/api/local.py b/chromadb/api/local.py index 7699b42ed1b..1658f23ec81 100644 --- a/chromadb/api/local.py +++ b/chromadb/api/local.py @@ -41,7 +41,7 @@ def check_index_name(index_name): raise ValueError(msg) if ".." in index_name: raise ValueError(msg) - if re.match("^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$", index_name): + if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): raise ValueError(msg) @@ -160,6 +160,67 @@ def _update( return True + def _upsert( + self, + collection_name: str, + ids: IDs, + embeddings: Embeddings, + metadatas: Optional[Metadatas] = None, + documents: Optional[Documents] = None, + increment_index: bool = True, + ): + # Determine which ids need to be added and which need to be updated based on the ids already in the collection + existing_ids = set(self._get(collection_name, ids=ids, include=[])['ids']) + + + ids_to_add = [] + ids_to_update = [] + embeddings_to_add: Embeddings = [] + embeddings_to_update: Embeddings = [] + metadatas_to_add: Optional[Metadatas] = [] if metadatas else None + metadatas_to_update: Optional[Metadatas] = [] if metadatas else None + documents_to_add: Optional[Documents] = [] if documents else None + documents_to_update: Optional[Documents] = [] if documents else None + + for i, id in enumerate(ids): + if id in existing_ids: + ids_to_update.append(id) + if embeddings is not None: + embeddings_to_update.append(embeddings[i]) + if metadatas is not None: + metadatas_to_update.append(metadatas[i]) + if documents is not None: + documents_to_update.append(documents[i]) + else: + ids_to_add.append(id) + if embeddings is not None: + embeddings_to_add.append(embeddings[i]) + if metadatas is not None: + metadatas_to_add.append(metadatas[i]) + if documents is not None: + documents_to_add.append(documents[i]) + + if len(ids_to_add) > 0: + self._add( + ids_to_add, + collection_name, + embeddings_to_add, + metadatas_to_add, + documents_to_add, + increment_index=increment_index, + ) + + if len(ids_to_update) > 0: + self._update( + collection_name, + ids_to_update, + embeddings_to_update, + metadatas_to_update, + documents_to_update, + ) + + return True + def _get( self, collection_name: str, diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index fa1151772d0..e855d8e8872 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, cast, List, Dict +from typing import TYPE_CHECKING, Optional, cast, List, Dict, Tuple from pydantic import BaseModel, PrivateAttr from chromadb.api.types import ( @@ -79,34 +79,9 @@ def add( ids: The ids to associate with the embeddings. Optional. """ - ids = validate_ids(maybe_cast_one_to_many(ids)) - embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None - metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None - documents = maybe_cast_one_to_many(documents) if documents else None - - # Check that one of embeddings or documents is provided - if embeddings is None and documents is None: - raise ValueError("You must provide either embeddings or documents, or both") - - # Check that, if they're provided, the lengths of the arrays match the length of ids - if embeddings is not None and len(embeddings) != len(ids): - raise ValueError( - f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}" - ) - if metadatas is not None and len(metadatas) != len(ids): - raise ValueError( - f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}" - ) - if documents is not None and len(documents) != len(ids): - raise ValueError( - f"Number of documents {len(documents)} must match number of ids {len(ids)}" - ) - - # If document embeddings are not provided, we need to compute them - if embeddings is None and documents is not None: - if self._embedding_function is None: - raise ValueError("You must provide embeddings or a function to compute them") - embeddings = self._embedding_function(documents) + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents + ) self._client._add(ids, self.name, embeddings, metadatas, documents, increment_index) @@ -237,18 +212,76 @@ def update( documents: The documents to associate with the embeddings. Optional. """ + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents, require_embeddings_or_documents=False + ) + + self._client._update(self.name, ids, embeddings, metadatas, documents) + + def upsert( + self, + ids: OneOrMany[ID], + embeddings: Optional[OneOrMany[Embedding]] = None, + metadatas: Optional[OneOrMany[Metadata]] = None, + documents: Optional[OneOrMany[Document]] = None, + increment_index: bool = True, + ): + """Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist. + + Args: + ids: The ids of the embeddings to update + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional. + metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. + documents: The documents to associate with the embeddings. Optional. + """ + + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents + ) + + self._client._upsert( + collection_name=self.name, + ids=ids, + embeddings=embeddings, + metadatas=metadatas, + documents=documents, + increment_index=increment_index, + ) + + def delete( + self, + ids: Optional[IDs] = None, + where: Optional[Where] = None, + where_document: Optional[WhereDocument] = None, + ): + """Delete the embeddings based on ids and/or a where filter + + Args: + ids: The ids of the embeddings to delete + where: A Where type dict used to filter the delection by. E.g. {"color" : "red", "price": 4.20}. Optional. + where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. {$contains: {"text": "hello"}}. Optional. + """ + ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None + where = validate_where(where) if where else None + where_document = validate_where_document(where_document) if where_document else None + return self._client._delete(self.name, ids, where, where_document) + + def create_index(self): + self._client.create_index(self.name) + + def _validate_embedding_set( + self, ids, embeddings, metadatas, documents, require_embeddings_or_documents=True + ) -> Tuple[IDs, Optional[List[Embedding]], Optional[List[Metadata]], Optional[List[Document]]]: + ids = validate_ids(maybe_cast_one_to_many(ids)) embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None documents = maybe_cast_one_to_many(documents) if documents else None - # Must update one of embeddings, metadatas, or documents - if embeddings is None and documents is None and metadatas is None: - raise ValueError("You must update at least one of embeddings, documents or metadatas.") - # Check that one of embeddings or documents is provided - if embeddings is not None and documents is None: - raise ValueError("You must provide updated documents with updated embeddings") + if require_embeddings_or_documents: + if embeddings is None and documents is None: + raise ValueError("You must provide either embeddings or documents, or both") # Check that, if they're provided, the lengths of the arrays match the length of ids if embeddings is not None and len(embeddings) != len(ids): @@ -270,25 +303,4 @@ def update( raise ValueError("You must provide embeddings or a function to compute them") embeddings = self._embedding_function(documents) - self._client._update(self.name, ids, embeddings, metadatas, documents) - - def delete( - self, - ids: Optional[IDs] = None, - where: Optional[Where] = None, - where_document: Optional[WhereDocument] = None, - ): - """Delete the embeddings based on ids and/or a where filter - - Args: - ids: The ids of the embeddings to delete - where: A Where type dict used to filter the delection by. E.g. {"color" : "red", "price": 4.20}. Optional. - where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. {$contains: {"text": "hello"}}. Optional. - """ - ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None - where = validate_where(where) if where else None - where_document = validate_where_document(where_document) if where_document else None - return self._client._delete(self.name, ids, where, where_document) - - def create_index(self): - self._client.create_index(self.name) + return ids, embeddings, metadatas, documents diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index cba6e1ad7fc..f1f461e760f 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -85,6 +85,9 @@ def __init__(self, settings): self.router.add_api_route( "/api/v1/collections/{collection_name}/update", self.update, methods=["POST"] ) + self.router.add_api_route( + "/api/v1/collections/{collection_name}/upsert", self.upsert, methods=["POST"] + ) self.router.add_api_route( "/api/v1/collections/{collection_name}/get", self.get, methods=["POST"] ) @@ -176,6 +179,16 @@ def update(self, collection_name: str, add: UpdateEmbedding): metadatas=add.metadatas, ) + def upsert(self, collection_name: str, upsert: AddEmbedding): + return self._api._upsert( + collection_name=collection_name, + ids=upsert.ids, + embeddings=upsert.embeddings, + documents=upsert.documents, + metadatas=upsert.metadatas, + increment_index=upsert.increment_index, + ) + def get(self, collection_name, get: GetEmbedding): return self._api._get( collection_name=collection_name, diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 4010e44d2dc..07d19361e43 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -140,15 +140,15 @@ def metadatas_strategy(count: int) -> st.SearchStrategy[Optional[List[types.Meta ) +default_id_st = st.text(alphabet=legal_id_characters, min_size=1, max_size=64) + @st.composite def embedding_set( draw, dimension_st: st.SearchStrategy[int] = st.integers(min_value=2, max_value=2048), count_st: st.SearchStrategy[int] = st.integers(min_value=1, max_value=512), dtype_st: st.SearchStrategy[np.dtype] = st.sampled_from(float_types), - id_st: st.SearchStrategy[str] = st.text( - alphabet=legal_id_characters, min_size=1, max_size=64 - ), + id_st: st.SearchStrategy[str] = default_id_st, documents_st_fn: Callable[ [int], st.SearchStrategy[Optional[List[str]]] ] = documents_strategy, diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index fda8bb765f2..51d53444944 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -92,7 +92,7 @@ def add_embeddings(self, embedding_set): return multiple() else: self.collection.add(**embedding_set) - self._add_embeddings(embedding_set) + self._upsert_embeddings(embedding_set) return multiple(*embedding_set["ids"]) @precondition(lambda self: len(self.embeddings["ids"]) > 20) @@ -122,7 +122,25 @@ def delete_by_ids(self, ids): def update_embeddings(self, embedding_set): trace("update embeddings") self.collection.update(**embedding_set) - self._update_embeddings(embedding_set) + self._upsert_embeddings(embedding_set) + + # Using a value < 3 causes more retries and lowers the number of valid samples + @precondition(lambda self: len(self.embeddings["ids"]) >= 3) + @rule( + embedding_set=strategies.embedding_set( + dtype_st=dtype_shared_st, + dimension_st=dimension_shared_st, + id_st=st.one_of(embedding_ids, strategies.default_id_st), + count_st=st.integers(min_value=1, max_value=5), + documents_st_fn=lambda c: st.lists( + st.text(min_size=1), min_size=c, max_size=c, unique=True + ), + ), + ) + def upsert_embeddings(self, embedding_set): + trace("upsert embeddings") + self.collection.upsert(**embedding_set) + self._upsert_embeddings(embedding_set) @invariant() def count(self): @@ -138,22 +156,30 @@ def ann_accuracy(self): collection=self.collection, embeddings=self.embeddings, min_recall=0.95 ) - def _add_embeddings(self, embeddings: strategies.EmbeddingSet): - self.embeddings["ids"].extend(embeddings["ids"]) - self.embeddings["embeddings"].extend(embeddings["embeddings"]) # type: ignore - - if "metadatas" in embeddings and embeddings["metadatas"] is not None: - metadatas = embeddings["metadatas"] - else: - metadatas = [None] * len(embeddings["ids"]) - - if "documents" in embeddings and embeddings["documents"] is not None: - documents = embeddings["documents"] - else: - documents = [None] * len(embeddings["ids"]) - - self.embeddings["metadatas"].extend(metadatas) # type: ignore - self.embeddings["documents"].extend(documents) # type: ignore + def _upsert_embeddings(self, embeddings: strategies.EmbeddingSet): + for idx, id in enumerate(embeddings["ids"]): + if id in self.embeddings["ids"]: + target_idx = self.embeddings["ids"].index(id) + if "embeddings" in embeddings and embeddings["embeddings"] is not None: + self.embeddings["embeddings"][target_idx] = embeddings["embeddings"][idx] + if "metadatas" in embeddings and embeddings["metadatas"] is not None: + self.embeddings["metadatas"][target_idx] = embeddings["metadatas"][idx] + if "documents" in embeddings and embeddings["documents"] is not None: + self.embeddings["documents"][target_idx] = embeddings["documents"][idx] + else: + self.embeddings["ids"].append(id) + if "embeddings" in embeddings and embeddings["embeddings"] is not None: + self.embeddings["embeddings"].append(embeddings["embeddings"][idx]) + else: + self.embeddings["embeddings"].append(None) + if "metadatas" in embeddings and embeddings["metadatas"] is not None: + self.embeddings["metadatas"].append(embeddings["metadatas"][idx]) + else: + self.embeddings["metadatas"].append(None) + if "documents" in embeddings and embeddings["documents"] is not None: + self.embeddings["documents"].append(embeddings["documents"][idx]) + else: + self.embeddings["documents"].append(None) def _remove_embeddings(self, indices_to_remove: Set[int]): indices_list = list(indices_to_remove) @@ -165,17 +191,6 @@ def _remove_embeddings(self, indices_to_remove: Set[int]): del self.embeddings["metadatas"][i] del self.embeddings["documents"][i] - def _update_embeddings(self, embeddings: strategies.EmbeddingSet): - for i in range(len(embeddings["ids"])): - idx = self.embeddings["ids"].index(embeddings["ids"][i]) - if embeddings["embeddings"]: - self.embeddings["embeddings"][idx] = embeddings["embeddings"][i] - if embeddings["metadatas"]: - self.embeddings["metadatas"][idx] = embeddings["metadatas"][i] - if embeddings["documents"]: - self.embeddings["documents"][idx] = embeddings["documents"][i] - - def test_embeddings_state(caplog, api): caplog.set_level(logging.ERROR) run_state_machine_as_test(lambda: EmbeddingStateMachine(api)) @@ -205,6 +220,8 @@ def test_dup_add(api: API): coll = api.create_collection(name="foo") with pytest.raises(errors.DuplicateIDError): coll.add(ids=["a", "a"], embeddings=[[0.0], [1.1]]) + with pytest.raises(errors.DuplicateIDError): + coll.upsert(ids=["a", "a"], embeddings=[[0.0], [1.1]]) # TODO: Use SQL escaping correctly internally diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index ce5a460750a..70c9cdd0603 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -1397,3 +1397,49 @@ def test_update_query(api_fixture, request): assert results["documents"][0][0] == updated_records["documents"][0] assert results["metadatas"][0][0]["foo"] == "bar" assert results["embeddings"][0][0] == updated_records["embeddings"][0] + + +initial_records = { + "embeddings": [[0, 0, 0], [1.2, 2.24, 3.2], [2.2, 3.24, 4.2]], + "ids": ["id1", "id2", "id3"], + "metadatas": [{"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2}, {"string_value": "three"}], + "documents": ["this document is first", "this document is second", "this document is third"], +} + +new_records = { + "embeddings": [[3.0, 3.0, 1.1], [3.2, 4.24, 5.2]], + "ids": ["id1", "id4"], + "metadatas": [{"int_value": 1, "string_value": "one_of_one", "float_value": 1.001}, {"int_value": 4}], + "documents": ["this document is even more first", "this document is new and fourth"], +} + +@pytest.mark.parametrize("api_fixture", test_apis) +def test_upsert(api_fixture, request): + api = request.getfixturevalue(api_fixture.__name__) + api.reset() + collection = api.create_collection("test") + + collection.add(**initial_records) + assert collection.count() == 3 + + collection.upsert(**new_records) + assert collection.count() == 4 + + get_result = collection.get(include=['embeddings', 'metadatas', 'documents'], ids=new_records['ids'][0]) + assert get_result['embeddings'][0] == new_records['embeddings'][0] + assert get_result['metadatas'][0] == new_records['metadatas'][0] + assert get_result['documents'][0] == new_records['documents'][0] + + query_result = collection.query(query_embeddings=get_result['embeddings'], n_results=1, include=['embeddings', 'metadatas', 'documents']) + assert query_result['embeddings'][0][0] == new_records['embeddings'][0] + assert query_result['metadatas'][0][0] == new_records['metadatas'][0] + assert query_result['documents'][0][0] == new_records['documents'][0] + + collection.delete(ids=initial_records['ids'][2]) + collection.upsert(ids=initial_records['ids'][2], embeddings=[[1.1, 0.99, 2.21]], metadatas=[{"string_value": "a new string value"}]) + assert collection.count() == 4 + + get_result = collection.get(include=['embeddings', 'metadatas', 'documents'], ids=['id3']) + assert get_result['embeddings'][0] == [1.1, 0.99, 2.21] + assert get_result['metadatas'][0] == {"string_value": "a new string value"} + assert get_result['documents'][0] == None \ No newline at end of file