Skip to content

Commit

Permalink
mend
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Apr 24, 2024
1 parent 59633d1 commit ad4feaa
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 18 deletions.
104 changes: 87 additions & 17 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import uuid
import warnings
from io import BytesIO
from typing import (
Expand Down Expand Up @@ -46,9 +47,13 @@
BaseMessage,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
Expand Down Expand Up @@ -401,24 +406,89 @@ def _parse_chat_history(


def _parse_response_candidate(
response_candidate: glm.Candidate, stream: bool
response_candidate: glm.Candidate, streaming: bool = False
) -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = proto.Message.to_dict(first_part.function_call)
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
return (AIMessageChunk if stream else AIMessage)(
content="", additional_kwargs={"function_call": function_call}
content: Union[None, str, List[str]] = None
additional_kwargs = {}
tool_calls = []
invalid_tool_calls = []
tool_call_chunks = []

for part in response_candidate.content.parts:
try:
text: Optional[str] = part.text
except AttributeError:
text = None

if text is not None:
if not content:
content = text
elif isinstance(content, str) and text:
content = [content, text]
elif isinstance(content, list) and text:
content.append(text)
elif text:
raise Exception("Unexpected content type")

if part.function_call:
# TODO: support multiple function calls
if "function_call" in additional_kwargs:
raise Exception("Multiple function calls are not currently supported")
function_call = {"name": part.function_call.name}
# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
function_call["arguments"] = json.dumps(
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call

if streaming:
tool_call_chunks.append(
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
index=function_call.get("index"), # type: ignore
)
)
else:
try:
tool_calls_dicts = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)
tool_calls = [
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id", str(uuid.uuid4())),
)
for tool_call in tool_calls_dicts
]
except Exception as e:
invalid_tool_calls = [
InvalidToolCall(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
error=str(e),
)
]
if content is None:
content = ""

if streaming:
return AIMessageChunk(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
else:
parts = response_candidate.content.parts

if len(parts) == 1 and parts[0].text:
content: Union[str, List[Union[str, Dict]]] = parts[0].text
else:
content = [proto.Message.to_dict(part) for part in parts]
return (AIMessageChunk if stream else AIMessage)(
content=content, additional_kwargs={}

return AIMessage(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)


Expand All @@ -441,7 +511,7 @@ def _response_to_result(
]
generations.append(
(ChatGenerationChunk if stream else ChatGeneration)(
message=_parse_response_candidate(candidate, stream=stream),
message=_parse_response_candidate(candidate, streaming=stream),
generation_info=generation_info,
)
)
Expand Down
5 changes: 4 additions & 1 deletion libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def search(
}
)
request = HumanMessage(
content="Please tell the primary color of following birds: sparrow, hawk, crow",
content=(
"Please tell the primary color of following birds: "
"sparrow, hawk, crow by using searchm"
)
)
response = llm_with_search_force.invoke([request])

Expand Down
Loading

0 comments on commit ad4feaa

Please sign in to comment.