From e8b6066c1edbb070710d35631ce58d40d1e9657f Mon Sep 17 00:00:00 2001 From: miro Date: Sat, 20 Apr 2024 19:41:55 +0100 Subject: [PATCH] voice from Session --- ovos_plugin_manager/templates/tts.py | 91 +++++++++++++++++++--------- test/unittests/test_tts.py | 31 +++++++--- 2 files changed, 85 insertions(+), 37 deletions(-) diff --git a/ovos_plugin_manager/templates/tts.py b/ovos_plugin_manager/templates/tts.py index b7e2b078..abd8b920 100644 --- a/ovos_plugin_manager/templates/tts.py +++ b/ovos_plugin_manager/templates/tts.py @@ -21,7 +21,7 @@ from ovos_utils.file_utils import get_cache_directory from ovos_utils.file_utils import resolve_resource_file from ovos_utils.lang.visimes import VISIMES -from ovos_utils.log import LOG, deprecated +from ovos_utils.log import LOG, deprecated, log_deprecation from ovos_utils.metrics import Stopwatch from ovos_utils.process_utils import RuntimeRequirements @@ -37,10 +37,11 @@ class TTSContext: _caches = {} - def __init__(self, plugin_id: str, lang: str, voice: str): + def __init__(self, plugin_id: str, lang: str, voice: str, synth_kwargs: dict = None): self.plugin_id = plugin_id self.lang = lang self.voice = voice + self.synth_kwargs = synth_kwargs or {} @property def tts_id(self): @@ -90,6 +91,9 @@ class TTS: def __init__(self, lang=None, config=None, validator=None, audio_ext='wav', phonetic_spelling=True, ssml_tags=None): + if lang is not None: + log_deprecation("lang argument for TTS has been deprecated! it will be ignored, " + "pass lang to get_tts directly instead") self.log_timestamps = False self.config = config or get_plugin_config(config, "tts") @@ -117,6 +121,18 @@ def __init__(self, lang=None, config=None, validator=None, # only present for backwards compat reasons self.bus = None + self._plugin_id = "" # the plugin name + + @property + def plugin_id(self) -> str: + if not self._plugin_id: + from ovos_plugin_manager.tts import find_tts_plugins + for tts_id, clazz in find_tts_plugins().items(): + if isinstance(self, clazz): + self._plugin_id = tts_id + break + return self._plugin_id + # methods for individual plugins to override @classproperty def runtime_requirements(self): @@ -183,11 +199,18 @@ def handle_metric(self, metadata=None): # properties that reflect bus message session @property def voice(self): + voice = self.config.get("voice") or "default" message = dig_for_message() if message: - # TODO - get from tts_prefs in session - pass - return self.config.get("voice") or "default" + sess = SessionManager.get() + if sess.tts_preferences["plugin_id"] == self.plugin_id: + v = sess.tts_preferences["config"].get("voice") + if v: + voice = v + else: + # we got a request for a TTS plugin that isn't loaded! + LOG.error("ignoring TTS preferences in Session, plugin does not match!") + return voice @voice.setter def voice(self, val): @@ -432,18 +455,38 @@ def _get_visemes(self, phonemes, sentence, ctxt): LOG.debug(f"no mouth movements available! unknown visemes for {sentence}") return viseme - def _get_ctxt(self, kwargs=None): - kwargs = kwargs or {} + def _get_ctxt(self, kwargs=None) -> TTSContext: + """create a TTSContext from arbitrary kwargs passed to synth/execute methods + takes preferences from Session into account if a message is present + """ # get request specific synth params + kwargs = kwargs or {} message = kwargs.get("message") or dig_for_message() - lang = kwargs.get("lang") - voice = kwargs.get("voice") - if message and not lang: - sess = SessionManager.get(message) - lang = lang or sess.lang - return TTSContext(plugin_id=self.tts_name, # TODO this should be the OPM name at some point - lang=lang or self.lang, - voice=voice or self.voice) + + # update kwargs from session + if message: + sess = SessionManager.get() + if sess.tts_preferences["plugin_id"] == self.plugin_id: + for k, v in sess.tts_preferences["config"].items(): + if k not in kwargs: + kwargs[k] = v + else: + # we got a request for a TTS plugin that isn't loaded! + LOG.error("ignoring TTS preferences in Session, plugin does not match!") + + if "lang" not in kwargs: + kwargs["lang"] = sess.lang + + # filter kwargs accepted by this specific plugin + kwargs = {k: v for k, v in kwargs.items() + if k in inspect.signature(self.get_tts).parameters + and k not in ["sentence", "wav_file"]} + + LOG.debug(f"TTS kwargs: {kwargs}") + return TTSContext(plugin_id=self.plugin_id, + lang=kwargs.get("lang") or self.lang, + voice=kwargs.get("voice") or self.voice, + synth_kwargs=kwargs) def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): if preprocess: @@ -470,7 +513,7 @@ def _execute(self, sentence, ident, listen, preprocess=True, **kwargs): # synth -> queue for playback for sentence, l in chunks: # load from cache or synth + cache - audio_file, phonemes = self.synth(sentence, ctxt, **kwargs) + audio_file, phonemes = self.synth(sentence, ctxt) # get visemes/mouth movements viseme = self._get_visemes(phonemes, sentence, ctxt) @@ -490,7 +533,7 @@ def synth(self, sentence, ctxt: TTSContext = None, **kwargs): self.add_metric({"metric_type": "tts.synth.start"}) sentence_hash = hash_sentence(sentence) - # parse requested language for this TTS request + # parse kwargs for this TTS request ctxt = ctxt or self._get_ctxt(kwargs) cache = ctxt.get_cache(self.audio_ext, self.config) @@ -502,16 +545,8 @@ def synth(self, sentence, ctxt: TTSContext = None, **kwargs): # synth + cache audio = cache.define_audio_file(sentence_hash) - - # filter kwargs per plugin, different plugins expose different kwargs - # ovos -> lang + voice optional kwargs - # neon-core -> message - kwargs = {k: v for k, v in kwargs.items() - if k in inspect.signature(self.get_tts).parameters - and k not in ["sentence", "wav_file"]} - - # finally do the TTS synth - audio.path, phonemes = self.get_tts(sentence, str(audio), **kwargs) + audio.path, phonemes = self.get_tts(sentence, str(audio), + **ctxt.synth_kwargs) self.add_metric({"metric_type": "tts.synth.finished"}) # cache sentence + phonemes @@ -588,7 +623,6 @@ def __del__(self): self.shutdown() # below code is all deprecated and marked for removal in next stable release - # TODO - update version number in warnings @property @deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code", "0.1.0") @@ -1004,3 +1038,4 @@ def __new__(self, *args, **kwargs): return PlaybackThread(*args, **kwargs) except ImportError: raise ImportError("please install ovos-audio for playback handling") + diff --git a/test/unittests/test_tts.py b/test/unittests/test_tts.py index efa01bad..1a2fa263 100644 --- a/test/unittests/test_tts.py +++ b/test/unittests/test_tts.py @@ -267,16 +267,10 @@ def test_create(self, get_class): class TestTTSContext(unittest.TestCase): - def test_tts_context_init(self): - session_mock = MagicMock() - tts_context = TTSContext(session=session_mock) - self.assertEqual(tts_context.session, session_mock) - self.assertEqual(tts_context.lang, session_mock.lang) @patch("ovos_plugin_manager.templates.tts.TextToSpeechCache", autospec=True) def test_tts_context_get_cache(self, cache_mock): - session_mock = MagicMock() - tts_context = TTSContext(session=session_mock) + tts_context = TTSContext("plug", "voice", "lang") result = tts_context.get_cache() @@ -286,13 +280,13 @@ def test_tts_context_get_cache(self, cache_mock): class TestTTSCache(unittest.TestCase): def setUp(self): - self.tts_mock = TTS(lang="en-us", config={"some_config_key": "some_config_value"}) + self.tts_mock = TTS(config={"some_config_key": "some_config_value"}) self.tts_mock.stopwatch = MagicMock() self.tts_mock.queue = MagicMock() self.tts_mock.playback = MagicMock() @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") - @patch("ovos_plugin_manager.templates.tts.TTSContext", autospec=True) + @patch("ovos_plugin_manager.templates.tts.TTSContext") def test_tts_synth(self, tts_context_mock, hash_sentence_mock): tts_context_mock.get_cache.return_value = MagicMock() tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" @@ -314,9 +308,28 @@ def test_tts_synth_cache_enabled(self, hash_sentence_mock): tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value} sentence = "Hello world!" + self.tts_mock.enable_cache = True result = self.tts_mock.synth(sentence, tts_context_mock) tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) self.assertIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences) + + @patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash") + def test_tts_synth_cache_disabled(self, hash_sentence_mock): + tts_context_mock = MagicMock() + tts_context_mock.tts_id = "fake_tts_id" + tts_context_mock.get_cache.return_value = MagicMock() + tts_context_mock.get_cache.return_value.cached_sentences = {} + tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path" + tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value} + + sentence = "Hello world!" + self.tts_mock.enable_cache = False + result = self.tts_mock.synth(sentence, tts_context_mock) + + tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config) + tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash") + self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None)) + self.assertNotIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences)