Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the full conversation input available to sentence triggers #131982

Merged
merged 2 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions homeassistant/components/conversation/default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[
[str, RecognizeResult, str | None], Awaitable[str | None]
[ConversationInput, RecognizeResult], Awaitable[str | None]
]
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file"
Expand Down Expand Up @@ -1286,9 +1286,7 @@ async def _handle_trigger_result(

# Gather callback responses in parallel
trigger_callbacks = [
self._trigger_sentences[trigger_id].callback(
user_input.text, trigger_result, user_input.device_id
)
self._trigger_sentences[trigger_id].callback(user_input, trigger_result)
for trigger_id, trigger_result in result.matched_triggers.items()
]

Expand Down
11 changes: 11 additions & 0 deletions homeassistant/components/conversation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ class ConversationInput:
agent_id: str | None = None
"""Agent to use for processing."""

def as_dict(self) -> dict[str, Any]:
"""Return input as a dict."""
return {
"text": self.text,
"context": self.context.as_dict(),
"conversation_id": self.conversation_id,
"device_id": self.device_id,
"language": self.language,
"agent_id": self.agent_id,
}


@dataclass(slots=True)
class ConversationResult:
Expand Down
8 changes: 5 additions & 3 deletions homeassistant/components/conversation/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from homeassistant.helpers.typing import UNDEFINED, ConfigType

from .const import DATA_DEFAULT_ENTITY, DOMAIN
from .models import ConversationInput


def has_no_punctuation(value: list[str]) -> list[str]:
Expand Down Expand Up @@ -62,7 +63,7 @@ async def async_attach_trigger(
job = HassJob(action)

async def call_action(
sentence: str, result: RecognizeResult, device_id: str | None
user_input: ConversationInput, result: RecognizeResult
) -> str | None:
"""Call action with right context."""

Expand All @@ -83,12 +84,13 @@ async def call_action(
trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data,
"platform": DOMAIN,
"sentence": sentence,
"sentence": user_input.text,
"details": details,
"slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items()
},
"device_id": device_id,
"device_id": user_input.device_id,
"user_input": user_input.as_dict(),
}

# Wait for the automation to complete
Expand Down
2 changes: 1 addition & 1 deletion tests/components/conversation/test_default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context())
assert callback.call_count == 1
assert callback.call_args[0][0] == sentence
assert callback.call_args[0][0].text == sentence
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence
Expand Down
144 changes: 127 additions & 17 deletions tests/components/conversation/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,31 @@ async def test_if_fires_on_event(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)

context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "Ha ha ha"},
blocking=True,
return_response=True,
context=context,
)
assert service_response["response"]["speech"]["plain"]["speech"] == "Done"

Expand All @@ -61,13 +74,21 @@ async def test_if_fires_on_event(
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "Ha ha ha",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "Ha ha ha",
},
}


Expand Down Expand Up @@ -152,7 +173,19 @@ async def test_response_same_sentence(
{"delay": "0:0:0.100"},
{
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
{"set_conversation_response": "response 2"},
],
Expand All @@ -168,13 +201,14 @@ async def test_response_same_sentence(
]
},
)

context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "test sentence"},
blocking=True,
return_response=True,
context=context,
)
await hass.async_block_till_done()

Expand All @@ -188,12 +222,20 @@ async def test_response_same_sentence(
assert service_calls[1].data["data"] == {
"alias": None,
"id": "trigger1",
"idx": "0",
"idx": 0,
"platform": "conversation",
"sentence": "test sentence",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "test sentence",
},
}


Expand Down Expand Up @@ -231,13 +273,14 @@ async def test_response_same_sentence_with_error(
]
},
)

context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "test sentence"},
blocking=True,
return_response=True,
context=context,
)
await hass.async_block_till_done()

Expand Down Expand Up @@ -320,19 +363,32 @@ async def test_same_trigger_multiple_sentences(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)

context = Context()
await hass.services.async_call(
"conversation",
"process",
{
"text": "hello",
},
blocking=True,
context=context,
)

# Only triggers once
Expand All @@ -342,13 +398,21 @@ async def test_same_trigger_multiple_sentences(
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "hello",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "hello",
},
}


Expand All @@ -371,7 +435,19 @@ async def test_same_sentence_multiple_triggers(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
},
{
Expand All @@ -384,7 +460,19 @@ async def test_same_sentence_multiple_triggers(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
},
],
Expand Down Expand Up @@ -488,19 +576,33 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)

context = Context()
await hass.services.async_call(
"conversation",
"process",
{
"text": "play the white album by the beatles",
},
blocking=True,
context=context,
)

await hass.async_block_till_done()
Expand All @@ -509,8 +611,8 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "play the white album by the beatles",
"slots": {
Expand All @@ -530,6 +632,14 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
},
},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "play the white album by the beatles",
},
}


Expand Down