From bc6fcc6a5484d4a9fbe9b7312247316a9e47184d Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 11:38:59 -0400 Subject: [PATCH 1/9] refactor hnswlib class --- chromadb/db/index/__init__.py | 18 +-- chromadb/db/index/hnswlib.py | 288 +++++++++++++++++----------------- 2 files changed, 151 insertions(+), 155 deletions(-) diff --git a/chromadb/db/index/__init__.py b/chromadb/db/index/__init__.py index 7d504314993..06a132e1fa2 100644 --- a/chromadb/db/index/__init__.py +++ b/chromadb/db/index/__init__.py @@ -3,29 +3,21 @@ class Index(ABC): @abstractmethod - def __init__(self, settings): + def __init__(self, id, settings, metadata): pass @abstractmethod - def delete(self, collection_name): + def delete(self): pass @abstractmethod - def delete_from_index(self, collection_name, uuids): + def delete_from_index(self, ids): pass @abstractmethod - def reset(self): + def add(self, ids, embeddings, update=False): pass @abstractmethod - def run(self, collection_name, uuids, embeddings): - pass - - @abstractmethod - def has_index(self, collection_name): - pass - - @abstractmethod - def get_nearest_neighbors(self, collection_name, embedding, n_results, ids): + def get_nearest_neighbors(self, embedding, n_results, ids): pass diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index 47b8a5ca136..a4b68cd24fe 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -9,112 +9,149 @@ from chromadb.db.index import Index from chromadb.errors import NoIndexException, InvalidDimensionException import logging +import re +from uuid import UUID logger = logging.getLogger(__name__) +valid_params = { + "hnsw:space": r"^(l2|cosine|ip)$", + "hnsw:construction_ef": r"^\d+$", + "hnsw:search_ef": r"^\d+$", + "hnsw:M": r"^\d+$", + "hnsw:num_threads": r"^\d+$", + "hnsw:resize_factor": r"^\d+(\.\d+)?$", +} + + +class HnswParams: + + space: str + construction_ef: int + search_ef: int + M: int + num_threads: int + resize_factor: float + + def __init__(self, metadata): + + for param, value in metadata.items(): + if param.startswith("hnsw:"): + if param not in valid_params: + raise ValueError(f"Unknown HNSW parameter: {param}") + if not re.match(valid_params[param], value): + raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") + + self.space = metadata.get("hnsw:space", "l2") + self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) + self.search_ef = int(metadata.get("hnsw:search_ef", 10)) + self.M = int(metadata.get("hnsw:M", 16)) + self.num_threads = int(metadata.get("hnsw:num_threads", 4)) + self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2)) + + +def hexid(id): + """Backwards compatibility for old indexes which called uuid.hex on UUID ids""" + return id.hex if isinstance(id, UUID) else id + + class Hnswlib(Index): - _collection_uuid = None - _index = None - _index_metadata: Optional[IndexMetadata] = None + _id: str + _index: hnswlib.Index + _index_metadata: IndexMetadata + _params: HnswParams - _id_to_uuid = {} - _uuid_to_id = {} + # Mapping of IDs to HNSW integer labels + _id_to_label = {} + _label_to_id = {} - def __init__(self, settings): + def __init__(self, id, settings, metadata): self._save_folder = settings.persist_directory + "/index" + self._params = HnswParams(metadata) + self._id = id - def run(self, collection_uuid, uuids, embeddings, space="l2", ef=10, num_threads=4): + self._load() + + def _init_index(self, dimensionality): # more comments available at the source: https://github.com/nmslib/hnswlib - dimensionality = len(embeddings[0]) - for uuid, i in zip(uuids, range(len(uuids))): - self._id_to_uuid[i] = uuid - self._uuid_to_id[uuid.hex] = i index = hnswlib.Index( - space=space, dim=dimensionality + space=self._params.space, dim=dimensionality ) # possible options are l2, cosine or ip - index.init_index(max_elements=len(embeddings), ef_construction=100, M=16) - index.set_ef(ef) - index.set_num_threads(num_threads) - index.add_items(embeddings, range(len(uuids))) + index.init_index( + max_elements=1000, + ef_construction=self._params.construction_ef, + M=self._params.M, + ) + index.set_ef(self._params.search_ef) + index.set_num_threads(self._params.num_threads) self._index = index - self._collection_uuid = collection_uuid self._index_metadata = { "dimensionality": dimensionality, - "elements": len(embeddings), + "elements": 0, "time_created": time.time(), } self._save() - def get_metadata(self) -> IndexMetadata: - if self._index_metadata is None: - raise NoIndexException("Index is not initialized") - return self._index_metadata + def add(self, ids, embeddings, update=False): + """Add or update embeddings to the index""" - def add_incremental(self, collection_uuid, uuids, embeddings): - if self._collection_uuid != collection_uuid: - self._load(collection_uuid) + dim = len(embeddings[0]) if self._index is None: - self.run(collection_uuid, uuids, embeddings) - - elif self._index is not None: - idx_dimension = self.get_metadata()["dimensionality"] - # Check dimensionality - if idx_dimension != len(embeddings[0]): - raise InvalidDimensionException( - f"Dimensionality of new embeddings ({len(embeddings[0])}) does not match index dimensionality ({idx_dimension})" - ) - - current_elements = self._index_metadata["elements"] - new_elements = len(uuids) - - self._index.resize_index(current_elements + new_elements) - - # first map the uuids to ids, offset by the current number of elements - for uuid, i in zip(uuids, range(len(uuids))): - offset = current_elements + i - self._id_to_uuid[offset] = uuid - self._uuid_to_id[uuid.hex] = offset - - # add the new elements to the index - self._index.add_items( - embeddings, range(current_elements, current_elements + new_elements) - ) + self._init_index(dim) - # update the metadata - self._index_metadata["elements"] += new_elements + # Check dimensionality + idx_dim = self._index.get_dim() + if dim != idx_dim: + raise InvalidDimensionException( + f"Dimensionality of new embeddings ({dim}) does not match index dimensionality ({idx_dim})" + ) + labels = [] + for id in ids: + if id in self._id_to_label: + if update: + labels.append(self._id_to_label[hexid(id)]) + else: + raise ValueError(f"ID {id} already exists in index") + else: + self._index_metadata["elements"] += 1 + next_label = self._index_metadata["elements"] + self._id_to_label[hexid(id)] = next_label + self._label_to_id[next_label] = id + + if self._index_metadata["elements"] > self._index.get_max_elements(): + new_size = min(self._index_metadata["elements"] * self._params.resize_factor, 1000) + self._index.resize_index(new_size) + + self._index.add_items(embeddings, labels) self._save() - def delete(self, collection_uuid): + def delete(self): # delete files, dont throw error if they dont exist try: - os.remove(f"{self._save_folder}/id_to_uuid_{collection_uuid}.pkl") - os.remove(f"{self._save_folder}/uuid_to_id_{collection_uuid}.pkl") - os.remove(f"{self._save_folder}/index_metadata_{collection_uuid}.pkl") - os.remove(f"{self._save_folder}/index_{collection_uuid}.bin") + os.remove(f"{self._save_folder}/id_to_uuid_{self._id}.pkl") + os.remove(f"{self._save_folder}/uuid_to_id_{self._id}.pkl") + os.remove(f"{self._save_folder}/index_{self._id}.bin") + os.remove(f"{self._save_folder}/index_metadata_{self._id}.pkl") except: pass - if self._collection_uuid == collection_uuid: - self._index = None - self._collection_uuid = None - self._index_metadata = None - self._id_to_uuid = {} - self._uuid_to_id = {} - - def delete_from_index(self, collection_uuid, uuids): - if self._collection_uuid != collection_uuid: - self._load(collection_uuid) + self._index = None + self._collection_uuid = None + self._id_to_label = {} + self._label_to_id = {} + def delete_from_index(self, ids): if self._index is not None: - for uuid in uuids: - self._index.mark_deleted(self._uuid_to_id[uuid.hex]) - del self._id_to_uuid[self._uuid_to_id[uuid.hex]] - del self._uuid_to_id[uuid.hex] + for id in ids: + label = self._id_to_label[hexid(id)] + self._index.mark_deleted(label) + del self._label_to_id[label] + del self._id_to_label[hexid(id)] self._save() @@ -125,99 +162,66 @@ def _save(self): if self._index is None: return - self._index.save_index(f"{self._save_folder}/index_{self._collection_uuid}.bin") + self._index.save_index(f"{self._save_folder}/index_{self._id}.bin") # pickle the mappers - with open(f"{self._save_folder}/id_to_uuid_{self._collection_uuid}.pkl", "wb") as f: - pickle.dump(self._id_to_uuid, f, pickle.HIGHEST_PROTOCOL) - with open(f"{self._save_folder}/uuid_to_id_{self._collection_uuid}.pkl", "wb") as f: - pickle.dump(self._uuid_to_id, f, pickle.HIGHEST_PROTOCOL) - with open(f"{self._save_folder}/index_metadata_{self._collection_uuid}.pkl", "wb") as f: + # Use old filenames for backwards compatibility + with open(f"{self._save_folder}/id_to_uuid_{self._id}.pkl", "wb") as f: + pickle.dump(self._label_to_id, f, pickle.HIGHEST_PROTOCOL) + with open(f"{self._save_folder}/uuid_to_id_{self._id}.pkl", "wb") as f: + pickle.dump(self._id_to_label, f, pickle.HIGHEST_PROTOCOL) + with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "wb") as f: pickle.dump(self._index_metadata, f, pickle.HIGHEST_PROTOCOL) logger.debug(f"Index saved to {self._save_folder}/index.bin") - def load_if_not_loaded(self, collection_uuid): - if self._collection_uuid != collection_uuid: - self._load(collection_uuid) + def _exists(self): + return - def _load(self, collection_uuid): - # if we are calling load, we clearly need a different index than the one we have - self._index = None - - # unpickle the mappers - try: - with open(f"{self._save_folder}/id_to_uuid_{collection_uuid}.pkl", "rb") as f: - self._id_to_uuid = pickle.load(f) - with open(f"{self._save_folder}/uuid_to_id_{collection_uuid}.pkl", "rb") as f: - self._uuid_to_id = pickle.load(f) - with open(f"{self._save_folder}/index_metadata_{collection_uuid}.pkl", "rb") as f: - self._index_metadata = pickle.load(f) - p = hnswlib.Index(space="l2", dim=self._index_metadata["dimensionality"]) - self._index = p - self._index.load_index( - f"{self._save_folder}/index_{collection_uuid}.bin", - max_elements=self._index_metadata["elements"], - ) - - self._collection_uuid = collection_uuid - except: - logger.debug("Index not found") + def _load(self): - def has_index(self, collection_uuid): - return os.path.isfile(f"{self._save_folder}/index_{collection_uuid}.bin") + if not os.path.exists(f"{self._save_folder}/index_{self._id}.bin"): + return - def get_nearest_neighbors(self, collection_uuid, query, k, uuids=None): - if self._collection_uuid != collection_uuid: - self._load(collection_uuid) + # unpickle the mappers + with open(f"{self._save_folder}/id_to_uuid_{self._id}.pkl", "rb") as f: + self._id_to_uuid = pickle.load(f) + with open(f"{self._save_folder}/uuid_to_id_{self._id}.pkl", "rb") as f: + self._uuid_to_id = pickle.load(f) + with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "rb") as f: + self._index_metadata = pickle.load(f) + + p = hnswlib.Index(space=self._params.space, dim=self._index_metadata["dimensionality"]) + self._index = p + self._index.load_index( + f"{self._save_folder}/index_{self._id}.bin", + max_elements=self._index_metadata["elements"], + ) + self._index.set_ef(self._params.search_ef) + self._index.set_num_threads(self._params.num_threads) + + def get_nearest_neighbors(self, query, k, ids=None): if self._index is None: raise NoIndexException("Index not found, please create an instance before querying") s2 = time.time() # get ids from uuids as a set, if they are available - ids = {} - if uuids is not None: - ids = {self._uuid_to_id[uuid.hex] for uuid in uuids} - if len(ids) < k: + labels = {} + if ids is not None: + labels = {self._id_to_label[hexid(id)] for id in ids} + if len(labels) < k: k = len(ids) filter_function = None - if len(ids) != 0: + if len(labels) != 0: filter_function = lambda id: id in ids logger.debug(f"time to pre process our knn query: {time.time() - s2}") s3 = time.time() - database_ids, distances = self._index.knn_query(query, k=k, filter=filter_function) + database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function) logger.debug(f"time to run knn query: {time.time() - s3}") - uuids = [[self._id_to_uuid[id] for id in ids] for ids in database_ids] - return uuids, distances - - def reset(self): - self._id_to_uuid = {} - self._uuid_to_id = {} - self._index = None - self._collection_uuid = None - - if os.path.exists(f"{self._save_folder}"): - for f in os.listdir(f"{self._save_folder}"): - os.remove(os.path.join(f"{self._save_folder}", f)) - # recreate the directory - if not os.path.exists(f"{self._save_folder}"): - os.makedirs(f"{self._save_folder}") - - def delete_index(self, uuid): - uuid = str(uuid) - if self._collection_uuid == uuid: - self._index = None - self._collection_uuid = None - self._index_metadata = None - self._id_to_uuid = {} - self._uuid_to_id = {} - - if os.path.exists(f"{self._save_folder}"): - for f in os.listdir(f"{self._save_folder}"): - if uuid in f: - os.remove(os.path.join(f"{self._save_folder}", f)) + ids = [[self._label_to_id[label] for label in labels] for label in database_labels] + return ids, distances From 28a7369c83e0e9811a420d59e8eff191c15d8941 Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 13:07:26 -0400 Subject: [PATCH 2/9] update index usage --- chromadb/db/clickhouse.py | 64 +++++++++++++++++++----------------- chromadb/db/index/hnswlib.py | 36 ++++++++++++++++---- 2 files changed, 64 insertions(+), 36 deletions(-) diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py index 0bdf7464f0c..94fa858c290 100644 --- a/chromadb/db/clickhouse.py +++ b/chromadb/db/clickhouse.py @@ -1,6 +1,6 @@ from chromadb.api.types import Documents, Embeddings, IDs, Metadatas, Where, WhereDocument from chromadb.db import DB -from chromadb.db.index.hnswlib import Hnswlib +from chromadb.db.index.hnswlib import Hnswlib, index_exists, delete_all_indexes from chromadb.errors import ( NoDatapointsException, InvalidDimensionException, @@ -51,7 +51,6 @@ class Clickhouse(DB): # def __init__(self, settings): self._conn = None - self._idx = Hnswlib(settings) self._settings = settings def _init_conn(self): @@ -81,6 +80,24 @@ def _create_table_embeddings(self, conn): ) ENGINE = MergeTree() ORDER BY collection_uuid""" ) + index_cache = {} + + def _index(self, collection_id): + """Retrieve an HNSW index instance for the given collection""" + + if collection_id not in self.index_cache: + collection_metadata = self.get_collection(collection_id)[2] + index = Hnswlib(self._settings, collection_id, collection_metadata) + self.index_cache[collection_id] = index + + return self.index_cache[collection_id] + + def _delete_index(self, collection_id): + """Delete an index from the cache""" + index = self._index(collection_id) + index.delete() + del self.index_cache[collection_id] + # # UTILITY METHODS # @@ -195,7 +212,7 @@ def delete_collection(self, name: str): """ ) - self._idx.delete_index(collection_uuid) + self._delete_index(collection_uuid) # # ITEM METHODS @@ -271,9 +288,8 @@ def update( # Update the index if embeddings is not None: - update_uuids = [x[1] for x in existing_items] - self._idx.delete_from_index(collection_uuid, update_uuids) - self._idx.add_incremental(collection_uuid, update_uuids, embeddings) + index = self._index(collection_uuid) + index.add(ids, embeddings, update=True) def _get(self, where={}, columns: Optional[List] = None): select_columns = db_schema_to_keys() if columns is None else columns @@ -437,7 +453,8 @@ def delete( deleted_uuids = self._delete(where_str) - self._idx.delete_from_index(collection_uuid, deleted_uuids) + index = self._index(collection_uuid) + index.delete_from_index(deleted_uuids) return deleted_uuids @@ -476,21 +493,6 @@ def get_nearest_neighbors( if collection_name is not None: collection_uuid = self.get_collection_uuid_from_name(collection_name) - self._idx.load_if_not_loaded(collection_uuid) - - idx_metadata = self._idx.get_metadata() - # Check query embeddings dimensionality - if idx_metadata["dimensionality"] != len(embeddings[0]): - raise InvalidDimensionException( - f"Query embeddings dimensionality {len(embeddings[0])} does not match index dimensionality {idx_metadata['dimensionality']}" - ) - - # Check number of requested results - if n_results > idx_metadata["elements"]: - raise NotEnoughElementsException( - f"Number of requested results {n_results} cannot be greater than number of elements in index {idx_metadata['elements']}" - ) - if len(where) != 0 or len(where_document) != 0: results = self.get( collection_uuid=collection_uuid, where=where, where_document=where_document @@ -504,9 +506,9 @@ def get_nearest_neighbors( ) else: ids = None - uuids, distances = self._idx.get_nearest_neighbors( - collection_uuid, embeddings, n_results, ids - ) + + index = self._index(collection_uuid) + uuids, distances = index.get_nearest_neighbors(collection_uuid, embeddings, n_results, ids) return uuids, distances @@ -523,13 +525,15 @@ def create_index(self, collection_uuid: str): uuids = [x[1] for x in get] embeddings = [x[2] for x in get] - self._idx.run(collection_uuid, uuids, embeddings) + index = self._index(collection_uuid) + index.add(collection_uuid, uuids, embeddings) def add_incremental(self, collection_uuid, uuids, embeddings): - self._idx.add_incremental(collection_uuid, uuids, embeddings) + index = self._index(collection_uuid) + index.add(collection_uuid, uuids, embeddings) def has_index(self, collection_uuid: str): - return self._idx.has_index(collection_uuid) + return index_exists(self._settings, collection_uuid) def reset(self): conn = self._get_conn() @@ -538,8 +542,8 @@ def reset(self): self._create_table_collections(conn) self._create_table_embeddings(conn) - self._idx.reset() - self._idx = Hnswlib(self._settings) + delete_all_indexes(self._settings) + self.index_cache = {} def raw_sql(self, sql): return self._get_conn().query(sql).result_rows diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index a4b68cd24fe..7a49f717b06 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -7,7 +7,7 @@ import hnswlib import numpy as np from chromadb.db.index import Index -from chromadb.errors import NoIndexException, InvalidDimensionException +from chromadb.errors import NoIndexException, InvalidDimensionException, NotEnoughElementsException import logging import re from uuid import UUID @@ -36,6 +36,8 @@ class HnswParams: def __init__(self, metadata): + metadata = metadata or {} + for param, value in metadata.items(): if param.startswith("hnsw:"): if param not in valid_params: @@ -56,6 +58,15 @@ def hexid(id): return id.hex if isinstance(id, UUID) else id +def index_exists(settings, id): + return os.path.exists(f"{settings.persist_directory}/index/index_{id}.bin") + + +def delete_all_indexes(settings): + for file in os.listdir(f"{settings.persist_directory}/index"): + os.remove(f"{settings.persist_directory}/index/{file}") + + class Hnswlib(Index): _id: str _index: hnswlib.Index @@ -95,6 +106,15 @@ def _init_index(self, dimensionality): } self._save() + def _check_dimensionality(self, data): + """Assert that the given data matches the index dimensionality""" + dim = len(data[0]) + idx_dim = self._index.get_dim() + if dim != idx_dim: + raise InvalidDimensionException( + f"Dimensionality of ({dim}) does not match index dimensionality ({idx_dim})" + ) + def add(self, ids, embeddings, update=False): """Add or update embeddings to the index""" @@ -104,11 +124,7 @@ def add(self, ids, embeddings, update=False): self._init_index(dim) # Check dimensionality - idx_dim = self._index.get_dim() - if dim != idx_dim: - raise InvalidDimensionException( - f"Dimensionality of new embeddings ({dim}) does not match index dimensionality ({idx_dim})" - ) + self._check_dimensionality(embeddings) labels = [] for id in ids: @@ -205,6 +221,14 @@ def get_nearest_neighbors(self, query, k, ids=None): if self._index is None: raise NoIndexException("Index not found, please create an instance before querying") + if k > self._index_metadata["elements"]: + raise NotEnoughElementsException( + f"Number of requested results {k} cannot be greater than number of elements in index {self._index_metadata['elements']}" + ) + + # Check dimensionality + self._check_dimensionality(query) + s2 = time.time() # get ids from uuids as a set, if they are available labels = {} From c5d4c84eeb4a2505afff924f4ab80b5aa4847743 Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 13:12:45 -0400 Subject: [PATCH 3/9] udpate index usage in duckdb.py --- chromadb/db/clickhouse.py | 7 +++++-- chromadb/db/duckdb.py | 10 ++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py index 94fa858c290..14b2344ce17 100644 --- a/chromadb/db/clickhouse.py +++ b/chromadb/db/clickhouse.py @@ -535,6 +535,10 @@ def add_incremental(self, collection_uuid, uuids, embeddings): def has_index(self, collection_uuid: str): return index_exists(self._settings, collection_uuid) + def reset_indexes(self): + delete_all_indexes(self._settings) + self.index_cache = {} + def reset(self): conn = self._get_conn() conn.command("DROP TABLE collections") @@ -542,8 +546,7 @@ def reset(self): self._create_table_collections(conn) self._create_table_embeddings(conn) - delete_all_indexes(self._settings) - self.index_cache = {} + self.reset_indexes() def raw_sql(self, sql): return self._get_conn().query(sql).result_rows diff --git a/chromadb/db/duckdb.py b/chromadb/db/duckdb.py index 79d748eecc4..1788e669570 100644 --- a/chromadb/db/duckdb.py +++ b/chromadb/db/duckdb.py @@ -1,6 +1,5 @@ from chromadb.api.types import Documents, Embeddings, IDs, Metadatas from chromadb.db import DB -from chromadb.db.index.hnswlib import Hnswlib from chromadb.db.clickhouse import ( Clickhouse, db_array_schema_to_clickhouse_schema, @@ -49,7 +48,6 @@ def __init__(self, settings): self._conn = duckdb.connect() self._create_table_collections() self._create_table_embeddings() - self._idx = Hnswlib(settings) self._settings = settings # https://duckdb.org/docs/extensions/overview @@ -116,7 +114,8 @@ def delete_collection(self, name: str): self._conn.execute( f"""DELETE FROM embeddings WHERE collection_uuid = ?""", [collection_uuid] ) - self._idx.delete_index(collection_uuid) + + self._delete_index(collection_uuid) self._conn.execute(f"""DELETE FROM collections WHERE name = ?""", [name]) def update_collection( @@ -350,12 +349,11 @@ def reset(self): self._create_table_collections() self._create_table_embeddings() - self._idx.reset() - self._idx = Hnswlib(self._settings) + self.reset_indexes() def __del__(self): logger.info("Exiting: Cleaning up .chroma directory") - self._idx.reset() + self.reset_indexes() def persist(self): raise NotImplementedError( From 846719f80e9e7b683adcc750b7b5e65fb1619103 Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 13:14:03 -0400 Subject: [PATCH 4/9] clean up unused method --- chromadb/db/__init__.py | 4 ---- chromadb/db/clickhouse.py | 5 +---- chromadb/db/index/hnswlib.py | 4 ---- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/chromadb/db/__init__.py b/chromadb/db/__init__.py index 7f6bad5106b..1a722a49f18 100644 --- a/chromadb/db/__init__.py +++ b/chromadb/db/__init__.py @@ -115,10 +115,6 @@ def raw_sql(self, raw_sql): def create_index(self, collection_uuid: str): pass - @abstractmethod - def has_index(self, collection_name): - pass - @abstractmethod def persist(self): pass diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py index 14b2344ce17..7c608e30f13 100644 --- a/chromadb/db/clickhouse.py +++ b/chromadb/db/clickhouse.py @@ -1,6 +1,6 @@ from chromadb.api.types import Documents, Embeddings, IDs, Metadatas, Where, WhereDocument from chromadb.db import DB -from chromadb.db.index.hnswlib import Hnswlib, index_exists, delete_all_indexes +from chromadb.db.index.hnswlib import Hnswlib, delete_all_indexes from chromadb.errors import ( NoDatapointsException, InvalidDimensionException, @@ -532,9 +532,6 @@ def add_incremental(self, collection_uuid, uuids, embeddings): index = self._index(collection_uuid) index.add(collection_uuid, uuids, embeddings) - def has_index(self, collection_uuid: str): - return index_exists(self._settings, collection_uuid) - def reset_indexes(self): delete_all_indexes(self._settings) self.index_cache = {} diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index 7a49f717b06..31428b77c39 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -58,10 +58,6 @@ def hexid(id): return id.hex if isinstance(id, UUID) else id -def index_exists(settings, id): - return os.path.exists(f"{settings.persist_directory}/index/index_{id}.bin") - - def delete_all_indexes(settings): for file in os.listdir(f"{settings.persist_directory}/index"): os.remove(f"{settings.persist_directory}/index/{file}") From 3c467c52b8ba5d3b397155211c522f97c02fb4b2 Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 13:39:30 -0400 Subject: [PATCH 5/9] bug fixes --- chromadb/db/clickhouse.py | 24 +++++++++++++++++++----- chromadb/db/duckdb.py | 4 ++++ chromadb/db/index/hnswlib.py | 11 ++++++++--- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py index 7c608e30f13..698b00f8214 100644 --- a/chromadb/db/clickhouse.py +++ b/chromadb/db/clickhouse.py @@ -86,8 +86,9 @@ def _index(self, collection_id): """Retrieve an HNSW index instance for the given collection""" if collection_id not in self.index_cache: - collection_metadata = self.get_collection(collection_id)[2] - index = Hnswlib(self._settings, collection_id, collection_metadata) + coll = self.get_collection_by_id(collection_id) + collection_metadata = coll[2] + index = Hnswlib(collection_id, self._settings, collection_metadata) self.index_cache[collection_id] = index return self.index_cache[collection_id] @@ -173,6 +174,19 @@ def get_collection(self, name: str): # json.loads the metadata return [[x[0], x[1], json.loads(x[2])] for x in res] + def get_collection_by_id(self, collection_uuid: str): + res = ( + self._get_conn() + .query( + f""" + SELECT * FROM collections WHERE uuid = '{collection_uuid}' + """ + ) + .result_rows + ) + # json.loads the metadata + return [[x[0], x[1], json.loads(x[2])] for x in res][0] + def list_collections(self) -> Sequence: res = self._get_conn().query(f"""SELECT * FROM collections""").result_rows return [[x[0], x[1], json.loads(x[2])] for x in res] @@ -508,7 +522,7 @@ def get_nearest_neighbors( ids = None index = self._index(collection_uuid) - uuids, distances = index.get_nearest_neighbors(collection_uuid, embeddings, n_results, ids) + uuids, distances = index.get_nearest_neighbors(embeddings, n_results, ids) return uuids, distances @@ -526,11 +540,11 @@ def create_index(self, collection_uuid: str): embeddings = [x[2] for x in get] index = self._index(collection_uuid) - index.add(collection_uuid, uuids, embeddings) + index.add(uuids, embeddings) def add_incremental(self, collection_uuid, uuids, embeddings): index = self._index(collection_uuid) - index.add(collection_uuid, uuids, embeddings) + index.add(uuids, embeddings) def reset_indexes(self): delete_all_indexes(self._settings) diff --git a/chromadb/db/duckdb.py b/chromadb/db/duckdb.py index 1788e669570..9e69fbff4af 100644 --- a/chromadb/db/duckdb.py +++ b/chromadb/db/duckdb.py @@ -105,6 +105,10 @@ def get_collection(self, name: str) -> Sequence: # json.loads the metadata return [[x[0], x[1], json.loads(x[2])] for x in res] + def get_collection_by_id(self, uuid: str) -> Sequence: + res = self._conn.execute(f"""SELECT * FROM collections WHERE uuid = ?""", [uuid]).fetchone() + return [res[0], res[1], json.loads(res[2])] + def list_collections(self) -> Sequence: res = self._conn.execute(f"""SELECT * FROM collections""").fetchall() return [[x[0], x[1], json.loads(x[2])] for x in res] diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index 31428b77c39..2ee377febfc 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -59,8 +59,9 @@ def hexid(id): def delete_all_indexes(settings): - for file in os.listdir(f"{settings.persist_directory}/index"): - os.remove(f"{settings.persist_directory}/index/{file}") + if os.path.exists(f"{settings.persist_directory}/index"): + for file in os.listdir(f"{settings.persist_directory}/index"): + os.remove(f"{settings.persist_directory}/index/{file}") class Hnswlib(Index): @@ -77,6 +78,7 @@ def __init__(self, id, settings, metadata): self._save_folder = settings.persist_directory + "/index" self._params = HnswParams(metadata) self._id = id + self._index = None self._load() @@ -105,7 +107,7 @@ def _init_index(self, dimensionality): def _check_dimensionality(self, data): """Assert that the given data matches the index dimensionality""" dim = len(data[0]) - idx_dim = self._index.get_dim() + idx_dim = self._index.dim if dim != idx_dim: raise InvalidDimensionException( f"Dimensionality of ({dim}) does not match index dimensionality ({idx_dim})" @@ -134,11 +136,14 @@ def add(self, ids, embeddings, update=False): next_label = self._index_metadata["elements"] self._id_to_label[hexid(id)] = next_label self._label_to_id[next_label] = id + labels.append(next_label) if self._index_metadata["elements"] > self._index.get_max_elements(): new_size = min(self._index_metadata["elements"] * self._params.resize_factor, 1000) self._index.resize_index(new_size) + print("embeddings: ", embeddings) + print("labels: ", labels) self._index.add_items(embeddings, labels) self._save() From 40c065de32769a9850ada1f02adb250073e199ee Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 16:11:31 -0400 Subject: [PATCH 6/9] fix bugs --- chromadb/db/index/hnswlib.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index 2ee377febfc..c4380f3d987 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -142,8 +142,6 @@ def add(self, ids, embeddings, update=False): new_size = min(self._index_metadata["elements"] * self._params.resize_factor, 1000) self._index.resize_index(new_size) - print("embeddings: ", embeddings) - print("labels: ", labels) self._index.add_items(embeddings, labels) self._save() @@ -236,11 +234,11 @@ def get_nearest_neighbors(self, query, k, ids=None): if ids is not None: labels = {self._id_to_label[hexid(id)] for id in ids} if len(labels) < k: - k = len(ids) + k = len(labels) filter_function = None if len(labels) != 0: - filter_function = lambda id: id in ids + filter_function = lambda label: label in labels logger.debug(f"time to pre process our knn query: {time.time() - s2}") @@ -248,5 +246,5 @@ def get_nearest_neighbors(self, query, k, ids=None): database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function) logger.debug(f"time to run knn query: {time.time() - s3}") - ids = [[self._label_to_id[label] for label in labels] for label in database_labels] + ids = [[self._label_to_id[label] for label in labels] for labels in database_labels] return ids, distances From 596caad20d8b2ac729a05c906ad4e5175fa77111 Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 16:18:28 -0400 Subject: [PATCH 7/9] fix validation ordering --- chromadb/db/index/hnswlib.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index c4380f3d987..d4f298f1ea4 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -220,14 +220,14 @@ def get_nearest_neighbors(self, query, k, ids=None): if self._index is None: raise NoIndexException("Index not found, please create an instance before querying") + # Check dimensionality + self._check_dimensionality(query) + if k > self._index_metadata["elements"]: raise NotEnoughElementsException( f"Number of requested results {k} cannot be greater than number of elements in index {self._index_metadata['elements']}" ) - # Check dimensionality - self._check_dimensionality(query) - s2 = time.time() # get ids from uuids as a set, if they are available labels = {} From ad9d8697db68a58063bff640bbc4d7dbb74576ce Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Tue, 28 Mar 2023 17:27:16 -0400 Subject: [PATCH 8/9] persistence test --- chromadb/db/index/hnswlib.py | 3 ++ chromadb/test/test_api.py | 98 ++++++++++++++++++++++++++++++++---- 2 files changed, 92 insertions(+), 9 deletions(-) diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index d4f298f1ea4..c99f12321dd 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -38,6 +38,9 @@ def __init__(self, metadata): metadata = metadata or {} + # Convert all values to strings for future compatibility. + metadata = {k: str(v) for k, v in metadata.items()} + for param, value in metadata.items(): if param.startswith("hnsw:"): if param not in valid_params: diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 1884c2ac3af..9226b6c7365 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -687,7 +687,7 @@ def test_metadata_validation_add(api_fixture, request): api.reset() collection = api.create_collection("test_metadata_validation") - with pytest.raises(ValueError, match='metadata'): + with pytest.raises(ValueError, match="metadata"): collection.add(**bad_metadata_records) @@ -698,7 +698,7 @@ def test_metadata_validation_update(api_fixture, request): api.reset() collection = api.create_collection("test_metadata_validation") collection.add(**metadata_records) - with pytest.raises(ValueError, match='metadata'): + with pytest.raises(ValueError, match="metadata"): collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}}) @@ -708,7 +708,7 @@ def test_where_validation_get(api_fixture, request): api.reset() collection = api.create_collection("test_where_validation") - with pytest.raises(ValueError, match='where'): + with pytest.raises(ValueError, match="where"): collection.get(where={"value": {"nested": "5"}}) @@ -718,7 +718,7 @@ def test_where_validation_query(api_fixture, request): api.reset() collection = api.create_collection("test_where_validation") - with pytest.raises(ValueError, match='where'): + with pytest.raises(ValueError, match="where"): collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}}) @@ -908,13 +908,13 @@ def test_query_document_valid_operators(api_fixture, request): api.reset() collection = api.create_collection("test_where_valid_operators") collection.add(**operator_records) - with pytest.raises(ValueError, match='where document'): + with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$lt": {"$nested": 2}}) - with pytest.raises(ValueError, match='where document'): + with pytest.raises(ValueError, match="where document"): collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2}) - with pytest.raises(ValueError, match='where document'): + with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$contains": []}) # Test invalid $and, $or @@ -1177,10 +1177,10 @@ def test_get_include(api_fixture, request): assert items["embeddings"] == None assert items["ids"][0] == "id1" - with pytest.raises(ValueError, match='include'): + with pytest.raises(ValueError, match="include"): items = collection.get(include=["metadatas", "undefined"]) - with pytest.raises(ValueError, match='include'): + with pytest.raises(ValueError, match="include"): items = collection.get(include=None) @@ -1224,3 +1224,83 @@ def test_invalid_id(api_fixture, request): with pytest.raises(ValueError) as e: collection.delete(ids=["valid", 0]) assert "ID" in str(e.value) + + +@pytest.mark.parametrize("api_fixture", test_apis) +def test_index_params(api_fixture, request): + api = request.getfixturevalue(api_fixture.__name__) + + # first standard add + api.reset() + collection = api.create_collection(name="test_index_params") + collection.add(**records) + items = collection.query( + query_embeddings=[0.6, 1.12, 1.6], + n_results=1, + ) + assert items["distances"][0][0] > 4 + + # cosine + api.reset() + collection = api.create_collection( + name="test_index_params", + metadata={"hnsw:space": "cosine", "hnsw:construction_ef": 20, "hnsw:M": 5}, + ) + collection.add(**records) + items = collection.query( + query_embeddings=[0.6, 1.12, 1.6], + n_results=1, + ) + assert items["distances"][0][0] > 0 + assert items["distances"][0][0] < 1 + + # ip + api.reset() + collection = api.create_collection(name="test_index_params", metadata={"hnsw:space": "ip"}) + collection.add(**records) + items = collection.query( + query_embeddings=[0.6, 1.12, 1.6], + n_results=1, + ) + assert items["distances"][0][0] < -5 + + +@pytest.mark.parametrize("api_fixture", test_apis) +def test_invalid_index_params(api_fixture, request): + + api = request.getfixturevalue(api_fixture.__name__) + api.reset() + + with pytest.raises(Exception): + collection = api.create_collection( + name="test_index_params", metadata={"hnsw:foobar": "blarg"} + ) + collection.add(**records) + + with pytest.raises(Exception): + collection = api.create_collection( + name="test_index_params", metadata={"hnsw:space": "foobar"} + ) + collection.add(**records) + + +@pytest.mark.parametrize("api_fixture", [local_persist_api]) +def test_persist_index_loading(api_fixture, request): + api = request.getfixturevalue("local_persist_api") + api.reset() + collection = api.create_collection("test") + collection.add(ids="id1", documents="hello") + + api.persist() + del api + + api2 = request.getfixturevalue("local_persist_api_cache_bust") + collection = api2.get_collection("test") + + nn = collection.query( + query_texts="hello", + n_results=1, + include=["embeddings", "documents", "metadatas", "distances"], + ) + for key in nn.keys(): + assert len(nn[key]) == 1 From 55e578c1fec2efc4409035b9ba57b1f428b380fd Mon Sep 17 00:00:00 2001 From: atroyn Date: Tue, 28 Mar 2023 16:13:50 -0700 Subject: [PATCH 9/9] Added params persistence test --- chromadb/test/test_api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 9226b6c7365..e81c0321966 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -1285,10 +1285,10 @@ def test_invalid_index_params(api_fixture, request): @pytest.mark.parametrize("api_fixture", [local_persist_api]) -def test_persist_index_loading(api_fixture, request): +def test_persist_index_loading_params(api_fixture, request): api = request.getfixturevalue("local_persist_api") api.reset() - collection = api.create_collection("test") + collection = api.create_collection("test", metadata={"hnsw:space": "ip"}) collection.add(ids="id1", documents="hello") api.persist() @@ -1297,6 +1297,8 @@ def test_persist_index_loading(api_fixture, request): api2 = request.getfixturevalue("local_persist_api_cache_bust") collection = api2.get_collection("test") + assert collection.metadata["hnsw:space"] == "ip" + nn = collection.query( query_texts="hello", n_results=1,