Skip to content

Commit

Permalink
Expose scripts with no fields as entities (#123061)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shulyaka authored Oct 23, 2024
1 parent 3ddef56 commit e0e61b5
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 83 deletions.
175 changes: 95 additions & 80 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,9 @@ def _async_get_tools(
):
continue

tools.append(ScriptTool(self.hass, state.entity_id))
script_tool = ScriptTool(self.hass, state.entity_id)
if script_tool.parameters.schema:
tools.append(script_tool)

return tools

Expand Down Expand Up @@ -451,12 +453,17 @@ def _get_exposed_entities(
entities = {}

for state in hass.states.async_all():
if state.domain == SCRIPT_DOMAIN:
continue

if not async_should_expose(hass, assistant, state.entity_id):
continue

description: str | None = None
if state.domain == SCRIPT_DOMAIN:
description, parameters = _get_cached_script_parameters(
hass, state.entity_id
)
if parameters.schema: # Only list scripts without input fields here
continue

entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
Expand Down Expand Up @@ -485,6 +492,9 @@ def _get_exposed_entities(
"state": state.state,
}

if description:
info["description"] = description

if area_names:
info["areas"] = ", ".join(area_names)

Expand Down Expand Up @@ -610,6 +620,83 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
return {"type": "string"}


def _get_cached_script_parameters(
hass: HomeAssistant, entity_id: str
) -> tuple[str | None, vol.Schema]:
"""Get script description and schema."""
entity_registry = er.async_get(hass)

description = None
parameters = vol.Schema({})
entity_entry = entity_registry.async_get(entity_id)
if entity_entry and entity_entry.unique_id:
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)

if parameters_cache is None:
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}

@callback
def clear_cache(event: Event) -> None:
"""Clear script parameter cache on script reload or delete."""
if (
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
and event.data[ATTR_SERVICE] in parameters_cache
):
parameters_cache.pop(event.data[ATTR_SERVICE])

cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)

@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()

hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
)

if entity_entry.unique_id in parameters_cache:
return parameters_cache[entity_entry.unique_id]

if service_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id
):
description = service_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {})

for field, config in fields.items():
field_description = config.get("description")
if not field_description:
field_description = config.get("name")
key: vol.Marker
if config.get("required"):
key = vol.Required(field, description=field_description)
else:
key = vol.Optional(field, description=field_description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string

parameters = vol.Schema(schema)

aliases: list[str] = []
if entity_entry.name:
aliases.append(entity_entry.name)
if entity_entry.aliases:
aliases.extend(entity_entry.aliases)
if aliases:
if description:
description = description + ". Aliases: " + str(list(aliases))
else:
description = "Aliases: " + str(list(aliases))

parameters_cache[entity_entry.unique_id] = (description, parameters)

return description, parameters


class ScriptTool(Tool):
"""LLM Tool representing a Script."""

Expand All @@ -619,86 +706,14 @@ def __init__(
script_entity_id: str,
) -> None:
"""Init the class."""
entity_registry = er.async_get(hass)

self.name = split_entity_id(script_entity_id)[1]
if self.name[0].isdigit():
self.name = "_" + self.name
self._entity_id = script_entity_id
self.parameters = vol.Schema({})
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id:
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)

if parameters_cache is None:
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}

@callback
def clear_cache(event: Event) -> None:
"""Clear script parameter cache on script reload or delete."""
if (
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
and event.data[ATTR_SERVICE] in parameters_cache
):
parameters_cache.pop(event.data[ATTR_SERVICE])

cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)

@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()

hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
)

if entity_entry.unique_id in parameters_cache:
self.description, self.parameters = parameters_cache[
entity_entry.unique_id
]
return

if service_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id
):
self.description = service_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {})

for field, config in fields.items():
description = config.get("description")
if not description:
description = config.get("name")
key: vol.Marker
if config.get("required"):
key = vol.Required(field, description=description)
else:
key = vol.Optional(field, description=description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string

self.parameters = vol.Schema(schema)

aliases: list[str] = []
if entity_entry.name:
aliases.append(entity_entry.name)
if entity_entry.aliases:
aliases.extend(entity_entry.aliases)
if aliases:
if self.description:
self.description = (
self.description + ". Aliases: " + str(list(aliases))
)
else:
self.description = "Aliases: " + str(list(aliases))

parameters_cache[entity_entry.unique_id] = (
self.description,
self.parameters,
)

self.description, self.parameters = _get_cached_script_parameters(
hass, script_entity_id
)

async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
Expand Down
22 changes: 19 additions & 3 deletions tests/helpers/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,16 @@ async def test_assist_api_prompt(
"beer": {"description": "Number of beers"},
"wine": {},
},
}
},
"script_with_no_fields": {
"description": "This is another test script",
"sequence": [],
},
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)

entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
Expand Down Expand Up @@ -511,6 +516,10 @@ def create_entity(
)
)
exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
- names: script_with_no_fields
domain: script
state: 'off'
description: This is another test script
- names: Kitchen
domain: light
state: 'on'
Expand Down Expand Up @@ -657,13 +666,18 @@ async def test_script_tool(
"extra_field": {"selector": {"area": {}}},
},
},
"script_with_no_fields": {
"description": "This is another test script",
"sequence": [],
},
"unexposed_script": {
"sequence": [],
},
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)

entity_registry.async_update_entity(
"script.test_script", name="script name", aliases={"script alias"}
Expand Down Expand Up @@ -700,7 +714,8 @@ async def test_script_tool(
"test_script": (
"This is a test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema),
)
),
"script_with_no_fields": ("This is another test script", vol.Schema({})),
}

tool_input = llm.ToolInput(
Expand Down Expand Up @@ -781,7 +796,8 @@ async def test_script_tool(
"test_script": (
"This is a new test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema),
)
),
"script_with_no_fields": ("This is another test script", vol.Schema({})),
}


Expand Down

0 comments on commit e0e61b5

Please sign in to comment.