From fe019bfb3675880de152c4cde630d873269306ca Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 13 Dec 2024 10:49:07 -0600 Subject: [PATCH] Move hassil fork internal --- hassil/VERSION | 1 + hassil/__init__.py | 13 + hassil/__main__.py | 89 ++++ hassil/_resources.py | 20 + hassil/errors.py | 13 + hassil/expression.py | 194 +++++++++ hassil/fst.py | 508 ++++++++++++++++++++++ hassil/intents.py | 463 ++++++++++++++++++++ hassil/models.py | 62 +++ hassil/parse_expression.py | 418 ++++++++++++++++++ hassil/parser.py | 315 ++++++++++++++ hassil/py.typed | 0 hassil/recognize.py | 650 ++++++++++++++++++++++++++++ hassil/sample.py | 314 ++++++++++++++ hassil/sample_template.py | 31 ++ hassil/string_matcher.py | 838 +++++++++++++++++++++++++++++++++++++ hassil/trie.py | 87 ++++ hassil/util.py | 216 ++++++++++ requirements.txt | 1 - 19 files changed, 4232 insertions(+), 1 deletion(-) create mode 100644 hassil/VERSION create mode 100644 hassil/__init__.py create mode 100644 hassil/__main__.py create mode 100644 hassil/_resources.py create mode 100644 hassil/errors.py create mode 100644 hassil/expression.py create mode 100644 hassil/fst.py create mode 100644 hassil/intents.py create mode 100644 hassil/models.py create mode 100644 hassil/parse_expression.py create mode 100644 hassil/parser.py create mode 100644 hassil/py.typed create mode 100644 hassil/recognize.py create mode 100644 hassil/sample.py create mode 100644 hassil/sample_template.py create mode 100644 hassil/string_matcher.py create mode 100644 hassil/trie.py create mode 100644 hassil/util.py diff --git a/hassil/VERSION b/hassil/VERSION new file mode 100644 index 0000000..e010258 --- /dev/null +++ b/hassil/VERSION @@ -0,0 +1 @@ +2.0.5 diff --git a/hassil/__init__.py b/hassil/__init__.py new file mode 100644 index 0000000..ec38dc6 --- /dev/null +++ b/hassil/__init__.py @@ -0,0 +1,13 @@ +"""Home Assistant Intent Language parser""" + +from .expression import ( + ListReference, + RuleReference, + Sentence, + Sequence, + SequenceType, + TextChunk, +) +from .intents import Intents +from .parse_expression import parse_sentence +from .recognize import is_match, recognize, recognize_all, recognize_best diff --git a/hassil/__main__.py b/hassil/__main__.py new file mode 100644 index 0000000..3d86073 --- /dev/null +++ b/hassil/__main__.py @@ -0,0 +1,89 @@ +"""Command-line interface to hassil.""" + +import argparse +import logging +import os +import sys +from pathlib import Path + +import yaml + +from .intents import Intents, TextSlotList +from .recognize import recognize +from .util import merge_dict + +_LOGGER = logging.getLogger("hassil") + + +def main(): + """Main entry point""" + parser = argparse.ArgumentParser() + parser.add_argument("yaml", nargs="+", help="YAML files or directories") + parser.add_argument( + "--areas", + nargs="+", + help="Area names", + default=[], + ) + parser.add_argument("--names", nargs="+", default=[], help="Device/entity names") + parser.add_argument( + "--debug", action="store_true", help="Print DEBUG messages to the console" + ) + args = parser.parse_args() + + level = logging.DEBUG if args.debug else logging.INFO + logging.basicConfig(level=level) + _LOGGER.debug(args) + + slot_lists = { + "area": TextSlotList.from_strings(args.areas), + "name": TextSlotList.from_strings(args.names), + } + + input_dict = {"intents": {}} + for yaml_path_str in args.yaml: + yaml_path = Path(yaml_path_str) + if yaml_path.is_dir(): + yaml_file_paths = yaml_path.glob("*.yaml") + else: + yaml_file_paths = [yaml_path] + + for yaml_file_path in yaml_file_paths: + _LOGGER.debug("Loading file: %s", yaml_file_path) + with open(yaml_file_path, "r", encoding="utf-8") as yaml_file: + merge_dict(input_dict, yaml.safe_load(yaml_file)) + + assert input_dict, "No intent YAML files loaded" + intents = Intents.from_dict(input_dict) + + _LOGGER.info("Area names: %s", args.areas) + _LOGGER.info("Device/Entity names: %s", args.names) + + if os.isatty(sys.stdout.fileno()): + print("Reading sentences from stdin...", file=sys.stderr) + + try: + for line in sys.stdin: + line = line.strip() + if not line: + continue + + try: + result = recognize(line, intents, slot_lists=slot_lists) + if result is not None: + print( + { + "intent": result.intent.name, + **{e.name: e.value for e in result.entities_list}, + } + ) + else: + print("") + except Exception: + _LOGGER.exception(line) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/hassil/_resources.py b/hassil/_resources.py new file mode 100644 index 0000000..82fc32f --- /dev/null +++ b/hassil/_resources.py @@ -0,0 +1,20 @@ +"""Shared access to package resources""" + +import os +import typing +from pathlib import Path + +try: + import importlib.resources + + files = importlib.resources.files # type: ignore +except (ImportError, AttributeError): + # Backport for Python < 3.9 + import importlib_resources # type: ignore + + files = importlib_resources.files + +_PACKAGE = "hassil" +_DIR = Path(typing.cast(os.PathLike, files(_PACKAGE))) + +__version__ = (_DIR / "VERSION").read_text(encoding="utf-8").strip() diff --git a/hassil/errors.py b/hassil/errors.py new file mode 100644 index 0000000..e7fefae --- /dev/null +++ b/hassil/errors.py @@ -0,0 +1,13 @@ +"""Errors for hassil.""" + + +class HassilError(Exception): + """Base class for hassil errors""" + + +class MissingListError(HassilError): + """Error when a {slot_list} is missing.""" + + +class MissingRuleError(HassilError): + """Error when an is missing.""" diff --git a/hassil/expression.py b/hassil/expression.py new file mode 100644 index 0000000..9ffbfde --- /dev/null +++ b/hassil/expression.py @@ -0,0 +1,194 @@ +"""Classes for representing sentence templates.""" + +import re +from abc import ABC +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Iterator, List, Optional + + +@dataclass +class Expression(ABC): + """Base class for expressions.""" + + +@dataclass +class TextChunk(Expression): + """Contiguous chunk of text (with whitespace).""" + + # Text with casing/whitespace normalized + text: str = "" + + # Set in __post_init__ + original_text: str = None # type: ignore + + parent: "Optional[Sequence]" = None + + def __post_init__(self): + if self.original_text is None: + self.original_text = self.text + + @property + def is_empty(self) -> bool: + """True if the chunk is empty""" + return self.text == "" + + @staticmethod + def empty() -> "TextChunk": + """Returns an empty text chunk""" + return TextChunk() + + +class SequenceType(str, Enum): + """Type of a sequence. Optionals are alternatives with an empty option.""" + + # Sequence of expressions + GROUP = "group" + + # Expressions where only one will be recognized + ALTERNATIVE = "alternative" + + # Permutations of a set of expressions + PERMUTATION = "permutation" + + +@dataclass +class Sequence(Expression): + """Ordered sequence of expressions. Supports groups, optionals, and alternatives.""" + + # Items in the sequence + items: List[Expression] = field(default_factory=list) + + # Group or alternative + type: SequenceType = SequenceType.GROUP + + is_optional: bool = False + + def text_chunk_count(self) -> int: + """Return the number of TextChunk expressions in this sequence (recursive).""" + num_text_chunks = 0 + for item in self.items: + if isinstance(item, TextChunk): + num_text_chunks += 1 + elif isinstance(item, Sequence): + seq: Sequence = item + num_text_chunks += seq.text_chunk_count() + + return num_text_chunks + + def list_names( + self, + expansion_rules: Optional[Dict[str, "Sentence"]] = None, + ) -> Iterator[str]: + """Return names of list references (recursive).""" + for item in self.items: + yield from self._list_names(item, expansion_rules) + + def _list_names( + self, + item: Expression, + expansion_rules: Optional[Dict[str, "Sentence"]] = None, + ) -> Iterator[str]: + """Return names of list references (recursive).""" + if isinstance(item, ListReference): + list_ref: ListReference = item + yield list_ref.list_name + elif isinstance(item, Sequence): + seq: Sequence = item + yield from seq.list_names(expansion_rules) + elif isinstance(item, RuleReference): + rule_ref: RuleReference = item + if expansion_rules and (rule_ref.rule_name in expansion_rules): + rule_body = expansion_rules[rule_ref.rule_name] + yield from self._list_names(rule_body, expansion_rules) + + +@dataclass +class RuleReference(Expression): + """Reference to an expansion rule by .""" + + # Name of referenced rule + rule_name: str = "" + + +@dataclass +class ListReference(Expression): + """Reference to a list by {name}.""" + + list_name: str = "" + prefix: Optional[str] = None + suffix: Optional[str] = None + _slot_name: Optional[str] = None + + def __post_init__(self): + if ":" in self.list_name: + # list_name:slot_name + self.list_name, self._slot_name = self.list_name.split(":", maxsplit=1) + else: + self._slot_name = self.list_name + + @property + def slot_name(self) -> str: + """Name of slot to put list value into.""" + assert self._slot_name is not None + return self._slot_name + + +@dataclass +class Sentence(Sequence): + """Sequence representing a complete sentence template.""" + + text: Optional[str] = None + pattern: Optional[re.Pattern] = None + + def compile(self, expansion_rules: Dict[str, "Sentence"]) -> None: + if self.pattern is not None: + # Already compiled + return + + pattern_chunks: List[str] = [] + self._compile_expression(self, pattern_chunks, expansion_rules) + + pattern_str = "".join(pattern_chunks).replace(r"\ ", r"[ ]*") + self.pattern = re.compile(f"^{pattern_str}$", re.IGNORECASE) + + def _compile_expression( + self, exp: Expression, pattern_chunks: List[str], rules: Dict[str, "Sentence"] + ): + if isinstance(exp, TextChunk): + # Literal text + chunk: TextChunk = exp + if chunk.text: + escaped_text = re.escape(chunk.text) + pattern_chunks.append(escaped_text) + elif isinstance(exp, Sequence): + # Linear sequence or alternative choices + seq: Sequence = exp + if seq.type == SequenceType.GROUP: + # Linear sequence + for item in seq.items: + self._compile_expression(item, pattern_chunks, rules) + elif seq.type == SequenceType.ALTERNATIVE: + # Alternative choices + if seq.items: + pattern_chunks.append("(?:") + for item in seq.items: + self._compile_expression(item, pattern_chunks, rules) + pattern_chunks.append("|") + pattern_chunks[-1] = ")" + else: + raise ValueError(seq) + elif isinstance(exp, ListReference): + # Slot list + pattern_chunks.append("(?:.+)") + + elif isinstance(exp, RuleReference): + # Expansion rule + rule_ref: RuleReference = exp + if rule_ref.rule_name not in rules: + raise ValueError(rule_ref) + + e_rule = rules[rule_ref.rule_name] + self._compile_expression(e_rule, pattern_chunks, rules) + else: + raise ValueError(exp) diff --git a/hassil/fst.py b/hassil/fst.py new file mode 100644 index 0000000..2e821b2 --- /dev/null +++ b/hassil/fst.py @@ -0,0 +1,508 @@ +import math +from collections import defaultdict, deque +from dataclasses import dataclass, field +from functools import reduce +from typing import Dict, List, Optional, Set, TextIO, Tuple + +from unicode_rbnf import RbnfEngine + +from .intents import ( + Intents, + IntentData, + SlotList, + TextSlotList, + RangeSlotList, + WildcardSlotList, +) +from .expression import ( + Expression, + ListReference, + RuleReference, + TextChunk, + Sequence, + SequenceType, + Sentence, +) +from .util import check_excluded_context, check_required_context + +EPS = "" +SPACE = "" + + +@dataclass +class FstArc: + to_state: int + in_label: str = EPS + out_label: str = EPS + log_prob: Optional[float] = None + + +@dataclass +class Fst: + arcs: Dict[int, List[FstArc]] = field(default_factory=lambda: defaultdict(list)) + final_states: Set[int] = field(default_factory=set) + start: int = 0 + current_state: int = 0 + + def next_state(self) -> int: + self.current_state += 1 + return self.current_state + + def next_edge( + self, + from_state: int, + in_label: Optional[str] = None, + out_label: Optional[str] = None, + log_prob: Optional[float] = None, + ) -> int: + to_state = self.next_state() + self.add_edge(from_state, to_state, in_label, out_label, log_prob) + return to_state + + def add_edge( + self, + from_state: int, + to_state: int, + in_label: Optional[str] = None, + out_label: Optional[str] = None, + log_prob: Optional[float] = None, + ) -> None: + if in_label is None: + in_label = EPS + + if out_label is None: + out_label = in_label + + if (" " in in_label) or (" " in out_label): + raise ValueError( + f"Cannot have white space in labels: from={in_label}, to={out_label}" + ) + + if (not in_label) or (not out_label): + raise ValueError(f"Labels cannot be empty: from={in_label}, to={out_label}") + + self.arcs[from_state].append(FstArc(to_state, in_label, out_label, log_prob)) + + def accept(self, state: int) -> None: + self.final_states.add(state) + + def write(self, fst_file: TextIO, symbols_file: TextIO) -> None: + symbols = {EPS: 0} + + for state, arcs in self.arcs.items(): + for arc in arcs: + if arc.in_label not in symbols: + symbols[arc.in_label] = len(symbols) + + if arc.out_label not in symbols: + symbols[arc.out_label] = len(symbols) + + if arc.log_prob is None: + print( + state, arc.to_state, arc.in_label, arc.out_label, file=fst_file + ) + else: + print( + state, + arc.to_state, + arc.in_label, + arc.out_label, + arc.log_prob, + file=fst_file, + ) + + for state in self.final_states: + print(state, file=fst_file) + + for symbol, symbol_id in symbols.items(): + print(symbol, symbol_id, file=symbols_file) + + def replace(self, replacements: "Dict[str, Fst]") -> "Fst": + pass + + def remove_spaces(self) -> "Fst": + fst_no_spaces = Fst() + q = deque([(self.start, fst_no_spaces.start, [])]) + + while q: + state, next_state, word_parts = q.popleft() + is_final = state in self.final_states + + if is_final and word_parts: + word = "".join(word_parts) + fst_no_spaces.accept(fst_no_spaces.next_edge(next_state, word, word)) + + for arc in self.arcs[state]: + if arc.in_label == SPACE: + # End word + if word_parts: + word = "".join(word_parts) + q.append( + ( + arc.to_state, + fst_no_spaces.next_edge(next_state, word, word), + [], + ) + ) + else: + q.append((arc.to_state, next_state, [])) + else: + # Continue word + if arc.in_label != EPS: + q.append( + (arc.to_state, next_state, word_parts + [arc.in_label]) + ) + else: + q.append((arc.to_state, next_state, word_parts)) + + return fst_no_spaces + + +@dataclass +class NumToWords: + engine: RbnfEngine + cache: Dict[Tuple[int, int, int], Sequence] = field(default_factory=dict) + + +def expression_to_fst( + expression: Expression, + state: int, + fst: Fst, + intent_data: IntentData, + intents: Intents, + slot_lists: Optional[Dict[str, SlotList]] = None, + num_to_words: Optional[NumToWords] = None, +) -> int: + if isinstance(expression, TextChunk): + chunk: TextChunk = expression + + space_before = False + space_after = False + + if chunk.original_text == " ": + return fst.next_edge(state, SPACE) + + if chunk.original_text.startswith(" "): + space_before = True + + if chunk.original_text.endswith(" "): + space_after = True + + word = chunk.original_text.strip() + if not word: + return state + + if space_before: + state = fst.next_edge(state, SPACE) + + sub_words = word.split() + last_sub_word_idx = len(sub_words) - 1 + for sub_word_idx, sub_word in enumerate(sub_words): + state = fst.next_edge(state, sub_word) + if sub_word_idx != last_sub_word_idx: + # Add spaces between words + state = fst.next_edge(state, SPACE) + + if space_after: + state = fst.next_edge(state, SPACE) + + return state + + if isinstance(expression, Sequence): + seq: Sequence = expression + if seq.type == SequenceType.ALTERNATIVE: + start = state + end = fst.next_state() + + for item in seq.items: + state = expression_to_fst( + item, start, fst, intent_data, intents, slot_lists, num_to_words + ) + if state == start: + # Empty item + continue + + fst.add_edge(state, end) + + if seq.is_optional: + fst.add_edge(start, end) + + return end + + if seq.type == SequenceType.GROUP: + for item in seq.items: + state = expression_to_fst( + item, state, fst, intent_data, intents, slot_lists, num_to_words + ) + + return state + + if isinstance(expression, ListReference): + # {list} + list_ref: ListReference = expression + + slot_list: Optional[SlotList] = None + if slot_lists is not None: + slot_list = slot_lists.get(list_ref.list_name) + + if slot_list is None: + slot_list = intent_data.slot_lists.get(list_ref.list_name) + + if slot_list is None: + slot_list = intents.slot_lists.get(list_ref.list_name) + + if isinstance(slot_list, TextSlotList): + text_list: TextSlotList = slot_list + + values = [] + for value in text_list.values: + if (intent_data.requires_context is not None) and ( + not check_required_context( + intent_data.requires_context, + value.context, + allow_missing_keys=True, + ) + ): + continue + + if (intent_data.excludes_context is not None) and ( + not check_excluded_context( + intent_data.excludes_context, + value.context, + ) + ): + continue + + values.append(value.text_in) + + if values: + return expression_to_fst( + Sequence(values, type=SequenceType.ALTERNATIVE), + state, + fst, + intent_data, + intents, + slot_lists, + num_to_words, + ) + + elif isinstance(slot_list, RangeSlotList): + range_list: RangeSlotList = slot_list + number_sequence: Optional[Sequence] = None + num_cache_key = (range_list.start, range_list.stop + 1, range_list.step) + + if num_to_words is not None: + number_sequence = num_to_words.cache.get(num_cache_key) + + if number_sequence is None: + values = [] + # TODO + # for number in range( + # range_list.start, range_list.stop + 1, range_list.step + # ): + # values.append(TextChunk(str(number))) + + if num_to_words is not None: + for number in range( + range_list.start, range_list.stop + 1, range_list.step + ): + number_result = num_to_words.engine.format_number(number) + number_words = { + w.replace("-", " ") + for w in number_result.text_by_ruleset.values() + } + values.extend((TextChunk(w) for w in number_words)) + + number_sequence = Sequence(values, type=SequenceType.ALTERNATIVE) + + if num_to_words is not None: + num_to_words.cache[num_cache_key] = number_sequence + + return expression_to_fst( + number_sequence, + state, + fst, + intent_data, + intents, + slot_lists, + num_to_words, + ) + else: + word = f"{{{list_ref.list_name}}}" + return expression_to_fst( + TextChunk(word), + state, + fst, + intent_data, + intents, + slot_lists, + num_to_words, + ) + + if isinstance(expression, RuleReference): + # + rule_ref: RuleReference = expression + + rule_body: Optional[Sentence] = intent_data.expansion_rules.get( + rule_ref.rule_name + ) + if rule_body is None: + rule_body = intents.expansion_rules.get(rule_ref.rule_name) + + if rule_body is None: + raise ValueError(f"Missing expansion rule <{rule_ref.rule_name}>") + + return expression_to_fst( + rule_body, state, fst, intent_data, intents, slot_lists, num_to_words + ) + + return state + + +def get_count( + e: Expression, + intents: Intents, + intent_data: IntentData, +) -> int: + if isinstance(e, Sequence): + seq: Sequence = e + item_counts = [get_count(item, intents, intent_data) for item in seq.items] + + if seq.type == SequenceType.ALTERNATIVE: + return sum(item_counts) + + if seq.type == SequenceType.GROUP: + return reduce(lambda x, y: x * y, item_counts, 1) + + if isinstance(e, ListReference): + list_ref: ListReference = e + slot_list: Optional[SlotList] = None + + slot_list = intent_data.slot_lists.get(list_ref.list_name) + if not slot_list: + slot_list = intents.slot_lists.get(list_ref.list_name) + + if isinstance(slot_list, TextSlotList): + text_list: TextSlotList = slot_list + return sum( + get_count(v.text_in, intents, intent_data) for v in text_list.values + ) + + if isinstance(slot_list, RangeSlotList): + range_list: RangeSlotList = slot_list + if range_list.step == 1: + return range_list.stop - range_list.start + 1 + + return len(range(range_list.start, range_list.stop + 1, range_list.step)) + + if isinstance(e, RuleReference): + rule_ref: RuleReference = e + rule_body: Optional[Sentence] = None + + rule_body = intent_data.expansion_rules.get(rule_ref.rule_name) + if not rule_body: + rule_body = intents.expansion_rules.get(rule_ref.rule_name) + + if rule_body: + return get_count(rule_body, intents, intent_data) + + return 1 + + +def lcm(*nums: int) -> int: + """Returns the least common multiple of the given integers""" + if nums: + nums_lcm = nums[0] + for n in nums[1:]: + nums_lcm = (nums_lcm * n) // math.gcd(nums_lcm, n) + + return nums_lcm + + return 1 + + +def intents_to_fst( + intents: Intents, + slot_lists: Optional[Dict[str, SlotList]] = None, + number_language: Optional[str] = None, + exclude_intents: Optional[Set[str]] = None, + include_intents: Optional[Set[str]] = None, +) -> Fst: + num_to_words: Optional[NumToWords] = None + if number_language: + num_to_words = NumToWords(engine=RbnfEngine.for_language(number_language)) + + filtered_intents = [] + # sentence_counts: Dict[str, int] = {} + sentence_counts: Dict[Sentence, int] = {} + + for intent in intents.intents.values(): + if (exclude_intents is not None) and (intent.name in exclude_intents): + continue + + if (include_intents is not None) and (intent.name not in include_intents): + continue + + # num_sentences = 0 + for i, data in enumerate(intent.data): + for j, sentence in enumerate(data.sentences): + # num_sentences += get_count(sentence, intents, data) + sentence_counts[(intent.name, i, j)] = get_count( + sentence, intents, data + ) + + filtered_intents.append(intent) + # sentence_counts[intent.name] = num_sentences + + fst_with_spaces = Fst() + final = fst_with_spaces.next_state() + + num_sentences_lcm = lcm(*sentence_counts.values()) + # intent_weights = { + # intent_name: num_sentences_lcm // max(1, count) + # for intent_name, count in sentence_counts.items() + # } + # weight_sum = max(1, sum(intent_weights.values())) + # total_sentences = max(1, sum(sentence_counts.values())) + + sentence_weights = { + key: num_sentences_lcm // max(1, count) + for key, count in sentence_counts.items() + } + weight_sum = max(1, sum(sentence_weights.values())) + + for intent in filtered_intents: + # weight = intent_weights[intent.name] / weight_sum + # weight = 1 / len(filtered_intents) + # print(intent.name, weight) + # intent_prob = -math.log(weight) + # intent_state = fst_with_spaces.next_edge( + # fst_with_spaces.start, SPACE, SPACE, #log_prob=intent_prob + # ) + + for i, data in enumerate(intent.data): + for j, sentence in enumerate(data.sentences): + weight = sentence_weights[(intent.name, i, j)] + sentence_prob = weight / weight_sum + # print(sentence.text, sentence_prob) + sentence_state = fst_with_spaces.next_edge( + fst_with_spaces.start, + SPACE, + SPACE, + # log_prob=-math.log(sentence_prob), + ) + state = expression_to_fst( + sentence, + # intent_state, + sentence_state, + fst_with_spaces, + data, + intents, + slot_lists, + num_to_words, + ) + fst_with_spaces.add_edge(state, final) + + fst_with_spaces.accept(final) + + return fst_with_spaces diff --git a/hassil/intents.py b/hassil/intents.py new file mode 100644 index 0000000..61cd10a --- /dev/null +++ b/hassil/intents.py @@ -0,0 +1,463 @@ +"""Classes/methods for loading YAML intent files.""" + +from abc import ABC +from dataclasses import dataclass, field +from enum import Enum +from functools import cached_property +from pathlib import Path +from typing import IO, Any, Dict, Iterable, List, Optional, Set, Tuple, Union, cast + +from yaml import safe_load + +from .expression import Expression, Sentence, TextChunk +from .parse_expression import parse_sentence +from .util import is_template, merge_dict, normalize_text + + +@dataclass +class SlotList(ABC): + """Base class for slot lists.""" + + name: Optional[str] + """Name of the slot list.""" + + +class RangeType(str, Enum): + """Number range type.""" + + NUMBER = "number" + PERCENTAGE = "percentage" + TEMPERATURE = "temperature" + + +@dataclass +class RangeSlotList(SlotList): + """Slot list for a range of numbers.""" + + start: int + stop: int + step: int = 1 + type: RangeType = RangeType.NUMBER + multiplier: Optional[float] = None + digits: bool = True + words: bool = True + words_language: Optional[str] = None + + def __post_init__(self): + """Validate number range""" + assert self.start < self.stop, "start must be less than stop" + assert self.step > 0, "step must be positive" + assert self.digits or self.words, "must have digits, words, or both" + + +@dataclass +class TextSlotValue: + """Single value in a text slot list.""" + + text_in: Expression + """Input text for this value""" + + value_out: Any + """Output value put into slot""" + + context: Optional[Dict[str, Any]] = None + """Items added to context if value is matched""" + + metadata: Optional[Dict[str, Any]] = None + """Additional metadata to be returned if value is matched""" + + @staticmethod + def from_tuple( + value_tuple: Union[ + Tuple[str, Any], + Tuple[str, Any, Dict[str, Any]], + Tuple[str, Any, Dict[str, Any], Dict[str, Any]], + ], + allow_template: bool = True, + ) -> "TextSlotValue": + """Construct text slot value from a tuple.""" + text_in, value_out, context, metadata = ( + value_tuple[0], + value_tuple[1], + None, + None, + ) + + if len(value_tuple) > 2: + context = cast(Tuple[str, Any, Dict[str, Any]], value_tuple)[2] + if len(value_tuple) > 3: + metadata = cast( + Tuple[str, Any, Dict[str, Any], Dict[str, Any]], value_tuple + )[3] + + return TextSlotValue( + text_in=_maybe_parse_template(text_in, allow_template), + value_out=value_out, + context=context, + metadata=metadata, + ) + + +@dataclass +class TextSlotList(SlotList): + """Slot list with pre-defined text values.""" + + values: List[TextSlotValue] + + @staticmethod + def from_strings( + strings: Iterable[str], + allow_template: bool = True, + name: Optional[str] = None, + ) -> "TextSlotList": + """ + Construct a text slot list from strings. + + Input and output values are the same text. + """ + return TextSlotList( + name=name, + values=[ + TextSlotValue( + text_in=_maybe_parse_template(text, allow_template), + value_out=text, + ) + for text in strings + ], + ) + + @staticmethod + def from_tuples( + tuples: Iterable[ + Union[ + Tuple[str, Any], + Tuple[str, Any, Dict[str, Any]], + Tuple[str, Any, Dict[str, Any], Dict[str, Any]], + ] + ], + allow_template: bool = True, + name: Optional[str] = None, + ) -> "TextSlotList": + """ + Construct a text slot list from text/value pairs. + + Input values are the left (text), output values are the right (any). + """ + return TextSlotList( + name=name, + values=[ + TextSlotValue.from_tuple(value_tuple, allow_template) + for value_tuple in tuples + ], + ) + + +@dataclass +class WildcardSlotList(SlotList): + """Matches as much text as possible.""" + + +@dataclass +class IntentDataSettings: + """Settings for intent data.""" + + filter_with_regex: bool = True + """Use regular expressions compiled from sentence patterns to filter possible matches.""" + + +@dataclass(frozen=True) +class IntentData: + """Block of sentences and known slots for an intent.""" + + sentence_texts: List[str] + """Sentence templates that match this intent.""" + + slots: Dict[str, Any] = field(default_factory=dict) + """Slot values that are assumed if intent is matched.""" + + response: Optional[str] = None + """Key for response to intent.""" + + requires_context: Dict[str, Any] = field(default_factory=dict) + """Context items required before match is successful.""" + + excludes_context: Dict[str, Any] = field(default_factory=dict) + """Context items that must not be present for match to be successful.""" + + expansion_rules: Dict[str, Sentence] = field(default_factory=dict) + """Local expansion rules in the context of a single intent.""" + + slot_lists: Dict[str, SlotList] = field(default_factory=dict) + """Local slot lists in the context of a single intent.""" + + wildcard_list_names: Set[str] = field(default_factory=set) + """List names that are wildcards.""" + + metadata: Optional[Dict[str, Any]] = None + """Metadata that will be passed into the result if matched.""" + + required_keywords: Optional[Set[str]] = None + """Keywords that must be present for any sentence to match.""" + + settings: IntentDataSettings = field(default_factory=IntentDataSettings) + """Settings for block of sentences.""" + + @cached_property + def sentences(self) -> List[Sentence]: + """Sentence templates that match this intent.""" + sentences = [ + parse_sentence(text, keep_text=True) for text in self.sentence_texts + ] + + # Sort sentences so that wildcards with more literal text chunks are processed first. + # This will reorder certain wildcards, for example: + # + # - "play {album} by {artist}" + # - "play {album} by {artist} in {room}" + # + # will be reordered to: + # + # - "play {album} by {artist} in {room}" + # - "play {album} by {artist}" + sentences = sorted(sentences, key=self._sentence_order) + + return sentences + + def _sentence_order(self, sentence: Sentence) -> int: + has_wildcards = False + if self.wildcard_list_names: + # Look for wildcard list references + for list_name in sentence.list_names(): + if list_name in self.wildcard_list_names: + has_wildcards = True + break + + if has_wildcards: + # Sentences with more text chunks should be processed sooner + return -sentence.text_chunk_count() + + return 0 + + +@dataclass +class Intent: + """A named intent with sentences + slots.""" + + name: str + data: List[IntentData] = field(default_factory=list) + + +@dataclass +class IntentsSettings: + """Settings for intents.""" + + ignore_whitespace: bool = False + """True if whitespace should be ignored during matching.""" + + filter_with_regex: bool = True + """Use regular expressions compiled from sentence patterns to filter possible matches.""" + + +@dataclass +class Intents: + """Collection of intents, rules, and lists for a language.""" + + language: str + """Language code (e.g., en).""" + + intents: Dict[str, Intent] + """Intents mapped by name.""" + + slot_lists: Dict[str, SlotList] = field(default_factory=dict) + """Slot lists mapped by name.""" + + expansion_rules: Dict[str, Sentence] = field(default_factory=dict) + """Expansion rules mapped by name.""" + + skip_words: List[str] = field(default_factory=list) + """Words that can be skipped during recognition.""" + + settings: IntentsSettings = field(default_factory=IntentsSettings) + """Settings that may change recognition.""" + + @staticmethod + def from_files(file_paths: Iterable[Union[str, Path]]) -> "Intents": + """Load intents from YAML file paths.""" + intents_dict: Dict[str, Any] = {} + for file_path in file_paths: + with open(file_path, "r", encoding="utf-8") as yaml_file: + merge_dict(intents_dict, safe_load(yaml_file)) + + return Intents.from_dict(intents_dict) + + @staticmethod + def from_yaml(yaml_file: IO[str]) -> "Intents": + """Load intents from a YAML file.""" + return Intents.from_dict(safe_load(yaml_file)) + + @staticmethod + def from_dict(input_dict: Dict[str, Any]) -> "Intents": + """Parse intents from a dict.""" + # language: "" + # settings: + # ignore_whitespace: false + # filter_with_regex: false + # intents: + # IntentName: + # data: + # - sentences: + # - "" + # slots: + # : + # : + # - + # metadata: + # key: value + # expansion_rules: + # : "" + # lists: + # : + # values: + # - "" + # + wildcard_list_names: Set[str] = { + list_name + for list_name, list_dict in input_dict.get("lists", {}).items() + if list_dict.get("wildcard", False) + } + return Intents( + language=input_dict["language"], + intents={ + intent_name: Intent( + name=intent_name, + data=[ + IntentData( + sentence_texts=data_dict["sentences"], + slots=data_dict.get("slots", {}), + requires_context=data_dict.get("requires_context", {}), + excludes_context=data_dict.get("excludes_context", {}), + expansion_rules={ + rule_name: parse_sentence(rule_body, keep_text=True) + for rule_name, rule_body in data_dict.get( + "expansion_rules", {} + ).items() + }, + slot_lists={ + list_name: _parse_list(list_name, list_dict) + for list_name, list_dict in data_dict.get( + "lists", {} + ).items() + }, + response=data_dict.get("response"), + wildcard_list_names=wildcard_list_names, + metadata=data_dict.get("metadata"), + required_keywords=( + set(data_dict["required_keywords"]) + if "required_keywords" in data_dict + else None + ), + settings=_parse_data_settings( + data_dict.get("settings", {}) + ), + ) + for data_dict in intent_dict["data"] + ], + ) + for intent_name, intent_dict in input_dict["intents"].items() + }, + slot_lists={ + list_name: _parse_list(list_name, list_dict) + for list_name, list_dict in input_dict.get("lists", {}).items() + }, + expansion_rules={ + rule_name: parse_sentence(rule_body, keep_text=True) + for rule_name, rule_body in input_dict.get( + "expansion_rules", {} + ).items() + }, + skip_words=input_dict.get("skip_words", []), + settings=_parse_settings(input_dict.get("settings", {})), + ) + + +def _parse_list( + list_name: str, + list_dict: Dict[str, Any], + allow_template: bool = True, +) -> SlotList: + """Parses a slot list from a dict.""" + if "values" in list_dict: + # Text values + text_values: List[TextSlotValue] = [] + for value in list_dict["values"]: + if isinstance(value, str) and allow_template and is_template(value): + # Wrap template + value = {"in": value} + + if isinstance(value, str): + # String value + text_values.append( + TextSlotValue( + text_in=_maybe_parse_template(value, allow_template), + value_out=value, + ) + ) + else: + # Object with "in" and "out" + text_values.append( + TextSlotValue( + text_in=_maybe_parse_template(value["in"], allow_template), + value_out=value.get("out"), + context=value.get("context"), + metadata=value.get("metadata"), + ) + ) + + return TextSlotList(name=list_name, values=text_values) + + if "range" in list_dict: + # Number range + range_dict = list_dict["range"] + range_multiplier = range_dict.get("multiplier") + return RangeSlotList( + name=list_name, + type=RangeType(range_dict.get("type", "number")), + start=int(range_dict["from"]), + stop=int(range_dict["to"]), + step=int(range_dict.get("step", 1)), + multiplier=( + float(range_multiplier) if range_multiplier is not None else None + ), + digits=bool(range_dict.get("digits", True)), + words=bool(range_dict.get("words", True)), + words_language=range_dict.get("words_language"), + ) + + if list_dict.get("wildcard", False): + # Wildcard + return WildcardSlotList(name=list_name) + + raise ValueError(f"Unknown slot list type: {list_dict}") + + +def _parse_settings(settings_dict: Dict[str, Any]) -> IntentsSettings: + """Parse intent settings.""" + return IntentsSettings( + ignore_whitespace=settings_dict.get("ignore_whitespace", False), + filter_with_regex=settings_dict.get("filter_with_regex", True), + ) + + +def _parse_data_settings(settings_dict: Dict[str, Any]) -> IntentDataSettings: + """Parse intent data settings.""" + return IntentDataSettings( + filter_with_regex=settings_dict.get("filter_with_regex", True), + ) + + +def _maybe_parse_template(text: str, allow_template: bool = True) -> Expression: + """Parse string as a sentence template if it has template syntax.""" + if allow_template and is_template(text): + return parse_sentence(text) + + return TextChunk(normalize_text(text)) diff --git a/hassil/models.py b/hassil/models.py new file mode 100644 index 0000000..916b3d0 --- /dev/null +++ b/hassil/models.py @@ -0,0 +1,62 @@ +"""Shared models.""" + +from abc import ABC +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +from .util import PUNCTUATION_ALL + + +@dataclass +class MatchEntity: + """Named entity that has been matched from a {slot_list}""" + + name: str + """Name of the entity.""" + + value: Any + """Value of the entity.""" + + text: str + """Original value text.""" + + metadata: Optional[Dict[str, Any]] = None + """Entity metadata.""" + + is_wildcard: bool = False + """True if entity is a wildcard.""" + + is_wildcard_open: bool = True + """While True, wildcard can continue matching.""" + + @property + def text_clean(self) -> str: + """Trimmed text with punctuation removed.""" + return PUNCTUATION_ALL.sub("", self.text.strip()) + + +@dataclass +class UnmatchedEntity(ABC): + """Base class for unmatched entities.""" + + name: str + """Name of entity that should have matched.""" + + +@dataclass +class UnmatchedTextEntity(UnmatchedEntity): + """Text entity that should have matched.""" + + text: str + """Text that failed to match slot values.""" + + is_open: bool = True + """While True, entity can continue matching.""" + + +@dataclass +class UnmatchedRangeEntity(UnmatchedEntity): + """Range entity that should have matched.""" + + value: Union[int, float] + """Value of entity that was out of range.""" diff --git a/hassil/parse_expression.py b/hassil/parse_expression.py new file mode 100644 index 0000000..a627045 --- /dev/null +++ b/hassil/parse_expression.py @@ -0,0 +1,418 @@ +from dataclasses import dataclass +from itertools import permutations +from typing import List, Optional + +from .expression import ( + Expression, + ListReference, + RuleReference, + Sentence, + Sequence, + SequenceType, + TextChunk, +) +from .parser import ( + GROUP_END, + GROUP_START, + LIST_END, + LIST_START, + OPT_END, + OPT_START, + RULE_END, + RULE_START, + ParseChunk, + ParseError, + ParseType, + next_chunk, + remove_delimiters, +) +from .util import normalize_text + + +@dataclass +class ParseMetadata: + """Debug metadata for more helpful parsing errors.""" + + file_name: str + line_number: int + intent_name: Optional[str] = None + + +class ParseExpressionError(ParseError): + def __init__(self, chunk: ParseChunk, metadata: Optional[ParseMetadata] = None): + super().__init__() + self.chunk = chunk + self.metadata = metadata + + def __str__(self) -> str: + return f"Error in chunk {self.chunk} at {self.metadata}" + + +def ensure_alternative(seq: Sequence): + if seq.type != SequenceType.ALTERNATIVE: + seq.type = SequenceType.ALTERNATIVE + + # Collapse items into a single group + seq.items = [ + Sequence( + type=SequenceType.GROUP, + items=seq.items, + ) + ] + + +def ensure_permutation(seq: Sequence): + if seq.type != SequenceType.PERMUTATION: + seq.type = SequenceType.PERMUTATION + + # Collapse items into a single group + seq.items = [ + Sequence( + type=SequenceType.GROUP, + items=seq.items, + ) + ] + + +def parse_group_or_alt_or_perm( + seq_chunk: ParseChunk, metadata: Optional[ParseMetadata] = None +) -> Sequence: + seq = Sequence(type=SequenceType.GROUP) + if seq_chunk.parse_type == ParseType.GROUP: + seq_text = remove_delimiters(seq_chunk.text, GROUP_START, GROUP_END) + elif seq_chunk.parse_type == ParseType.OPT: + seq_text = remove_delimiters(seq_chunk.text, OPT_START, OPT_END) + else: + raise ParseExpressionError(seq_chunk, metadata=metadata) + + item_chunk = next_chunk(seq_text) + last_seq_text = seq_text + + while item_chunk is not None: + if item_chunk.parse_type in ( + ParseType.WORD, + ParseType.GROUP, + ParseType.OPT, + ParseType.LIST, + ParseType.RULE, + ): + item = parse_expression(item_chunk, metadata=metadata) + + if seq.type in (SequenceType.ALTERNATIVE, SequenceType.PERMUTATION): + # Add to most recent group + if not seq.items: + seq.items.append(Sequence(type=SequenceType.GROUP)) + + # Must be group or alternative + last_item = seq.items[-1] + if not isinstance(last_item, Sequence): + raise ParseExpressionError(seq_chunk, metadata=metadata) + + last_item.items.append(item) + else: + # Add to parent group + seq.items.append(item) + + if isinstance(item, TextChunk): + item_tc: TextChunk = item + item_tc.parent = seq + elif item_chunk.parse_type == ParseType.ALT: + ensure_alternative(seq) + + # Begin new group + seq.items.append(Sequence(type=SequenceType.GROUP)) + elif item_chunk.parse_type == ParseType.PERM: + ensure_permutation(seq) + + # Begin new group + seq.items.append(Sequence(type=SequenceType.GROUP)) + else: + raise ParseExpressionError(seq_chunk, metadata=metadata) + + # Next chunk + seq_text = seq_text[item_chunk.end_index :] + + if seq_text == last_seq_text: + # No change, unable to proceed + raise ParseExpressionError(seq_chunk, metadata=metadata) + + item_chunk = next_chunk(seq_text) + last_seq_text = seq_text + + if seq.type == SequenceType.PERMUTATION: + permuted_items: List[Expression] = [] + + for permutation in permutations(seq.items): + permutation_with_spaces = add_spaces_between_items(list(permutation)) + permuted_items.append( + Sequence(type=SequenceType.GROUP, items=permutation_with_spaces) + ) + + seq = Sequence(type=SequenceType.ALTERNATIVE, items=permuted_items) + + return seq + + +def parse_expression( + chunk: ParseChunk, metadata: Optional[ParseMetadata] = None +) -> Expression: + if chunk.parse_type == ParseType.WORD: + return TextChunk(text=normalize_text(chunk.text), original_text=chunk.text) + + if chunk.parse_type == ParseType.GROUP: + return parse_group_or_alt_or_perm(chunk, metadata=metadata) + + if chunk.parse_type == ParseType.OPT: + seq = parse_group_or_alt_or_perm(chunk, metadata=metadata) + ensure_alternative(seq) + seq.items.append(TextChunk(text="", parent=seq)) + seq.is_optional = True + return seq + + if chunk.parse_type == ParseType.LIST: + return ListReference( + list_name=remove_delimiters(chunk.text, LIST_START, LIST_END), + ) + + if chunk.parse_type == ParseType.RULE: + rule_name = remove_delimiters( + chunk.text, + RULE_START, + RULE_END, + ) + + return RuleReference(rule_name=rule_name) + + raise ParseExpressionError(chunk, metadata=metadata) + + +def parse_sentence( + text: str, keep_text=False, metadata: Optional[ParseMetadata] = None +) -> Sentence: + """Parse a single sentence.""" + original_text = text + text = text.strip() + # text = fix_pattern_whitespace(text.strip()) + + # Wrap in a group because sentences need to always be sequences. + text = f"({text})" + + chunk = next_chunk(text) + if chunk is None: + raise ParseError(f"Unexpected empty chunk in: {text}") + + if chunk.parse_type != ParseType.GROUP: + raise ParseError(f"Expected (group) in: {text}") + + if chunk.start_index != 0: + raise ParseError(f"Expected (group) to start at index 0 in: {text}") + + if chunk.end_index != len(text): + raise ParseError(f"Expected chunk to end at index {chunk.end_index} in: {text}") + + seq = parse_expression(chunk, metadata=metadata) + if not isinstance(seq, Sequence): + raise ParseError(f"Expected Sequence, got: {seq}") + + # Unpack redundant sequence + if len(seq.items) == 1: + first_item = seq.items[0] + if isinstance(first_item, Sequence): + seq = first_item + + return Sentence( + type=seq.type, + items=seq.items, + text=original_text if keep_text else None, + is_optional=seq.is_optional, + ) + + +# def fix_pattern_whitespace(text: str) -> str: +# if PERM_SEP in text: +# # Fix within permutations +# text = PERM_SEP.join( +# GROUP_START + fix_pattern_whitespace(perm_chunk) + GROUP_END +# for perm_chunk in text.split(PERM_SEP) +# ) + +# # Recursively process (group) +# group_start_index = text.find(GROUP_START) +# while group_start_index > 0: +# # TODO: Can't cross OPT boundary +# group_end_index = find_end_delimiter( +# text, group_start_index + 1, GROUP_START, GROUP_END +# ) +# if group_end_index is None: +# return text # will fail parsing + +# before_group, text_without_group, after_group = ( +# text[:group_start_index], +# text[group_start_index + 1 : group_end_index - 1], +# text[group_end_index:], +# ) + +# text = ( +# fix_pattern_whitespace(before_group) +# + GROUP_START +# + fix_pattern_whitespace(text_without_group) +# + GROUP_END +# + fix_pattern_whitespace(after_group) +# ) +# group_start_index = text.find(GROUP_START, group_end_index) + +# # Fix whitespace after optional (beginning of sentence) +# left_text, right_text = "", text +# while right_text.startswith(OPT_START): +# opt_end_index = find_end_delimiter(right_text, 1, OPT_START, OPT_END) +# if (opt_end_index is None) or (opt_end_index >= len(right_text)): +# break + +# if not right_text[opt_end_index].isspace(): +# # No adjustment needed +# break + +# # Move whitespace into optional and group +# left_text += ( +# OPT_START +# + GROUP_START +# + right_text[1 : opt_end_index - 1] +# + GROUP_END +# + " " +# + OPT_END +# ) +# right_text = right_text[opt_end_index:].lstrip() + +# text = left_text + right_text + +# # Fix whitespace before optional (end of sentence) +# left_text, right_text = text, "" +# while left_text.endswith(OPT_END): +# opt_end_index = len(left_text) +# opt_start_index = left_text.rfind(OPT_START) +# maybe_opt_end_index: Optional[int] = None + +# # Keep looking back for the "[" that starts this optional +# while opt_start_index > 0: +# maybe_opt_end_index = find_end_delimiter( +# left_text, opt_start_index + 1, OPT_START, OPT_END +# ) +# if maybe_opt_end_index == opt_end_index: +# break # found the matching "[" + +# # Look farther back +# opt_start_index = left_text.rfind(OPT_START, 0, opt_start_index) + +# if (maybe_opt_end_index != opt_end_index) or (opt_start_index <= 0): +# break + +# if not left_text[opt_start_index - 1].isspace(): +# # No adjustment needed +# break + +# # Move whitespace into optional and group +# right_text = ( +# (OPT_START + " " + GROUP_START + left_text[opt_start_index + 1 : -1]) +# + GROUP_END +# + OPT_END +# + right_text +# ) + +# left_text = left_text[:opt_start_index].rstrip() + +# text = left_text + right_text + +# # Fix whitespace around optional (middle of a sentence) +# left_text, right_text = "", text +# match = re.search(rf"\s({re.escape(OPT_START)})", right_text) +# while match is not None: +# opt_start_index = match.start(1) +# opt_end_index = find_end_delimiter( +# right_text, opt_start_index + 1, OPT_START, OPT_END +# ) +# if (opt_end_index is None) or (opt_end_index >= len(text)): +# break + +# if right_text[opt_end_index].isspace(): +# # Move whitespace inside optional, add group +# left_text += ( +# right_text[: opt_start_index - 1] +# + OPT_START +# + " " +# + GROUP_START +# + right_text[opt_start_index + 1 : opt_end_index - 1].lstrip() +# + GROUP_END +# + OPT_END +# ) +# else: +# left_text += right_text[:opt_end_index] + +# right_text = right_text[opt_end_index:] +# if not right_text: +# break + +# match = re.search(rf"\s({re.escape(OPT_START)})", right_text) + +# text = left_text + right_text + +# return text + + +def add_spaces_between_items(items: List[Expression]) -> List[Expression]: + """Add spaces between each 2 items of a sequence, used for permutations""" + spaced_items: List[Expression] = [] + + # Unpack single item sequences to make pattern matching easier below + unpacked_items: List[Expression] = [] + for item in items: + while ( + isinstance(item, Sequence) + and (item.type == SequenceType.GROUP) + and (len(item.items) == 1) + ): + item = item.items[0] + + unpacked_items.append(item) + + previous_item: Optional[Expression] = None + for item_idx, item in enumerate(unpacked_items): + if item_idx > 0: + # Only add whitespace after the first item + if isinstance(previous_item, Sequence) and previous_item.is_optional: + # Modify the previous optional to include a space at the end of + # each item. + opt: Sequence = previous_item + fixed_items: List[Expression] = [] + for opt_item in opt.items: + fix_item = True + if isinstance(opt_item, TextChunk): + opt_tc: TextChunk = opt_item + if not opt_tc.text: + # Don't fix empty text chunks + fix_item = False + else: + # Remove ending whitespace since we'll be adding a + # whitespace text chunk after. + opt_tc.text = opt_tc.text.rstrip() + + if fix_item: + fixed_items.append( + Sequence( + type=SequenceType.GROUP, + items=[opt_item, TextChunk(" ")], + ) + ) + else: + fixed_items.append(opt_item) + + spaced_items[-1] = Sequence( + type=SequenceType.ALTERNATIVE, is_optional=True, items=fixed_items + ) + else: + # Add a space in front + spaced_items.append(TextChunk(text=" ")) + + spaced_items.append(item) + previous_item = item + + return spaced_items diff --git a/hassil/parser.py b/hassil/parser.py new file mode 100644 index 0000000..222ffdb --- /dev/null +++ b/hassil/parser.py @@ -0,0 +1,315 @@ +import re +from dataclasses import dataclass +from enum import Enum, auto +from typing import Optional + +GROUP_START = "(" +GROUP_END = ")" +OPT_START = "[" +OPT_END = "]" +LIST_START = "{" +LIST_END = "}" +RULE_START = "<" +RULE_END = ">" + +DELIM = { + GROUP_START: GROUP_END, + OPT_START: OPT_END, + LIST_START: LIST_END, + RULE_START: RULE_END, +} +DELIM_START = tuple(DELIM.keys()) +DELIM_END = tuple(DELIM.values()) + +WORD_SEP = " " +ALT_SEP = "|" +PERM_SEP = ";" +ESCAPE_CHAR = "\\" + + +class ParseType(Enum): + """Parse chunk types.""" + + GROUP = auto() + OPT = auto() + ALT = auto() + PERM = auto() + RULE = auto() + LIST = auto() + WORD = auto() + END = auto() + + +@dataclass +class ParseChunk: + """Block of text that means something to the parser.""" + + text: str + start_index: int + end_index: int + parse_type: ParseType + + +def find_end_delimiter( + text: str, start_index: int, start_char: str, end_char: str +) -> Optional[int]: + """Finds the index of an ending delimiter.""" + if start_index > 0: + text = text[start_index:] + + stack = 1 + is_escaped = False + for i, c in enumerate(text): + if is_escaped: + is_escaped = False + continue + + if c == ESCAPE_CHAR: + is_escaped = True + continue + + if c == end_char: + stack -= 1 + if stack < 0: + return None + + if stack == 0: + return start_index + i + 1 + + if c == start_char: + stack += 1 + + return None + + +def find_end_word(text: str, start_index: int) -> Optional[int]: + """Finds the end index of a word.""" + if start_index > 0: + text = text[start_index:] + + is_escaped = False + separator_found = False + for i, c in enumerate(text): + if is_escaped: + is_escaped = False + continue + + if c == ESCAPE_CHAR: + is_escaped = True + continue + + if (i > 0) and (c == WORD_SEP): + separator_found = True + continue + + if separator_found and (c != WORD_SEP): + # Start of next word + return start_index + i + + if (c == ALT_SEP) or (c == PERM_SEP) or (c in DELIM_START) or (c in DELIM_END): + return start_index + i + + if text: + # Entire text is a word + return start_index + len(text) + + return None + + +def peek_type(text, start_index: int) -> ParseType: + """Gets the parse chunk type based on the next character.""" + if start_index >= len(text): + return ParseType.END + + c = text[start_index] + if c == GROUP_START: + return ParseType.GROUP + + if c == OPT_START: + return ParseType.OPT + + if c == ALT_SEP: + return ParseType.ALT + + if c == PERM_SEP: + return ParseType.PERM + + if c == LIST_START: + return ParseType.LIST + + if c == RULE_START: + return ParseType.RULE + + return ParseType.WORD + + +class ParseError(Exception): + """Base class for parse errors""" + + +def skip_text(text: str, start_index: int, skip: str) -> int: + """Skips a string in text, taking escapes into account.""" + if start_index > 0: + text = text[start_index:] + + if not text: + raise ParseError(f"Cannot skip '{skip}' in empty text") + + text_index = 0 + for c_text in text: + if c_text == ESCAPE_CHAR: + text_index += 1 + continue + + if c_text != skip[0]: + break + + text_index += 1 + skip = skip[1:] + + if not skip: + break + + if skip: + raise ParseError(f"Failed to skip '{skip}' in: {text}") + + return start_index + text_index + + +def next_chunk(text: str, start_index: int = 0) -> Optional[ParseChunk]: + """Gets the next parsable chunk from text.""" + next_type = peek_type(text, start_index) + + if next_type == ParseType.WORD: + # Single word + word_end_index = find_end_word(text, start_index) + if word_end_index is None: + raise ParseError( + f"Unable to find end of word from index {start_index} in: {text}" + ) + + word_text = remove_escapes(text[start_index:word_end_index]) + + return ParseChunk( + text=word_text, + start_index=start_index, + end_index=word_end_index, + parse_type=ParseType.WORD, + ) + + if next_type == ParseType.GROUP: + # Skip '(' + group_start_index = skip_text(text, start_index, GROUP_START) + group_end_index = find_end_delimiter( + text, group_start_index, GROUP_START, GROUP_END + ) + if group_end_index is None: + raise ParseError( + f"Unable to find end of group ')' from index {start_index} in: {text}" + ) + + group_text = remove_escapes(text[start_index:group_end_index]) + + return ParseChunk( + text=group_text, + start_index=start_index, + end_index=group_end_index, + parse_type=ParseType.GROUP, + ) + + if next_type == ParseType.OPT: + # Skip '[' + opt_start_index = skip_text(text, start_index, OPT_START) + opt_end_index = find_end_delimiter(text, opt_start_index, OPT_START, OPT_END) + if opt_end_index is None: + raise ParseError( + f"Unable to find end of optional ']' from index {start_index} in: {text}" + ) + + opt_text = remove_escapes(text[start_index:opt_end_index]) + + return ParseChunk( + text=opt_text, + start_index=start_index, + end_index=opt_end_index, + parse_type=ParseType.OPT, + ) + + if next_type == ParseType.LIST: + # Skip '{' + list_start_index = skip_text(text, start_index, LIST_START) + list_end_index = find_end_delimiter( + text, list_start_index, LIST_START, LIST_END + ) + if list_end_index is None: + raise ParseError( + f"Unable to find end of list '}}' from index {start_index} in: {text}" + ) + + return ParseChunk( + text=remove_escapes(text[start_index:list_end_index]), + start_index=start_index, + end_index=list_end_index, + parse_type=ParseType.LIST, + ) + + if next_type == ParseType.RULE: + # Skip '<' + rule_start_index = skip_text(text, start_index, RULE_START) + rule_end_index = find_end_delimiter( + text, rule_start_index, RULE_START, RULE_END + ) + if rule_end_index is None: + raise ParseError( + f"Unable to find end of rule '>' from index {start_index} in: {text}" + ) + + return ParseChunk( + text=remove_escapes(text[start_index:rule_end_index]), + start_index=start_index, + end_index=rule_end_index, + parse_type=ParseType.RULE, + ) + + if next_type == ParseType.ALT: + return ParseChunk( + text=text[start_index : start_index + 1], + start_index=start_index, + end_index=start_index + 1, + parse_type=ParseType.ALT, + ) + + if next_type == ParseType.PERM: + return ParseChunk( + text=text[start_index : start_index + 1], + start_index=start_index, + end_index=start_index + 1, + parse_type=ParseType.PERM, + ) + + return None + + +def remove_delimiters( + text: str, start_char: str, end_char: Optional[str] = None +) -> str: + """Removes the surrounding delimiters in text.""" + if end_char is None: + assert len(text) > 1, "Text is too short" + assert text[0] == start_char, "Wrong start char" + return text[1:] + + assert len(text) > 2, "Text is too short" + assert text[0] == start_char, "Wrong start char" + assert text[-1] == end_char, "Wrong end char" + return text[1:-1] + + +def remove_escapes(text: str) -> str: + """Remove backslash escape sequences""" + return re.sub(r"\\(.)", r"\1", text) + + +def escape_text(text: str) -> str: + """Escape parentheses, etc.""" + return re.sub(r"([()\[\]{}<>])", r"\\\1", text) diff --git a/hassil/py.typed b/hassil/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/hassil/recognize.py b/hassil/recognize.py new file mode 100644 index 0000000..b614d39 --- /dev/null +++ b/hassil/recognize.py @@ -0,0 +1,650 @@ +"""Methods for recognizing intents from text.""" + +import collections.abc +import itertools +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, MutableSequence, Optional, Tuple + +from .expression import Sentence +from .intents import Intent, IntentData, Intents, SlotList +from .models import MatchEntity, UnmatchedEntity, UnmatchedTextEntity +from .string_matcher import MatchContext, MatchSettings, match_expression +from .util import ( + WHITESPACE, + check_excluded_context, + check_required_context, + normalize_text, + remove_punctuation, + remove_skip_words, +) + +MISSING_ENTITY = "" + +_LOGGER = logging.getLogger() + + +@dataclass +class RecognizeResult: + """Result of recognition.""" + + intent: Intent + """Matched intent""" + + intent_data: IntentData + """Matched intent data""" + + entities: Dict[str, MatchEntity] = field(default_factory=dict) + """Matched entities mapped by name.""" + + entities_list: List[MatchEntity] = field(default_factory=list) + """Matched entities as a list (duplicates allowed).""" + + response: Optional[str] = None + """Key for intent response.""" + + context: Dict[str, Any] = field(default_factory=dict) + """Context values acquired during matching.""" + + unmatched_entities: Dict[str, UnmatchedEntity] = field(default_factory=dict) + """Unmatched entities mapped by name.""" + + unmatched_entities_list: List[UnmatchedEntity] = field(default_factory=list) + """Unmatched entities as a list (duplicates allowed).""" + + text_chunks_matched: int = 0 + """Number of literal text chunks that were successfully matched.""" + + intent_sentence: Optional[Sentence] = None + """Sentence template that was matched.""" + + intent_metadata: Optional[Dict[str, Any]] = None + """Metadata from the intent sentence that was matched.""" + + +def recognize( + text: str, + intents: Intents, + slot_lists: Optional[Dict[str, SlotList]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, + skip_words: Optional[List[str]] = None, + intent_context: Optional[Dict[str, Any]] = None, + default_response: Optional[str] = "default", + allow_unmatched_entities: bool = False, + language: Optional[str] = None, +) -> Optional[RecognizeResult]: + """Return the first match of input text/words against a collection of intents. + + text: Text to recognize + intents: Compiled intents + slot_lists: Pre-defined text lists, ranges, or wildcards + expansion_rules: Named template snippets + skip_words: Strings to ignore in text + intent_context: Slot values to use when not found in text + default_response: Response key to use if not set in intent + allow_unmatched_entities: True if entity values outside slot lists are allowed (slower) + language: Optional language to use when converting digits to words + + Returns the first result. + If allow_unmatched_entities is True, you should check for unmatched entities. + """ + for result in recognize_all( + text, + intents, + slot_lists=slot_lists, + expansion_rules=expansion_rules, + skip_words=skip_words, + intent_context=intent_context, + default_response=default_response, + allow_unmatched_entities=allow_unmatched_entities, + language=language, + ): + return result + + return None + + +def recognize_all( + text: str, + intents: Intents, + slot_lists: Optional[Dict[str, SlotList]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, + skip_words: Optional[Iterable[str]] = None, + intent_context: Optional[Dict[str, Any]] = None, + default_response: Optional[str] = "default", + allow_unmatched_entities: bool = False, + language: Optional[str] = None, +) -> Iterable[RecognizeResult]: + """Return all matches for input text/words against a collection of intents. + + text: Text to recognize + intents: Compiled intents + slot_lists: Pre-defined text lists, ranges, or wildcards + expansion_rules: Named template snippets + skip_words: Strings to ignore in text + intent_context: Slot values to use when not found in text + default_response: Response key to use if not set in intent + allow_unmatched_entities: True if entity values outside slot lists are allowed (slower) + language: Optional language to use when converting digits to words + + Yields results as they're matched. + If allow_unmatched_entities is True, you should check for unmatched entities. + """ + text = normalize_text(remove_punctuation(text)).strip() + + if skip_words is None: + skip_words = intents.skip_words + else: + # Combine skip words + skip_words = list(itertools.chain(skip_words, intents.skip_words)) + + if skip_words: + text = remove_skip_words(text, skip_words, intents.settings.ignore_whitespace) + + text_keywords = text.split() + + if slot_lists is None: + slot_lists = intents.slot_lists + else: + # Combine with intents + slot_lists = {**intents.slot_lists, **slot_lists} + + if slot_lists is None: + slot_lists = {} + + if expansion_rules is None: + expansion_rules = intents.expansion_rules + else: + # Combine rules + expansion_rules = {**intents.expansion_rules, **expansion_rules} + + if intent_context is None: + intent_context = {} + + # Filter intents based on context and keywords + available_intents: MutableSequence[ + Tuple[Intent, IntentData, MatchSettings, Optional[List[Sentence]]] + ] = [] + + for intent in intents.intents.values(): + for intent_data in intent.data: + if ( + intent_data.required_keywords + and intent_data.required_keywords.isdisjoint(text_keywords) + ): + # No keyword overlap + continue + + if intent_context: + # Skip sentence templates that can't possibly be matched due to + # requires/excludes context. + # + # Additional context can be added during matching, so we can + # only be sure about keys that exist right now. + if intent_data.requires_context and ( + not check_required_context( + intent_data.requires_context, + intent_context, + allow_missing_keys=True, + ) + ): + continue + + if intent_data.excludes_context and ( + not check_excluded_context( + intent_data.excludes_context, intent_context + ) + ): + continue + + match_settings = MatchSettings( + slot_lists={ + **slot_lists, + **intent_data.slot_lists, + }, + expansion_rules={ + **expansion_rules, + **intent_data.expansion_rules, + }, + ignore_whitespace=intents.settings.ignore_whitespace, + allow_unmatched_entities=allow_unmatched_entities, + language=language or intents.language, + ) + + available_intents.append((intent, intent_data, match_settings, None)) + + # Filter with regex + if intents.settings.filter_with_regex and (not allow_unmatched_entities): + matching_intents: MutableSequence[ + Tuple[Intent, IntentData, MatchSettings, Optional[List[Sentence]]] + ] = [] + + for intent, intent_data, match_settings, _intent_sentences in available_intents: + if not intent_data.settings.filter_with_regex: + # All sentences + matching_intents.append((intent, intent_data, match_settings, None)) + continue + + matching_intent_sentences = [] + for intent_sentence in intent_data.sentences: + # Compile to regex once + intent_sentence.compile(match_settings.expansion_rules) + assert intent_sentence.pattern is not None + + regex_match = intent_sentence.pattern.match(text) + if regex_match is not None: + matching_intent_sentences.append(intent_sentence) + + if matching_intent_sentences: + matching_intents.append( + (intent, intent_data, match_settings, matching_intent_sentences) + ) + + if matching_intents: + available_intents = matching_intents + + # Fall back to string matcher + if intents.settings.ignore_whitespace: + text = WHITESPACE.sub("", text) + else: + # Artifical word boundary + text += " " + + for intent, intent_data, match_settings, intent_sentences in available_intents: + if not intent_sentences: + intent_sentences = intent_data.sentences + + # Check each sentence template + for intent_sentence in intent_sentences: + # Create initial context + match_context = MatchContext( + text=text, + intent_context=intent_context, + intent_sentence=intent_sentence, + intent_data=intent_data, + ) + maybe_match_contexts = match_expression( + match_settings, match_context, intent_sentence + ) + yield from _process_match_contexts( + maybe_match_contexts, + intent, + intent_data, + default_response=default_response, + allow_unmatched_entities=allow_unmatched_entities, + ) + + +def _merge_match_contexts( + match_contexts: Iterable[MatchContext], merged_context: MatchContext +) -> MatchContext: + for match_context in match_contexts: + if match_context.text: + # Needed for open wildcards + merged_context.text = match_context.text + + merged_context.entities.extend(match_context.entities) + merged_context.intent_context.update(match_context.intent_context) + + return merged_context + + +def _process_match_contexts( + match_contexts: Iterable[MatchContext], + intent: Intent, + intent_data: IntentData, + default_response: Optional[str] = None, + allow_unmatched_entities: bool = False, +) -> Iterable[RecognizeResult]: + for maybe_match_context in match_contexts: + # Close any open wildcards or unmatched entities + final_text = maybe_match_context.text.strip() + if final_text: + if unmatched_entity := maybe_match_context.get_open_entity(): + # Consume the rest of the text (unmatched entity) + unmatched_entity.text += final_text + unmatched_entity.is_open = False + maybe_match_context.text = "" + elif wildcard := maybe_match_context.get_open_wildcard(): + # Consume the rest of the text (wildcard) + wildcard.text += final_text + wildcard.value = wildcard.text + wildcard.is_wildcard_open = False + maybe_match_context.text = "" + + if not maybe_match_context.is_match: + # Incomplete match with text still left at the end + continue + + # Verify excluded context + if intent_data.excludes_context and ( + not check_excluded_context( + intent_data.excludes_context, + maybe_match_context.intent_context, + ) + ): + continue + + # Verify required context + slots_from_context: List[MatchEntity] = [] + if intent_data.requires_context and ( + not _copy_and_check_required_context( + intent_data.requires_context, + maybe_match_context, + slots_from_context, + allow_unmatched_entities=allow_unmatched_entities, + ) + ): + continue + + # Clean up wildcard entities + for entity in maybe_match_context.entities: + if not entity.is_wildcard: + continue + + entity.text = entity.text.strip() + if isinstance(entity.value, str): + entity.value = entity.value.strip() + + # Add fixed entities + entity_names = set(entity.name for entity in maybe_match_context.entities) + for slot_name, slot_value in intent_data.slots.items(): + if slot_name not in entity_names: + maybe_match_context.entities.append( + MatchEntity(name=slot_name, value=slot_value, text="") + ) + + # Add context slots + for slot_entity in slots_from_context: + if slot_entity.name not in entity_names: + maybe_match_context.entities.append(slot_entity) + + # Return each match + response = default_response + if intent_data.response is not None: + response = intent_data.response + + intent_metadata: Optional[Dict[str, Any]] = None + if maybe_match_context.intent_data is not None: + intent_metadata = maybe_match_context.intent_data.metadata + + yield RecognizeResult( + intent=intent, + intent_data=intent_data, + entities={entity.name: entity for entity in maybe_match_context.entities}, + entities_list=maybe_match_context.entities, + response=response, + context=maybe_match_context.intent_context, + unmatched_entities={ + entity.name: entity for entity in maybe_match_context.unmatched_entities + }, + unmatched_entities_list=maybe_match_context.unmatched_entities, + text_chunks_matched=maybe_match_context.text_chunks_matched, + intent_sentence=maybe_match_context.intent_sentence, + intent_metadata=intent_metadata, + ) + + +def is_match( + text: str, + sentence: Sentence, + slot_lists: Optional[Dict[str, SlotList]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, + skip_words: Optional[Iterable[str]] = None, + entities: Optional[Dict[str, Any]] = None, + intent_context: Optional[Dict[str, Any]] = None, + ignore_whitespace: bool = False, + allow_unmatched_entities: bool = False, + language: Optional[str] = None, +) -> Optional[MatchContext]: + """Return the first match of input text/words against a sentence expression.""" + text = normalize_text(remove_punctuation(text)).strip() + + if skip_words: + text = remove_skip_words(text, skip_words, ignore_whitespace) + + if ignore_whitespace: + text = WHITESPACE.sub("", text) + else: + # Artifical word boundary + text += " " + + if slot_lists is None: + slot_lists = {} + + if expansion_rules is None: + expansion_rules = {} + + if intent_context is None: + intent_context = {} + + settings = MatchSettings( + slot_lists=slot_lists, + expansion_rules=expansion_rules, + ignore_whitespace=ignore_whitespace, + allow_unmatched_entities=allow_unmatched_entities, + language=language, + ) + + match_context = MatchContext( + text=text, + intent_context=intent_context, + intent_sentence=sentence, + ) + + for maybe_match_context in match_expression(settings, match_context, sentence): + if maybe_match_context.is_match: + return maybe_match_context + + return None + + +def _copy_and_check_required_context( + required_context: Dict[str, Any], + maybe_match_context: MatchContext, + slots_from_context: List[MatchEntity], + allow_unmatched_entities: bool = False, +) -> bool: + """Check required context and copy slots into new entities.""" + for ( + context_key, + context_value, + ) in required_context.items(): + copy_to_slot: Optional[str] = None + if isinstance(context_value, collections.abc.Mapping): + # Unpack dict + # : + # value: ... + # slot: true/false or "name" + maybe_copy_to_slot = context_value.get("slot") + if isinstance(maybe_copy_to_slot, str): + # Slot name provided + copy_to_slot = maybe_copy_to_slot + elif maybe_copy_to_slot: + # True + copy_to_slot = context_key + + context_value = context_value.get("value") + + actual_value = maybe_match_context.intent_context.get(context_key) + actual_text = "" + actual_metadata: Optional[Dict[str, Any]] = None + + if isinstance(actual_value, collections.abc.Mapping): + # Unpack dict + actual_text = actual_value.get("text", "") + actual_metadata = actual_value.get("metadata") + actual_value = actual_value.get("value") + + if allow_unmatched_entities and (actual_value is None): + # Look in unmatched entities + for unmatched_context_entity in maybe_match_context.unmatched_entities: + if (unmatched_context_entity.name == context_key) and isinstance( + unmatched_context_entity, UnmatchedTextEntity + ): + actual_value = unmatched_context_entity.text + break + + if actual_value == context_value and context_value is not None: + # Exact match to context value, except when context value is required and not provided + if copy_to_slot: + slots_from_context.append( + MatchEntity( + name=copy_to_slot, + value=actual_value, + text=actual_text, + metadata=actual_metadata, + ) + ) + continue + + if (context_value is None) and (actual_value is not None): + # Any value matches, as long as it's set + if copy_to_slot: + slots_from_context.append( + MatchEntity( + name=copy_to_slot, + value=actual_value, + text=actual_text, + metadata=actual_metadata, + ) + ) + continue + + if ( + isinstance(context_value, collections.abc.Collection) + and not isinstance(context_value, str) + and (actual_value in context_value) + ): + # Actual value was in context value list + if copy_to_slot: + slots_from_context.append( + MatchEntity( + name=copy_to_slot, + value=actual_value, + text=actual_text, + metadata=actual_metadata, + ) + ) + continue + + if allow_unmatched_entities: + # Create missing entity as unmatched + has_unmatched_entity = False + for unmatched_context_entity in maybe_match_context.unmatched_entities: + if unmatched_context_entity.name == context_key: + has_unmatched_entity = True + break + + if not has_unmatched_entity: + maybe_match_context.unmatched_entities.append( + UnmatchedTextEntity( + name=context_key, + text=MISSING_ENTITY, + is_open=False, + ) + ) + else: + # Did not match required context + return False + + return True + + +def recognize_best( + text: str, + intents: Intents, + slot_lists: Optional[Dict[str, SlotList]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, + skip_words: Optional[Iterable[str]] = None, + intent_context: Optional[Dict[str, Any]] = None, + default_response: Optional[str] = "default", + allow_unmatched_entities: bool = False, + language: Optional[str] = None, + best_metadata_key: Optional[str] = None, + best_slot_name: Optional[str] = None, +) -> Optional[RecognizeResult]: + """Find the best result with the following priorities: + + 1. The result that has "best_metadata_key" in its metadata + 2. The result that has an entity for "best_slot_name" and longest text + 3. The result that matches the most literal text + + See "recognize_all" for other parameters. + """ + metadata_found = False + slot_found = False + best_results: List[RecognizeResult] = [] + best_slot_quality: Optional[int] = None + + for result in recognize_all( + text, + intents, + slot_lists=slot_lists, + expansion_rules=expansion_rules, + skip_words=skip_words, + intent_context=intent_context, + default_response=default_response, + allow_unmatched_entities=allow_unmatched_entities, + language=language, + ): + # Prioritize intents with a specific metadata key + if best_metadata_key is not None: + is_metadata = ( + result.intent_metadata is not None + and result.intent_metadata.get(best_metadata_key) + ) + + if metadata_found and not is_metadata: + continue + + if (not metadata_found) and is_metadata: + metadata_found = True + + # Clear builtin results + slot_found = False + best_results = [] + best_slot_quality = None + + # Prioritize results with a specific slot + if best_slot_name: + entity = result.entities.get(best_slot_name) + is_slot = (entity is not None) and not entity.is_wildcard + + if slot_found and not is_slot: + continue + + if (not slot_found) and is_slot: + slot_found = True + + # Clear non-slot results + best_results = [] + + if is_slot and (entity is not None) and isinstance(entity.value, str): + # Prioritize results with a better slot value + slot_quality = len(entity.text) + if (best_slot_quality is None) or (slot_quality > best_slot_quality): + best_slot_quality = slot_quality + + # Clear worse slot results + best_results = [] + elif slot_quality < best_slot_quality: + continue + + # Accumulate results. We will resolve the ambiguity below. + best_results.append(result) + + if best_results: + # Prioritize matches with fewer wildcards and more literal text matched. + return sorted(best_results, key=_get_result_score)[0] + + return None + + +def _get_result_score(result: RecognizeResult) -> Tuple[int, int]: + """Get sort score for a result with (wildcards, -text_matched). + + text_matched is negated since we are sorting with lowest first. + """ + num_wildcards = sum(1 for e in result.entities_list if e.is_wildcard) + return (num_wildcards, -result.text_chunks_matched) diff --git a/hassil/sample.py b/hassil/sample.py new file mode 100644 index 0000000..c627250 --- /dev/null +++ b/hassil/sample.py @@ -0,0 +1,314 @@ +"""CLI tool for sampling sentences from intents.""" + +import argparse +import itertools +import json +import logging +import sys +from functools import partial +from pathlib import Path +from typing import Dict, Iterable, Optional, Set, Tuple + +import yaml +from unicode_rbnf import RbnfEngine + +from .errors import MissingListError, MissingRuleError +from .expression import ( + Expression, + ListReference, + RuleReference, + Sentence, + Sequence, + SequenceType, + TextChunk, +) +from .intents import Intents, RangeSlotList, SlotList, TextSlotList, WildcardSlotList +from .util import merge_dict, normalize_whitespace + +_LOGGER = logging.getLogger("hassil.sample") + +# lang -> engine +_ENGINE_CACHE: Dict[str, RbnfEngine] = {} + + +def sample_intents( + intents: Intents, + slot_lists: Optional[Dict[str, SlotList]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, + max_sentences_per_intent: Optional[int] = None, + intent_names: Optional[Set[str]] = None, + language: Optional[str] = None, + exclude_sentences_with_wildcards: bool = True, + expand_ranges: bool = True, +) -> Iterable[Tuple[str, str]]: + """Sample text strings for sentences from intents.""" + if slot_lists is None: + slot_lists = intents.slot_lists + else: + # Combine with intents + slot_lists = {**intents.slot_lists, **slot_lists} + + if slot_lists is None: + slot_lists = {} + + if expansion_rules is None: + expansion_rules = intents.expansion_rules + else: + # Combine rules + expansion_rules = {**intents.expansion_rules, **expansion_rules} + + for intent_name, intent in intents.intents.items(): + if intent_names and (intent_name not in intent_names): + # Skip intent + continue + + num_intent_sentences = 0 + skip_intent = False + + for intent_data in intent.data: + if intent_data.expansion_rules: + local_expansion_rules = { + **expansion_rules, + **intent_data.expansion_rules, + } + else: + local_expansion_rules = expansion_rules + + for intent_sentence in intent_data.sentences: + if exclude_sentences_with_wildcards and any( + list_name in intent_data.wildcard_list_names + for list_name in intent_sentence.list_names(local_expansion_rules) + ): + continue + + sentence_texts = sample_expression( + intent_sentence, + slot_lists, + local_expansion_rules, + language=language, + expand_ranges=expand_ranges, + ) + for sentence_text in sentence_texts: + yield (intent_name, sentence_text) + num_intent_sentences += 1 + + if (max_sentences_per_intent is not None) and ( + 0 < max_sentences_per_intent <= num_intent_sentences + ): + skip_intent = True + break + + if skip_intent: + break + + if skip_intent: + break + + +def sample_expression( + expression: Expression, + slot_lists: Optional[Dict[str, SlotList]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, + language: Optional[str] = None, + expand_lists: bool = True, + expand_ranges: bool = True, +) -> Iterable[str]: + """Sample possible text strings from an expression.""" + if isinstance(expression, TextChunk): + chunk: TextChunk = expression + yield chunk.original_text + elif isinstance(expression, Sequence): + seq: Sequence = expression + if seq.type == SequenceType.ALTERNATIVE: + for item in seq.items: + yield from sample_expression( + item, + slot_lists, + expansion_rules, + language=language, + expand_lists=expand_lists, + expand_ranges=expand_ranges, + ) + elif seq.type == SequenceType.GROUP: + seq_sentences = map( + partial( + sample_expression, + slot_lists=slot_lists, + expansion_rules=expansion_rules, + language=language, + expand_lists=expand_lists, + expand_ranges=expand_ranges, + ), + seq.items, + ) + sentence_texts = itertools.product(*seq_sentences) + for sentence_words in sentence_texts: + yield normalize_whitespace("".join(sentence_words)) + else: + raise ValueError(f"Unexpected sequence type: {seq}") + elif isinstance(expression, ListReference): + # {list} + list_ref: ListReference = expression + + if not expand_lists: + yield f"{{{list_ref.list_name}}}" + return + + if (not slot_lists) or (list_ref.list_name not in slot_lists): + raise MissingListError(f"Missing slot list {{{list_ref.list_name}}}") + + slot_list = slot_lists[list_ref.list_name] + if isinstance(slot_list, TextSlotList): + text_list: TextSlotList = slot_list + + if not text_list.values: + # Not necessarily an error, but may be a surprise + _LOGGER.warning("No values for list: %s", list_ref.list_name) + + for text_value in text_list.values: + yield from sample_expression( + text_value.text_in, + slot_lists, + expansion_rules, + language=language, + expand_lists=expand_lists, + expand_ranges=expand_ranges, + ) + elif isinstance(slot_list, RangeSlotList): + range_list: RangeSlotList = slot_list + + if not expand_ranges: + if range_list.name: + yield f"{{{range_list.name}}}" + else: + yield "{number}" + return + + if range_list.digits: + number_strs = map( + str, range(range_list.start, range_list.stop + 1, range_list.step) + ) + yield from number_strs + + if range_list.words: + words_language = range_list.words_language or language + if words_language: + engine = _ENGINE_CACHE.get(words_language) + if engine is None: + engine = RbnfEngine.for_language(words_language) + _ENGINE_CACHE[words_language] = engine + + assert engine is not None + + # digits -> words + for word_number in range( + range_list.start, range_list.stop + 1, range_list.step + ): + # Use all unique words for a number, including different + # genders, cases, etc. + format_result = engine.format_number(word_number) + unique_number_strs = set(format_result.text_by_ruleset.values()) + yield from unique_number_strs + else: + _LOGGER.warning( + "No language set, so cannot convert %s digits to words", + list_ref.slot_name, + ) + elif isinstance(slot_list, WildcardSlotList): + wildcard_list: WildcardSlotList = slot_list + if wildcard_list.name: + yield f"{{{wildcard_list.name}}}" + else: + yield "{wildcard}" + else: + raise ValueError(f"Unexpected slot list type: {slot_list}") + elif isinstance(expression, RuleReference): + # + rule_ref: RuleReference = expression + if (not expansion_rules) or (rule_ref.rule_name not in expansion_rules): + raise MissingRuleError(f"Missing expansion rule <{rule_ref.rule_name}>") + + rule_body = expansion_rules[rule_ref.rule_name] + yield from sample_expression( + rule_body, + slot_lists, + expansion_rules, + language=language, + expand_lists=expand_lists, + expand_ranges=expand_ranges, + ) + else: + raise ValueError(f"Unexpected expression: {expression}") + + +def main(): + """Main entry point""" + parser = argparse.ArgumentParser() + parser.add_argument("yaml", nargs="+", help="YAML files or directories") + parser.add_argument( + "-n", + "--max-sentences-per-intent", + type=int, + help="Limit number of sentences per intent", + ) + parser.add_argument( + "--intents", nargs="+", help="Only sample sentences from these intents" + ) + parser.add_argument( + "--areas", + nargs="+", + help="Area names", + default=["area"], + ) + parser.add_argument( + "--names", nargs="+", default=["entity"], help="Device/entity names" + ) + parser.add_argument("--language", help="Language for digits to words") + parser.add_argument( + "--debug", action="store_true", help="Print DEBUG messages to the console" + ) + args = parser.parse_args() + + level = logging.DEBUG if args.debug else logging.INFO + logging.basicConfig(level=level) + _LOGGER.debug(args) + + slot_lists = { + "area": TextSlotList.from_strings(args.areas), + "name": TextSlotList.from_strings(args.names), + } + + input_dict = {"intents": {}} + for yaml_path_str in args.yaml: + yaml_path = Path(yaml_path_str) + if yaml_path.is_dir(): + yaml_file_paths = yaml_path.glob("*.yaml") + else: + yaml_file_paths = [yaml_path] + + for yaml_file_path in yaml_file_paths: + _LOGGER.debug("Loading file: %s", yaml_file_path) + with open(yaml_file_path, "r", encoding="utf-8") as yaml_file: + merge_dict(input_dict, yaml.safe_load(yaml_file)) + + assert input_dict, "No intent YAML files loaded" + intents = Intents.from_dict(input_dict) + + intents_and_texts = sample_intents( + intents, + slot_lists, + max_sentences_per_intent=args.max_sentences_per_intent, + intent_names=set(args.intents) if args.intents else None, + language=args.language, + ) + for intent_name, sentence_text in intents_and_texts: + json.dump( + {"intent": intent_name, "text": sentence_text.strip()}, + sys.stdout, + ensure_ascii=False, + ) + print("") + + +if __name__ == "__main__": + main() diff --git a/hassil/sample_template.py b/hassil/sample_template.py new file mode 100644 index 0000000..83d6f99 --- /dev/null +++ b/hassil/sample_template.py @@ -0,0 +1,31 @@ +"""CLI tool for sampling sentences from a template.""" + +import argparse +import logging + +from .parse_expression import parse_sentence +from .sample import sample_expression + +_LOGGER = logging.getLogger("hassil.sample_template") + + +def main(): + """Main entry point""" + parser = argparse.ArgumentParser() + parser.add_argument("sentence", help="Sentence template") + parser.add_argument( + "--debug", action="store_true", help="Print DEBUG messages to the console" + ) + args = parser.parse_args() + + level = logging.DEBUG if args.debug else logging.INFO + logging.basicConfig(level=level) + _LOGGER.debug(args) + + sentence = parse_sentence(args.sentence) + for text in sample_expression(sentence): + print(text) + + +if __name__ == "__main__": + main() diff --git a/hassil/string_matcher.py b/hassil/string_matcher.py new file mode 100644 index 0000000..06becd5 --- /dev/null +++ b/hassil/string_matcher.py @@ -0,0 +1,838 @@ +"""Original hassil matcher.""" + +import logging +import re +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +from unicode_rbnf import RbnfEngine + +from .errors import MissingListError, MissingRuleError +from .expression import ( + Expression, + ListReference, + RuleReference, + Sentence, + Sequence, + SequenceType, + TextChunk, +) +from .intents import IntentData, RangeSlotList, SlotList, TextSlotList, WildcardSlotList +from .models import ( + MatchEntity, + UnmatchedEntity, + UnmatchedRangeEntity, + UnmatchedTextEntity, +) +from .trie import Trie +from .util import ( + PUNCTUATION_ALL, + WHITESPACE, + check_excluded_context, + check_required_context, + match_first, + match_start, +) + +NUMBER_START = re.compile(r"^(\s*-?[0-9]+)") +NUMBER_ANYWHERE = re.compile(r"(\s*-?[0-9]+)") +BREAK_WORDS_TABLE = str.maketrans("-_", " ") + +# lang -> engine +_ENGINE_CACHE: Dict[str, RbnfEngine] = {} + +# lang -> number -> words +_RANGE_TRIE_CACHE: Dict[str, Dict[Tuple[int, int, int], Trie]] = defaultdict(dict) + +_LOGGER = logging.getLogger() + + +@dataclass +class MatchSettings: + """Settings used in match_expression.""" + + slot_lists: Dict[str, SlotList] = field(default_factory=dict) + """Available slot lists mapped by name.""" + + expansion_rules: Dict[str, Sentence] = field(default_factory=dict) + """Available expansion rules mapped by name.""" + + ignore_whitespace: bool = False + """True if whitespace should be ignored during matching.""" + + allow_unmatched_entities: bool = False + """True if unmatched entities are kept for better error messages (slower).""" + + language: Optional[str] = None + """Optional language to use when converting digits to words.""" + + +@dataclass +class MatchContext: + """Context passed to match_expression.""" + + text: str + """Input text remaining to be processed.""" + + entities: List[MatchEntity] = field(default_factory=list) + """Entities that have been found in input text.""" + + intent_context: Dict[str, Any] = field(default_factory=dict) + """Context items from outside or acquired during matching.""" + + is_start_of_word: bool = True + """True if current text is the start of a word.""" + + unmatched_entities: List[UnmatchedEntity] = field(default_factory=list) + """Entities that failed to match (requires allow_unmatched_entities=True).""" + + close_wildcards: bool = False + """True if open wildcards should be closed during init.""" + + close_unmatched: bool = False + """True if open unmatched entities should be closed during init.""" + + text_chunks_matched: int = 0 + """Number of literal text chunks that were matched.""" + + intent_sentence: Optional[Sentence] = None + """Sentence template that is being matched.""" + + intent_data: Optional[IntentData] = None + """Data from sentence template group in intents.""" + + def __post_init__(self): + if self.close_wildcards: + for entity in self.entities: + entity.is_wildcard_open = False + + if self.close_unmatched: + for unmatched_entity in self.unmatched_entities: + if isinstance(unmatched_entity, UnmatchedTextEntity): + unmatched_entity.is_open = False + + @property + def is_match(self) -> bool: + """True if no text is left that isn't just whitespace or punctuation""" + text = PUNCTUATION_ALL.sub("", self.text).strip() + if text: + return False + + # Wildcards cannot be empty + for entity in self.entities: + if entity.is_wildcard and (not entity.text.strip()): + return False + + # Unmatched entities cannot be empty + for unmatched_entity in self.unmatched_entities: + if isinstance(unmatched_entity, UnmatchedTextEntity) and ( + not unmatched_entity.text.strip() + ): + return False + + return True + + def get_open_wildcard(self) -> Optional[MatchEntity]: + """Get the last open wildcard or None.""" + if not self.entities: + return None + + last_entity = self.entities[-1] + if last_entity.is_wildcard and last_entity.is_wildcard_open: + return last_entity + + return None + + def get_open_entity(self) -> Optional[UnmatchedTextEntity]: + """Get the last open unmatched text entity or None.""" + if not self.unmatched_entities: + return None + + last_entity = self.unmatched_entities[-1] + if isinstance(last_entity, UnmatchedTextEntity) and last_entity.is_open: + return last_entity + + return None + + +def match_expression( + settings: MatchSettings, context: MatchContext, expression: Expression +) -> Iterable[MatchContext]: + """Yield matching contexts for an expression""" + if isinstance(expression, TextChunk): + chunk: TextChunk = expression + + if settings.ignore_whitespace: + # Remove all whitespace + chunk_text = WHITESPACE.sub("", chunk.text) + context_text = WHITESPACE.sub("", context.text) + else: + # Keep whitespace + chunk_text = chunk.text + context_text = context.text + + if context.is_start_of_word: + # Ignore extra whitespace at the beginning of chunk and text + # since we know we're at the start of a word. + chunk_text = chunk_text.lstrip() + context_text = context_text.lstrip() + + # True if remaining text to be matched is empty or whitespace. + # + # If so, we can't say this is a successful match yet because the + # sentence template may have remaining non-optional expressions. + # + # So we have to continue matching, skipping over empty or whitespace + # chunks until the template is exhausted. + is_context_text_empty = len(context_text.strip()) == 0 + + if chunk.is_empty: + # Skip empty chunk (NOT whitespace) + yield context + else: + wildcard = context.get_open_wildcard() + if (wildcard is not None) and (not wildcard.text.strip()): + if not chunk_text.strip(): + # Skip space + yield MatchContext( + text=context_text, + is_start_of_word=True, + # Copy over + entities=context.entities, + intent_context=context.intent_context, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + ) + return + + # Wildcard cannot be empty + start_idx = match_first(context_text, chunk_text) + if start_idx < 0: + # Cannot possibly match + return + + if start_idx == 0: + # Possible degenerate case where the next word in the + # template duplicates. + start_idx = match_first(context_text, chunk_text, 1) + if start_idx < 0: + # Cannot possibly match + return + + # Produce all possible matches where the wildcard consumes text + # up to where the chunk matches in the string. + entities_without_wildcard = context.entities[:-1] + while start_idx > 0: + wildcard_text = context_text[:start_idx] + yield from match_expression( + settings, + MatchContext( + text=context_text[start_idx:], + is_start_of_word=True, + entities=entities_without_wildcard + + [ + MatchEntity( + name=wildcard.name, + text=wildcard_text, + value=wildcard_text, + is_wildcard=True, + is_wildcard_open=False, # always close + ) + ], + # Copy over + intent_context=context.intent_context, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + ), + expression, + ) + start_idx = match_first(context_text, chunk_text, start_idx + 1) + + # Do not continue with matching + return + + end_pos = match_start(context_text, chunk_text) + if end_pos is not None: + # Successful match for chunk + context_text = context_text[end_pos:] + + # Close wildcards/unmatched entities on non-empty chunk + chunk_text_stripped = chunk_text.strip() + is_chunk_non_empty = len(chunk_text_stripped) > 0 + + text_chunks_matched = context.text_chunks_matched + if is_chunk_non_empty: + text_chunks_matched += len(chunk_text_stripped) + + yield MatchContext( + text=context_text, + # must use chunk.text because it hasn't been stripped + is_start_of_word=chunk.text.endswith(" "), + text_chunks_matched=text_chunks_matched, + # Copy over + entities=context.entities, + intent_context=context.intent_context, + unmatched_entities=context.unmatched_entities, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + close_wildcards=is_chunk_non_empty, + close_unmatched=is_chunk_non_empty, + ) + elif is_context_text_empty and chunk_text.isspace(): + # No text left to match, so extra whitespace is OK to skip + yield context + else: + # Try breaking words apart + context_text = context_text.translate(BREAK_WORDS_TABLE) + end_pos = match_start(context_text, chunk_text) + + if end_pos is not None: + context_text = context_text[end_pos:] + + # Close wildcards/unmatched entities on non-empty chunk + is_chunk_non_empty = len(chunk_text.strip()) > 0 + + yield MatchContext( + text=context_text, + # Copy over + entities=context.entities, + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + close_wildcards=is_chunk_non_empty, + close_unmatched=is_chunk_non_empty, + ) + elif wildcard is not None: + # Add to wildcard by skipping ahead in the text until we find + # the current chunk text. + skip_idx = match_first(context_text, chunk_text) + if skip_idx >= 0: + wildcard_text = context_text[:skip_idx] + + # Wildcards cannot be empty + if wildcard_text: + entities = [ + e for e in context.entities if e.name != wildcard.name + ] + entities.append( + MatchEntity( + name=wildcard.name, + value=wildcard_text, + text=wildcard_text, + is_wildcard=True, + is_wildcard_open=False, # always close + ) + ) + yield MatchContext( + text=context.text[skip_idx + len(chunk_text) :], + # Copy over + # entities=context.entities, + intent_context=context.intent_context, + is_start_of_word=True, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + entities=entities, + ) + elif settings.allow_unmatched_entities and ( + unmatched_entity := context.get_open_entity() + ): + # Add to the most recent unmatched entity by skipping ahead in + # the text until we find the current chunk text. + re_chunk_text = re.escape(chunk_text.strip()) + if settings.ignore_whitespace: + chunk_match = re.search(re_chunk_text, context_text) + else: + # Only skip to a word boundary + chunk_match = re.search( + rf"\s{re_chunk_text}(\s|$)", context_text + ) + + if chunk_match: + unmatched_entity_text = ( + unmatched_entity.text + + context_text[: chunk_match.start() + 1] + ) + + # Unmatched entities cannot be empty + if unmatched_entity_text: + # Make a copy of modified unmatched entity + unmatched_entities = [ + e + for e in context.unmatched_entities + if e.name != unmatched_entity.name + ] + unmatched_entities.append( + UnmatchedTextEntity( + name=unmatched_entity.name, + text=unmatched_entity_text, + is_open=False, # always close + ) + ) + + yield MatchContext( + text=context.text[chunk_match.end() :], + # Copy over + entities=context.entities, + intent_context=context.intent_context, + is_start_of_word=True, + text_chunks_matched=context.text_chunks_matched + + len(chunk.text.strip()), + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + unmatched_entities=unmatched_entities, + ) + else: + # Match failed + pass + elif isinstance(expression, Sequence): + seq: Sequence = expression + if seq.type == SequenceType.ALTERNATIVE: + # Any may match (words | in | alternative) + # NOTE: [optional] = (optional | ) + for item in seq.items: + yield from match_expression(settings, context, item) + + elif seq.type == SequenceType.GROUP: + if seq.items: + # All must match (words in group) + group_contexts = [context] + for item in seq.items: + # Next step + group_contexts = [ + item_context + for group_context in group_contexts + for item_context in match_expression( + settings, group_context, item + ) + ] + if not group_contexts: + break + + yield from group_contexts + else: + raise ValueError(f"Unexpected sequence type: {seq}") + + elif isinstance(expression, ListReference): + # {list} + list_ref: ListReference = expression + if (not settings.slot_lists) or (list_ref.list_name not in settings.slot_lists): + raise MissingListError(f"Missing slot list {{{list_ref.list_name}}}") + + wildcard = context.get_open_wildcard() + slot_list = settings.slot_lists[list_ref.list_name] + if isinstance(slot_list, TextSlotList): + if context.text: + text_list: TextSlotList = slot_list + # Any value may match + has_matches = False + + required_context: Optional[Dict[str, Any]] = None + excluded_context: Optional[Dict[str, Any]] = None + if context.intent_data is not None: + required_context = context.intent_data.requires_context + excluded_context = context.intent_data.excludes_context + + for slot_value in text_list.values: + # Filter possible values with required/excluded context + if required_context and ( + not check_required_context( + required_context, + slot_value.context, + allow_missing_keys=True, + ) + ): + continue + + if excluded_context and ( + not check_excluded_context(excluded_context, slot_value.context) + ): + continue + + if (isinstance(slot_value.text_in, TextChunk)) and ( + len(context.text) < len(slot_value.text_in.text) + ): + # Not enough text left to match + continue + + value_contexts = match_expression( + settings, + MatchContext( + # Copy over + text=context.text, + entities=context.entities, + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + ), + slot_value.text_in, + ) + + for value_context in value_contexts: + has_matches = True + value_wildcard: Optional[MatchEntity] = None + if ( + value_context.entities + and value_context.entities[-1].is_wildcard + ): + value_wildcard = value_context.entities[-1] + + if value_wildcard is not None and context.text.startswith( + value_wildcard.text + ): + # Remove wildcard text from value + remaining_text = context.text[len(value_wildcard.text) :] + else: + remaining_text = context.text + + entities = value_context.entities + [ + MatchEntity( + name=list_ref.slot_name, + value=slot_value.value_out, + text=( + remaining_text[: -len(value_context.text)] + if value_context.text + else remaining_text + ), + metadata=slot_value.metadata, + ) + ] + + if slot_value.context: + # Merge context from matched list value + yield MatchContext( + entities=entities, + intent_context={ + **context.intent_context, + **slot_value.context, + }, + # Copy over + text=value_context.text, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + ) + else: + yield MatchContext( + entities=entities, + # Copy over + text=value_context.text, + intent_context=value_context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + ) + + if (not has_matches) and settings.allow_unmatched_entities: + # Report mismatch + yield MatchContext( + # Copy over + text=context.text, + entities=context.entities, + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + unmatched_entities=context.unmatched_entities + + [UnmatchedTextEntity(name=list_ref.slot_name, text="")], + close_wildcards=True, + ) + + elif isinstance(slot_list, RangeSlotList): + if context.text: + # List that represents a number range. + range_list: RangeSlotList = slot_list + + number_matches: List[re.Match] = [] + if wildcard is None: + # Look for digits at the start of the incoming text + number_match = NUMBER_START.match(context.text) + if number_match is not None: + number_matches.append(number_match) + else: + # Look for digit(s) anywhere in the string. + # The wildcard will consume text up to that point. + number_matches.extend(NUMBER_ANYWHERE.finditer(context.text)) + + digits_match = False + if range_list.digits and number_matches: + for number_match in number_matches: + number_text = number_match[1] + word_number: Union[int, float] = int(number_text) + + # Check if number is within range of our list + if range_list.step == 1: + # Unit step + in_range = ( + range_list.start <= word_number <= range_list.stop + ) + else: + # Non-unit step + in_range = word_number in range( + range_list.start, range_list.stop + 1, range_list.step + ) + + if in_range: + # Number is in range + digits_match = True + range_value = word_number + if range_list.multiplier is not None: + range_value *= range_list.multiplier + + entities = context.entities + [ + MatchEntity( + name=list_ref.slot_name, + value=range_value, + text=number_match.group(1), + ) + ] + + if wildcard is None: + yield MatchContext( + text=context.text[number_match.end() :], + entities=entities, + # Copy over + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + ) + else: + # Wildcard consumes text before number + wildcard.text += context.text[: number_match.end() - 1] + wildcard.value = wildcard.text + yield MatchContext( + text=context.text[number_match.end() :], + entities=entities, + # Copy over + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + close_wildcards=True, + ) + elif settings.allow_unmatched_entities and (wildcard is None): + # Report out of range + yield MatchContext( + # Copy over + text=context.text[len(number_text) :], + entities=context.entities, + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + unmatched_entities=context.unmatched_entities + + [ + UnmatchedRangeEntity( + name=list_ref.slot_name, value=word_number + ) + ], + ) + + # Only check number words if: + # 1. Words are enabled for this list + # 2. We didn't already match digits + # 3. the incoming text doesn't start with digits + words_match: bool = False + if range_list.words and (not digits_match) and (not number_matches): + words_language = range_list.words_language or settings.language + if words_language: + range_settings = ( + range_list.start, + range_list.stop, + range_list.step, + ) + range_trie = _RANGE_TRIE_CACHE[words_language].get( + range_settings + ) + try: + if range_trie is None: + range_trie = _build_range_trie( + words_language, range_list + ) + _RANGE_TRIE_CACHE[words_language][ + range_settings + ] = range_trie + + for ( + number_end_pos, + number_text, + range_value, + ) in range_trie.find(context.text): + number_start_pos = number_end_pos - len(number_text) + if (wildcard is None) and (number_start_pos > 0): + # Can't possibly match because the number + # string isn't at the start of the text. + continue + + entities = context.entities + [ + MatchEntity( + name=list_ref.slot_name, + value=range_value, + text=number_text, + ) + ] + if wildcard is None: + yield from match_expression( + settings, + MatchContext( + text=context.text, + entities=entities, + # Copy over + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + ), + TextChunk(number_text), + ) + else: + # Wildcard consumes text before number + wildcard.text += context.text[:number_start_pos] + wildcard.value = wildcard.text + yield from match_expression( + settings, + MatchContext( + text=context.text[number_start_pos:], + entities=entities, + # Copy over + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + close_wildcards=True, + ), + TextChunk(number_text), + ) + except ValueError as error: + _LOGGER.debug( + "Unexpected error converting numbers to words for language '%s': %s", + settings.language, + str(error), + ) + + if ( + (not digits_match) + and (not words_match) + and settings.allow_unmatched_entities + ): + # Report not a number + yield MatchContext( + # Copy over + text=context.text, + entities=context.entities, + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + unmatched_entities=context.unmatched_entities + + [UnmatchedTextEntity(name=list_ref.slot_name, text="")], + close_wildcards=True, + ) + elif isinstance(slot_list, WildcardSlotList): + if context.text: + # Start wildcard entities + yield MatchContext( + # Copy over + text=context.text, + intent_context=context.intent_context, + is_start_of_word=context.is_start_of_word, + unmatched_entities=context.unmatched_entities, + text_chunks_matched=context.text_chunks_matched, + intent_sentence=context.intent_sentence, + intent_data=context.intent_data, + # + entities=context.entities + + [ + MatchEntity( + name=list_ref.slot_name, value="", text="", is_wildcard=True + ) + ], + close_unmatched=True, + ) + else: + raise ValueError(f"Unexpected slot list type: {slot_list}") + + elif isinstance(expression, RuleReference): + # + rule_ref: RuleReference = expression + if (not settings.expansion_rules) or ( + rule_ref.rule_name not in settings.expansion_rules + ): + raise MissingRuleError(f"Missing expansion rule <{rule_ref.rule_name}>") + + yield from match_expression( + settings, context, settings.expansion_rules[rule_ref.rule_name] + ) + else: + raise ValueError(f"Unexpected expression: {expression}") + + +def _build_range_trie(language: str, range_list: RangeSlotList) -> Trie: + range_trie = Trie() + + # Load number formatting engine + engine = _ENGINE_CACHE.get(language) + if engine is None: + engine = RbnfEngine.for_language(language) + _ENGINE_CACHE[language] = engine + + for word_number in range(range_list.start, range_list.stop + 1, range_list.step): + range_value: Union[float, int] = word_number + if range_list.multiplier is not None: + range_value *= range_list.multiplier + + format_result = engine.format_number(word_number) + used_words = set() + + for words in format_result.text_by_ruleset.values(): + if words in used_words: + continue + + range_trie.insert(words, range_value) + used_words.add(words) + + words = words.translate(BREAK_WORDS_TABLE) + if words in used_words: + continue + + range_trie.insert(words, range_value) + used_words.add(words) + + return range_trie diff --git a/hassil/trie.py b/hassil/trie.py new file mode 100644 index 0000000..845fe5b --- /dev/null +++ b/hassil/trie.py @@ -0,0 +1,87 @@ +"""Specialized implementation of a trie. + +See: https://en.wikipedia.org/wiki/Trie +""" + +from collections import deque +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Tuple + + +@dataclass +class TrieNode: + """Node in trie.""" + + id: int + text: Optional[str] = None + values: Optional[List[Any]] = None + children: "Optional[Dict[str, TrieNode]]" = None + + +class Trie: + """A specialized trie data structure that finds all known words in a string.""" + + def __init__(self) -> None: + self.roots: Dict[str, TrieNode] = {} + self._next_id = 0 + + def insert(self, text: str, value: Any) -> None: + """Insert a word and value into the trie.""" + current_node: Optional[TrieNode] = None + current_children: Optional[Dict[str, TrieNode]] = self.roots + + last_idx = len(text) - 1 + for i, c in enumerate(text): + if current_children is None: + assert current_node is not None + current_node.children = current_children = {} + + current_node = current_children.get(c) + if current_node is None: + current_node = TrieNode(id=self.next_id()) + current_children[c] = current_node + + if i == last_idx: + current_node.text = text + if current_node.values is None: + current_node.values = [value] + else: + current_node.values.append(value) + + current_children = current_node.children + + def find(self, text: str, unique: bool = True) -> Iterable[Tuple[int, str, Any]]: + """Yield (end_pos, text, value) pairs of all words found in the string.""" + q = deque([(self.roots, i) for i in range(len(text))]) + visited = set() + + while q: + item = q.popleft() + current_children, current_position = item + if current_position >= len(text): + continue + + current_char = text[current_position] + + node = current_children.get(current_char) + if (node is not None) and (node.id not in visited): + + if node.text is not None: + # End is one past the current position + if unique: + visited.add(node.id) + + if node.values: + for value in node.values: + yield (current_position + 1, node.text, value) + else: + # null value + yield (current_position + 1, node.text, None) + + if node.children and (current_position < len(text)): + q.append((node.children, current_position + 1)) + + def next_id(self) -> int: + current_id = self._next_id + self._next_id += 1 + return current_id diff --git a/hassil/util.py b/hassil/util.py new file mode 100644 index 0000000..6b60d12 --- /dev/null +++ b/hassil/util.py @@ -0,0 +1,216 @@ +"""Utility methods""" + +import collections +import re +import unicodedata +from typing import Any, Dict, Iterable, Optional + +WHITESPACE = re.compile(r"\s+") +WHITESPACE_CAPTURE = re.compile(r"(\s+)") +WHITESPACE_SEPARATOR = " " + +TEMPLATE_SYNTAX = re.compile(r".*[(){}<>\[\]|].*") + +PUNCTUATION_STR = ".。,,?¿?؟!¡!;;::’" +PUNCTUATION_PATTERN = rf"[{re.escape(PUNCTUATION_STR)}]+" +PUNCTUATION_ALL = re.compile(rf"{PUNCTUATION_PATTERN}") +PUNCTUATION_START = re.compile(rf"^{PUNCTUATION_PATTERN}") +PUNCTUATION_END = re.compile(rf"{PUNCTUATION_PATTERN}$") +PUNCTUATION_END_SPACE = re.compile(rf"{PUNCTUATION_PATTERN}\s*$") +PUNCTUATION_START_WORD = re.compile(rf"(?<=\W){PUNCTUATION_PATTERN}(?=\w)") +PUNCTUATION_END_WORD = re.compile(rf"(?<=\w){PUNCTUATION_PATTERN}(?=\W)") +PUNCTUATION_WORD = re.compile(rf"(?<=\W){PUNCTUATION_PATTERN}(?=\W)") + + +def merge_dict(base_dict, new_dict): + """Merges new_dict into base_dict.""" + for key, value in new_dict.items(): + if key in base_dict: + old_value = base_dict[key] + if isinstance(old_value, collections.abc.MutableMapping): + # Combine dictionary + assert isinstance( + value, collections.abc.Mapping + ), f"Not a dict: {value}" + merge_dict(old_value, value) + elif isinstance(old_value, collections.abc.MutableSequence): + # Combine list + assert isinstance( + value, collections.abc.Sequence + ), f"Not a list: {value}" + old_value.extend(value) + else: + # Overwrite + base_dict[key] = value + else: + base_dict[key] = value + + +def remove_escapes(text: str) -> str: + """Remove backslash escape sequences.""" + return re.sub(r"\\(.)", r"\1", text) + + +def normalize_whitespace(text: str) -> str: + """Makes all whitespace inside a string single spaced.""" + return WHITESPACE_CAPTURE.sub(WHITESPACE_SEPARATOR, text) + + +def normalize_text(text: str) -> str: + """Normalize whitespace and unicode forms.""" + text = normalize_whitespace(text) + text = unicodedata.normalize("NFC", text) + + return text + + +def is_template(text: str) -> bool: + """True if text contains template syntax""" + return TEMPLATE_SYNTAX.match(text) is not None + + +def check_required_context( + required_context: Dict[str, Any], + match_context: Optional[Dict[str, Any]], + allow_missing_keys: bool = False, +) -> bool: + """Return True if match context does not violate required context. + + Setting allow_missing_keys to True only checks existing keys in match + context. + """ + for ( + required_key, + required_value, + ) in required_context.items(): + if (not match_context) or (required_key not in match_context): + # Match is missing key + if allow_missing_keys: + # Only checking existing keys + continue + + return False + + if isinstance(required_value, collections.abc.Mapping): + # Unpack dict + # : + # value: ... + required_value = required_value.get("value") + + # Ensure value matches + actual_value = match_context[required_key] + + if isinstance(actual_value, collections.abc.Mapping): + # Unpack dict + # : + # value: ... + actual_value = actual_value.get("value") + + if (not isinstance(required_value, str)) and isinstance( + required_value, collections.abc.Collection + ): + if actual_value not in required_value: + # Match value not in required list + return False + elif (required_value is not None) and (actual_value != required_value): + # Match value doesn't equal required value + return False + + return True + + +def check_excluded_context( + excluded_context: Dict[str, Any], match_context: Optional[Dict[str, Any]] +) -> bool: + """Return True if match context does not violate excluded context.""" + for ( + excluded_key, + excluded_value, + ) in excluded_context.items(): + if (not match_context) or (excluded_key not in match_context): + continue + + if isinstance(excluded_value, collections.abc.Mapping): + # Unpack dict + # : + # value: ... + excluded_value = excluded_value.get("value") + + # Ensure value does not match + actual_value = match_context[excluded_key] + + if isinstance(actual_value, collections.abc.Mapping): + # Unpack dict + # : + # value: ... + actual_value = actual_value.get("value") + + if (not isinstance(excluded_value, str)) and isinstance( + excluded_value, collections.abc.Collection + ): + if actual_value in excluded_value: + # Match value is in excluded list + return False + elif actual_value == excluded_value: + # Match value equals excluded value + return False + + return True + + +def remove_skip_words( + text: str, skip_words: Iterable[str], ignore_whitespace: bool +) -> str: + if not skip_words: + return text + + if ignore_whitespace: + skip_words_pattern = re.compile( + r"(" + + "|".join( + re.escape(w.strip()) for w in sorted(skip_words, key=len, reverse=True) + ) + + r")", + re.IGNORECASE, + ) + return skip_words_pattern.sub("", text) + + skip_words_pattern = re.compile( + r"(?<=\W)(" + + "|".join( + re.escape(w.strip()) for w in sorted(skip_words, key=len, reverse=True) + ) + + r")(?=\W)", + re.IGNORECASE, + ) + text = skip_words_pattern.sub(" ", f" {text} ").strip() + return normalize_whitespace(text) + + +def remove_punctuation(text: str) -> str: + text = PUNCTUATION_START.sub("", text) + text = PUNCTUATION_END.sub("", text) + text = PUNCTUATION_START_WORD.sub("", text) + text = PUNCTUATION_END_WORD.sub("", text) + text = PUNCTUATION_WORD.sub("", text) + + return text + + +def match_start(text: str, prefix: str) -> Optional[int]: + match = re.match(rf"^{re.escape(prefix)}", text, re.IGNORECASE) + if match is None: + return None + + return match.end() + + +def match_first(text: str, prefix: str, start_idx: int = 0) -> int: + if start_idx > 0: + text = text[start_idx:] + + match = re.search(rf"{re.escape(prefix)}", text, re.IGNORECASE) + if match is None: + return -1 + + return start_idx + match.start() diff --git a/requirements.txt b/requirements.txt index 81ea19f..96320f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ wyoming==1.5.4 -hassil~=2.0.0 unicode-rbnf>=2,<3 regex==2024.11.6 Flask[async]~=3.1.0