Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parameters for HNSW indexes using metadata #245

Merged
merged 9 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions chromadb/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ def raw_sql(self, raw_sql):
def create_index(self, collection_uuid: str):
pass

@abstractmethod
def has_index(self, collection_name):
pass

@abstractmethod
def persist(self):
pass
80 changes: 49 additions & 31 deletions chromadb/db/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from chromadb.api.types import Documents, Embeddings, IDs, Metadatas, Where, WhereDocument
from chromadb.db import DB
from chromadb.db.index.hnswlib import Hnswlib
from chromadb.db.index.hnswlib import Hnswlib, delete_all_indexes
from chromadb.errors import (
NoDatapointsException,
InvalidDimensionException,
Expand Down Expand Up @@ -51,7 +51,6 @@ class Clickhouse(DB):
#
def __init__(self, settings):
self._conn = None
self._idx = Hnswlib(settings)
self._settings = settings

def _init_conn(self):
Expand Down Expand Up @@ -81,6 +80,25 @@ def _create_table_embeddings(self, conn):
) ENGINE = MergeTree() ORDER BY collection_uuid"""
)

index_cache = {}

def _index(self, collection_id):
"""Retrieve an HNSW index instance for the given collection"""

if collection_id not in self.index_cache:
coll = self.get_collection_by_id(collection_id)
collection_metadata = coll[2]
index = Hnswlib(collection_id, self._settings, collection_metadata)
self.index_cache[collection_id] = index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use an LRU cache with some simple heuristics, this seems like an easy way to get into hairy performance issues, but not sure if its premature to do so since most users don't have enough data to saturate RAM.


return self.index_cache[collection_id]

def _delete_index(self, collection_id):
"""Delete an index from the cache"""
index = self._index(collection_id)
index.delete()
del self.index_cache[collection_id]

#
# UTILITY METHODS
#
Expand Down Expand Up @@ -156,6 +174,19 @@ def get_collection(self, name: str):
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res]

def get_collection_by_id(self, collection_uuid: str):
res = (
self._get_conn()
.query(
f"""
SELECT * FROM collections WHERE uuid = '{collection_uuid}'
"""
)
.result_rows
)
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res][0]

def list_collections(self) -> Sequence:
res = self._get_conn().query(f"""SELECT * FROM collections""").result_rows
return [[x[0], x[1], json.loads(x[2])] for x in res]
Expand Down Expand Up @@ -195,7 +226,7 @@ def delete_collection(self, name: str):
"""
)

self._idx.delete_index(collection_uuid)
self._delete_index(collection_uuid)

