Skip to content

Commit

Permalink
Support parameters for HNSW indexes using metadata (#245)
Browse files Browse the repository at this point in the history
* refactor hnswlib class

* update index usage

* udpate index usage in duckdb.py

* clean up unused method

* bug fixes

* fix bugs

* fix validation ordering

* persistence test

* Added params persistence test

---------

Co-authored-by: atroyn <[email protected]>
  • Loading branch information
levand and atroyn authored Mar 28, 2023
1 parent 90ae684 commit 5cef7f9
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 210 deletions.
4 changes: 0 additions & 4 deletions chromadb/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ def raw_sql(self, raw_sql):
def create_index(self, collection_uuid: str):
pass

@abstractmethod
def has_index(self, collection_name):
pass

@abstractmethod
def persist(self):
pass
80 changes: 49 additions & 31 deletions chromadb/db/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from chromadb.api.types import Documents, Embeddings, IDs, Metadatas, Where, WhereDocument
from chromadb.db import DB
from chromadb.db.index.hnswlib import Hnswlib
from chromadb.db.index.hnswlib import Hnswlib, delete_all_indexes
from chromadb.errors import (
NoDatapointsException,
InvalidDimensionException,
Expand Down Expand Up @@ -51,7 +51,6 @@ class Clickhouse(DB):
#
def __init__(self, settings):
self._conn = None
self._idx = Hnswlib(settings)
self._settings = settings

def _init_conn(self):
Expand Down Expand Up @@ -81,6 +80,25 @@ def _create_table_embeddings(self, conn):
) ENGINE = MergeTree() ORDER BY collection_uuid"""
)

index_cache = {}

def _index(self, collection_id):
"""Retrieve an HNSW index instance for the given collection"""

if collection_id not in self.index_cache:
coll = self.get_collection_by_id(collection_id)
collection_metadata = coll[2]
index = Hnswlib(collection_id, self._settings, collection_metadata)
self.index_cache[collection_id] = index

return self.index_cache[collection_id]

def _delete_index(self, collection_id):
"""Delete an index from the cache"""
index = self._index(collection_id)
index.delete()
del self.index_cache[collection_id]

#
# UTILITY METHODS
#
Expand Down Expand Up @@ -156,6 +174,19 @@ def get_collection(self, name: str):
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res]

def get_collection_by_id(self, collection_uuid: str):
res = (
self._get_conn()
.query(
f"""
SELECT * FROM collections WHERE uuid = '{collection_uuid}'
"""
)
.result_rows
)
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res][0]

def list_collections(self) -> Sequence:
res = self._get_conn().query(f"""SELECT * FROM collections""").result_rows
return [[x[0], x[1], json.loads(x[2])] for x in res]
Expand Down Expand Up @@ -195,7 +226,7 @@ def delete_collection(self, name: str):
"""
)

self._idx.delete_index(collection_uuid)
self._delete_index(collection_uuid)

