From 22cef73c5e53a12f2f9f7d7a663bdb37483ab883 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 15 Apr 2024 19:05:49 +0000 Subject: [PATCH] Don't use autograd hook for bwd reduction Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 21 ------------------- .../pytorch/module/layernorm_linear.py | 12 +++++------ .../pytorch/module/layernorm_mlp.py | 13 +++++------- transformer_engine/pytorch/module/linear.py | 13 +++++------- 4 files changed, 15 insertions(+), 44 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index e821bfe11d..d06443efb6 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -81,8 +81,6 @@ class FP8GlobalStateManager: fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" - multi_grad_hook_tensors = [] - bwd_amax_update_hook_registered = False autocast_arguments = {} autocast_to_fp8_params = {} fp8_param_to_autocast = {} @@ -106,8 +104,6 @@ def reset(cls) -> None: cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" - cls.multi_grad_hook_tensors = [] - cls.bwd_amax_update_hook_registered = False cls.autocast_arguments = {} cls.autocast_to_fp8_params = {} cls.fp8_param_to_autocast = {} @@ -370,16 +366,6 @@ def reduce_and_update_fp8_tensors( _amax_and_scale_update( amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe) - @classmethod - def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor): - """Add tensor to list for multi grad hook.""" - cls.multi_grad_hook_tensors.append(tensor) - - @classmethod - def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument - """Executes at the end of backward pass.""" - cls.reduce_and_update_fp8_tensors(forward=False) - @classmethod def get_unique_autocast_key( cls, @@ -407,13 +393,6 @@ def fp8_autocast_enter( autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0: - # This hook does not fire for graphed modules. - torch.autograd.graph.register_multi_grad_hook( - tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction) - cls.bwd_amax_update_hook_registered = True - cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = fp8_recipe diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ffa14bc157..5df4950276 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -40,6 +40,7 @@ ) from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo +from ..graph import is_graph_capturing from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor @@ -89,7 +90,6 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, - dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -328,6 +328,7 @@ def forward( ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization ctx.primary_weights_in_fp8 = primary_weights_in_fp8 + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -660,6 +661,9 @@ def backward( else: wgrad = None + if ctx.is_first_module and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, @@ -696,7 +700,6 @@ def backward( None, None, None, - None, ) @@ -1001,10 +1004,6 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) - FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) - def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1176,7 +1175,6 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, - self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e143cf6659..6efb72b8db 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -49,7 +49,7 @@ from ..constants import dist_group_type, TE_DType from ..jit import no_torch_dynamo - +from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor from ._common import _apply_normalization @@ -121,7 +121,6 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, gemm_gelu_fusion: bool, - dummy_tensor: torch.Tensor, # pylint: disable=unused-argument, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -545,6 +544,7 @@ def forward( ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization ctx.primary_weights_in_fp8 = primary_weights_in_fp8 + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear if ub_overlap_rs: @@ -1121,6 +1121,9 @@ def backward( else: fc2_wgrad = None + if ctx.is_first_module and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, @@ -1165,7 +1168,6 @@ def backward( None, None, None, - None, ) @@ -1429,10 +1431,6 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) - FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) - def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1588,7 +1586,6 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.gemm_gelu_fusion, - self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 4baf2d5965..3c055270b0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -43,7 +43,7 @@ ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo - +from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor __all__ = ["Linear"] @@ -81,7 +81,6 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, ub_name: str, - dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -321,6 +320,7 @@ def forward( ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad ctx.primary_weights_in_fp8 = primary_weights_in_fp8 + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear if ub_overlap_rs: @@ -530,6 +530,9 @@ def backward( else: wgrad = None + if ctx.is_first_module and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( wgrad, None, @@ -555,7 +558,6 @@ def backward( None, None, None, - None, ) @@ -798,10 +800,6 @@ def __init__( else: self.gemm_bias_unfused_add = False - # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) - FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) - def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -941,7 +939,6 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, - self.dummy_tensor, ) out = linear_fn(*args)