Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added toolmessage support #181

Merged
merged 3 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions libs/genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from __future__ import annotations

from typing import (
Dict,
List,
Sequence,
Type,
Union,
)

import google.ai.generativelanguage as glm
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolDict,
ToolType,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.json_schema import dereference_refs

FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]

TYPE_ENUM = {
"string": glm.Type.STRING,
"number": glm.Type.NUMBER,
Expand All @@ -25,18 +29,41 @@


def convert_to_genai_function_declarations(
function_calls: List[FunctionCallType],
) -> List[glm.Tool]:
tool: Union[GoogleTool, ToolDict, Sequence[FunctionDeclarationType]],
) -> List[ToolType]:
if isinstance(tool, GoogleTool):
return [tool]
# check whether a dict is supported by glm, otherwise we parse it explicitly
if isinstance(tool, dict):
first_function_declaration = tool.get("function_declarations", [None])[0]
if isinstance(first_function_declaration, glm.FunctionDeclaration):
return [tool]
schema = None
try:
schema = first_function_declaration.parameters
except AttributeError:
pass
if schema is None:
schema = first_function_declaration.get("parameters")
if schema is None or isinstance(schema, glm.Schema):
return [tool]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
)
for fc in tool["function_declarations"]
]
return [
glm.Tool(
function_declarations=[_convert_to_genai_function(fc)],
)
for fc in function_calls
for fc in tool
]


def _convert_to_genai_function(fc: FunctionCallType) -> glm.FunctionDeclaration:
def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDeclaration:
if isinstance(fc, BaseTool):
print(fc)
return _convert_tool_to_genai_function(fc)
elif isinstance(fc, type) and issubclass(fc, BaseModel):
return _convert_pydantic_to_genai_function(fc)
Expand Down Expand Up @@ -64,7 +91,6 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
if tool.args_schema:
schema = dereference_refs(tool.args_schema.schema())
schema.pop("definitions", None)

return glm.FunctionDeclaration(
name=tool.name or schema["title"],
description=tool.description or schema["description"],
Expand All @@ -76,7 +102,7 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"required": schema.get("required", []),
"type_": TYPE_ENUM[schema["type"]],
},
)
Expand Down
167 changes: 142 additions & 25 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import uuid
import warnings
from io import BytesIO
from typing import (
Expand All @@ -29,6 +30,12 @@
import google.generativeai as genai # type: ignore[import]
import proto # type: ignore[import]
import requests
from google.generativeai.types import SafetySettingDict # type: ignore[import]
from google.generativeai.types.content_types import ( # type: ignore[import]
FunctionDeclarationType,
ToolConfigDict,
ToolDict,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -40,8 +47,13 @@
BaseMessage,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
Expand Down Expand Up @@ -350,6 +362,40 @@ def _parse_chat_history(
)
)
]
elif isinstance(message, ToolMessage):
role = "user"
prev_message: Optional[BaseMessage] = (
input_messages[i - 1] if i > 0 else None
)
if (
prev_message
and isinstance(prev_message, AIMessage)
and prev_message.tool_calls
):
# message.name can be null for ToolMessage
name: str = prev_message.tool_calls[0]["name"]
else:
name = message.name # type: ignore
tool_response: Any
if not isinstance(message.content, str):
tool_response = message.content
else:
try:
tool_response = json.loads(message.content)
except json.JSONDecodeError:
tool_response = message.content # leave as str representation
parts = [
glm.Part(
function_response=glm.FunctionResponse(
name=name,
response=(
{"output": tool_response}
if not isinstance(tool_response, dict)
else tool_response
),
)
)
]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand All @@ -360,24 +406,89 @@ def _parse_chat_history(


def _parse_response_candidate(
response_candidate: glm.Candidate, stream: bool
response_candidate: glm.Candidate, streaming: bool = False
) -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = proto.Message.to_dict(first_part.function_call)
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
return (AIMessageChunk if stream else AIMessage)(
content="", additional_kwargs={"function_call": function_call}
content: Union[None, str, List[str]] = None
additional_kwargs = {}
tool_calls = []
invalid_tool_calls = []
tool_call_chunks = []

for part in response_candidate.content.parts:
try:
text: Optional[str] = part.text
except AttributeError:
text = None

if text is not None:
if not content:
content = text
elif isinstance(content, str) and text:
content = [content, text]
elif isinstance(content, list) and text:
content.append(text)
elif text:
raise Exception("Unexpected content type")

if part.function_call:
# TODO: support multiple function calls
if "function_call" in additional_kwargs:
raise Exception("Multiple function calls are not currently supported")
function_call = {"name": part.function_call.name}
# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
function_call["arguments"] = json.dumps(
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call

if streaming:
tool_call_chunks.append(
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
index=function_call.get("index"), # type: ignore
)
)
else:
try:
tool_calls_dicts = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)
tool_calls = [
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id", str(uuid.uuid4())),
)
for tool_call in tool_calls_dicts
]
except Exception as e:
invalid_tool_calls = [
InvalidToolCall(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
error=str(e),
)
]
if content is None:
content = ""

if streaming:
return AIMessageChunk(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
else:
parts = response_candidate.content.parts

if len(parts) == 1 and parts[0].text:
content: Union[str, List[Union[str, Dict]]] = parts[0].text
else:
content = [proto.Message.to_dict(part) for part in parts]
return (AIMessageChunk if stream else AIMessage)(
content=content, additional_kwargs={}

return AIMessage(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)


Expand All @@ -400,7 +511,7 @@ def _response_to_result(
]
generations.append(
(ChatGenerationChunk if stream else ChatGeneration)(
message=_parse_response_candidate(candidate, stream=stream),
message=_parse_response_candidate(candidate, streaming=stream),
generation_info=generation_info,
)
)
Expand Down Expand Up @@ -627,20 +738,26 @@ def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
tools: Optional[Sequence[ToolDict]] = None,
functions: Optional[Sequence[FunctionDeclarationType]] = None,
safety_settings: Optional[SafetySettingDict] = None,
tool_config: Optional[Union[Dict, ToolConfigDict]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
client = self.client
functions = kwargs.pop("functions", None)
safety_settings = kwargs.pop("safety_settings", self.safety_settings)
if functions or safety_settings:
tools = (
convert_to_genai_function_declarations(functions) if functions else None
)
formatted_tools = None
raw_tools = tools if tools else functions
if raw_tools:
formatted_tools = convert_to_genai_function_declarations(raw_tools)

if formatted_tools or safety_settings:
client = genai.GenerativeModel(
model_name=self.model, tools=tools, safety_settings=safety_settings
model_name=self.model,
tools=formatted_tools,
safety_settings=safety_settings,
)

params = self._prepare_params(stop, **kwargs)
params = self._prepare_params(stop, tool_config=tool_config, **kwargs)
system_instruction, history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
Expand Down
Loading
Loading