Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate add to prevent duplicate IDs. #363

Merged
merged 10 commits into from
Apr 17, 2023
4 changes: 4 additions & 0 deletions chromadb/api/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def _add(
increment_index: bool = True,
):

existing_ids = set(self._get(collection_name, ids=ids, include=[])["ids"])
levand marked this conversation as resolved.
Show resolved Hide resolved
if len(existing_ids) > 0:
raise ValueError(f"IDs {existing_ids} already exist in collection {collection_name}")

collection_uuid = self._db.get_collection_uuid_from_name(collection_name)
added_uuids = self._db.add(
collection_uuid,
Expand Down
3 changes: 3 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def validate_ids(ids: IDs) -> IDs:
for id in ids:
if not isinstance(id, str):
raise ValueError(f"Expected ID to be a str, got {id}")
if len(ids) != len(set(ids)):
dups = set([x for x in ids if ids.count(x) > 1])
levand marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Expected IDs to be unique, found duplicates for: {dups}")
return ids


Expand Down
21 changes: 11 additions & 10 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hypothesis
import hypothesis.strategies as st
from typing import Optional, Sequence, TypedDict, cast
from typing import Optional, Sequence, TypedDict, Callable, List, cast
import hypothesis.extra.numpy as npst
import numpy as np
import chromadb.api.types as types
Expand Down Expand Up @@ -105,11 +105,11 @@ def create_embeddings(dim: int, count: int, dtype: np.dtype):
).astype(dtype)


def documents_strategy(count: int):
def documents_strategy(count: int) -> st.SearchStrategy[Optional[List[str]]]:
# TODO: Handle non-unique documents
# TODO: Handle empty string documents
return st.one_of(
st.lists(st.text(min_size=1), min_size=count, max_size=count, unique=True), st.none()
st.none(), st.lists(st.text(min_size=1), min_size=count, max_size=count, unique=True)
)


Expand All @@ -122,11 +122,8 @@ def metadata_strategy():
)


def metadatas_strategy(count: int):
return st.one_of(
st.lists(metadata_strategy(), min_size=count, max_size=count),
st.none(),
)
def metadatas_strategy(count: int) -> st.SearchStrategy[Optional[List[types.Metadata]]]:
return st.one_of(st.none(), st.lists(metadata_strategy(), min_size=count, max_size=count))


