From 3eed5de36785abc2deb011e4300d0de70508f798 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 13 Sep 2024 15:31:38 -0500 Subject: [PATCH] Handle announcement finished for ESPHome TTS response (#125625) * Handle announcement finished for TTS response * Adjust test --- .../components/esphome/assist_satellite.py | 13 ++ tests/components/esphome/conftest.py | 34 +++- .../esphome/test_assist_satellite.py | 159 ++++++++++++++++++ 3 files changed, 205 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index 370c3b9c8fd7d2..08dd2ac0774689 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -14,6 +14,7 @@ from aioesphomeapi import ( MediaPlayerFormatPurpose, + VoiceAssistantAnnounceFinished, VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, VoiceAssistantEventType, @@ -166,6 +167,7 @@ async def async_added_to_hass(self) -> None: handle_start=self.handle_pipeline_start, handle_stop=self.handle_pipeline_stop, handle_audio=self.handle_audio, + handle_announcement_finished=self.handle_announcement_finished, ) ) else: @@ -174,6 +176,7 @@ async def async_added_to_hass(self) -> None: self.cli.subscribe_voice_assistant( handle_start=self.handle_pipeline_start, handle_stop=self.handle_pipeline_stop, + handle_announcement_finished=self.handle_announcement_finished, ) ) @@ -194,6 +197,10 @@ async def async_added_to_hass(self) -> None: assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE ) + if not (feature_flags & VoiceAssistantFeature.SPEAKER): + # Will use media player for TTS/announcements + self._update_tts_format() + async def async_will_remove_from_hass(self) -> None: """Run when entity will be removed from hass.""" await super().async_will_remove_from_hass() @@ -382,6 +389,12 @@ def handle_timer_event( timer_info.is_active, ) + async def handle_announcement_finished( + self, announce_finished: VoiceAssistantAnnounceFinished + ) -> None: + """Handle announcement finished message (also sent for TTS).""" + self.tts_response_finished() + def _update_tts_format(self) -> None: """Update the TTS format from the first media player.""" for supported_format in chain(*self.entry_data.media_player_formats.values()): diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index a95d28359d2d62..2b7c127efd3141 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -19,6 +19,7 @@ HomeassistantServiceCall, ReconnectLogic, UserService, + VoiceAssistantAnnounceFinished, VoiceAssistantAudioSettings, VoiceAssistantFeature, ) @@ -214,6 +215,13 @@ def __init__( ] | None ) + self.voice_assistant_handle_announcement_finished_callback: ( + Callable[ + [VoiceAssistantAnnounceFinished], + Coroutine[Any, Any, None], + ] + | None + ) self.device_info = device_info def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None: @@ -295,11 +303,21 @@ def set_subscribe_voice_assistant_callbacks( ] | None ) = None, + handle_announcement_finished: ( + Callable[ + [VoiceAssistantAnnounceFinished], + Coroutine[Any, Any, None], + ] + | None + ) = None, ) -> None: """Set the voice assistant subscription callbacks.""" self.voice_assistant_handle_start_callback = handle_start self.voice_assistant_handle_stop_callback = handle_stop self.voice_assistant_handle_audio_callback = handle_audio + self.voice_assistant_handle_announcement_finished_callback = ( + handle_announcement_finished + ) async def mock_voice_assistant_handle_start( self, @@ -322,6 +340,13 @@ async def mock_voice_assistant_handle_audio(self, audio: bytes) -> None: assert self.voice_assistant_handle_audio_callback is not None await self.voice_assistant_handle_audio_callback(audio) + async def mock_voice_assistant_handle_announcement_finished( + self, finished: VoiceAssistantAnnounceFinished + ) -> None: + """Mock voice assistant handle announcement finished.""" + assert self.voice_assistant_handle_announcement_finished_callback is not None + await self.voice_assistant_handle_announcement_finished_callback(finished) + async def _mock_generic_device_entry( hass: HomeAssistant, @@ -402,10 +427,17 @@ def _subscribe_voice_assistant( ] | None ) = None, + handle_announcement_finished: ( + Callable[ + [VoiceAssistantAnnounceFinished], + Coroutine[Any, Any, None], + ] + | None + ) = None, ) -> Callable[[], None]: """Subscribe to voice assistant.""" mock_device.set_subscribe_voice_assistant_callbacks( - handle_start, handle_stop, handle_audio + handle_start, handle_stop, handle_audio, handle_announcement_finished ) def unsub(): diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py index 2e6727d88bb399..eb4f980221947a 100644 --- a/tests/components/esphome/test_assist_satellite.py +++ b/tests/components/esphome/test_assist_satellite.py @@ -15,6 +15,7 @@ MediaPlayerInfo, MediaPlayerSupportedFormat, UserService, + VoiceAssistantAnnounceFinished, VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, VoiceAssistantEventType, @@ -603,6 +604,160 @@ async def test_udp_errors() -> None: protocol.transport.sendto.assert_not_called() +async def test_pipeline_media_player( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], + mock_wav: bytes, +) -> None: + """Test a complete pipeline run with the TTS response sent to a media player instead of a speaker. + + This test is not as comprehensive as test_pipeline_api_audio since we're + mainly focused on tts_response_finished getting automatically called. + """ + conversation_id = "test-conversation-id" + media_url = "http://test.url" + media_id = "test-media-id" + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.API_AUDIO + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + + async def async_pipeline_from_audio_stream(*args, device_id, **kwargs): + stt_stream = kwargs["stt_stream"] + + async for _chunk in stt_stream: + break + + event_callback = kwargs["event_callback"] + + # STT + event_callback( + PipelineEvent( + type=PipelineEventType.STT_START, + data={"engine": "test-stt-engine", "metadata": {}}, + ) + ) + + event_callback( + PipelineEvent( + type=PipelineEventType.STT_END, + data={"stt_output": {"text": "test-stt-text"}}, + ) + ) + + # Intent + event_callback( + PipelineEvent( + type=PipelineEventType.INTENT_START, + data={ + "engine": "test-intent-engine", + "language": hass.config.language, + "intent_input": "test-intent-text", + "conversation_id": conversation_id, + "device_id": device_id, + }, + ) + ) + + event_callback( + PipelineEvent( + type=PipelineEventType.INTENT_END, + data={"intent_output": {"conversation_id": conversation_id}}, + ) + ) + + # TTS + event_callback( + PipelineEvent( + type=PipelineEventType.TTS_START, + data={ + "engine": "test-stt-engine", + "language": hass.config.language, + "voice": "test-voice", + "tts_input": "test-tts-text", + }, + ) + ) + + # Should return mock_wav audio + event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={"tts_output": {"url": media_url, "media_id": media_id}}, + ) + ) + + event_callback(PipelineEvent(type=PipelineEventType.RUN_END)) + + pipeline_finished = asyncio.Event() + original_handle_pipeline_finished = satellite.handle_pipeline_finished + + def handle_pipeline_finished(): + original_handle_pipeline_finished() + pipeline_finished.set() + + async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + return ("wav", mock_wav) + + tts_finished = asyncio.Event() + original_tts_response_finished = satellite.tts_response_finished + + def tts_response_finished(): + original_tts_response_finished() + tts_finished.set() + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch( + "homeassistant.components.tts.async_get_media_source_audio", + new=async_get_media_source_audio, + ), + patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished), + patch.object(satellite, "tts_response_finished", tts_response_finished), + ): + async with asyncio.timeout(1): + await satellite.handle_pipeline_start( + conversation_id=conversation_id, + flags=VoiceAssistantCommandFlag(0), # stt + audio_settings=VoiceAssistantAudioSettings(), + wake_word_phrase="", + ) + + await satellite.handle_pipeline_stop(abort=False) + await pipeline_finished.wait() + + assert satellite.state == AssistSatelliteState.RESPONDING + + # Will trigger tts_response_finished + await mock_device.mock_voice_assistant_handle_announcement_finished( + VoiceAssistantAnnounceFinished(success=True) + ) + await tts_finished.wait() + + assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD + + async def test_timer_events( hass: HomeAssistant, device_registry: dr.DeviceRegistry, @@ -952,6 +1107,7 @@ async def test_announce_message( async def send_voice_assistant_announcement_await_response( media_id: str, timeout: float, text: str ): + assert satellite.state == AssistSatelliteState.RESPONDING assert media_id == "https://www.home-assistant.io/resolved.mp3" assert text == "test-text" @@ -983,6 +1139,7 @@ async def send_voice_assistant_announcement_await_response( blocking=True, ) await done.wait() + assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD async def test_announce_media_id( @@ -1016,6 +1173,7 @@ async def test_announce_media_id( async def send_voice_assistant_announcement_await_response( media_id: str, timeout: float, text: str ): + assert satellite.state == AssistSatelliteState.RESPONDING assert media_id == "https://www.home-assistant.io/resolved.mp3" done.set() @@ -1038,6 +1196,7 @@ async def send_voice_assistant_announcement_await_response( blocking=True, ) await done.wait() + assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD async def test_satellite_unloaded_on_disconnect(