Skip to content

Commit

Permalink
t5 inference working, lm_vocab size passed to dataset class to handle…
Browse files Browse the repository at this point in the history
… phoneme token indices correctly

Signed-off-by: Paarth Neekhara <[email protected]>
  • Loading branch information
paarthneekhara committed Dec 13, 2023
1 parent 6e0d89f commit e4f9b09
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ trainer:
precision: 32
logger: False
enable_checkpointing: False
replace_sampler_ddp: False
use_distributed_sampler: False
max_epochs: 10000
max_steps: -1
log_every_n_steps: 10
val_check_interval: null
check_val_every_n_epoch: 3
gradient_clip_val: 1.0
resume_from_checkpoint: null
limit_val_batches: 25

exp_manager:
explicit_log_dir: null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,9 @@
import torch
import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment

from nemo.collections.nlp.models.language_modeling.megatron_t5_speechlm_model import MegatronT5SpeechLMModel

# from nemo.collections.nlp.models.language_modeling.megatron_t5_speechlm_pretrain_model import (
# MegatronT5SpeechLMModel,
# )
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
Expand All @@ -39,38 +28,19 @@ def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam'

plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce
gradient_as_bucket_view=False,
find_unused_parameters=False,
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 8),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2 and not with_distributed_adam:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
# MegatronTrainerBuilder compat checks
if "gradient_as_bucket_view" not in cfg.model:
with open_dict(cfg):
cfg.model.gradient_as_bucket_view=False

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
trainer = MegatronTrainerBuilder(cfg).create_trainer()
exp_manager(trainer, cfg.exp_manager)

# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

# load existing or init new soft prompt T5 model
checkpoint_path = cfg.get('checkpoint_path', None)
assert checkpoint_path is not None, "Please specify checkpoint_path in the config file"
model = MegatronT5SpeechLMModel.load_from_checkpoint(
Expand All @@ -80,6 +50,5 @@ def main(cfg) -> None:
model = model.cuda()
trainer.test(model)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def on_validation_epoch_end(self):
def test_step(self, batch, batch_idx):
return self.predict_step(batch, batch_idx)

def test_epoch_end(self, outputs):
def on_test_epoch_end(self, outputs):
"""
This might still be broken for lightning 2.0. to fix: see
https://github.com/NVIDIA/NeMo/blob/9bdf4d12276ee8f95a340cf2f7f340e9b5b74a7e/docs/source/starthere/migration-guide.rst
Expand Down Expand Up @@ -1084,6 +1084,7 @@ def build_virtual_prompt_dataset(
use_attention_prior=self.cfg.data.get('use_attention_prior', False),
attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0),
cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0),
lm_vocab_size=self.lm_vocab_size,
)

rank = parallel_state.get_data_parallel_rank()
Expand Down Expand Up @@ -1143,6 +1144,7 @@ def build_virtual_prompt_tarred_dataset(
use_attention_prior=self.cfg.data.get('use_attention_prior', False),
attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0),
cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0),
lm_vocab_size=self.lm_vocab_size,
)
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
Expand Down

0 comments on commit e4f9b09

Please sign in to comment.