@st.composite
Expand All @@ -136,6 +133,10 @@ def embedding_set(
count_st: st.SearchStrategy[int] = st.integers(min_value=1, max_value=512),
dtype_st: st.SearchStrategy[np.dtype] = st.sampled_from(float_types),
id_st: st.SearchStrategy[str] = st.text(alphabet=legal_id_characters, min_size=1, max_size=64),
documents_st_fn: Callable[[int], st.SearchStrategy[Optional[List[str]]]] = documents_strategy,
metadatas_st_fn: Callable[
[int], st.SearchStrategy[Optional[List[types.Metadata]]]
] = metadatas_strategy,
dimension: Optional[int] = None,
count: Optional[int] = None,
dtype: Optional[np.dtype] = None,
Expand All @@ -157,8 +158,8 @@ def embedding_set(

# TODO: Test documents only
# TODO: Generative embedding function to guarantee unique embeddings for unique documents
documents = draw(documents_strategy(count))
metadatas = draw(metadatas_strategy(count))
documents = draw(documents_st_fn(count))
metadatas = draw(metadatas_st_fn(count))

embeddings = create_embeddings(dimension, count, dtype)

Expand Down
27 changes: 21 additions & 6 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class EmbeddingStateMachine(RuleBasedStateMachine):

def __init__(self, api):
super().__init__()
self.api = chromadb.Client(configurations()[0])
self.api = api

@initialize(
collection=strategies.collections(),
Expand All @@ -92,9 +92,14 @@ def add_embeddings(self, embedding_set):
if len(self.embeddings["ids"]) > 0:
trace("add_more_embeddings")

self.collection.add(**embedding_set)
self._add_embeddings(embedding_set)
return multiple(*embedding_set["ids"])
if set(embedding_set["ids"]).intersection(set(self.embeddings["ids"])):
with pytest.raises(ValueError):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test that the value error is correct.

self.collection.add(**embedding_set)
return multiple()
else:
self.collection.add(**embedding_set)
self._add_embeddings(embedding_set)
return multiple(*embedding_set["ids"])

@precondition(lambda self: len(self.embeddings["ids"]) > 20)
@rule(ids=st.lists(consumes(embedding_ids), min_size=1, max_size=20))
Expand All @@ -116,6 +121,9 @@ def delete_by_ids(self, ids):
dimension_st=dimension_st,
id_st=embedding_ids,
count_st=st.integers(min_value=1, max_value=5),
documents_st_fn=lambda c: st.lists(
st.text(min_size=1), min_size=c, max_size=c, unique=True
),
)
)
def update_embeddings(self, embedding_set):
Expand Down Expand Up @@ -195,13 +203,20 @@ def test_multi_add(api):

assert coll.count() == 1

results = coll.query(query_embeddings=[[0.0]], n_results=2)
assert results["ids"] == [["a"]]
results = coll.get()
assert results["ids"] == ["a"]

coll.delete(ids=["a"])
assert coll.count() == 0


def test_dup_add(api):
api.reset()
coll = api.create_collection(name="foo")
with pytest.raises(ValueError):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create more specific error type.

coll.add(ids=["a", "a"], embeddings=[[0.0], [1.1]])


def test_escape_chars_in_ids(api):
api.reset()
id = "\x1f"
Expand Down
4 changes: 2 additions & 2 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def test_heartbeat(api_fixture, request):

batch_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com", "https://example.com"],
"ids": ["https://example.com/1", "https://example.com/2"],
}


Expand Down Expand Up @@ -251,7 +251,7 @@ def test_get_or_create(api_fixture, request):

minimal_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com", "https://example.com"],
"ids": ["https://example.com/1", "https://example.com/2"],
}


Expand Down
8 changes: 8 additions & 0 deletions clients/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ export class Collection {
);
}

const uniqueIds = new Set(idsArray);
if (uniqueIds.size !== idsArray.length) {
const duplicateIds = idsArray.filter((item, index) => idsArray.indexOf(item) !== index);
throw new Error(
`Expected IDs to be unique, found duplicates for: ${duplicateIds}`,
);
}

const response = await this.api.add({
collectionName: this.name,
addEmbedding: {
Expand Down
34 changes: 34 additions & 0 deletions clients/js/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,38 @@ test('wrong code returns an error', async () => {
const results = await collection.get(undefined, { "test": { "$contains": "hello" } });
expect(results.error).toBeDefined()
expect(results.error).toBe("ValueError('Expected one of $gt, $lt, $gte, $lte, $ne, $eq, got $contains')")
})

test('it should return an error when inserting duplicate IDs', async () => {
await chroma.reset()
const collection = await chroma.createCollection('test')
const ids = ['test1', 'test2', 'test3']
const embeddings = [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
]
const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }]
await collection.add(ids, embeddings, metadatas)
const results = await collection.add(ids, embeddings, metadatas);
expect(results.error).toBeDefined()
expect(results.error).toContain("ValueError")
})

test('validation errors when inserting duplicate IDs in the same batch', async () => {
await chroma.reset()
const collection = await chroma.createCollection('test')
const ids = ['test1', 'test2', 'test3', 'test1']
const embeddings = [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
]
const metadatas = [{ test: 'test1' }, { test: 'test2' }, { test: 'test3' }, { test: 'test4' }]
try {
await collection.add(ids, embeddings, metadatas);
} catch (e: any) {
expect(e.message).toMatch('duplicates')
}
})