Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate add to prevent duplicate IDs. #363

Merged
merged 10 commits into from
Apr 17, 2023
8 changes: 7 additions & 1 deletion chromadb/api/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Dict, List, Optional, Sequence, Callable, cast
from chromadb import __version__

import chromadb.errors as errors
from chromadb.api import API
from chromadb.db import DB
from chromadb.api.types import (
Expand Down Expand Up @@ -126,6 +126,12 @@ def _add(
increment_index: bool = True,
):

existing_ids = self._get(collection_name, ids=ids, include=[])["ids"]
if len(existing_ids) > 0:
raise errors.IDAlreadyExistsError(
f"IDs {existing_ids} already exist in collection {collection_name}"
)

collection_uuid = self._db.get_collection_uuid_from_name(collection_name)
added_uuids = self._db.add(
collection_uuid,
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Literal, Optional, Union, Dict, Sequence, TypedDict, Protocol, TypeVar, List
import chromadb.errors as errors

ID = str
IDs = List[ID]
Expand Down Expand Up @@ -84,6 +85,9 @@ def validate_ids(ids: IDs) -> IDs:
for id in ids:
if not isinstance(id, str):
raise ValueError(f"Expected ID to be a str, got {id}")
if len(ids) != len(set(ids)):
dups = set([x for x in ids if ids.count(x) > 1])
levand marked this conversation as resolved.
Show resolved Hide resolved
raise errors.DuplicateIDError(f"Expected IDs to be unique, found duplicates for: {dups}")
return ids


Expand Down
8 changes: 8 additions & 0 deletions chromadb/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@ class InvalidDimensionException(Exception):

class NotEnoughElementsException(Exception):
pass


class IDAlreadyExistsError(ValueError):
pass


class DuplicateIDError(ValueError):
pass
21 changes: 11 additions & 10 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hypothesis
import hypothesis.strategies as st
from typing import Optional, Sequence, TypedDict, cast
from typing import Optional, Sequence, TypedDict, Callable, List, cast
import hypothesis.extra.numpy as npst
import numpy as np
import chromadb.api.types as types
Expand Down Expand Up @@ -105,11 +105,11 @@ def create_embeddings(dim: int, count: int, dtype: np.dtype):
).astype(dtype)


def documents_strategy(count: int):
def documents_strategy(count: int) -> st.SearchStrategy[Optional[List[str]]]:
# TODO: Handle non-unique documents
# TODO: Handle empty string documents
return st.one_of(
st.lists(st.text(min_size=1), min_size=count, max_size=count, unique=True), st.none()
st.none(), st.lists(st.text(min_size=1), min_size=count, max_size=count, unique=True)
)


Expand All @@ -122,11 +122,8 @@ def metadata_strategy():
)


def metadatas_strategy(count: int):
return st.one_of(
st.lists(metadata_strategy(), min_size=count, max_size=count),
st.none(),
)
def metadatas_strategy(count: int) -> st.SearchStrategy[Optional[List[types.Metadata]]]:
return st.one_of(st.none(), st.lists(metadata_strategy(), min_size=count, max_size=count))


