Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Upsert to Collection with Hypothesis tests #385

Merged
merged 21 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _add(
⚠️ It is recommended to use the more specific methods below when possible.

Args:
collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to
collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to
embedding (Sequence[Sequence[float]]): The sequence of embeddings to add
metadata (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None.
documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None.
Expand All @@ -166,17 +166,40 @@ def _update(
⚠️ It is recommended to use the more specific methods below when possible.

Args:
collection_name (Union[str, Sequence[str]]): The model space(s) to add the embeddings to
collection_name (Union[str, Sequence[str]]): The collection(s) to add the embeddings to
embedding (Sequence[Sequence[float]]): The sequence of embeddings to add
"""
pass

@abstractmethod
def _upsert(
self,
collection_name: str,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
):
"""Add or update entries in the embedding store.
If an entry with the same id already exists, it will be updated, otherwise it will be added.

Args:
collection_name (str): The collection to add the embeddings to
ids (Optional[Union[str, Sequence[str]]], optional): The ids to associate with the embeddings. Defaults to None.
embeddings (Sequence[Sequence[float]]): The sequence of embeddings to add
metadatas (Optional[Union[Dict, Sequence[Dict]]], optional): The metadata to associate with the embeddings. Defaults to None.
documents (Optional[Union[str, Sequence[str]]], optional): The documents to associate with the embeddings. Defaults to None.
increment_index (bool, optional): If True, will incrementally add to the ANN index of the collection. Defaults to True.
"""
pass

@abstractmethod
def _count(self, collection_name: str) -> int:
"""Returns the number of embeddings in the database

Args:
collection_name (str): The model space to count the embeddings in.
collection_name (str): The collection to count the embeddings in.

Returns:
int: The number of embeddings in the collection
Expand Down Expand Up @@ -282,11 +305,11 @@ def raw_sql(self, sql: str) -> pd.DataFrame:

@abstractmethod
def create_index(self, collection_name: Optional[str] = None) -> bool:
"""Creates an index for the given model space
"""Creates an index for the given collection
⚠️ This method should not be used directly.

Args:
collection_name (Optional[str], optional): The model space to create the index for. Uses the client's model space if None. Defaults to None.
collection_name (Optional[str], optional): The collection to create the index for. Uses the client's collection if None. Defaults to None.

Returns:
bool: True if the index was created successfully
Expand Down
32 changes: 31 additions & 1 deletion chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def _add(
self._api_url + "/collections/" + collection_name + "/add",
data=json.dumps(
{
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
"ids": ids,
"increment_index": increment_index,
}
),
Expand Down Expand Up @@ -224,6 +224,36 @@ def _update(
resp.raise_for_status()
return True

def _upsert(
self,
collection_name: str,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
):
"""
Updates a batch of embeddings in the database
- pass in column oriented data lists
"""

resp = requests.post(
self._api_url + "/collections/" + collection_name + "/upsert",
data=json.dumps(
{
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
"increment_index": increment_index,
}
),
)

resp.raise_for_status()
return True

def _query(
self,
collection_name,
Expand Down
63 changes: 62 additions & 1 deletion chromadb/api/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def check_index_name(index_name):
raise ValueError(msg)
if ".." in index_name:
raise ValueError(msg)
if re.match("^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$", index_name):
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name):
raise ValueError(msg)


Expand Down Expand Up @@ -160,6 +160,67 @@ def _update(

return True

def _upsert(
self,
collection_name: str,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
increment_index: bool = True,
):
# Determine which ids need to be added and which need to be updated based on the ids already in the collection
existing_ids = set(self._get(collection_name, ids=ids, include=[])['ids'])


ids_to_add = []
ids_to_update = []
embeddings_to_add: Embeddings = []
embeddings_to_update: Embeddings = []
metadatas_to_add: Optional[Metadatas] = [] if metadatas else None
metadatas_to_update: Optional[Metadatas] = [] if metadatas else None
documents_to_add: Optional[Documents] = [] if documents else None
documents_to_update: Optional[Documents] = [] if documents else None

for i, id in enumerate(ids):
if id in existing_ids:
ids_to_update.append(id)
if embeddings is not None:
embeddings_to_update.append(embeddings[i])
if metadatas is not None:
metadatas_to_update.append(metadatas[i])
if documents is not None:
documents_to_update.append(documents[i])
else:
ids_to_add.append(id)
if embeddings is not None:
embeddings_to_add.append(embeddings[i])
if metadatas is not None:
metadatas_to_add.append(metadatas[i])
if documents is not None:
documents_to_add.append(documents[i])

if len(ids_to_add) > 0:
self._add(
ids_to_add,
collection_name,
embeddings_to_add,
metadatas_to_add,
documents_to_add,
increment_index=increment_index,
)

if len(ids_to_update) > 0:
self._update(
collection_name,
ids_to_update,
embeddings_to_update,
metadatas_to_update,
documents_to_update,
)

return True

def _get(
self,
collection_name: str,
Expand Down
126 changes: 69 additions & 57 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, cast, List, Dict
from typing import TYPE_CHECKING, Optional, cast, List, Dict, Tuple
from pydantic import BaseModel, PrivateAttr

from chromadb.api.types import (
Expand Down Expand Up @@ -79,34 +79,9 @@ def add(
ids: The ids to associate with the embeddings. Optional.
"""

ids = validate_ids(maybe_cast_one_to_many(ids))
embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None
metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None
documents = maybe_cast_one_to_many(documents) if documents else None

# Check that one of embeddings or documents is provided
if embeddings is None and documents is None:
raise ValueError("You must provide either embeddings or documents, or both")

# Check that, if they're provided, the lengths of the arrays match the length of ids
if embeddings is not None and len(embeddings) != len(ids):
raise ValueError(
f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}"
)
if metadatas is not None and len(metadatas) != len(ids):
raise ValueError(
f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}"
)
if documents is not None and len(documents) != len(ids):
raise ValueError(
f"Number of documents {len(documents)} must match number of ids {len(ids)}"
)

# If document embeddings are not provided, we need to compute them
if embeddings is None and documents is not None:
if self._embedding_function is None:
raise ValueError("You must provide embeddings or a function to compute them")
embeddings = self._embedding_function(documents)
ids, embeddings, metadatas, documents = self._validate_embedding_set(
ids, embeddings, metadatas, documents
)

self._client._add(ids, self.name, embeddings, metadatas, documents, increment_index)

Expand Down Expand Up @@ -237,18 +212,76 @@ def update(
documents: The documents to associate with the embeddings. Optional.
"""

ids, embeddings, metadatas, documents = self._validate_embedding_set(
ids, embeddings, metadatas, documents, require_embeddings_or_documents=False
)

self._client._update(self.name, ids, embeddings, metadatas, documents)

def upsert(
self,
ids: OneOrMany[ID],
embeddings: Optional[OneOrMany[Embedding]] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
increment_index: bool = True,
):
"""Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist.

Args:
ids: The ids of the embeddings to update
embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional.
metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional.
documents: The documents to associate with the embeddings. Optional.
"""

ids, embeddings, metadatas, documents = self._validate_embedding_set(
ids, embeddings, metadatas, documents
)

self._client._upsert(
collection_name=self.name,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
increment_index=increment_index,
)

def delete(
self,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
):
"""Delete the embeddings based on ids and/or a where filter

Args:
ids: The ids of the embeddings to delete
where: A Where type dict used to filter the delection by. E.g. {"color" : "red", "price": 4.20}. Optional.
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. {$contains: {"text": "hello"}}. Optional.
"""
ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None
where = validate_where(where) if where else None
where_document = validate_where_document(where_document) if where_document else None
return self._client._delete(self.name, ids, where, where_document)

def create_index(self):
self._client.create_index(self.name)

def _validate_embedding_set(
self, ids, embeddings, metadatas, documents, require_embeddings_or_documents=True
) -> Tuple[IDs, Optional[List[Embedding]], Optional[List[Metadata]], Optional[List[Document]]]:

ids = validate_ids(maybe_cast_one_to_many(ids))
embeddings = maybe_cast_one_to_many(embeddings) if embeddings else None
metadatas = validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas else None
documents = maybe_cast_one_to_many(documents) if documents else None

# Must update one of embeddings, metadatas, or documents
if embeddings is None and documents is None and metadatas is None:
raise ValueError("You must update at least one of embeddings, documents or metadatas.")

# Check that one of embeddings or documents is provided
if embeddings is not None and documents is None:
raise ValueError("You must provide updated documents with updated embeddings")
if require_embeddings_or_documents:
if embeddings is None and documents is None:
raise ValueError("You must provide either embeddings or documents, or both")

# Check that, if they're provided, the lengths of the arrays match the length of ids
if embeddings is not None and len(embeddings) != len(ids):
Expand All @@ -270,25 +303,4 @@ def update(
raise ValueError("You must provide embeddings or a function to compute them")
embeddings = self._embedding_function(documents)

self._client._update(self.name, ids, embeddings, metadatas, documents)

def delete(
self,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
):
"""Delete the embeddings based on ids and/or a where filter

Args:
ids: The ids of the embeddings to delete
where: A Where type dict used to filter the delection by. E.g. {"color" : "red", "price": 4.20}. Optional.
where_document: A WhereDocument type dict used to filter the deletion by the document content. E.g. {$contains: {"text": "hello"}}. Optional.
"""
ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None
where = validate_where(where) if where else None
where_document = validate_where_document(where_document) if where_document else None
return self._client._delete(self.name, ids, where, where_document)

def create_index(self):
self._client.create_index(self.name)
return ids, embeddings, metadatas, documents
3 changes: 2 additions & 1 deletion chromadb/db/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ def update(

# Update the index
if embeddings is not None:
update_uuids = [x[1] for x in existing_items]
uuid_mapping = {r[4]: r[1] for r in existing_items}
update_uuids = [uuid_mapping[id] for id in ids]
index = self._index(collection_uuid)
index.add(update_uuids, embeddings, update=True)

Expand Down
13 changes: 13 additions & 0 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self, settings):
self.router.add_api_route(
"/api/v1/collections/{collection_name}/update", self.update, methods=["POST"]
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}/upsert", self.upsert, methods=["POST"]
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}/get", self.get, methods=["POST"]
)
Expand Down Expand Up @@ -176,6 +179,16 @@ def update(self, collection_name: str, add: UpdateEmbedding):
metadatas=add.metadatas,
)

def upsert(self, collection_name: str, upsert: AddEmbedding):
return self._api._upsert(
collection_name=collection_name,
ids=upsert.ids,
embeddings=upsert.embeddings,
documents=upsert.documents,
metadatas=upsert.metadatas,
increment_index=upsert.increment_index,
)

def get(self, collection_name, get: GetEmbedding):
return self._api._get(
collection_name=collection_name,
Expand Down
Loading