Skip to content

Commit

Permalink
added new eval seen speaker dataset
Browse files Browse the repository at this point in the history
Signed-off-by: Paarth Neekhara <[email protected]>
  • Loading branch information
paarthneekhara committed Jan 24, 2025
1 parent b185407 commit c9ccd9a
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions scripts/t5tts/infer_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS',
'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS',
},
'libri_seen_test_v2': {
'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/libri_seen_evalset_from_testclean_v2.json',
'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS',
'feature_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS',
},
'libri_unseen_val': {
'manifest_path' : '/home/pneekhara/2023/SimpleT5NeMo/manifests/dev_clean_withContextAudioPaths_evalset.json',
'audio_dir' : '/datap/misc/LibriTTSfromNemo/LibriTTS',
Expand Down Expand Up @@ -158,6 +163,11 @@ def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature,
del dataset_meta_for_dl[key]

dataset_meta = {dataset: dataset_meta_for_dl}
context_durration_min = model.cfg.get('context_duration_min', 5.0)
context_durration_max = model.cfg.get('context_duration_max', 5.0)
if context_durration_min < 5.0 and context_durration_max > 5.0:
context_durration_min = 5.0
context_durration_max = 5.0 # @pneekhara - For multiencoder models, I want fixed size contexts for fair eval. Not too important though.
test_dataset = T5TTSDataset(
dataset_meta=dataset_meta,
sample_rate=model_cfg.sample_rate,
Expand All @@ -178,8 +188,8 @@ def run_inference(hparams_file, checkpoint_file, datasets, out_dir, temperature,
load_16khz_audio=model.model_type == 'single_encoder_sv_tts',
use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,
pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,
context_duration_min=model.cfg.get('context_duration_min', 5.0),
context_duration_max=model.cfg.get('context_duration_max', 5.0),
context_duration_min=context_durration_min,
context_duration_max=context_durration_max,
)
assert len(test_dataset) == len(manifest_records), "Dataset length and manifest length should be the same. Dataset length: {}, Manifest length: {}".format(len(test_dataset), len(manifest_records))
test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')
Expand Down Expand Up @@ -263,7 +273,7 @@ def main():
parser.add_argument('--hparams_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_hparams.yaml,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_hparams.yaml")
parser.add_argument('--checkpoint_files', type=str, default="/datap/misc/continuouscheckpoints_ks3ks3/multiencoder_small_sp_ks3_epoch302.ckpt,/datap/misc/continuouscheckpoints_ks3ks3/decodercontext_small_sp_ks3Correct_epoch305.ckpt")
parser.add_argument('--codecmodel_path', type=str, default="/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo")
parser.add_argument('--datasets', type=str, default="libri_unseen_test,libri_val")
parser.add_argument('--datasets', type=str, default="libri_seen_test,libri_unseen_test")
parser.add_argument('--base_exp_dir', type=str, default="/datap/misc/eosmount4/AllKernselSize3/NewTransformer")
parser.add_argument('--draco_exp_dir', type=str, default="/lustre/fsw/llmservice_nemo_speechlm/users/pneekhara/gitrepos/experiments/NewT5TTS_FixedPosEmb/AllKernselSize3/NewTransformer")
parser.add_argument('--server_address', type=str, default="[email protected]")
Expand Down

0 comments on commit c9ccd9a

Please sign in to comment.