Skip to content

Commit

Permalink
Update conflict merges
Browse files Browse the repository at this point in the history
Signed-off-by: MaximumEntropy <[email protected]>
  • Loading branch information
MaximumEntropy committed Feb 16, 2022
2 parents 5da9724 + 2ebca22 commit acf4909
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 19 deletions.
11 changes: 7 additions & 4 deletions examples/nlp/language_modeling/conf/megatron_t5_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,20 @@ exp_manager:
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
filename: 'megatron_t5--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}


model:
# model parallelism
micro_batch_size: 4
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_model_parallel_size: 1 # T5 PP is not supported yet. Use 1 for now.

# model architecture
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler

megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting.

seq_length: 512
max_position_embeddings: ${.seq_length}
num_layers: 12
Expand All @@ -57,10 +58,12 @@ model:
num_attention_heads: 12
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
persist_layer_norm: True # Use of persistent fused layer norm kernel.
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
encoder_arch: 'transformer'
decoder_arch: 'transformer'

Expand Down Expand Up @@ -90,7 +93,7 @@ model:
data:
# Path to data must be specified by the user.
# can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-t5_00_text_document,.5,/raid/data/pile/my-t5_01_text_document]",
# Or see example below:
# Or see example below:
# data_prefix:
# - .5
# - /raid/data/pile/my-t5_00_text_document
Expand All @@ -105,8 +108,8 @@ model:
num_workers: 0
dataloader_type: single # cyclic
masked_lm_prob: 0.15
short_seq_prob: 0.1
dataset_type: 't5'
short_seq_prob: 0.0
max_ngram_size: 10
mean_ngram_size: null
geometric_dist: True
Expand Down
30 changes: 23 additions & 7 deletions examples/nlp/language_modeling/megatron_t5_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPDDPPlugin
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, MegatronHalfPrecisionPlugin, NLPDDPPlugin
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager
Expand All @@ -33,13 +33,29 @@ def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes, find_unused_parameters=False)]
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = [
NLPDDPPlugin(
num_nodes=cfg.trainer.num_nodes,
no_ddp_communication_hook=(
megatron_amp_o2 and cfg.trainer.precision == 'bf16'
), # Only bf16 uses fp32_grad_accum.
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
plugins.append(NativeMixedPrecisionPlugin(precision=16, device='cuda', scaler=scaler))
]
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
)
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.modules.common.megatron.clip_grads import clip_grad_norm_fp32
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import (
MegatronTokenLevelEncoderDecoderModule,
)
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler
from nemo.utils import AppState, logging

try:
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -80,6 +83,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
fp16_cross_entropy=cfg.get('fp16_lm_cross_entropy', False),
use_cpu_initialization=cfg.get('use_cpu_initialization', False),
hidden_dropout=cfg.get('hidden_dropout', 0.1),
attention_dropout=cfg.get('attention_dropout', 0.1),
precision=cfg.get('precision', 16),
fp32_residual_connection=cfg.get('fp32_residual_connection', False),
activations_checkpoint_method=cfg.get('activations_checkpoint_method', None),
Expand All @@ -91,6 +95,18 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
onnx_safe=cfg.get('onnx_safe', False),
)

self.setup_optimizer_param_groups()

self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False)

if self.megatron_amp_o2:

# Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type
self.enc_dec_model.cuda(torch.cuda.current_device())

# Model wrapper to convert both model and inputs to half precision
self.enc_dec_model = Float16Module(module=self.enc_dec_model, precision=cfg.precision)

