Skip to content

Commit

Permalink
voice from Session
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Apr 20, 2024
1 parent 3bcfaf9 commit e8b6066
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 37 deletions.
91 changes: 63 additions & 28 deletions ovos_plugin_manager/templates/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ovos_utils.file_utils import get_cache_directory
from ovos_utils.file_utils import resolve_resource_file
from ovos_utils.lang.visimes import VISIMES
from ovos_utils.log import LOG, deprecated
from ovos_utils.log import LOG, deprecated, log_deprecation
from ovos_utils.metrics import Stopwatch
from ovos_utils.process_utils import RuntimeRequirements

Expand All @@ -37,10 +37,11 @@
class TTSContext:
_caches = {}

def __init__(self, plugin_id: str, lang: str, voice: str):
def __init__(self, plugin_id: str, lang: str, voice: str, synth_kwargs: dict = None):
self.plugin_id = plugin_id
self.lang = lang
self.voice = voice
self.synth_kwargs = synth_kwargs or {}

@property
def tts_id(self):
Expand Down Expand Up @@ -90,6 +91,9 @@ class TTS:

def __init__(self, lang=None, config=None, validator=None,
audio_ext='wav', phonetic_spelling=True, ssml_tags=None):
if lang is not None:
log_deprecation("lang argument for TTS has been deprecated! it will be ignored, "
"pass lang to get_tts directly instead")
self.log_timestamps = False

