Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix STT and TTS configuration handling #187

Merged
merged 7 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions ovos_plugin_manager/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
PluginTypes, PluginConfigTypes
from ovos_config import Configuration
from ovos_plugin_manager.utils.config import get_valid_plugin_configs, \
sort_plugin_configs
from ovos_utils.log import LOG
sort_plugin_configs, get_plugin_config
from ovos_utils.log import LOG, log_deprecation
from ovos_plugin_manager.templates.stt import STT, StreamingSTT, StreamThread


Expand Down Expand Up @@ -88,14 +88,15 @@ def get_stt_supported_langs() -> dict:
return get_plugin_supported_languages(PluginTypes.STT)


def get_stt_config(config: dict = None) -> dict:
def get_stt_config(config: dict = None, module: str = None) -> dict:
"""
Get relevant configuration for factory methods
@param config: global Configuration OR plugin class-specific configuration
@param module: STT module to get configuration for
@return: plugin class-specific configuration
"""
from ovos_plugin_manager.utils.config import get_plugin_config
stt_config = get_plugin_config(config, "stt")
stt_config = get_plugin_config(config, "stt", module)
assert stt_config.get('lang') is not None, "expected lang but got None"
return stt_config

Expand Down Expand Up @@ -133,7 +134,7 @@ def get_class(config=None):
"module": <engine_name>
}
"""
config = config or get_stt_config()
config = get_stt_config(config)
stt_module = config["module"]
if stt_module in OVOSSTTFactory.MAPPINGS:
stt_module = OVOSSTTFactory.MAPPINGS[stt_module]
Expand All @@ -150,12 +151,15 @@ def create(config=None):
"module": <engine_name>
}
"""
config = get_stt_config(config)
plugin = config["module"]
plugin_config = config.get(plugin) or {}
stt_config = get_stt_config(config)
plugin = stt_config.get("module", "dummy")
if plugin in OVOSSTTFactory.MAPPINGS:
log_deprecation("Module mappings will be deprecated", "0.1.0")
plugin = OVOSSTTFactory.MAPPINGS[plugin]
stt_config = get_stt_config(config, plugin)
try:
clazz = OVOSSTTFactory.get_class(config)
return clazz(plugin_config)
clazz = OVOSSTTFactory.get_class(stt_config)
return clazz(stt_config)
except Exception:
LOG.exception('The selected STT plugin could not be loaded!')
raise
30 changes: 20 additions & 10 deletions ovos_plugin_manager/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ovos_plugin_manager.utils import PluginTypes, normalize_lang, \
PluginConfigTypes
from ovos_plugin_manager.utils.config import get_valid_plugin_configs, \
sort_plugin_configs
from ovos_utils.log import LOG
sort_plugin_configs, get_plugin_config
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.xdg_utils import xdg_data_home
from hashlib import md5

Expand Down Expand Up @@ -90,18 +90,20 @@ def get_tts_supported_langs():
return get_plugin_supported_languages(PluginTypes.TTS)


def get_tts_config(config: dict = None) -> dict:
def get_tts_config(config: dict = None, module: str = None) -> dict:
"""
Get relevant configuration for factory methods
@param config: global Configuration OR plugin class-specific configuration
@param module: TTS module to get configuration for
@return: plugin class-specific configuration
"""
from ovos_plugin_manager.utils.config import get_plugin_config
return get_plugin_config(config, 'tts')
return get_plugin_config(config, 'tts', module)


def get_voice_id(plugin_name, lang, tts_config):
tts_hash = md5(json.dumps(tts_config, sort_keys=True).encode("utf-8")).hexdigest()
tts_hash = md5(json.dumps(tts_config,
sort_keys=True).encode("utf-8")).hexdigest()
return f"{plugin_name}_{lang}_{tts_hash}"


