Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: FastEmbed embedding function support #1986

Closed
wants to merge 14 commits into from
13 changes: 13 additions & 0 deletions chromadb/test/ef/test_fastembed_ef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from chromadb.utils.embedding_functions import FastEmbedEmbeddingFunction

# Skip test if the 'fastembed' package is not installed is not installed
fastembed = pytest.importorskip("fastembed", reason="fastembed not installed")


def test_fastembed() -> None:
ef = FastEmbedEmbeddingFunction(model_name="BAAI/bge-small-en-v1.5")
embeddings = ef(["Here is an article about llamas...", "this is another article"])
assert len(embeddings) == 2
assert len(embeddings[0]) == 384
91 changes: 80 additions & 11 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:


class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]):
def __init__(
self, api_key: str = "", api_url = "https://infer.roboflow.com"
) -> None:
def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None:
"""
Create a RoboflowEmbeddingFunction.

Expand All @@ -757,7 +755,7 @@ def __init__(
api_key = os.environ.get("ROBOFLOW_API_KEY")

self._api_url = api_url
self._api_key = api_key
self._api_key = api_key

try:
self._PILImage = importlib.import_module("PIL.Image")
Expand Down Expand Up @@ -789,10 +787,10 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
json=infer_clip_payload,
)

result = res.json()['embeddings']
result = res.json()["embeddings"]

embeddings.append(result[0])

elif is_document(item):
infer_clip_payload = {
"text": input,
Expand All @@ -803,13 +801,13 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
json=infer_clip_payload,
)

result = res.json()['embeddings']
result = res.json()["embeddings"]

embeddings.append(result[0])

return embeddings


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
Expand Down Expand Up @@ -909,7 +907,8 @@ def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
)

class ChromaLangchainEmbeddingFunction(
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore
LangchainEmbeddings,
EmbeddingFunction[Union[Documents, Images]], # type: ignore
):
"""
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
Expand Down Expand Up @@ -962,7 +961,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore

return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
Expand Down Expand Up @@ -1018,7 +1017,77 @@ def __call__(self, input: Documents) -> Embeddings:
],
)



class FastEmbedEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using FastEmbed - https://qdrant.github.io/fastembed/.
Find the list of supported models at https://qdrant.github.io/fastembed/examples/Supported_Models/.
"""

def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
batch_size: int = 256,
tazarov marked this conversation as resolved.
Show resolved Hide resolved
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
parallel: Optional[int] = None,
**kwargs,
) -> None:
"""
Initialize fastembed.TextEmbedding

Args:
model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
Defaults to 256.
cache_dir (str, optional): The path to the model cache directory.\
Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
threads (int, optional): The number of threads single onnxruntime session can use.
parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for offline encoding of large datasets.\
If `0`, use all available cores.\
If `None`, don't use data-parallel processing, use default onnxruntime threading instead.\
Defaults to None.
**kwargs: Additional options to pass to fastembed.TextEmbedding

Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""
try:
from fastembed import TextEmbedding
except ImportError:
raise ValueError(
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`"
)
self._batch_size = batch_size
self._parallel = parallel
self._model = TextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs
)

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.

Args:
input (Documents): A list of texts to get embeddings for.

Returns:
Embeddings: The embeddings for the texts.

Example:
>>> fastembed_ef = FastEmbedEmbeddingFunction(model_name="sentence-transformers/all-MiniLM-L6-v2")
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = fastembed_ef(texts)
"""
embeddings = self._model.embed(
input, batch_size=self._batch_size, parallel=self._parallel
)
return cast(
Embeddings,
[embedding.tolist() for embedding in embeddings],
)


# List of all classes in this module
_classes = [
name
Expand Down
Loading