diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index c6aa14bff151d5..ce6f3e8d024a2c 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -5,7 +5,7 @@ import array import asyncio from collections import defaultdict, deque -from collections.abc import AsyncGenerator, AsyncIterable, Callable, Iterable +from collections.abc import AsyncGenerator, AsyncIterable, Callable from dataclasses import asdict, dataclass, field from enum import StrEnum import logging @@ -118,8 +118,10 @@ def validate_language(data: dict[str, Any]) -> Any: @callback def _async_resolve_default_pipeline_settings( hass: HomeAssistant, - stt_engine_id: str | None, - tts_engine_id: str | None, + *, + conversation_engine_id: str | None = None, + stt_engine_id: str | None = None, + tts_engine_id: str | None = None, pipeline_name: str, ) -> dict[str, str | None]: """Resolve settings for a default pipeline. @@ -137,12 +139,13 @@ def _async_resolve_default_pipeline_settings( wake_word_entity = None wake_word_id = None + if conversation_engine_id is None: + conversation_engine_id = conversation.HOME_ASSISTANT_AGENT + # Find a matching language supported by the Home Assistant conversation agent conversation_languages = language_util.matches( hass.config.language, - conversation.async_get_conversation_languages( - hass, conversation.HOME_ASSISTANT_AGENT - ), + conversation.async_get_conversation_languages(hass, conversation_engine_id), country=hass.config.country, ) if conversation_languages: @@ -201,7 +204,7 @@ def _async_resolve_default_pipeline_settings( tts_engine_id = None return { - "conversation_engine": conversation.HOME_ASSISTANT_AGENT, + "conversation_engine": conversation_engine_id, "conversation_language": conversation_language, "language": hass.config.language, "name": pipeline_name, @@ -224,7 +227,7 @@ async def _async_create_default_pipeline( default stt / tts engines. """ pipeline_settings = _async_resolve_default_pipeline_settings( - hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant" + hass, pipeline_name="Home Assistant" ) return await pipeline_store.async_create_item(pipeline_settings) @@ -243,7 +246,10 @@ async def async_create_default_pipeline( pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_store = pipeline_data.pipeline_store pipeline_settings = _async_resolve_default_pipeline_settings( - hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name + hass, + stt_engine_id=stt_engine_id, + tts_engine_id=tts_engine_id, + pipeline_name=pipeline_name, ) if ( pipeline_settings["stt_engine"] != stt_engine_id @@ -274,11 +280,11 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P @callback -def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]: +def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]: """Get all pipelines.""" pipeline_data: PipelineData = hass.data[DOMAIN] - return pipeline_data.pipeline_store.data.values() + return list(pipeline_data.pipeline_store.data.values()) async def async_update_pipeline( @@ -1675,7 +1681,7 @@ def ws_list_item( connection.send_result( msg["id"], { - "pipelines": self.storage_collection.async_items(), + "pipelines": async_get_pipelines(hass), "preferred_pipeline": self.storage_collection.async_get_preferred_item(), }, )