Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set distance function and other index params #228

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QueryResult,
GetResult,
WhereDocument,
HnswIndexParams,
)
import json

Expand Down Expand Up @@ -56,6 +57,7 @@ def create_collection(
metadata: Optional[Dict] = None,
get_or_create: bool = False,
embedding_function: Optional[Callable] = None,
index_params: Optional[HnswIndexParams] = None,
) -> Collection:
"""Creates a new collection in the database

Expand All @@ -64,6 +66,7 @@ def create_collection(
metadata (Optional[Dict], optional): A dictionary of metadata to associate with the collection. Defaults to None.
get_or_create (bool, optional): If True, will return the collection if it already exists. Defaults to False.
embedding_function (Optional[Callable], optional): A function that takes documents and returns an embedding. Defaults to None.
index_params (Optional[HnswIndexParams], optional): The parameters to use for the HNSW index. Defaults to None.

Returns:
dict: the created collection
Expand All @@ -83,7 +86,12 @@ def delete_collection(
"""

@abstractmethod
def get_or_create_collection(self, name: str, metadata: Optional[Dict] = None) -> Collection:
def get_or_create_collection(
self,
name: str,
metadata: Optional[Dict] = None,
index_params: Optional[HnswIndexParams] = None,
) -> Collection:
"""Calls create_collection with get_or_create=True

Args:
Expand Down
20 changes: 18 additions & 2 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable, Dict, Optional
from chromadb.api import API
from chromadb.api.local import check_hnsw_index_params
from chromadb.api.types import (
Documents,
Embeddings,
HnswIndexParams,
IDs,
Include,
Metadatas,
Expand Down Expand Up @@ -45,11 +47,21 @@ def create_collection(
metadata: Optional[Dict] = None,
embedding_function: Optional[Callable] = None,
get_or_create: bool = False,
index_params: Optional[HnswIndexParams] = None,
) -> Collection:
"""Creates a collection"""
index_params = check_hnsw_index_params(index_params)

resp = requests.post(
self._api_url + "/collections",
data=json.dumps({"name": name, "metadata": metadata, "get_or_create": get_or_create}),
data=json.dumps(
{
"name": name,
"metadata": metadata,
"get_or_create": get_or_create,
"index_params": index_params,
},
),
)
resp.raise_for_status()
resp_json = resp.json()
Expand All @@ -58,6 +70,7 @@ def create_collection(
name=resp_json["name"],
embedding_function=embedding_function,
metadata=resp_json["metadata"],
index_params=resp_json["index_params"],
)

def get_collection(
Expand All @@ -81,10 +94,13 @@ def get_or_create_collection(
name: str,
metadata: Optional[Dict] = None,
embedding_function: Optional[Callable] = None,
index_params: Optional[HnswIndexParams] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something we'll have to deal with alongside persisting embedding functions, but fine for now.

) -> Collection:
"""Get a collection, or return it if it exists"""

return self.create_collection(name, metadata, embedding_function, get_or_create=True)
return self.create_collection(
name, metadata, embedding_function, get_or_create=True, index_params=index_params
)

def _modify(self, current_name: str, new_name: str, new_metadata: Optional[Dict] = None):
"""Updates a collection"""
Expand Down
45 changes: 31 additions & 14 deletions chromadb/api/local.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import uuid
import time
from typing import Dict, List, Optional, Sequence, Callable, Type, cast
from typing import Dict, List, Literal, Optional, Sequence, Callable, Type, cast, TypedDict
from chromadb.api import API
from chromadb.db import DB
from chromadb.api.types import (
Expand All @@ -15,21 +15,26 @@
QueryResult,
Where,
WhereDocument,
HnswIndexParams,
)
from chromadb.api.models.Collection import Collection

import re

from chromadb.db.index.hnswlib import check_hnsw_index_params


# mimics s3 bucket requirements for naming
def check_index_name(index_name):
msg = ("Expected collection name that "
"(1) contains 3-63 characters, "
"(2) starts and ends with an alphanumeric character, "
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), "
"(4) contains no two consecutive periods (..) and "
"(5) is not a valid IPv4 address, "
f"got {index_name}")
msg = (
"Expected collection name that "
"(1) contains 3-63 characters, "
"(2) starts and ends with an alphanumeric character, "
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), "
"(4) contains no two consecutive periods (..) and "
"(5) is not a valid IPv4 address, "
f"got {index_name}"
)
if len(index_name) < 3 or len(index_name) > 63:
raise ValueError(msg)
if not re.match("^[a-z0-9][a-z0-9._-]*[a-z0-9]$", index_name):
Expand All @@ -39,7 +44,6 @@ def check_index_name(index_name):
if re.match("^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$", index_name):
raise ValueError(msg)


class LocalAPI(API):
def __init__(self, settings, db: DB):
self._db = db
Expand All @@ -56,21 +60,30 @@ def create_collection(
metadata: Optional[Dict] = None,
embedding_function: Optional[Callable] = None,
get_or_create: bool = False,
index_params: Optional[HnswIndexParams] = None,
) -> Collection:
check_index_name(name)
index_params = check_hnsw_index_params(index_params)

res = self._db.create_collection(name, metadata, get_or_create)
res = self._db.create_collection(name, metadata, get_or_create, index_params)
return Collection(
client=self, name=name, embedding_function=embedding_function, metadata=res[0][2]
client=self,
name=name,
embedding_function=embedding_function,
metadata=res[0][2],
index_params=res[0][3],
)

def get_or_create_collection(
self,
name: str,
metadata: Optional[Dict] = None,
embedding_function: Optional[Callable] = None,
index_params: Optional[HnswIndexParams] = None,
) -> Collection:
return self.create_collection(name, metadata, embedding_function, get_or_create=True)
return self.create_collection(
name, metadata, embedding_function, get_or_create=True, index_params=index_params
)

def get_collection(
self,
Expand Down Expand Up @@ -119,8 +132,12 @@ def _add(
documents: Optional[Documents] = None,
increment_index: bool = True,
):
collection = self._db.get_fields_from_collection_name(
collection_name, fields=["uuid", "index_params"]
)
collection_uuid = collection[0]
index_params = collection[1]

collection_uuid = self._db.get_collection_uuid_from_name(collection_name)
added_uuids = self._db.add(
collection_uuid,
embeddings=embeddings,
Expand All @@ -130,7 +147,7 @@ def _add(
)

if increment_index:
self._db.add_incremental(collection_uuid, added_uuids, embeddings)
self._db.add_incremental(collection_uuid, added_uuids, embeddings, index_params)

return True # NIT: should this return the ids of the succesfully added items?

Expand Down
6 changes: 4 additions & 2 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from chromadb.api.types import (
Embedding,
HnswIndexParams,
Include,
Metadata,
Document,
Expand Down Expand Up @@ -32,6 +33,7 @@
class Collection(BaseModel):
name: str
metadata: Optional[Dict] = None
index_params: Optional[HnswIndexParams] = None
_client: "API" = PrivateAttr()
_embedding_function: Optional[EmbeddingFunction] = PrivateAttr()

Expand All @@ -41,6 +43,7 @@ def __init__(
name: str,
embedding_function: Optional[EmbeddingFunction] = None,
metadata: Optional[Dict] = None,
index_params: Optional[HnswIndexParams] = None,
):

self._client = client
Expand All @@ -49,12 +52,11 @@ def __init__(
else:
import chromadb.utils.embedding_functions as ef


logger.warning(
"No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction"
)
self._embedding_function = ef.SentenceTransformerEmbeddingFunction()
super().__init__(name=name, metadata=metadata)
super().__init__(name=name, metadata=metadata, index_params=index_params)

def __repr__(self):
return f"Collection(name={self.name})"
Expand Down
18 changes: 14 additions & 4 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ class QueryResult(TypedDict):
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]

class HnswIndexParams(TypedDict):
space: Optional[Literal["l2", "cosine", "ip"]]
ef: Optional[int]
M: Optional[int]

class IndexMetadata(TypedDict):
dimensionality: int
elements: int
time_created: float

index_params: HnswIndexParams

class EmbeddingFunction(Protocol):
def __call__(self, texts: Documents) -> Embeddings:
Expand Down Expand Up @@ -169,10 +173,14 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument:
if not isinstance(where_document, dict):
raise ValueError(f"Expected where document to be a dictionary, got {where_document}")
if len(where_document) != 1:
raise ValueError(f"Epected where document to have exactly one operator, got {where_document}")
raise ValueError(
f"Epected where document to have exactly one operator, got {where_document}"
)
for operator, operand in where_document.items():
if operator not in ["$contains", "$and", "$or"]:
raise ValueError(f"Expected where document operator to be one of $contains, $and, $or, got {operator}")
raise ValueError(
f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
)
if operator == "$and" or operator == "$or":
if not isinstance(operand, list):
raise ValueError(
Expand Down Expand Up @@ -205,5 +213,7 @@ def validate_include(include: Include, allow_distances: bool) -> Include:
if allow_distances:
allowed_values.append("distances")
if item not in allowed_values:
raise ValueError(f"Expected include item to be one of {', '.join(allowed_values)}, got {item}")
raise ValueError(
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
)
return include
29 changes: 26 additions & 3 deletions chromadb/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from typing import Dict, List, Sequence, Optional, Tuple
from uuid import UUID
import numpy.typing as npt
from chromadb.api.types import Embeddings, Documents, IDs, Metadatas, Where, WhereDocument
from chromadb.api.types import (
Embeddings,
Documents,
IDs,
Metadatas,
Where,
WhereDocument,
HnswIndexParams,
)


class DB(ABC):
Expand All @@ -12,7 +20,11 @@ def __init__(self):

@abstractmethod
def create_collection(
self, name: str, metadata: Optional[Dict] = None, get_or_create: bool = False
self,
name: str,
metadata: Optional[Dict] = None,
get_or_create: bool = False,
index_params: Optional[HnswIndexParams] = None,
) -> Sequence:
pass

Expand All @@ -38,6 +50,10 @@ def delete_collection(self, name: str):
def get_collection_uuid_from_name(self, collection_name: str) -> str:
pass

@abstractmethod
def get_fields_from_collection_name(self, collection_name: str, fields: List[str]) -> List:
pass

@abstractmethod
def add(
self,
Expand All @@ -46,11 +62,18 @@ def add(
metadatas: Optional[Metadatas],
documents: Optional[Documents],
ids: List[UUID],
index_params: Optional[HnswIndexParams] = None,
) -> List[UUID]:
pass

@abstractmethod
def add_incremental(self, collection_uuid: str, ids: List[UUID], embeddings: Embeddings):
def add_incremental(
self,
collection_uuid: str,
ids: List[UUID],
embeddings: Embeddings,
index_params: Optional[HnswIndexParams] = None,
):
pass

@abstractmethod
Expand Down
Loading