Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS] Remove unused TTS eval function #5605

Merged
merged 1 commit into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
import torch
from numba import jit, prange
from numpy import ndarray
from pesq import pesq
from pystoi import stoi

from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens
from nemo.utils import logging
Expand Down Expand Up @@ -480,38 +478,6 @@ def remove(conv_list):
return new_conv_list


def eval_tts_scores(
y_clean: ndarray, y_est: ndarray, T_ys: Sequence[int] = (0,), sampling_rate=22050
) -> Dict[str, float]:
"""
calculate metric using EvalModule. y can be a batch.
Args:
y_clean: real audio
y_est: estimated audio
T_ys: length of the non-zero parts of the histograms
sampling_rate: The used Sampling rate.

Returns:
A dictionary mapping scoring systems (string) to numerical scores.
1st entry: 'STOI'
2nd entry: 'PESQ'
"""

if y_clean.ndim == 1:
y_clean = y_clean[np.newaxis, ...]
y_est = y_est[np.newaxis, ...]
if T_ys == (0,):
T_ys = (y_clean.shape[1],) * y_clean.shape[0]

clean = y_clean[0, : T_ys[0]]
estimated = y_est[0, : T_ys[0]]
stoi_score = stoi(clean, estimated, sampling_rate, extended=False)
pesq_score = pesq(16000, np.asarray(clean), estimated, 'wb')
## fs was set 16,000, as pesq lib doesnt currently support felxible fs.

return {'STOI': stoi_score, 'PESQ': pesq_score}


def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
"""A function that takes predicted durations per encoded token, and repeats enc_out according to the duration.
NOTE: durations.shape[1] == enc_out.shape[1]
Expand Down
2 changes: 0 additions & 2 deletions requirements/requirements_tts.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@ librosa
matplotlib
nltk
pandas
pesq
pypinyin
pystoi