self.config = config or get_plugin_config(config, "tts")
Expand Down Expand Up @@ -117,6 +121,18 @@ def __init__(self, lang=None, config=None, validator=None,
# only present for backwards compat reasons
self.bus = None

self._plugin_id = "" # the plugin name

@property
def plugin_id(self) -> str:
if not self._plugin_id:
from ovos_plugin_manager.tts import find_tts_plugins
for tts_id, clazz in find_tts_plugins().items():
if isinstance(self, clazz):
self._plugin_id = tts_id
break
return self._plugin_id

# methods for individual plugins to override
@classproperty
def runtime_requirements(self):
Expand Down Expand Up @@ -183,11 +199,18 @@ def handle_metric(self, metadata=None):
# properties that reflect bus message session
@property
def voice(self):
voice = self.config.get("voice") or "default"
message = dig_for_message()
if message:
# TODO - get from tts_prefs in session
pass
return self.config.get("voice") or "default"
sess = SessionManager.get()
if sess.tts_preferences["plugin_id"] == self.plugin_id:
v = sess.tts_preferences["config"].get("voice")
if v:
voice = v
else:
# we got a request for a TTS plugin that isn't loaded!
LOG.error("ignoring TTS preferences in Session, plugin does not match!")
return voice

@voice.setter
def voice(self, val):
Expand Down Expand Up @@ -432,18 +455,38 @@ def _get_visemes(self, phonemes, sentence, ctxt):
LOG.debug(f"no mouth movements available! unknown visemes for {sentence}")
return viseme

def _get_ctxt(self, kwargs=None):
kwargs = kwargs or {}
def _get_ctxt(self, kwargs=None) -> TTSContext:
"""create a TTSContext from arbitrary kwargs passed to synth/execute methods
takes preferences from Session into account if a message is present
"""
# get request specific synth params
kwargs = kwargs or {}
message = kwargs.get("message") or dig_for_message()
lang = kwargs.get("lang")
voice = kwargs.get("voice")
if message and not lang:
sess = SessionManager.get(message)
lang = lang or sess.lang
return TTSContext(plugin_id=self.tts_name, # TODO this should be the OPM name at some point
lang=lang or self.lang,
voice=voice or self.voice)

# update kwargs from session
if message:
sess = SessionManager.get()
if sess.tts_preferences["plugin_id"] == self.plugin_id:
for k, v in sess.tts_preferences["config"].items():
if k not in kwargs:
kwargs[k] = v
else:
# we got a request for a TTS plugin that isn't loaded!
LOG.error("ignoring TTS preferences in Session, plugin does not match!")

if "lang" not in kwargs:
kwargs["lang"] = sess.lang

# filter kwargs accepted by this specific plugin
kwargs = {k: v for k, v in kwargs.items()
if k in inspect.signature(self.get_tts).parameters
and k not in ["sentence", "wav_file"]}

LOG.debug(f"TTS kwargs: {kwargs}")
return TTSContext(plugin_id=self.plugin_id,
lang=kwargs.get("lang") or self.lang,
voice=kwargs.get("voice") or self.voice,
synth_kwargs=kwargs)

def _execute(self, sentence, ident, listen, preprocess=True, **kwargs):
if preprocess:
Expand All @@ -470,7 +513,7 @@ def _execute(self, sentence, ident, listen, preprocess=True, **kwargs):
# synth -> queue for playback
for sentence, l in chunks:
# load from cache or synth + cache
audio_file, phonemes = self.synth(sentence, ctxt, **kwargs)
audio_file, phonemes = self.synth(sentence, ctxt)

# get visemes/mouth movements
viseme = self._get_visemes(phonemes, sentence, ctxt)
Expand All @@ -490,7 +533,7 @@ def synth(self, sentence, ctxt: TTSContext = None, **kwargs):
self.add_metric({"metric_type": "tts.synth.start"})
sentence_hash = hash_sentence(sentence)

# parse requested language for this TTS request
# parse kwargs for this TTS request
ctxt = ctxt or self._get_ctxt(kwargs)
cache = ctxt.get_cache(self.audio_ext, self.config)

Expand All @@ -502,16 +545,8 @@ def synth(self, sentence, ctxt: TTSContext = None, **kwargs):

# synth + cache
audio = cache.define_audio_file(sentence_hash)

# filter kwargs per plugin, different plugins expose different kwargs
# ovos -> lang + voice optional kwargs
# neon-core -> message
kwargs = {k: v for k, v in kwargs.items()
if k in inspect.signature(self.get_tts).parameters
and k not in ["sentence", "wav_file"]}

# finally do the TTS synth
audio.path, phonemes = self.get_tts(sentence, str(audio), **kwargs)
audio.path, phonemes = self.get_tts(sentence, str(audio),
**ctxt.synth_kwargs)
self.add_metric({"metric_type": "tts.synth.finished"})

# cache sentence + phonemes
Expand Down Expand Up @@ -588,7 +623,6 @@ def __del__(self):
self.shutdown()

# below code is all deprecated and marked for removal in next stable release
# TODO - update version number in warnings
@property
@deprecated("self.enclosure has been deprecated, use EnclosureAPI directly decoupled from the plugin code",
"0.1.0")
Expand Down Expand Up @@ -1004,3 +1038,4 @@ def __new__(self, *args, **kwargs):
return PlaybackThread(*args, **kwargs)
except ImportError:
raise ImportError("please install ovos-audio for playback handling")

31 changes: 22 additions & 9 deletions test/unittests/test_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,10 @@ def test_create(self, get_class):


class TestTTSContext(unittest.TestCase):
def test_tts_context_init(self):
session_mock = MagicMock()
tts_context = TTSContext(session=session_mock)
self.assertEqual(tts_context.session, session_mock)
self.assertEqual(tts_context.lang, session_mock.lang)

@patch("ovos_plugin_manager.templates.tts.TextToSpeechCache", autospec=True)
def test_tts_context_get_cache(self, cache_mock):
session_mock = MagicMock()
tts_context = TTSContext(session=session_mock)
tts_context = TTSContext("plug", "voice", "lang")

result = tts_context.get_cache()

Expand All @@ -286,13 +280,13 @@ def test_tts_context_get_cache(self, cache_mock):

class TestTTSCache(unittest.TestCase):
def setUp(self):
self.tts_mock = TTS(lang="en-us", config={"some_config_key": "some_config_value"})
self.tts_mock = TTS(config={"some_config_key": "some_config_value"})
self.tts_mock.stopwatch = MagicMock()
self.tts_mock.queue = MagicMock()
self.tts_mock.playback = MagicMock()

@patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash")
@patch("ovos_plugin_manager.templates.tts.TTSContext", autospec=True)
@patch("ovos_plugin_manager.templates.tts.TTSContext")
def test_tts_synth(self, tts_context_mock, hash_sentence_mock):
tts_context_mock.get_cache.return_value = MagicMock()
tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path"
Expand All @@ -314,9 +308,28 @@ def test_tts_synth_cache_enabled(self, hash_sentence_mock):
tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value}

sentence = "Hello world!"
self.tts_mock.enable_cache = True
result = self.tts_mock.synth(sentence, tts_context_mock)

tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config)
tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash")
self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None))
self.assertIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences)

@patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash")
def test_tts_synth_cache_disabled(self, hash_sentence_mock):
tts_context_mock = MagicMock()
tts_context_mock.tts_id = "fake_tts_id"
tts_context_mock.get_cache.return_value = MagicMock()
tts_context_mock.get_cache.return_value.cached_sentences = {}
tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path"
tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value}

sentence = "Hello world!"
self.tts_mock.enable_cache = False
result = self.tts_mock.synth(sentence, tts_context_mock)

tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config)
tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash")
self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None))
self.assertNotIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences)

0 comments on commit e8b6066

Please sign in to comment.