From a14f15854ac1ebf45a1bee064d86caab8ab32187 Mon Sep 17 00:00:00 2001 From: Nayjest Date: Thu, 9 Nov 2023 00:22:58 +0200 Subject: [PATCH] support of OpenAI package v1.X.X for utils.OpenAIEmbeddingFunction, deployment_id parameter for openai v0.X.X (#1338) ## Description of changes - Add support of OpenAI package v1.X.X for utils.OpenAIEmbeddingFunction - Add Azure OpenAI Deployment ID parameter for openai v0.X.X lib in utils.OpenAIEmbeddingFunction ## Test plan *How are these changes tested?* Tested as dependency of https://github.com/Nayjest/ai-microcore with Azure & openai packages v0.28.1 & v1.0.1, v1.1.0 --- chromadb/utils/embedding_functions.py | 52 ++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 5e38936ef6c..141964465a5 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -85,6 +85,7 @@ def __init__( api_base: Optional[str] = None, api_type: Optional[str] = None, api_version: Optional[str] = None, + deployment_id: Optional[str] = None, ): """ Initialize the OpenAIEmbeddingFunction. @@ -103,6 +104,7 @@ def __init__( api_version (str, optional): The api version for the API. If not provided, it will use the api version for the OpenAI API. This can be used to point to a different deployment, such as an Azure deployment. + deployment_id (str, optional): Deployment ID for Azure OpenAI. """ try: @@ -126,27 +128,61 @@ def __init__( if api_version is not None: openai.api_version = api_version + self._api_type = api_type if api_type is not None: openai.api_type = api_type if organization_id is not None: openai.organization = organization_id - self._client = openai.Embedding + self._v1 = openai.__version__.startswith('1.') + if self._v1: + if api_type == "azure": + self._client = openai.AzureOpenAI( + api_key=api_key, + api_version=api_version, + azure_endpoint=api_base + ).embeddings + else: + self._client = openai.OpenAI( + api_key=api_key, + base_url=api_base + ).embeddings + else: + self._client = openai.Embedding self._model_name = model_name + self._deployment_id = deployment_id def __call__(self, input: Documents) -> Embeddings: # replace newlines, which can negatively affect performance. input = [t.replace("\n", " ") for t in input] # Call the OpenAI Embedding API - embeddings = self._client.create(input=input, engine=self._model_name)["data"] - - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore - - # Return just the embeddings - return [result["embedding"] for result in sorted_embeddings] + if self._v1: + embeddings = self._client.create( + input=input, + model=self._deployment_id or self._model_name + ).data + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore + + # Return just the embeddings + return [result.embedding for result in sorted_embeddings] + else: + if self._api_type == "azure": + embeddings = self._client.create( + input=input, + engine=self._deployment_id or self._model_name + )["data"] + else: + embeddings = self._client.create(input=input, model=self._model_name)["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + + # Return just the embeddings + return [result["embedding"] for result in sorted_embeddings] class CohereEmbeddingFunction(EmbeddingFunction[Documents]):