Skip to content

Commit

Permalink
cohere[patch]: Add cohere tools agent (#19602)
Browse files Browse the repository at this point in the history
**Description**: Adds a cohere tools agent and related notebook.

---------

Co-authored-by: BeatrixCohere <[email protected]>
Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2024
1 parent 5c41f40 commit 3685f8c
Show file tree
Hide file tree
Showing 10 changed files with 510 additions and 20 deletions.
2 changes: 2 additions & 0 deletions libs/partners/cohere/langchain_cohere/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain_cohere.chat_models import ChatCohere
from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.rag_retrievers import CohereRagRetriever
from langchain_cohere.rerank import CohereRerank
Expand All @@ -9,4 +10,5 @@
"CohereEmbeddings",
"CohereRagRetriever",
"CohereRerank",
"create_cohere_tools_agent",
]
97 changes: 84 additions & 13 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import json
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
)

from cohere.types import NonStreamedChatResponse, ToolCall
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
Expand All @@ -18,7 +31,11 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool

from langchain_cohere.cohere_agent import _format_to_cohere_tools
from langchain_cohere.llms import BaseCohere


Expand Down Expand Up @@ -143,6 +160,14 @@ def _default_params(self) -> Dict[str, Any]:
}
return {k: v for k, v in base_params.items() if v is not None}

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = _format_to_cohere_tools(tools)
return super().bind(tools=formatted_tools, **kwargs)

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
Expand All @@ -169,6 +194,14 @@ def _stream(
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
elif data.event_type == "stream-end":
generation_info = self._get_generation_info(data.response)
yield ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs=generation_info
),
generation_info=generation_info,
)

async def _astream(
self,
Expand All @@ -191,16 +224,34 @@ async def _astream(
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

def _get_generation_info(self, response: Any) -> Dict[str, Any]:
elif data.event_type == "stream-end":
generation_info = self._get_generation_info(data.response)
yield ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs=generation_info
),
generation_info=generation_info,
)

def _get_generation_info(self, response: NonStreamedChatResponse) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
return {
generation_info = {
"documents": response.documents,
"citations": response.citations,
"search_results": response.search_results,
"search_queries": response.search_queries,
"token_count": response.token_count,
"is_search_required": response.is_search_required,
"generation_id": response.generation_id,
}
if response.tool_calls:
# Only populate tool_calls when 1) present on the response and
# 2) has one or more calls.
generation_info["tool_calls"] = _format_cohere_tool_calls(
response.generation_id or "", response.tool_calls
)
if hasattr(response, "token_count"):
generation_info["token_count"] = response.token_count
return generation_info

def _generate(
self,
Expand All @@ -218,10 +269,8 @@ def _generate(
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)

message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response)
generation_info = self._get_generation_info(response)
message = AIMessage(content=response.text, additional_kwargs=generation_info)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
Expand All @@ -244,10 +293,8 @@ async def _agenerate(
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)

message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response)
generation_info = self._get_generation_info(response)
message = AIMessage(content=response.text, additional_kwargs=generation_info)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
Expand All @@ -257,3 +304,27 @@ async def _agenerate(
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
return len(self.client.tokenize(text).tokens)


def _format_cohere_tool_calls(
generation_id: str, tool_calls: Optional[List[ToolCall]] = None
) -> List[Dict]:
"""
Formats a Cohere API response into the tool call format used elsewhere in Langchain.
"""
if not tool_calls:
return []

formatted_tool_calls = []
for tool_call in tool_calls:
formatted_tool_calls.append(
{
"id": generation_id,
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.parameters),
},
"type": "function",
}
)
return formatted_tool_calls
168 changes: 168 additions & 0 deletions libs/partners/cohere/langchain_cohere/cohere_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Any, Dict, List, Sequence, Tuple, Type, Union

from cohere.types import Tool, ToolParameterDefinitionsValue
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.outputs import Generation
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function


