diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 94e3fd98e88..924d967b101 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -14,6 +14,7 @@ QueryResult, GetResult, WhereDocument, + HnswIndexParams, ) import json @@ -56,6 +57,7 @@ def create_collection( metadata: Optional[Dict] = None, get_or_create: bool = False, embedding_function: Optional[Callable] = None, + index_params: Optional[HnswIndexParams] = None, ) -> Collection: """Creates a new collection in the database @@ -64,6 +66,7 @@ def create_collection( metadata (Optional[Dict], optional): A dictionary of metadata to associate with the collection. Defaults to None. get_or_create (bool, optional): If True, will return the collection if it already exists. Defaults to False. embedding_function (Optional[Callable], optional): A function that takes documents and returns an embedding. Defaults to None. + index_params (Optional[HnswIndexParams], optional): The parameters to use for the HNSW index. Defaults to None. Returns: dict: the created collection @@ -83,7 +86,12 @@ def delete_collection( """ @abstractmethod - def get_or_create_collection(self, name: str, metadata: Optional[Dict] = None) -> Collection: + def get_or_create_collection( + self, + name: str, + metadata: Optional[Dict] = None, + index_params: Optional[HnswIndexParams] = None, + ) -> Collection: """Calls create_collection with get_or_create=True Args: diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index ef6d1c4dcb2..29dfb78c67d 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,8 +1,10 @@ from typing import Callable, Dict, Optional from chromadb.api import API +from chromadb.api.local import check_hnsw_index_params from chromadb.api.types import ( Documents, Embeddings, + HnswIndexParams, IDs, Include, Metadatas, @@ -45,11 +47,21 @@ def create_collection( metadata: Optional[Dict] = None, embedding_function: Optional[Callable] = None, get_or_create: bool = False, + index_params: Optional[HnswIndexParams] = None, ) -> Collection: """Creates a collection""" + index_params = check_hnsw_index_params(index_params) + resp = requests.post( self._api_url + "/collections", - data=json.dumps({"name": name, "metadata": metadata, "get_or_create": get_or_create}), + data=json.dumps( + { + "name": name, + "metadata": metadata, + "get_or_create": get_or_create, + "index_params": index_params, + }, + ), ) resp.raise_for_status() resp_json = resp.json() @@ -58,6 +70,7 @@ def create_collection( name=resp_json["name"], embedding_function=embedding_function, metadata=resp_json["metadata"], + index_params=resp_json["index_params"], ) def get_collection( @@ -81,10 +94,13 @@ def get_or_create_collection( name: str, metadata: Optional[Dict] = None, embedding_function: Optional[Callable] = None, + index_params: Optional[HnswIndexParams] = None, ) -> Collection: """Get a collection, or return it if it exists""" - return self.create_collection(name, metadata, embedding_function, get_or_create=True) + return self.create_collection( + name, metadata, embedding_function, get_or_create=True, index_params=index_params + ) def _modify(self, current_name: str, new_name: str, new_metadata: Optional[Dict] = None): """Updates a collection""" diff --git a/chromadb/api/local.py b/chromadb/api/local.py index 9bb1ddead76..0e242e34e27 100644 --- a/chromadb/api/local.py +++ b/chromadb/api/local.py @@ -1,7 +1,7 @@ import json import uuid import time -from typing import Dict, List, Optional, Sequence, Callable, Type, cast +from typing import Dict, List, Literal, Optional, Sequence, Callable, Type, cast, TypedDict from chromadb.api import API from chromadb.db import DB from chromadb.api.types import ( @@ -15,21 +15,26 @@ QueryResult, Where, WhereDocument, + HnswIndexParams, ) from chromadb.api.models.Collection import Collection import re +from chromadb.db.index.hnswlib import check_hnsw_index_params + # mimics s3 bucket requirements for naming def check_index_name(index_name): - msg = ("Expected collection name that " - "(1) contains 3-63 characters, " - "(2) starts and ends with an alphanumeric character, " - "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " - "(4) contains no two consecutive periods (..) and " - "(5) is not a valid IPv4 address, " - f"got {index_name}") + msg = ( + "Expected collection name that " + "(1) contains 3-63 characters, " + "(2) starts and ends with an alphanumeric character, " + "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " + "(4) contains no two consecutive periods (..) and " + "(5) is not a valid IPv4 address, " + f"got {index_name}" + ) if len(index_name) < 3 or len(index_name) > 63: raise ValueError(msg) if not re.match("^[a-z0-9][a-z0-9._-]*[a-z0-9]$", index_name): @@ -39,7 +44,6 @@ def check_index_name(index_name): if re.match("^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$", index_name): raise ValueError(msg) - class LocalAPI(API): def __init__(self, settings, db: DB): self._db = db @@ -56,12 +60,18 @@ def create_collection( metadata: Optional[Dict] = None, embedding_function: Optional[Callable] = None, get_or_create: bool = False, + index_params: Optional[HnswIndexParams] = None, ) -> Collection: check_index_name(name) + index_params = check_hnsw_index_params(index_params) - res = self._db.create_collection(name, metadata, get_or_create) + res = self._db.create_collection(name, metadata, get_or_create, index_params) return Collection( - client=self, name=name, embedding_function=embedding_function, metadata=res[0][2] + client=self, + name=name, + embedding_function=embedding_function, + metadata=res[0][2], + index_params=res[0][3], ) def get_or_create_collection( @@ -69,8 +79,11 @@ def get_or_create_collection( name: str, metadata: Optional[Dict] = None, embedding_function: Optional[Callable] = None, + index_params: Optional[HnswIndexParams] = None, ) -> Collection: - return self.create_collection(name, metadata, embedding_function, get_or_create=True) + return self.create_collection( + name, metadata, embedding_function, get_or_create=True, index_params=index_params + ) def get_collection( self, @@ -119,8 +132,12 @@ def _add( documents: Optional[Documents] = None, increment_index: bool = True, ): + collection = self._db.get_fields_from_collection_name( + collection_name, fields=["uuid", "index_params"] + ) + collection_uuid = collection[0] + index_params = collection[1] - collection_uuid = self._db.get_collection_uuid_from_name(collection_name) added_uuids = self._db.add( collection_uuid, embeddings=embeddings, @@ -130,7 +147,7 @@ def _add( ) if increment_index: - self._db.add_incremental(collection_uuid, added_uuids, embeddings) + self._db.add_incremental(collection_uuid, added_uuids, embeddings, index_params) return True # NIT: should this return the ids of the succesfully added items? diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 2d29c35647d..3324751427b 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -3,6 +3,7 @@ from chromadb.api.types import ( Embedding, + HnswIndexParams, Include, Metadata, Document, @@ -32,6 +33,7 @@ class Collection(BaseModel): name: str metadata: Optional[Dict] = None + index_params: Optional[HnswIndexParams] = None _client: "API" = PrivateAttr() _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() @@ -41,6 +43,7 @@ def __init__( name: str, embedding_function: Optional[EmbeddingFunction] = None, metadata: Optional[Dict] = None, + index_params: Optional[HnswIndexParams] = None, ): self._client = client @@ -49,12 +52,11 @@ def __init__( else: import chromadb.utils.embedding_functions as ef - logger.warning( "No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction" ) self._embedding_function = ef.SentenceTransformerEmbeddingFunction() - super().__init__(name=name, metadata=metadata) + super().__init__(name=name, metadata=metadata, index_params=index_params) def __repr__(self): return f"Collection(name={self.name})" diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 441c6b5786f..ac65c58acd6 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -46,12 +46,16 @@ class QueryResult(TypedDict): metadatas: Optional[List[List[Metadata]]] distances: Optional[List[List[float]]] +class HnswIndexParams(TypedDict): + space: Optional[Literal["l2", "cosine", "ip"]] + ef: Optional[int] + M: Optional[int] class IndexMetadata(TypedDict): dimensionality: int elements: int time_created: float - + index_params: HnswIndexParams class EmbeddingFunction(Protocol): def __call__(self, texts: Documents) -> Embeddings: @@ -169,10 +173,14 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument: if not isinstance(where_document, dict): raise ValueError(f"Expected where document to be a dictionary, got {where_document}") if len(where_document) != 1: - raise ValueError(f"Epected where document to have exactly one operator, got {where_document}") + raise ValueError( + f"Epected where document to have exactly one operator, got {where_document}" + ) for operator, operand in where_document.items(): if operator not in ["$contains", "$and", "$or"]: - raise ValueError(f"Expected where document operator to be one of $contains, $and, $or, got {operator}") + raise ValueError( + f"Expected where document operator to be one of $contains, $and, $or, got {operator}" + ) if operator == "$and" or operator == "$or": if not isinstance(operand, list): raise ValueError( @@ -205,5 +213,7 @@ def validate_include(include: Include, allow_distances: bool) -> Include: if allow_distances: allowed_values.append("distances") if item not in allowed_values: - raise ValueError(f"Expected include item to be one of {', '.join(allowed_values)}, got {item}") + raise ValueError( + f"Expected include item to be one of {', '.join(allowed_values)}, got {item}" + ) return include diff --git a/chromadb/db/__init__.py b/chromadb/db/__init__.py index a739b854d9b..a3dc20c287d 100644 --- a/chromadb/db/__init__.py +++ b/chromadb/db/__init__.py @@ -2,7 +2,15 @@ from typing import Dict, List, Sequence, Optional, Tuple from uuid import UUID import numpy.typing as npt -from chromadb.api.types import Embeddings, Documents, IDs, Metadatas, Where, WhereDocument +from chromadb.api.types import ( + Embeddings, + Documents, + IDs, + Metadatas, + Where, + WhereDocument, + HnswIndexParams, +) class DB(ABC): @@ -12,7 +20,11 @@ def __init__(self): @abstractmethod def create_collection( - self, name: str, metadata: Optional[Dict] = None, get_or_create: bool = False + self, + name: str, + metadata: Optional[Dict] = None, + get_or_create: bool = False, + index_params: Optional[HnswIndexParams] = None, ) -> Sequence: pass @@ -38,6 +50,10 @@ def delete_collection(self, name: str): def get_collection_uuid_from_name(self, collection_name: str) -> str: pass + @abstractmethod + def get_fields_from_collection_name(self, collection_name: str, fields: List[str]) -> List: + pass + @abstractmethod def add( self, @@ -46,11 +62,18 @@ def add( metadatas: Optional[Metadatas], documents: Optional[Documents], ids: List[UUID], + index_params: Optional[HnswIndexParams] = None, ) -> List[UUID]: pass @abstractmethod - def add_incremental(self, collection_uuid: str, ids: List[UUID], embeddings: Embeddings): + def add_incremental( + self, + collection_uuid: str, + ids: List[UUID], + embeddings: Embeddings, + index_params: Optional[HnswIndexParams] = None, + ): pass @abstractmethod diff --git a/chromadb/db/clickhouse.py b/chromadb/db/clickhouse.py index 7ac1050b3c7..f206b2a79f8 100644 --- a/chromadb/db/clickhouse.py +++ b/chromadb/db/clickhouse.py @@ -1,6 +1,14 @@ -from chromadb.api.types import Documents, Embeddings, IDs, Metadatas, Where, WhereDocument +from chromadb.api.types import ( + Documents, + Embeddings, + IDs, + Metadatas, + Where, + WhereDocument, + HnswIndexParams, +) from chromadb.db import DB -from chromadb.db.index.hnswlib import Hnswlib +from chromadb.db.index.hnswlib import DEFAULT_INDEX_PARAMS, Hnswlib from chromadb.errors import ( NoDatapointsException, InvalidDimensionException, @@ -18,7 +26,13 @@ logger = logging.getLogger(__name__) -COLLECTION_TABLE_SCHEMA = [{"uuid": "UUID"}, {"name": "String"}, {"metadata": "String"}] +COLLECTION_TABLE_SCHEMA = [ + {"uuid": "UUID"}, + {"name": "String"}, + {"metadata": "String"}, +] + +COLLECTION_TABLE_SCHEMA_MIGRATIONS = [{"index_params": "String"}] EMBEDDING_TABLE_SCHEMA = [ {"collection_uuid": "UUID"}, @@ -30,6 +44,12 @@ ] +def get_migration_default_value(key: str) -> Optional[str]: + if key == "index_params": + return json.dumps(DEFAULT_INDEX_PARAMS) + return None + + def db_array_schema_to_clickhouse_schema(table_schema): return_str = "" for element in table_schema: @@ -60,6 +80,7 @@ def _init_conn(self): host=self._settings.clickhouse_host, port=int(self._settings.clickhouse_port) ) self._create_table_collections(self._conn) + self._run_collection_table_migrations(self._conn) self._create_table_embeddings(self._conn) def _get_conn(self) -> Client: @@ -74,6 +95,19 @@ def _create_table_collections(self, conn): ) ENGINE = MergeTree() ORDER BY uuid""" ) + def _run_collection_table_migrations(self, conn): + for migration in COLLECTION_TABLE_SCHEMA_MIGRATIONS: + for k, v in migration.items(): + + DEFAULT = get_migration_default_value(k) + + try: + conn.command( + f"""ALTER TABLE collections ADD COLUMN IF NOT EXISTS {k} {v} DEFAULT {DEFAULT}""" + ) + except Exception as e: + logger.info(f"migration {k} failed, skipping") + def _create_table_embeddings(self, conn): conn.command( f"""CREATE TABLE IF NOT EXISTS embeddings ( @@ -95,6 +129,14 @@ def get_collection_uuid_from_name(self, name: str) -> str: ) return res.result_rows[0][0] + def get_fields_from_collection_name(self, name: str, fields: List[str]) -> List: + res = self._get_conn().query( + f""" + SELECT {','.join(fields)} FROM collections WHERE name = '{name}' + """ + ) + return res.result_rows[0] + def _create_where_clause( self, collection_uuid: str, @@ -121,25 +163,31 @@ def _create_where_clause( # COLLECTION METHODS # def create_collection( - self, name: str, metadata: Optional[Dict] = None, get_or_create: bool = False + self, + name: str, + metadata: Optional[Dict] = None, + get_or_create: bool = False, + index_params: Optional[HnswIndexParams] = None, ) -> Sequence: # poor man's unique constraint dupe_check = self.get_collection(name) if len(dupe_check) > 0: if get_or_create: - logger.info(f"collection with name {name} already exists, returning existing collection") + logger.info( + f"collection with name {name} already exists, returning existing collection" + ) return dupe_check else: raise ValueError(f"Collection with name {name} already exists") collection_uuid = uuid.uuid4() - data_to_insert = [[collection_uuid, name, json.dumps(metadata)]] + data_to_insert = [[collection_uuid, name, json.dumps(metadata), json.dumps(index_params)]] self._get_conn().insert( - "collections", data_to_insert, column_names=["uuid", "name", "metadata"] + "collections", data_to_insert, column_names=["uuid", "name", "metadata", "index_params"] ) - return [[collection_uuid, name, metadata]] + return [[collection_uuid, name, metadata, index_params]] def get_collection(self, name: str): res = ( @@ -151,8 +199,8 @@ def get_collection(self, name: str): ) .result_rows ) - # json.loads the metadata - return [[x[0], x[1], json.loads(x[2])] for x in res] + # json.loads the metadata and index_params + return [[x[0], x[1], json.loads(x[2]), json.loads(x[3])] for x in res] def list_collections(self) -> Sequence: res = self._get_conn().query(f"""SELECT * FROM collections""").result_rows @@ -530,8 +578,8 @@ def create_index(self, collection_uuid: str): self._idx.run(collection_uuid, uuids, embeddings) - def add_incremental(self, collection_uuid, uuids, embeddings): - self._idx.add_incremental(collection_uuid, uuids, embeddings) + def add_incremental(self, collection_uuid, uuids, embeddings, index_params=None): + self._idx.add_incremental(collection_uuid, uuids, embeddings, index_params) def has_index(self, collection_uuid: str): return self._idx.has_index(collection_uuid) diff --git a/chromadb/db/duckdb.py b/chromadb/db/duckdb.py index 0474300c564..46a0bf05bbc 100644 --- a/chromadb/db/duckdb.py +++ b/chromadb/db/duckdb.py @@ -1,12 +1,13 @@ -from chromadb.api.types import Documents, Embeddings, IDs, Metadatas +from chromadb.api.types import Documents, Embeddings, HnswIndexParams, IDs, Metadatas from chromadb.db import DB -from chromadb.db.index.hnswlib import Hnswlib +from chromadb.db.index.hnswlib import DEFAULT_INDEX_PARAMS, Hnswlib from chromadb.db.clickhouse import ( Clickhouse, db_array_schema_to_clickhouse_schema, EMBEDDING_TABLE_SCHEMA, db_schema_to_keys, COLLECTION_TABLE_SCHEMA, + COLLECTION_TABLE_SCHEMA_MIGRATIONS, ) from typing import List, Optional, Sequence, Dict import pandas as pd @@ -19,6 +20,7 @@ logger = logging.getLogger(__name__) + def clickhouse_to_duckdb_schema(table_schema): for item in table_schema: if "embedding" in item: @@ -58,7 +60,7 @@ def __init__(self, settings): def _create_table_collections(self): self._conn.execute( f"""CREATE TABLE collections ( - {db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(COLLECTION_TABLE_SCHEMA))} + {db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(COLLECTION_TABLE_SCHEMA + COLLECTION_TABLE_SCHEMA_MIGRATIONS))} ) """ ) @@ -78,31 +80,49 @@ def get_collection_uuid_from_name(self, name): f"""SELECT uuid FROM collections WHERE name = ?""", [name] ).fetchall()[0][0] + def get_fields_from_collection_name(self, name: str, fields: List[str]) -> List: + res = self._conn.execute( + f"""SELECT {", ".join(fields)} FROM collections WHERE name = ?""", + [name], + ).fetchall() + + # if index_params was passed, json.loads it + if "index_params" in fields: + res = [[x[0], json.loads(x[1])] for x in res] + + return res[0] + # # COLLECTION METHODS # def create_collection( - self, name: str, metadata: Optional[Dict] = None, get_or_create: bool = False + self, + name: str, + metadata: Optional[Dict] = None, + get_or_create: bool = False, + index_params: Optional[HnswIndexParams] = None, ) -> Sequence: # poor man's unique constraint dupe_check = self.get_collection(name) if len(dupe_check) > 0: if get_or_create == True: - logger.info(f"collection with name {name} already exists, returning existing collection") + logger.info( + f"collection with name {name} already exists, returning existing collection" + ) return dupe_check else: raise ValueError(f"Collection with name {name} already exists") self._conn.execute( - f"""INSERT INTO collections (uuid, name, metadata) VALUES (?, ?, ?)""", - [str(uuid.uuid4()), name, json.dumps(metadata)], + f"""INSERT INTO collections (uuid, name, metadata, index_params) VALUES (?, ?, ?, ?)""", + [str(uuid.uuid4()), name, json.dumps(metadata), json.dumps(index_params)], ) - return [[str(uuid.uuid4()), name, metadata]] + return [[str(uuid.uuid4()), name, metadata, index_params]] def get_collection(self, name: str) -> Sequence: res = self._conn.execute(f"""SELECT * FROM collections WHERE name = ?""", [name]).fetchall() # json.loads the metadata - return [[x[0], x[1], json.loads(x[2])] for x in res] + return [[x[0], x[1], json.loads(x[2]), json.loads(x[3])] for x in res] def list_collections(self) -> Sequence: res = self._conn.execute(f"""SELECT * FROM collections""").fetchall() @@ -431,6 +451,9 @@ def load(self): logger.info(f"No existing DB found in {self._save_folder}, skipping load") else: path = self._save_folder + "/chroma-collections.parquet" + + self._run_migrations(path) + self._conn.execute(f"INSERT INTO collections SELECT * FROM read_parquet('{path}');") logger.info( f"""loaded in {self._conn.query(f"SELECT COUNT() FROM collections").fetchall()[0][0]} collections""" @@ -448,3 +471,13 @@ def reset(self): shutil.rmtree(self._save_folder) os.mkdir(self._save_folder) + + def _run_migrations(self, path_to_collections_parquet): + df = pd.read_parquet(path_to_collections_parquet) + headers = df.columns.tolist() + + # this column was added to enable users to specify index params + if "index_params" not in headers: + df["index_params"] = json.dumps(DEFAULT_INDEX_PARAMS) + + df.to_parquet(path_to_collections_parquet) diff --git a/chromadb/db/index/hnswlib.py b/chromadb/db/index/hnswlib.py index 47b8a5ca136..8bf90aa03e6 100644 --- a/chromadb/db/index/hnswlib.py +++ b/chromadb/db/index/hnswlib.py @@ -3,7 +3,7 @@ import time from typing import Optional import uuid -from chromadb.api.types import IndexMetadata +from chromadb.api.types import IndexMetadata, HnswIndexParams import hnswlib import numpy as np from chromadb.db.index import Index @@ -12,6 +12,31 @@ logger = logging.getLogger(__name__) +DEFAULT_INDEX_PARAMS = HnswIndexParams(space="l2", ef=10, M=16) + +def check_hnsw_index_params(index_params: Optional[HnswIndexParams] = None) -> HnswIndexParams: + + params = DEFAULT_INDEX_PARAMS + + if index_params is None: + return params + + if ('space' in index_params) and (index_params['space'] is not None): + if index_params["space"] not in ["l2", "cosine", "ip"]: + raise ValueError(f"Expected space to be l2, cosine, or ip, got {index_params['space']}") + params["space"] = index_params["space"] + + if ('M' in index_params) and (index_params["M"] is not None): + if index_params["M"] < 0: + raise ValueError(f"Expected M to be >= 0, got {index_params['M']}") + params["M"] = index_params["M"] + + if ('ef' in index_params) and (index_params["ef"] is not None): + if index_params["ef"] < 0: + raise ValueError(f"Expected ef to be >= 0, got {index_params['ef']}") + params["ef"] = index_params["ef"] + + return params class Hnswlib(Index): _collection_uuid = None @@ -24,7 +49,7 @@ class Hnswlib(Index): def __init__(self, settings): self._save_folder = settings.persist_directory + "/index" - def run(self, collection_uuid, uuids, embeddings, space="l2", ef=10, num_threads=4): + def run(self, collection_uuid, uuids, embeddings, index_params: HnswIndexParams = DEFAULT_INDEX_PARAMS, num_threads=4): # more comments available at the source: https://github.com/nmslib/hnswlib dimensionality = len(embeddings[0]) for uuid, i in zip(uuids, range(len(uuids))): @@ -32,20 +57,21 @@ def run(self, collection_uuid, uuids, embeddings, space="l2", ef=10, num_threads self._uuid_to_id[uuid.hex] = i index = hnswlib.Index( - space=space, dim=dimensionality + space=index_params["space"], dim=dimensionality ) # possible options are l2, cosine or ip - index.init_index(max_elements=len(embeddings), ef_construction=100, M=16) - index.set_ef(ef) + index.init_index(max_elements=len(embeddings), ef_construction=100, M=index_params["M"]) + index.set_ef(index_params["ef"]) index.set_num_threads(num_threads) index.add_items(embeddings, range(len(uuids))) self._index = index self._collection_uuid = collection_uuid - self._index_metadata = { - "dimensionality": dimensionality, - "elements": len(embeddings), - "time_created": time.time(), - } + self._index_metadata = IndexMetadata( + dimensionality=dimensionality, + elements=len(embeddings), + time_created=time.time(), + index_params=index_params, + ) self._save() def get_metadata(self) -> IndexMetadata: @@ -53,12 +79,18 @@ def get_metadata(self) -> IndexMetadata: raise NoIndexException("Index is not initialized") return self._index_metadata - def add_incremental(self, collection_uuid, uuids, embeddings): + def add_incremental(self, collection_uuid, uuids, embeddings, index_params: Optional[HnswIndexParams] = None): if self._collection_uuid != collection_uuid: self._load(collection_uuid) if self._index is None: - self.run(collection_uuid, uuids, embeddings) + index_params = check_hnsw_index_params(index_params) + self.run( + collection_uuid, + uuids, + embeddings, + index_params=index_params, + ) elif self._index is not None: idx_dimension = self.get_metadata()["dimensionality"] @@ -153,7 +185,9 @@ def _load(self, collection_uuid): self._uuid_to_id = pickle.load(f) with open(f"{self._save_folder}/index_metadata_{collection_uuid}.pkl", "rb") as f: self._index_metadata = pickle.load(f) - p = hnswlib.Index(space="l2", dim=self._index_metadata["dimensionality"]) + + space = self._index_metadata['index_params']['space'] if 'index_params' in self._index_metadata else "l2" + p = hnswlib.Index(space=space, dim=self._index_metadata["dimensionality"]) self._index = p self._index.load_index( f"{self._save_folder}/index_{collection_uuid}.bin", diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index b0895f62ff0..a975e423d01 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -134,6 +134,7 @@ def create_collection(self, collection: CreateCollection): name=collection.name, metadata=collection.metadata, get_or_create=collection.get_or_create, + index_params=collection.index_params, ) def get_collection(self, collection_name: str): diff --git a/chromadb/server/fastapi/types.py b/chromadb/server/fastapi/types.py index 5accd05e735..b09f66d54b2 100644 --- a/chromadb/server/fastapi/types.py +++ b/chromadb/server/fastapi/types.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from typing import List, Union, Any -from chromadb.api.types import Include +from typing import List, Union, Any, Optional +from chromadb.api.types import Include, HnswIndexParams # type supports single and batch mode class AddEmbedding(BaseModel): @@ -64,9 +64,11 @@ class DeleteEmbedding(BaseModel): class CreateCollection(BaseModel): name: str metadata: dict = None + index_params: Optional[HnswIndexParams] = None get_or_create: bool = False class UpdateCollection(BaseModel): new_name: str = None new_metadata: dict = None + new_index_params: Optional[HnswIndexParams] = None diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 1884c2ac3af..6221bd81283 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -687,7 +687,7 @@ def test_metadata_validation_add(api_fixture, request): api.reset() collection = api.create_collection("test_metadata_validation") - with pytest.raises(ValueError, match='metadata'): + with pytest.raises(ValueError, match="metadata"): collection.add(**bad_metadata_records) @@ -698,7 +698,7 @@ def test_metadata_validation_update(api_fixture, request): api.reset() collection = api.create_collection("test_metadata_validation") collection.add(**metadata_records) - with pytest.raises(ValueError, match='metadata'): + with pytest.raises(ValueError, match="metadata"): collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}}) @@ -708,7 +708,7 @@ def test_where_validation_get(api_fixture, request): api.reset() collection = api.create_collection("test_where_validation") - with pytest.raises(ValueError, match='where'): + with pytest.raises(ValueError, match="where"): collection.get(where={"value": {"nested": "5"}}) @@ -718,7 +718,7 @@ def test_where_validation_query(api_fixture, request): api.reset() collection = api.create_collection("test_where_validation") - with pytest.raises(ValueError, match='where'): + with pytest.raises(ValueError, match="where"): collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}}) @@ -908,13 +908,13 @@ def test_query_document_valid_operators(api_fixture, request): api.reset() collection = api.create_collection("test_where_valid_operators") collection.add(**operator_records) - with pytest.raises(ValueError, match='where document'): + with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$lt": {"$nested": 2}}) - with pytest.raises(ValueError, match='where document'): + with pytest.raises(ValueError, match="where document"): collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2}) - with pytest.raises(ValueError, match='where document'): + with pytest.raises(ValueError, match="where document"): collection.get(where_document={"$contains": []}) # Test invalid $and, $or @@ -1177,10 +1177,10 @@ def test_get_include(api_fixture, request): assert items["embeddings"] == None assert items["ids"][0] == "id1" - with pytest.raises(ValueError, match='include'): + with pytest.raises(ValueError, match="include"): items = collection.get(include=["metadatas", "undefined"]) - with pytest.raises(ValueError, match='include'): + with pytest.raises(ValueError, match="include"): items = collection.get(include=None) @@ -1224,3 +1224,72 @@ def test_invalid_id(api_fixture, request): with pytest.raises(ValueError) as e: collection.delete(ids=["valid", 0]) assert "ID" in str(e.value) + + +# test setting index_params +@pytest.mark.parametrize("api_fixture", test_apis) +def test_index_params(api_fixture, request): + api = request.getfixturevalue(api_fixture.__name__) + + # first standard add + api.reset() + collection = api.create_collection(name="test_index_params") + collection.add(**records) + items = collection.query( + query_embeddings=[0.6, 1.12, 1.6], + n_results=1, + ) + assert items["distances"][0][0] > 4 + + # cosine + api.reset() + collection = api.create_collection( + name="test_index_params", index_params={"space": "cosine", "ef": 20, "M": 5} + ) + collection.add(**records) + items = collection.query( + query_embeddings=[0.6, 1.12, 1.6], + n_results=1, + ) + assert items["distances"][0][0] > 0 + assert items["distances"][0][0] < 1 + + # ip + api.reset() + collection = api.create_collection(name="test_index_params", index_params={"space": "ip"}) + collection.add(**records) + items = collection.query( + query_embeddings=[0.6, 1.12, 1.6], + n_results=1, + ) + assert items["distances"][0][0] < -5 + +# test loading from disk where index_params isnt set yet (migration case) +# + +# test persisting and loading these indexes from disk +@pytest.mark.parametrize("api_fixture", [local_persist_api]) +def test_persist_index_loading(api_fixture, request): + api = request.getfixturevalue("local_persist_api") + api.reset() + collection = api.create_collection("test") + collection.add(ids="id1", documents="hello") + + api.persist() + del api + + api2 = request.getfixturevalue("local_persist_api_cache_bust") + collection = api2.get_collection("test") + + nn = collection.query( + query_texts="hello", + n_results=1, + include=["embeddings", "documents", "metadatas", "distances"], + ) + for key in nn.keys(): + assert len(nn[key]) == 1 + + +# test allow updating the index type but only if the index is empty + +# test other index params, EF and M