Expand All @@ -110,7 +112,8 @@ def scan_voices():
for lang in get_tts_supported_langs():
VOICES_FOLDER = f"{xdg_data_home()}/OPM/voice_configs/{lang}"
os.makedirs(VOICES_FOLDER, exist_ok=True)
for plug, voices in get_tts_lang_configs(lang, include_dialects=True).items():
for plug, voices in get_tts_lang_configs(lang,
include_dialects=True).items():
for voice in voices:
voiceid = get_voice_id(plug, lang, voice)
if "meta" not in voice:
Expand Down Expand Up @@ -189,20 +192,27 @@ def create(config=None):
}
"""
tts_config = get_tts_config(config)
tts_lang = tts_config["lang"]
tts_module = tts_config.get('module', 'dummy')
if tts_module in OVOSTTSFactory.MAPPINGS:
NeonDaniel marked this conversation as resolved.
Show resolved Hide resolved
# The configured module maps to a valid plugin; get configuration
# again to make sure any module-specific config/overrides are loaded
log_deprecation("Module mappings will be deprecated", "0.1.0")
tts_module = OVOSTTSFactory.MAPPINGS[tts_module]
tts_config = get_tts_config(config, tts_module)
try:
clazz = OVOSTTSFactory.get_class(tts_config)
if clazz:
LOG.info(f'Found plugin {tts_module}')
tts = clazz(tts_lang, tts_config)
tts = clazz(lang=None, # explicitly read lang from config
config=tts_config)
tts.validator.validate()
LOG.info(f'Loaded plugin {tts_module}')
else:
raise FileNotFoundError("unknown plugin")
raise RuntimeError(f"unknown plugin: {tts_module}")
except Exception:
plugins = find_tts_plugins()
modules = ",".join(plugins.keys())
LOG.exception(f'The TTS plugin "{tts_module}" could not be loaded.\nAvailable modules: {modules}')
LOG.exception(f'The TTS plugin "{tts_module}" could not be loaded.'
f'\nAvailable modules: {modules}')
raise
return tts
1 change: 1 addition & 0 deletions ovos_plugin_manager/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def get_plugin_config(config: Optional[dict] = None, section: str = None,
continue
elif isinstance(val, dict):
continue
# Use section-scoped config as defaults (i.e. TTS.lang)
module_config.setdefault(key, val)
config = module_config
if section not in ["hotwords", "VAD", "listener", "gui"]:
Expand Down
71 changes: 66 additions & 5 deletions test/unittests/test_stt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from copy import copy

from unittest.mock import patch
from unittest.mock import patch, Mock
from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes


Expand Down Expand Up @@ -81,15 +81,76 @@ def test_get_supported_langs(self, get_supported_languages):
get_supported_languages.assert_called_once_with(self.PLUGIN_TYPE)

@patch("ovos_plugin_manager.utils.config.get_plugin_config")
def test_get_config(self, get_config):
def test_get_stt_config(self, get_config):
from ovos_plugin_manager.stt import get_stt_config
config = copy(self.TEST_CONFIG)
get_stt_config(self.TEST_CONFIG)
get_config.assert_called_once_with(self.TEST_CONFIG,
self.CONFIG_SECTION)
self.CONFIG_SECTION, None)
self.assertEqual(config, self.TEST_CONFIG)


class TestSTTFactory(unittest.TestCase):
from ovos_plugin_manager.stt import OVOSSTTFactory
# TODO
def test_mappings(self):
from ovos_plugin_manager.stt import OVOSSTTFactory
self.assertIsInstance(OVOSSTTFactory.MAPPINGS, dict)
for key in OVOSSTTFactory.MAPPINGS:
self.assertIsInstance(key, str)
self.assertIsInstance(OVOSSTTFactory.MAPPINGS[key], str)
self.assertNotEqual(key, OVOSSTTFactory.MAPPINGS[key])

@patch("ovos_plugin_manager.stt.load_stt_plugin")
def test_get_class(self, load_plugin):
from ovos_plugin_manager.stt import OVOSSTTFactory
global_config = {"stt": {"module": "dummy"}}
tts_config = {"module": "test-stt-plugin-test"}

# Test load plugin mapped global config
OVOSSTTFactory.get_class(global_config)
load_plugin.assert_called_with("ovos-stt-plugin-dummy")

# Test load plugin explicit STT config
OVOSSTTFactory.get_class(tts_config)
load_plugin.assert_called_with("test-stt-plugin-test")

@patch("ovos_plugin_manager.stt.OVOSSTTFactory.get_class")
def test_create(self, get_class):
from ovos_plugin_manager.stt import OVOSSTTFactory
plugin_class = Mock()
get_class.return_value = plugin_class

global_config = {"lang": "en-gb",
"stt": {"module": "dummy",
"ovos-stt-plugin-dummy": {"config": True,
"lang": "en-ca"}}}
stt_config = {"lang": "es-es",
"module": "test-stt-plugin-test"}

stt_config_2 = {"lang": "es-es",
"module": "test-stt-plugin-test",
"test-stt-plugin-test": {"config": True,
"lang": "es-mx"}}

# Test create with global config and lang override
plugin = OVOSSTTFactory.create(global_config)
expected_config = {"module": "ovos-stt-plugin-dummy",
"config": True,
"lang": "en-ca"}
get_class.assert_called_once_with(expected_config)
plugin_class.assert_called_once_with(expected_config)
self.assertEqual(plugin, plugin_class())

# Test create with STT config and no module config
plugin = OVOSSTTFactory.create(stt_config)
get_class.assert_called_with(stt_config)
plugin_class.assert_called_with(stt_config)
self.assertEqual(plugin, plugin_class())

# Test create with STT config with module-specific config
plugin = OVOSSTTFactory.create(stt_config_2)
expected_config = {"module": "test-stt-plugin-test",
"config": True, "lang": "es-mx"}
get_class.assert_called_with(expected_config)
plugin_class.assert_called_with(expected_config)
self.assertEqual(plugin, plugin_class())

71 changes: 65 additions & 6 deletions test/unittests/test_tts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import patch
from unittest.mock import patch, Mock
from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes
from ovos_plugin_manager.templates.tts import TTS

Expand Down Expand Up @@ -180,11 +180,11 @@ def test_get_supported_langs(self, get_supported_languages):
get_supported_languages.assert_called_once_with(self.PLUGIN_TYPE)

@patch("ovos_plugin_manager.utils.config.get_plugin_config")
def test_get_config(self, get_config):
def test_get_tts_config(self, get_config):
from ovos_plugin_manager.tts import get_tts_config
get_tts_config(self.TEST_CONFIG)
get_config.assert_called_once_with(self.TEST_CONFIG,
self.CONFIG_SECTION)
self.CONFIG_SECTION, None)

def test_get_voice_id(self):
from ovos_plugin_manager.tts import get_voice_id
Expand All @@ -200,6 +200,65 @@ def test_get_voices(self):


class TestTTSFactory(unittest.TestCase):
from ovos_plugin_manager.tts import OVOSTTSFactory
# TODO

def test_mappings(self):
from ovos_plugin_manager.tts import OVOSTTSFactory
self.assertIsInstance(OVOSTTSFactory.MAPPINGS, dict)
for key in OVOSTTSFactory.MAPPINGS:
self.assertIsInstance(key, str)
self.assertIsInstance(OVOSTTSFactory.MAPPINGS[key], str)
self.assertNotEqual(key, OVOSTTSFactory.MAPPINGS[key])

@patch("ovos_plugin_manager.tts.load_tts_plugin")
def test_get_class(self, load_plugin):
from ovos_plugin_manager.tts import OVOSTTSFactory
global_config = {"tts": {"module": "dummy"}}
tts_config = {"module": "test-tts-plugin-test"}

# Test load plugin mapped global config
OVOSTTSFactory.get_class(global_config)
load_plugin.assert_called_with("ovos-tts-plugin-dummy")

# Test load plugin explicit TTS config
OVOSTTSFactory.get_class(tts_config)
load_plugin.assert_called_with("test-tts-plugin-test")

@patch("ovos_plugin_manager.tts.OVOSTTSFactory.get_class")
def test_create(self, get_class):
from ovos_plugin_manager.tts import OVOSTTSFactory
plugin_class = Mock()
get_class.return_value = plugin_class

global_config = {"lang": "en-gb",
"tts": {"module": "dummy",
"ovos-tts-plugin-dummy": {"config": True,
"lang": "en-ca"}}}
tts_config = {"lang": "es-es",
"module": "test-tts-plugin-test"}

tts_config_2 = {"lang": "es-es",
"module": "test-tts-plugin-test",
"test-tts-plugin-test": {"config": True,
"lang": "es-mx"}}

# Test create with global config and lang override
plugin = OVOSTTSFactory.create(global_config)
expected_config = {"module": "ovos-tts-plugin-dummy",
"config": True,
"lang": "en-ca"}
get_class.assert_called_once_with(expected_config)
plugin_class.assert_called_once_with(lang=None, config=expected_config)
self.assertEqual(plugin, plugin_class())

# Test create with TTS config and no module config
plugin = OVOSTTSFactory.create(tts_config)
get_class.assert_called_with(tts_config)
plugin_class.assert_called_with(lang=None, config=tts_config)
self.assertEqual(plugin, plugin_class())

# Test create with TTS config with module-specific config
plugin = OVOSTTSFactory.create(tts_config_2)
expected_config = {"module": "test-tts-plugin-test",
"config": True, "lang": "es-mx"}
get_class.assert_called_with(expected_config)
plugin_class.assert_called_with(lang=None, config=expected_config)
self.assertEqual(plugin, plugin_class())
20 changes: 20 additions & 0 deletions test/unittests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,26 @@ def test_get_plugin_config(self, config):
self.assertEqual(seg_config, get_plugin_config(section="segmentation"))
self.assertEqual(gui_config, get_plugin_config(section="gui"))

# Test TTS config with plugin `lang` override
config = {
"lang": "en-us",
"tts": {
"module": "ovos_tts_plugin_espeakng",
"ovos_tts_plugin_espeakng": {
"lang": "de-de",
"voice": "german-mbrola-5",
"speed": "135",
"amplitude": "80",
"pitch": "20"
}
}
}
tts_config = get_plugin_config(config, "tts")
self.assertEqual(tts_config['lang'], 'de-de')
self.assertEqual(tts_config['module'], 'ovos_tts_plugin_espeakng')
self.assertEqual(tts_config['voice'], 'german-mbrola-5')
self.assertNotIn("ovos_tts_plugin_espeakng", tts_config)

self.assertEqual(_MOCK_CONFIG, start_config)

def test_get_valid_plugin_configs(self):
Expand Down