Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cohere[patch]: Add additional kwargs support for Cohere SDK params #19533

Merged
37 changes: 19 additions & 18 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def get_role(message: BaseMessage) -> str:

def get_cohere_chat_request(
messages: List[BaseMessage],
*,
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
**kwargs: Any
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.

Expand All @@ -60,19 +58,22 @@ def get_cohere_chat_request(
Returns:
The request for the Cohere chat API.
"""
documents = (
None
if "source_documents" not in kwargs
else [
{
"snippet": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(kwargs["source_documents"])
]
)
kwargs.pop("source_documents", None)
maybe_connectors = connectors if documents is None else None
additional_kwargs = messages[-1].additional_kwargs

# cohere SDK will fail loudly if both connectors and documents are provided
documents = [
{
"snippet": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(additional_kwargs.get("documents", []))
]
if not documents:
documents = None

connectors = [{"id": connector} for connector in additional_kwargs.get("connectors", [])]
if not connectors:
connectors = None

# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
Expand All @@ -85,10 +86,10 @@ def get_cohere_chat_request(
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
],
**additional_kwargs, # params below will overwrite these values
"documents": documents,
"connectors": maybe_connectors,
"connectors": connectors,
"prompt_truncation": prompt_truncation,
giannis2two marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
}

return {k: v for k, v in req.items() if v is not None}
Expand Down