Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

12 featfew updates #13

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ This module levereges chromaDB to cache embeddings for reusability.
It helps storing and fetching embeddings of chromaDB.

In a nutshell :
- it creates a vector store with the model name as a collection name
- it creates a vector store with the model name / provider as a collection name
- it encodes any sentence using the specified embedding function using the ``encode()`` method
- among the sentences provided in the ``encode`` method, it calls the model (or the api) to embed those for which the embedding are not already available in chromaDB

Expand All @@ -15,48 +15,56 @@ Here are the installation steps:
- 1) If you haven't already, clone this repository.
- 2) Activate your python environment (or shell)
- 3) when your are in this repository, run ``pip install .`` to install this repo as a package

⚠️ if you run into an error due to faiseq, this is du to incompatibility between fairseq, win11 and python > 3.9. Either downgrade your python version, or install this fix in your environment :
```bash
pip install fairseq git+https://github.com/liyaodev/fairseq.git
```
- 4) [Optionnal] To make the package lighter, sentence-transformers dependencies are optionnal. If you are planning to use models from HuggingFace 🤗 through the sentence-transfomer package, you may use ``pip install ".[st]"`` instead.

***TODO*** : Add instruction for installation as pypi package

## Usage

Fast & simple usage :

```py
from chromacache import ChromaCache
from chromacache.embedding_functions import OpenAIEmbeddingFunction

MODEL_NAME = "text-embedding-3-small" # or any embedding model name
emb_function = OpenAIEmbeddingFunction() # or any embedding function available
cc = ChromaCache(OpenAIEmbeddingFunction(MODEL_NAME)) # creates a collection in chroma
from chromacache.embedding_functions import OpenAIEmbeddingFunction # or any embedding function available

cc = ChromaCache(OpenAIEmbeddingFunction())
embeddings = cc.encode(["my sentence", "my other sentence"])
```

## Extra features

The ``ChromaCache`` supports extra arguments :
- ***batch_size***: int = 32, the batch size at which sentences are processed. If the model's provider API raises an error due to the size of the request being exceeded, it might be a good idea to decrease this
- ***save_embbedings***: bool = True, whether or not the embeddings should be saved
- ***path_to_chromadb***: str = "./Chroma", where the chromadb should be stored

All embedding functions also support the ``max_token_length`` argument. This can be used to crop each sentence to the max token size supported by the model's provider API
Usage when using huggingface's ``datasets`` package.

Example usage :
```py
emb_function = MistralAIEmbeddingFunction("mistral-embed", max_token_length=4000)
import datasets
from chromacache import ChromaCache
from chromacache.embedding_functions import AzureEmbeddingFunction # or any embedding function available

emb_function = AzureEmbeddingFunction(
model_name="text-embedding-3-large",
dimensions=768,
max_requests_per_minute=300,
)
cc = ChromaCache(
emb_func,
batch_size=4,
save_embedding=False,
path_to_chromdb="./my_favorite_directory"
batch_size=16,
path_to_chromadb="path/to/my/chromadb/folder",
max_token_length=8191 # adapt to the model you use
)
# let's assume this return a 'Dataset' object, with a 'text' column we want to embed
mydataset = datasets.load_dataset("PathToMyDataset/OnHuggingFace")
mydataset = mydataset.add_column(
"embeddings", cc.encode(mydataset["text"])
)
```

Moreover, all capabilities of the chromaDB collections can be leveraged directly using the ``collection`` attribute of the ChromaCache.
## Extra arguments

The ``ChromaCache`` supports extra arguments :
- ***batch_size***: int = 32, the batch size at which sentences are processed. If the model's provider API raises an error due to the size of the request being exceeded, it might be a good idea to decrease this
- ***save_embbedings***: bool = True, whether or not the embeddings should be saved
- ***path_to_chromadb***: str = "./Chroma", where the chromadb should be stored
- ***max_token_length***: int = 8191, texts longer than this amount of tokens will be truncated to avoid API Errors.

