diff --git a/ovos_plugin_manager/wakewords.py b/ovos_plugin_manager/wakewords.py index 49f1a8e4..15ffae55 100644 --- a/ovos_plugin_manager/wakewords.py +++ b/ovos_plugin_manager/wakewords.py @@ -1,6 +1,8 @@ import json import os from hashlib import md5 +from typing import Optional + from ovos_utils.log import LOG from ovos_utils.xdg_utils import xdg_data_home @@ -98,29 +100,64 @@ class OVOSWakeWordFactory: } @staticmethod - def get_class(hotword, config=None): - config = get_hotwords_config(config) - if hotword not in config: + def get_class(hotword: str, config: Optional[dict] = None) -> type: + """ + Get the plugin class for the specified hotword + @param hotword: string hotword to load + @param config: optional global configuration + @return: Uninitialized hotword class + """ + hotword_config = get_hotwords_config(config) + if hotword not in hotword_config: + LOG.warning(f"{hotword} not in {hotword_config}! " + f"Returning base HotWordEngine") return HotWordEngine - ww_module = config["module"] + ww_module = hotword_config[hotword]["module"] if ww_module in OVOSWakeWordFactory.MAPPINGS: ww_module = OVOSWakeWordFactory.MAPPINGS[ww_module] return load_wake_word_plugin(ww_module) @staticmethod - def load_module(module, hotword, config, lang, loop): - LOG.info(f'Loading "{hotword}" wake word via {module}') - clazz = OVOSWakeWordFactory.get_class(module, config) + def load_module(module: str, hotword: str, hotword_config: dict, + lang: str, loop=None) -> HotWordEngine: + """ + Get an initialized HotWordEngine using the specified module and hotword + @param module: hotword plugin to load (not parsed) + @param hotword: string hotword to load + @param hotword_config: configuration for the specified `hotword`. + Equivalent to Configuration()['hotwords'][hotword] + @param lang: BCP-47 language code of hotword + @param loop: Unused + @return: Initialized HotWordEngine + """ + # config here is config['hotwords'][module] + LOG.info(f'Loading "{hotword}" wake word via {module} with ' + f'config: {hotword_config}') + config = {"lang": lang, "hotwords": {hotword: hotword_config}} + clazz = OVOSWakeWordFactory.get_class(hotword, config) if clazz is None: - raise ImportError(f'Wake Word plugin {module} failed to load') - LOG.info(f'Loaded the Wake Word plugin {module}') - return clazz(hotword, config, lang=lang) + raise ImportError(f'Wake Word {hotword} with module {module} ' + f'failed to load') + LOG.info(f'Loaded the Wake Word {hotword} with module {module}') + return clazz(hotword, hotword_config, lang=lang) @classmethod - def create_hotword(cls, hotword="hey mycroft", config=None, - lang="en-us", loop=None): + def create_hotword(cls, hotword: str = "hey mycroft", + config: Optional[dict] = None, + lang: str = "en-us", loop=None) -> HotWordEngine: + """ + Get an initialized HotWordEngine by configured name + @param hotword: string hotword to load + @param config: optional global configuration + @param lang: BCP-47 language code of hotword + @param loop: Unused + @return: Initialized HotWordEngine + """ ww_configs = get_hotwords_config(config) - ww_config = ww_configs.get(hotword) or ww_configs.get("hey_mycroft") + if hotword not in ww_configs: + LOG.warning(f"replace ` ` in {hotword} with `_`") + hotword = hotword.replace(' ', '_') + ww_config = ww_configs.get(hotword) module = ww_config.get("module", "pocketsphinx") try: return cls.load_module(module, hotword, ww_config, lang, loop) diff --git a/test/unittests/test_wakewords.py b/test/unittests/test_wakewords.py new file mode 100644 index 00000000..ec6b4493 --- /dev/null +++ b/test/unittests/test_wakewords.py @@ -0,0 +1,87 @@ +import unittest +from unittest.mock import patch, Mock + +from ovos_plugin_manager import PluginTypes + +_TEST_CONFIG = { + "hotwords": { + "hey_neon": { + "module": "ovos-ww-plugin-vosk", + "listen": True, + "active": True + }, + "hey_mycroft": { + "module": "precise", + "listen": True, + "active": True + } + } +} + + +class TestWakeWordFactory(unittest.TestCase): + def test_create_hotword(self): + from ovos_plugin_manager.wakewords import OVOSWakeWordFactory + real_load_module = OVOSWakeWordFactory.load_module + mock_load = Mock() + OVOSWakeWordFactory.load_module = mock_load + + OVOSWakeWordFactory.create_hotword(config=_TEST_CONFIG) + mock_load.assert_called_once_with("precise", "hey_mycroft", + _TEST_CONFIG["hotwords"] + ['hey_mycroft'], "en-us", None) + + OVOSWakeWordFactory.create_hotword("hey_neon", _TEST_CONFIG) + mock_load.assert_called_with("ovos-ww-plugin-vosk", "hey_neon", + _TEST_CONFIG["hotwords"] + ['hey_neon'], "en-us", None) + OVOSWakeWordFactory.load_module = real_load_module + + @patch("ovos_plugin_manager.wakewords.load_plugin") + def test_get_class(self, load_plugin): + mock = Mock() + load_plugin.return_value = mock + from ovos_plugin_manager.wakewords import OVOSWakeWordFactory + # Test valid module + module = OVOSWakeWordFactory.get_class("hey_neon", _TEST_CONFIG) + load_plugin.assert_called_once_with("ovos-ww-plugin-vosk", + PluginTypes.WAKEWORD) + self.assertEqual(mock, module) + + # Test mapped module + load_plugin.reset_mock() + module = OVOSWakeWordFactory.get_class("hey_mycroft", _TEST_CONFIG) + load_plugin.assert_called_once_with("ovos-ww-plugin-precise", + PluginTypes.WAKEWORD) + self.assertEqual(mock, module) + + # Test invalid module + load_plugin.reset_mock() + module = OVOSWakeWordFactory.get_class("invalid_ww", _TEST_CONFIG) + load_plugin.assert_not_called() + from ovos_plugin_manager.templates.hotwords import HotWordEngine + self.assertEqual(module, HotWordEngine) + + def test_load_module(self): + from ovos_plugin_manager.wakewords import OVOSWakeWordFactory + real_get_class = OVOSWakeWordFactory.get_class + mock_get_class = Mock() + OVOSWakeWordFactory.get_class = mock_get_class + + # Test valid return + mock_return = Mock() + mock_get_class.return_value = mock_return + module = OVOSWakeWordFactory.load_module( + "precise", "hey_mycroft", _TEST_CONFIG['hotwords']['hey_mycroft'], + 'en-us') + mock_get_class.assert_called_once_with( + "hey_mycroft", {"lang": "en-us", "hotwords": { + "hey_mycroft": _TEST_CONFIG['hotwords']['hey_mycroft']}}) + self.assertEqual(module, mock_return()) + + # Test no return + mock_get_class.return_value = None + with self.assertRaises(ImportError): + OVOSWakeWordFactory.load_module("dummy", "test", {}, "en-us") + + OVOSWakeWordFactory.get_class = real_get_class