Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp8 (torchao)/fsdp2/torch_compile handlers and tests #20445

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion _notebooks
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems always being deleted 🤔 any ideas on how to keep it @lantiga

Submodule _notebooks deleted from b83fde
38 changes: 38 additions & 0 deletions examples/pytorch/custom_handler_fp8_fsdp1n2_compile/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# PyTorch Native FP8 Training with FSDP1/2 and Torch Compile using Custom Handler

This is an example of a ...

## Requirements

Install requirements by running

```bash
sh setup.sh
```

## Example

In this example we present

```bash
# # config the PYTHONPATH if needed
# export PYTHONPATH=/teamspace/studios/this_studio/pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile:$PYTHONPATH
cd pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile

# fsdp1 + fp8 + torch compile + gradient checkpointing + cpu offloading
python train.py --enable_fp8 --enable_torch_compile --enable_gradient_checkpointing --enable_cpu_offload

# fsdp2 + fp8 + torch compile + gradient checkpointing (the example does not implement fsdp2 cpu offloading currently)
python train.py --enable_fsdp2 --enable_fp8 --enable_torch_compile --enable_gradient_checkpointing
```

## Test the handlers

```bash
# # config the PYTHONPATH if needed
# export PYTHONPATH=/teamspace/studios/this_studio/pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile:$PYTHONPATH
cd pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile
pytest tests/*
```

> **Warning**
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# the script is modified based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py
import logging
import operator
from dataclasses import dataclass
from typing import Union

import torch
import torch.nn as nn
from lightning_utilities.core.imports import compare_version

log = logging.getLogger(__name__)


def is_sm89_or_later():
# Float8 is only supported on SM89 or later (H100+ GPUs)
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


# check https://github.com/pytorch/ao/blob/main/torchao/float8/config.py for more config details
@dataclass
class FP8Config:
enable_fp8: bool = True
enable_amax_init: bool = False
scaling_type_input: str = "delayed"
scaling_type_weight: str = "delayed"
scaling_type_grad_output: str = "delayed"
enable_fsdp_float8_all_gather: bool = False
precompute_float8_dynamic_scale_for_fsdp: bool = False
pad_inner_dim: bool = True
emulate_fp8: bool = False # Set to True for testing without FP8 hardware
enable_torch_compile: bool = True
enable_pre_and_post_forward: bool = False


# Define a map for module filter functions based on model name
MODULE_FILTER_MAP = {
"llama": lambda mod, fqn: isinstance(mod, nn.Linear) and "mlp" in fqn and "lm_head" not in fqn,
"mixtral": lambda mod, fqn: isinstance(mod, nn.Linear)
and "block_sparse_moe" in fqn
and "block_sparse_moe.gate" not in fqn
and "lm_head" not in fqn,
"default": lambda mod, fqn: isinstance(mod, nn.Linear), # Default filter
}


class Float8TrainingHandler:
"""Handler for configuring models for FP8 training using torchao."""

def __init__(self, args: FP8Config, model_path: str, parallel_dims: dict[str, bool]):
"""Initializes the handler for FP8 training and configuration.

Args:
args (FP8Config): Configuration object for FP8 training, including settings for scaling, amax initialization, and torch compile.
model_path (str): The path to the model. Typically used for determining model-specific settings.
parallel_dims (Dict[str, bool]): Dictionary specifying parallelization settings, such as whether DP shard is enabled.

Example Usage:
fp8_config = FP8Config(
enable_fp8=True,
enable_amax_init=True,
scaling_type_input="delayed",
scaling_type_weight="delayed",
scaling_type_grad_output="delayed",
enable_fsdp_float8_all_gather=False,
precompute_float8_dynamic_scale_for_fsdp=False,
pad_inner_dim=True,
emulate_fp8=False, # Set to True for testing without FP8 hardware
enable_torch_compile=True,
enable_pre_and_post_forward=False,
)

parallel_dims = {"dp_shard_enabled": False}
handler = Float8TrainingHandler(fp8_config, "path/to/model", parallel_dims)

"""
self.model_path = model_path
self.args = args
self.parallel_dims = parallel_dims
self.compile = args.enable_torch_compile
self.enable_fp8 = args.enable_fp8

if not self.enable_fp8:
log.warning("Fp8 is disabled here")
return

if not is_sm89_or_later() and not args.emulate_fp8:
log.error("Failed to swap to Float8Linear because float8 is only supported on SM89 or later (H100+ GPUs)")
raise RuntimeError("Float8Linear operation is not supported on the current hardware.")

# Check if torchao is installed and version is >= 0.5.0
try:
compare_version("torchao", operator.ge, "0.6.1")
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
except ImportError as e:
log.error(str(e))
raise

# Configure Float8LinearConfig parameters from args
scaling_type_input = ScalingType(args.scaling_type_input)
scaling_type_weight = ScalingType(args.scaling_type_weight)
scaling_type_grad_output = ScalingType(args.scaling_type_grad_output)