Moreover, all [capabilities of the chromaDB collections](https://docs.trychroma.com/reference/py-collection) can be leveraged directly using the ``collection`` attribute of the ChromaCache. Hence, you can query, delete, ... any collection.
For example, to query the collection for the 5 documents:
```py
cc = ChromaCache(VoyageAIEmbeddingFunction("voyage-code-2"))
Expand Down
37 changes: 30 additions & 7 deletions chromacache/chromacache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import chromadb
from chromadb import EmbeddingFunction
import tiktoken
from tqdm import tqdm


class ChromaCache:
"""Handles the mecanics of producing and saving embeddings in chromaDB
"""Handles the mechanics of producing and saving embeddings in chromaDB
It needs an embedding function, as described in the chromaDB's docs : https://docs.trychroma.com/embeddings
This embedding function specifies the way embeddings are obtained, from a model or api
"""
Expand All @@ -22,8 +24,8 @@ def __init__(
embedding_function: EmbeddingFunction = None,
batch_size: int = 32,
save_embbedings: bool = True,
path_to_chromadb="./ChromaDB",
**kwargs,
path_to_chromadb: str = "./ChromaDB",
max_token_length: int = 8191,
):
self.batch_size = batch_size
self.save_embbeddings = save_embbedings
Expand All @@ -36,19 +38,24 @@ def __init__(

# setup the chromaDB collection
self.client = chromadb.PersistentClient(path=path_to_chromadb)
collection_name = embedding_function.model_name.replace("/", "-")[:63]
collection_name = embedding_function.collection_name.replace("/", "-")[:63]
self.collection = self.client.get_or_create_collection(
name=collection_name,
embedding_function=self.embedding_function,
metadata={"hnsw:space": "cosine"},
)

self.max_token_length = max_token_length
# Use tiktoken to compute token length
# As we may not know the exact tokenizer used for the model, we generically use the one of adav2
self.tokenizer = tiktoken.get_encoding("cl100k_base")

@staticmethod
def list_embedding_functions():
# TODO : function to list the available embedding functions
raise NotImplementedError

def encode(self, sentences: Documents, **kwargs) -> Embeddings:
def encode(self, sentences: Documents) -> Embeddings:
"""Encodes the provided sentences et gets their embeddings
using the EmbeddingFunction that has been set.
It works like so :
Expand All @@ -58,7 +65,6 @@ def encode(self, sentences: Documents, **kwargs) -> Embeddings:

Args:
sentences (Documents): the list of strings that must be encoded
**kwargs: additional keyword arguments.

Returns:
Embeddings: the list of embeddings corresponding the the list of strings
Expand All @@ -73,8 +79,11 @@ def encode(self, sentences: Documents, **kwargs) -> Embeddings:
# use a dict to store a mapping of {sentence: embedding}
# we have to do this because collection.get() returns embeddings in a random order...
sent_emb_mapping = {}
for i in range(0, len(unique_sentences), self.batch_size):

for i in tqdm(range(0, len(unique_sentences), self.batch_size)):
# create batch and truncate
batch_sentences = unique_sentences[i : i + self.batch_size]
batch_sentences = self.truncate_documents(batch_sentences)
# check if we have the embedding in chroma
sentences_in_chroma = self.collection.get(
ids=batch_sentences, include=["documents", "embeddings"]
Expand Down Expand Up @@ -119,3 +128,17 @@ def encode(self, sentences: Documents, **kwargs) -> Embeddings:
)
# return embeddings in correct order
return [sent_emb_mapping[s] for s in sentences]

def truncate_documents(self, sentences: Documents) -> Documents:
"""Truncates the sentences considering the max context window of the model

Args:
sentences (Documents): a list a sentences (documents)

Returns:
Documents: the truncated documents
"""
return [
self.tokenizer.decode(self.tokenizer.encode(s)[: self.max_token_length])
for s in sentences
]
68 changes: 12 additions & 56 deletions chromacache/embedding_functions/AbstractEmbeddingFunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,77 +8,33 @@

from abc import ABC, abstractmethod

import tiktoken
from chromadb import Documents, EmbeddingFunction, Embeddings


class AbstractEmbeddingFunction(EmbeddingFunction, ABC):
class AbstractEmbeddingFunction(EmbeddingFunction, ABC): # type: ignore --> missing typing in chromaDB
"""Base class for all embedding functions"""

def __init__(
self,
max_token_length: int = 4096,
):
self.max_token_length = max_token_length
# Use tiktoken to compute token length
# As we may not know the exact tokenizer used for the model, we generically use the one of adav2
self.tokenizer = tiktoken.get_encoding("cl100k_base")
model_name: str,
) -> None:
self.model_name = model_name

@property
@abstractmethod
def model_name(self):
pass

@staticmethod
def _truncate_documents(
tokenizer, sentences: Documents, max_token_length: int
) -> Documents:
"""Truncates the sentences considering the max context window of the model

Args:
tokenizer : the tokenizer
sentences (Documents): a list a sentences (documents)
max_token_length (int): the maximum token length

Returns:
Documents: the truncated documents
"""
truncated_input = []
for s in sentences:
tokenized_string = tokenizer.encode(s)
# if string too large, truncate, decode, and replace
if len(tokenized_string) > max_token_length:
tokenized_string = tokenized_string[:max_token_length]
truncated_input.append(tokenizer.decode(tokenized_string))
else:
truncated_input.append(s)

return truncated_input
def collection_name(self) -> str:
"""Used as the collection name by chroma cache. Must lead to unique name per model"""

def truncate_documents(self, sentences: Documents) -> Documents:
"""Truncates the sentences considering the max context window of the model

Args:
sentences (Documents): a list a sentences (documents)

Returns:
Documents: the truncated documents
"""
return self._truncate_documents(
self.tokenizer, sentences, self.max_token_length
)

def __call__(self, input: Documents) -> Embeddings:
"""Wrapper that truncates the documents, encodes them
def __call__(self, documents: Documents) -> Embeddings:
"""Encodes the documents

Args:
documents (Documents): List of documents

Returns:
Embeddings: the encoded sentences
"""
truncated_input = self.truncate_documents(input)
embeddings = self.encode_documents(truncated_input)

return embeddings
return self.encode_documents(documents)

@abstractmethod
def encode_documents(self, documents: Documents) -> Embeddings:
Expand All @@ -94,4 +50,4 @@ def encode_documents(self, documents: Documents) -> Embeddings:
Returns:
Embeddings: list of embeddings
"""
raise NotImplementedError()
pass
64 changes: 39 additions & 25 deletions chromacache/embedding_functions/CohereEmbeddingFunction.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,50 @@
import os
from typing import Literal

import cohere
from chromadb import Documents, Embeddings
from chromadb.utils.embedding_functions import CohereEmbeddingFunction as CoEmbFunc
from dotenv import load_dotenv
from litellm import embedding

from .AbstractEmbeddingFunction import AbstractEmbeddingFunction
from .LiteLLMEmbeddingFunction import LiteLLMEmbeddingFunction

load_dotenv()


class CohereEmbeddingFunction(AbstractEmbeddingFunction):
class CohereEmbeddingFunction(LiteLLMEmbeddingFunction):
def __init__(
self,
model_name: str = "Cohere/Cohere-embed-multilingual-light-v3.0",
max_token_length: int = 512,
model_name: str = "embed-multilingual-light-v3.0",
input_type: Literal[
"search_document", "search_query", "classification", "clustering"
] = "search_document",
dimensions: int | None = None,
max_requests_per_minute: int = 2000, # Prod, or 100 for trial
):
super().__init__(max_token_length)
self._model_name = model_name

api_key = os.getenv("COHERE_API_KEY", None)
if api_key is None:
raise ValueError(
"Please make sure 'COHERE_API_KEY' is setup as an environment variable"
)

self.client = cohere.Client(api_key)
CoEmbFunc.__init__(self, api_key=api_key, model_name=model_name)
LiteLLMEmbeddingFunction.__init__(
self, model_name, dimensions, max_requests_per_minute
)
self.input_type = input_type

@property
def model_name(self):
return self._model_name
def api_key_name(self):
return "COHERE_API_KEY"

def encode_documents(self, input: Documents) -> Embeddings:
return CoEmbFunc.__call__(self, input)
@property
def litellm_provider_prefix(self):
return "cohere"

def encode_documents(self, documents: Documents) -> Embeddings:
"""Takes a list of strings and returns the corresponding embedding

Args:
documents (Documents): list of documents (strings)

Returns:
Embeddings: list of embeddings
"""
# replace empty string to avoid errors with apis
documents = [d if d else " " for d in documents]
response = embedding(
model=f"{self.litellm_provider_prefix}/{self.model_name}",
input=documents,
input_type=self.input_type,
dimensions=self.dimensions,
)

return [resp["embedding"] for resp in response.data] # type: ignore --> missing typing for response.data
Loading