Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bind_tools #185

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 116 additions & 17 deletions libs/genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
from __future__ import annotations

from typing import (
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
TypedDict,
Union,
)

import google.ai.generativelanguage as glm
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionCallingConfigType,
FunctionDeclarationType,
ToolDict,
ToolType,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as callable_as_lc_tool
from langchain_core.utils.json_schema import dereference_refs

TYPE_ENUM = {
Expand All @@ -27,17 +34,25 @@
"object": glm.Type.OBJECT,
}

TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()}


def convert_to_genai_function_declarations(
tool: Union[GoogleTool, ToolDict, Sequence[FunctionDeclarationType]],
) -> List[ToolType]:
tool: Union[
GoogleTool, ToolDict, FunctionDeclarationType, Sequence[FunctionDeclarationType]
],
) -> ToolType:
"""Convert any tool-like object to a ToolType.

https://github.com/google-gemini/generative-ai-python/blob/668695ebe3e9de496a36eeb95cb2ed2faba9b939/google/generativeai/types/content_types.py#L574
"""
if isinstance(tool, GoogleTool):
return [tool]
return tool
# check whether a dict is supported by glm, otherwise we parse it explicitly
if isinstance(tool, dict):
first_function_declaration = tool.get("function_declarations", [None])[0]
if isinstance(first_function_declaration, glm.FunctionDeclaration):
return [tool]
return tool
schema = None
try:
schema = first_function_declaration.parameters
Expand All @@ -46,27 +61,60 @@ def convert_to_genai_function_declarations(
if schema is None:
schema = first_function_declaration.get("parameters")
if schema is None or isinstance(schema, glm.Schema):
return [tool]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
)
for fc in tool["function_declarations"]
]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
return tool
return glm.Tool(
function_declarations=[
_convert_to_genai_function(fc) for fc in tool["function_declarations"]
],
)
for fc in tool
]
elif isinstance(tool, type) and issubclass(tool, BaseModel):
return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])
elif callable(tool):
return _convert_tool_to_genai_function(callable_as_lc_tool()(tool))
elif isinstance(tool, list):
return glm.Tool(
function_declarations=[_convert_to_genai_function(fc) for fc in tool]
)
return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])


def tool_to_dict(tool: Union[glm.Tool, GoogleTool]) -> ToolDict:
if isinstance(tool, GoogleTool):
tool = tool._proto
function_declarations = []
for function_declaration_proto in tool.function_declarations:
properties: Dict[str, Any] = {}
for property in function_declaration_proto.parameters.properties:
property_type = function_declaration_proto.parameters.properties[
property
].type
property_dict = {"type": TYPE_ENUM_REVERSE[property_type]}
property_description = function_declaration_proto.parameters.properties[
property
].description
if property_description:
property_dict["description"] = property_description
properties[property] = property_dict
function_declaration = {
"name": function_declaration_proto.name,
"description": function_declaration_proto.description,
"parameters": {"type": "object", "properties": properties},
}
if function_declaration_proto.parameters.required:
function_declaration["parameters"][ # type: ignore[index]
"required"
] = function_declaration_proto.parameters.required
function_declarations.append(function_declaration)
return {"function_declarations": function_declarations}


def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDeclaration:
if isinstance(fc, BaseTool):
print(fc)
return _convert_tool_to_genai_function(fc)
elif isinstance(fc, type) and issubclass(fc, BaseModel):
return _convert_pydantic_to_genai_function(fc)
elif callable(fc):
return _convert_tool_to_genai_function(callable_as_lc_tool()(fc))
elif isinstance(fc, dict):
return glm.FunctionDeclaration(
name=fc["name"],
Expand Down Expand Up @@ -140,3 +188,54 @@ def _convert_pydantic_to_genai_function(
"type_": TYPE_ENUM[schema["type"]],
},
)


_ToolChoiceType = Union[
dict, List[str], str, Literal["auto", "none", "any"], Literal[True]
]


class _ToolConfigDict(TypedDict):
function_calling_config: FunctionCallingConfigType


