Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple spaces #457

Merged
merged 5 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions chromadb/db/index/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from chromadb.api.types import IndexMetadata
import hnswlib
from chromadb.db.index import Index
from chromadb.errors import NoIndexException, InvalidDimensionException, NotEnoughElementsException
from chromadb.errors import (
NoIndexException,
InvalidDimensionException,
NotEnoughElementsException,
)
import logging
import re
from uuid import UUID
Expand All @@ -24,7 +28,6 @@


class HnswParams:

space: str
construction_ef: int
search_ef: int
Expand All @@ -33,7 +36,6 @@ class HnswParams:
resize_factor: float

def __init__(self, metadata):

metadata = metadata or {}

# Convert all values to strings for future compatibility.
Expand All @@ -44,7 +46,9 @@ def __init__(self, metadata):
if param not in valid_params:
raise ValueError(f"Unknown HNSW parameter: {param}")
if not re.match(valid_params[param], value):
raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}")
raise ValueError(
f"Invalid value for HNSW parameter: {param} = {value}"
)

self.space = metadata.get("hnsw:space", "l2")
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100))
Expand All @@ -71,7 +75,7 @@ class Hnswlib(Index):
_index_metadata: IndexMetadata
_params: HnswParams
_id_to_label: Dict[str, int]
_label_to_id: Dict[int, str]
_label_to_id: Dict[int, UUID]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below comment on how its odd that we store the string and the uuid in opposite direction. Updating the typing.


def __init__(self, id, settings, metadata):
self._save_folder = settings.persist_directory + "/index"
Expand Down Expand Up @@ -128,7 +132,7 @@ def add(self, ids, embeddings, update=False):

labels = []
for id in ids:
if id in self._id_to_label:
if hexid(id) in self._id_to_label:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the bugfix

if update:
labels.append(self._id_to_label[hexid(id)])
else:
Expand All @@ -141,7 +145,9 @@ def add(self, ids, embeddings, update=False):
labels.append(next_label)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is weird that sometimes we store the UUID vs the id but leaving this for now to minimize the bugs i might introduce.


if self._index_metadata["elements"] > self._index.get_max_elements():
new_size = max(self._index_metadata["elements"] * self._params.resize_factor, 1000)
new_size = max(
self._index_metadata["elements"] * self._params.resize_factor, 1000
)
self._index.resize_index(int(new_size))

self._index.add_items(embeddings, labels)
Expand Down Expand Up @@ -196,7 +202,6 @@ def _exists(self):
return

def _load(self):

if not os.path.exists(f"{self._save_folder}/index_{self._id}.bin"):
return

Expand All @@ -208,7 +213,9 @@ def _load(self):
with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "rb") as f:
self._index_metadata = pickle.load(f)

p = hnswlib.Index(space=self._params.space, dim=self._index_metadata["dimensionality"])
p = hnswlib.Index(
space=self._params.space, dim=self._index_metadata["dimensionality"]
)
self._index = p
self._index.load_index(
f"{self._save_folder}/index_{self._id}.bin",
Expand All @@ -218,9 +225,10 @@ def _load(self):
self._index.set_num_threads(self._params.num_threads)

def get_nearest_neighbors(self, query, k, ids=None):

if self._index is None:
raise NoIndexException("Index not found, please create an instance before querying")
raise NoIndexException(
"Index not found, please create an instance before querying"
)

# Check dimensionality
self._check_dimensionality(query)
Expand All @@ -245,8 +253,12 @@ def get_nearest_neighbors(self, query, k, ids=None):
logger.debug(f"time to pre process our knn query: {time.time() - s2}")

s3 = time.time()
database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function)
database_labels, distances = self._index.knn_query(
query, k=k, filter=filter_function
)
logger.debug(f"time to run knn query: {time.time() - s3}")

ids = [[self._label_to_id[label] for label in labels] for labels in database_labels]
ids = [
[self._label_to_id[label] for label in labels] for labels in database_labels
]
return ids, distances
26 changes: 24 additions & 2 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ def no_duplicates(collection: Collection):
assert len(ids) == len(set(ids))


