Skip to content

Commit

Permalink
Merge branch 'main' into t5_lm_adaptation
Browse files Browse the repository at this point in the history
  • Loading branch information
michalivne authored Feb 14, 2022
2 parents b255e58 + 461a866 commit 99fabbb
Show file tree
Hide file tree
Showing 26 changed files with 784 additions and 73 deletions.
5 changes: 2 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2210,14 +2210,13 @@ pipeline {
~trainer.check_val_every_n_epoch'
}
}
// TODO(Oktai15): update it in 1.8.0 version
stage('FastPitch') {
steps {
sh 'python examples/tts/fastpitch.py \
--config-name fastpitch_align \
--config-name fastpitch_align_v1.05 \
train_dataset=/home/TestData/an4_dataset/an4_train.json \
validation_datasets=/home/TestData/an4_dataset/an4_val.json \
prior_folder=/home/TestData/an4_dataset/beta_priors \
sup_data_path=/home/TestData/an4_dataset/beta_priors \
trainer.devices="[0]" \
+trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \
trainer.strategy=null \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: megatron_t5_glue

trainer:
gpus: 2
num_nodes: 1
accelerator: ddp
precision: 16
logger: False # logger provided by exp_manager
checkpoint_callback: False
replace_sampler_ddp: False
max_epochs: 3
max_steps: null # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 300
accumulate_grad_batches: 2
gradient_clip_val: 1.0

exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_t5_glue
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_acc
save_top_k: 10
mode: max
always_save_nemo: False # TODO: add support
filename: 'megatron_t5--{val_acc:.3f}-{step}'
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True

model:
restore_from_path: ??? # Path to a trained T5 .nemo file
tensor_model_parallel_size: 1

data:
train_ds:
task_name: 'mnli'
file_path: ??? # Path to the TSV file for MNLI train ex: '/raid/Data/GLUE/MNLI/train.tsv'
batch_size: 32
shuffle: True
num_workers: 8
pin_memory: True
max_seq_length: 512

validation_ds:
task_name: 'mnli'
file_path: ??? # Path to the TSV file for MNLI dev ex: '/raid/Data/GLUE/MNLI/dev_matched.tsv'
batch_size: 32
shuffle: False
num_workers: 8
pin_memory: True
max_seq_length: 512

optim:
name: fused_adam
lr: 5e-6
weight_decay: 0.0
80 changes: 80 additions & 0 deletions examples/nlp/language_modeling/megatron_t5_glue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPDDPPlugin
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager


@hydra_runner(config_path="conf", config_name="megatron_t5_config_finetune")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)]
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
)
plugins.append(NativeMixedPrecisionPlugin(precision=16, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, **cfg.trainer)

exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer.checkpoint_connector.resume_from_checkpoint_fit_path
if resume_from_checkpoint is not None:
# inject mp_rank into resume_from_checkpoint
if cfg.model.tensor_model_parallel_size is not None and cfg.model.tensor_model_parallel_size > 1:
mp_rank = compute_model_parallel_rank(trainer.local_rank, cfg.model.tensor_model_parallel_size)
resume_from_checkpoint = Path(resume_from_checkpoint)
resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath(f'mp_rank_{mp_rank:02d}').joinpath(
resume_from_checkpoint.name
)
resume_from_checkpoint = str(resume_from_checkpoint)
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer.checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
# Override timer callback to a stateless one
for idx, callback in enumerate(trainer.callbacks):
if isinstance(callback, Timer):
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)

# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision
model = MegatronT5GLUEModel(cfg.model, trainer)
trainer.fit(model)
trainer.validate(model)


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions examples/tts/conf/fastpitch_align_v1.05.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ model:
_target_: nemo.collections.tts.torch.g2ps.EnglishG2p
phoneme_dict: ${phoneme_dict_path}
heteronyms: ${heteronyms_path}
phoneme_probability: 0.5

train_ds:
dataset:
Expand All @@ -101,6 +102,7 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
Expand Down Expand Up @@ -131,6 +133,7 @@ model:
pitch_norm: true
pitch_mean: ${model.pitch_mean}
pitch_std: ${model.pitch_std}
use_beta_binomial_interpolator: true

dataloader_params:
drop_last: false
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/conf/hifigan/hifigan_44100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ train_n_segments: 16384
train_max_duration: null
train_min_duration: 0.75

val_n_segments: 132096
val_n_segments: 131072
val_max_duration: null
val_min_duration: 3

Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/common/data/vocabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import unicodedata
from builtins import str as unicode
from contextlib import contextmanager
from typing import List

import nltk
Expand Down Expand Up @@ -375,3 +376,8 @@ def encode(self, text):
ps = [space] + ps + [space]

return [self._label2id[p] for p in ps]

@contextmanager
def set_phone_prob(self, prob=None):
# Add do nothing since this class doesn't support mixed g2p
yield
16 changes: 16 additions & 0 deletions nemo/collections/common/metrics/classification_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,19 @@ def compute_topk_accuracy(correct_counts_k, total_counts_k):
top_k_scores.append(correct_count / float(total_count))

return top_k_scores


class ExactStringMatchMetric(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, pred: str, target: str):
if pred == target:
self.correct += 1
self.total += 1

def compute(self):
return self.correct.float() / self.total
56 changes: 56 additions & 0 deletions nemo/collections/nlp/data/glue_benchmark/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
return f"mrpc sentence1: {text_a} sentence2: {text_b}"

def label2string(self, label):
return "equivalent" if label == "1" else "not equivalent"


class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
Expand Down Expand Up @@ -91,6 +97,12 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
return f"mnli hypothesis: {text_a} premise: {text_b}"

def label2string(self, label):
return label


class MnliMismatchedProcessor(MnliProcessor):
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
Expand Down Expand Up @@ -125,6 +137,13 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
assert text_b is None
return f"cola sentence: {text_a}"

def label2string(self, label):
return "acceptable" if label == "1" else "not acceptable"


class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
Expand Down Expand Up @@ -153,6 +172,13 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
assert text_b is None
return f"sst2 sentence: {text_a}"

def label2string(self, label):
return "positive" if label == "1" else "negative"


class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
Expand Down Expand Up @@ -182,6 +208,12 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
return f"stsb sentence1: {text_a} sentence2: {text_b}"

def label2string(self, label):
return '%.1f' % float(label)


class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
Expand Down Expand Up @@ -214,6 +246,12 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
return f"qqp question1: {text_a} question2: {text_b}"

def label2string(self, label):
return "duplicate" if label == "1" else "not_duplicate"


class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
Expand Down Expand Up @@ -243,6 +281,12 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
return f"qnli question: {text_a} sentence: {text_b}"

def label2string(self, label):
return label


class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
Expand Down Expand Up @@ -272,6 +316,12 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
return f"rte sentence1: {text_a} sentence2: {text_b}"

def label2string(self, label):
return label


class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
Expand Down Expand Up @@ -301,6 +351,12 @@ def _create_examples(self, lines, set_type):
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples

def get_t5_prompted_query(self, text_a, text_b):
raise NotImplementedError("NeMo-Megatron T5 does not support WNLI at the moment.")

def label2string(self, label):
raise NotImplementedError("NeMo-Megatron T5 does not support WNLI at the moment.")


class InputExample(object):
"""A single training/test example for simple sequence classification.
Expand Down
Loading

0 comments on commit 99fabbb

Please sign in to comment.