Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added **kwargs for embedding funcs #7664

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _invocation_params(self) -> Dict:
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None, **kwargs
) -> List[List[float]]:
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try:
Expand Down Expand Up @@ -329,6 +329,7 @@ def _get_len_safe_embeddings(
self,
input=tokens[i : i + _chunk_size],
**self._invocation_params,
**kwargs
)
batched_embeddings += [r["embedding"] for r in response["data"]]

Expand Down Expand Up @@ -424,7 +425,7 @@ async def _aget_len_safe_embeddings(

return embeddings

def _embedding_func(self, text: str, *, engine: str) -> List[float]:
def _embedding_func(self, text: str, *, engine: str, **kwargs) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if len(text) > self.embedding_ctx_length:
Expand All @@ -438,6 +439,7 @@ def _embedding_func(self, text: str, *, engine: str) -> List[float]:
self,
input=[text],
**self._invocation_params,
**kwargs
)[
"data"
][0]["embedding"]
Expand All @@ -461,7 +463,7 @@ async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
)["data"][0]["embedding"]

def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
self, texts: List[str], chunk_size: Optional[int] = 0, **kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this signature is defined on base embeddings class, probably don't want to break the interface just here and im not sure we should add it to the base class.

what if added something like this instead? https://github.com/hwchase17/langchain/blob/c7b687e944883df972cabdf00064112587306daf/langchain/llms/openai.py#L137

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that I need to also pass a parameter model. If I place stuff in model_kwargs (in other places where I can)

llm = AzureOpenAI(deployment_name="text-davinci-003",model_name="text-davinci-003", model_kwargs={ "user" : "JL", "model" : "text-davinci-003" })
__root__
  Parameters {'model'} should be specified explicitly. Instead they were passed in as part of `model_kwargs` parameter. (type=value_error)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for embeddings you can already pass model in directly, no?

OpenAIEmbeddings(model=...)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but I need to place it also on the request body because that is parsed by our custom proxy 😢

) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.

Expand All @@ -475,7 +477,7 @@ def embed_documents(
"""
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment)
return self._get_len_safe_embeddings(texts, engine=self.deployment, **kwargs)

async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
Expand All @@ -494,7 +496,7 @@ async def aembed_documents(
# than the maximum context and use length-safe embedding function.
return await self._aget_len_safe_embeddings(texts, engine=self.deployment)

def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str, **kwargs) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.

Args:
Expand All @@ -503,7 +505,7 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.deployment)
embedding = self._embedding_func(text, engine=self.deployment, **kwargs)
return embedding

async def aembed_query(self, text: str) -> List[float]:
Expand Down
Loading