diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 3955648f924d..b364eea442ba 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,6 +5,7 @@ import json import os from collections import defaultdict +from contextlib import contextmanager from pathlib import Path from typing import List, Optional @@ -26,8 +27,8 @@ # these imports will happen on as-needed basis amp = None -convert_syncbn = None -create_syncbn_process_group = None +# convert_syncbn = None +# create_syncbn_process_group = None LARC = None FusedLAMB = None FusedAdam = None @@ -59,16 +60,16 @@ def __init__( global amp amp = importlib.import_module('apex.amp') if local_rank is not None: - global convert_syncbn - global create_syncbn_process_group + # global convert_syncbn + # global create_syncbn_process_group global LARC global FusedLAMB global FusedAdam global FusedNovoGrad parallel = importlib.import_module('apex.parallel') apex_optimizer = importlib.import_module('apex.optimizers') - convert_syncbn = parallel.convert_syncbn_model - create_syncbn_process_group = parallel.create_syncbn_process_group + # convert_syncbn = parallel.convert_syncbn_model + # create_syncbn_process_group = parallel.create_syncbn_process_group LARC = parallel.LARC FusedLAMB = apex_optimizer.FusedLAMB FusedAdam = apex_optimizer.FusedAdam @@ -377,7 +378,7 @@ def __initialize_amp( return optimizer def __nm_graph_forward_pass( - self, call_chain, registered_tensors, mode=ModelMode.train, disable_allreduce=False, use_cache=False, + self, call_chain, registered_tensors, mode=ModelMode.train, use_cache=False, ): for ind in range(1, len(call_chain)): if use_cache: @@ -397,12 +398,12 @@ def __nm_graph_forward_pass( m_id = call_chain[ind][0].unique_instance_id pmodule = self.module_reference_table[m_id][1] - if self._local_rank is not None: - if isinstance(pmodule, DDP): - if disable_allreduce: - pmodule.disable_allreduce() - else: - pmodule.enable_allreduce() + # if self._local_rank is not None: + # if isinstance(pmodule, DDP): + # if disable_allreduce: + # pmodule.disable_allreduce() + # else: + # pmodule.enable_allreduce() if mode == ModelMode.train: # if module.is_trainable(): @@ -1064,6 +1065,11 @@ def train( gradient_predivide=False, amp_max_loss_scale=2.0 ** 24, ): + if gradient_predivide: + logging.error( + "gradient_predivide is currently disabled, and is under consideration for removal in future versions. " + "If this functionality is needed, please raise a github issue." + ) if not optimization_params: optimization_params = {} num_epochs = optimization_params.get("num_epochs", None) @@ -1205,23 +1211,42 @@ def train( key = call_chain[i][0].unique_instance_id pmodule = self.module_reference_table[key][1] if not isinstance(pmodule, DDP) and isinstance(pmodule, torch.nn.Module): - gpf = 1 - if gradient_predivide: - gpf = dist.get_world_size() - pmodule = DDP(pmodule, gradient_predivide_factor=gpf) - - # Convert batchnorm modules to synced if applicable - if synced_batchnorm and isinstance(pmodule, torch.nn.Module): - world_size = dist.get_world_size() - if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0: - raise ValueError( - f"Synchronized batch norm group size" - f" ({synced_batchnorm_groupsize}) must be 0" - f" or divide total number of GPUs" - f" ({world_size})." + # gpf = 1 + # if gradient_predivide: + # gpf = dist.get_world_size() + # pmodule = DDP(pmodule, gradient_predivide_factor=gpf) # Old Apex Method + + # Per pytorch docs, convert sync bn prior to DDP + if synced_batchnorm: + world_size = dist.get_world_size() + sync_batchnorm_group = None + if synced_batchnorm_groupsize > 0: + if world_size % synced_batchnorm_groupsize != 0: + raise ValueError( + f"Synchronized batch norm group size ({synced_batchnorm_groupsize}) must be 0" + f" or divide total number of GPUs ({world_size})." + ) + sync_batchnorm_group = torch.distributed.new_group(synced_batchnorm_groupsize) + pmodule = nn.SyncBatchNorm.convert_sync_batchnorm( + pmodule, process_group=sync_batchnorm_group ) - process_group = create_syncbn_process_group(synced_batchnorm_groupsize) - pmodule = convert_syncbn(pmodule, process_group=process_group) + + # By default, disable broadcast_buffers. This disables batch norm synchronization on forward + # pass + pmodule = DDP(pmodule, device_ids=[self.local_rank], broadcast_buffers=False) + + # # Convert batchnorm modules to synced if applicable + # if synced_batchnorm and isinstance(pmodule, torch.nn.Module): + # world_size = dist.get_world_size() + # if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0: + # raise ValueError( + # f"Synchronized batch norm group size" + # f" ({synced_batchnorm_groupsize}) must be 0" + # f" or divide total number of GPUs" + # f" ({world_size})." + # ) + # process_group = create_syncbn_process_group(synced_batchnorm_groupsize) + # pmodule = convert_syncbn(pmodule, process_group=process_group) self.module_reference_table[key] = ( self.module_reference_table[key][0], @@ -1300,9 +1325,7 @@ def train( } disable_allreduce = batch_counter < (batches_per_step - 1) self.__nm_graph_forward_pass( - call_chain=curr_call_chain, - registered_tensors=registered_tensors, - disable_allreduce=disable_allreduce, + call_chain=curr_call_chain, registered_tensors=registered_tensors, ) curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1] @@ -1323,19 +1346,27 @@ def train( if nan: continue if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0: - with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce,) as scaled_loss: + with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss: if torch.isnan(scaled_loss).any() or torch.isinf(scaled_loss).any(): if stop_on_nan_loss: raise ValueError('Loss is NaN or inf -' ' exiting') logging.warning('WARNING: Loss is NaN or inf') curr_optimizer.zero_grad() continue - scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) + if disable_allreduce: + with self.no_sync(curr_call_chain): + scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) + else: + scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) # no AMP optimizations needed else: # multi-GPU, float32 if self._local_rank is not None: - final_loss.backward(bps_scale.to(final_loss.get_device())) + if disable_allreduce: + with self.no_sync(curr_call_chain): + final_loss.backward(bps_scale.to(final_loss.get_device())) + else: + final_loss.backward(bps_scale.to(final_loss.get_device())) # single device (CPU or GPU) else: # Fix (workaround?) enabling to backpropagate gradiens on CPUs. @@ -1430,3 +1461,32 @@ def infer( use_cache=use_cache, offload_to_cpu=offload_to_cpu, ) + + @contextmanager + def no_sync(self, call_chain): + """ + Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be + inside the @no_sync context manager for graduation accumulation + """ + modules = [] + for ind in range(1, len(call_chain)): + m_id = call_chain[ind][0].unique_instance_id + module = self.module_reference_table[m_id][1] + if isinstance(module, DDP): + modules.append(module) + + @contextmanager + def recursive_yield(list_of_modules): + mod = list_of_modules.pop() + with mod.no_sync() as ctx: + try: + recursive_yield(list_of_modules) + yield [ctx] + finally: + pass + + with recursive_yield(modules): + try: + yield + finally: + pass