Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Don't use autograd hook for bwd reduction #781

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ class FP8GlobalStateManager:
fp8_tensors_recompute_buffer = []
fp8_available = None
reason_for_no_fp8 = ""
multi_grad_hook_tensors = []
bwd_amax_update_hook_registered = False
autocast_arguments = {}
autocast_to_fp8_params = {}
fp8_param_to_autocast = {}
Expand All @@ -106,8 +104,6 @@ def reset(cls) -> None:
cls.fp8_tensors_recompute_buffer = []
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.multi_grad_hook_tensors = []
cls.bwd_amax_update_hook_registered = False
cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
Expand Down Expand Up @@ -370,16 +366,6 @@ def reduce_and_update_fp8_tensors(
_amax_and_scale_update(
amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe)

@classmethod
def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor):
"""Add tensor to list for multi grad hook."""
cls.multi_grad_hook_tensors.append(tensor)

@classmethod
def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument
"""Executes at the end of backward pass."""
cls.reduce_and_update_fp8_tensors(forward=False)

@classmethod
def get_unique_autocast_key(
cls,
Expand Down Expand Up @@ -407,13 +393,6 @@ def fp8_autocast_enter(
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)

if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0:
# This hook does not fire for graphed modules.
torch.autograd.graph.register_multi_grad_hook(
tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction)
cls.bwd_amax_update_hook_registered = True

cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = fp8_recipe
Expand Down
12 changes: 5 additions & 7 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor

Expand Down Expand Up @@ -89,7 +90,6 @@ def forward(
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
Expand Down Expand Up @@ -328,6 +328,7 @@ def forward(
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()

# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
Expand Down Expand Up @@ -660,6 +661,9 @@ def backward(
else:
wgrad = None

if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
Expand Down Expand Up @@ -696,7 +700,6 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -1001,10 +1004,6 @@ def __init__(
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))

# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
Expand Down Expand Up @@ -1176,7 +1175,6 @@ def forward(
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = fwd_fn(*args)

Expand Down
13 changes: 5 additions & 8 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo

from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization

Expand Down Expand Up @@ -121,7 +121,6 @@ def forward(
ub_overlap_rs: bool,
ub_overlap_ag: bool,
gemm_gelu_fusion: bool,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
Expand Down Expand Up @@ -545,6 +544,7 @@ def forward(
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()

# Row Parallel Linear
if ub_overlap_rs:
Expand Down Expand Up @@ -1121,6 +1121,9 @@ def backward(
else:
fc2_wgrad = None

if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
Expand Down Expand Up @@ -1165,7 +1168,6 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -1429,10 +1431,6 @@ def __init__(
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))

# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
Expand Down Expand Up @@ -1588,7 +1586,6 @@ def forward(
self.ub_overlap_rs,
self.ub_overlap_ag,
self.gemm_gelu_fusion,
self.dummy_tensor,
)
out = fwd_fn(*args)

Expand Down
13 changes: 5 additions & 8 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo

from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor

__all__ = ["Linear"]
Expand Down Expand Up @@ -81,7 +81,6 @@ def forward(
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
Expand Down Expand Up @@ -321,6 +320,7 @@ def forward(
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()

# Row Parallel Linear
if ub_overlap_rs:
Expand Down Expand Up @@ -530,6 +530,9 @@ def backward(
else:
wgrad = None

if ctx.is_first_module and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

return (
wgrad,
None,
Expand All @@ -555,7 +558,6 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -798,10 +800,6 @@ def __init__(
else:
self.gemm_bias_unfused_add = False

# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)

def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)

Expand Down Expand Up @@ -941,7 +939,6 @@ def forward(
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = linear_fn(*args)

Expand Down
Loading