Skip to content

Commit

Permalink
added toolmessage support (#181)
Browse files Browse the repository at this point in the history
* added toolmessage support
  • Loading branch information
lkuligin authored Apr 24, 2024
1 parent f2f33c6 commit 93ad2c9
Show file tree
Hide file tree
Showing 7 changed files with 918 additions and 272 deletions.
44 changes: 35 additions & 9 deletions libs/genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from __future__ import annotations

from typing import (
Dict,
List,
Sequence,
Type,
Union,
)

import google.ai.generativelanguage as glm
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolDict,
ToolType,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.json_schema import dereference_refs

FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]

TYPE_ENUM = {
"string": glm.Type.STRING,
"number": glm.Type.NUMBER,
Expand All @@ -25,18 +29,41 @@


def convert_to_genai_function_declarations(
function_calls: List[FunctionCallType],
) -> List[glm.Tool]:
tool: Union[GoogleTool, ToolDict, Sequence[FunctionDeclarationType]],
) -> List[ToolType]:
if isinstance(tool, GoogleTool):
return [tool]
# check whether a dict is supported by glm, otherwise we parse it explicitly
if isinstance(tool, dict):
first_function_declaration = tool.get("function_declarations", [None])[0]
if isinstance(first_function_declaration, glm.FunctionDeclaration):
return [tool]
schema = None
try:
schema = first_function_declaration.parameters
except AttributeError:
pass
if schema is None:
schema = first_function_declaration.get("parameters")
if schema is None or isinstance(schema, glm.Schema):
return [tool]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
)
for fc in tool["function_declarations"]
]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
)
for fc in function_calls
for fc in tool
]


def _convert_to_genai_function(fc: FunctionCallType) -> glm.FunctionDeclaration:
def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDeclaration:
if isinstance(fc, BaseTool):
print(fc)
return _convert_tool_to_genai_function(fc)
elif isinstance(fc, type) and issubclass(fc, BaseModel):
return _convert_pydantic_to_genai_function(fc)
Expand Down Expand Up @@ -64,7 +91,6 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
if tool.args_schema:
schema = dereference_refs(tool.args_schema.schema())
schema.pop("definitions", None)

return glm.FunctionDeclaration(
name=tool.name or schema["title"],
description=tool.description or schema["description"],
Expand All @@ -76,7 +102,7 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"required": schema.get("required", []),
"type_": TYPE_ENUM[schema["type"]],
},
)
Expand Down
167 changes: 142 additions & 25 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import uuid
import warnings
from io import BytesIO
from typing import (
Expand All @@ -29,6 +30,12 @@
import google.generativeai as genai # type: ignore[import]
import proto # type: ignore[import]
import requests
from google.generativeai.types import SafetySettingDict # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolConfigDict,
ToolDict,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -40,8 +47,13 @@
BaseMessage,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
Expand Down Expand Up @@ -350,6 +362,40 @@ def _parse_chat_history(
)
)
]
elif isinstance(message, ToolMessage):
role = "user"
prev_message: Optional[BaseMessage] = (
input_messages[i - 1] if i > 0 else None
)
if (
prev_message
and isinstance(prev_message, AIMessage)
and prev_message.tool_calls
):
# message.name can be null for ToolMessage
name: str = prev_message.tool_calls[0]["name"]
else:
name = message.name # type: ignore
tool_response: Any
if not isinstance(message.content, str):
tool_response = message.content
else:
try:
tool_response = json.loads(message.content)
except json.JSONDecodeError:
tool_response = message.content # leave as str representation
parts = [
glm.Part(
function_response=glm.FunctionResponse(
name=name,
response=(
{"output": tool_response}
if not isinstance(tool_response, dict)
else tool_response
),
)
)
]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand All @@ -360,24 +406,89 @@ def _parse_chat_history(


def _parse_response_candidate(
response_candidate: glm.Candidate, stream: bool
response_candidate: glm.Candidate, streaming: bool = False
) -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = proto.Message.to_dict(first_part.function_call)
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
return (AIMessageChunk if stream else AIMessage)(
content="", additional_kwargs={"function_call": function_call}
content: Union[None, str, List[str]] = None
additional_kwargs = {}
tool_calls = []
invalid_tool_calls = []
tool_call_chunks = []

for part in response_candidate.content.parts:
try:
text: Optional[str] = part.text
except AttributeError:
text = None

if text is not None:
if not content:
content = text
elif isinstance(content, str) and text:
content = [content, text]
elif isinstance(content, list) and text:
content.append(text)
elif text:
raise Exception("Unexpected content type")

if part.function_call:
# TODO: support multiple function calls
if "function_call" in additional_kwargs:
raise Exception("Multiple function calls are not currently supported")
function_call = {"name": part.function_call.name}
# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
function_call["arguments"] = json.dumps(
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call

if streaming:
tool_call_chunks.append(
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
index=function_call.get("index"), # type: ignore
)
)
else:
try:
tool_calls_dicts = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)
tool_calls = [
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id", str(uuid.uuid4())),
)
for tool_call in tool_calls_dicts
]
except Exception as e:
invalid_tool_calls = [
InvalidToolCall(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
error=str(e),
)
]
if content is None:
content = ""

if streaming:
return AIMessageChunk(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
else:
parts = response_candidate.content.parts

if len(parts) == 1 and parts[0].text:
content: Union[str, List[Union[str, Dict]]] = parts[0].text
else:
content = [proto.Message.to_dict(part) for part in parts]
return (AIMessageChunk if stream else AIMessage)(
content=content, additional_kwargs={}

return AIMessage(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)


Expand All @@ -400,7 +511,7 @@ def _response_to_result(
]
generations.append(
(ChatGenerationChunk if stream else ChatGeneration)(
message=_parse_response_candidate(candidate, stream=stream),
message=_parse_response_candidate(candidate, streaming=stream),
generation_info=generation_info,
)
)
Expand Down Expand Up @@ -627,20 +738,26 @@ def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
tools: Optional[Sequence[ToolDict]] = None,
functions: Optional[Sequence[FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, ToolConfigDict]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
client = self.client
functions = kwargs.pop("functions", None)
safety_settings = kwargs.pop("safety_settings", self.safety_settings)
if functions or safety_settings:
tools = (
convert_to_genai_function_declarations(functions) if functions else None
)
formatted_tools = None
raw_tools = tools if tools else functions
if raw_tools:
formatted_tools = convert_to_genai_function_declarations(raw_tools)

if formatted_tools or safety_settings:
client = genai.GenerativeModel(
model_name=self.model, tools=tools, safety_settings=safety_settings
model_name=self.model,
tools=formatted_tools,
safety_settings=safety_settings,
)

params = self._prepare_params(stop, **kwargs)
params = self._prepare_params(stop, tool_config=tool_config, **kwargs)
system_instruction, history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
Expand Down
Loading

0 comments on commit 93ad2c9

Please sign in to comment.