# These match what the spec of hnswlib is
distance_functions = {
"l2": lambda x, y: np.linalg.norm(x - y) ** 2,
"cosine": lambda x, y: 1 - np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y)),
"ip": lambda x, y: 1 - np.dot(x, y),
}


def _exact_distances(
query: types.Embeddings,
targets: types.Embeddings,
Expand Down Expand Up @@ -148,9 +156,20 @@ def ann_accuracy(
# If we don't have embeddings, we can't do an ANN search
return

# l2 is the default distance function
distance_function = distance_functions["l2"]
if "hnsw:space" in collection.metadata:
space = collection.metadata["hnsw:space"]
if space == "cosine":
distance_function = distance_functions["cosine"]
if space == "ip":
distance_function = distance_functions["ip"]

# Perform exact distance computation
indices, distances = _exact_distances(
embeddings["embeddings"], embeddings["embeddings"]
embeddings["embeddings"],
embeddings["embeddings"],
distance_fn=distance_function,
)

query_results = collection.query(
Expand All @@ -176,7 +195,10 @@ def ann_accuracy(
if id not in expected_ids:
continue
index = id_to_index[id]
assert np.allclose(distances_i[index], query_results["distances"][i][j])
# TODO: IP distance is resulting in more noise than expected so atol=1e-5
assert np.allclose(
distances_i[index], query_results["distances"][i][j], atol=1e-5
)
assert np.allclose(
embeddings["embeddings"][index], query_results["embeddings"][i][j]
)
Expand Down
5 changes: 5 additions & 0 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def collections(draw, add_filterable_data=False, with_hnsw_params=False):
if metadata is None:
metadata = {}
metadata.update(test_hnsw_config)
# Sometimes, select a space at random
if draw(st.booleans()):
# TODO: pull the distance functions from a source of truth that lives not
# in tests once https://github.com/chroma-core/issues/issues/61 lands
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))

known_metadata_keys = {}
if add_filterable_data:
Expand Down
32 changes: 18 additions & 14 deletions chromadb/test/property/test_cross_version_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,37 @@
version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$")


def _patch_uppercase_coll_name(collection: strategies.Collection,
embeddings: strategies.RecordSet):
def _patch_uppercase_coll_name(
collection: strategies.Collection, embeddings: strategies.RecordSet
):
"""Old versions didn't handle uppercase characters in collection names"""
collection.name = collection.name.lower()


def _patch_empty_dict_metadata(collection: strategies.Collection,
embeddings: strategies.RecordSet):
def _patch_empty_dict_metadata(
collection: strategies.Collection, embeddings: strategies.RecordSet
):
"""Old versions do the wrong thing when metadata is a single empty dict"""
if embeddings["metadatas"] == {}:
embeddings["metadatas"] = None


version_patches = [("0.3.21", _patch_uppercase_coll_name),
("0.3.21", _patch_empty_dict_metadata)]
version_patches = [
("0.3.21", _patch_uppercase_coll_name),
("0.3.21", _patch_empty_dict_metadata),
]


def patch_for_version(version,
collection: strategies.Collection,
embeddings: strategies.RecordSet):
def patch_for_version(
version, collection: strategies.Collection, embeddings: strategies.RecordSet
):
"""Override aspects of the collection and embeddings, before testing, to account for
breaking changes in old versions."""

for patch_version, patch in version_patches:
if packaging_version.Version(version) <= packaging_version.Version(patch_version):
if packaging_version.Version(version) <= packaging_version.Version(
patch_version
):
patch(collection, embeddings)


Expand Down Expand Up @@ -84,9 +90,7 @@ def configurations(versions):

# This fixture is not shared with the rest of the tests because it is unique in how it
# installs the versions of chromadb
@pytest.fixture(
scope="module", params=configurations(test_old_versions)
)
@pytest.fixture(scope="module", params=configurations(test_old_versions))
def version_settings(request) -> Generator[Tuple[str, Settings], None, None]:
configuration = request.param
version = configuration[0]
Expand Down Expand Up @@ -172,7 +176,7 @@ def persist_generated_data_with_old_version(
coll = api.create_collection(
name=collection_strategy.name,
metadata=collection_strategy.metadata,
embedding_function=collection_strategy.embedding_function,
embedding_function=lambda x: None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the tests faster because we pass none across the process boundary because we can't pickle a lambda

)
coll.add(**embeddings_strategy)
# We can't use the invariants module here because it uses the current version
Expand Down
41 changes: 26 additions & 15 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ class EmbeddingStateMachineStates:
update_embeddings = "update_embeddings"
upsert_embeddings = "upsert_embeddings"


collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")


class EmbeddingStateMachine(RuleBasedStateMachine):
collection: Collection
embedding_ids: Bundle = Bundle("embedding_ids")

def __init__(self, api = None):
def __init__(self, api=None):
super().__init__()
# For debug only, to run as class-based test
if not api:
Expand All @@ -73,7 +75,7 @@ def initialize(self, collection: strategies.Collection):
self.collection = self.api.create_collection(
name=collection.name,
metadata=collection.metadata,
embedding_function=collection.embedding_function
embedding_function=collection.embedding_function,
)
trace("init")
self.on_state_change(EmbeddingStateMachineStates.initialize)
Expand All @@ -84,8 +86,7 @@ def initialize(self, collection: strategies.Collection):
"documents": [],
}

@rule(target=embedding_ids,
embedding_set=strategies.recordsets(collection_st))
@rule(target=embedding_ids, embedding_set=strategies.recordsets(collection_st))
def add_embeddings(self, embedding_set):
trace("add_embeddings")
self.on_state_change(EmbeddingStateMachineStates.add_embeddings)
Expand All @@ -95,7 +96,9 @@ def add_embeddings(self, embedding_set):
if len(normalized_embedding_set["ids"]) > 0:
trace("add_more_embeddings")

if set(normalized_embedding_set["ids"]).intersection(set(self.embeddings["ids"])):
if set(normalized_embedding_set["ids"]).intersection(
set(self.embeddings["ids"])
):
with pytest.raises(errors.IDAlreadyExistsError):
self.collection.add(**embedding_set)
return multiple()
Expand All @@ -117,10 +120,14 @@ def delete_by_ids(self, ids):
# Removing the precondition causes the tests to frequently fail as "unsatisfiable"
# Using a value < 5 causes retries and lowers the number of valid samples
@precondition(lambda self: len(self.embeddings["ids"]) >= 5)
@rule(embedding_set=strategies.recordsets(collection_strategy=collection_st,
id_strategy=embedding_ids,
min_size=1,
max_size=5))
@rule(
embedding_set=strategies.recordsets(
collection_strategy=collection_st,
id_strategy=embedding_ids,
min_size=1,
max_size=5,
)
)
def update_embeddings(self, embedding_set):
trace("update embeddings")
self.on_state_change(EmbeddingStateMachineStates.update_embeddings)
Expand All @@ -129,10 +136,14 @@ def update_embeddings(self, embedding_set):

# Using a value < 3 causes more retries and lowers the number of valid samples
@precondition(lambda self: len(self.embeddings["ids"]) >= 3)
@rule(embedding_set=strategies.recordsets(
collection_strategy=collection_st,
id_strategy=st.one_of(embedding_ids, strategies.safe_text),
min_size=1, max_size=5))
@rule(
embedding_set=strategies.recordsets(
collection_strategy=collection_st,
id_strategy=st.one_of(embedding_ids, strategies.safe_text),
min_size=1,
max_size=5,
)
)
def upsert_embeddings(self, embedding_set):
trace("upsert embeddings")
self.on_state_change(EmbeddingStateMachineStates.upsert_embeddings)
Expand All @@ -141,7 +152,7 @@ def upsert_embeddings(self, embedding_set):

@invariant()
def count(self):
invariants.count(self.collection, self.embeddings) #type: ignore
invariants.count(self.collection, self.embeddings) # type: ignore

@invariant()
def no_duplicates(self):
Expand All @@ -150,7 +161,7 @@ def no_duplicates(self):
@invariant()
def ann_accuracy(self):
invariants.ann_accuracy(
collection=self.collection, embeddings=self.embeddings, min_recall=0.95 #type: ignore
collection=self.collection, embeddings=self.embeddings, min_recall=0.95 # type: ignore
)

def _upsert_embeddings(self, embeddings: strategies.RecordSet):
Expand Down