Skip to content

Commit

Permalink
[TTS] Replace IPA lambda arguments with locale string (#5298)
Browse files Browse the repository at this point in the history
* [TTS] Replace IPA lambda arguments with locale string
* [TTS] Add locale validation
* Fixed typos
* Return punctuation as sorted list

Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman authored Nov 3, 2022
1 parent adb1771 commit 3beba51
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 39 deletions.
2 changes: 0 additions & 2 deletions nemo/collections/common/tokenizers/text_to_speech/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.common.tokenizers.text_to_speech.tokenizer_wrapper import TextToSpeechTokenizer
32 changes: 26 additions & 6 deletions nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@


# fmt: off

SUPPORTED_LOCALES = ["en-US", "de-DE", "es-ES"]

DEFAULT_PUNCTUATION = (
',', '.', '!', '?', '-',
':', ';', '/', '"', '(',
Expand Down Expand Up @@ -62,31 +65,48 @@
'ʊ', 'ʌ', 'ʒ', '̃', 'θ'
)
}

# fmt: on


def validate_locale(locale):
if locale not in SUPPORTED_LOCALES:
raise ValueError(f"Unsupported locale '{locale}'. " f"Supported locales {SUPPORTED_LOCALES}")


def get_grapheme_character_set(locale):
if locale not in GRAPHEME_CHARACTER_SETS:
raise ValueError(f"Grapheme character set not found for locale {locale}")
raise ValueError(
f"Grapheme character set not found for locale '{locale}'. "
f"Supported locales {GRAPHEME_CHARACTER_SETS.keys()}"
)
char_set = set(GRAPHEME_CHARACTER_SETS[locale])
return char_set


def get_ipa_character_set(locale):
if locale not in IPA_CHARACTER_SETS:
raise ValueError(f"IPA character set not found for locale {locale}")
raise ValueError(
f"IPA character set not found for locale '{locale}'. " f"Supported locales {IPA_CHARACTER_SETS.keys()}"
)
char_set = set(IPA_CHARACTER_SETS[locale])
return char_set


def get_ipa_punctuation_list(locale):
punct_list = list(DEFAULT_PUNCTUATION)
if locale is None:
return sorted(list(DEFAULT_PUNCTUATION))

validate_locale(locale)

punct_set = set(DEFAULT_PUNCTUATION)
if locale in ["de-DE", "es-ES"]:
# https://en.wikipedia.org/wiki/Guillemet#Uses
punct_list.extend(['«', '»', '‹', '›'])
punct_set.update(['«', '»', '‹', '›'])
if locale == "de-DE":
punct_list.extend(['„', '“'])
punct_set.update(['„', '“'])
elif locale == "es-ES":
punct_list.extend(['¿', '¡'])
punct_set.update(['¿', '¡'])

punct_list = sorted(list(punct_set))
return punct_list
40 changes: 21 additions & 19 deletions nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
chinese_text_preprocessing,
english_text_preprocessing,
german_text_preprocessing,
ipa_text_preprocessing,
spanish_text_preprocessing,
)

from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import get_ipa_punctuation_list
from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import get_ipa_punctuation_list, validate_locale
from nemo.utils import logging
from nemo.utils.decorators import experimental

Expand Down Expand Up @@ -500,17 +501,10 @@ def set_phone_prob(self, prob):

@experimental
class IPATokenizer(BaseTokenizer):
# fmt: off
PUNCT_LIST = ( # Derived from LJSpeech and "/" additionally
',', '.', '!', '?', '-',
':', ';', '/', '"', '(',
')', '[', ']', '{', '}',
)
# fmt: on

