diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index 0cfa760d6bb..0fee8698020 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -5,7 +5,11 @@ from chromadb.api.types import IndexMetadata import hnswlib from chromadb.db.index import Index -from chromadb.errors import NoIndexException, InvalidDimensionException, NotEnoughElementsException +from chromadb.errors import ( + NoIndexException, + InvalidDimensionException, + NotEnoughElementsException, +) import logging import re from uuid import UUID @@ -24,7 +28,6 @@ class HnswParams: - space: str construction_ef: int search_ef: int @@ -33,7 +36,6 @@ class HnswParams: resize_factor: float def __init__(self, metadata): - metadata = metadata or {} # Convert all values to strings for future compatibility. @@ -44,7 +46,9 @@ def __init__(self, metadata): if param not in valid_params: raise ValueError(f"Unknown HNSW parameter: {param}") if not re.match(valid_params[param], value): - raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") + raise ValueError( + f"Invalid value for HNSW parameter: {param} = {value}" + ) self.space = metadata.get("hnsw:space", "l2") self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) @@ -71,7 +75,7 @@ class Hnswlib(Index): _index_metadata: IndexMetadata _params: HnswParams _id_to_label: Dict[str, int] - _label_to_id: Dict[int, str] + _label_to_id: Dict[int, UUID] def __init__(self, id, settings, metadata): self._save_folder = settings.persist_directory + "/index" @@ -128,7 +132,7 @@ def add(self, ids, embeddings, update=False): labels = [] for id in ids: - if id in self._id_to_label: + if hexid(id) in self._id_to_label: if update: labels.append(self._id_to_label[hexid(id)]) else: @@ -141,7 +145,9 @@ def add(self, ids, embeddings, update=False): labels.append(next_label) if self._index_metadata["elements"] > self._index.get_max_elements(): - new_size = max(self._index_metadata["elements"] * self._params.resize_factor, 1000) + new_size = max( + self._index_metadata["elements"] * self._params.resize_factor, 1000 + ) self._index.resize_index(int(new_size)) self._index.add_items(embeddings, labels) @@ -196,7 +202,6 @@ def _exists(self): return def _load(self): - if not os.path.exists(f"{self._save_folder}/index_{self._id}.bin"): return @@ -208,7 +213,9 @@ def _load(self): with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "rb") as f: self._index_metadata = pickle.load(f) - p = hnswlib.Index(space=self._params.space, dim=self._index_metadata["dimensionality"]) + p = hnswlib.Index( + space=self._params.space, dim=self._index_metadata["dimensionality"] + ) self._index = p self._index.load_index( f"{self._save_folder}/index_{self._id}.bin", @@ -218,9 +225,10 @@ def _load(self): self._index.set_num_threads(self._params.num_threads) def get_nearest_neighbors(self, query, k, ids=None): - if self._index is None: - raise NoIndexException("Index not found, please create an instance before querying") + raise NoIndexException( + "Index not found, please create an instance before querying" + ) # Check dimensionality self._check_dimensionality(query) @@ -245,8 +253,12 @@ def get_nearest_neighbors(self, query, k, ids=None): logger.debug(f"time to pre process our knn query: {time.time() - s2}") s3 = time.time() - database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function) + database_labels, distances = self._index.knn_query( + query, k=k, filter=filter_function + ) logger.debug(f"time to run knn query: {time.time() - s3}") - ids = [[self._label_to_id[label] for label in labels] for labels in database_labels] + ids = [ + [self._label_to_id[label] for label in labels] for labels in database_labels + ] return ids, distances diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index 8c98f101007..ed0b5796f7f 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -112,6 +112,14 @@ def no_duplicates(collection: Collection): assert len(ids) == len(set(ids)) +# These match what the spec of hnswlib is +distance_functions = { + "l2": lambda x, y: np.linalg.norm(x - y) ** 2, + "cosine": lambda x, y: 1 - np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y)), + "ip": lambda x, y: 1 - np.dot(x, y), +} + + def _exact_distances( query: types.Embeddings, targets: types.Embeddings, @@ -148,9 +156,20 @@ def ann_accuracy( # If we don't have embeddings, we can't do an ANN search return + # l2 is the default distance function + distance_function = distance_functions["l2"] + if "hnsw:space" in collection.metadata: + space = collection.metadata["hnsw:space"] + if space == "cosine": + distance_function = distance_functions["cosine"] + if space == "ip": + distance_function = distance_functions["ip"] + # Perform exact distance computation indices, distances = _exact_distances( - embeddings["embeddings"], embeddings["embeddings"] + embeddings["embeddings"], + embeddings["embeddings"], + distance_fn=distance_function, ) query_results = collection.query( @@ -176,7 +195,10 @@ def ann_accuracy( if id not in expected_ids: continue index = id_to_index[id] - assert np.allclose(distances_i[index], query_results["distances"][i][j]) + # TODO: IP distance is resulting in more noise than expected so atol=1e-5 + assert np.allclose( + distances_i[index], query_results["distances"][i][j], atol=1e-5 + ) assert np.allclose( embeddings["embeddings"][index], query_results["embeddings"][i][j] ) diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 121e1616ee3..26ecfa566fb 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -146,6 +146,11 @@ def collections(draw, add_filterable_data=False, with_hnsw_params=False): if metadata is None: metadata = {} metadata.update(test_hnsw_config) + # Sometimes, select a space at random + if draw(st.booleans()): + # TODO: pull the distance functions from a source of truth that lives not + # in tests once https://github.com/chroma-core/issues/issues/61 lands + metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"])) known_metadata_keys = {} if add_filterable_data: diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 7a926ad4626..d6f9a508aea 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -25,31 +25,37 @@ version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$") -def _patch_uppercase_coll_name(collection: strategies.Collection, - embeddings: strategies.RecordSet): +def _patch_uppercase_coll_name( + collection: strategies.Collection, embeddings: strategies.RecordSet +): """Old versions didn't handle uppercase characters in collection names""" collection.name = collection.name.lower() -def _patch_empty_dict_metadata(collection: strategies.Collection, - embeddings: strategies.RecordSet): +def _patch_empty_dict_metadata( + collection: strategies.Collection, embeddings: strategies.RecordSet +): """Old versions do the wrong thing when metadata is a single empty dict""" if embeddings["metadatas"] == {}: embeddings["metadatas"] = None -version_patches = [("0.3.21", _patch_uppercase_coll_name), - ("0.3.21", _patch_empty_dict_metadata)] +version_patches = [ + ("0.3.21", _patch_uppercase_coll_name), + ("0.3.21", _patch_empty_dict_metadata), +] -def patch_for_version(version, - collection: strategies.Collection, - embeddings: strategies.RecordSet): +def patch_for_version( + version, collection: strategies.Collection, embeddings: strategies.RecordSet +): """Override aspects of the collection and embeddings, before testing, to account for breaking changes in old versions.""" for patch_version, patch in version_patches: - if packaging_version.Version(version) <= packaging_version.Version(patch_version): + if packaging_version.Version(version) <= packaging_version.Version( + patch_version + ): patch(collection, embeddings) @@ -84,9 +90,7 @@ def configurations(versions): # This fixture is not shared with the rest of the tests because it is unique in how it # installs the versions of chromadb -@pytest.fixture( - scope="module", params=configurations(test_old_versions) -) +@pytest.fixture(scope="module", params=configurations(test_old_versions)) def version_settings(request) -> Generator[Tuple[str, Settings], None, None]: configuration = request.param version = configuration[0] @@ -172,7 +176,7 @@ def persist_generated_data_with_old_version( coll = api.create_collection( name=collection_strategy.name, metadata=collection_strategy.metadata, - embedding_function=collection_strategy.embedding_function, + embedding_function=lambda x: None, ) coll.add(**embeddings_strategy) # We can't use the invariants module here because it uses the current version diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 162c9a2fc2f..edb070089b3 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -53,13 +53,15 @@ class EmbeddingStateMachineStates: update_embeddings = "update_embeddings" upsert_embeddings = "upsert_embeddings" + collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") + class EmbeddingStateMachine(RuleBasedStateMachine): collection: Collection embedding_ids: Bundle = Bundle("embedding_ids") - def __init__(self, api = None): + def __init__(self, api=None): super().__init__() # For debug only, to run as class-based test if not api: @@ -73,7 +75,7 @@ def initialize(self, collection: strategies.Collection): self.collection = self.api.create_collection( name=collection.name, metadata=collection.metadata, - embedding_function=collection.embedding_function + embedding_function=collection.embedding_function, ) trace("init") self.on_state_change(EmbeddingStateMachineStates.initialize) @@ -84,8 +86,7 @@ def initialize(self, collection: strategies.Collection): "documents": [], } - @rule(target=embedding_ids, - embedding_set=strategies.recordsets(collection_st)) + @rule(target=embedding_ids, embedding_set=strategies.recordsets(collection_st)) def add_embeddings(self, embedding_set): trace("add_embeddings") self.on_state_change(EmbeddingStateMachineStates.add_embeddings) @@ -95,7 +96,9 @@ def add_embeddings(self, embedding_set): if len(normalized_embedding_set["ids"]) > 0: trace("add_more_embeddings") - if set(normalized_embedding_set["ids"]).intersection(set(self.embeddings["ids"])): + if set(normalized_embedding_set["ids"]).intersection( + set(self.embeddings["ids"]) + ): with pytest.raises(errors.IDAlreadyExistsError): self.collection.add(**embedding_set) return multiple() @@ -117,10 +120,14 @@ def delete_by_ids(self, ids): # Removing the precondition causes the tests to frequently fail as "unsatisfiable" # Using a value < 5 causes retries and lowers the number of valid samples @precondition(lambda self: len(self.embeddings["ids"]) >= 5) - @rule(embedding_set=strategies.recordsets(collection_strategy=collection_st, - id_strategy=embedding_ids, - min_size=1, - max_size=5)) + @rule( + embedding_set=strategies.recordsets( + collection_strategy=collection_st, + id_strategy=embedding_ids, + min_size=1, + max_size=5, + ) + ) def update_embeddings(self, embedding_set): trace("update embeddings") self.on_state_change(EmbeddingStateMachineStates.update_embeddings) @@ -129,10 +136,14 @@ def update_embeddings(self, embedding_set): # Using a value < 3 causes more retries and lowers the number of valid samples @precondition(lambda self: len(self.embeddings["ids"]) >= 3) - @rule(embedding_set=strategies.recordsets( - collection_strategy=collection_st, - id_strategy=st.one_of(embedding_ids, strategies.safe_text), - min_size=1, max_size=5)) + @rule( + embedding_set=strategies.recordsets( + collection_strategy=collection_st, + id_strategy=st.one_of(embedding_ids, strategies.safe_text), + min_size=1, + max_size=5, + ) + ) def upsert_embeddings(self, embedding_set): trace("upsert embeddings") self.on_state_change(EmbeddingStateMachineStates.upsert_embeddings) @@ -141,7 +152,7 @@ def upsert_embeddings(self, embedding_set): @invariant() def count(self): - invariants.count(self.collection, self.embeddings) #type: ignore + invariants.count(self.collection, self.embeddings) # type: ignore @invariant() def no_duplicates(self): @@ -150,7 +161,7 @@ def no_duplicates(self): @invariant() def ann_accuracy(self): invariants.ann_accuracy( - collection=self.collection, embeddings=self.embeddings, min_recall=0.95 #type: ignore + collection=self.collection, embeddings=self.embeddings, min_recall=0.95 # type: ignore ) def _upsert_embeddings(self, embeddings: strategies.RecordSet):