Skip to content

Commit

Permalink
Make FabricMegatronMixedPrecision match MegatronMixedPrecision (#10835)
Browse files Browse the repository at this point in the history
* Make FabricMegatronMixedPrecision match MegatronMixedPrecision

Signed-off-by: Marc Romeijn <[email protected]>

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

---------

Signed-off-by: Marc Romeijn <[email protected]>
Signed-off-by: marcromeyn <[email protected]>
Co-authored-by: marcromeyn <[email protected]>
  • Loading branch information
2 people authored and Yashaswi Karnati committed Oct 20, 2024
1 parent d2fb283 commit 2481b88
Showing 1 changed file with 95 additions and 58 deletions.
153 changes: 95 additions & 58 deletions nemo/lightning/fabric/plugins.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,94 @@
from contextlib import contextmanager
from typing import Any, Generator, Literal, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Generator, Literal, TypeVar

import torch
from lightning_fabric.plugins.precision import MixedPrecision
from lightning_fabric.utilities.types import Optimizable
from torch import nn
from torch.optim import Optimizer

from nemo.lightning._strategy_lib import GradScaler
from nemo.lightning.fabric.conversion import to_fabric
from nemo.lightning.pytorch.plugins.mixed_precision import MegatronMixedPrecision
from nemo.lightning.pytorch.plugins.mixed_precision import (
DtypeConfig,
MegatronMixedPrecision,
get_optim_config,
update_config_with_dtype_overrides,
)
from nemo.utils.import_utils import safe_import

if TYPE_CHECKING:
from megatron.core.model_parallel_config import ModelParallelConfig

AnyT = TypeVar("AnyT")
ConfigT = TypeVar("ConfigT", bound="ModelParallelConfig")


class FabricMegatronMixedPrecision(MixedPrecision):
def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"] = "16-mixed",
amp_02: bool = True,
device="cuda",
scaler: Optional[Union[torch.cuda.amp.GradScaler, str]] = None,
precision: Literal["16-mixed", "bf16-mixed", "32"],
params_dtype: torch.dtype = None,
pipeline_dtype: torch.dtype = None,
autocast_dtype: torch.dtype = None,
autocast_enabled: bool = False,
grad_reduce_in_fp32: bool = True,
# fp8 related,
fp8: str = None,
fp8_margin: int = 0,
fp8_amax_history_len: int = 1,
fp8_amax_compute_algo: str = "most_recent",
fp8_wgrad: bool = True,
fp8_dot_product_attention: bool = False,
fp8_multi_head_attention: bool = False,
fp8_params: bool = False,
fp16_loss_scale: float = None,
fp16_initial_loss_scale: float = 4294967296,
fp16_min_loss_scale: float = 1.0,
fp16_loss_scale_window: int = 1000,
fp16_hysteresis: int = 2,
) -> None:
if precision == "bf16-mixed":
scaler = None
if isinstance(precision, int):
precision = str(precision)

fp8_param_gather = False
if fp8 is not None:
te_fp8, HAVE_TE = safe_import("transformer_engine.pytorch.fp8")
assert HAVE_TE, "FP8 precision requires transformer engine."
if fp8_params:
te_fp8.FP8GlobalStateManager.FP8_PARAMETERS = True
fp8_param_gather = True

dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32
self.dtype_config = DtypeConfig(
fp32=precision in ['fp32', '32'],
fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'],
bf16=precision in ['bf16', 'bf16-mixed'],
params_dtype=params_dtype or torch.float32,
pipeline_dtype=pipeline_dtype or dtype,
autocast_dtype=autocast_dtype or dtype,
autocast_enabled=autocast_enabled,
grad_reduce_in_fp32=grad_reduce_in_fp32,
fp8=fp8,
fp8_margin=fp8_margin,
fp8_amax_history_len=fp8_amax_history_len,
fp8_amax_compute_algo=fp8_amax_compute_algo,
fp8_wgrad=fp8_wgrad,
fp8_dot_product_attention=fp8_dot_product_attention,
fp8_multi_head_attention=fp8_multi_head_attention,
fp8_param_gather=fp8_param_gather,
# fp16 loss scale
loss_scale=fp16_loss_scale,
initial_loss_scale=fp16_initial_loss_scale,
min_loss_scale=fp16_min_loss_scale,
loss_scale_window=fp16_loss_scale_window,
hysteresis=fp16_hysteresis,
)
if self.dtype_config.fp16:
self.precision = "16-mixed"
elif self.dtype_config.bf16:
self.precision = "bf16-mixed"
else:
scaler = GradScaler(
init_scale=2**32,
growth_interval=1000,
hysteresis=2,
)

super().__init__(precision, device, scaler)
self.amp_02 = amp_02
self.precision = "32-true"
self.scaler = None

def convert_input(self, data: AnyT) -> AnyT:
"""Convert model inputs (forward) to the floating point precision type of this plugin.
Expand All @@ -52,64 +108,45 @@ def convert_output(self, data: AnyT) -> AnyT:
"""
return data

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
from nemo.core.optim import MainParamsOptimizerWrapper
def convert_config(self, config: ConfigT) -> ConfigT:
"""Convert the config to the precision type this plugin handles.
return MainParamsOptimizerWrapper(
optimizer,
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_base_model.py#L496
fp32_grad_accum=True,
contiguous_grad_bucket=True,
)
This is optional and depends on the precision limitations during optimization.
"""
return update_config_with_dtype_overrides(self.dtype_config, config)

def convert_module(self, module: nn.Module) -> nn.Module:
"""Convert the module parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
if not hasattr(module, "module"):
return module

from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config

if self.precision in ["16-mixed", "bf16-mixed"]:
if self.dtype_config.fp16 or self.dtype_config.bf16:
# Patch config options
config = get_model_config(module.module)
config.fp16 = self.precision == "16-mixed"
config.bf16 = self.precision == "bf16-mixed"
if not isinstance(module.module, Float16Module):
config.fp16 = self.dtype_config.fp16
config.bf16 = self.dtype_config.bf16
if hasattr(module, 'module'):
module.module = Float16Module(config, module.module)
else:
module = Float16Module(config, module)

return module

def optimizer_step(
self,
optimizer: Optimizable,
**kwargs: Any,
) -> None:
from nemo.core.optim import MainParamsOptimizerWrapper

assert isinstance(
optimizer, MainParamsOptimizerWrapper
), "MegatronHalfPrecisionPlugin supports only the optimizer with master parameters"

if self.scaler is None:
assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation"

# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(optimizer, **kwargs)
def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Convert the optimizer parameters to the precision type this plugin handles.
assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation"

# cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update
optimizer.copy_model_grads_to_main_grads()

# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
This is optional and depends on the precision limitations during optimization.
return step_output
"""
for optim_config in get_optim_config(optimizer):
assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch"
assert optim_config.fp16 == self.dtype_config.fp16, "FP16 model/optim config mismatch"
return optimizer

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down

0 comments on commit 2481b88

Please sign in to comment.