Skip to content

Commit

Permalink
Deprecate Fused LayerNorm (#2475)
Browse files Browse the repository at this point in the history
Will be removed in v0.18.
  • Loading branch information
nik-mosaic authored Aug 28, 2023
1 parent a6c0b0f commit 23aaaeb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
8 changes: 8 additions & 0 deletions composer/algorithms/fused_layernorm/fused_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def apply_fused_layernorm(model: torch.nn.Module, optimizers: Union[torch.optim.
By fusing multiple kernel launches into one, this usually improves GPU utilization.
"""
warnings.warn(
DeprecationWarning(
'Fused LayerNorm has been deprecated and will be removed in Composer 0.18. Please switch to Low Precision LayerNorm.'
))
check_if_apex_installed()

# prepare the replacement policy and perform replacement
Expand Down Expand Up @@ -99,6 +103,10 @@ def no_op(self, *args): pass

def __init__(self):
# FusedLayerNorm takes no arguments
warnings.warn(
DeprecationWarning(
'Fused LayerNorm has been deprecated and will be removed in Composer 0.18. Please switch to Low Precision LayerNorm.'
))
check_if_apex_installed()

def __repr__(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import logging
import textwrap
import warnings
from typing import Dict, Optional, Sequence, Type, Union

Expand Down Expand Up @@ -40,6 +41,14 @@ def apply_low_precision_layernorm(model,
# Prior to v1.13, torch.nn.LayerNorm is slow in bf16 precision.
# We use FusedLayerNorm as a fallback.
if version.parse(torch.__version__) < version.parse('1.13') and precision == Precision.AMP_BF16:
warnings.warn(
DeprecationWarning(
textwrap.dedent(
'You are using Low Precision LayerNorm on PyTorch < v.1.13 with bfloat16 precision. '
'In this scenario, we fall back to Fused LayerNorm. '
'Fused LayerNorm has been deprecated and will be removed in Composer 0.18. '
'Please upgrade your PyTorch version to >=v.1.13 to use Low Precision LayerNorm without the Fused LayerNorm fallback.'
)))
check_if_apex_installed()
policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = {
torch.nn.LayerNorm: _to_FusedLayerNorm
Expand Down

0 comments on commit 23aaaeb

Please sign in to comment.