#
# ITEM METHODS
Expand Down Expand Up @@ -271,9 +302,8 @@ def update(

# Update the index
if embeddings is not None:
update_uuids = [x[1] for x in existing_items]
self._idx.delete_from_index(collection_uuid, update_uuids)
self._idx.add_incremental(collection_uuid, update_uuids, embeddings)
index = self._index(collection_uuid)
index.add(ids, embeddings, update=True)

def _get(self, where={}, columns: Optional[List] = None):
select_columns = db_schema_to_keys() if columns is None else columns
Expand Down Expand Up @@ -437,7 +467,8 @@ def delete(

deleted_uuids = self._delete(where_str)

self._idx.delete_from_index(collection_uuid, deleted_uuids)
index = self._index(collection_uuid)
index.delete_from_index(deleted_uuids)

return deleted_uuids

Expand Down Expand Up @@ -476,21 +507,6 @@ def get_nearest_neighbors(
if collection_name is not None:
collection_uuid = self.get_collection_uuid_from_name(collection_name)

self._idx.load_if_not_loaded(collection_uuid)

idx_metadata = self._idx.get_metadata()
# Check query embeddings dimensionality
if idx_metadata["dimensionality"] != len(embeddings[0]):
raise InvalidDimensionException(
f"Query embeddings dimensionality {len(embeddings[0])} does not match index dimensionality {idx_metadata['dimensionality']}"
)

# Check number of requested results
if n_results > idx_metadata["elements"]:
raise NotEnoughElementsException(
f"Number of requested results {n_results} cannot be greater than number of elements in index {idx_metadata['elements']}"
)

if len(where) != 0 or len(where_document) != 0:
results = self.get(
collection_uuid=collection_uuid, where=where, where_document=where_document
Expand All @@ -504,9 +520,9 @@ def get_nearest_neighbors(
)
else:
ids = None
uuids, distances = self._idx.get_nearest_neighbors(
collection_uuid, embeddings, n_results, ids
)

index = self._index(collection_uuid)
uuids, distances = index.get_nearest_neighbors(embeddings, n_results, ids)

return uuids, distances

Expand All @@ -523,13 +539,16 @@ def create_index(self, collection_uuid: str):
uuids = [x[1] for x in get]
embeddings = [x[2] for x in get]

self._idx.run(collection_uuid, uuids, embeddings)
index = self._index(collection_uuid)
index.add(uuids, embeddings)

def add_incremental(self, collection_uuid, uuids, embeddings):
self._idx.add_incremental(collection_uuid, uuids, embeddings)
index = self._index(collection_uuid)
index.add(uuids, embeddings)

def has_index(self, collection_uuid: str):
return self._idx.has_index(collection_uuid)
def reset_indexes(self):
delete_all_indexes(self._settings)
self.index_cache = {}

def reset(self):
conn = self._get_conn()
Expand All @@ -538,8 +557,7 @@ def reset(self):
self._create_table_collections(conn)
self._create_table_embeddings(conn)

self._idx.reset()
self._idx = Hnswlib(self._settings)
self.reset_indexes()

def raw_sql(self, sql):
return self._get_conn().query(sql).result_rows
14 changes: 8 additions & 6 deletions chromadb/db/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from chromadb.api.types import Documents, Embeddings, IDs, Metadatas
from chromadb.db import DB
from chromadb.db.index.hnswlib import Hnswlib
from chromadb.db.clickhouse import (
Clickhouse,
db_array_schema_to_clickhouse_schema,
Expand Down Expand Up @@ -48,7 +47,6 @@ def __init__(self, settings):
self._conn = duckdb.connect()
self._create_table_collections()
self._create_table_embeddings()
self._idx = Hnswlib(settings)
self._settings = settings

# https://duckdb.org/docs/extensions/overview
Expand Down Expand Up @@ -106,6 +104,10 @@ def get_collection(self, name: str) -> Sequence:
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res]

def get_collection_by_id(self, uuid: str) -> Sequence:
res = self._conn.execute(f"""SELECT * FROM collections WHERE uuid = ?""", [uuid]).fetchone()
return [res[0], res[1], json.loads(res[2])]

def list_collections(self) -> Sequence:
res = self._conn.execute(f"""SELECT * FROM collections""").fetchall()
return [[x[0], x[1], json.loads(x[2])] for x in res]
Expand All @@ -115,7 +117,8 @@ def delete_collection(self, name: str):
self._conn.execute(
f"""DELETE FROM embeddings WHERE collection_uuid = ?""", [collection_uuid]
)
self._idx.delete_index(collection_uuid)

self._delete_index(collection_uuid)
self._conn.execute(f"""DELETE FROM collections WHERE name = ?""", [name])

def update_collection(
Expand Down Expand Up @@ -349,12 +352,11 @@ def reset(self):
self._create_table_collections()
self._create_table_embeddings()

self._idx.reset()
self._idx = Hnswlib(self._settings)
self.reset_indexes()

def __del__(self):
logger.info("Exiting: Cleaning up .chroma directory")
self._idx.reset()
self.reset_indexes()

def persist(self):
raise NotImplementedError(
Expand Down
18 changes: 5 additions & 13 deletions chromadb/db/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,21 @@

class Index(ABC):
@abstractmethod
def __init__(self, settings):
def __init__(self, id, settings, metadata):
pass

@abstractmethod
def delete(self, collection_name):
def delete(self):
pass

@abstractmethod
def delete_from_index(self, collection_name, uuids):
def delete_from_index(self, ids):
pass

@abstractmethod
def reset(self):
def add(self, ids, embeddings, update=False):
pass

@abstractmethod
def run(self, collection_name, uuids, embeddings):
pass

@abstractmethod
def has_index(self, collection_name):
pass

@abstractmethod
def get_nearest_neighbors(self, collection_name, embedding, n_results, ids):
def get_nearest_neighbors(self, embedding, n_results, ids):
pass
Loading

0 comments on commit 5cef7f9

Please sign in to comment.