diff --git a/.core_files.yaml b/.core_files.yaml index 4a11d5da27c5e..e852a56760179 100644 --- a/.core_files.yaml +++ b/.core_files.yaml @@ -14,6 +14,7 @@ core: &core base_platforms: &base_platforms - homeassistant/components/air_quality/** - homeassistant/components/alarm_control_panel/** + - homeassistant/components/assist_satellite/** - homeassistant/components/binary_sensor/** - homeassistant/components/button/** - homeassistant/components/calendar/** diff --git a/.strict-typing b/.strict-typing index 1a5133efe897f..84c22d1cfcac8 100644 --- a/.strict-typing +++ b/.strict-typing @@ -95,6 +95,7 @@ homeassistant.components.aruba.* homeassistant.components.arwn.* homeassistant.components.aseko_pool_live.* homeassistant.components.assist_pipeline.* +homeassistant.components.assist_satellite.* homeassistant.components.asuswrt.* homeassistant.components.autarco.* homeassistant.components.auth.* diff --git a/CODEOWNERS b/CODEOWNERS index edd10858e8d94..d2a60cbb24667 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -143,6 +143,8 @@ build.json @home-assistant/supervisor /tests/components/aseko_pool_live/ @milanmeu /homeassistant/components/assist_pipeline/ @balloob @synesthesiam /tests/components/assist_pipeline/ @balloob @synesthesiam +/homeassistant/components/assist_satellite/ @home-assistant/core @synesthesiam +/tests/components/assist_satellite/ @home-assistant/core @synesthesiam /homeassistant/components/asuswrt/ @kennedyshead @ollo69 /tests/components/asuswrt/ @kennedyshead @ollo69 /homeassistant/components/atag/ @MatsNL diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 0a03402105abd..ec6d8a646b6ec 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -17,6 +17,7 @@ DATA_LAST_WAKE_UP, DOMAIN, EVENT_RECORDING, + OPTION_PREFERRED, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH, @@ -58,6 +59,7 @@ "PipelineNotFound", "WakeWordSettings", "EVENT_RECORDING", + "OPTION_PREFERRED", "SAMPLES_PER_CHUNK", "SAMPLE_RATE", "SAMPLE_WIDTH", diff --git a/homeassistant/components/assist_pipeline/const.py b/homeassistant/components/assist_pipeline/const.py index f7306b89a54db..300cb5aad2aa6 100644 --- a/homeassistant/components/assist_pipeline/const.py +++ b/homeassistant/components/assist_pipeline/const.py @@ -22,3 +22,5 @@ MS_PER_CHUNK = 10 SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit + +OPTION_PREFERRED = "preferred" diff --git a/homeassistant/components/assist_pipeline/select.py b/homeassistant/components/assist_pipeline/select.py index 5d011424e6ea7..c7e4846aad73c 100644 --- a/homeassistant/components/assist_pipeline/select.py +++ b/homeassistant/components/assist_pipeline/select.py @@ -9,12 +9,10 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import collection, entity_registry as er, restore_state -from .const import DOMAIN +from .const import DOMAIN, OPTION_PREFERRED from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection from .vad import VadSensitivity -OPTION_PREFERRED = "preferred" - @callback def get_chosen_pipeline( diff --git a/homeassistant/components/assist_satellite/__init__.py b/homeassistant/components/assist_satellite/__init__.py new file mode 100644 index 0000000000000..3d6e04bcc7568 --- /dev/null +++ b/homeassistant/components/assist_satellite/__init__.py @@ -0,0 +1,65 @@ +"""Base class for assist satellite entities.""" + +import logging + +import voluptuous as vol + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.typing import ConfigType + +from .const import DOMAIN, AssistSatelliteEntityFeature +from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription +from .errors import SatelliteBusyError +from .websocket_api import async_register_websocket_api + +__all__ = [ + "DOMAIN", + "AssistSatelliteEntity", + "AssistSatelliteEntityDescription", + "AssistSatelliteEntityFeature", + "SatelliteBusyError", +] + +_LOGGER = logging.getLogger(__name__) + +PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity]( + _LOGGER, DOMAIN, hass + ) + await component.async_setup(config) + + component.async_register_entity_service( + "announce", + vol.All( + cv.make_entity_service_schema( + { + vol.Optional("message"): str, + vol.Optional("media_id"): str, + } + ), + cv.has_at_least_one_key("message", "media_id"), + ), + "async_internal_announce", + [AssistSatelliteEntityFeature.ANNOUNCE], + ) + async_register_websocket_api(hass) + + return True + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN] + return await component.async_setup_entry(entry) + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN] + return await component.async_unload_entry(entry) diff --git a/homeassistant/components/assist_satellite/const.py b/homeassistant/components/assist_satellite/const.py new file mode 100644 index 0000000000000..3a9ce896fb215 --- /dev/null +++ b/homeassistant/components/assist_satellite/const.py @@ -0,0 +1,12 @@ +"""Constants for assist satellite.""" + +from enum import IntFlag + +DOMAIN = "assist_satellite" + + +class AssistSatelliteEntityFeature(IntFlag): + """Supported features of Assist satellite entity.""" + + ANNOUNCE = 1 + """Device supports remotely triggered announcements.""" diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py new file mode 100644 index 0000000000000..8364a81b1fb8e --- /dev/null +++ b/homeassistant/components/assist_satellite/entity.py @@ -0,0 +1,332 @@ +"""Assist satellite entity.""" + +from abc import abstractmethod +import asyncio +from collections.abc import AsyncIterable +from enum import StrEnum +import logging +import time +from typing import Any, Final, final + +from homeassistant.components import media_source, stt, tts +from homeassistant.components.assist_pipeline import ( + OPTION_PREFERRED, + AudioSettings, + PipelineEvent, + PipelineEventType, + PipelineStage, + async_get_pipeline, + async_get_pipelines, + async_pipeline_from_audio_stream, + vad, +) +from homeassistant.components.media_player import async_process_play_media_url +from homeassistant.components.tts.media_source import ( + generate_media_source_id as tts_generate_media_source_id, +) +from homeassistant.core import Context, callback +from homeassistant.helpers import entity +from homeassistant.helpers.entity import EntityDescription +from homeassistant.util import ulid + +from .const import AssistSatelliteEntityFeature +from .errors import AssistSatelliteError, SatelliteBusyError + +_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes + +_LOGGER = logging.getLogger(__name__) + + +class AssistSatelliteState(StrEnum): + """Valid states of an Assist satellite entity.""" + + LISTENING_WAKE_WORD = "listening_wake_word" + """Device is streaming audio for wake word detection to Home Assistant.""" + + LISTENING_COMMAND = "listening_command" + """Device is streaming audio with the voice command to Home Assistant.""" + + PROCESSING = "processing" + """Home Assistant is processing the voice command.""" + + RESPONDING = "responding" + """Device is speaking the response.""" + + +class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True): + """A class that describes Assist satellite entities.""" + + +class AssistSatelliteEntity(entity.Entity): + """Entity encapsulating the state and functionality of an Assist satellite.""" + + entity_description: AssistSatelliteEntityDescription + _attr_should_poll = False + _attr_supported_features = AssistSatelliteEntityFeature(0) + _attr_pipeline_entity_id: str | None = None + _attr_vad_sensitivity_entity_id: str | None = None + + _conversation_id: str | None = None + _conversation_id_time: float | None = None + + _run_has_tts: bool = False + _is_announcing = False + _wake_word_intercept_future: asyncio.Future[str | None] | None = None + + __assist_satellite_state: AssistSatelliteState | None = None + + @final + @property + def state(self) -> str | None: + """Return state of the entity.""" + return self.__assist_satellite_state + + @property + def pipeline_entity_id(self) -> str | None: + """Entity ID of the pipeline to use for the next conversation.""" + return self._attr_pipeline_entity_id + + @property + def vad_sensitivity_entity_id(self) -> str | None: + """Entity ID of the VAD sensitivity to use for the next conversation.""" + return self._attr_vad_sensitivity_entity_id + + async def async_intercept_wake_word(self) -> str | None: + """Intercept the next wake word from the satellite. + + Returns the detected wake word phrase or None. + """ + if self._wake_word_intercept_future is not None: + raise SatelliteBusyError("Wake word interception already in progress") + + # Will cause next wake word to be intercepted in + # async_accept_pipeline_from_satellite + self._wake_word_intercept_future = asyncio.Future() + + _LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id) + + try: + return await self._wake_word_intercept_future + finally: + self._wake_word_intercept_future = None + + async def async_internal_announce( + self, + message: str | None = None, + media_id: str | None = None, + ) -> None: + """Play and show an announcement on the satellite. + + If media_id is not provided, message is synthesized to + audio with the selected pipeline. + + If media_id is provided, it is played directly. It is possible + to omit the message and the satellite will not show any text. + + Calls async_announce with message and media id. + """ + if message is None: + message = "" + + if not media_id: + # Synthesize audio and get URL + pipeline_id = self._resolve_pipeline() + pipeline = async_get_pipeline(self.hass, pipeline_id) + + tts_options: dict[str, Any] = {} + if pipeline.tts_voice is not None: + tts_options[tts.ATTR_VOICE] = pipeline.tts_voice + + media_id = tts_generate_media_source_id( + self.hass, + message, + engine=pipeline.tts_engine, + language=pipeline.tts_language, + options=tts_options, + ) + + if media_source.is_media_source_id(media_id): + media = await media_source.async_resolve_media( + self.hass, + media_id, + None, + ) + media_id = media.url + + # Resolve to full URL + media_id = async_process_play_media_url(self.hass, media_id) + + if self._is_announcing: + raise SatelliteBusyError + + self._is_announcing = True + + try: + # Block until announcement is finished + await self.async_announce(message, media_id) + finally: + self._is_announcing = False + + async def async_announce(self, message: str, media_id: str) -> None: + """Announce media on the satellite. + + Should block until the announcement is done playing. + """ + raise NotImplementedError + + async def async_accept_pipeline_from_satellite( + self, + audio_stream: AsyncIterable[bytes], + start_stage: PipelineStage = PipelineStage.STT, + end_stage: PipelineStage = PipelineStage.TTS, + wake_word_phrase: str | None = None, + ) -> None: + """Triggers an Assist pipeline in Home Assistant from a satellite.""" + if self._wake_word_intercept_future and start_stage in ( + PipelineStage.WAKE_WORD, + PipelineStage.STT, + ): + if start_stage == PipelineStage.WAKE_WORD: + self._wake_word_intercept_future.set_exception( + AssistSatelliteError( + "Only on-device wake words currently supported" + ) + ) + return + + # Intercepting wake word and immediately end pipeline + _LOGGER.debug( + "Intercepted wake word: %s (entity_id=%s)", + wake_word_phrase, + self.entity_id, + ) + + if wake_word_phrase is None: + self._wake_word_intercept_future.set_exception( + AssistSatelliteError("No wake word phrase provided") + ) + else: + self._wake_word_intercept_future.set_result(wake_word_phrase) + self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END)) + return + + device_id = self.registry_entry.device_id if self.registry_entry else None + + # Refresh context if necessary + if ( + (self._context is None) + or (self._context_set is None) + or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS) + ): + self.async_set_context(Context()) + + assert self._context is not None + + # Reset conversation id if necessary + if (self._conversation_id_time is None) or ( + (time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC + ): + self._conversation_id = None + + if self._conversation_id is None: + self._conversation_id = ulid.ulid() + + # Update timeout + self._conversation_id_time = time.monotonic() + + # Set entity state based on pipeline events + self._run_has_tts = False + + await async_pipeline_from_audio_stream( + self.hass, + context=self._context, + event_callback=self._internal_on_pipeline_event, + stt_metadata=stt.SpeechMetadata( + language="", # set in async_pipeline_from_audio_stream + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_stream, + pipeline_id=self._resolve_pipeline(), + conversation_id=self._conversation_id, + device_id=device_id, + tts_audio_output="wav", + wake_word_phrase=wake_word_phrase, + audio_settings=AudioSettings( + silence_seconds=self._resolve_vad_sensitivity() + ), + start_stage=start_stage, + end_stage=end_stage, + ) + + @abstractmethod + def on_pipeline_event(self, event: PipelineEvent) -> None: + """Handle pipeline events.""" + + @callback + def _internal_on_pipeline_event(self, event: PipelineEvent) -> None: + """Set state based on pipeline stage.""" + if event.type is PipelineEventType.WAKE_WORD_START: + self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD) + elif event.type is PipelineEventType.STT_START: + self._set_state(AssistSatelliteState.LISTENING_COMMAND) + elif event.type is PipelineEventType.INTENT_START: + self._set_state(AssistSatelliteState.PROCESSING) + elif event.type is PipelineEventType.TTS_START: + # Wait until tts_response_finished is called to return to waiting state + self._run_has_tts = True + self._set_state(AssistSatelliteState.RESPONDING) + elif event.type is PipelineEventType.RUN_END: + if not self._run_has_tts: + self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD) + + self.on_pipeline_event(event) + + @callback + def _set_state(self, state: AssistSatelliteState) -> None: + """Set the entity's state.""" + self.__assist_satellite_state = state + self.async_write_ha_state() + + @callback + def tts_response_finished(self) -> None: + """Tell entity that the text-to-speech response has finished playing.""" + self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD) + + @callback + def _resolve_pipeline(self) -> str | None: + """Resolve pipeline from select entity to id. + + Return None to make async_get_pipeline look up the preferred pipeline. + """ + if not (pipeline_entity_id := self.pipeline_entity_id): + return None + + if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None: + raise RuntimeError("Pipeline entity not found") + + if pipeline_entity_state.state != OPTION_PREFERRED: + # Resolve pipeline by name + for pipeline in async_get_pipelines(self.hass): + if pipeline.name == pipeline_entity_state.state: + return pipeline.id + + return None + + @callback + def _resolve_vad_sensitivity(self) -> float: + """Resolve VAD sensitivity from select entity to enum.""" + vad_sensitivity = vad.VadSensitivity.DEFAULT + + if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id: + if ( + vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id) + ) is None: + raise RuntimeError("VAD sensitivity entity not found") + + vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state) + + return vad.VadSensitivity.to_seconds(vad_sensitivity) diff --git a/homeassistant/components/assist_satellite/errors.py b/homeassistant/components/assist_satellite/errors.py new file mode 100644 index 0000000000000..cd05f374521c3 --- /dev/null +++ b/homeassistant/components/assist_satellite/errors.py @@ -0,0 +1,11 @@ +"""Errors for assist satellite.""" + +from homeassistant.exceptions import HomeAssistantError + + +class AssistSatelliteError(HomeAssistantError): + """Base class for assist satellite errors.""" + + +class SatelliteBusyError(AssistSatelliteError): + """Satellite is busy and cannot handle the request.""" diff --git a/homeassistant/components/assist_satellite/icons.json b/homeassistant/components/assist_satellite/icons.json new file mode 100644 index 0000000000000..a98c3aefc5bab --- /dev/null +++ b/homeassistant/components/assist_satellite/icons.json @@ -0,0 +1,12 @@ +{ + "entity_component": { + "_": { + "default": "mdi:account-voice" + } + }, + "services": { + "announce": { + "service": "mdi:bullhorn" + } + } +} diff --git a/homeassistant/components/assist_satellite/manifest.json b/homeassistant/components/assist_satellite/manifest.json new file mode 100644 index 0000000000000..b4f894563518c --- /dev/null +++ b/homeassistant/components/assist_satellite/manifest.json @@ -0,0 +1,9 @@ +{ + "domain": "assist_satellite", + "name": "Assist Satellite", + "codeowners": ["@home-assistant/core", "@synesthesiam"], + "dependencies": ["assist_pipeline", "stt", "tts"], + "documentation": "https://www.home-assistant.io/integrations/assist_satellite", + "integration_type": "entity", + "quality_scale": "internal" +} diff --git a/homeassistant/components/assist_satellite/services.yaml b/homeassistant/components/assist_satellite/services.yaml new file mode 100644 index 0000000000000..e7fefc4705f5c --- /dev/null +++ b/homeassistant/components/assist_satellite/services.yaml @@ -0,0 +1,16 @@ +announce: + target: + entity: + domain: assist_satellite + supported_features: + - assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE + fields: + message: + required: false + example: "Time to wake up!" + selector: + text: + media_id: + required: false + selector: + text: diff --git a/homeassistant/components/assist_satellite/strings.json b/homeassistant/components/assist_satellite/strings.json new file mode 100644 index 0000000000000..1d07882daaefc --- /dev/null +++ b/homeassistant/components/assist_satellite/strings.json @@ -0,0 +1,30 @@ +{ + "title": "Assist satellite", + "entity_component": { + "_": { + "name": "Assist satellite", + "state": { + "listening_wake_word": "Wake word", + "listening_command": "Voice command", + "responding": "Responding", + "processing": "Processing" + } + } + }, + "services": { + "announce": { + "name": "Announce", + "description": "Let the satellite announce a message.", + "fields": { + "message": { + "name": "Message", + "description": "The message to announce." + }, + "media_id": { + "name": "Media ID", + "description": "The media ID to announce instead of using text-to-speech." + } + } + } + } +} diff --git a/homeassistant/components/assist_satellite/websocket_api.py b/homeassistant/components/assist_satellite/websocket_api.py new file mode 100644 index 0000000000000..10687f4210e7b --- /dev/null +++ b/homeassistant/components/assist_satellite/websocket_api.py @@ -0,0 +1,46 @@ +"""Assist satellite Websocket API.""" + +from typing import Any + +import voluptuous as vol + +from homeassistant.components import websocket_api +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.entity_component import EntityComponent + +from .const import DOMAIN +from .entity import AssistSatelliteEntity + + +@callback +def async_register_websocket_api(hass: HomeAssistant) -> None: + """Register the websocket API.""" + websocket_api.async_register_command(hass, websocket_intercept_wake_word) + + +@callback +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_satellite/intercept_wake_word", + vol.Required("entity_id"): cv.entity_domain(DOMAIN), + } +) +@websocket_api.require_admin +@websocket_api.async_response +async def websocket_intercept_wake_word( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Intercept the next wake word from a satellite.""" + component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN] + satellite = component.get_entity(msg["entity_id"]) + if satellite is None: + connection.send_error( + msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found" + ) + return + + wake_word_phrase = await satellite.async_intercept_wake_word() + connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase}) diff --git a/homeassistant/const.py b/homeassistant/const.py index 1ee73408f98e3..ee90ebfc28b29 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -41,6 +41,7 @@ class Platform(StrEnum): AIR_QUALITY = "air_quality" ALARM_CONTROL_PANEL = "alarm_control_panel" + ASSIST_SATELLITE = "assist_satellite" BINARY_SENSOR = "binary_sensor" BUTTON = "button" CALENDAR = "calendar" diff --git a/mypy.ini b/mypy.ini index 3854477b94b65..2686fbe30620d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -705,6 +705,16 @@ disallow_untyped_defs = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.assist_satellite.*] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.asuswrt.*] check_untyped_defs = true disallow_incomplete_defs = true diff --git a/script/hassfest/docker/Dockerfile b/script/hassfest/docker/Dockerfile index 4dbea0e4c959b..a37fa9c57fcdf 100644 --- a/script/hassfest/docker/Dockerfile +++ b/script/hassfest/docker/Dockerfile @@ -23,7 +23,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.2.27,source=/uv,target=/bin/uv \ -c /usr/src/homeassistant/homeassistant/package_constraints.txt \ -r /usr/src/homeassistant/requirements.txt \ stdlib-list==0.10.0 pipdeptree==2.23.1 tqdm==4.66.4 ruff==0.6.2 \ - PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0 + PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2 LABEL "name"="hassfest" LABEL "maintainer"="Home Assistant " diff --git a/tests/components/assist_satellite/__init__.py b/tests/components/assist_satellite/__init__.py new file mode 100644 index 0000000000000..7e06ea3a4b95a --- /dev/null +++ b/tests/components/assist_satellite/__init__.py @@ -0,0 +1,3 @@ +"""Tests for Assist Satellite.""" + +ENTITY_ID = "assist_satellite.test_entity" diff --git a/tests/components/assist_satellite/conftest.py b/tests/components/assist_satellite/conftest.py new file mode 100644 index 0000000000000..a14e9e9452bfb --- /dev/null +++ b/tests/components/assist_satellite/conftest.py @@ -0,0 +1,107 @@ +"""Test helpers for Assist Satellite.""" + +import pathlib +from unittest.mock import Mock + +import pytest + +from homeassistant.components.assist_pipeline import PipelineEvent +from homeassistant.components.assist_satellite import ( + DOMAIN as AS_DOMAIN, + AssistSatelliteEntity, + AssistSatelliteEntityFeature, +) +from homeassistant.config_entries import ConfigEntry, ConfigFlow +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component + +from tests.common import ( + MockConfigEntry, + MockModule, + mock_config_flow, + mock_integration, + mock_platform, + setup_test_component_platform, +) + +TEST_DOMAIN = "test" + + +@pytest.fixture(autouse=True) +def mock_tts(mock_tts_cache_dir: pathlib.Path) -> None: + """Mock TTS cache dir fixture.""" + + +class MockAssistSatellite(AssistSatelliteEntity): + """Mock Assist Satellite Entity.""" + + _attr_name = "Test Entity" + _attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE + + def __init__(self) -> None: + """Initialize the mock entity.""" + self.events = [] + self.announcements = [] + + def on_pipeline_event(self, event: PipelineEvent) -> None: + """Handle pipeline events.""" + self.events.append(event) + + async def async_announce(self, message: str, media_id: str) -> None: + """Announce media on a device.""" + self.announcements.append((message, media_id)) + + +@pytest.fixture +def entity() -> MockAssistSatellite: + """Mock Assist Satellite Entity.""" + return MockAssistSatellite() + + +@pytest.fixture +def config_entry(hass: HomeAssistant) -> ConfigEntry: + """Mock config entry.""" + entry = MockConfigEntry(domain=TEST_DOMAIN) + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +async def init_components( + hass: HomeAssistant, + config_entry: ConfigEntry, + entity: MockAssistSatellite, +) -> None: + """Initialize components.""" + assert await async_setup_component(hass, "homeassistant", {}) + + async def async_setup_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Set up test config entry.""" + await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN]) + return True + + async def async_unload_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Unload test config entry.""" + await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN) + return True + + mock_integration( + hass, + MockModule( + TEST_DOMAIN, + async_setup_entry=async_setup_entry_init, + async_unload_entry=async_unload_entry_init, + ), + ) + setup_test_component_platform(hass, AS_DOMAIN, [entity], from_config_entry=True) + mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock()) + + with mock_config_flow(TEST_DOMAIN, ConfigFlow): + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + return config_entry diff --git a/tests/components/assist_satellite/test_entity.py b/tests/components/assist_satellite/test_entity.py new file mode 100644 index 0000000000000..f957a826828e8 --- /dev/null +++ b/tests/components/assist_satellite/test_entity.py @@ -0,0 +1,332 @@ +"""Test the Assist Satellite entity.""" + +import asyncio +from unittest.mock import patch + +import pytest + +from homeassistant.components import stt +from homeassistant.components.assist_pipeline import ( + OPTION_PREFERRED, + AudioSettings, + Pipeline, + PipelineEvent, + PipelineEventType, + PipelineStage, + async_get_pipeline, + async_update_pipeline, + vad, +) +from homeassistant.components.assist_satellite import SatelliteBusyError +from homeassistant.components.assist_satellite.entity import AssistSatelliteState +from homeassistant.components.media_source import PlayMedia +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import STATE_UNKNOWN +from homeassistant.core import Context, HomeAssistant + +from . import ENTITY_ID +from .conftest import MockAssistSatellite + + +async def test_entity_state( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test entity state represent events.""" + + state = hass.states.get(ENTITY_ID) + assert state is not None + assert state.state == STATE_UNKNOWN + + context = Context() + audio_stream = object() + + entity.async_set_context(context) + + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" + ) as mock_start_pipeline: + await entity.async_accept_pipeline_from_satellite(audio_stream) + + assert mock_start_pipeline.called + kwargs = mock_start_pipeline.call_args[1] + assert kwargs["context"] is context + assert kwargs["event_callback"] == entity._internal_on_pipeline_event + assert kwargs["stt_metadata"] == stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ) + assert kwargs["stt_stream"] is audio_stream + assert kwargs["pipeline_id"] is None + assert kwargs["device_id"] is None + assert kwargs["tts_audio_output"] == "wav" + assert kwargs["wake_word_phrase"] is None + assert kwargs["audio_settings"] == AudioSettings( + silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT) + ) + assert kwargs["start_stage"] == PipelineStage.STT + assert kwargs["end_stage"] == PipelineStage.TTS + + for event_type, expected_state in ( + (PipelineEventType.RUN_START, STATE_UNKNOWN), + (PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD), + (PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD), + (PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD), + (PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND), + (PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND), + (PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND), + (PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND), + (PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING), + (PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING), + (PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING), + (PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING), + (PipelineEventType.ERROR, AssistSatelliteState.RESPONDING), + ): + kwargs["event_callback"](PipelineEvent(event_type, {})) + state = hass.states.get(ENTITY_ID) + assert state.state == expected_state, event_type + + entity.tts_response_finished() + state = hass.states.get(ENTITY_ID) + assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD + + +@pytest.mark.parametrize( + ("service_data", "expected_params"), + [ + ( + {"message": "Hello"}, + ("Hello", "https://www.home-assistant.io/resolved.mp3"), + ), + ( + { + "message": "Hello", + "media_id": "http://example.com/bla.mp3", + }, + ("Hello", "http://example.com/bla.mp3"), + ), + ( + {"media_id": "http://example.com/bla.mp3"}, + ("", "http://example.com/bla.mp3"), + ), + ], +) +async def test_announce( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + service_data: dict, + expected_params: tuple[str, str], +) -> None: + """Test announcing on a device.""" + await async_update_pipeline( + hass, + async_get_pipeline(hass), + tts_engine="tts.mock_entity", + tts_language="en", + tts_voice="test-voice", + ) + + with ( + patch( + "homeassistant.components.assist_satellite.entity.tts_generate_media_source_id", + return_value="media-source://bla", + ), + patch( + "homeassistant.components.media_source.async_resolve_media", + return_value=PlayMedia( + url="https://www.home-assistant.io/resolved.mp3", + mime_type="audio/mp3", + ), + ), + ): + await hass.services.async_call( + "assist_satellite", + "announce", + service_data, + target={"entity_id": "assist_satellite.test_entity"}, + blocking=True, + ) + + assert entity.announcements[0] == expected_params + + +async def test_announce_busy( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, +) -> None: + """Test that announcing while an announcement is in progress raises an error.""" + media_id = "https://www.home-assistant.io/resolved.mp3" + announce_started = asyncio.Event() + got_error = asyncio.Event() + + async def async_announce(message, media_id): + announce_started.set() + + # Block so we can do another announcement + await got_error.wait() + + with patch.object(entity, "async_announce", new=async_announce): + announce_task = asyncio.create_task( + entity.async_internal_announce(media_id=media_id) + ) + async with asyncio.timeout(1): + await announce_started.wait() + + # Try to do a second announcement + with pytest.raises(SatelliteBusyError): + await entity.async_internal_announce(media_id=media_id) + + # Avoid lingering task + got_error.set() + await announce_task + + +async def test_context_refresh( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test that the context will be automatically refreshed.""" + audio_stream = object() + + # Remove context + entity._context = None + + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream" + ): + await entity.async_accept_pipeline_from_satellite(audio_stream) + + # Context should have been refreshed + assert entity._context is not None + + +async def test_pipeline_entity( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test getting pipeline from an entity.""" + audio_stream = object() + pipeline = Pipeline( + conversation_engine="test", + conversation_language="en", + language="en", + name="test-pipeline", + stt_engine=None, + stt_language=None, + tts_engine=None, + tts_language=None, + tts_voice=None, + wake_word_entity=None, + wake_word_id=None, + ) + + pipeline_entity_id = "select.pipeline" + hass.states.async_set(pipeline_entity_id, pipeline.name) + entity._attr_pipeline_entity_id = pipeline_entity_id + + done = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs): + assert pipeline_id == pipeline.id + done.set() + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch( + "homeassistant.components.assist_satellite.entity.async_get_pipelines", + return_value=[pipeline], + ), + ): + async with asyncio.timeout(1): + await entity.async_accept_pipeline_from_satellite(audio_stream) + await done.wait() + + +async def test_pipeline_entity_preferred( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test getting pipeline from an entity with a preferred state.""" + audio_stream = object() + + pipeline_entity_id = "select.pipeline" + hass.states.async_set(pipeline_entity_id, OPTION_PREFERRED) + entity._attr_pipeline_entity_id = pipeline_entity_id + + done = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs): + # Preferred pipeline + assert pipeline_id is None + done.set() + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + ): + async with asyncio.timeout(1): + await entity.async_accept_pipeline_from_satellite(audio_stream) + await done.wait() + + +async def test_vad_sensitivity_entity( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test getting vad sensitivity from an entity.""" + audio_stream = object() + + vad_sensitivity_entity_id = "select.vad_sensitivity" + hass.states.async_set(vad_sensitivity_entity_id, vad.VadSensitivity.AGGRESSIVE) + entity._attr_vad_sensitivity_entity_id = vad_sensitivity_entity_id + + done = asyncio.Event() + + async def async_pipeline_from_audio_stream( + *args, audio_settings: AudioSettings, **kwargs + ): + # Verify vad sensitivity + assert audio_settings.silence_seconds == vad.VadSensitivity.to_seconds( + vad.VadSensitivity.AGGRESSIVE + ) + done.set() + + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ): + async with asyncio.timeout(1): + await entity.async_accept_pipeline_from_satellite(audio_stream) + await done.wait() + + +async def test_pipeline_entity_not_found( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test that setting the pipeline entity id to a non-existent entity raises an error.""" + audio_stream = object() + + # Set to an entity that doesn't exist + entity._attr_pipeline_entity_id = "select.pipeline" + + with pytest.raises(RuntimeError): + await entity.async_accept_pipeline_from_satellite(audio_stream) + + +async def test_vad_sensitivity_entity_not_found( + hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite +) -> None: + """Test that setting the vad sensitivity entity id to a non-existent entity raises an error.""" + audio_stream = object() + + # Set to an entity that doesn't exist + entity._attr_vad_sensitivity_entity_id = "select.vad_sensitivity" + + with pytest.raises(RuntimeError): + await entity.async_accept_pipeline_from_satellite(audio_stream) diff --git a/tests/components/assist_satellite/test_websocket_api.py b/tests/components/assist_satellite/test_websocket_api.py new file mode 100644 index 0000000000000..af49334e629bc --- /dev/null +++ b/tests/components/assist_satellite/test_websocket_api.py @@ -0,0 +1,192 @@ +"""Test WebSocket API.""" + +import asyncio + +from homeassistant.components.assist_pipeline import PipelineStage +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant + +from . import ENTITY_ID +from .conftest import MockAssistSatellite + +from tests.common import MockUser +from tests.typing import WebSocketGenerator + + +async def test_intercept_wake_word( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test intercepting a wake word.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": ENTITY_ID, + } + ) + + for _ in range(3): + await asyncio.sleep(0) + + await entity.async_accept_pipeline_from_satellite( + object(), + start_stage=PipelineStage.STT, + wake_word_phrase="ok, nabu", + ) + + response = await ws_client.receive_json() + + assert response["success"] + assert response["result"] == {"wake_word_phrase": "ok, nabu"} + + +async def test_intercept_wake_word_requires_on_device_wake_word( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test intercepting a wake word fails if detection happens in HA.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": ENTITY_ID, + } + ) + + for _ in range(3): + await asyncio.sleep(0) + + await entity.async_accept_pipeline_from_satellite( + object(), + # Emulate wake word processing in Home Assistant + start_stage=PipelineStage.WAKE_WORD, + ) + + response = await ws_client.receive_json() + assert not response["success"] + assert response["error"] == { + "code": "home_assistant_error", + "message": "Only on-device wake words currently supported", + } + + +async def test_intercept_wake_word_requires_wake_word_phrase( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test intercepting a wake word fails if detection happens in HA.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": ENTITY_ID, + } + ) + + for _ in range(3): + await asyncio.sleep(0) + + await entity.async_accept_pipeline_from_satellite( + object(), + start_stage=PipelineStage.STT, + # We are not passing wake word phrase + ) + + response = await ws_client.receive_json() + assert not response["success"] + assert response["error"] == { + "code": "home_assistant_error", + "message": "No wake word phrase provided", + } + + +async def test_intercept_wake_word_require_admin( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, + hass_admin_user: MockUser, +) -> None: + """Test intercepting a wake word requires admin access.""" + # Remove admin permission and verify we're not allowed + hass_admin_user.groups = [] + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": ENTITY_ID, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"] == { + "code": "unauthorized", + "message": "Unauthorized", + } + + +async def test_intercept_wake_word_invalid_satellite( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test intercepting a wake word requires admin access.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": "assist_satellite.invalid", + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"] == { + "code": "not_found", + "message": "Entity not found", + } + + +async def test_intercept_wake_word_twice( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test intercepting a wake word requires admin access.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": ENTITY_ID, + } + ) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/intercept_wake_word", + "entity_id": ENTITY_ID, + } + ) + response = await ws_client.receive_json() + + assert not response["success"] + assert response["error"] == { + "code": "home_assistant_error", + "message": "Wake word interception already in progress", + }