Skip to content

Commit

Permalink
Handle announcement finished for ESPHome TTS response (home-assistant…
Browse files Browse the repository at this point in the history
…#125625)

* Handle announcement finished for TTS response

* Adjust test
  • Loading branch information
synesthesiam authored Sep 13, 2024
1 parent 970d28b commit 3eed5de
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 1 deletion.
13 changes: 13 additions & 0 deletions homeassistant/components/esphome/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from aioesphomeapi import (
MediaPlayerFormatPurpose,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
Expand Down Expand Up @@ -166,6 +167,7 @@ async def async_added_to_hass(self) -> None:
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_audio=self.handle_audio,
handle_announcement_finished=self.handle_announcement_finished,
)
)
else:
Expand All @@ -174,6 +176,7 @@ async def async_added_to_hass(self) -> None:
self.cli.subscribe_voice_assistant(
handle_start=self.handle_pipeline_start,
handle_stop=self.handle_pipeline_stop,
handle_announcement_finished=self.handle_announcement_finished,
)
)

Expand All @@ -194,6 +197,10 @@ async def async_added_to_hass(self) -> None:
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)

if not (feature_flags & VoiceAssistantFeature.SPEAKER):
# Will use media player for TTS/announcements
self._update_tts_format()

async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
Expand Down Expand Up @@ -382,6 +389,12 @@ def handle_timer_event(
timer_info.is_active,
)

async def handle_announcement_finished(
self, announce_finished: VoiceAssistantAnnounceFinished
) -> None:
"""Handle announcement finished message (also sent for TTS)."""
self.tts_response_finished()

def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):
Expand Down
34 changes: 33 additions & 1 deletion tests/components/esphome/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
HomeassistantServiceCall,
ReconnectLogic,
UserService,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantFeature,
)
Expand Down Expand Up @@ -214,6 +215,13 @@ def __init__(
]
| None
)
self.voice_assistant_handle_announcement_finished_callback: (
Callable[
[VoiceAssistantAnnounceFinished],
Coroutine[Any, Any, None],
]
| None
)
self.device_info = device_info

def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
Expand Down Expand Up @@ -295,11 +303,21 @@ def set_subscribe_voice_assistant_callbacks(
]
| None
) = None,
handle_announcement_finished: (
Callable[
[VoiceAssistantAnnounceFinished],
Coroutine[Any, Any, None],
]
| None
) = None,
) -> None:
"""Set the voice assistant subscription callbacks."""
self.voice_assistant_handle_start_callback = handle_start
self.voice_assistant_handle_stop_callback = handle_stop
self.voice_assistant_handle_audio_callback = handle_audio
self.voice_assistant_handle_announcement_finished_callback = (
handle_announcement_finished
)

async def mock_voice_assistant_handle_start(
self,
Expand All @@ -322,6 +340,13 @@ async def mock_voice_assistant_handle_audio(self, audio: bytes) -> None:
assert self.voice_assistant_handle_audio_callback is not None
await self.voice_assistant_handle_audio_callback(audio)

async def mock_voice_assistant_handle_announcement_finished(
self, finished: VoiceAssistantAnnounceFinished
) -> None:
"""Mock voice assistant handle announcement finished."""
assert self.voice_assistant_handle_announcement_finished_callback is not None
await self.voice_assistant_handle_announcement_finished_callback(finished)


async def _mock_generic_device_entry(
hass: HomeAssistant,
Expand Down Expand Up @@ -402,10 +427,17 @@ def _subscribe_voice_assistant(
]
| None
) = None,
handle_announcement_finished: (
Callable[
[VoiceAssistantAnnounceFinished],
Coroutine[Any, Any, None],
]
| None
) = None,
) -> Callable[[], None]:
"""Subscribe to voice assistant."""
mock_device.set_subscribe_voice_assistant_callbacks(
handle_start, handle_stop, handle_audio
handle_start, handle_stop, handle_audio, handle_announcement_finished
)

def unsub():
Expand Down
159 changes: 159 additions & 0 deletions tests/components/esphome/test_assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MediaPlayerInfo,
MediaPlayerSupportedFormat,
UserService,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
Expand Down Expand Up @@ -603,6 +604,160 @@ async def test_udp_errors() -> None:
protocol.transport.sendto.assert_not_called()


async def test_pipeline_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test a complete pipeline run with the TTS response sent to a media player instead of a speaker.
This test is not as comprehensive as test_pipeline_api_audio since we're
mainly focused on tts_response_finished getting automatically called.
"""
conversation_id = "test-conversation-id"
media_url = "http://test.url"
media_id = "test-media-id"

mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.API_AUDIO
},
)
await hass.async_block_till_done()

satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None

async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
stt_stream = kwargs["stt_stream"]

async for _chunk in stt_stream:
break

event_callback = kwargs["event_callback"]

# STT
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={"engine": "test-stt-engine", "metadata": {}},
)
)

event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "test-stt-text"}},
)
)

# Intent
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={
"engine": "test-intent-engine",
"language": hass.config.language,
"intent_input": "test-intent-text",
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)

event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={"intent_output": {"conversation_id": conversation_id}},
)
)

# TTS
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={
"engine": "test-stt-engine",
"language": hass.config.language,
"voice": "test-voice",
"tts_input": "test-tts-text",
},
)
)

# Should return mock_wav audio
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": media_url, "media_id": media_id}},
)
)

event_callback(PipelineEvent(type=PipelineEventType.RUN_END))

pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished

def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()

async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("wav", mock_wav)

tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished

def tts_response_finished():
original_tts_response_finished()
tts_finished.set()

with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
async with asyncio.timeout(1):
await satellite.handle_pipeline_start(
conversation_id=conversation_id,
flags=VoiceAssistantCommandFlag(0), # stt
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)

await satellite.handle_pipeline_stop(abort=False)
await pipeline_finished.wait()

assert satellite.state == AssistSatelliteState.RESPONDING

# Will trigger tts_response_finished
await mock_device.mock_voice_assistant_handle_announcement_finished(
VoiceAssistantAnnounceFinished(success=True)
)
await tts_finished.wait()

assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD


async def test_timer_events(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
Expand Down Expand Up @@ -952,6 +1107,7 @@ async def test_announce_message(
async def send_voice_assistant_announcement_await_response(
media_id: str, timeout: float, text: str
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "https://www.home-assistant.io/resolved.mp3"
assert text == "test-text"

Expand Down Expand Up @@ -983,6 +1139,7 @@ async def send_voice_assistant_announcement_await_response(
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD


async def test_announce_media_id(
Expand Down Expand Up @@ -1016,6 +1173,7 @@ async def test_announce_media_id(
async def send_voice_assistant_announcement_await_response(
media_id: str, timeout: float, text: str
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "https://www.home-assistant.io/resolved.mp3"

done.set()
Expand All @@ -1038,6 +1196,7 @@ async def send_voice_assistant_announcement_await_response(
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD


async def test_satellite_unloaded_on_disconnect(
Expand Down

0 comments on commit 3eed5de

Please sign in to comment.