Skip to content

Commit

Permalink
Merge pull request #385 from chroma-core/lukev/anton-upsert
Browse files Browse the repository at this point in the history
Add Upsert to Collection with Hypothesis tests
  • Loading branch information
levand authored Apr 20, 2023
2 parents 29dff1a + b8f4db4 commit d205f87
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 96 deletions.
33 changes: 28 additions & 5 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _add(
⚠️ It is recommended to use the more specific methods below when possible.
Args:
collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to
collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to
embedding (Sequence[Sequence[float]]): The sequence of embeddings to add
metadata (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None.
documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None.
Expand All @@ -166,17 +166,40 @@ def _update(
⚠️ It is recommended to use the more specific methods below when possible.
Args:
collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to
collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to
embedding (Sequence[Sequence[float]]): The sequence of embeddings to add
"""
pass

@abstractmethod
def _upsert(
self,
collection_name: str,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
):
"""Add or update entries in the embedding store.
If an entry with the same id already exists, it will be updated, otherwise it will be added.
Args:
collection_name (str): The collection to add the embeddings to
ids (Optional[Union[str, Sequence[str]]], optional): The ids to associate with the embeddings. Defaults to None.
embeddings (Sequence[Sequence[float]]): The sequence of embeddings to add
metadatas (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None.
documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None.
increment_index (bool, optional): If True, will incrementally add to the ANN index of the collection. Defaults to True.
"""
pass

@abstractmethod
def _count(self, collection_name: str) -> int:
"""Returns the number of embeddings in the database
Args:
collection_name (str): The model space to count the embeddings in.
collection_name (str): The collection to count the embeddings in.
Returns:
int: The number of embeddings in the collection
Expand Down Expand Up @@ -282,11 +305,11 @@ def raw_sql(self, sql: str) -> pd.DataFrame:

@abstractmethod
def create_index(self, collection_name: Optional[str] = None) -> bool:
"""Creates an index for the given model space
"""Creates an index for the given collection
⚠️ This method should not be used directly.
Args:
collection_name (Optional[str], optional): The model space to create the index for. Uses the client's model space if None. Defaults to None.
collection_name (Optional[str], optional): The collection to create the index for. Uses the client's collection if None. Defaults to None.
Returns:
bool: True if the index was created successfully
Expand Down
32 changes: 31 additions & 1 deletion chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def _add(
self._api_url + "/collections/" + collection_name + "/add",
data=json.dumps(
{
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
"ids": ids,
"increment_index": increment_index,
}
),
Expand Down Expand Up @@ -224,6 +224,36 @@ def _update(
resp.raise_for_status()
return True

def _upsert(
self,
collection_name: str,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
):
"""
Updates a batch of embeddings in the database
- pass in column oriented data lists
"""

resp = requests.post(
self._api_url + "/collections/" + collection_name + "/upsert",
data=json.dumps(
{
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
"increment_index": increment_index,
}
),
)

resp.raise_for_status()
return True

def _query(
self,
collection_name,
Expand Down
63 changes: 62 additions & 1 deletion chromadb/api/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def check_index_name(index_name):
raise ValueError(msg)
if ".." in index_name:
raise ValueError(msg)
if re.match("^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$", index_name):
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name):
raise ValueError(msg)


Expand Down Expand Up @@ -160,6 +160,67 @@ def _update(

return True

def _upsert(
self,
collection_name: str,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
):
# Determine which ids need to be added and which need to be updated based on the ids already in the collection
existing_ids = set(self._get(collection_name, ids=ids, include=[])['ids'])


ids_to_add = []
ids_to_update = []
embeddings_to_add: Embeddings = []
embeddings_to_update: Embeddings = []
metadatas_to_add: Optional[Metadatas] = [] if metadatas else None
metadatas_to_update: Optional[Metadatas] = [] if metadatas else None
documents_to_add: Optional[Documents] = [] if documents else None
documents_to_update: Optional[Documents] = [] if documents else None

for i, id in enumerate(ids):
if id in existing_ids:
ids_to_update.append(id)
if embeddings is not None:
embeddings_to_update.append(embeddings[i])
if metadatas is not None:
metadatas_to_update.append(metadatas[i])
if documents is not None:
documents_to_update.append(documents[i])
else:
ids_to_add.append(id)
if embeddings is not None:
embeddings_to_add.append(embeddings[i])
if metadatas is not None:
metadatas_to_add.append(metadatas[i])
if documents is not None:
documents_to_add.append(documents[i])

if len(ids_to_add) > 0:
self._add(
ids_to_add,
collection_name,
embeddings_to_add,
metadatas_to_add,
documents_to_add,
increment_index=increment_index,
)

if len(ids_to_update) > 0:
self._update(
collection_name,
ids_to_update,
embeddings_to_update,
metadatas_to_update,
documents_to_update,
)

return True

def _get(
self,
collection_name: str,
Expand Down
126 changes: 69 additions & 57 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, cast, List, Dict
from typing import TYPE_CHECKING, Optional, cast, List, Dict, Tuple
from pydantic import BaseModel, PrivateAttr

from chromadb.api.types import (
Expand Down Expand Up @@ -79,34 +79,9 @@ def add(
ids: The ids to associate with the embeddings. Optional.
"""

ids = validate_ids(maybe_cast_one_to_many(ids))
embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None
metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None
documents = maybe_cast_one_to_many(documents) if documents else None

# Check that one of embeddings or documents is provided
if embeddings is None and documents is None:
raise ValueError("You must provide either embeddings or documents, or both")

# Check that, if they're provided, the lengths of the arrays match the length of ids
if embeddings is not None and len(embeddings) != len(ids):
raise ValueError(
f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}"
)
if metadatas is not None and len(metadatas) != len(ids):
raise ValueError(
f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}"
)
if documents is not None and len(documents) != len(ids):
raise ValueError(
f"Number of documents {len(documents)} must match number of ids {len(ids)}"
)

# If document embeddings are not provided, we need to compute them
if embeddings is None and documents is not None:
if self._embedding_function is None:
raise ValueError("You must provide embeddings or a function to compute them")
embeddings = self._embedding_function(documents)
ids, embeddings, metadatas, documents = self._validate_embedding_set(
ids, embeddings, metadatas, documents
)

self._client._add(ids, self.name, embeddings, metadatas, documents, increment_index)

Expand Down Expand Up @@ -237,18 +212,76 @@ def update(
documents: The documents to associate with the embeddings. Optional.
"""

ids, embeddings, metadatas, documents = self._validate_embedding_set(
ids, embeddings, metadatas, documents, require_embeddings_or_documents=False
)

self._client._update(self.name, ids, embeddings, metadatas, documents)

def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[OneOrMany[Embedding]] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
increment_index: bool = True,
):
"""Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.
Args:
ids: The ids of the embeddings to update
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
"""

ids, embeddings, metadatas, documents = self._validate_embedding_set(
ids, embeddings, metadatas, documents
)

self._client._upsert(
collection_name=self.name,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
increment_index=increment_index,
)

def delete(
self,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
):
"""Delete the embeddings based on ids and/or a where filter
Args:
ids: The ids of the embeddings to delete
where: A Where type dict used to filter the delection by. E.g. {"color" : "red", "price": 4.20}. Optional.
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. {$contains: {"text": "hello"}}. Optional.
"""
ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None
where = validate_where(where) if where else None
where_document = validate_where_document(where_document) if where_document else None
return self._client._delete(self.name, ids, where, where_document)

def create_index(self):
self._client.create_index(self.name)

def _validate_embedding_set(
self, ids, embeddings, metadatas, documents, require_embeddings_or_documents=True
) -> Tuple[IDs, Optional[List[Embedding]], Optional[List[Metadata]], Optional[List[Document]]]:

ids = validate_ids(maybe_cast_one_to_many(ids))
embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None
metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None
documents = maybe_cast_one_to_many(documents) if documents else None

# Must update one of embeddings, metadatas, or documents
if embeddings is None and documents is None and metadatas is None:
raise ValueError("You must update at least one of embeddings, documents or metadatas.")

# Check that one of embeddings or documents is provided
if embeddings is not None and documents is None:
raise ValueError("You must provide updated documents with updated embeddings")
if require_embeddings_or_documents:
if embeddings is None and documents is None:
raise ValueError("You must provide either embeddings or documents, or both")

# Check that, if they're provided, the lengths of the arrays match the length of ids
if embeddings is not None and len(embeddings) != len(ids):
Expand All @@ -270,25 +303,4 @@ def update(
raise ValueError("You must provide embeddings or a function to compute them")
embeddings = self._embedding_function(documents)

self._client._update(self.name, ids, embeddings, metadatas, documents)

def delete(
self,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
):
"""Delete the embeddings based on ids and/or a where filter
Args:
ids: The ids of the embeddings to delete
where: A Where type dict used to filter the delection by. E.g. {"color" : "red", "price": 4.20}. Optional.
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. {$contains: {"text": "hello"}}. Optional.
"""
ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None
where = validate_where(where) if where else None
where_document = validate_where_document(where_document) if where_document else None
return self._client._delete(self.name, ids, where, where_document)

def create_index(self):
self._client.create_index(self.name)
return ids, embeddings, metadatas, documents
13 changes: 13 additions & 0 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self, settings):
self.router.add_api_route(
"/api/v1/collections/{collection_name}/update", self.update, methods=["POST"]
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}/upsert", self.upsert, methods=["POST"]
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}/get", self.get, methods=["POST"]
)
Expand Down Expand Up @@ -176,6 +179,16 @@ def update(self, collection_name: str, add: UpdateEmbedding):
metadatas=add.metadatas,
)

def upsert(self, collection_name: str, upsert: AddEmbedding):
return self._api._upsert(
collection_name=collection_name,
ids=upsert.ids,
embeddings=upsert.embeddings,
documents=upsert.documents,
metadatas=upsert.metadatas,
increment_index=upsert.increment_index,
)

def get(self, collection_name, get: GetEmbedding):
return self._api._get(
collection_name=collection_name,
Expand Down
6 changes: 3 additions & 3 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ def metadatas_strategy(count: int) -> st.SearchStrategy[Optional[List[types.Meta
)


default_id_st = st.text(alphabet=legal_id_characters, min_size=1, max_size=64)

@st.composite
def embedding_set(
draw,
dimension_st: st.SearchStrategy[int] = st.integers(min_value=2, max_value=2048),
count_st: st.SearchStrategy[int] = st.integers(min_value=1, max_value=512),
dtype_st: st.SearchStrategy[np.dtype] = st.sampled_from(float_types),
id_st: st.SearchStrategy[str] = st.text(
alphabet=legal_id_characters, min_size=1, max_size=64
),
id_st: st.SearchStrategy[str] = default_id_st,
documents_st_fn: Callable[
[int], st.SearchStrategy[Optional[List[str]]]
] = documents_strategy,
Expand Down
Loading

0 comments on commit d205f87

Please sign in to comment.