Skip to content

Commit

Permalink
Use correct service name with Wyoming satellite + local wake word det…
Browse files Browse the repository at this point in the history
…ection (#111870)

* Use correct service name with satellite + local wake word detection

* Don't load platforms for satellite services

* Update homeassistant/components/wyoming/data.py

Co-authored-by: Paulus Schoutsen <[email protected]>

* Fix ruff error

---------

Co-authored-by: Paulus Schoutsen <[email protected]>
  • Loading branch information
synesthesiam and balloob authored Feb 29, 2024
1 parent 66b17a8 commit f0deae3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
33 changes: 18 additions & 15 deletions homeassistant/components/wyoming/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Base class for Wyoming providers."""

from __future__ import annotations

import asyncio

from wyoming.client import AsyncTcpClient
from wyoming.info import Describe, Info, Satellite
from wyoming.info import Describe, Info

from homeassistant.const import Platform

Expand All @@ -23,14 +24,19 @@ def __init__(self, host: str, port: int, info: Info) -> None:
self.host = host
self.port = port
self.info = info
platforms = []
self.platforms = []

if (self.info.satellite is not None) and self.info.satellite.installed:
# Don't load platforms for satellite services, such as local wake
# word detection.
return

if any(asr.installed for asr in info.asr):
platforms.append(Platform.STT)
self.platforms.append(Platform.STT)
if any(tts.installed for tts in info.tts):
platforms.append(Platform.TTS)
self.platforms.append(Platform.TTS)
if any(wake.installed for wake in info.wake):
platforms.append(Platform.WAKE_WORD)
self.platforms = platforms
self.platforms.append(Platform.WAKE_WORD)

def has_services(self) -> bool:
"""Return True if services are installed that Home Assistant can use."""
Expand All @@ -43,6 +49,12 @@ def has_services(self) -> bool:

def get_name(self) -> str | None:
"""Return name of first installed usable service."""

# Wyoming satellite
# Must be checked first because satellites may contain wake services, etc.
if (self.info.satellite is not None) and self.info.satellite.installed:
return self.info.satellite.name

# ASR = automated speech recognition (speech-to-text)
asr_installed = [asr for asr in self.info.asr if asr.installed]
if asr_installed:
Expand All @@ -58,15 +70,6 @@ def get_name(self) -> str | None:
if wake_installed:
return wake_installed[0].name

# satellite
satellite_installed: Satellite | None = None

if (self.info.satellite is not None) and self.info.satellite.installed:
satellite_installed = self.info.satellite

if satellite_installed:
return satellite_installed.name

return None

@classmethod
Expand Down
31 changes: 27 additions & 4 deletions tests/components/wyoming/test_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Test tts."""

from __future__ import annotations

from unittest.mock import patch

from syrupy.assertion import SnapshotAssertion
from wyoming.info import Info

from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
from homeassistant.core import HomeAssistant
Expand All @@ -27,10 +29,13 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None:
"""Test loading info and error raising."""
mock_client = MockAsyncTcpClient([STT_INFO.event()])

with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
with (
patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
mock_client,
),
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
):
info = await load_wyoming_info(
"localhost",
1234,
Expand Down Expand Up @@ -75,3 +80,21 @@ async def test_service_name(hass: HomeAssistant) -> None:
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == SATELLITE_INFO.satellite.name


async def test_satellite_with_wake_word(hass: HomeAssistant) -> None:
"""Test that wake word info with satellite doesn't overwrite the service name."""
# Info for local wake word detection
satellite_info = Info(
satellite=SATELLITE_INFO.satellite,
wake=WAKE_WORD_INFO.wake,
)

with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([satellite_info.event()]),
):
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == satellite_info.satellite.name
assert not service.platforms

0 comments on commit f0deae3

Please sign in to comment.