Skip to content

Commit

Permalink
Assist Pipeline minor cleanup (#121187)
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob authored Jul 5, 2024
1 parent 2b9bddc commit 22718ca
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import array
import asyncio
from collections import defaultdict, deque
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
from collections.abc import AsyncGenerator, AsyncIterable, Callable
from dataclasses import asdict, dataclass, field
from enum import StrEnum
import logging
Expand Down Expand Up @@ -118,8 +118,10 @@ def validate_language(data: dict[str, Any]) -> Any:
@callback
def _async_resolve_default_pipeline_settings(
hass: HomeAssistant,
stt_engine_id: str | None,
tts_engine_id: str | None,
*,
conversation_engine_id: str | None = None,
stt_engine_id: str | None = None,
tts_engine_id: str | None = None,
pipeline_name: str,
) -> dict[str, str | None]:
"""Resolve settings for a default pipeline.
Expand All @@ -137,12 +139,13 @@ def _async_resolve_default_pipeline_settings(
wake_word_entity = None
wake_word_id = None

if conversation_engine_id is None:
conversation_engine_id = conversation.HOME_ASSISTANT_AGENT

# Find a matching language supported by the Home Assistant conversation agent
conversation_languages = language_util.matches(
hass.config.language,
conversation.async_get_conversation_languages(
hass, conversation.HOME_ASSISTANT_AGENT
),
conversation.async_get_conversation_languages(hass, conversation_engine_id),
country=hass.config.country,
)
if conversation_languages:
Expand Down Expand Up @@ -201,7 +204,7 @@ def _async_resolve_default_pipeline_settings(
tts_engine_id = None

return {
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_engine": conversation_engine_id,
"conversation_language": conversation_language,
"language": hass.config.language,
"name": pipeline_name,
Expand All @@ -224,7 +227,7 @@ async def _async_create_default_pipeline(
default stt / tts engines.
"""
pipeline_settings = _async_resolve_default_pipeline_settings(
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant"
hass, pipeline_name="Home Assistant"
)
return await pipeline_store.async_create_item(pipeline_settings)

Expand All @@ -243,7 +246,10 @@ async def async_create_default_pipeline(
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
pipeline_settings = _async_resolve_default_pipeline_settings(
hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name
hass,
stt_engine_id=stt_engine_id,
tts_engine_id=tts_engine_id,
pipeline_name=pipeline_name,
)
if (
pipeline_settings["stt_engine"] != stt_engine_id
Expand Down Expand Up @@ -274,11 +280,11 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P


@callback
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
"""Get all pipelines."""
pipeline_data: PipelineData = hass.data[DOMAIN]

return pipeline_data.pipeline_store.data.values()
return list(pipeline_data.pipeline_store.data.values())


async def async_update_pipeline(
Expand Down Expand Up @@ -1675,7 +1681,7 @@ def ws_list_item(
connection.send_result(
msg["id"],
{
"pipelines": self.storage_collection.async_items(),
"pipelines": async_get_pipelines(hass),
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
},
)
Expand Down

0 comments on commit 22718ca

Please sign in to comment.