Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 9, 2022
1 parent aafb93b commit 66d4e35
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,8 @@ def __getitem__(self, idx):

# Negative index comes when we pad the last batch in MegatronPretrainingBatchSampler
# We make the loss_mask zero to mask out loss from these samples
if idx == -1:
logging.info(
'WARNING: Got -1 as item index. Masking loss from this sample'
)
if idx == -1:
logging.info('WARNING: Got -1 as item index. Masking loss from this sample')
loss_mask = torch.zeros_like(loss_mask)

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,12 @@ def loss_func(output_tensor):
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_samples_in_mb = int(loss_mask.sum() / loss_mask.numel() * loss_mask.shape[0])
loss_sum_for_mb = num_valid_samples_in_mb * loss_for_mb
loss_sum_and_mb_size_all_gpu = torch.cat([loss_sum_for_mb.clone().detach().view(1), torch.tensor([num_valid_samples_in_mb]).cuda().clone().detach()])
loss_sum_and_mb_size_all_gpu = torch.cat(
[
loss_sum_for_mb.clone().detach().view(1),
torch.tensor([num_valid_samples_in_mb]).cuda().clone().detach(),
]
)
# Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds)
torch.distributed.all_reduce(
loss_sum_and_mb_size_all_gpu, group=parallel_state.get_data_parallel_group()
Expand Down Expand Up @@ -568,7 +573,9 @@ def validation_step(self, batch, batch_idx):
if loss_sum['loss_sum_and_mb_size'][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(axis=0) if len(loss_sum_tensors_list) > 0 else torch.tensor(0.0)
torch.vstack(loss_sum_tensors_list).sum(axis=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor(0.0)
)
return loss_sum
else:
Expand All @@ -582,7 +589,7 @@ def validation_epoch_end(self, outputs):
averaged_loss = torch.stack(outputs).mean()
else:
# Compute the avg loss by total_loss across all samples / total number of samples
total_loss_and_total_samples = torch.vstack(outputs).sum(axis = 0)
total_loss_and_total_samples = torch.vstack(outputs).sum(axis=0)
avg_loss = total_loss_and_total_samples[0] / total_loss_and_total_samples[1]
averaged_loss = avg_loss.type(torch.float32).cuda()
else:
Expand Down Expand Up @@ -661,7 +668,9 @@ def build_train_valid_test_datasets(self):

return self._train_ds, self._validation_ds, self._test_ds

def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False):
def build_pretraining_data_loader(
self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False
):
"""Buld dataloader given an input dataset."""

logging.info(f'Building dataloader with consumed samples: {consumed_samples}')
Expand All @@ -676,7 +685,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=drop_last,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
elif self.cfg.data.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomBatchSampler(
Expand All @@ -687,7 +696,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, dataset_type=
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=self.cfg.get('drop_last', True),
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
else:
raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"')
Expand Down Expand Up @@ -784,7 +793,7 @@ def setup_validation_data(self, cfg):
logging.info(
f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}'
)

drop_last = True
if not self.cfg.data.get('validation_drop_last', True):
logging.info(f'Drop last in validation dataset is set to False')
Expand Down

0 comments on commit 66d4e35

Please sign in to comment.