Skip to content

Commit

Permalink
Make the full conversation input available to sentence triggers (#131982
Browse files Browse the repository at this point in the history
)

Co-authored-by: Michael Hansen <[email protected]>
  • Loading branch information
balloob and synesthesiam authored Dec 1, 2024
1 parent ffeefd4 commit 6103cea
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 25 deletions.
6 changes: 2 additions & 4 deletions homeassistant/components/conversation/default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[
[str, RecognizeResult, str | None], Awaitable[str | None]
[ConversationInput, RecognizeResult], Awaitable[str | None]
]
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
METADATA_CUSTOM_FILE = "hass_custom_file"
Expand Down Expand Up @@ -1286,9 +1286,7 @@ async def _handle_trigger_result(

# Gather callback responses in parallel
trigger_callbacks = [
self._trigger_sentences[trigger_id].callback(
user_input.text, trigger_result, user_input.device_id
)
self._trigger_sentences[trigger_id].callback(user_input, trigger_result)
for trigger_id, trigger_result in result.matched_triggers.items()
]

Expand Down
11 changes: 11 additions & 0 deletions homeassistant/components/conversation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ class ConversationInput:
agent_id: str | None = None
"""Agent to use for processing."""

def as_dict(self) -> dict[str, Any]:
"""Return input as a dict."""
return {
"text": self.text,
"context": self.context.as_dict(),
"conversation_id": self.conversation_id,
"device_id": self.device_id,
"language": self.language,
"agent_id": self.agent_id,
}


@dataclass(slots=True)
class ConversationResult:
Expand Down
8 changes: 5 additions & 3 deletions homeassistant/components/conversation/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from homeassistant.helpers.typing import UNDEFINED, ConfigType

from .const import DATA_DEFAULT_ENTITY, DOMAIN
from .models import ConversationInput


def has_no_punctuation(value: list[str]) -> list[str]:
Expand Down Expand Up @@ -62,7 +63,7 @@ async def async_attach_trigger(
job = HassJob(action)

async def call_action(
sentence: str, result: RecognizeResult, device_id: str | None
user_input: ConversationInput, result: RecognizeResult
) -> str | None:
"""Call action with right context."""

Expand All @@ -83,12 +84,13 @@ async def call_action(
trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data,
"platform": DOMAIN,
"sentence": sentence,
"sentence": user_input.text,
"details": details,
"slots": { # direct access to values
entity_name: entity["value"] for entity_name, entity in details.items()
},
"device_id": device_id,
"device_id": user_input.device_id,
"user_input": user_input.as_dict(),
}

# Wait for the automation to complete
Expand Down
2 changes: 1 addition & 1 deletion tests/components/conversation/test_default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context())
assert callback.call_count == 1
assert callback.call_args[0][0] == sentence
assert callback.call_args[0][0].text == sentence
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence
Expand Down
144 changes: 127 additions & 17 deletions tests/components/conversation/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,31 @@ async def test_if_fires_on_event(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)

context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "Ha ha ha"},
blocking=True,
return_response=True,
context=context,
)
assert service_response["response"]["speech"]["plain"]["speech"] == "Done"

Expand All @@ -61,13 +74,21 @@ async def test_if_fires_on_event(
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "Ha ha ha",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "Ha ha ha",
},
}


Expand Down Expand Up @@ -152,7 +173,19 @@ async def test_response_same_sentence(
{"delay": "0:0:0.100"},
{
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
{"set_conversation_response": "response 2"},
],
Expand All @@ -168,13 +201,14 @@ async def test_response_same_sentence(
]
},
)

context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "test sentence"},
blocking=True,
return_response=True,
context=context,
)
await hass.async_block_till_done()

Expand All @@ -188,12 +222,20 @@ async def test_response_same_sentence(
assert service_calls[1].data["data"] == {
"alias": None,
"id": "trigger1",
"idx": "0",
"idx": 0,
"platform": "conversation",
"sentence": "test sentence",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "test sentence",
},
}


Expand Down Expand Up @@ -231,13 +273,14 @@ async def test_response_same_sentence_with_error(
]
},
)

context = Context()
service_response = await hass.services.async_call(
"conversation",
"process",
{"text": "test sentence"},
blocking=True,
return_response=True,
context=context,
)
await hass.async_block_till_done()

Expand Down Expand Up @@ -320,19 +363,32 @@ async def test_same_trigger_multiple_sentences(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)

context = Context()
await hass.services.async_call(
"conversation",
"process",
{
"text": "hello",
},
blocking=True,
context=context,
)

# Only triggers once
Expand All @@ -342,13 +398,21 @@ async def test_same_trigger_multiple_sentences(
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "hello",
"slots": {},
"details": {},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "hello",
},
}


Expand All @@ -371,7 +435,19 @@ async def test_same_sentence_multiple_triggers(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
},
{
Expand All @@ -384,7 +460,19 @@ async def test_same_sentence_multiple_triggers(
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
},
],
Expand Down Expand Up @@ -488,19 +576,33 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
},
"action": {
"service": "test.automation",
"data_template": {"data": "{{ trigger }}"},
"data_template": {
"data": {
"alias": "{{ trigger.alias }}",
"id": "{{ trigger.id }}",
"idx": "{{ trigger.idx }}",
"platform": "{{ trigger.platform }}",
"sentence": "{{ trigger.sentence }}",
"slots": "{{ trigger.slots }}",
"details": "{{ trigger.details }}",
"device_id": "{{ trigger.device_id }}",
"user_input": "{{ trigger.user_input }}",
}
},
},
}
},
)

context = Context()
await hass.services.async_call(
"conversation",
"process",
{
"text": "play the white album by the beatles",
},
blocking=True,
context=context,
)

await hass.async_block_till_done()
Expand All @@ -509,8 +611,8 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
assert service_calls[1].service == "automation"
assert service_calls[1].data["data"] == {
"alias": None,
"id": "0",
"idx": "0",
"id": 0,
"idx": 0,
"platform": "conversation",
"sentence": "play the white album by the beatles",
"slots": {
Expand All @@ -530,6 +632,14 @@ async def test_wildcards(hass: HomeAssistant, service_calls: list[ServiceCall])
},
},
"device_id": None,
"user_input": {
"agent_id": None,
"context": context.as_dict(),
"conversation_id": None,
"device_id": None,
"language": "en",
"text": "play the white album by the beatles",
},
}


Expand Down

0 comments on commit 6103cea

Please sign in to comment.