Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain_cohere:Cohere package misc fixs tool use agent and cohere chat #19705

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 237 additions & 0 deletions libs/partners/cohere/docs/cohere_agent.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 0\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cohere Tools\n",
"\n",
"The following notebook goes over how to use the Cohere tools agent:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prerequisites for this notebook:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: langchain in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.13)\n",
"Requirement already satisfied: langchain-cohere in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.0rc2)\n",
"Requirement already satisfied: PyYAML>=5.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (6.0.1)\n",
"Requirement already satisfied: SQLAlchemy<3,>=1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.0.27)\n",
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (3.9.3)\n",
"Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.6.4)\n",
"Requirement already satisfied: jsonpatch<2.0,>=1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.33)\n",
"Requirement already satisfied: langchain-community<0.1,>=0.0.29 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.29)\n",
"Requirement already satisfied: langchain-core<0.2.0,>=0.1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.35)\n",
"Requirement already satisfied: langchain-text-splitters<0.1,>=0.0.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.1)\n",
"Requirement already satisfied: langsmith<0.2.0,>=0.1.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.31)\n",
"Requirement already satisfied: numpy<2,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.24.4)\n",
"Requirement already satisfied: pydantic<3,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.6.4)\n",
"Requirement already satisfied: requests<3,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.31.0)\n",
"Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (8.2.3)\n",
"Requirement already satisfied: cohere<6.0.0,>=5.1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-cohere) (5.1.4)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n",
"Requirement already satisfied: httpx>=0.21.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (0.27.0)\n",
"Requirement already satisfied: typing_extensions>=4.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (4.10.0)\n",
"Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.20.2)\n",
"Requirement already satisfied: typing-inspect<1,>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n",
"Requirement already satisfied: jsonpointer>=1.9 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from jsonpatch<2.0,>=1.33->langchain) (2.4)\n",
"Requirement already satisfied: packaging<24.0,>=23.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.33->langchain) (23.2)\n",
"Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langsmith<0.2.0,>=0.1.17->langchain) (3.9.15)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.16.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (2.16.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2024.2.2)\n",
"Requirement already satisfied: anyio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (4.3.0)\n",
"Requirement already satisfied: httpcore==1.* in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.0.4)\n",
"Requirement already satisfied: sniffio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.3.1)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (0.14.0)\n",
"Requirement already satisfied: mypy-extensions>=0.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n",
"Requirement already satisfied: wikipedia in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (1.4.0)\n",
"Requirement already satisfied: beautifulsoup4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (4.12.3)\n",
"Requirement already satisfied: requests<3.0.0,>=2.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (2.31.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2024.2.2)\n",
"Requirement already satisfied: soupsieve>1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from beautifulsoup4->wikipedia) (2.5)\n"
]
}
],
"source": [
"# install package\n",
"!pip install langchain langchain-cohere\n",
"!pip install wikipedia"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentExecutor\n",
"from langchain.retrievers import WikipediaRetriever\n",
"from langchain.tools.retriever import create_retriever_tool\n",
"from langchain_cohere import create_cohere_tools_agent\n",
"from langchain_cohere.chat_models import ChatCohere\n",
"from langchain_core.prompts import ChatPromptTemplate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we create the prompt template and cohere model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create the prompt\n",
"prompt = ChatPromptTemplate.from_template(\n",
" \"Write all output in capital letters. {input}\"\n",
")\n",
"\n",
"# Create the Cohere chat model\n",
"chat = ChatCohere(cohere_api_key=\"API_KEY\", model=\"command-r\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we use a Wikipedia retrieval tool "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"retriever = WikipediaRetriever()\n",
"retriever_tool = create_retriever_tool(\n",
" retriever,\n",
" \"wikipedia\",\n",
" \"Search for information on Wikipedia\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, create the cohere tool agent and call with the input"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mwikipedia\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'Who founded Cohere?',\n",
" 'text': 'COHERE WAS FOUNDED BY AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.',\n",
" 'additional_info': {'documents': [{'answer': '',\n",
" 'id': 'wikipedia:0:0',\n",
" 'tool_name': 'wikipedia'}],\n",
" 'citations': [ChatCitation(start=22, end=63, text='AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.', document_ids=['wikipedia:0:0'])],\n",
" 'search_results': None,\n",
" 'search_queries': None,\n",
" 'is_search_required': None,\n",
" 'generation_id': '3b7e96be-8aad-4fa0-9ae3-7a38e800c289',\n",
" 'token_count': {'prompt_tokens': 740,\n",
" 'response_tokens': 27,\n",
" 'total_tokens': 767,\n",
" 'billed_tokens': 48}}}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent = create_cohere_tools_agent(\n",
" llm=chat,\n",
" tools=[retriever_tool],\n",
" prompt=prompt,\n",
")\n",
"agent_executor = AgentExecutor(agent=agent, tools=[retriever_tool], verbose=True)\n",
"agent_executor.invoke({\"input\": \"Who founded Cohere?\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
29 changes: 13 additions & 16 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,22 @@ def get_cohere_chat_request(
additional_kwargs = messages[-1].additional_kwargs

# cohere SDK will fail loudly if both connectors and documents are provided
if (
len(additional_kwargs.get("documents", [])) > 0
and documents
and len(documents) > 0
):
if additional_kwargs.get("documents", []) and documents and len(documents) > 0:
raise ValueError(
"Received documents both as a keyword argument and as an prompt additional"
"keywword argument. Please choose only one option."
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
)

formatted_docs = [
{
"text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(additional_kwargs.get("documents", []))
] or documents
if not formatted_docs:
formatted_docs = None
formatted_docs: Optional[List[Dict[str, Any]]] = None
if additional_kwargs.get("documents"):
formatted_docs = [
{
"text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(additional_kwargs.get("documents", []))
]
elif documents:
formatted_docs = documents

# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
Expand Down
45 changes: 32 additions & 13 deletions libs/partners/cohere/langchain_cohere/cohere_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import json
from typing import Any, Dict, List, Sequence, Tuple, Type, Union

from cohere.types import Tool, ToolParameterDefinitionsValue
from cohere.types import (
ChatRequestToolResultsItem,
Tool,
ToolCall,
ToolParameterDefinitionsValue,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
Expand Down Expand Up @@ -30,9 +36,7 @@ def llm_with_tools(input_: Dict) -> Runnable:
RunnablePassthrough.assign(
# Intermediate steps are in tool results.
# Edit below to change the prompt parameters.
input=lambda x: prompt.format_messages(
input=x["input"], agent_scratchpad=[]
),
input=lambda x: prompt.format_messages(**x, agent_scratchpad=[]),
tools=lambda x: _format_to_cohere_tools(tools),
tool_results=lambda x: _format_to_cohere_tools_messages(
x["intermediate_steps"]
Expand All @@ -52,20 +56,35 @@ def _format_to_cohere_tools(

def _format_to_cohere_tools_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> list:
) -> List[Dict[str, Any]]:
"""Convert (AgentAction, tool output) tuples into tool messages."""
if len(intermediate_steps) == 0:
return []
tool_results = []
for agent_action, observation in intermediate_steps:
# agent_action.tool_input can be a dict, serialised dict, or string.
# Cohere API only accepts a dict.
tool_call_parameters: Dict[str, Any]
if isinstance(agent_action.tool_input, dict):
# tool_input is a dict, use as-is.
tool_call_parameters = agent_action.tool_input
else:
try:
# tool_input is serialised dict.
tool_call_parameters = json.loads(agent_action.tool_input)
if not isinstance(tool_call_parameters, dict):
raise ValueError()
except ValueError:
# tool_input is a string, last ditch attempt at having something useful.
tool_call_parameters = {"input": agent_action.tool_input}
tool_results.append(
{
"call": {
"name": agent_action.tool,
"parameters": agent_action.tool_input,
},
"outputs": [{"answer": observation}],
}
ChatRequestToolResultsItem(
call=ToolCall(
name=agent_action.tool,
parameters=tool_call_parameters,
),
outputs=[{"answer": observation}],
).dict()
)

return tool_results
Expand Down Expand Up @@ -143,7 +162,7 @@ def parse_result(
) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
if result[0].message.additional_kwargs["tool_calls"]:
if "tool_calls" in result[0].message.additional_kwargs:
actions = []
for tool in result[0].message.additional_kwargs["tool_calls"]:
function = tool.get("function", {})
Expand Down
Loading
Loading