-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for multiple spaces #457
Changes from all commits
00c64a4
52b4616
85c089d
36abd89
fdde8ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,11 @@ | |
from chromadb.api.types import IndexMetadata | ||
import hnswlib | ||
from chromadb.db.index import Index | ||
from chromadb.errors import NoIndexException, InvalidDimensionException, NotEnoughElementsException | ||
from chromadb.errors import ( | ||
NoIndexException, | ||
InvalidDimensionException, | ||
NotEnoughElementsException, | ||
) | ||
import logging | ||
import re | ||
from uuid import UUID | ||
|
@@ -24,7 +28,6 @@ | |
|
||
|
||
class HnswParams: | ||
|
||
space: str | ||
construction_ef: int | ||
search_ef: int | ||
|
@@ -33,7 +36,6 @@ class HnswParams: | |
resize_factor: float | ||
|
||
def __init__(self, metadata): | ||
|
||
metadata = metadata or {} | ||
|
||
# Convert all values to strings for future compatibility. | ||
|
@@ -44,7 +46,9 @@ def __init__(self, metadata): | |
if param not in valid_params: | ||
raise ValueError(f"Unknown HNSW parameter: {param}") | ||
if not re.match(valid_params[param], value): | ||
raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") | ||
raise ValueError( | ||
f"Invalid value for HNSW parameter: {param} = {value}" | ||
) | ||
|
||
self.space = metadata.get("hnsw:space", "l2") | ||
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) | ||
|
@@ -71,7 +75,7 @@ class Hnswlib(Index): | |
_index_metadata: IndexMetadata | ||
_params: HnswParams | ||
_id_to_label: Dict[str, int] | ||
_label_to_id: Dict[int, str] | ||
_label_to_id: Dict[int, UUID] | ||
|
||
def __init__(self, id, settings, metadata): | ||
self._save_folder = settings.persist_directory + "/index" | ||
|
@@ -128,7 +132,7 @@ def add(self, ids, embeddings, update=False): | |
|
||
labels = [] | ||
for id in ids: | ||
if id in self._id_to_label: | ||
if hexid(id) in self._id_to_label: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the bugfix |
||
if update: | ||
labels.append(self._id_to_label[hexid(id)]) | ||
else: | ||
|
@@ -141,7 +145,9 @@ def add(self, ids, embeddings, update=False): | |
labels.append(next_label) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is weird that sometimes we store the UUID vs the id but leaving this for now to minimize the bugs i might introduce. |
||
|
||
if self._index_metadata["elements"] > self._index.get_max_elements(): | ||
new_size = max(self._index_metadata["elements"] * self._params.resize_factor, 1000) | ||
new_size = max( | ||
self._index_metadata["elements"] * self._params.resize_factor, 1000 | ||
) | ||
self._index.resize_index(int(new_size)) | ||
|
||
self._index.add_items(embeddings, labels) | ||
|
@@ -196,7 +202,6 @@ def _exists(self): | |
return | ||
|
||
def _load(self): | ||
|
||
if not os.path.exists(f"{self._save_folder}/index_{self._id}.bin"): | ||
return | ||
|
||
|
@@ -208,7 +213,9 @@ def _load(self): | |
with open(f"{self._save_folder}/index_metadata_{self._id}.pkl", "rb") as f: | ||
self._index_metadata = pickle.load(f) | ||
|
||
p = hnswlib.Index(space=self._params.space, dim=self._index_metadata["dimensionality"]) | ||
p = hnswlib.Index( | ||
space=self._params.space, dim=self._index_metadata["dimensionality"] | ||
) | ||
self._index = p | ||
self._index.load_index( | ||
f"{self._save_folder}/index_{self._id}.bin", | ||
|
@@ -218,9 +225,10 @@ def _load(self): | |
self._index.set_num_threads(self._params.num_threads) | ||
|
||
def get_nearest_neighbors(self, query, k, ids=None): | ||
|
||
if self._index is None: | ||
raise NoIndexException("Index not found, please create an instance before querying") | ||
raise NoIndexException( | ||
"Index not found, please create an instance before querying" | ||
) | ||
|
||
# Check dimensionality | ||
self._check_dimensionality(query) | ||
|
@@ -245,8 +253,12 @@ def get_nearest_neighbors(self, query, k, ids=None): | |
logger.debug(f"time to pre process our knn query: {time.time() - s2}") | ||
|
||
s3 = time.time() | ||
database_labels, distances = self._index.knn_query(query, k=k, filter=filter_function) | ||
database_labels, distances = self._index.knn_query( | ||
query, k=k, filter=filter_function | ||
) | ||
logger.debug(f"time to run knn query: {time.time() - s3}") | ||
|
||
ids = [[self._label_to_id[label] for label in labels] for labels in database_labels] | ||
ids = [ | ||
[self._label_to_id[label] for label in labels] for labels in database_labels | ||
] | ||
return ids, distances |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,31 +25,37 @@ | |
version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$") | ||
|
||
|
||
def _patch_uppercase_coll_name(collection: strategies.Collection, | ||
embeddings: strategies.RecordSet): | ||
def _patch_uppercase_coll_name( | ||
collection: strategies.Collection, embeddings: strategies.RecordSet | ||
): | ||
"""Old versions didn't handle uppercase characters in collection names""" | ||
collection.name = collection.name.lower() | ||
|
||
|
||
def _patch_empty_dict_metadata(collection: strategies.Collection, | ||
embeddings: strategies.RecordSet): | ||
def _patch_empty_dict_metadata( | ||
collection: strategies.Collection, embeddings: strategies.RecordSet | ||
): | ||
"""Old versions do the wrong thing when metadata is a single empty dict""" | ||
if embeddings["metadatas"] == {}: | ||
embeddings["metadatas"] = None | ||
|
||
|
||
version_patches = [("0.3.21", _patch_uppercase_coll_name), | ||
("0.3.21", _patch_empty_dict_metadata)] | ||
version_patches = [ | ||
("0.3.21", _patch_uppercase_coll_name), | ||
("0.3.21", _patch_empty_dict_metadata), | ||
] | ||
|
||
|
||
def patch_for_version(version, | ||
collection: strategies.Collection, | ||
embeddings: strategies.RecordSet): | ||
def patch_for_version( | ||
version, collection: strategies.Collection, embeddings: strategies.RecordSet | ||
): | ||
"""Override aspects of the collection and embeddings, before testing, to account for | ||
breaking changes in old versions.""" | ||
|
||
for patch_version, patch in version_patches: | ||
if packaging_version.Version(version) <= packaging_version.Version(patch_version): | ||
if packaging_version.Version(version) <= packaging_version.Version( | ||
patch_version | ||
): | ||
patch(collection, embeddings) | ||
|
||
|
||
|
@@ -84,9 +90,7 @@ def configurations(versions): | |
|
||
# This fixture is not shared with the rest of the tests because it is unique in how it | ||
# installs the versions of chromadb | ||
@pytest.fixture( | ||
scope="module", params=configurations(test_old_versions) | ||
) | ||
@pytest.fixture(scope="module", params=configurations(test_old_versions)) | ||
def version_settings(request) -> Generator[Tuple[str, Settings], None, None]: | ||
configuration = request.param | ||
version = configuration[0] | ||
|
@@ -172,7 +176,7 @@ def persist_generated_data_with_old_version( | |
coll = api.create_collection( | ||
name=collection_strategy.name, | ||
metadata=collection_strategy.metadata, | ||
embedding_function=collection_strategy.embedding_function, | ||
embedding_function=lambda x: None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes the tests faster because we pass none across the process boundary because we can't pickle a lambda |
||
) | ||
coll.add(**embeddings_strategy) | ||
# We can't use the invariants module here because it uses the current version | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below comment on how its odd that we store the string and the uuid in opposite direction. Updating the typing.