Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rampup batch size support for Megatron GPT #6424

Merged
merged 18 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ model:
# gradient accumulation will be done automatically based on data_parallel_size
micro_batch_size: 4 # limited by GPU memory
global_batch_size: 8 # will use more micro batches to reach global batch size
rampup_batch_size: null # Should be a list of 3 values: [<start_batch_size>, <batch_size_increment>, <ramup_samples>]
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism
virtual_pipeline_model_parallel_size: null # interleaved pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0),
micro_batch_size=cfg.get('micro_batch_size'),
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size'),
use_fp8=cfg.get('fp8', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from nemo.utils import logging

try:
import apex.transformer.pipeline_parallel.utils
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
Expand All @@ -63,6 +64,7 @@
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -446,15 +448,22 @@ def training_step(self, dataloader_iter, batch_idx):
'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1,
)

consumed_samples = self.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
'consumed_samples',
self.compute_consumed_samples(self.trainer.global_step - self.init_global_step),
prog_bar=True,
rank_zero_only=True,
batch_size=1,
'consumed_samples', consumed_samples, prog_bar=True, rank_zero_only=True, batch_size=1,
)

if self.cfg.get('rampup_batch_size', None):
micro_batch_size = self.cfg.get('micro_batch_size', 1)
current_global_batch_size = get_num_microbatches() * micro_batch_size
self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1)

num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
num_microbatch_calculator.update(
consumed_samples=consumed_samples, consistency_check=True,
)

return loss_mean

def backward(self, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank=None,
micro_batch_size=None,
global_batch_size=None,
rampup_batch_size=None,
use_fp8=False,
seed=1234,
apex_transformer_log_level=30,
Expand Down Expand Up @@ -115,7 +116,7 @@ def initialize_model_parallel_for_nemo(
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=None,
rampup_batch_size=rampup_batch_size,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatches):
Expand Down
158 changes: 158 additions & 0 deletions tests/collections/nlp/test_rampup_bath_size.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in the file name, bath -> batch

Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import torch
from omegaconf import DictConfig
from pytorch_lightning import Trainer

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy

DEVICE_CAPABILITY = None
if torch.cuda.is_available():
DEVICE_CAPABILITY = torch.cuda.get_device_capability()


@pytest.fixture()
def model_cfg(test_data_dir):

model_cfg = {
'precision': 16,
'micro_batch_size': 4,
'global_batch_size': 8,
'rampup_batch_size': [4, 2, 100],
'tensor_model_parallel_size': 1,
'pipeline_model_parallel_size': 1,
'resume_from_checkpoint': None,
'encoder_seq_length': 512,
'max_position_embeddings': 512,
'num_layers': 1,
'hidden_size': 128,
'ffn_hidden_size': 512,
'num_attention_heads': 2,
'init_method_std': 0.02,
'hidden_dropout': 0.1,
'kv_channels': None,
'apply_query_key_layer_scaling': True,
'layernorm_epsilon': 1e-5,
'make_vocab_size_divisible_by': 128,
'pre_process': True,
'post_process': True,
'persist_layer_norm': True,
'gradient_as_bucket_view': True,
'tokenizer': {
'library': 'megatron',
'type': 'GPT2BPETokenizer',
'model': None,
'vocab_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/vocab.json'),
'merge_file': os.path.join(test_data_dir, 'nlp/gpt_vocab_merges/merges.txt'),
'delimiter': None,
},
'native_amp_init_scale': 4294967296,
'native_amp_growth_interval': 1000,
'hysteresis': 2,
'fp32_residual_connection': False,
'fp16_lm_cross_entropy': False,
'megatron_amp_O2': False,
'seed': 1234,
'use_cpu_initialization': False,
'onnx_safe': False,
'apex_transformer_log_level': 30,
'activations_checkpoint_method': None,
'activations_checkpoint_num_layers': 1,
'data': {
'data_prefix': '???',
'index_mapping_dir': None,
'data_impl': 'mmap',
'splits_string': '900,50,50',
'seq_length': 512,
'skip_warmup': True,
'num_workers': 2,
'dataloader_type': 'single',
'reset_position_ids': False,
'reset_attention_mask': False,
'eod_mask_loss': False,
},
'optim': {
'name': 'fused_adam',
'lr': 2e-4,
'weight_decay': 0.01,
'betas': [0.9, 0.98],
'sched': {'name': 'CosineAnnealing', 'warmup_steps': 500, 'constant_steps': 50000, 'min_lr': '2e-5'},
},
}
return model_cfg


@pytest.fixture()
def trainer_cfg():

trainer_cfg = {
'devices': 1,
'num_nodes': 1,
'accelerator': 'gpu',
'precision': 16,
'logger': False,
'enable_checkpointing': False,
'replace_sampler_ddp': False,
'max_epochs': 1000,
'max_steps': 100000,
'log_every_n_steps': 10,
'val_check_interval': 100,
'limit_val_batches': 50,
'limit_test_batches': 500,
'accumulate_grad_batches': 1,
'gradient_clip_val': 1.0,
}

return trainer_cfg


@pytest.fixture()
def gpt_model(model_cfg, trainer_cfg):
strategy = NLPDDPStrategy()
trainer = Trainer(strategy=strategy, **trainer_cfg)
cfg = DictConfig(model_cfg)

try:
model = MegatronGPTModel(cfg, trainer)
return model
except Exception as exception:
return exception


@pytest.fixture()
def rampup_batch_size():
return [4, 2, 100]


@pytest.fixture()
def rampup_batch_size_invalid():
return "Microbatch calculator already initialized."


@pytest.mark.run_only_on('GPU')
class TestRampupBatchSize:
@pytest.mark.unit
def test_rampup_bs(self, gpt_model, rampup_batch_size):
assert gpt_model.cfg.rampup_batch_size == rampup_batch_size

@pytest.mark.unit
def test_invalid_rampup_bs(self, gpt_model, rampup_batch_size_invalid):
assert str(gpt_model) == rampup_batch_size_invalid