#
# ITEM METHODS
Expand Down Expand Up @@ -271,9 +302,8 @@ def update(

# Update the index
if embeddings is not None:
update_uuids = [x[1] for x in existing_items]
self._idx.delete_from_index(collection_uuid, update_uuids)
self._idx.add_incremental(collection_uuid, update_uuids, embeddings)
index = self._index(collection_uuid)
index.add(ids, embeddings, update=True)

def _get(self, where={}, columns: Optional[List] = None):
select_columns = db_schema_to_keys() if columns is None else columns
Expand Down Expand Up @@ -437,7 +467,8 @@ def delete(

deleted_uuids = self._delete(where_str)

self._idx.delete_from_index(collection_uuid, deleted_uuids)
index = self._index(collection_uuid)
index.delete_from_index(deleted_uuids)

return deleted_uuids

Expand Down Expand Up @@ -476,21 +507,6 @@ def get_nearest_neighbors(
if collection_name is not None:
collection_uuid = self.get_collection_uuid_from_name(collection_name)

self._idx.load_if_not_loaded(collection_uuid)

idx_metadata = self._idx.get_metadata()
# Check query embeddings dimensionality
if idx_metadata["dimensionality"] != len(embeddings[0]):
raise InvalidDimensionException(
f"Query embeddings dimensionality {len(embeddings[0])} does not match index dimensionality {idx_metadata['dimensionality']}"
)

# Check number of requested results
if n_results > idx_metadata["elements"]:
raise NotEnoughElementsException(
f"Number of requested results {n_results} cannot be greater than number of elements in index {idx_metadata['elements']}"
)

if len(where) != 0 or len(where_document) != 0:
results = self.get(
collection_uuid=collection_uuid, where=where, where_document=where_document
Expand All @@ -504,9 +520,9 @@ def get_nearest_neighbors(
)
else:
ids = None
uuids, distances = self._idx.get_nearest_neighbors(
collection_uuid, embeddings, n_results, ids
)

index = self._index(collection_uuid)
uuids, distances = index.get_nearest_neighbors(embeddings, n_results, ids)

return uuids, distances

Expand All @@ -523,13 +539,16 @@ def create_index(self, collection_uuid: str):
uuids = [x[1] for x in get]
embeddings = [x[2] for x in get]

self._idx.run(collection_uuid, uuids, embeddings)
index = self._index(collection_uuid)
index.add(uuids, embeddings)

def add_incremental(self, collection_uuid, uuids, embeddings):
self._idx.add_incremental(collection_uuid, uuids, embeddings)
index = self._index(collection_uuid)
index.add(uuids, embeddings)

def has_index(self, collection_uuid: str):
return self._idx.has_index(collection_uuid)
def reset_indexes(self):
delete_all_indexes(self._settings)
self.index_cache = {}

def reset(self):
conn = self._get_conn()
Expand All @@ -538,8 +557,7 @@ def reset(self):
self._create_table_collections(conn)
self._create_table_embeddings(conn)

self._idx.reset()
self._idx = Hnswlib(self._settings)
self.reset_indexes()

def raw_sql(self, sql):
return self._get_conn().query(sql).result_rows
14 changes: 8 additions & 6 deletions chromadb/db/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from chromadb.api.types import Documents, Embeddings, IDs, Metadatas
from chromadb.db import DB
from chromadb.db.index.hnswlib import Hnswlib
from chromadb.db.clickhouse import (
Clickhouse,
db_array_schema_to_clickhouse_schema,
Expand Down Expand Up @@ -49,7 +48,6 @@ def __init__(self, settings):
self._conn = duckdb.connect()
self._create_table_collections()
self._create_table_embeddings()
self._idx = Hnswlib(settings)
self._settings = settings

# https://duckdb.org/docs/extensions/overview
Expand Down Expand Up @@ -107,6 +105,10 @@ def get_collection(self, name: str) -> Sequence:
# json.loads the metadata
return [[x[0], x[1], json.loads(x[2])] for x in res]

def get_collection_by_id(self, uuid: str) -> Sequence:
res = self._conn.execute(f"""SELECT * FROM collections WHERE uuid = ?""", [uuid]).fetchone()
return [res[0], res[1], json.loads(res[2])]

def list_collections(self) -> Sequence:
res = self._conn.execute(f"""SELECT * FROM collections""").fetchall()
return [[x[0], x[1], json.loads(x[2])] for x in res]
Expand All @@ -116,7 +118,8 @@ def delete_collection(self, name: str):
self._conn.execute(
f"""DELETE FROM embeddings WHERE collection_uuid = ?""", [collection_uuid]
)
self._idx.delete_index(collection_uuid)

self._delete_index(collection_uuid)
self._conn.execute(f"""DELETE FROM collections WHERE name = ?""", [name])

def update_collection(
Expand Down Expand Up @@ -350,12 +353,11 @@ def reset(self):
self._create_table_collections()
self._create_table_embeddings()

self._idx.reset()
self._idx = Hnswlib(self._settings)
self.reset_indexes()

def __del__(self):
logger.info("Exiting: Cleaning up .chroma directory")
self._idx.reset()
self.reset_indexes()

def persist(self):
raise NotImplementedError(
Expand Down
18 changes: 5 additions & 13 deletions chromadb/db/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,21 @@

class Index(ABC):
@abstractmethod
def __init__(self, settings):
def __init__(self, id, settings, metadata):
pass

@abstractmethod
def delete(self, collection_name):
def delete(self):
pass

@abstractmethod
def delete_from_index(self, collection_name, uuids):
def delete_from_index(self, ids):
pass

@abstractmethod
def reset(self):
def add(self, ids, embeddings, update=False):
pass

@abstractmethod
def run(self, collection_name, uuids, embeddings):
pass

@abstractmethod
def has_index(self, collection_name):
pass

@abstractmethod
def get_nearest_neighbors(self, collection_name, embedding, n_results, ids):
def get_nearest_neighbors(self, embedding, n_results, ids):
pass
Loading