Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle escaped unicode characters passed to tools in Google Generative AI #119117

Merged
merged 6 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import codecs
from typing import Any, Literal

from google.api_core.exceptions import GoogleAPICallError
Expand Down Expand Up @@ -106,14 +107,14 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
)


def _adjust_value(value: Any) -> Any:
"""Reverse unnecessary single quotes escaping."""
def _escape_decode(value: Any) -> Any:
"""Recursively call codecs.escape_decode on all values."""
if isinstance(value, str):
return value.replace("\\'", "'")
return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined]
if isinstance(value, list):
return [_adjust_value(item) for item in value]
return [_escape_decode(item) for item in value]
if isinstance(value, dict):
return {k: _adjust_value(v) for k, v in value.items()}
return {k: _escape_decode(v) for k, v in value.items()}
return value


Expand Down Expand Up @@ -334,10 +335,7 @@ async def async_process(
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = {
key: _adjust_value(value)
for key, value in tool_call["args"].items()
}
tool_args = _escape_decode(tool_call["args"])
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.components.google_generative_ai_conversation.conversation import (
_escape_decode,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
Expand Down Expand Up @@ -504,3 +507,18 @@ async def test_conversation_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"


async def test_escape_decode() -> None:
"""Test _escape_decode."""
assert _escape_decode(
{
"param1": ["test_value", "param1\\'s value"],
"param2": "param2\\'s value",
"param3": {"param31": "Cheminée", "param32": "Chemin\\303\\251e"},
}
) == {
"param1": ["test_value", "param1's value"],
"param2": "param2's value",
"param3": {"param31": "Cheminée", "param32": "Cheminée"},
}