def _build_tokenizer(self):
"""
Default tokenizer is based on available nemo tokenizers.
Expand Down Expand Up @@ -142,6 +158,51 @@ def forward(

return ret_dict

def setup_optimizer_param_groups(self):
"""ModelPT override. Optimizer will get self._optimizer_param_groups"""
self._optimizer_param_groups = _get_params_for_weight_decay_optimization([self.enc_dec_model])

def configure_optimizers(self):
self.setup_optimization()

# Wrap the baseline optimizer with the optimizer class with master parameters
if self.megatron_amp_o2 and self._optimizer is not None:
if self.cfg.precision == 'bf16':
fp32_grad_accum = True
contiguous_grad_bucket = True
async_grad_allreduce = True

elif self.cfg.precision == 16:
fp32_grad_accum = False
# TODO: contiguous grad bucket for fp16 is also planned to be supported
contiguous_grad_bucket = False
async_grad_allreduce = False

self._optimizer = MainParamsOptimizerWrapper(
self._optimizer,
fp32_grad_accum=fp32_grad_accum,
contiguous_grad_bucket=contiguous_grad_bucket,
async_grad_allreduce=async_grad_allreduce,
)
assert self._trainer.max_steps is not None, "'max_steps' is missing in trainer config."
sched_config = self._cfg.optim.sched
sched_config['max_steps'] = self._trainer.max_steps
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=sched_config, train_dataloader=self._train_dl
)

if self._scheduler is None:
return self._optimizer
else:
return [self._optimizer], [self._scheduler]

def get_parameters(self):
params = []
for param_group in self._optimizer_param_groups:
for param in param_group['params']:
params.append(param)
return params

def training_step(self, batch, batch_idx):
tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask = self.process_batch(batch)

Expand Down Expand Up @@ -307,8 +368,15 @@ def configure_gradient_clipping(self, *args, **kwargs):
if clip_val <= 0:
return

parameters = self.enc_dec_model.parameters()
clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)
if self.megatron_amp_o2:
# grep fp32 master parameters for gradient clipping
parameters = self._optimizer.get_parameters()
else:
parameters = self.get_parameters()

grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val)

self.log('grad_norm', grad_norm, rank_zero_only=True)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
request = batch
Expand All @@ -333,7 +401,6 @@ def decode(self, tokens_enc, enc_mask, num_tokens_to_generate):
predicted_tokens_dec = (
torch.LongTensor([self.tokenizer.bos_id] * tokens_enc.size(0)).unsqueeze(1).to(tokens_enc.device)
)

for _ in range(num_tokens_to_generate):
dec_mask = predicted_tokens_dec != self.tokenizer.pad_id
token_logits = itemgetter("token_logits")(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_decoder_model(
init_method_std=0.02,
use_cpu_initialization=False,
hidden_dropout=0.1,
attention_dropout=0.1,
precision=16,
fp32_residual_connection=False,
activations_checkpoint_method=None,
Expand Down Expand Up @@ -101,6 +102,7 @@ def get_decoder_model(
post_process=post_process,
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
precision=precision,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_encoder_model(
init_method_std=0.02,
use_cpu_initialization=False,
hidden_dropout=0.1,
attention_dropout=0.1,
precision=16,
fp32_residual_connection=False,
activations_checkpoint_method=None,
Expand Down Expand Up @@ -100,6 +101,7 @@ def get_encoder_model(
post_process=post_process,
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
precision=precision,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@

import numpy as np
import torch
from apex.transformer.utils import ensure_divisibility

from nemo.utils import AppState, logging

try:
from apex.transformer import tensor_parallel
from apex.transformer.log_util import set_logging_level
from apex.transformer.parallel_state import (
get_pipeline_model_parallel_rank,
set_pipeline_model_parallel_rank,
set_pipeline_model_parallel_world_size,
set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size,
)
from apex.transformer.log_util import set_logging_level
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator
from apex.transformer.utils import ensure_divisibility

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
use_cpu_initialization=False,
decoder_attn_mask_type=AttnMaskType.causal,
hidden_dropout=0.1,
attention_dropout=0.1,
precision=16,
fp32_residual_connection=False,
activations_checkpoint_method=None,
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
masked_softmax_fusion=masked_softmax_fusion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
use_cpu_initialization=False,
encoder_attn_mask_type=AttnMaskType.padding,
hidden_dropout=0.1,
attention_dropout=0.1,
precision=16,
fp32_residual_connection=False,
activations_checkpoint_method=None,
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
activations_checkpoint_num_layers=activations_checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
masked_softmax_fusion=masked_softmax_fusion,
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/nlp/modules/common/megatron/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def float16_converter(val):
return val.bfloat16()

else:
raise Exception(f'{precision} is not supported. Float16Module supports ' 'only fp16 and bf16.')
raise Exception(
f'precision {precision} is not supported. Float16Module (megatron_amp_O2) supports '
'only fp16 and bf16.'
)

self.float16_converter = float16_converter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
fp16_cross_entropy=False,
use_cpu_initialization=False,
hidden_dropout=0.1,
attention_dropout=0.1,
precision=16,
fp32_residual_connection=False,
activations_checkpoint_method=None,
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
init_method_std=init_method_std,
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
precision=precision,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
Expand Down Expand Up @@ -175,6 +177,7 @@ def __init__(
init_method_std=init_method_std,
use_cpu_initialization=use_cpu_initialization,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
precision=precision,
fp32_residual_connection=fp32_residual_connection,
activations_checkpoint_method=activations_checkpoint_method,
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ def __init__(
activations_checkpoint_num_layers=1,
layernorm_epsilon=1e-5,
hidden_dropout=0.1,
attention_dropout=0.1,
use_cpu_initialization=False,
bias_gelu_fusion=True,
masked_softmax_fusion=True,
Expand Down Expand Up @@ -671,6 +672,7 @@ def build_layer(layer_number):
fp32_residual_connection=fp32_residual_connection,
layernorm_epsilon=layernorm_epsilon,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
use_cpu_initialization=use_cpu_initialization,
bias_gelu_fusion=bias_gelu_fusion,
masked_softmax_fusion=masked_softmax_fusion,
Expand Down
13 changes: 11 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,20 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# Release strict state dict matching when using Megatron AMP-O2 to skip matching
# half-precision module wrapper module.
# TODO: Refactor this to be more generic.
model_key = None
model_attr = None
if hasattr(self.lightning_module, 'model'):
if isinstance(self.lightning_module.model, Float16Module):
model_key = 'model'
model_attr = self.lightning_module.model
elif hasattr(self.lightning_module, 'enc_dec_model'):
model_key = 'enc_dec_model'
model_attr = self.lightning_module.enc_dec_model
if model_key is not None:
if isinstance(model_attr, Float16Module):
new_state_dict = {}
for key in checkpoint['state_dict'].keys():
new_key = key.replace('model.', 'model.module.', 1)
new_key = key.replace(f'{model_key}.', f'{model_key}.module.', 1)
new_state_dict[new_key] = checkpoint['state_dict'][key]
checkpoint['state_dict'] = new_state_dict

Expand Down

0 comments on commit acf4909

Please sign in to comment.