def create_cohere_tools_agent(
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
) -> Runnable:
def llm_with_tools(input_: Dict) -> Runnable:
tool_results = (
input_["tool_results"] if len(input_["tool_results"]) > 0 else None
)
tools_ = input_["tools"] if len(input_["tools"]) > 0 else None
return RunnableLambda(lambda x: x["input"]) | llm.bind(
tools=tools_, tool_results=tool_results
)

agent = (
RunnablePassthrough.assign(
# Intermediate steps are in tool results.
# Edit below to change the prompt parameters.
input=lambda x: prompt.format_messages(
input=x["input"], agent_scratchpad=[]
),
tools=lambda x: _format_to_cohere_tools(tools),
tool_results=lambda x: _format_to_cohere_tools_messages(
x["intermediate_steps"]
),
)
| llm_with_tools
| _CohereToolsAgentOutputParser()
)
return agent


def _format_to_cohere_tools(
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
) -> List[Dict[str, Any]]:
return [_convert_to_cohere_tool(tool) for tool in tools]


def _format_to_cohere_tools_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> list:
"""Convert (AgentAction, tool output) tuples into tool messages."""
if len(intermediate_steps) == 0:
return []
tool_results = []
for agent_action, observation in intermediate_steps:
tool_results.append(
{
"call": {
"name": agent_action.tool,
"parameters": agent_action.tool_input,
},
"outputs": [{"answer": observation}],
}
)

return tool_results


def _convert_to_cohere_tool(
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
) -> Dict[str, Any]:
"""
Convert a BaseTool instance, JSON schema dict, or BaseModel type to a Cohere tool.
"""
if isinstance(tool, BaseTool):
return Tool(
name=tool.name,
description=tool.description,
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
required="default" not in param_definition,
)
for param_name, param_definition in tool.args.items()
},
).dict()
elif isinstance(tool, dict):
if not all(k in tool for k in ("title", "description", "properties")):
raise ValueError(
"Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)
return Tool(
name=tool.get("title"),
description=tool.get("description"),
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
required="default" not in param_definition,
)
for param_name, param_definition in tool.get("properties", {}).items()
},
).dict()
elif issubclass(tool, BaseModel):
as_json_schema_function = convert_to_openai_function(tool)
parameters = as_json_schema_function.get("parameters", {})
properties = parameters.get("properties", {})
return Tool(
name=as_json_schema_function.get("name"),
description=as_json_schema_function.get(
# The Cohere API requires the description field.
"description",
as_json_schema_function.get("name"),
),
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
required=param_name in parameters.get("required", []),
)
for param_name, param_definition in properties.items()
},
).dict()
else:
raise ValueError(
f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)


class _CohereToolsAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
):
"""Parses a message into agent actions/finish."""

def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
if result[0].message.additional_kwargs["tool_calls"]:
actions = []
for tool in result[0].message.additional_kwargs["tool_calls"]:
function = tool.get("function", {})
actions.append(
AgentAction(
tool=function.get("name"),
tool_input=function.get("arguments"),
log=function.get("name"),
)
)
return actions
else:
return AgentFinish(
return_values={
"text": result[0].message.content,
"additional_info": result[0].message.additional_kwargs,
},
log="",
)

def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")
1 change: 1 addition & 0 deletions libs/partners/cohere/langchain_cohere/rag_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _get_docs(response: Any) -> List[Document]:
docs = (
[]
if "documents" not in response.generation_info
or len(response.generation_info["documents"]) == 0
else [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
Expand Down
9 changes: 4 additions & 5 deletions libs/partners/cohere/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions libs/partners/cohere/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-cohere"
version = "0.1.0rc1"
version = "0.1.0rc2"
description = "An integration package connecting Cohere and LangChain"
authors = []
readme = "README.md"
Expand All @@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.32"
cohere = "^5.1.1"
cohere = "^5.1.4"

[tool.poetry.group.test]
optional = true
Expand Down
Loading

0 comments on commit 3685f8c

Please sign in to comment.