diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index d4a6c5bfa69073..89fba79fced4b0 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -23,7 +23,7 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN, LOGGER +from .const import CONF_PROMPT, DOMAIN, LOGGER SERVICE_GENERATE_CONTENT = "generate_content" CONF_IMAGE_FILENAME = "image_filename" @@ -97,11 +97,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: genai.configure(api_key=entry.data[CONF_API_KEY]) try: - await hass.async_add_executor_job( - partial( - genai.get_model, entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) - ) - ) + await hass.async_add_executor_job(partial(genai.list_models)) except ClientError as err: if err.reason == "API_KEY_INVALID": LOGGER.error("Invalid API key: %s", err) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index ab1c976273f973..6bf65de86f07ff 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -4,7 +4,6 @@ from functools import partial import logging -import types from types import MappingProxyType from typing import Any @@ -18,11 +17,15 @@ ConfigFlowResult, OptionsFlow, ) -from homeassistant.const import CONF_API_KEY +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from homeassistant.helpers.selector import ( NumberSelector, NumberSelectorConfig, + SelectOptionDict, + SelectSelector, + SelectSelectorConfig, TemplateSelector, ) @@ -50,17 +53,6 @@ } ) -DEFAULT_OPTIONS = types.MappingProxyType( - { - CONF_PROMPT: DEFAULT_PROMPT, - CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL, - CONF_TEMPERATURE: DEFAULT_TEMPERATURE, - CONF_TOP_P: DEFAULT_TOP_P, - CONF_TOP_K: DEFAULT_TOP_K, - CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS, - } -) - async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """Validate the user input allows us to connect. @@ -99,7 +91,9 @@ async def async_step_user( errors["base"] = "unknown" else: return self.async_create_entry( - title="Google Generative AI Conversation", data=user_input + title="Google Generative AI", + data=user_input, + options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, ) return self.async_show_form( @@ -126,53 +120,96 @@ async def async_step_init( ) -> ConfigFlowResult: """Manage the options.""" if user_input is not None: - return self.async_create_entry( - title="Google Generative AI Conversation", data=user_input - ) - schema = google_generative_ai_config_option_schema(self.config_entry.options) + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) + return self.async_create_entry(title="", data=user_input) + schema = await google_generative_ai_config_option_schema( + self.hass, self.config_entry.options + ) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), ) -def google_generative_ai_config_option_schema( +async def google_generative_ai_config_option_schema( + hass: HomeAssistant, options: MappingProxyType[str, Any], ) -> dict: """Return a schema for Google Generative AI completion options.""" - if not options: - options = DEFAULT_OPTIONS + api_models = await hass.async_add_executor_job(partial(genai.list_models)) + + models: list[SelectOptionDict] = [ + SelectOptionDict( + label="Gemini 1.5 Flash (recommended)", + value="models/gemini-1.5-flash-latest", + ), + ] + models.extend( + SelectOptionDict( + label=api_model.display_name, + value=api_model.name, + ) + for api_model in sorted(api_models, key=lambda x: x.display_name) + if ( + api_model.name + not in ( + "models/gemini-1.0-pro", # duplicate of gemini-pro + "models/gemini-1.5-flash-latest", + ) + and "vision" not in api_model.name + and "generateContent" in api_model.supported_generation_methods + ) + ) + + apis: list[SelectOptionDict] = [ + SelectOptionDict( + label="No control", + value="none", + ) + ] + apis.extend( + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ) + return { + vol.Optional( + CONF_CHAT_MODEL, + description={"suggested_value": options.get(CONF_CHAT_MODEL)}, + default=DEFAULT_CHAT_MODEL, + ): SelectSelector(SelectSelectorConfig(options=models)), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=apis)), vol.Optional( CONF_PROMPT, - description={"suggested_value": options[CONF_PROMPT]}, + description={"suggested_value": options.get(CONF_PROMPT)}, default=DEFAULT_PROMPT, ): TemplateSelector(), - vol.Optional( - CONF_CHAT_MODEL, - description={ - "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) - }, - default=DEFAULT_CHAT_MODEL, - ): str, vol.Optional( CONF_TEMPERATURE, - description={"suggested_value": options[CONF_TEMPERATURE]}, + description={"suggested_value": options.get(CONF_TEMPERATURE)}, default=DEFAULT_TEMPERATURE, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), vol.Optional( CONF_TOP_P, - description={"suggested_value": options[CONF_TOP_P]}, + description={"suggested_value": options.get(CONF_TOP_P)}, default=DEFAULT_TOP_P, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), vol.Optional( CONF_TOP_K, - description={"suggested_value": options[CONF_TOP_K]}, + description={"suggested_value": options.get(CONF_TOP_K)}, default=DEFAULT_TOP_K, ): int, vol.Optional( CONF_MAX_TOKENS, - description={"suggested_value": options[CONF_MAX_TOKENS]}, + description={"suggested_value": options.get(CONF_MAX_TOKENS)}, default=DEFAULT_MAX_TOKENS, ): int, } diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index f7e71989efd834..ba47b2acfe3a77 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -21,11 +21,8 @@ {%- endif %} {%- endfor %} {%- endfor %} - -Answer the user's questions about the world truthfully. - -If the user wants to control a device, reject the request and suggest using the Home Assistant app. """ + CONF_CHAT_MODEL = "chat_model" DEFAULT_CHAT_MODEL = "models/gemini-pro" CONF_TEMPERATURE = "temperature" @@ -36,3 +33,4 @@ DEFAULT_TOP_K = 1 CONF_MAX_TOKENS = "max_tokens" DEFAULT_MAX_TOKENS = 150 +DEFAULT_ALLOW_HASS_ACCESS = False diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 90a3104f662aa7..8e16e8eaceb2af 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -2,18 +2,21 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal +import google.ai.generativelanguage as glm from google.api_core.exceptions import ClientError import google.generativeai as genai import google.generativeai.types as genai_types +import voluptuous as vol +from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.config_entries import ConfigEntry -from homeassistant.const import MATCH_ALL +from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import TemplateError -from homeassistant.helpers import intent, template +from homeassistant.exceptions import HomeAssistantError, TemplateError +from homeassistant.helpers import intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid @@ -30,9 +33,13 @@ DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, + DOMAIN, LOGGER, ) +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 + async def async_setup_entry( hass: HomeAssistant, @@ -44,6 +51,55 @@ async def async_setup_entry( async_add_entities([agent]) +SUPPORTED_SCHEMA_KEYS = { + "type", + "format", + "description", + "nullable", + "enum", + "items", + "properties", + "required", +} + + +def _format_schema(schema: dict[str, Any]) -> dict[str, Any]: + """Format the schema to protobuf.""" + result = {} + for key, val in schema.items(): + if key not in SUPPORTED_SCHEMA_KEYS: + continue + if key == "type": + key = "type_" + val = val.upper() + elif key == "format": + key = "format_" + elif key == "items": + val = _format_schema(val) + elif key == "properties": + val = {k: _format_schema(v) for k, v in val.items()} + result[key] = val + return result + + +def _format_tool(tool: llm.Tool) -> dict[str, Any]: + """Format tool specification.""" + + parameters = _format_schema(convert(tool.parameters)) + + return glm.Tool( + { + "function_declarations": [ + { + "name": tool.name, + "description": tool.description, + "parameters": parameters, + } + ] + } + ) + + class GoogleGenerativeAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -80,6 +136,26 @@ async def async_process( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" + intent_response = intent.IntentResponse(language=user_input.language) + llm_api: llm.API | None = None + tools: list[dict[str, Any]] | None = None + + if self.entry.options.get(CONF_LLM_HASS_API): + try: + llm_api = llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API] + ) + except HomeAssistantError as err: + LOGGER.error("Error getting LLM API: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Error preparing LLM API: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=user_input.conversation_id + ) + tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) model = genai.GenerativeModel( model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL), @@ -93,8 +169,8 @@ async def async_process( CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS ), }, + tools=tools or None, ) - LOGGER.debug("Model: %s", model) if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id @@ -103,9 +179,8 @@ async def async_process( conversation_id = ulid.ulid_now() messages = [{}, {}] - intent_response = intent.IntentResponse(language=user_input.language) try: - prompt = self._async_generate_prompt(raw_prompt) + prompt = self._async_generate_prompt(raw_prompt, llm_api) except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) intent_response.async_set_error( @@ -122,40 +197,84 @@ async def async_process( LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) chat = model.start_chat(history=messages) - try: - chat_response = await chat.send_message_async(user_input.text) - except ( - ClientError, - ValueError, - genai_types.BlockedPromptException, - genai_types.StopCandidateException, - ) as err: - LOGGER.error("Error sending message: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to Google Generative AI: {err}", + chat_request = user_input.text + # To prevent infinite loops, we limit the number of iterations + for _iteration in range(MAX_TOOL_ITERATIONS): + try: + chat_response = await chat.send_message_async(chat_request) + except ( + ClientError, + ValueError, + genai_types.BlockedPromptException, + genai_types.StopCandidateException, + ) as err: + LOGGER.error("Error sending message: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to Google Generative AI: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + LOGGER.debug("Response: %s", chat_response.parts) + if not chat_response.parts: + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + "Sorry, I had a problem talking to Google Generative AI. Likely blocked", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + self.history[conversation_id] = chat.history + tool_call = chat_response.parts[0].function_call + + if not tool_call or not llm_api: + break + + tool_input = llm.ToolInput( + tool_name=tool_call.name, + tool_args=dict(tool_call.args), + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + LOGGER.debug( + "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args ) + try: + function_response = await llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + function_response = {"error": type(e).__name__} + if str(e): + function_response["error_text"] = str(e) - LOGGER.debug("Response: %s", chat_response.parts) - if not chat_response.parts: - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - "Sorry, I had a problem talking to Google Generative AI. Likely blocked", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + LOGGER.debug("Tool response: %s", function_response) + chat_request = glm.Content( + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + name=tool_call.name, response=function_response + ) + ) + ] ) - self.history[conversation_id] = chat.history + intent_response.async_set_speech(chat_response.text) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) - def _async_generate_prompt(self, raw_prompt: str) -> str: + def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str: """Generate a prompt for the user.""" + raw_prompt += "\n" + if llm_api: + raw_prompt += llm_api.prompt_template + else: + raw_prompt += llm.PROMPT_NO_API_CONFIGURED + return template.Template(raw_prompt, self.hass).async_render( { "ha_name": self.hass.config.location_name, diff --git a/homeassistant/components/google_generative_ai_conversation/manifest.json b/homeassistant/components/google_generative_ai_conversation/manifest.json index bcbba23e9a7e8d..00ba74f16b2c21 100644 --- a/homeassistant/components/google_generative_ai_conversation/manifest.json +++ b/homeassistant/components/google_generative_ai_conversation/manifest.json @@ -1,12 +1,12 @@ { "domain": "google_generative_ai_conversation", - "name": "Google Generative AI Conversation", - "after_dependencies": ["assist_pipeline"], + "name": "Google Generative AI", + "after_dependencies": ["assist_pipeline", "intent"], "codeowners": ["@tronikos"], "config_flow": true, "dependencies": ["conversation"], "documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation", "integration_type": "service", "iot_class": "cloud_polling", - "requirements": ["google-generativeai==0.5.4"] + "requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.3"] } diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json index 306072f33a8ea4..a6be0c694c17a6 100644 --- a/homeassistant/components/google_generative_ai_conversation/strings.json +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -3,7 +3,8 @@ "step": { "user": { "data": { - "api_key": "[%key:common::config_flow::data::api_key%]" + "api_key": "[%key:common::config_flow::data::api_key%]", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" } } }, @@ -18,11 +19,12 @@ "init": { "data": { "prompt": "Prompt Template", - "model": "[%key:common::generic::model%]", + "chat_model": "[%key:common::generic::model%]", "temperature": "Temperature", "top_p": "Top P", "top_k": "Top K", - "max_tokens": "Maximum tokens to return in response" + "max_tokens": "Maximum tokens to return in response", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" } } } diff --git a/homeassistant/const.py b/homeassistant/const.py index 66b4b3e4dcf360..77de43f730f8a2 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -113,6 +113,7 @@ class Platform(StrEnum): CONF_ADDRESS: Final = "address" CONF_AFTER: Final = "after" CONF_ALIAS: Final = "alias" +CONF_LLM_HASS_API = "llm_hass_api" CONF_ALLOWLIST_EXTERNAL_URLS: Final = "allowlist_external_urls" CONF_API_KEY: Final = "api_key" CONF_API_TOKEN: Final = "api_token" diff --git a/homeassistant/generated/integrations.json b/homeassistant/generated/integrations.json index 938aa216747b13..e5b061cad2366f 100644 --- a/homeassistant/generated/integrations.json +++ b/homeassistant/generated/integrations.json @@ -2271,7 +2271,7 @@ "integration_type": "service", "config_flow": true, "iot_class": "cloud_polling", - "name": "Google Generative AI Conversation" + "name": "Google Generative AI" }, "google_mail": { "integration_type": "service", diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index db1b46f656a73d..2edc6d650f4a4c 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -17,18 +17,17 @@ from . import intent from .singleton import singleton +LLM_API_ASSIST = "assist" + +PROMPT_NO_API_CONFIGURED = "If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant." + @singleton("llm") @callback def _async_get_apis(hass: HomeAssistant) -> dict[str, API]: """Get all the LLM APIs.""" return { - "assist": AssistAPI( - hass=hass, - id="assist", - name="Assist", - prompt_template="Call the intent tools to control the system. Just pass the name to the intent.", - ), + LLM_API_ASSIST: AssistAPI(hass=hass), } @@ -170,6 +169,15 @@ class AssistAPI(API): INTENT_GET_TEMPERATURE, } + def __init__(self, hass: HomeAssistant) -> None: + """Init the class.""" + super().__init__( + hass=hass, + id=LLM_API_ASSIST, + name="Assist", + prompt_template="Call the intent tools to control the system. Just pass the name to the intent.", + ) + @callback def async_get_tools(self) -> list[Tool]: """Return a list of LLM tools.""" diff --git a/homeassistant/strings.json b/homeassistant/strings.json index 97bba2fb3b7b3c..b31e83394bbacf 100644 --- a/homeassistant/strings.json +++ b/homeassistant/strings.json @@ -88,6 +88,7 @@ "access_token": "Access token", "api_key": "API key", "api_token": "API token", + "llm_hass_api": "Control Home Assistant", "ssl": "Uses an SSL certificate", "verify_ssl": "Verify SSL certificate", "elevation": "Elevation", diff --git a/requirements_all.txt b/requirements_all.txt index c9bd31d33aaaca..102425ab4596a9 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2825,6 +2825,9 @@ voip-utils==0.1.0 # homeassistant.components.volkszaehler volkszaehler==0.4.0 +# homeassistant.components.google_generative_ai_conversation +voluptuous-openapi==0.0.3 + # homeassistant.components.volvooncall volvooncall==0.10.3 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 9fc9ff2dc1ef4f..a0821223ac4079 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2190,6 +2190,9 @@ vilfo-api-client==0.5.0 # homeassistant.components.voip voip-utils==0.1.0 +# homeassistant.components.google_generative_ai_conversation +voluptuous-openapi==0.0.3 + # homeassistant.components.volvooncall volvooncall==0.10.3 diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index d5b4e8672e3f8a..4dfa6379d737dc 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -5,7 +5,9 @@ import pytest from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -25,6 +27,15 @@ def mock_config_entry(hass): return entry +@pytest.fixture +def mock_config_entry_with_assist(hass, mock_config_entry): + """Mock a config entry with assist.""" + hass.config_entries.async_update_entry( + mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + ) + return mock_config_entry + + @pytest.fixture async def mock_init_component(hass: HomeAssistant, mock_config_entry: ConfigEntry): """Initialize integration.""" diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index bf37fe0f2d950f..f97c331705e4fc 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_default_prompt[None] +# name: test_default_prompt[False-None] list([ tuple( '', @@ -13,6 +13,7 @@ 'top_p': 1.0, }), 'model_name': 'models/gemini-pro', + 'tools': None, }), ), tuple( @@ -36,9 +37,7 @@ - Test Device 4 - 1 (3) - Answer the user's questions about the world truthfully. - - If the user wants to control a device, reject the request and suggest using the Home Assistant app. + If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. ''', 'role': 'user', }), @@ -59,7 +58,7 @@ ), ]) # --- -# name: test_default_prompt[conversation.google_generative_ai_conversation] +# name: test_default_prompt[False-conversation.google_generative_ai_conversation] list([ tuple( '', @@ -73,6 +72,7 @@ 'top_p': 1.0, }), 'model_name': 'models/gemini-pro', + 'tools': None, }), ), tuple( @@ -96,9 +96,7 @@ - Test Device 4 - 1 (3) - Answer the user's questions about the world truthfully. - - If the user wants to control a device, reject the request and suggest using the Home Assistant app. + If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. ''', 'role': 'user', }), @@ -119,48 +117,118 @@ ), ]) # --- -# name: test_generate_content_service_with_image +# name: test_default_prompt[True-None] list([ tuple( '', tuple( ), dict({ - 'model_name': 'gemini-pro-vision', + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 0.9, + 'top_k': 1, + 'top_p': 1.0, + }), + 'model_name': 'models/gemini-pro', + 'tools': None, }), ), tuple( - '().generate_content_async', + '().start_chat', tuple( - list([ - 'Describe this image from my doorbell camera', + ), + dict({ + 'history': list([ dict({ - 'data': b'image bytes', - 'mime_type': 'image/jpeg', + 'parts': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Call the intent tools to control the system. Just pass the name to the intent. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', }), ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'hello', ), dict({ }), ), ]) # --- -# name: test_generate_content_service_without_images +# name: test_default_prompt[True-conversation.google_generative_ai_conversation] list([ tuple( '', tuple( ), dict({ - 'model_name': 'gemini-pro', + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 0.9, + 'top_k': 1, + 'top_p': 1.0, + }), + 'model_name': 'models/gemini-pro', + 'tools': None, }), ), tuple( - '().generate_content_async', + '().start_chat', tuple( - list([ - 'Write an opening speech for a Home Assistant release party', + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Call the intent tools to control the system. Just pass the name to the intent. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'hello', ), dict({ }), diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index 3bac01db42df83..57c9633a743aee 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -1,6 +1,6 @@ """Test the Google Generative AI Conversation config flow.""" -from unittest.mock import patch +from unittest.mock import Mock, patch from google.api_core.exceptions import ClientError from google.rpc.error_details_pb2 import ErrorInfo @@ -18,12 +18,35 @@ DEFAULT_TOP_P, DOMAIN, ) +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType +from homeassistant.helpers import llm from tests.common import MockConfigEntry +@pytest.fixture +def mock_models(): + """Mock the model list API.""" + model_15_flash = Mock( + display_name="Gemini 1.5 Flash", + supported_generation_methods=["generateContent"], + ) + model_15_flash.name = "models/gemini-1.5-flash-latest" + + model_10_pro = Mock( + display_name="Gemini 1.0 Pro", + supported_generation_methods=["generateContent"], + ) + model_10_pro.name = "models/gemini-pro" + with patch( + "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", + return_value=[model_10_pro], + ): + yield + + async def test_form(hass: HomeAssistant) -> None: """Test we get the form.""" # Pretend we already set up a config entry. @@ -60,11 +83,14 @@ async def test_form(hass: HomeAssistant) -> None: assert result2["data"] == { "api_key": "bla", } + assert result2["options"] == { + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + } assert len(mock_setup_entry.mock_calls) == 1 async def test_options( - hass: HomeAssistant, mock_config_entry, mock_init_component + hass: HomeAssistant, mock_config_entry, mock_init_component, mock_models ) -> None: """Test the options form.""" options_flow = await hass.config_entries.options.async_init( @@ -85,6 +111,9 @@ async def test_options( assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS + assert ( + CONF_LLM_HASS_API not in options["data"] + ), "Options flow should not set this key" @pytest.mark.parametrize( diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index e56838c4b31112..b267d605b44111 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -5,10 +5,18 @@ from google.api_core.exceptions import ClientError import pytest from syrupy.assertion import SnapshotAssertion +import voluptuous as vol from homeassistant.components import conversation +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import area_registry as ar, device_registry as dr, intent +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + intent, + llm, +) from tests.common import MockConfigEntry @@ -16,6 +24,7 @@ @pytest.mark.parametrize( "agent_id", [None, "conversation.google_generative_ai_conversation"] ) +@pytest.mark.parametrize("allow_hass_access", [False, True]) async def test_default_prompt( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -24,6 +33,7 @@ async def test_default_prompt( device_registry: dr.DeviceRegistry, snapshot: SnapshotAssertion, agent_id: str | None, + allow_hass_access: bool, ) -> None: """Test that the default prompt works.""" entry = MockConfigEntry(title=None) @@ -34,6 +44,15 @@ async def test_default_prompt( if agent_id is None: agent_id = mock_config_entry.entry_id + if allow_hass_access: + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + **mock_config_entry.options, + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + }, + ) + device_registry.async_get_or_create( config_entry_id=entry.entry_id, connections={("test", "1234")}, @@ -100,12 +119,20 @@ async def test_default_prompt( model=3, suggested_area="Test Area 2", ) - with patch("google.generativeai.GenerativeModel") as mock_model: + with ( + patch("google.generativeai.GenerativeModel") as mock_model, + patch( + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools", + return_value=[], + ) as mock_get_tools, + ): mock_chat = AsyncMock() mock_model.return_value.start_chat.return_value = mock_chat chat_response = MagicMock() mock_chat.send_message_async.return_value = chat_response - chat_response.parts = ["Hi there!"] + mock_part = MagicMock() + mock_part.function_call = None + chat_response.parts = [mock_part] chat_response.text = "Hi there!" result = await conversation.async_converse( hass, @@ -118,6 +145,171 @@ async def test_default_prompt( assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + assert mock_get_tools.called == allow_hass_access + + +@patch( + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" +) +async def test_function_call( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test that the default prompt works.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + { + vol.Optional("param1", description="Test parameters"): [ + vol.All(str, vol.Lower) + ] + } + ) + + mock_get_tools.return_value = [mock_tool] + + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + chat_response = MagicMock() + mock_chat.send_message_async.return_value = chat_response + mock_part = MagicMock() + mock_part.function_call.name = "test_tool" + mock_part.function_call.args = {"param1": ["test_value"]} + + def tool_call(hass, tool_input): + mock_part.function_call = False + chat_response.text = "Hi there!" + return {"result": "Test response"} + + mock_tool.async_call.side_effect = tool_call + chat_response.parts = [mock_part] + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" + mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] + mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) + assert mock_tool_call == { + "parts": [ + { + "function_response": { + "name": "test_tool", + "response": { + "result": "Test response", + }, + }, + }, + ], + "role": "", + } + + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": ["test_value"]}, + platform="google_generative_ai_conversation", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + ), + ) + + +@patch( + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" +) +async def test_function_exception( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test that the default prompt works.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + { + vol.Optional("param1", description="Test parameters"): vol.All( + vol.Coerce(int), vol.Range(0, 100) + ) + } + ) + + mock_get_tools.return_value = [mock_tool] + + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + chat_response = MagicMock() + mock_chat.send_message_async.return_value = chat_response + mock_part = MagicMock() + mock_part.function_call.name = "test_tool" + mock_part.function_call.args = {"param1": 1} + + def tool_call(hass, tool_input): + mock_part.function_call = False + chat_response.text = "Hi there!" + raise HomeAssistantError("Test tool exception") + + mock_tool.async_call.side_effect = tool_call + chat_response.parts = [mock_part] + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" + mock_tool_call = mock_chat.send_message_async.mock_calls[1][1][0] + mock_tool_call = type(mock_tool_call).to_dict(mock_tool_call) + assert mock_tool_call == { + "parts": [ + { + "function_response": { + "name": "test_tool", + "response": { + "error": "HomeAssistantError", + "error_text": "Test tool exception", + }, + }, + }, + ], + "role": "", + } + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": 1}, + platform="google_generative_ai_conversation", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + ), + ) async def test_error_handling( diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 861a63ec3efb69..8b3de48e5ae099 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -18,12 +18,13 @@ async def test_get_api_no_existing(hass: HomeAssistant) -> None: async def test_register_api(hass: HomeAssistant) -> None: """Test registering an llm api.""" - api = llm.AssistAPI( - hass=hass, - id="test", - name="Test", - prompt_template="Test", - ) + + class MyAPI(llm.API): + def async_get_tools(self) -> list[llm.Tool]: + """Return a list of tools.""" + return [] + + api = MyAPI(hass=hass, id="test", name="Test", prompt_template="") llm.async_register_api(hass, api) assert llm.async_get_api(hass, "test") is api