Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shanmugamr1992 committed Nov 9, 2022
1 parent 1910d93 commit aafb93b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ model:
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
validation_drop_last: True # Set to false if the last partial validation samples is to be consumed
pad_samples_to_global_batch_size: True # Set to True if you want to pad the last partial batch with -1's to equal global batch size

# Nsys profiling options
nsys_profile:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,12 @@ def __getitem__(self, idx):
tokens[tokens == -1] = 0
labels[labels == -1] = 0

if idx == -1: # Negative index comes when we pad the last batch in MegatronPretrainingBatchSampler
# Negative index comes when we pad the last batch in MegatronPretrainingBatchSampler
# We make the loss_mask zero to mask out loss from these samples
if idx == -1:
logging.info(
'WARNING: Got -1 as item index. Masking loss from this sample'
)
loss_mask = torch.zeros_like(loss_mask)

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool,
pad_samples_to_global_batch_size=False,
) -> None:
"""Constructor of Megatron-LM style Batch Sampler.
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self.data_parallel_rank: int = data_parallel_rank
self.data_parallel_size: int = data_parallel_size
self.drop_last: bool = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size

self.update_global_batch_size(global_batch_size)

Expand Down Expand Up @@ -161,8 +163,9 @@ def __iter__(self):
if len(batch) > 0 and not self.drop_last:
# start_idx, end_idx = self.get_start_end_idx()
indices = [batch[i] for i in range(self.data_parallel_rank, len(batch), self.data_parallel_size)]
num_pad = self._global_batch_size // self.data_parallel_size - len(indices)
indices = indices + [-1] * num_pad
if self.pad_samples_to_global_batch_size:
num_pad = self._global_batch_size // self.data_parallel_size - len(indices)
indices = indices + [-1] * num_pad
yield indices


Expand Down
30 changes: 19 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,12 @@ def loss_func(output_tensor):
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_samples_in_mb = int(loss_mask.sum() / loss_mask.numel() * loss_mask.shape[0])
loss_sum_for_mb = num_valid_samples_in_mb * loss_for_mb
loss_sum_for_mb_all_gpu = torch.cat([loss_sum_for_mb.clone().detach().view(1)])
loss_sum_and_mb_size_all_gpu = torch.cat([loss_sum_for_mb.clone().detach().view(1), torch.tensor([num_valid_samples_in_mb]).cuda().clone().detach()])
# Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds)
torch.distributed.all_reduce(
loss_sum_for_mb_all_gpu, group=parallel_state.get_data_parallel_group()
loss_sum_and_mb_size_all_gpu, group=parallel_state.get_data_parallel_group()
)
return loss_for_mb, {'loss_sum': loss_sum_for_mb_all_gpu}
return loss_for_mb, {'loss_sum_and_mb_size': loss_sum_and_mb_size_all_gpu}
else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_mb])
return loss_for_mb, {'avg': reduced_loss}
Expand Down Expand Up @@ -563,12 +563,12 @@ def validation_step(self, batch, batch_idx):
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum']
loss_sum['loss_sum_and_mb_size']
for loss_sum in losses_reduced_per_micro_batch
if not loss_sum['loss_sum'].isnan()
if loss_sum['loss_sum_and_mb_size'][1] > 0
]
loss_sum = (
torch.concat(loss_sum_tensors_list).sum() if len(loss_sum_tensors_list) > 0 else torch.tensor(0.0)
torch.vstack(loss_sum_tensors_list).sum(axis=0) if len(loss_sum_tensors_list) > 0 else torch.tensor(0.0)
)
return loss_sum
else:
Expand All @@ -582,9 +582,9 @@ def validation_epoch_end(self, outputs):
averaged_loss = torch.stack(outputs).mean()
else:
# Compute the avg loss by total_loss across all samples / total number of samples
total_loss = torch.stack(outputs).sum()
avg_loss = total_loss / len(self._validation_ds)
averaged_loss = torch.tensor(avg_loss, dtype=torch.float32).cuda()
total_loss_and_total_samples = torch.vstack(outputs).sum(axis = 0)
avg_loss = total_loss_and_total_samples[0] / total_loss_and_total_samples[1]
averaged_loss = avg_loss.type(torch.float32).cuda()
else:
averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda()

Expand Down Expand Up @@ -661,7 +661,7 @@ def build_train_valid_test_datasets(self):

return self._train_ds, self._validation_ds, self._test_ds

def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=None, drop_last=True):
def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False):
"""Buld dataloader given an input dataset."""

logging.info(f'Building dataloader with consumed samples: {consumed_samples}')
Expand All @@ -676,6 +676,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=drop_last,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size
)
elif self.cfg.data.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomBatchSampler(
Expand All @@ -686,6 +687,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=self.cfg.get('drop_last', True),
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size
)
else:
raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"')
Expand Down Expand Up @@ -782,12 +784,18 @@ def setup_validation_data(self, cfg):
logging.info(
f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}'
)

drop_last = True
if not self.cfg.data.get('validation_drop_last', True):
logging.info(f'Drop last in validation dataset is set to False')
drop_last = False
pad_samples_to_global_batch_size = False
if self.cfg.data.get('pad_samples_to_global_batch_size', False):
logging.info('pad_samples_to_global_batch_size set to True')
pad_samples_to_global_batch_size = True

self._validation_dl = self.build_pretraining_data_loader(
self._validation_ds, consumed_samples, "validation", drop_last
self._validation_ds, consumed_samples, "validation", drop_last, pad_samples_to_global_batch_size
)

def setup_test_data(self, cfg):
Expand Down

0 comments on commit aafb93b

Please sign in to comment.