Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS]add typehint for g2pw #2390

Merged
merged 1 commit into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddlespeech/t2s/frontend/g2pw/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter
from .onnx_api import G2PWOnnxConverter
66 changes: 34 additions & 32 deletions paddlespeech/t2s/frontend/g2pw/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np

from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
Expand All @@ -23,22 +27,17 @@


def prepare_onnx_input(tokenizer,
labels,
char2phonemes,
chars,
texts,
query_ids,
phonemes=None,
pos_tags=None,
use_mask=False,
use_char_phoneme=False,
use_pos=False,
window_size=None,
max_len=512):
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool=False,
window_size: int=None,
max_len: int=512) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size,
texts, query_ids)

truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
Expand All @@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer,
query_id = (truncated_query_ids if window_size else query_ids)[idx]

try:
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
tokens, text2token, token2text = tokenize_and_map(
tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}

text, query_id, tokens, text2token, token2text = _truncate(
max_len, text, query_id, tokens, text2token, token2text)
max_len=max_len,
text=text,
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)

processed_tokens = ['[CLS]'] + tokens + ['[SEP]']

Expand Down Expand Up @@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer,
return outputs


def _truncate_texts(window_size, texts, query_ids):
def _truncate_texts(window_size: int, texts: List[str],
query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
Expand All @@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids):
return truncated_texts, truncated_query_ids


def _truncate(max_len, text, query_id, tokens, text2token, token2text):
def _truncate(max_len: int,
text: str,
query_id: int,
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
Expand All @@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])


def prepare_data(sent_path, lb_path=None):
raw_texts = open(sent_path).read().rstrip().split('\n')
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
if lb_path is None:
return texts, query_ids
else:
phonemes = open(lb_path).read().rstrip().split('\n')
return texts, query_ids, phonemes


def get_phoneme_labels(polyphonic_chars):
def get_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
Expand All @@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
return labels, char2phonemes


def get_char_phoneme_labels(polyphonic_chars):
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {}
Expand Down
50 changes: 30 additions & 20 deletions paddlespeech/t2s/frontend/g2pw/onnx_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
"""
import json
import os
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np
import onnxruntime
Expand All @@ -37,7 +41,8 @@
model_version = '1.1'


def predict(session, onnx_input, labels):
def predict(session, onnx_input: Dict[str, Any],
labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = []
all_confidences = []
probs = session.run([], {
Expand All @@ -61,10 +66,10 @@ def predict(session, onnx_input, labels):

class G2PWOnnxConverter:
def __init__(self,
model_dir=MODEL_HOME,
style='bopomofo',
model_source=None,
enable_non_tradional_chinese=False):
model_dir: os.PathLike=MODEL_HOME,
style: str='bopomofo',
model_source: str=None,
enable_non_tradional_chinese: bool=False):
uncompress_path = download_and_decompress(
g2pw_onnx_models['G2PWModel'][model_version], model_dir)

Expand All @@ -76,7 +81,8 @@ def __init__(self,
os.path.join(uncompress_path, 'g2pW.onnx'),
sess_options=sess_options)
self.config = load_config(
os.path.join(uncompress_path, 'config.py'), use_default=True)
config_path=os.path.join(uncompress_path, 'config.py'),
use_default=True)

self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
Expand All @@ -103,9 +109,9 @@ def __init__(self,
.strip().split('\n')
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
self.polyphonic_chars
polyphonic_chars=self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
self.polyphonic_chars)
polyphonic_chars=self.polyphonic_chars)

self.chars = sorted(list(self.char2phonemes.keys()))

Expand Down Expand Up @@ -146,7 +152,7 @@ def __init__(self,
if self.enable_opencc:
self.cc = OpenCC('s2tw')

def _convert_bopomofo_to_pinyin(self, bopomofo):
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
assert tone in '12345'
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
Expand All @@ -156,7 +162,7 @@ def _convert_bopomofo_to_pinyin(self, bopomofo):
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None

def __call__(self, sentences):
def __call__(self, sentences: List[str]) -> List[List[str]]:
if isinstance(sentences, str):
sentences = [sentences]

Expand All @@ -169,23 +175,25 @@ def __call__(self, sentences):
sentences = translated_sentences

texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences)
sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results

onnx_input = prepare_onnx_input(
self.tokenizer,
self.labels,
self.char2phonemes,
self.chars,
texts,
query_ids,
tokenizer=self.tokenizer,
labels=self.labels,
char2phonemes=self.char2phonemes,
chars=self.chars,
texts=texts,
query_ids=query_ids,
use_mask=self.config.use_mask,
use_char_phoneme=self.config.use_char_phoneme,
window_size=None)

preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
preds, confidences = predict(
session=self.session_g2pW,
onnx_input=onnx_input,
labels=self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]

Expand All @@ -195,7 +203,9 @@ def __call__(self, sentences):

return results

def _prepare_data(self, sentences):
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
Expand Down
11 changes: 6 additions & 5 deletions paddlespeech/t2s/frontend/g2pw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import re


def wordize_and_map(text):
def wordize_and_map(text: str):
words = []
index_map_from_text_to_word = []
index_map_from_word_to_text = []
Expand Down Expand Up @@ -54,8 +55,8 @@ def wordize_and_map(text):
return words, index_map_from_text_to_word, index_map_from_word_to_text


def tokenize_and_map(tokenizer, text):
words, text2word, word2text = wordize_and_map(text)
def tokenize_and_map(tokenizer, text: str):
words, text2word, word2text = wordize_and_map(text=text)

tokens = []
index_map_from_token_to_text = []
Expand All @@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text):
return tokens, index_map_from_text_to_token, index_map_from_token_to_text


def _load_config(config_path):
def _load_config(config_path: os.PathLike):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
config = importlib.util.module_from_spec(spec)
Expand Down Expand Up @@ -130,7 +131,7 @@ def _load_config(config_path):
}


def load_config(config_path, use_default=False):
def load_config(config_path: os.PathLike, use_default: bool=False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():
Expand Down