Skip to content

Commit

Permalink
Merge pull request #1 from anmolgupt/mblaz/validation-nan-fix
Browse files Browse the repository at this point in the history
Fix NaN handling during validation
  • Loading branch information
anmolgupt authored Jan 30, 2023
2 parents 1aa5490 + 4b61925 commit 49c4d7a
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,12 @@ def loss_func(output_tensor):
loss_for_ub = self.loss_func(loss_mask, output_tensor)
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_tokens_in_ub = loss_mask.sum()
loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub
if loss_for_ub.isnan():
assert loss_mask.count_nonzero() == 0, 'Got NaN loss with non-empty input'
loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub)
else:
loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub

loss_sum_and_ub_size_all_gpu = torch.cat(
[
loss_sum_for_ub.clone().detach().view(1),
Expand Down

0 comments on commit 49c4d7a

Please sign in to comment.