Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assist Pipeline minor cleanup #121187

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import array
import asyncio
from collections import defaultdict, deque
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable
from collections.abc import AsyncGenerator, AsyncIterable, Callable
from dataclasses import asdict, dataclass, field
from enum import StrEnum
import logging
Expand Down Expand Up @@ -118,8 +118,10 @@ def validate_language(data: dict[str, Any]) -> Any:
@callback
def _async_resolve_default_pipeline_settings(
hass: HomeAssistant,
stt_engine_id: str | None,
tts_engine_id: str | None,
*,
conversation_engine_id: str | None = None,
stt_engine_id: str | None = None,
tts_engine_id: str | None = None,
pipeline_name: str,
) -> dict[str, str | None]:
"""Resolve settings for a default pipeline.
Expand All @@ -137,12 +139,13 @@ def _async_resolve_default_pipeline_settings(
wake_word_entity = None
wake_word_id = None

if conversation_engine_id is None:
conversation_engine_id = conversation.HOME_ASSISTANT_AGENT

# Find a matching language supported by the Home Assistant conversation agent
conversation_languages = language_util.matches(
hass.config.language,
conversation.async_get_conversation_languages(
hass, conversation.HOME_ASSISTANT_AGENT
),
conversation.async_get_conversation_languages(hass, conversation_engine_id),
country=hass.config.country,
)
if conversation_languages:
Expand Down Expand Up @@ -201,7 +204,7 @@ def _async_resolve_default_pipeline_settings(
tts_engine_id = None

return {
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_engine": conversation_engine_id,
"conversation_language": conversation_language,
"language": hass.config.language,
"name": pipeline_name,
Expand All @@ -224,7 +227,7 @@ async def _async_create_default_pipeline(
default stt / tts engines.
"""
pipeline_settings = _async_resolve_default_pipeline_settings(
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant"
hass, pipeline_name="Home Assistant"
)
return await pipeline_store.async_create_item(pipeline_settings)

Expand All @@ -243,7 +246,10 @@ async def async_create_default_pipeline(
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
pipeline_settings = _async_resolve_default_pipeline_settings(
hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name
hass,
stt_engine_id=stt_engine_id,
tts_engine_id=tts_engine_id,
pipeline_name=pipeline_name,
)
if (
pipeline_settings["stt_engine"] != stt_engine_id
Expand Down Expand Up @@ -274,11 +280,11 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P


@callback
def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]:
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
"""Get all pipelines."""
pipeline_data: PipelineData = hass.data[DOMAIN]

return pipeline_data.pipeline_store.data.values()
return list(pipeline_data.pipeline_store.data.values())


async def async_update_pipeline(
Expand Down Expand Up @@ -1675,7 +1681,7 @@ def ws_list_item(
connection.send_result(
msg["id"],
{
"pipelines": self.storage_collection.async_items(),
"pipelines": async_get_pipelines(hass),
"preferred_pipeline": self.storage_collection.async_get_preferred_item(),
},
)
Expand Down