Skip to content

Commit

Permalink
add bind_tools
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Apr 25, 2024
1 parent 93ad2c9 commit ebea355
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 31 deletions.
133 changes: 116 additions & 17 deletions libs/genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
from __future__ import annotations

from typing import (
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
TypedDict,
Union,
)

import google.ai.generativelanguage as glm
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionCallingConfigType,
FunctionDeclarationType,
ToolDict,
ToolType,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as callable_as_lc_tool
from langchain_core.utils.json_schema import dereference_refs

TYPE_ENUM = {
Expand All @@ -27,17 +34,25 @@
"object": glm.Type.OBJECT,
}

TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()}


def convert_to_genai_function_declarations(
tool: Union[GoogleTool, ToolDict, Sequence[FunctionDeclarationType]],
) -> List[ToolType]:
tool: Union[
GoogleTool, ToolDict, FunctionDeclarationType, Sequence[FunctionDeclarationType]
],
) -> ToolType:
"""Convert any tool-like object to a ToolType.
https://github.com/google-gemini/generative-ai-python/blob/668695ebe3e9de496a36eeb95cb2ed2faba9b939/google/generativeai/types/content_types.py#L574
"""
if isinstance(tool, GoogleTool):
return [tool]
return tool
# check whether a dict is supported by glm, otherwise we parse it explicitly
if isinstance(tool, dict):
first_function_declaration = tool.get("function_declarations", [None])[0]
if isinstance(first_function_declaration, glm.FunctionDeclaration):
return [tool]
return tool
schema = None
try:
schema = first_function_declaration.parameters
Expand All @@ -46,27 +61,60 @@ def convert_to_genai_function_declarations(
if schema is None:
schema = first_function_declaration.get("parameters")
if schema is None or isinstance(schema, glm.Schema):
return [tool]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
)
for fc in tool["function_declarations"]
]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
return tool
return glm.Tool(
function_declarations=[
_convert_to_genai_function(fc) for fc in tool["function_declarations"]
],
)
for fc in tool
]
elif isinstance(tool, type) and issubclass(tool, BaseModel):
return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])
elif callable(tool):
return _convert_tool_to_genai_function(callable_as_lc_tool()(tool))
elif isinstance(tool, list):
return glm.Tool(
function_declarations=[_convert_to_genai_function(fc) for fc in tool]
)
return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])


def tool_to_dict(tool: Union[glm.Tool, GoogleTool]) -> ToolDict:
if isinstance(tool, GoogleTool):
tool = tool._proto
function_declarations = []
for function_declaration_proto in tool.function_declarations:
properties: Dict[str, Any] = {}
for property in function_declaration_proto.parameters.properties:
property_type = function_declaration_proto.parameters.properties[
property
].type
property_dict = {"type": TYPE_ENUM_REVERSE[property_type]}
property_description = function_declaration_proto.parameters.properties[
property
].description
if property_description:
property_dict["description"] = property_description
properties[property] = property_dict
function_declaration = {
"name": function_declaration_proto.name,
"description": function_declaration_proto.description,
"parameters": {"type": "object", "properties": properties},
}
if function_declaration_proto.parameters.required:
function_declaration["parameters"][ # type: ignore[index]
"required"
] = function_declaration_proto.parameters.required
function_declarations.append(function_declaration)
return {"function_declarations": function_declarations}


def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDeclaration:
if isinstance(fc, BaseTool):
print(fc)
return _convert_tool_to_genai_function(fc)
elif isinstance(fc, type) and issubclass(fc, BaseModel):
return _convert_pydantic_to_genai_function(fc)
elif callable(fc):
return _convert_tool_to_genai_function(callable_as_lc_tool()(fc))
elif isinstance(fc, dict):
return glm.FunctionDeclaration(
name=fc["name"],
Expand Down Expand Up @@ -140,3 +188,54 @@ def _convert_pydantic_to_genai_function(
"type_": TYPE_ENUM[schema["type"]],
},
)


_ToolChoiceType = Union[
dict, List[str], str, Literal["auto", "none", "any"], Literal[True]
]


class _ToolConfigDict(TypedDict):
function_calling_config: FunctionCallingConfigType


