Skip to content

Commit

Permalink
Merge branch 'feature_ptune' of github.com:NVIDIA/NeMo into feature_p…
Browse files Browse the repository at this point in the history
…tune
  • Loading branch information
yidong72 committed Jan 25, 2022
2 parents 1d76088 + b1005fe commit 593125e
Show file tree
Hide file tree
Showing 16 changed files with 1,313 additions and 5 deletions.
3 changes: 3 additions & 0 deletions docs/source/tts/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ the end-to-end pipeline.
* - HiFiGAN
- :class:`Vocoder<nemo.collections.tts.models.base.Vocoder>`
- https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_hifigan
* - UnivNet
- :class:`Vocoder<nemo.collections.tts.models.base.Vocoder>`
- https://ngc.nvidia.com/catalog/models/nvidia:nemo:tts_en_lj_univnet
- GAN-based vocoder
* - FastPitch_HifiGan_E2E
- :class:`TextToWaveform<nemo.collections.tts.models.base.TextToWaveform>`
Expand Down
8 changes: 8 additions & 0 deletions examples/tts/conf/univnet/model/generator/c16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.univnet_modules.Generator
noise_dim: 64
channel_size: 16
dilations: [1, 3, 9, 27]
strides: [8, 8, 4]
lrelu_slope: 0.2
kpnet_conv_size: 3
8 changes: 8 additions & 0 deletions examples/tts/conf/univnet/model/generator/c32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _group_
_target_: nemo.collections.tts.modules.univnet_modules.Generator
noise_dim: 64
channel_size: 32
dilations: [1, 3, 9, 27]
strides: [8, 8, 4]
lrelu_slope: 0.2
kpnet_conv_size: 3
13 changes: 13 additions & 0 deletions examples/tts/conf/univnet/model/train_ds/train_ds.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# @package _group_
dataset:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${train_dataset}
max_duration: null
min_duration: 0.75
n_segments: 16384
trim: false
dataloader_params:
drop_last: false
shuffle: true
batch_size: 32
num_workers: 4
11 changes: 11 additions & 0 deletions examples/tts/conf/univnet/model/train_ds/train_ds_finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# @package _group_
dataset:
_target_: "nemo.collections.tts.data.datalayers.MelAudioDataset"
manifest_filepath: ${train_dataset}
min_duration: 0.75
n_segments: 16384
dataloader_params:
drop_last: false
shuffle: true
batch_size: 32
num_workers: 4
13 changes: 13 additions & 0 deletions examples/tts/conf/univnet/model/validation_ds/val_ds.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# @package _group_
dataset:
_target_: "nemo.collections.tts.data.datalayers.AudioDataset"
manifest_filepath: ${validation_datasets}
max_duration: null
min_duration: null
n_segments: -1
trim: false
dataloader_params:
drop_last: false
shuffle: false
batch_size: 16
num_workers: 1
11 changes: 11 additions & 0 deletions examples/tts/conf/univnet/model/validation_ds/val_ds_finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# @package _group_
dataset:
_target_: "nemo.collections.tts.data.datalayers.MelAudioDataset"
manifest_filepath: ${validation_datasets}
min_duration: 3
n_segments: 66048
dataloader_params:
drop_last: false
shuffle: false
batch_size: 16
num_workers: 4
79 changes: 79 additions & 0 deletions examples/tts/conf/univnet/univnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
name: "UnivNet"
train_dataset: ???
validation_datasets: ???

defaults:
- model/generator: c32
- model/train_ds: train_ds
- model/validation_ds: val_ds

model:
discriminator:
mpd:
periods: [2,3,5,7,11]
kernel_size: 5
stride: 3
use_spectral_norm: false
lrelu_slope: 0.2
mrd:
resolutions: [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] # (filter_length, hop_length, win_length)
use_spectral_norm: false
lrelu_slope: 0.2
preprocessor:
_target_: nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures
dither: 0.0
frame_splicing: 1
nfilt: 80
highfreq: 8000
log: true
log_zero_guard_type: clamp
log_zero_guard_value: 1e-05
lowfreq: 0
mag_power: 1.0
n_fft: 1024
n_window_size: 1024
n_window_stride: 256
normalize: null
pad_to: 0
pad_value: -11.52
preemph: null
sample_rate: 22050
window: hann
use_grads: false
exact_pad: true

optim:
_target_: torch.optim.AdamW
lr: 0.0001
betas: [0.5, 0.9]

max_steps: 1000000
stft_lamb: 2.5
denoise_strength: 0.0025

trainer:
gpus: -1 # number of gpus
max_steps: ${model.max_steps}
num_nodes: 1
accelerator: ddp
accumulate_grad_batches: 1
checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager
flush_logs_every_n_steps: 200
log_every_n_steps: 100
check_val_every_n_epoch: 10

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true

create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null

create_checkpoint_callback: True
checkpoint_callback_params:
monitor: "val_loss"
mode: "min"
35 changes: 35 additions & 0 deletions examples/tts/univnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl

from nemo.collections.common.callbacks import LogEpochTimeCallback
from nemo.collections.tts.models import UnivNetModel
from nemo.core.config import hydra_runner
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf/univnet", config_name="univnet")
def main(cfg):
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
model = UnivNetModel(cfg=cfg.model, trainer=trainer)
lr_logger = pl.callbacks.LearningRateMonitor()
epoch_time_logger = LogEpochTimeCallback()
trainer.callbacks.extend([lr_logger, epoch_time_logger])
trainer.fit(model)


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
2 changes: 2 additions & 0 deletions nemo/collections/tts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from nemo.collections.tts.models.uniglow import UniGlowModel
from nemo.collections.tts.models.waveglow import WaveGlowModel
from nemo.collections.tts.models.mixer_tts import MixerTTSModel
from nemo.collections.tts.models.univnet import UnivNetModel
except ModuleNotFoundError:
pass

Expand All @@ -55,4 +56,5 @@
"FastSpeech2HifiGanE2EModel",
"AlignerModel",
"MixerTTSModel",
"UnivNetModel",
]
Loading

0 comments on commit 593125e

Please sign in to comment.