From aafb93b6dec3bbe2df6935acbdb5112dd8745ac0 Mon Sep 17 00:00:00 2001 From: shanmugamr1992 Date: Tue, 8 Nov 2022 17:34:38 -0800 Subject: [PATCH] bug fix --- .../conf/megatron_gpt_config.yaml | 1 + .../language_modeling/megatron/gpt_dataset.py | 7 ++++- .../megatron/megatron_batch_samplers.py | 7 +++-- .../language_modeling/megatron_gpt_model.py | 30 ++++++++++++------- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 73297588e9cf..c76f3703301f 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -176,6 +176,7 @@ model: reset_attention_mask: False # Reset attention mask after end-of-document token eod_mask_loss: False # Mask loss for the end of document tokens validation_drop_last: True # Set to false if the last partial validation samples is to be consumed + pad_samples_to_global_batch_size: True # Set to True if you want to pad the last partial batch with -1's to equal global batch size # Nsys profiling options nsys_profile: diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index 73ebebb4e14a..c42098c2af77 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -379,7 +379,12 @@ def __getitem__(self, idx): tokens[tokens == -1] = 0 labels[labels == -1] = 0 - if idx == -1: # Negative index comes when we pad the last batch in MegatronPretrainingBatchSampler + # Negative index comes when we pad the last batch in MegatronPretrainingBatchSampler + # We make the loss_mask zero to mask out loss from these samples + if idx == -1: + logging.info( + 'WARNING: Got -1 as item index. Masking loss from this sample' + ) loss_mask = torch.zeros_like(loss_mask) return { diff --git a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py index f6760d366d9c..e0f23184e5d8 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py @@ -59,6 +59,7 @@ def __init__( data_parallel_rank: int, data_parallel_size: int, drop_last: bool, + pad_samples_to_global_batch_size=False, ) -> None: """Constructor of Megatron-LM style Batch Sampler. @@ -94,6 +95,7 @@ def __init__( self.data_parallel_rank: int = data_parallel_rank self.data_parallel_size: int = data_parallel_size self.drop_last: bool = drop_last + self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size self.update_global_batch_size(global_batch_size) @@ -161,8 +163,9 @@ def __iter__(self): if len(batch) > 0 and not self.drop_last: # start_idx, end_idx = self.get_start_end_idx() indices = [batch[i] for i in range(self.data_parallel_rank, len(batch), self.data_parallel_size)] - num_pad = self._global_batch_size // self.data_parallel_size - len(indices) - indices = indices + [-1] * num_pad + if self.pad_samples_to_global_batch_size: + num_pad = self._global_batch_size // self.data_parallel_size - len(indices) + indices = indices + [-1] * num_pad yield indices diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ea577be256ba..6e6010a02fbd 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -484,12 +484,12 @@ def loss_func(output_tensor): if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_samples_in_mb = int(loss_mask.sum() / loss_mask.numel() * loss_mask.shape[0]) loss_sum_for_mb = num_valid_samples_in_mb * loss_for_mb - loss_sum_for_mb_all_gpu = torch.cat([loss_sum_for_mb.clone().detach().view(1)]) + loss_sum_and_mb_size_all_gpu = torch.cat([loss_sum_for_mb.clone().detach().view(1), torch.tensor([num_valid_samples_in_mb]).cuda().clone().detach()]) # Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds) torch.distributed.all_reduce( - loss_sum_for_mb_all_gpu, group=parallel_state.get_data_parallel_group() + loss_sum_and_mb_size_all_gpu, group=parallel_state.get_data_parallel_group() ) - return loss_for_mb, {'loss_sum': loss_sum_for_mb_all_gpu} + return loss_for_mb, {'loss_sum_and_mb_size': loss_sum_and_mb_size_all_gpu} else: reduced_loss = average_losses_across_data_parallel_group([loss_for_mb]) return loss_for_mb, {'avg': reduced_loss} @@ -563,12 +563,12 @@ def validation_step(self, batch, batch_idx): else: # Get the total loss since micro batches sizes are not uniform loss_sum_tensors_list = [ - loss_sum['loss_sum'] + loss_sum['loss_sum_and_mb_size'] for loss_sum in losses_reduced_per_micro_batch - if not loss_sum['loss_sum'].isnan() + if loss_sum['loss_sum_and_mb_size'][1] > 0 ] loss_sum = ( - torch.concat(loss_sum_tensors_list).sum() if len(loss_sum_tensors_list) > 0 else torch.tensor(0.0) + torch.vstack(loss_sum_tensors_list).sum(axis=0) if len(loss_sum_tensors_list) > 0 else torch.tensor(0.0) ) return loss_sum else: @@ -582,9 +582,9 @@ def validation_epoch_end(self, outputs): averaged_loss = torch.stack(outputs).mean() else: # Compute the avg loss by total_loss across all samples / total number of samples - total_loss = torch.stack(outputs).sum() - avg_loss = total_loss / len(self._validation_ds) - averaged_loss = torch.tensor(avg_loss, dtype=torch.float32).cuda() + total_loss_and_total_samples = torch.vstack(outputs).sum(axis = 0) + avg_loss = total_loss_and_total_samples[0] / total_loss_and_total_samples[1] + averaged_loss = avg_loss.type(torch.float32).cuda() else: averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda() @@ -661,7 +661,7 @@ def build_train_valid_test_datasets(self): return self._train_ds, self._validation_ds, self._test_ds - def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=None, drop_last=True): + def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False): """Buld dataloader given an input dataset.""" logging.info(f'Building dataloader with consumed samples: {consumed_samples}') @@ -676,6 +676,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type= data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), drop_last=drop_last, + pad_samples_to_global_batch_size=pad_samples_to_global_batch_size ) elif self.cfg.data.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomBatchSampler( @@ -686,6 +687,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type= data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), drop_last=self.cfg.get('drop_last', True), + pad_samples_to_global_batch_size=pad_samples_to_global_batch_size ) else: raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') @@ -782,12 +784,18 @@ def setup_validation_data(self, cfg): logging.info( f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' ) + drop_last = True if not self.cfg.data.get('validation_drop_last', True): logging.info(f'Drop last in validation dataset is set to False') drop_last = False + pad_samples_to_global_batch_size = False + if self.cfg.data.get('pad_samples_to_global_batch_size', False): + logging.info('pad_samples_to_global_batch_size set to True') + pad_samples_to_global_batch_size = True + self._validation_dl = self.build_pretraining_data_loader( - self._validation_ds, consumed_samples, "validation", drop_last + self._validation_ds, consumed_samples, "validation", drop_last, pad_samples_to_global_batch_size ) def setup_test_data(self, cfg):