def _tool_choice_to_tool_config(
tool_choice: _ToolChoiceType,
all_names: List[str],
) -> _ToolConfigDict:
allowed_function_names: Optional[List[str]] = None
if tool_choice is True or tool_choice == "any":
mode = "any"
allowed_function_names = all_names
elif tool_choice == "auto":
mode = "auto"
elif tool_choice == "none":
mode = "none"
elif isinstance(tool_choice, str):
mode = "any"
allowed_function_names = [tool_choice]
elif isinstance(tool_choice, list):
mode = "any"
allowed_function_names = tool_choice
elif isinstance(tool_choice, dict):
if "mode" in tool_choice:
mode = tool_choice["mode"]
allowed_function_names = tool_choice.get("allowed_function_names")
elif "function_calling_config" in tool_choice:
mode = tool_choice["function_calling_config"]["mode"]
allowed_function_names = tool_choice["function_calling_config"].get(
"allowed_function_names"
)
else:
raise ValueError(
f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match "
f"Google GenerativeAI ToolConfig or FunctionCallingConfig format."
)
else:
raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}")
return _ToolConfigDict(
function_calling_config={
"mode": mode,
"allowed_function_names": allowed_function_names,
}
)
55 changes: 49 additions & 6 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@
import proto # type: ignore[import]
import requests
from google.generativeai.types import SafetySettingDict # type: ignore[import]
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolConfigDict,
ToolDict,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -56,6 +57,7 @@
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.utils import get_from_dict_or_env
from tenacity import (
before_sleep_log,
Expand All @@ -67,7 +69,11 @@

from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai._function_utils import (
_tool_choice_to_tool_config,
_ToolChoiceType,
_ToolConfigDict,
convert_to_genai_function_declarations,
tool_to_dict,
)
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI

Expand Down Expand Up @@ -738,17 +744,20 @@ def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
tools: Optional[Sequence[ToolDict]] = None,
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
functions: Optional[Sequence[FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, ToolConfigDict]] = None,
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
client = self.client
formatted_tools = None
raw_tools = tools if tools else functions
if raw_tools:
formatted_tools = convert_to_genai_function_declarations(raw_tools)
if tools:
formatted_tools = [
convert_to_genai_function_declarations(tool) for tool in tools
]
elif functions:
formatted_tools = [convert_to_genai_function_declarations(functions)]

if formatted_tools or safety_settings:
client = genai.GenerativeModel(
Expand Down Expand Up @@ -789,3 +798,37 @@ def get_num_tokens(self, text: str) -> int:
token_count = result["token_count"]

return token_count

def bind_tools(
self,
tools: Sequence[Union[ToolDict, GoogleTool]],
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
*,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with google-generativeAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
if tool_choice and tool_config:
raise ValueError(
"Must specify at most one of tool_choice and tool_config, received "
f"both:\n\n{tool_choice=}\n\n{tool_config=}"
)
# Bind dicts for easier serialization/deserialization.
genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
if tool_choice:
all_names = [
f["name"] for t in genai_tools for f in t["function_declarations"]
]
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
90 changes: 88 additions & 2 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from typing import Generator

import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool

from langchain_google_genai import (
Expand All @@ -14,7 +22,7 @@
)
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError

_MODEL = "gemini-pro" # TODO: Use nano when it's available.
_MODEL = "models/gemini-1.0-pro-001" # TODO: Use nano when it's available.
_VISION_MODEL = "gemini-pro-vision"
_B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501

Expand Down Expand Up @@ -283,3 +291,81 @@ def search(
assert isinstance(result, AIMessage)
assert "brown" in result.content
assert len(result.tool_calls) > 0


def _check_tool_calls(response: BaseMessage, expected_name: str) -> None:
"""Check tool calls are as expected."""
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == expected_name
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}
tool_calls = response.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["name"] == expected_name
assert tool_call["args"] == {"age": 27.0, "name": "Erick"}


@pytest.mark.extended
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
# Test .bind_tools with BaseModel
message = HumanMessage(content="My name is Erick and I am 27 years old")
model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[MyModel]
)
response = model.invoke([message])
_check_tool_calls(response, "MyModel")

# Test .bind_tools with function
def my_model(name: str, age: int) -> None:
"""Invoke this with names and ages."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[my_model]
)
response = model.invoke([message])
_check_tool_calls(response, "my_model")

# Test .bind_tools with tool
@tool
def my_tool(name: str, age: int) -> None:
"""Invoke this with names and ages."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[my_tool]
)
response = model.invoke([message])
_check_tool_calls(response, "my_tool")

# Test streaming
stream = model.stream([message])
first = True
for chunk in stream:
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert isinstance(gathered, AIMessageChunk)
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "my_tool"
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'
Loading

0 comments on commit ebea355

Please sign in to comment.