def _tool_choice_to_tool_config(
tool_choice: _ToolChoiceType,
all_names: List[str],
) -> _ToolConfigDict:
allowed_function_names: Optional[List[str]] = None
if tool_choice is True or tool_choice == "any":
mode = "any"
allowed_function_names = all_names
elif tool_choice == "auto":
mode = "auto"
elif tool_choice == "none":
mode = "none"
elif isinstance(tool_choice, str):
mode = "any"
allowed_function_names = [tool_choice]
elif isinstance(tool_choice, list):
mode = "any"
allowed_function_names = tool_choice
elif isinstance(tool_choice, dict):
if "mode" in tool_choice:
mode = tool_choice["mode"]
allowed_function_names = tool_choice.get("allowed_function_names")
elif "function_calling_config" in tool_choice:
mode = tool_choice["function_calling_config"]["mode"]
allowed_function_names = tool_choice["function_calling_config"].get(
"allowed_function_names"
)
else:
raise ValueError(
f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match "
f"Google GenerativeAI ToolConfig or FunctionCallingConfig format."
)
else:
raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}")
return _ToolConfigDict(
function_calling_config={
"mode": mode,
"allowed_function_names": allowed_function_names,
}
)
55 changes: 49 additions & 6 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@
import proto # type: ignore[import]
import requests
from google.generativeai.types import SafetySettingDict # type: ignore[import]
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolConfigDict,
ToolDict,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -56,6 +57,7 @@
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.utils import get_from_dict_or_env
from tenacity import (
before_sleep_log,
Expand All @@ -67,7 +69,11 @@

from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai._function_utils import (
_tool_choice_to_tool_config,
_ToolChoiceType,
_ToolConfigDict,
convert_to_genai_function_declarations,
tool_to_dict,
)
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI

Expand Down Expand Up @@ -738,17 +744,20 @@ def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
tools: Optional[Sequence[ToolDict]] = None,
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, ToolConfigDict]] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
client = self.client
formatted_tools = None
raw_tools = tools if tools else functions
if raw_tools:
formatted_tools = convert_to_genai_function_declarations(raw_tools)
if tools:
formatted_tools = [
convert_to_genai_function_declarations(tool) for tool in tools
]
elif functions:
formatted_tools = [convert_to_genai_function_declarations(functions)]

if formatted_tools or safety_settings:
client = genai.GenerativeModel(
Expand Down Expand Up @@ -789,3 +798,37 @@ def get_num_tokens(self, text: str) -> int:
token_count = result["token_count"]

return token_count

def bind_tools(
self,
tools: Sequence[Union[ToolDict, GoogleTool]],
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
*,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.

Assumes model is compatible with google-generativeAI tool-calling API.

Args:
tools: A list of tool definitions to bind to this chat model.
Can be a pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
if tool_choice and tool_config:
raise ValueError(
"Must specify at most one of tool_choice and tool_config, received "
f"both:\n\n{tool_choice=}\n\n{tool_config=}"
)
# Bind dicts for easier serialization/deserialization.
genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
if tool_choice:
all_names = [
f["name"] for t in genai_tools for f in t["function_declarations"]
]
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
90 changes: 88 additions & 2 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from typing import Generator

import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool

from langchain_google_genai import (
Expand All @@ -14,7 +22,7 @@
)
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError

_MODEL = "gemini-pro" # TODO: Use nano when it's available.
_MODEL = "models/gemini-1.0-pro-001" # TODO: Use nano when it's available.
_VISION_MODEL = "gemini-pro-vision"
_B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501

Expand Down Expand Up @@ -283,3 +291,81 @@ def search(
assert isinstance(result, AIMessage)
assert "brown" in result.content
assert len(result.tool_calls) > 0


def _check_tool_calls(response: BaseMessage, expected_name: str) -> None:
"""Check tool calls are as expected."""
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == expected_name
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}
tool_calls = response.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["name"] == expected_name
assert tool_call["args"] == {"age": 27.0, "name": "Erick"}


@pytest.mark.extended
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
# Test .bind_tools with BaseModel
message = HumanMessage(content="My name is Erick and I am 27 years old")
model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[MyModel]
)
response = model.invoke([message])
_check_tool_calls(response, "MyModel")

# Test .bind_tools with function
def my_model(name: str, age: int) -> None:
"""Invoke this with names and ages."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[my_model]
)
response = model.invoke([message])
_check_tool_calls(response, "my_model")

# Test .bind_tools with tool
@tool
def my_tool(name: str, age: int) -> None:
"""Invoke this with names and ages."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[my_tool]
)
response = model.invoke([message])
_check_tool_calls(response, "my_tool")

# Test streaming
stream = model.stream([message])
first = True
for chunk in stream:
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert isinstance(gathered, AIMessageChunk)
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "my_tool"
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'
Loading
Loading