Skip to content

Commit

Permalink
Google Assistant SDK: Always enable conversation agent and support mu…
Browse files Browse the repository at this point in the history
…ltiple languages (#93201)

* Enable agent and support multiple languages

* fix test
  • Loading branch information
tronikos authored Jun 30, 2023
1 parent 1dcaec4 commit 17ceacd
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 122 deletions.
39 changes: 9 additions & 30 deletions homeassistant/components/google_assistant_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,11 @@
)
from homeassistant.helpers.typing import ConfigType

from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
)
from .const import DATA_MEM_STORAGE, DATA_SESSION, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .helpers import (
GoogleAssistantSDKAudioView,
InMemoryStorage,
async_send_text_commands,
default_language_code,
)

SERVICE_SEND_TEXT_COMMAND = "send_text_command"
Expand Down Expand Up @@ -82,8 +75,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:

await async_setup_service(hass)

entry.async_on_unload(entry.add_update_listener(update_listener))
await update_listener(hass, entry)
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, entry, agent)

return True

Expand All @@ -100,8 +93,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
for service_name in hass.services.async_services()[DOMAIN]:
hass.services.async_remove(DOMAIN, service_name)

if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
conversation.async_unset_agent(hass, entry)
conversation.async_unset_agent(hass, entry)

return True

Expand All @@ -125,15 +117,6 @@ async def send_text_command(call: ServiceCall) -> None:
)


async def update_listener(hass, entry):
"""Handle options update."""
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, entry, agent)
else:
conversation.async_unset_agent(hass, entry)


class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
"""Google Assistant SDK conversation agent."""

Expand All @@ -143,6 +126,7 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
self.entry = entry
self.assistant: TextAssistant | None = None
self.session: OAuth2Session | None = None
self.language: str | None = None

@property
def attribution(self):
Expand All @@ -155,10 +139,7 @@ def attribution(self):
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
return [language_code]
return SUPPORTED_LANGUAGE_CODES

async def async_process(
self, user_input: conversation.ConversationInput
Expand All @@ -172,12 +153,10 @@ async def async_process(
if not session.valid_token:
await session.async_ensure_token_valid()
self.assistant = None
if not self.assistant:
if not self.assistant or user_input.language != self.language:
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
self.assistant = TextAssistant(credentials, language_code)
self.language = user_input.language
self.assistant = TextAssistant(credentials, self.language)

resp = self.assistant.assist(user_input.text)
text_response = resp[0] or "<empty response>"
Expand Down
14 changes: 1 addition & 13 deletions homeassistant/components/google_assistant_sdk/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow

from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DEFAULT_NAME,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .helpers import default_language_code

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,12 +108,6 @@ async def async_step_init(
CONF_LANGUAGE_CODE,
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
): vol.In(SUPPORTED_LANGUAGE_CODES),
vol.Required(
CONF_ENABLE_CONVERSATION_AGENT,
default=self.config_entry.options.get(
CONF_ENABLE_CONVERSATION_AGENT
),
): bool,
}
),
)
1 change: 0 additions & 1 deletion homeassistant/components/google_assistant_sdk/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

DEFAULT_NAME: Final = "Google Assistant SDK"

CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
CONF_LANGUAGE_CODE: Final = "language_code"

DATA_MEM_STORAGE: Final = "mem_storage"
Expand Down
4 changes: 1 addition & 3 deletions homeassistant/components/google_assistant_sdk/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
"step": {
"init": {
"data": {
"enable_conversation_agent": "Enable the conversation agent",
"language_code": "Language code"
},
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
}
}
}
},
Expand Down
44 changes: 9 additions & 35 deletions tests/components/google_assistant_sdk/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,65 +223,39 @@ async def test_options_flow(
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
assert set(data_schema) == {"language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
user_input={"language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
assert config_entry.options == {"language_code": "es-ES"}

# Retrigger options flow, not change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
assert set(data_schema) == {"language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
user_input={"language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
assert config_entry.options == {"language_code": "es-ES"}

# Retrigger options flow, change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
assert set(data_schema) == {"language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": False, "language_code": "en-US"},
user_input={"language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "en-US",
}

# Retrigger options flow, enable conversation agent
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": True, "language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": True,
"language_code": "en-US",
}
assert config_entry.options == {"language_code": "en-US"}
87 changes: 47 additions & 40 deletions tests/components/google_assistant_sdk/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from homeassistant.components import conversation
from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.components.google_assistant_sdk.const import SUPPORTED_LANGUAGE_CODES
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow

Expand All @@ -29,13 +30,9 @@ async def fetch_api_url(hass_client, url):
return response.status, contents


@pytest.mark.parametrize(
"enable_conversation_agent", [False, True], ids=["", "enable_conversation_agent"]
)
async def test_setup_success(
hass: HomeAssistant,
setup_integration: ComponentSetup,
enable_conversation_agent: bool,
) -> None:
"""Test successful setup and unload."""
await setup_integration()
Expand All @@ -44,12 +41,6 @@ async def test_setup_success(
assert len(entries) == 1
assert entries[0].state is ConfigEntryState.LOADED

if enable_conversation_agent:
hass.config_entries.async_update_entry(
entries[0], options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

await hass.config_entries.async_unload(entries[0].entry_id)
await hass.async_block_till_done()

Expand Down Expand Up @@ -333,30 +324,21 @@ async def test_conversation_agent(
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
assert agent.supported_languages == ["en-US"]
assert agent.attribution.keys() == {"name", "url"}
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text1, None, Context(), "en-US", config_entry.entry_id
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text2, None, Context(), "en-US", config_entry.entry_id
)

# Assert constructor is called only once since it's reused across requests
Expand All @@ -381,21 +363,14 @@ async def test_conversation_agent_refresh_token(
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text1, None, Context(), "en-US", config_entry.entry_id
)

# Expire the token between requests
Expand All @@ -411,11 +386,8 @@ async def test_conversation_agent_refresh_token(
},
)

await hass.services.async_call(
"conversation",
"process",
{"text": text2, "agent_id": config_entry.entry_id},
blocking=True,
await conversation.async_converse(
hass, text2, None, Context(), "en-US", config_entry.entry_id
)

# Assert constructor is called twice since the token was expired
Expand All @@ -426,3 +398,38 @@ async def test_conversation_agent_refresh_token(
)
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])


async def test_conversation_agent_language_changed(
hass: HomeAssistant,
config_entry: MockConfigEntry,
setup_integration: ComponentSetup,
) -> None:
"""Test GoogleAssistantConversationAgent when language is changed."""
await setup_integration()

assert await async_setup_component(hass, "conversation", {})

entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED

text1 = "tell me a joke"
text2 = "cuéntame un chiste"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await conversation.async_converse(
hass, text1, None, Context(), "en-US", config_entry.entry_id
)
await conversation.async_converse(
hass, text2, None, Context(), "es-ES", config_entry.entry_id
)

# Assert constructor is called twice since the language was changed
assert mock_text_assistant.call_count == 2
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "es-ES")])
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])

0 comments on commit 17ceacd

Please sign in to comment.