Skip to content

Commit

Permalink
switch from apex to torch
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc committed Feb 7, 2020
1 parent 675b0fe commit d535720
Showing 1 changed file with 95 additions and 35 deletions.
130 changes: 95 additions & 35 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import json
import os
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import List, Optional

@@ -26,8 +27,8 @@

# these imports will happen on as-needed basis
amp = None
convert_syncbn = None
create_syncbn_process_group = None
# convert_syncbn = None
# create_syncbn_process_group = None
LARC = None
FusedLAMB = None
FusedAdam = None
@@ -59,16 +60,16 @@ def __init__(
global amp
amp = importlib.import_module('apex.amp')
if local_rank is not None:
global convert_syncbn
global create_syncbn_process_group
# global convert_syncbn
# global create_syncbn_process_group
global LARC
global FusedLAMB
global FusedAdam
global FusedNovoGrad
parallel = importlib.import_module('apex.parallel')
apex_optimizer = importlib.import_module('apex.optimizers')
convert_syncbn = parallel.convert_syncbn_model
create_syncbn_process_group = parallel.create_syncbn_process_group
# convert_syncbn = parallel.convert_syncbn_model
# create_syncbn_process_group = parallel.create_syncbn_process_group
LARC = parallel.LARC
FusedLAMB = apex_optimizer.FusedLAMB
FusedAdam = apex_optimizer.FusedAdam
@@ -377,7 +378,7 @@ def __initialize_amp(
return optimizer

def __nm_graph_forward_pass(
self, call_chain, registered_tensors, mode=ModelMode.train, disable_allreduce=False, use_cache=False,
self, call_chain, registered_tensors, mode=ModelMode.train, use_cache=False,
):
for ind in range(1, len(call_chain)):
if use_cache:
@@ -397,12 +398,12 @@ def __nm_graph_forward_pass(
m_id = call_chain[ind][0].unique_instance_id
pmodule = self.module_reference_table[m_id][1]

if self._local_rank is not None:
if isinstance(pmodule, DDP):
if disable_allreduce:
pmodule.disable_allreduce()
else:
pmodule.enable_allreduce()
# if self._local_rank is not None:
# if isinstance(pmodule, DDP):
# if disable_allreduce:
# pmodule.disable_allreduce()
# else:
# pmodule.enable_allreduce()

if mode == ModelMode.train:
# if module.is_trainable():
@@ -1064,6 +1065,11 @@ def train(
gradient_predivide=False,
amp_max_loss_scale=2.0 ** 24,
):
if gradient_predivide:
logging.error(
"gradient_predivide is currently disabled, and is under consideration for removal in future versions. "
"If this functionality is needed, please raise a github issue."
)
if not optimization_params:
optimization_params = {}
num_epochs = optimization_params.get("num_epochs", None)
@@ -1205,23 +1211,42 @@ def train(
key = call_chain[i][0].unique_instance_id
pmodule = self.module_reference_table[key][1]
if not isinstance(pmodule, DDP) and isinstance(pmodule, torch.nn.Module):
gpf = 1
if gradient_predivide:
gpf = dist.get_world_size()
pmodule = DDP(pmodule, gradient_predivide_factor=gpf)

# Convert batchnorm modules to synced if applicable
if synced_batchnorm and isinstance(pmodule, torch.nn.Module):
world_size = dist.get_world_size()
if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0:
raise ValueError(
f"Synchronized batch norm group size"
f" ({synced_batchnorm_groupsize}) must be 0"
f" or divide total number of GPUs"
f" ({world_size})."
# gpf = 1
# if gradient_predivide:
# gpf = dist.get_world_size()
# pmodule = DDP(pmodule, gradient_predivide_factor=gpf) # Old Apex Method

# Per pytorch docs, convert sync bn prior to DDP
if synced_batchnorm:
world_size = dist.get_world_size()
sync_batchnorm_group = None
if synced_batchnorm_groupsize > 0:
if world_size % synced_batchnorm_groupsize != 0:
raise ValueError(
f"Synchronized batch norm group size ({synced_batchnorm_groupsize}) must be 0"
f" or divide total number of GPUs ({world_size})."
)
sync_batchnorm_group = torch.distributed.new_group(synced_batchnorm_groupsize)
pmodule = nn.SyncBatchNorm.convert_sync_batchnorm(
pmodule, process_group=sync_batchnorm_group
)
process_group = create_syncbn_process_group(synced_batchnorm_groupsize)
pmodule = convert_syncbn(pmodule, process_group=process_group)

# By default, disable broadcast_buffers. This disables batch norm synchronization on forward
# pass
pmodule = DDP(pmodule, device_ids=[self.local_rank], broadcast_buffers=False)

# # Convert batchnorm modules to synced if applicable
# if synced_batchnorm and isinstance(pmodule, torch.nn.Module):
# world_size = dist.get_world_size()
# if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0:
# raise ValueError(
# f"Synchronized batch norm group size"
# f" ({synced_batchnorm_groupsize}) must be 0"
# f" or divide total number of GPUs"
# f" ({world_size})."
# )
# process_group = create_syncbn_process_group(synced_batchnorm_groupsize)
# pmodule = convert_syncbn(pmodule, process_group=process_group)

self.module_reference_table[key] = (
self.module_reference_table[key][0],
@@ -1300,9 +1325,7 @@ def train(
}
disable_allreduce = batch_counter < (batches_per_step - 1)
self.__nm_graph_forward_pass(
call_chain=curr_call_chain,
registered_tensors=registered_tensors,
disable_allreduce=disable_allreduce,
call_chain=curr_call_chain, registered_tensors=registered_tensors,
)

curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1]
@@ -1323,19 +1346,27 @@ def train(
if nan:
continue
if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce,) as scaled_loss:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss:
if torch.isnan(scaled_loss).any() or torch.isinf(scaled_loss).any():
if stop_on_nan_loss:
raise ValueError('Loss is NaN or inf -' ' exiting')
logging.warning('WARNING: Loss is NaN or inf')
curr_optimizer.zero_grad()
continue
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
if disable_allreduce:
with self.no_sync(curr_call_chain):
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
else:
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
# no AMP optimizations needed
else:
# multi-GPU, float32
if self._local_rank is not None:
final_loss.backward(bps_scale.to(final_loss.get_device()))
if disable_allreduce:
with self.no_sync(curr_call_chain):
final_loss.backward(bps_scale.to(final_loss.get_device()))
else:
final_loss.backward(bps_scale.to(final_loss.get_device()))
# single device (CPU or GPU)
else:
# Fix (workaround?) enabling to backpropagate gradiens on CPUs.
@@ -1430,3 +1461,32 @@ def infer(
use_cache=use_cache,
offload_to_cpu=offload_to_cpu,
)

@contextmanager
def no_sync(self, call_chain):
"""
Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be
inside the @no_sync context manager for graduation accumulation
"""
modules = []
for ind in range(1, len(call_chain)):
m_id = call_chain[ind][0].unique_instance_id
module = self.module_reference_table[m_id][1]
if isinstance(module, DDP):
modules.append(module)

@contextmanager
def recursive_yield(list_of_modules):
mod = list_of_modules.pop()
with mod.no_sync() as ctx:
try:
recursive_yield(list_of_modules)
yield [ctx]
finally:
pass

with recursive_yield(modules):
try:
yield
finally:
pass

0 comments on commit d535720

Please sign in to comment.