def __init__(
self,
g2p,
locale="en-US",
punct=True,
non_default_punct_list=None,
*,
Expand All @@ -521,11 +515,13 @@ def __init__(
sep='|', # To be able to distinguish between symbols
add_blank_at=None,
pad_with_space=False,
text_preprocessing_func=lambda text: english_text_preprocessing(text, lower=False),
):
"""General-purpose IPA-based tokenizer.
Args:
g2p: Grapheme to phoneme module, should be IPAG2P or some subclass thereof.
locale: Locale used to determine default text processing logic and punctuation.
Supports ["en-US", "de-DE", "es-ES"]. Defaults to "en-US".
Specify None if implementing custom logic for a new locale.
punct: Whether to reserve grapheme for basic punctuation or not.
non_default_punct_list: List of punctuation marks which will be used instead default, if any.
space: Space token as string.
Expand All @@ -534,12 +530,8 @@ def __init__(
oov: OOV token as string.
sep: Separation token as string.
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
if None then no blank in labels.
if None then no blank in labels.
pad_with_space: Whether to pad text with spaces at the beginning and at the end or not.
text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer.
Basically, it replaces all non-unicode characters with unicode ones.
Note that lower() function shouldn't applied here, in case the text contains phonemes (it will be handled by g2p).
Defaults to English text preprocessing.
"""
if not hasattr(g2p, "symbols"):
logging.error(
Expand All @@ -549,6 +541,9 @@ def __init__(
)
raise ValueError("G2P modules passed into the IPATokenizer must have `symbols` defined.")

if locale is not None:
validate_locale(locale)

self.phoneme_probability = None
if hasattr(g2p, "phoneme_probability"):
self.phoneme_probability = g2p.phoneme_probability
Expand All @@ -561,8 +556,11 @@ def __init__(

if punct:
if non_default_punct_list is not None:
self.PUNCT_LIST = non_default_punct_list
tokens.update(self.PUNCT_LIST)
self.punct_list = non_default_punct_list
else:
self.punct_list = get_ipa_punctuation_list(locale)

tokens.update(self.punct_list)

# Sort to ensure that vocab is in the same order every time
tokens = sorted(list(tokens))
Expand All @@ -580,9 +578,13 @@ def __init__(
self.punct = punct
self.pad_with_space = pad_with_space

self.text_preprocessing_func = text_preprocessing_func
self.g2p = g2p

if locale == "en-US":
self.text_preprocessing_func = lambda text: english_text_preprocessing(text, lower=False)
else:
self.text_preprocessing_func = ipa_text_preprocessing

def encode(self, text):
"""See base class for more information."""

Expand All @@ -608,7 +610,7 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None):
elif p in tokens:
# Add next phoneme or char (if chars=True)
ps.append(p)
elif (p in self.PUNCT_LIST) and self.punct:
elif (p in self.punct_list) and self.punct:
# Add punct
ps.append(p)
elif p != space:
Expand Down
7 changes: 6 additions & 1 deletion nemo_text_processing/g2p/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def english_text_preprocessing(text, lower=True):

def _word_tokenize(words):
"""
Convert text (str) to List[Tuple[Union[str, List[str]], bool]] where every tuple denotes word representation and flag whether to leave unchanged or not.
Convert text (str) to List[Tuple[Union[str, List[str]], bool]] where every tuple denotes word representation and
flag whether to leave unchanged or not.
Word can be one of: valid english word, any substring starts from | to | (unchangeable word) or punctuation marks.
This function expects that unchangeable word is carefully divided by spaces (e.g. HH AH L OW).
Unchangeable word will be splitted by space and represented as List[str], other cases are represented as str.
Expand Down Expand Up @@ -146,6 +147,10 @@ def ipa_word_tokenize(text):
return _word_tokenize(words)


def ipa_text_preprocessing(text):
return text.lower()


def german_text_preprocessing(text):
return text.lower()

Expand Down
26 changes: 17 additions & 9 deletions nemo_text_processing/g2p/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@

import nltk
import torch
from nemo_text_processing.g2p.data.data_utils import english_word_tokenize
from nemo_text_processing.g2p.data.data_utils import english_word_tokenize, ipa_word_tokenize

from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import validate_locale
from nemo.utils import logging
from nemo.utils.decorators import experimental
from nemo.utils.get_rank import is_global_rank_zero
Expand Down Expand Up @@ -263,7 +264,7 @@ class IPAG2P(BaseG2p):
def __init__(
self,
phoneme_dict: Union[str, pathlib.Path, dict],
word_tokenize_func: Callable[[str], List[Tuple[Union[str, List[str]], bool]]] = english_word_tokenize,
locale: str = "en-US",
apply_to_oov_word: Optional[Callable[[str], str]] = None,
ignore_ambiguous_words: bool = True,
heteronyms: Optional[Union[str, pathlib.Path, List[str]]] = None,
Expand All @@ -281,13 +282,9 @@ def __init__(
phoneme_dict (str, Path, Dict): Path to file in CMUdict format or a IPA dict object with CMUdict-like entries.
a dictionary file example: scripts/tts_dataset_files/ipa_cmudict-0.7b_nv22.06.txt;
a dictionary object example: {..., "WIRE": [["ˈ", "w", "a", "ɪ", "ɚ"], ["ˈ", "w", "a", "ɪ", "ɹ"]], ...}
word_tokenize_func: Function for tokenizing text to words.
It has to return List[Tuple[Union[str, List[str]], bool]] where every tuple denotes word
representation and flag whether to leave unchanged or not.
It is expected that unchangeable word representation will be represented as List[str], other
cases are represented as str.
It is useful to mark word as unchangeable which is already in phoneme representation.
Defaults to the English word tokenizer.
locale: Locale used to determine default tokenization logic.
Supports ["en-US", "de-DE", "es-ES"]. Defaults to "en-US".
Specify None if implementing custom logic for a new locale.
apply_to_oov_word: Function that will be applied to out of phoneme_dict word.
ignore_ambiguous_words: Whether to not handle word via phoneme_dict with ambiguous phoneme sequences.
Defaults to True.
Expand All @@ -307,6 +304,9 @@ def __init__(
self.phoneme_probability = phoneme_probability
self._rng = random.Random()

if locale is not None:
validate_locale(locale)

if not use_chars and self.phoneme_probability is not None:
self.use_chars = True
logging.warning(
Expand Down Expand Up @@ -348,6 +348,14 @@ def __init__(
"you may see unexpected deletions in your input."
)

# word_tokenize_func returns a List[Tuple[Union[str, List[str]], bool]] where every tuple denotes
# a word representation (word or list tokens) and a flag indicating whether to process the word or
# leave it unchanged.
if locale == "en-US":
word_tokenize_func = english_word_tokenize
else:
word_tokenize_func = ipa_word_tokenize

super().__init__(
phoneme_dict=self.phoneme_dict,
word_tokenize_func=word_tokenize_func,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,27 @@
# limitations under the License.

import pytest
from nemo_text_processing.g2p.modules import IPAG2P

from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import (
EnglishCharsTokenizer,
GermanCharsTokenizer,
IPATokenizer,
SpanishCharsTokenizer,
)


class TestTTSTokenizers:
PHONEME_DICT_DE = {
"HALLO": ["hˈaloː"],
"WELT": ["vˈɛlt"],
}
PHONEME_DICT_EN = {"HELLO": ["həˈɫoʊ"], "WORLD": ["ˈwɝɫd"], "CAFE": ["kəˈfeɪ"]}
PHONEME_DICT_ES = {
"BUENOS": ["bwˈenos"],
"DÍAS": ["dˈias"],
}

@staticmethod
def _parse_text(tokenizer, text):
tokens = tokenizer.encode(text)
Expand Down Expand Up @@ -88,3 +100,60 @@ def test_spanish_chars_tokenizer(self):

assert chars == expected_output
assert len(tokens) == len(input_text)

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_ipa_tokenizer(self):
input_text = "Hello world!"
expected_output = " həˈɫoʊ ˈwɝɫd! "

g2p = IPAG2P(phoneme_dict=self.PHONEME_DICT_EN)

tokenizer = IPATokenizer(g2p=g2p, locale=None, pad_with_space=True)
chars, tokens = self._parse_text(tokenizer, input_text)

assert chars == expected_output

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_ipa_tokenizer_unsupported_locale(self):
g2p = IPAG2P(phoneme_dict=self.PHONEME_DICT_EN)
with pytest.raises(ValueError, match="Unsupported locale"):
IPATokenizer(g2p=g2p, locale="asdf")

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_ipa_tokenizer_de_de(self):
input_text = "Hallo welt"
expected_output = "hˈaloː vˈɛlt"

g2p = IPAG2P(phoneme_dict=self.PHONEME_DICT_DE, locale="de-DE")
tokenizer = IPATokenizer(g2p=g2p, locale="de-DE")
chars, tokens = self._parse_text(tokenizer, input_text)

assert chars == expected_output

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_ipa_tokenizer_en_us(self):
input_text = "Hello café."
expected_output = "həˈɫoʊ kəˈfeɪ."
g2p = IPAG2P(phoneme_dict=self.PHONEME_DICT_EN)

tokenizer = IPATokenizer(g2p=g2p, locale="en-US")
tokenizer.tokens.extend("CAFE")
chars, tokens = self._parse_text(tokenizer, input_text)

assert chars == expected_output

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_ipa_tokenizer_es_es(self):
input_text = "¡Buenos días!"
expected_output = "¡bwˈenos dˈias!"

g2p = IPAG2P(phoneme_dict=self.PHONEME_DICT_ES, locale="es-ES")
tokenizer = IPATokenizer(g2p=g2p, locale="es-ES")
chars, tokens = self._parse_text(tokenizer, input_text)

assert chars == expected_output
2 changes: 2 additions & 0 deletions tests/nemo_text_processing/g2p/phoneme_dict/test_dict_de.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
HALLO hˈaloː
WELT vˈɛlt
2 changes: 2 additions & 0 deletions tests/nemo_text_processing/g2p/phoneme_dict/test_dict_es.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
HOLA ˈola
MUNDO mˈundo
Loading

0 comments on commit 3beba51

Please sign in to comment.