diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index b292223c..82985948 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -1,21 +1,28 @@ from __future__ import annotations from typing import ( + Any, + Dict, List, + Literal, + Optional, Sequence, Type, + TypedDict, Union, ) import google.ai.generativelanguage as glm from google.generativeai.types import Tool as GoogleTool # type: ignore[import] from google.generativeai.types.content_types import ( # type: ignore[import] + FunctionCallingConfigType, FunctionDeclarationType, ToolDict, ToolType, ) from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import BaseTool +from langchain_core.tools import tool as callable_as_lc_tool from langchain_core.utils.json_schema import dereference_refs TYPE_ENUM = { @@ -27,17 +34,25 @@ "object": glm.Type.OBJECT, } +TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()} + def convert_to_genai_function_declarations( - tool: Union[GoogleTool, ToolDict, Sequence[FunctionDeclarationType]], -) -> List[ToolType]: + tool: Union[ + GoogleTool, ToolDict, FunctionDeclarationType, Sequence[FunctionDeclarationType] + ], +) -> ToolType: + """Convert any tool-like object to a ToolType. + + https://github.com/google-gemini/generative-ai-python/blob/668695ebe3e9de496a36eeb95cb2ed2faba9b939/google/generativeai/types/content_types.py#L574 + """ if isinstance(tool, GoogleTool): - return [tool] + return tool # check whether a dict is supported by glm, otherwise we parse it explicitly if isinstance(tool, dict): first_function_declaration = tool.get("function_declarations", [None])[0] if isinstance(first_function_declaration, glm.FunctionDeclaration): - return [tool] + return tool schema = None try: schema = first_function_declaration.parameters @@ -46,27 +61,60 @@ def convert_to_genai_function_declarations( if schema is None: schema = first_function_declaration.get("parameters") if schema is None or isinstance(schema, glm.Schema): - return [tool] - return [ - glm.Tool( - function_declarations=[_convert_to_genai_function(fc)], - ) - for fc in tool["function_declarations"] - ] - return [ - glm.Tool( - function_declarations=[_convert_to_genai_function(fc)], + return tool + return glm.Tool( + function_declarations=[ + _convert_to_genai_function(fc) for fc in tool["function_declarations"] + ], ) - for fc in tool - ] + elif isinstance(tool, type) and issubclass(tool, BaseModel): + return glm.Tool(function_declarations=[_convert_to_genai_function(tool)]) + elif callable(tool): + return _convert_tool_to_genai_function(callable_as_lc_tool()(tool)) + elif isinstance(tool, list): + return glm.Tool( + function_declarations=[_convert_to_genai_function(fc) for fc in tool] + ) + return glm.Tool(function_declarations=[_convert_to_genai_function(tool)]) + + +def tool_to_dict(tool: Union[glm.Tool, GoogleTool]) -> ToolDict: + if isinstance(tool, GoogleTool): + tool = tool._proto + function_declarations = [] + for function_declaration_proto in tool.function_declarations: + properties: Dict[str, Any] = {} + for property in function_declaration_proto.parameters.properties: + property_type = function_declaration_proto.parameters.properties[ + property + ].type + property_dict = {"type": TYPE_ENUM_REVERSE[property_type]} + property_description = function_declaration_proto.parameters.properties[ + property + ].description + if property_description: + property_dict["description"] = property_description + properties[property] = property_dict + function_declaration = { + "name": function_declaration_proto.name, + "description": function_declaration_proto.description, + "parameters": {"type": "object", "properties": properties}, + } + if function_declaration_proto.parameters.required: + function_declaration["parameters"][ # type: ignore[index] + "required" + ] = function_declaration_proto.parameters.required + function_declarations.append(function_declaration) + return {"function_declarations": function_declarations} def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDeclaration: if isinstance(fc, BaseTool): - print(fc) return _convert_tool_to_genai_function(fc) elif isinstance(fc, type) and issubclass(fc, BaseModel): return _convert_pydantic_to_genai_function(fc) + elif callable(fc): + return _convert_tool_to_genai_function(callable_as_lc_tool()(fc)) elif isinstance(fc, dict): return glm.FunctionDeclaration( name=fc["name"], @@ -140,3 +188,54 @@ def _convert_pydantic_to_genai_function( "type_": TYPE_ENUM[schema["type"]], }, ) + + +_ToolChoiceType = Union[ + dict, List[str], str, Literal["auto", "none", "any"], Literal[True] +] + + +class _ToolConfigDict(TypedDict): + function_calling_config: FunctionCallingConfigType + + +def _tool_choice_to_tool_config( + tool_choice: _ToolChoiceType, + all_names: List[str], +) -> _ToolConfigDict: + allowed_function_names: Optional[List[str]] = None + if tool_choice is True or tool_choice == "any": + mode = "any" + allowed_function_names = all_names + elif tool_choice == "auto": + mode = "auto" + elif tool_choice == "none": + mode = "none" + elif isinstance(tool_choice, str): + mode = "any" + allowed_function_names = [tool_choice] + elif isinstance(tool_choice, list): + mode = "any" + allowed_function_names = tool_choice + elif isinstance(tool_choice, dict): + if "mode" in tool_choice: + mode = tool_choice["mode"] + allowed_function_names = tool_choice.get("allowed_function_names") + elif "function_calling_config" in tool_choice: + mode = tool_choice["function_calling_config"]["mode"] + allowed_function_names = tool_choice["function_calling_config"].get( + "allowed_function_names" + ) + else: + raise ValueError( + f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match " + f"Google GenerativeAI ToolConfig or FunctionCallingConfig format." + ) + else: + raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}") + return _ToolConfigDict( + function_calling_config={ + "mode": mode, + "allowed_function_names": allowed_function_names, + } + ) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 113d93b9..4acbff0a 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -31,15 +31,16 @@ import proto # type: ignore[import] import requests from google.generativeai.types import SafetySettingDict # type: ignore[import] +from google.generativeai.types import Tool as GoogleTool # type: ignore[import] from google.generativeai.types.content_types import ( # type: ignore[import] FunctionDeclarationType, - ToolConfigDict, ToolDict, ) from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, @@ -56,6 +57,7 @@ from langchain_core.output_parsers.openai_tools import parse_tool_calls from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import SecretStr, root_validator +from langchain_core.runnables import Runnable from langchain_core.utils import get_from_dict_or_env from tenacity import ( before_sleep_log, @@ -67,7 +69,11 @@ from langchain_google_genai._common import GoogleGenerativeAIError from langchain_google_genai._function_utils import ( + _tool_choice_to_tool_config, + _ToolChoiceType, + _ToolConfigDict, convert_to_genai_function_declarations, + tool_to_dict, ) from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI @@ -738,17 +744,20 @@ def _prepare_chat( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, - tools: Optional[Sequence[ToolDict]] = None, + tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None, functions: Optional[Sequence[FunctionDeclarationType]] = None, safety_settings: Optional[SafetySettingDict] = None, - tool_config: Optional[Union[Dict, ToolConfigDict]] = None, + tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, **kwargs: Any, ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]: client = self.client formatted_tools = None - raw_tools = tools if tools else functions - if raw_tools: - formatted_tools = convert_to_genai_function_declarations(raw_tools) + if tools: + formatted_tools = [ + convert_to_genai_function_declarations(tool) for tool in tools + ] + elif functions: + formatted_tools = [convert_to_genai_function_declarations(functions)] if formatted_tools or safety_settings: client = genai.GenerativeModel( @@ -789,3 +798,37 @@ def get_num_tokens(self, text: str) -> int: token_count = result["token_count"] return token_count + + def bind_tools( + self, + tools: Sequence[Union[ToolDict, GoogleTool]], + tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + *, + tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model is compatible with google-generativeAI tool-calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + if tool_choice and tool_config: + raise ValueError( + "Must specify at most one of tool_choice and tool_config, received " + f"both:\n\n{tool_choice=}\n\n{tool_config=}" + ) + # Bind dicts for easier serialization/deserialization. + genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))] + if tool_choice: + all_names = [ + f["name"] for t in genai_tools for f in t["function_declarations"] + ] + tool_config = _tool_choice_to_tool_config(tool_choice, all_names) + return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs) diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 37ece444..7ba52e01 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -4,7 +4,15 @@ from typing import Generator import pytest -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import tool from langchain_google_genai import ( @@ -14,7 +22,7 @@ ) from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError -_MODEL = "gemini-pro" # TODO: Use nano when it's available. +_MODEL = "models/gemini-1.0-pro-001" # TODO: Use nano when it's available. _VISION_MODEL = "gemini-pro-vision" _B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501 @@ -283,3 +291,81 @@ def search( assert isinstance(result, AIMessage) assert "brown" in result.content assert len(result.tool_calls) > 0 + + +def _check_tool_calls(response: BaseMessage, expected_name: str) -> None: + """Check tool calls are as expected.""" + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert response.content == "" + function_call = response.additional_kwargs.get("function_call") + assert function_call + assert function_call["name"] == expected_name + arguments_str = function_call.get("arguments") + assert arguments_str + arguments = json.loads(arguments_str) + assert arguments == { + "name": "Erick", + "age": 27.0, + } + tool_calls = response.tool_calls + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert tool_call["name"] == expected_name + assert tool_call["args"] == {"age": 27.0, "name": "Erick"} + + +@pytest.mark.extended +def test_chat_vertexai_gemini_function_calling() -> None: + class MyModel(BaseModel): + name: str + age: int + + safety = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + } + # Test .bind_tools with BaseModel + message = HumanMessage(content="My name is Erick and I am 27 years old") + model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools( + [MyModel] + ) + response = model.invoke([message]) + _check_tool_calls(response, "MyModel") + + # Test .bind_tools with function + def my_model(name: str, age: int) -> None: + """Invoke this with names and ages.""" + pass + + model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools( + [my_model] + ) + response = model.invoke([message]) + _check_tool_calls(response, "my_model") + + # Test .bind_tools with tool + @tool + def my_tool(name: str, age: int) -> None: + """Invoke this with names and ages.""" + pass + + model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools( + [my_tool] + ) + response = model.invoke([message]) + _check_tool_calls(response, "my_tool") + + # Test streaming + stream = model.stream([message]) + first = True + for chunk in stream: + if first: + gathered = chunk + first = False + else: + gathered = gathered + chunk # type: ignore + assert isinstance(gathered, AIMessageChunk) + assert len(gathered.tool_call_chunks) == 1 + tool_call_chunk = gathered.tool_call_chunks[0] + assert tool_call_chunk["name"] == "my_tool" + assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}' diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index bf111f2e..0df8cee7 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -1,8 +1,15 @@ +from typing import Any + import google.ai.generativelanguage as glm +import pytest +from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import tool from langchain_google_genai._function_utils import ( + _tool_choice_to_tool_config, + _ToolConfigDict, convert_to_genai_function_declarations, + tool_to_dict, ) @@ -15,7 +22,7 @@ def get_datetime() -> str: return datetime.datetime.now().strftime("%Y-%m-%d") schema = convert_to_genai_function_declarations([get_datetime]) - function_declaration = schema[0].function_declarations[0] + function_declaration = schema.function_declarations[0] assert function_declaration.name == "get_datetime" assert ( function_declaration.description @@ -34,7 +41,7 @@ def sum_two_numbers(a: float, b: float) -> str: return str(a + b) schema = convert_to_genai_function_declarations([sum_two_numbers]) # type: ignore - function_declaration = schema[0].function_declarations[0] + function_declaration = schema.function_declarations[0] assert function_declaration.name == "sum_two_numbers" assert function_declaration.parameters assert len(function_declaration.parameters.required) == 2 @@ -45,7 +52,7 @@ def do_something_optional(a: float, b: float = 0) -> str: return str(a + b) schema = convert_to_genai_function_declarations([do_something_optional]) # type: ignore - function_declaration = schema[0].function_declarations[0] + function_declaration = schema.function_declarations[0] assert function_declaration.name == "do_something_optional" assert function_declaration.parameters assert len(function_declaration.parameters.required) == 1 @@ -69,7 +76,7 @@ def test_format_tooldict_to_genai_function() -> None: ] } schema = convert_to_genai_function_declarations(calculator) - assert schema[0] == calculator + assert schema == calculator def test_format_native_dict_to_genai_function() -> None: @@ -82,7 +89,7 @@ def test_format_native_dict_to_genai_function() -> None: ] } schema = convert_to_genai_function_declarations(calculator) - assert schema[0] == calculator + assert schema == calculator def test_format_dict_to_genai_function() -> None: @@ -101,7 +108,50 @@ def test_format_dict_to_genai_function() -> None: ] } schema = convert_to_genai_function_declarations(calculator) - function_declaration = schema[0].function_declarations[0] + function_declaration = schema.function_declarations[0] assert function_declaration.name == "search" assert function_declaration.parameters assert function_declaration.parameters.required == [] + + +@pytest.mark.parametrize("choice", (True, "foo", ["foo"], "any")) +def test__tool_choice_to_tool_config(choice: Any) -> None: + expected = _ToolConfigDict( + function_calling_config={ + "mode": "any", + "allowed_function_names": ["foo"], + }, + ) + actual = _tool_choice_to_tool_config(choice, ["foo"]) + assert expected == actual + + +def test_tool_to_dict_glm_tool() -> None: + tool = glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name="multiply", + description="Returns the product of two numbers.", + parameters=glm.Schema( + type=glm.Type.OBJECT, + properties={ + "a": glm.Schema(type=glm.Type.NUMBER), + "b": glm.Schema(type=glm.Type.NUMBER), + }, + required=["a", "b"], + ), + ) + ] + ) + tool_dict = tool_to_dict(tool) + assert tool == convert_to_genai_function_declarations(tool_dict) + + +def test_tool_to_dict_pydantic() -> None: + class MyModel(BaseModel): + name: str + age: int + + tool = convert_to_genai_function_declarations([MyModel]) + tool_dict = tool_to_dict(tool) + assert tool == convert_to_genai_function_declarations(tool_dict)