Skip to content

Commit

Permalink
Better NaN/inf loss handling for O0 (skip step across workers) (#637)
Browse files Browse the repository at this point in the history
* Better NaN/inf loss handling for O0 (skip step across workers)

Signed-off-by: Jocelyn Huang <[email protected]>

* Add entry to changelog

Signed-off-by: Jocelyn Huang <[email protected]>

* Change NaN/inf all_reduce check to use MAX instead of default SUM

Signed-off-by: Jocelyn Huang <[email protected]>
  • Loading branch information
redoctopus authored May 18, 2020
1 parent 99ef493 commit 97679e8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ To release a new version, please update the changelog as followed:
- ContextNet Encoder + Decoder Initial Support ([PR #630](https://github.com/NVIDIA/NeMo/pull/630)) - @titu1994

### Changed
- Syncs across workers at each step to check for NaN or inf loss. Terminates all workers if stop\_on\_nan\_loss is set (as before), lets Apex deal with it if apex.amp optimization level is O1 or higher, and skips the step across workers otherwise. ([PR #637](https://github.com/NVIDIA/NeMo/pull/637)) - @redoctopus

### Dependencies Update

Expand Down
59 changes: 15 additions & 44 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,22 +1078,6 @@ def deployment_export(module, output: str, d_format: DeploymentFormat, input_exa
output_example=output_example,
)

def _check_nan_or_inf(self, placement_gpu, nan_or_inf, steps_per_nan_check=None):
# Note that nan_or_inf only gets set if stop_on_nan loss is True, or if using O0/not using apex.amp.
if not placement_gpu:
return
if steps_per_nan_check is None or self.step % steps_per_nan_check == 0:
world_size = dist.get_world_size()
# We use dtype=int because nccl backend doesn't support torch.bool
nan_inf_tensor = torch.tensor(nan_or_inf, dtype=int).cuda()
nan_inf_results = []
for _ in range(world_size):
nan_inf_results.append(torch.empty_like(nan_inf_tensor))
dist.all_gather(nan_inf_results, nan_inf_tensor)
for nan_inf in nan_inf_results:
if nan_inf:
raise ValueError('Terminating due to previous NaN or inf.')

def train(
self,
tensors_to_optimize=None,
Expand All @@ -1104,7 +1088,6 @@ def train(
lr_policy=None,
batches_per_step=None,
stop_on_nan_loss=False,
steps_per_nan_check=100,
synced_batchnorm=False,
synced_batchnorm_groupsize=0,
gradient_predivide=False,
Expand Down Expand Up @@ -1353,8 +1336,6 @@ def train(
# Do action start callbacks
self._perform_on_action_start(callbacks=callbacks)

nan_or_inf = False

# MAIN TRAINING LOOP
# iteration over epochs
while num_epochs is None or self.epoch_num < num_epochs:
Expand Down Expand Up @@ -1418,26 +1399,22 @@ def train(
curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1]
final_loss = 0
for tensor in curr_tensors_to_optimize:
if (
torch.isnan(registered_tensors[tensor.unique_name]).any()
or torch.isinf(registered_tensors[tensor.unique_name]).any()
):
if (
(stop_on_nan_loss)
or (self._optim_level not in AmpOptimizations)
or (self._optim_level == Optimization.mxprO0)
):
# Set flag here and terminate at next all_gather check.
nan_or_inf = True
logging.warning(
'Loss is NaN or inf at step %d, will terminate within the'
' next steps_per_nan_check steps',
self.step,
)
else:
logging.warning('Loss is NaN or inf, continuing training')
final_loss += registered_tensors[tensor.unique_name]

# Check for NaN/inf loss (across workers if applicable)
loss_nan_inf_checker = final_loss.clone()
if placement_gpu:
dist.all_reduce(loss_nan_inf_checker, torch.distributed.ReduceOp.MAX)
if torch.isnan(loss_nan_inf_checker).any() or torch.isinf(loss_nan_inf_checker).any():
if stop_on_nan_loss:
raise ValueError('Loss is NaN or inf - exiting')
if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
logging.warning('Loss is NaN or inf.')
else:
# Skip this step across workers if loss is NaN/inf and using fp32
logging.warning('Loss is NaN or inf. Skipping update.')
continue

if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss:
if disable_allreduce:
Expand All @@ -1460,15 +1437,12 @@ def train(
final_loss.backward(bps_scale.to(final_loss.get_device()))
# single device (CPU or GPU)
else:
# Fix (workaround?) enabling to backpropagate gradiens on CPUs.
# Fix (workaround?) enabling to backpropagate gradients on CPUs.
if final_loss.get_device() < 0:
final_loss.backward(bps_scale)
else:
final_loss.backward(bps_scale.to(final_loss.get_device()))

# Check if we should terminate due to NaN/inf on any workers.
self._check_nan_or_inf(placement_gpu, nan_or_inf, steps_per_nan_check=steps_per_nan_check)

batch_counter += 1

if batch_counter == batches_per_step:
Expand All @@ -1488,9 +1462,6 @@ def train(
self._perform_on_epoch_end(callbacks=callbacks)
self.epoch_num += 1

# Check again if we should stop on NaN/inf
self._check_nan_or_inf(placement_gpu, nan_or_inf)

self._perform_on_action_end(callbacks=callbacks)

def infer(
Expand Down
5 changes: 1 addition & 4 deletions nemo/core/neural_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def train(
batch_size
stop_on_nan_loss: (default: False) If set to True, the training
will stop if loss=nan or inf. If set to False, the training
will continue. Note that if apex.amp is not used, or if
optimization level is O0, training will stop regardless.
will continue.
Returns:
None
Expand Down Expand Up @@ -573,7 +572,6 @@ def train(
lr_policy=None,
batches_per_step=None,
stop_on_nan_loss=False,
steps_per_nan_check=100,
synced_batchnorm=False,
synced_batchnorm_groupsize=0,
gradient_predivide=False,
Expand All @@ -591,7 +589,6 @@ def train(
lr_policy=lr_policy,
batches_per_step=batches_per_step,
stop_on_nan_loss=stop_on_nan_loss,
steps_per_nan_check=steps_per_nan_check,
synced_batchnorm=synced_batchnorm,
synced_batchnorm_groupsize=synced_batchnorm_groupsize,
gradient_predivide=gradient_predivide,
Expand Down

0 comments on commit 97679e8

Please sign in to comment.