From fcec04942b85f32961ba43eb1d392c00b0035cb2 Mon Sep 17 00:00:00 2001 From: FrancescoSaverioZuppichini Date: Thu, 13 Jul 2023 17:14:40 +0200 Subject: [PATCH] added **kwargs for embedding funcs --- langchain/embeddings/openai.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index c7ebf277e8719..c8311f9a30b40 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -276,7 +276,7 @@ def _invocation_params(self) -> Dict: # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb def _get_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None, **kwargs ) -> List[List[float]]: embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: @@ -329,6 +329,7 @@ def _get_len_safe_embeddings( self, input=tokens[i : i + _chunk_size], **self._invocation_params, + **kwargs ) batched_embeddings += [r["embedding"] for r in response["data"]] @@ -424,7 +425,7 @@ async def _aget_len_safe_embeddings( return embeddings - def _embedding_func(self, text: str, *, engine: str) -> List[float]: + def _embedding_func(self, text: str, *, engine: str, **kwargs) -> List[float]: """Call out to OpenAI's embedding endpoint.""" # handle large input text if len(text) > self.embedding_ctx_length: @@ -438,6 +439,7 @@ def _embedding_func(self, text: str, *, engine: str) -> List[float]: self, input=[text], **self._invocation_params, + **kwargs )[ "data" ][0]["embedding"] @@ -461,7 +463,7 @@ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: )["data"][0]["embedding"] def embed_documents( - self, texts: List[str], chunk_size: Optional[int] = 0 + self, texts: List[str], chunk_size: Optional[int] = 0, **kwargs ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. @@ -475,7 +477,7 @@ def embed_documents( """ # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. - return self._get_len_safe_embeddings(texts, engine=self.deployment) + return self._get_len_safe_embeddings(texts, engine=self.deployment, **kwargs) async def aembed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 @@ -494,7 +496,7 @@ async def aembed_documents( # than the maximum context and use length-safe embedding function. return await self._aget_len_safe_embeddings(texts, engine=self.deployment) - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str, **kwargs) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. Args: @@ -503,7 +505,7 @@ def embed_query(self, text: str) -> List[float]: Returns: Embedding for the text. """ - embedding = self._embedding_func(text, engine=self.deployment) + embedding = self._embedding_func(text, engine=self.deployment, **kwargs) return embedding async def aembed_query(self, text: str) -> List[float]: