From c49ca5ed56818a91594295d6e44b7edb1ef24665 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 8 Jun 2024 11:53:47 -0400 Subject: [PATCH] Ensure intent tools have safe names (#119144) --- homeassistant/helpers/llm.py | 13 +++++++++++-- tests/helpers/test_llm.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 3c240692d52e08..903e52af1a2e09 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -5,8 +5,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum +from functools import cache, partial from typing import Any +import slugify as unicode_slug import voluptuous as vol from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE @@ -175,10 +177,11 @@ class IntentTool(Tool): def __init__( self, + name: str, intent_handler: intent.IntentHandler, ) -> None: """Init the class.""" - self.name = intent_handler.intent_type + self.name = name self.description = ( intent_handler.description or f"Execute Home Assistant {self.name} intent" ) @@ -261,6 +264,9 @@ def __init__(self, hass: HomeAssistant) -> None: id=LLM_API_ASSIST, name="Assist", ) + self.cached_slugify = cache( + partial(unicode_slug.slugify, separator="_", lowercase=False) + ) async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance: """Return the instance of the API.""" @@ -373,7 +379,10 @@ def _async_get_tools( or intent_handler.platforms & exposed_domains ] - return [IntentTool(intent_handler) for intent_handler in intent_handlers] + return [ + IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler) + for intent_handler in intent_handlers + ] def _get_exposed_entities( diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 3f61ed8a0ed6ae..6ac17a2fe0efb7 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -249,6 +249,39 @@ async def test_assist_api_get_timer_tools( assert "HassStartTimer" in [tool.name for tool in api.tools] +async def test_assist_api_tools( + hass: HomeAssistant, llm_context: llm.LLMContext +) -> None: + """Test getting timer tools with Assist API.""" + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "intent", {}) + + llm_context.device_id = "test_device" + + async_register_timer_handler(hass, "test_device", lambda *args: None) + + class MyIntentHandler(intent.IntentHandler): + intent_type = "Super crazy intent with unique nĂ¥me" + description = "my intent handler" + + intent.async_register(hass, MyIntentHandler()) + + api = await llm.async_get_api(hass, "assist", llm_context) + assert [tool.name for tool in api.tools] == [ + "HassTurnOn", + "HassTurnOff", + "HassSetPosition", + "HassStartTimer", + "HassCancelTimer", + "HassIncreaseTimer", + "HassDecreaseTimer", + "HassPauseTimer", + "HassUnpauseTimer", + "HassTimerStatus", + "Super_crazy_intent_with_unique_name", + ] + + async def test_assist_api_description( hass: HomeAssistant, llm_context: llm.LLMContext ) -> None: