From d87ef1f6452faf50f627e22d6b2c9b8a366c86d4 Mon Sep 17 00:00:00 2001 From: GiviMAD Date: Wed, 13 Sep 2023 21:09:13 +0200 Subject: [PATCH] [voice] Add dialog group and location (#3798) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Miguel Álvarez --- .../org/openhab/core/voice/DialogContext.java | 94 +++---------- .../core/voice/DialogRegistration.java | 8 ++ .../core/voice/internal/DialogProcessor.java | 54 +++++--- .../VoiceConsoleCommandExtension.java | 33 +++-- .../core/voice/internal/VoiceManagerImpl.java | 11 +- .../text/AbstractRuleBasedInterpreter.java | 125 +++++++++++++----- .../voice/text/HumanLanguageInterpreter.java | 14 ++ .../org/openhab/core/voice/text/Rule.java | 9 +- .../text/StandardInterpreterTest.java | 28 ++++ 9 files changed, 231 insertions(+), 145 deletions(-) diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogContext.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogContext.java index 03e0f3311e9..348e03bf92e 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogContext.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogContext.java @@ -29,78 +29,10 @@ * @author Miguel Álvarez - Initial contribution */ @NonNullByDefault -public class DialogContext { - private final @Nullable KSService ks; - private final @Nullable String keyword; - private final STTService stt; - private final TTSService tts; - private final @Nullable Voice voice; - private final List hlis; - private final AudioSource source; - private final AudioSink sink; - private final Locale locale; - private final @Nullable String listeningItem; - private final @Nullable String listeningMelody; - - public DialogContext(@Nullable KSService ks, @Nullable String keyword, STTService stt, TTSService tts, - @Nullable Voice voice, List hlis, AudioSource source, AudioSink sink, - Locale locale, @Nullable String listeningItem, @Nullable String listeningMelody) { - this.ks = ks; - this.keyword = keyword; - this.stt = stt; - this.tts = tts; - this.voice = voice; - this.hlis = hlis; - this.source = source; - this.sink = sink; - this.locale = locale; - this.listeningItem = listeningItem; - this.listeningMelody = listeningMelody; - } - - public @Nullable KSService ks() { - return ks; - } - - public @Nullable String keyword() { - return keyword; - } - - public STTService stt() { - return stt; - } - - public TTSService tts() { - return tts; - } - - public @Nullable Voice voice() { - return voice; - } - - public List hlis() { - return hlis; - } - - public AudioSource source() { - return source; - } - - public AudioSink sink() { - return sink; - } - - public Locale locale() { - return locale; - } - - public @Nullable String listeningItem() { - return listeningItem; - } - - public @Nullable String listeningMelody() { - return listeningMelody; - } +public record DialogContext(@Nullable KSService ks, @Nullable String keyword, STTService stt, TTSService tts, + @Nullable Voice voice, List hlis, AudioSource source, AudioSink sink, Locale locale, + String dialogGroup, @Nullable String locationItem, @Nullable String listeningItem, + @Nullable String listeningMelody) { /** * Builder for {@link DialogContext} @@ -116,6 +48,8 @@ public static class Builder { private @Nullable Voice voice; private List hlis = List.of(); // options + private String dialogGroup = "default"; + private @Nullable String locationItem; private @Nullable String listeningItem; private @Nullable String listeningMelody; private String keyword; @@ -189,6 +123,20 @@ public Builder withVoice(@Nullable Voice voice) { return this; } + public Builder withDialogGroup(@Nullable String dialogGroup) { + if (dialogGroup != null) { + this.dialogGroup = dialogGroup; + } + return this; + } + + public Builder withLocationItem(@Nullable String locationItem) { + if (locationItem != null) { + this.locationItem = locationItem; + } + return this; + } + public Builder withListeningItem(@Nullable String listeningItem) { if (listeningItem != null) { this.listeningItem = listeningItem; @@ -244,7 +192,7 @@ public DialogContext build() throws IllegalStateException { throw new IllegalStateException("Cannot build dialog context: " + String.join(", ", errors) + "."); } else { return new DialogContext(ksService, keyword, sttService, ttsService, voice, hliServices, audioSource, - audioSink, locale, listeningItem, listeningMelody); + audioSink, locale, dialogGroup, locationItem, listeningItem, listeningMelody); } } } diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogRegistration.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogRegistration.java index 12421a7e480..655026135d8 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogRegistration.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/DialogRegistration.java @@ -65,6 +65,14 @@ public class DialogRegistration { * Linked listening item */ public @Nullable String listeningItem; + /** + * Linked location item + */ + public @Nullable String locationItem; + /** + * Dialog group name + */ + public @Nullable String dialogGroup; /** * Custom listening melody */ diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/DialogProcessor.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/DialogProcessor.java index f23b962140f..d171d989cd5 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/DialogProcessor.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/DialogProcessor.java @@ -17,6 +17,7 @@ import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.WeakHashMap; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -69,13 +70,13 @@ * @author Miguel Álvarez - Close audio streams + use RecognitionStartEvent * @author Miguel Álvarez - Use dialog context * @author Miguel Álvarez - Add sounds + * @author Miguel Álvarez - Add dialog groups * */ @NonNullByDefault public class DialogProcessor implements KSListener, STTListener { - private final Logger logger = LoggerFactory.getLogger(DialogProcessor.class); - + private final WeakHashMap activeDialogGroups; public final DialogContext dialogContext; private @Nullable List listeningMelody; private final EventPublisher eventPublisher; @@ -105,11 +106,12 @@ public class DialogProcessor implements KSListener, STTListener { private @Nullable ToneSynthesizer toneSynthesizer; public DialogProcessor(DialogContext context, DialogEventListener eventListener, EventPublisher eventPublisher, - TranslationProvider i18nProvider, Bundle bundle) { + WeakHashMap activeDialogGroups, TranslationProvider i18nProvider, Bundle bundle) { this.dialogContext = context; this.eventListener = eventListener; this.eventPublisher = eventPublisher; this.i18nProvider = i18nProvider; + this.activeDialogGroups = activeDialogGroups; this.bundle = bundle; var ks = context.ks(); this.ksFormat = ks != null @@ -182,7 +184,15 @@ public void start() throws IllegalStateException { * Starts a single dialog */ public void startSimpleDialog() { - abortSTT(); + synchronized (activeDialogGroups) { + if (!activeDialogGroups.containsKey(dialogContext.dialogGroup())) { + logger.debug("Acquiring dialog group '{}'", dialogContext.dialogGroup()); + activeDialogGroups.put(dialogContext.dialogGroup(), dialogContext); + } else { + logger.warn("Ignoring keyword spotting event, dialog group '{}' running", dialogContext.dialogGroup()); + return; + } + } closeStreamSTT(); isSTTServerAborting = false; AudioFormat fmt = sttFormat; @@ -196,6 +206,7 @@ public void startSimpleDialog() { AudioStream stream = dialogContext.source().getInputStream(fmt); streamSTT = stream; sttServiceHandle = dialogContext.stt().recognize(this, stream, dialogContext.locale(), new HashSet<>()); + return; } catch (AudioException e) { logger.warn("Error creating the audio stream: {}", e.getMessage()); } catch (STTException e) { @@ -208,6 +219,11 @@ public void startSimpleDialog() { say(text.replace("{0}", "")); } } + // In case of error release dialog group + synchronized (activeDialogGroups) { + logger.debug("Releasing dialog group '{}' due to errors", dialogContext.dialogGroup()); + activeDialogGroups.remove(dialogContext.dialogGroup()); + } } /** @@ -264,6 +280,10 @@ private void abortSTT() { sttServiceHandle = null; } isSTTServerAborting = true; + synchronized (activeDialogGroups) { + logger.debug("Releasing dialog group '{}'", dialogContext.dialogGroup()); + activeDialogGroups.remove(dialogContext.dialogGroup()); + } } private void closeStreamSTT() { @@ -292,20 +312,18 @@ private void toggleProcessing(boolean value) { @Override public void ksEventReceived(KSEvent ksEvent) { - if (!processing) { - isSTTServerAborting = false; - if (ksEvent instanceof KSpottedEvent) { - logger.debug("KSpottedEvent event received"); - try { - startSimpleDialog(); - } catch (IllegalStateException e) { - logger.warn("{}", e.getMessage()); - } - } else if (ksEvent instanceof KSErrorEvent kse) { - logger.debug("KSErrorEvent event received"); - String text = i18nProvider.getText(bundle, "error.ks-error", null, dialogContext.locale()); - say(text == null ? kse.getMessage() : text.replace("{0}", kse.getMessage())); + isSTTServerAborting = false; + if (ksEvent instanceof KSpottedEvent) { + logger.debug("KSpottedEvent event received"); + try { + startSimpleDialog(); + } catch (IllegalStateException e) { + logger.warn("{}", e.getMessage()); } + } else if (ksEvent instanceof KSErrorEvent kse) { + logger.debug("KSErrorEvent event received"); + String text = i18nProvider.getText(bundle, "error.ks-error", null, dialogContext.locale()); + say(text == null ? kse.getMessage() : text.replace("{0}", kse.getMessage())); } } @@ -322,7 +340,7 @@ public synchronized void sttEventReceived(STTEvent sttEvent) { String error = null; for (HumanLanguageInterpreter interpreter : dialogContext.hlis()) { try { - answer = interpreter.interpret(dialogContext.locale(), question); + answer = interpreter.interpret(dialogContext.locale(), question, dialogContext); logger.debug("Interpretation result: {}", answer); error = null; break; diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceConsoleCommandExtension.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceConsoleCommandExtension.java index c4eb1ff0941..7817952d31c 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceConsoleCommandExtension.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceConsoleCommandExtension.java @@ -97,17 +97,17 @@ public List getUsages() { buildCommandUsage(SUBCMD_DIALOG_REGS, "lists the existing dialog registrations and their selected audio/voice services"), buildCommandUsage(SUBCMD_REGISTER_DIALOG - + " [--source ] [--sink ] [--hlis ] [--tts [--voice ]] [--stt ] [--ks ks [--keyword ]] [--listening-item ]", + + " [--source ] [--sink ] [--hlis ] [--tts [--voice ]] [--stt ] [--ks ks [--keyword ]] [--listening-item ] [--location-item ] [--dialog-group ]", "register a new dialog processing using the default services or the services identified with provided arguments, it will be persisted and keep running whenever is possible."), buildCommandUsage(SUBCMD_UNREGISTER_DIALOG + " [source]", "unregister the dialog processing for the default audio source or the audio source identified with provided argument, stopping it if started"), buildCommandUsage(SUBCMD_START_DIALOG - + " [--source ] [--sink ] [--hlis ] [--tts [--voice ]] [--stt ] [--ks ks [--keyword ]] [--listening-item ]", + + " [--source ] [--sink ] [--hlis ] [--tts [--voice ]] [--stt ] [--ks ks [--keyword ]] [--listening-item ] [--location-item ] [--dialog-group ]", "start a new dialog processing using the default services or the services identified with provided arguments"), buildCommandUsage(SUBCMD_STOP_DIALOG + " []", "stop the dialog processing for the default audio source or the audio source identified with provided argument"), buildCommandUsage(SUBCMD_LISTEN_ANSWER - + " [--source ] [--sink ] [--hlis ] [--tts [--voice ]] [--stt ] [--listening-item ]", + + " [--source ] [--sink ] [--hlis ] [--tts [--voice ]] [--stt ] [--listening-item ] [--location-item ] [--dialog-group ]", "Execute a simple dialog sequence without keyword spotting using the default services or the services identified with provided arguments"), buildCommandUsage(SUBCMD_INTERPRETERS, "lists the interpreters"), buildCommandUsage(SUBCMD_KEYWORD_SPOTTERS, "lists the keyword spotters"), @@ -309,11 +309,12 @@ private void listDialogRegistrations(Console console) { Collection registrations = voiceManager.getDialogRegistrations(); if (!registrations.isEmpty()) { registrations.stream().sorted(comparing(dr -> dr.sourceId)).forEach(dr -> { - console.println( - String.format(" Source: %s - Sink: %s (STT: %s, TTS: %s, HLIs: %s, KS: %s, Keyword: %s)", - dr.sourceId, dr.sinkId, getOrDefault(dr.sttId), getOrDefault(dr.ttsId), - dr.hliIds.isEmpty() ? getOrDefault(null) : String.join("->", dr.hliIds), - getOrDefault(dr.ksId), getOrDefault(dr.keyword))); + String locationText = dr.locationItem != null ? String.format(" Location: %s", dr.locationItem) : ""; + console.println(String.format( + " Source: %s - Sink: %s (STT: %s, TTS: %s, HLIs: %s, KS: %s, Keyword: %s, Dialog Group: %s)%s", + dr.sourceId, dr.sinkId, getOrDefault(dr.sttId), getOrDefault(dr.ttsId), + dr.hliIds.isEmpty() ? getOrDefault(null) : String.join("->", dr.hliIds), getOrDefault(dr.ksId), + getOrDefault(dr.keyword), getOrDefault(dr.dialogGroup), locationText)); }); } else { console.println("No dialog registrations."); @@ -330,11 +331,12 @@ private void listDialogs(Console console) { dialogContexts.stream().sorted(comparing(s -> s.source().getId())).forEach(c -> { var ks = c.ks(); String ksText = ks != null ? String.format(", KS: %s, Keyword: %s", ks.getId(), c.keyword()) : ""; - console.println( - String.format(" Source: %s - Sink: %s (STT: %s, TTS: %s, HLIs: %s%s)", c.source().getId(), - c.sink().getId(), c.stt().getId(), c.tts().getId(), c.hlis().stream() - .map(HumanLanguageInterpreter::getId).collect(Collectors.joining("->")), - ksText)); + String locationText = c.locationItem() != null ? String.format(" Location: %s", c.locationItem()) : ""; + console.println(String.format( + " Source: %s - Sink: %s (STT: %s, TTS: %s, HLIs: %s%s, Dialog Group: %s)%s", c.source().getId(), + c.sink().getId(), c.stt().getId(), c.tts().getId(), + c.hlis().stream().map(HumanLanguageInterpreter::getId).collect(Collectors.joining("->")), + ksText, c.dialogGroup(), locationText)); }); } else { console.println("No running dialogs."); @@ -450,6 +452,8 @@ private DialogContext.Builder parseDialogContext(String[] args) { .withHLIs(voiceManager.getHLIsByIds(parameters.remove("hlis"))) // .withKS(voiceManager.getKS(parameters.remove("ks"))) // .withListeningItem(parameters.remove("listening-item")) // + .withLocationItem(parameters.remove("location-item")) // + .withDialogGroup(parameters.remove("dialog-group")) // .withKeyword(parameters.remove("keyword")); if (!parameters.isEmpty()) { throw new IllegalStateException( @@ -483,6 +487,9 @@ private DialogRegistration parseDialogRegistration(String[] args) { dr.ttsId = parameters.remove("tts"); dr.voiceId = parameters.remove("voice"); dr.listeningItem = parameters.remove("listening-item"); + dr.locationItem = parameters.remove("location-item"); + dr.dialogGroup = parameters.remove("dialog-group"); + String hliIds = parameters.remove("hlis"); if (hliIds != null) { dr.hliIds = Arrays.stream(hliIds.split(",")).map(String::trim).collect(Collectors.toList()); diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceManagerImpl.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceManagerImpl.java index c70a2d84c22..6b5c5285002 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceManagerImpl.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/internal/VoiceManagerImpl.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.WeakHashMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -116,6 +117,8 @@ public class VoiceManagerImpl implements VoiceManager, ConfigOptionProvider, Dia private final Map ttsServices = new HashMap<>(); private final Map humanLanguageInterpreters = new HashMap<>(); + private final WeakHashMap activeDialogGroups = new WeakHashMap<>(); + private final LocaleProvider localeProvider; private final AudioManager audioManager; private final EventPublisher eventPublisher; @@ -526,7 +529,8 @@ public void startDialog(DialogContext context) throws IllegalStateException { if (processor == null) { logger.debug("Starting a new dialog for source {} ({})", context.source().getLabel(null), context.source().getId()); - processor = new DialogProcessor(context, this, this.eventPublisher, this.i18nProvider, b); + processor = new DialogProcessor(context, this, this.eventPublisher, this.activeDialogGroups, + this.i18nProvider, b); dialogProcessors.put(context.source().getId(), processor); processor.start(); } else { @@ -582,7 +586,8 @@ public void listenAndAnswer(DialogContext context) throws IllegalStateException isSingleDialog = true; activeProcessor = singleDialogProcessors.get(audioSource.getId()); } - var processor = new DialogProcessor(context, this, this.eventPublisher, this.i18nProvider, b); + var processor = new DialogProcessor(context, this, this.eventPublisher, this.activeDialogGroups, + this.i18nProvider, b); if (activeProcessor == null) { logger.debug("Executing a simple dialog for source {} ({})", audioSource.getLabel(null), audioSource.getId()); @@ -970,6 +975,8 @@ private void buildDialogRegistrations() { .withVoice(getVoice(dr.voiceId)) // .withHLIs(getHLIsByIds(dr.hliIds)) // .withLocale(dr.locale) // + .withDialogGroup(dr.dialogGroup) // + .withLocationItem(dr.locationItem) // .withListeningItem(dr.listeningItem) // .withMelody(dr.listeningMelody) // .build()); diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/AbstractRuleBasedInterpreter.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/AbstractRuleBasedInterpreter.java index 0d3a0039d46..9e11ebfe157 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/AbstractRuleBasedInterpreter.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/AbstractRuleBasedInterpreter.java @@ -36,10 +36,12 @@ import org.openhab.core.items.MetadataKey; import org.openhab.core.items.MetadataRegistry; import org.openhab.core.items.events.ItemEventFactory; +import org.openhab.core.library.CoreItemFactory; import org.openhab.core.library.types.DecimalType; import org.openhab.core.library.types.StringType; import org.openhab.core.types.Command; import org.openhab.core.types.State; +import org.openhab.core.voice.DialogContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +51,7 @@ * @author Tilman Kamp - Initial contribution * @author Kai Kreuzer - Improved error handling * @author Miguel Álvarez - Reduce collisions on exact match and use item synonyms + * @author Miguel Álvarez - Reduce collisions using dialog location */ @NonNullByDefault public abstract class AbstractRuleBasedInterpreter implements HumanLanguageInterpreter { @@ -79,7 +82,7 @@ public abstract class AbstractRuleBasedInterpreter implements HumanLanguageInter private final Map> languageRules = new HashMap<>(); private final Map> allItemTokens = new HashMap<>(); - private final Map>>>> itemTokens = new HashMap<>(); + private final Map> itemTokens = new HashMap<>(); private final ItemRegistry itemRegistry; private final EventPublisher eventPublisher; @@ -143,6 +146,12 @@ protected void deactivate() { @Override public String interpret(Locale locale, String text) throws InterpretationException { + return interpret(locale, text, null); + } + + @Override + public String interpret(Locale locale, String text, @Nullable DialogContext dialogContext) + throws InterpretationException { ResourceBundle language = ResourceBundle.getBundle(LANGUAGE_SUPPORT, locale); Rule[] rules = getRules(locale); if (rules.length == 0) { @@ -157,7 +166,7 @@ public String interpret(Locale locale, String text) throws InterpretationExcepti InterpretationResult lastResult = null; for (Rule rule : rules) { - if ((result = rule.execute(language, tokens)).isSuccess()) { + if ((result = rule.execute(language, tokens, dialogContext)).isSuccess()) { return result.getResponse(); } else { if (!InterpretationResult.SYNTAX_ERROR.equals(result)) { @@ -208,13 +217,13 @@ Set getAllItemTokens(Locale locale) { * @param locale The locale that is to be used for preparing the tokens. * @return the list of identifier token sets per item */ - Map>>> getItemTokens(Locale locale) { - Map>>> localeTokens = itemTokens.get(locale); + Map getItemTokens(Locale locale) { + Map localeTokens = itemTokens.get(locale); if (localeTokens == null) { itemTokens.put(locale, localeTokens = new HashMap<>()); for (Item item : itemRegistry.getItems()) { if (item.getGroupNames().isEmpty()) { - addItem(locale, localeTokens, new ArrayList<>(), item); + addItem(locale, localeTokens, new ArrayList<>(), item, new ArrayList<>()); } } } @@ -227,23 +236,27 @@ private String[] getItemSynonyms(Item item) { return (synonymsMetadata != null) ? synonymsMetadata.getValue().split(",") : new String[] {}; } - private void addItem(Locale locale, Map>>> target, List> tokens, - Item item) { - addItem(locale, target, tokens, item, item.getLabel()); + private void addItem(Locale locale, Map target, List> tokens, + Item item, ArrayList locationParentNames) { + addItem(locale, target, tokens, item, item.getLabel(), locationParentNames); for (String synonym : getItemSynonyms(item)) { - addItem(locale, target, tokens, item, synonym); + addItem(locale, target, tokens, item, synonym, locationParentNames); } } - private void addItem(Locale locale, Map>>> target, List> tokens, - Item item, @Nullable String itemLabel) { + private void addItem(Locale locale, Map target, List> tokens, + Item item, @Nullable String itemLabel, ArrayList locationParentNames) { List> nt = new ArrayList<>(tokens); nt.add(tokenize(locale, itemLabel)); - List>> list = target.computeIfAbsent(item, k -> new ArrayList<>()); - list.add(nt); + ItemInterpretationMetadata metadata = target.computeIfAbsent(item, k -> new ItemInterpretationMetadata()); + metadata.pathToItem.add(nt); + metadata.locationParentNames.addAll(locationParentNames); if (item instanceof GroupItem groupItem) { + if (item.hasTag(CoreItemFactory.LOCATION)) { + locationParentNames.add(item.getName()); + } for (Item member : groupItem.getMembers()) { - addItem(locale, target, nt, member); + addItem(locale, target, nt, member, locationParentNames); } } } @@ -353,7 +366,8 @@ protected Rule itemRule(Object headExpression, @Nullable Object tailExpression) Expression expression = tail == null ? seq(headExpression, name()) : seq(headExpression, name(tail), tail); return new Rule(expression) { @Override - public InterpretationResult interpretAST(ResourceBundle language, ASTNode node) { + public InterpretationResult interpretAST(ResourceBundle language, ASTNode node, + @Nullable DialogContext dialogContext) { String[] name = node.findValueAsStringArray(NAME); ASTNode cmdNode = node.findNode(CMD); Object tag = cmdNode.getTag(); @@ -368,7 +382,7 @@ public InterpretationResult interpretAST(ResourceBundle language, ASTNode node) } if (name != null) { try { - return new InterpretationResult(true, executeSingle(language, name, command)); + return new InterpretationResult(true, executeSingle(language, name, command, dialogContext)); } catch (InterpretationException ex) { return new InterpretationResult(ex); } @@ -538,11 +552,11 @@ protected ExpressionCardinality plus(Object expression) { * @return response text * @throws InterpretationException in case that there is no or more than on item matching the fragments */ - protected String executeSingle(ResourceBundle language, String[] labelFragments, Command command) - throws InterpretationException { - List items = getMatchingItems(language, labelFragments, command.getClass()); + protected String executeSingle(ResourceBundle language, String[] labelFragments, Command command, + @Nullable DialogContext dialogContext) throws InterpretationException { + List items = getMatchingItems(language, labelFragments, command.getClass(), dialogContext); if (items.isEmpty()) { - if (!getMatchingItems(language, labelFragments, null).isEmpty()) { + if (!getMatchingItems(language, labelFragments, null, dialogContext).isEmpty()) { throw new InterpretationException( language.getString(COMMAND_NOT_ACCEPTED).replace("", command.toString())); } else { @@ -596,13 +610,14 @@ protected String executeSingle(ResourceBundle language, String[] labelFragments, * @return All matching items from the item registry. */ protected List getMatchingItems(ResourceBundle language, String[] labelFragments, - @Nullable Class commandType) { - Set items = new HashSet<>(); - Set exactMatchItems = new HashSet<>(); - Map>>> map = getItemTokens(language.getLocale()); - for (Entry>>> entry : map.entrySet()) { + @Nullable Class commandType, @Nullable DialogContext dialogContext) { + Map itemsData = new HashMap<>(); + Map exactMatchItemsData = new HashMap<>(); + Map map = getItemTokens(language.getLocale()); + for (Entry entry : map.entrySet()) { Item item = entry.getKey(); - for (List> itemLabelFragmentsPath : entry.getValue()) { + ItemInterpretationMetadata interpretationMetadata = entry.getValue(); + for (List> itemLabelFragmentsPath : interpretationMetadata.pathToItem) { boolean exactMatch = false; logger.trace("Checking tokens {} against the item tokens {}", labelFragments, itemLabelFragmentsPath); List lowercaseLabelFragments = Arrays.stream(labelFragments) @@ -617,13 +632,13 @@ protected List getMatchingItems(ResourceBundle language, String[] labelFra unmatchedFragments.removeAll(itemLabelFragments); } boolean allMatched = unmatchedFragments.isEmpty(); - logger.trace("All labels matched: {}", allMatched); + logger.trace("Matched: {}", allMatched); logger.trace("Exact match: {}", exactMatch); if (allMatched) { if (commandType == null || item.getAcceptedCommandTypes().contains(commandType)) { - insertDiscardingMembers(items, item); + insertDiscardingMembers(itemsData, item, interpretationMetadata); if (exactMatch) { - insertDiscardingMembers(exactMatchItems, item); + insertDiscardingMembers(exactMatchItemsData, item, interpretationMetadata); } } } @@ -632,19 +647,49 @@ protected List getMatchingItems(ResourceBundle language, String[] labelFra if (logger.isDebugEnabled()) { String typeDetails = commandType != null ? " that accept " + commandType.getSimpleName() : ""; logger.debug("Partial matched items against {}{}: {}", labelFragments, typeDetails, - items.stream().map(Item::getName).collect(Collectors.joining(", "))); + itemsData.keySet().stream().map(Item::getName).collect(Collectors.joining(", "))); logger.debug("Exact matched items against {}{}: {}", labelFragments, typeDetails, - exactMatchItems.stream().map(Item::getName).collect(Collectors.joining(", "))); + exactMatchItemsData.keySet().stream().map(Item::getName).collect(Collectors.joining(", "))); + } + @Nullable + String locationContext = dialogContext != null ? dialogContext.locationItem() : null; + if (locationContext != null && itemsData.size() > 1) { + logger.debug("Filtering {} matched items based on location '{}'", itemsData.size(), locationContext); + Item matchByLocation = filterMatchedItemsByLocation(itemsData, locationContext); + if (matchByLocation != null) { + return List.of(matchByLocation); + } + } + if (locationContext != null && exactMatchItemsData.size() > 1) { + logger.debug("Filtering {} exact matched items based on location '{}'", exactMatchItemsData.size(), + locationContext); + Item matchByLocation = filterMatchedItemsByLocation(exactMatchItemsData, locationContext); + if (matchByLocation != null) { + return List.of(matchByLocation); + } } - return new ArrayList<>(items.size() != 1 && exactMatchItems.size() == 1 ? exactMatchItems : items); + return new ArrayList<>(itemsData.size() != 1 && exactMatchItemsData.size() == 1 ? exactMatchItemsData.keySet() + : itemsData.keySet()); } - private static void insertDiscardingMembers(Set items, Item item) { + @Nullable + private Item filterMatchedItemsByLocation(Map itemsData, String locationContext) { + var itemsFilteredByLocation = itemsData.entrySet().stream() + .filter((entry) -> entry.getValue().locationParentNames.contains(locationContext)).toList(); + if (itemsFilteredByLocation.size() != 1) { + return null; + } + logger.debug("Unique match by location found in '{}', taking prevalence", locationContext); + return itemsFilteredByLocation.get(0).getKey(); + } + + private static void insertDiscardingMembers(Map items, Item item, + ItemInterpretationMetadata interpretationMetadata) { String name = item.getName(); - boolean insert = items.stream().noneMatch(i -> name.startsWith(i.getName())); + boolean insert = items.keySet().stream().noneMatch(i -> name.startsWith(i.getName())); if (insert) { - items.removeIf((matchedItem) -> matchedItem.getName().startsWith(name)); - items.add(item); + items.keySet().removeIf((matchedItem) -> matchedItem.getName().startsWith(name)); + items.put(item, interpretationMetadata); } } @@ -919,4 +964,12 @@ String getGrammar() { JSGFGenerator generator = new JSGFGenerator(ResourceBundle.getBundle(LANGUAGE_SUPPORT, locale)); return generator.getGrammar(); } + + private static class ItemInterpretationMetadata { + final List>> pathToItem = new ArrayList<>(); + final List locationParentNames = new ArrayList<>(); + + ItemInterpretationMetadata() { + } + } } diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/HumanLanguageInterpreter.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/HumanLanguageInterpreter.java index 9a5364d1ffd..40cd87c7bfa 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/HumanLanguageInterpreter.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/HumanLanguageInterpreter.java @@ -17,6 +17,7 @@ import org.eclipse.jdt.annotation.NonNullByDefault; import org.eclipse.jdt.annotation.Nullable; +import org.openhab.core.voice.DialogContext; /** * This is the interface that a human language text interpreter has to implement. @@ -50,6 +51,19 @@ public interface HumanLanguageInterpreter { */ String interpret(Locale locale, String text) throws InterpretationException; + /** + * Interprets a human language text fragment of a given {@link Locale} with optional access to the context of a + * dialog execution. + * + * @param locale language of the text (given by a {@link Locale}) + * @param text the text to interpret + * @return a human language response + */ + default String interpret(Locale locale, String text, @Nullable DialogContext dialogContext) + throws InterpretationException { + return interpret(locale, text); + } + /** * Gets the grammar of all commands of a given {@link Locale} of the interpreter * diff --git a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/Rule.java b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/Rule.java index f74e65467fd..97e20530fae 100644 --- a/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/Rule.java +++ b/bundles/org.openhab.core.voice/src/main/java/org/openhab/core/voice/text/Rule.java @@ -15,6 +15,8 @@ import java.util.ResourceBundle; import org.eclipse.jdt.annotation.NonNullByDefault; +import org.eclipse.jdt.annotation.Nullable; +import org.openhab.core.voice.DialogContext; /** * Represents an expression plus action code that will be executed after successful parsing. This class is immutable and @@ -43,12 +45,13 @@ public Rule(Expression expression) { * @param node the resulting AST node of the parse run. To be used as input. * @return */ - public abstract InterpretationResult interpretAST(ResourceBundle language, ASTNode node); + public abstract InterpretationResult interpretAST(ResourceBundle language, ASTNode node, + @Nullable DialogContext dialogContext); - InterpretationResult execute(ResourceBundle language, TokenList list) { + InterpretationResult execute(ResourceBundle language, TokenList list, @Nullable DialogContext dialogContext) { ASTNode node = expression.parse(language, list); if (node.isSuccess() && node.getRemainingTokens().eof()) { - return interpretAST(language, node); + return interpretAST(language, node, dialogContext); } return InterpretationResult.SYNTAX_ERROR; } diff --git a/bundles/org.openhab.core.voice/src/test/java/org/openhab/core/voice/internal/text/StandardInterpreterTest.java b/bundles/org.openhab.core.voice/src/test/java/org/openhab/core/voice/internal/text/StandardInterpreterTest.java index 708f084be84..910c7c5f777 100644 --- a/bundles/org.openhab.core.voice/src/test/java/org/openhab/core/voice/internal/text/StandardInterpreterTest.java +++ b/bundles/org.openhab.core.voice/src/test/java/org/openhab/core/voice/internal/text/StandardInterpreterTest.java @@ -29,6 +29,8 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.openhab.core.audio.AudioSink; +import org.openhab.core.audio.AudioSource; import org.openhab.core.events.EventPublisher; import org.openhab.core.items.GroupItem; import org.openhab.core.items.Item; @@ -39,6 +41,9 @@ import org.openhab.core.items.events.ItemEventFactory; import org.openhab.core.library.items.SwitchItem; import org.openhab.core.library.types.OnOffType; +import org.openhab.core.voice.DialogContext; +import org.openhab.core.voice.STTService; +import org.openhab.core.voice.TTSService; import org.openhab.core.voice.text.InterpretationException; /** @@ -55,6 +60,11 @@ public class StandardInterpreterTest { private @Mock @NonNullByDefault({}) ItemRegistry itemRegistryMock; private @Mock @NonNullByDefault({}) MetadataRegistry metadataRegistryMock; private @NonNullByDefault({}) StandardInterpreter standardInterpreter; + private @NonNullByDefault({}) STTService sttService; + private @NonNullByDefault({}) TTSService ttsService; + private @NonNullByDefault({}) AudioSource audioSource; + private @NonNullByDefault({}) AudioSink audioSink; + private static final String OK_RESPONSE = "Ok."; @BeforeEach @@ -94,6 +104,24 @@ public void noNameCollisionOnSingleExactMatchForGroups() throws InterpretationEx .post(ItemEventFactory.createCommandEvent(computerSwitchItem.getName(), OnOffType.OFF)); } + @Test + public void noNameCollisionWhenDialogContext() throws InterpretationException { + var locationItem = Mockito.spy(new GroupItem("livingroom")); + locationItem.setLabel("Living room"); + var computerItem = new SwitchItem("computer"); + computerItem.setLabel("Computer"); + var computerItem2 = new SwitchItem("computer2"); + computerItem2.setLabel("Computer"); + when(locationItem.getMembers()).thenReturn(Set.of(computerItem)); + var dialogContext = new DialogContext(null, null, sttService, ttsService, null, List.of(), audioSource, + audioSink, Locale.ENGLISH, "", locationItem.getName(), null, null); + List items = List.of(computerItem2, locationItem, computerItem); + when(itemRegistryMock.getItems()).thenReturn(items); + assertEquals(OK_RESPONSE, standardInterpreter.interpret(Locale.ENGLISH, "turn off computer", dialogContext)); + verify(eventPublisherMock, times(1)) + .post(ItemEventFactory.createCommandEvent(computerItem.getName(), OnOffType.OFF)); + } + @Test public void allowUseItemSynonyms() throws InterpretationException { var computerItem = new SwitchItem("computer");