@st.composite
Expand All @@ -136,6 +133,10 @@ def embedding_set(
count_st: st.SearchStrategy[int] = st.integers(min_value=1, max_value=512),
dtype_st: st.SearchStrategy[np.dtype] = st.sampled_from(float_types),
id_st: st.SearchStrategy[str] = st.text(alphabet=legal_id_characters, min_size=1, max_size=64),
documents_st_fn: Callable[[int], st.SearchStrategy[Optional[List[str]]]] = documents_strategy,
metadatas_st_fn: Callable[
[int], st.SearchStrategy[Optional[List[types.Metadata]]]
] = metadatas_strategy,
dimension: Optional[int] = None,
count: Optional[int] = None,
dtype: Optional[np.dtype] = None,
Expand All @@ -157,8 +158,8 @@ def embedding_set(

# TODO: Test documents only
# TODO: Generative embedding function to guarantee unique embeddings for unique documents
documents = draw(documents_strategy(count))
metadatas = draw(metadatas_strategy(count))
documents = draw(documents_st_fn(count))
metadatas = draw(metadatas_st_fn(count))

embeddings = create_embeddings(dimension, count, dtype)

Expand Down
34 changes: 25 additions & 9 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hypothesis.strategies as st
from typing import List, Set, TypedDict, Sequence
import chromadb
import chromadb.errors as errors
from chromadb.api import API
from chromadb.api.models.Collection import Collection
from chromadb.test.configurations import configurations
Expand Down Expand Up @@ -67,7 +68,7 @@ class EmbeddingStateMachine(RuleBasedStateMachine):

def __init__(self, api):
super().__init__()
self.api = chromadb.Client(configurations()[0])
self.api = api

@initialize(
collection=strategies.collections(),
Expand All @@ -92,9 +93,14 @@ def add_embeddings(self, embedding_set):
if len(self.embeddings["ids"]) > 0:
trace("add_more_embeddings")

self.collection.add(**embedding_set)
self._add_embeddings(embedding_set)
return multiple(*embedding_set["ids"])
if set(embedding_set["ids"]).intersection(set(self.embeddings["ids"])):
with pytest.raises(errors.IDAlreadyExistsError):
self.collection.add(**embedding_set)
return multiple()
else:
self.collection.add(**embedding_set)
self._add_embeddings(embedding_set)
return multiple(*embedding_set["ids"])

@precondition(lambda self: len(self.embeddings["ids"]) > 20)
@rule(ids=st.lists(consumes(embedding_ids), min_size=1, max_size=20))
Expand All @@ -116,6 +122,9 @@ def delete_by_ids(self, ids):
dimension_st=dimension_st,
id_st=embedding_ids,
count_st=st.integers(min_value=1, max_value=5),
documents_st_fn=lambda c: st.lists(
st.text(min_size=1), min_size=c, max_size=c, unique=True
),
)
)
def update_embeddings(self, embedding_set):
Expand Down Expand Up @@ -152,8 +161,8 @@ def _add_embeddings(self, embeddings: strategies.EmbeddingSet):
else:
documents = [None] * len(embeddings["ids"])

self.embeddings["metadatas"] += metadatas # type: ignore
self.embeddings["documents"] += documents # type: ignore
self.embeddings["metadatas"].extend(metadatas) # type: ignore
self.embeddings["documents"].extend(documents) # type: ignore

def _remove_embeddings(self, indices_to_remove: Set[int]):

Expand Down Expand Up @@ -190,18 +199,25 @@ def test_multi_add(api):
coll.add(ids=["a"], embeddings=[[0.0]])
assert coll.count() == 1

with pytest.raises(ValueError):
with pytest.raises(errors.IDAlreadyExistsError):
coll.add(ids=["a"], embeddings=[[0.0]])

assert coll.count() == 1

results = coll.query(query_embeddings=[[0.0]], n_results=2)
assert results["ids"] == [["a"]]
results = coll.get()
assert results["ids"] == ["a"]

coll.delete(ids=["a"])
assert coll.count() == 0


def test_dup_add(api):
api.reset()
coll = api.create_collection(name="foo")
with pytest.raises(errors.DuplicateIDError):
coll.add(ids=["a", "a"], embeddings=[[0.0], [1.1]])


def test_escape_chars_in_ids(api):
api.reset()
id = "\x1f"
Expand Down
4 changes: 2 additions & 2 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_heartbeat(api_fixture, request):

batch_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com", "https://example.com"],
"ids": ["https://example.com/1", "https://example.com/2"],
}


Expand Down Expand Up @@ -251,7 +251,7 @@ def test_get_or_create(api_fixture, request):

minimal_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com", "https://example.com"],
"ids": ["https://example.com/1", "https://example.com/2"],
}


Expand Down