enable_fsdp_float8_all_gather = (
parallel_dims.get("dp_shard_enabled", False) and args.enable_fsdp_float8_all_gather
)

enable_amax_init = args.enable_amax_init
self.config = Float8LinearConfig(
enable_amax_init=enable_amax_init,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
enable_pre_and_post_forward=args.enable_pre_and_post_forward,
pad_inner_dim=args.pad_inner_dim,
emulate=args.emulate_fp8,
)

# For precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = enable_fsdp_float8_all_gather and args.precompute_float8_dynamic_scale_for_fsdp

# For sync_float8_amax_and_scale_history
self.delayed_scaling = (
scaling_type_input == ScalingType.DELAYED
or scaling_type_weight == ScalingType.DELAYED
or scaling_type_grad_output == ScalingType.DELAYED
)
self._sync_float8_amax_and_scale_history = None

log.info("Float8 training active")

def convert_to_float8_training(self, model: nn.Module, module_filter_fn: callable = None):
"""Converts the linear layers of `model` to `Float8Linear` based on a module filter function. Mutates the model
in place.

Args:
model (nn.Module): The model whose layers should be converted.
module_filter_fn (callable, optional): A function to filter which modules should be replaced.
Defaults to a model-specific filter based on `model_path`.

"""
if not self.enable_fp8:
log.warning("FP8 is disabled, so layers will not be replaced.")
return

log.warning("Enabling FP8 Training")

# Use the provided filter function or select from the map
if module_filter_fn is None:
model_path_lower = self.model_path.lower()
module_filter_fn = next(
(fn for key, fn in MODULE_FILTER_MAP.items() if key in model_path_lower),
MODULE_FILTER_MAP["default"], # Default filter if no match is found
)

from torchao.float8 import convert_to_float8_training

convert_to_float8_training(
model,
config=self.config,
module_filter_fn=module_filter_fn,
)
log.info(
f"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather={self.config.enable_fsdp_float8_all_gather}"
)

def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, list[nn.Module]]):
if not self.enable_fp8 or not self.precompute_scale:
return

from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)

def sync_float8_amax_and_scale_history(self, model: Union[nn.Module, list[nn.Module]]):
if not self.enable_fp8 or not self.delayed_scaling:
return

from torchao.float8 import sync_float8_amax_and_scale_history

# Cache the compiled function if necessary
if self._sync_float8_amax_and_scale_history is None:
if self.compile:
self._sync_float8_amax_and_scale_history = torch.compile(sync_float8_amax_and_scale_history)
else:
self._sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history

models = [model] if isinstance(model, nn.Module) else model
for m in models:
self._sync_float8_amax_and_scale_history(m)
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
import operator
from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
from lightning_utilities.core.imports import compare_version

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh

log = logging.getLogger(__name__)


@dataclass
class FSDP2Config:
enable_cpu_offload: bool = False
enable_gradient_checkpointing: bool = False


class FSDP2Handler:
"""Handler for wrapping the model layers with FSDP2.

Args:
args (FSDP2Config): Configuration for FSDP2, including options for CPU offload and gradient checkpointing.
device_mesh (DeviceMesh): Device mesh configuration for FSDP2 parallelism.

Attributes:
args (FSDP2Config): Stores the FSDP2 configuration.
device_mesh (DeviceMesh): Stores the device mesh configuration.

"""

def __init__(self, args: FSDP2Config, device_mesh: "DeviceMesh"):
self.args = args
self.device_mesh = device_mesh

# Check PyTorch version for FSDP2 support (currently we require PyTorch >= 2.6.0)
try:
compare_version("torch", operator.ge, "2.6.0")
except RuntimeError as e:
log.error(str(e))
raise

# Import necessary FSDP modules
try:
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)

self.fully_shard = fully_shard
self.checkpoint_wrapper = checkpoint_wrapper
self.MixedPrecisionPolicy = MixedPrecisionPolicy
self.CPUOffloadPolicy = CPUOffloadPolicy
except ImportError as e:
log.error(f"Failed to import FSDP modules: {e}")
raise

def wrap_model(self, model: nn.Module):
"""Wraps the model layers with FSDP configurations.

Args:
model (nn.Module): The model to wrap.

Returns:
nn.Module: The wrapped model.

"""
dp_mesh = self.device_mesh["data_parallel"]
assert dp_mesh.size() > 1, "FSDP requires at least two devices."

fsdp_policy = {
"mesh": dp_mesh,
"mp_policy": self.MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
),
}
if self.args.enable_cpu_offload:
fsdp_policy["offload_policy"] = self.CPUOffloadPolicy()

for layer_id, module in enumerate(model.model.layers):
reshard_after_forward = layer_id < len(model.model.layers) - 1
if self.args.enable_gradient_checkpointing:
module = self.checkpoint_wrapper(module)
self.fully_shard(
module,
**fsdp_policy,
reshard_after_forward=reshard_after_forward,
)
model.model.layers[layer_id] = module

self.fully_shard(model, **fsdp_policy)
return model
Loading