Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google Assistant SDK: Always enable conversation agent and support multiple languages #93201

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 9 additions & 30 deletions homeassistant/components/google_assistant_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,11 @@
)
from homeassistant.helpers.typing import ConfigType

from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
)
from .const import DATA_MEM_STORAGE, DATA_SESSION, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .helpers import (
GoogleAssistantSDKAudioView,
InMemoryStorage,
async_send_text_commands,
default_language_code,
)

SERVICE_SEND_TEXT_COMMAND = "send_text_command"
Expand Down Expand Up @@ -82,8 +75,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:

await async_setup_service(hass)

entry.async_on_unload(entry.add_update_listener(update_listener))
await update_listener(hass, entry)
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, entry, agent)

return True

Expand All @@ -100,8 +93,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
for service_name in hass.services.async_services()[DOMAIN]:
hass.services.async_remove(DOMAIN, service_name)

if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
conversation.async_unset_agent(hass, entry)
conversation.async_unset_agent(hass, entry)

return True

Expand All @@ -125,15 +117,6 @@ async def send_text_command(call: ServiceCall) -> None:
)


async def update_listener(hass, entry):
"""Handle options update."""
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, entry, agent)
else:
conversation.async_unset_agent(hass, entry)


class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
"""Google Assistant SDK conversation agent."""

Expand All @@ -143,6 +126,7 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
self.entry = entry
self.assistant: TextAssistant | None = None
self.session: OAuth2Session | None = None
self.language: str | None = None

@property
def attribution(self):
Expand All @@ -155,10 +139,7 @@ def attribution(self):
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
return [language_code]
return SUPPORTED_LANGUAGE_CODES

async def async_process(
self, user_input: conversation.ConversationInput
Expand All @@ -172,12 +153,10 @@ async def async_process(
if not session.valid_token:
await session.async_ensure_token_valid()
self.assistant = None
if not self.assistant:
if not self.assistant or user_input.language != self.language:
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
self.assistant = TextAssistant(credentials, language_code)
self.language = user_input.language
self.assistant = TextAssistant(credentials, self.language)

resp = self.assistant.assist(user_input.text)
text_response = resp[0] or "<empty response>"
Expand Down
14 changes: 1 addition & 13 deletions homeassistant/components/google_assistant_sdk/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow

from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DEFAULT_NAME,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .helpers import default_language_code

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,12 +108,6 @@ async def async_step_init(
CONF_LANGUAGE_CODE,
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
): vol.In(SUPPORTED_LANGUAGE_CODES),
vol.Required(
CONF_ENABLE_CONVERSATION_AGENT,
default=self.config_entry.options.get(
CONF_ENABLE_CONVERSATION_AGENT
),
): bool,
}
),
)
1 change: 0 additions & 1 deletion homeassistant/components/google_assistant_sdk/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

DEFAULT_NAME: Final = "Google Assistant SDK"

CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
CONF_LANGUAGE_CODE: Final = "language_code"

DATA_MEM_STORAGE: Final = "mem_storage"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
"step": {
"init": {
"data": {
"enable_conversation_agent": "Enable the conversation agent",
"language_code": "Language code"
},
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
}
}
}
},
Expand Down
44 changes: 9 additions & 35 deletions tests/components/google_assistant_sdk/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,65 +223,39 @@ async def test_options_flow(
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
assert set(data_schema) == {"language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
user_input={"language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
assert config_entry.options == {"language_code": "es-ES"}

# Retrigger options flow, not change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
assert set(data_schema) == {"language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
user_input={"language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
assert config_entry.options == {"language_code": "es-ES"}

# Retrigger options flow, change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
assert set(data_schema) == {"language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": False, "language_code": "en-US"},
user_input={"language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "en-US",
}

# Retrigger options flow, enable conversation agent
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": True, "language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": True,
"language_code": "en-US",
}
assert config_entry.options == {"language_code": "en-US"}
87 changes: 47 additions & 40 deletions tests/components/google_assistant_sdk/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from homeassistant.components import conversation
from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.components.google_assistant_sdk.const import SUPPORTED_LANGUAGE_CODES
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow

Expand All @@ -29,13 +30,9 @@ async def fetch_api_url(hass_client, url):
return response.status, contents


@pytest.mark.parametrize(
"enable_conversation_agent", [False, True], ids=["", "enable_conversation_agent"]
)
async def test_setup_success(
hass: HomeAssistant,
setup_integration: ComponentSetup,
enable_conversation_agent: bool,
) -> None:
"""Test successful setup and unload."""
await setup_integration()
Expand All @@ -44,12 +41,6 @@ async def test_setup_success(
assert len(entries) == 1
assert entries[0].state is ConfigEntryState.LOADED

if enable_conversation_agent:
hass.config_entries.async_update_entry(
entries[0], options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

await hass.config_entries.async_unload(entries[0].entry_id)
await hass.async_block_till_done()

Expand Down Expand Up @@ -333,30 +324,21 @@ async def test_conversation_agent(
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
assert agent.supported_languages == ["en-US"]
assert agent.attribution.keys() == {"name", "url"}
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text1, None, Context(), "en-US", config_entry.entry_id
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text2, None, Context(), "en-US", config_entry.entry_id
)

# Assert constructor is called only once since it's reused across requests
Expand All @@ -381,21 +363,14 @@ async def test_conversation_agent_refresh_token(
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text1, None, Context(), "en-US", config_entry.entry_id
)

# Expire the token between requests
Expand All @@ -411,11 +386,8 @@ async def test_conversation_agent_refresh_token(
},
)

await hass.services.async_call(
"conversation",
"process",
{"text": text2, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text2, None, Context(), "en-US", config_entry.entry_id
)

# Assert constructor is called twice since the token was expired
Expand All @@ -426,3 +398,38 @@ async def test_conversation_agent_refresh_token(
)
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])


async def test_conversation_agent_language_changed(
hass: HomeAssistant,
config_entry: MockConfigEntry,
setup_integration: ComponentSetup,
) -> None:
"""Test GoogleAssistantConversationAgent when language is changed."""
await setup_integration()

assert await async_setup_component(hass, "conversation", {})

entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED

text1 = "tell me a joke"
text2 = "cuéntame un chiste"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await conversation.async_converse(
hass, text1, None, Context(), "en-US", config_entry.entry_id
)
await conversation.async_converse(
hass, text2, None, Context(), "es-ES", config_entry.entry_id
)

# Assert constructor is called twice since the language was changed
assert mock_text_assistant.call_count == 2
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "es-ES")])
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])