-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fp8 (torchao)/fsdp2/torch_compile handlers and tests #20445
Draft
qingquansong
wants to merge
5
commits into
Lightning-AI:master
Choose a base branch
from
qingquansong:qsong/fp8
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
887199a
refactor and add new examples
qingquansong ec6f5a8
update example
qingquansong 289a0ff
update readme
qingquansong e6889ed
rm redundant
qingquansong 9a35627
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
38 changes: 38 additions & 0 deletions
38
examples/pytorch/custom_handler_fp8_fsdp1n2_compile/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# PyTorch Native FP8 Training with FSDP1/2 and Torch Compile using Custom Handler | ||
|
||
This is an example of a ... | ||
|
||
## Requirements | ||
|
||
Install requirements by running | ||
|
||
```bash | ||
sh setup.sh | ||
``` | ||
|
||
## Example | ||
|
||
In this example we present | ||
|
||
```bash | ||
# # config the PYTHONPATH if needed | ||
# export PYTHONPATH=/teamspace/studios/this_studio/pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile:$PYTHONPATH | ||
cd pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile | ||
|
||
# fsdp1 + fp8 + torch compile + gradient checkpointing + cpu offloading | ||
python train.py --enable_fp8 --enable_torch_compile --enable_gradient_checkpointing --enable_cpu_offload | ||
|
||
# fsdp2 + fp8 + torch compile + gradient checkpointing (the example does not implement fsdp2 cpu offloading currently) | ||
python train.py --enable_fsdp2 --enable_fp8 --enable_torch_compile --enable_gradient_checkpointing | ||
``` | ||
|
||
## Test the handlers | ||
|
||
```bash | ||
# # config the PYTHONPATH if needed | ||
# export PYTHONPATH=/teamspace/studios/this_studio/pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile:$PYTHONPATH | ||
cd pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile | ||
pytest tests/* | ||
``` | ||
|
||
> **Warning** |
Empty file.
192 changes: 192 additions & 0 deletions
192
examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fp8_training_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# the script is modified based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py | ||
import logging | ||
import operator | ||
from dataclasses import dataclass | ||
from typing import Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
from lightning_utilities.core.imports import compare_version | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def is_sm89_or_later(): | ||
# Float8 is only supported on SM89 or later (H100+ GPUs) | ||
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) | ||
|
||
|
||
# check https://github.com/pytorch/ao/blob/main/torchao/float8/config.py for more config details | ||
@dataclass | ||
class FP8Config: | ||
enable_fp8: bool = True | ||
enable_amax_init: bool = False | ||
scaling_type_input: str = "delayed" | ||
scaling_type_weight: str = "delayed" | ||
scaling_type_grad_output: str = "delayed" | ||
enable_fsdp_float8_all_gather: bool = False | ||
precompute_float8_dynamic_scale_for_fsdp: bool = False | ||
pad_inner_dim: bool = True | ||
emulate_fp8: bool = False # Set to True for testing without FP8 hardware | ||
enable_torch_compile: bool = True | ||
enable_pre_and_post_forward: bool = False | ||
|
||
|
||
# Define a map for module filter functions based on model name | ||
MODULE_FILTER_MAP = { | ||
"llama": lambda mod, fqn: isinstance(mod, nn.Linear) and "mlp" in fqn and "lm_head" not in fqn, | ||
"mixtral": lambda mod, fqn: isinstance(mod, nn.Linear) | ||
and "block_sparse_moe" in fqn | ||
and "block_sparse_moe.gate" not in fqn | ||
and "lm_head" not in fqn, | ||
"default": lambda mod, fqn: isinstance(mod, nn.Linear), # Default filter | ||
} | ||
|
||
|
||
class Float8TrainingHandler: | ||
"""Handler for configuring models for FP8 training using torchao.""" | ||
|
||
def __init__(self, args: FP8Config, model_path: str, parallel_dims: dict[str, bool]): | ||
"""Initializes the handler for FP8 training and configuration. | ||
|
||
Args: | ||
args (FP8Config): Configuration object for FP8 training, including settings for scaling, amax initialization, and torch compile. | ||
model_path (str): The path to the model. Typically used for determining model-specific settings. | ||
parallel_dims (Dict[str, bool]): Dictionary specifying parallelization settings, such as whether DP shard is enabled. | ||
|
||
Example Usage: | ||
fp8_config = FP8Config( | ||
enable_fp8=True, | ||
enable_amax_init=True, | ||
scaling_type_input="delayed", | ||
scaling_type_weight="delayed", | ||
scaling_type_grad_output="delayed", | ||
enable_fsdp_float8_all_gather=False, | ||
precompute_float8_dynamic_scale_for_fsdp=False, | ||
pad_inner_dim=True, | ||
emulate_fp8=False, # Set to True for testing without FP8 hardware | ||
enable_torch_compile=True, | ||
enable_pre_and_post_forward=False, | ||
) | ||
|
||
parallel_dims = {"dp_shard_enabled": False} | ||
handler = Float8TrainingHandler(fp8_config, "path/to/model", parallel_dims) | ||
|
||
""" | ||
self.model_path = model_path | ||
self.args = args | ||
self.parallel_dims = parallel_dims | ||
self.compile = args.enable_torch_compile | ||
self.enable_fp8 = args.enable_fp8 | ||
|
||
if not self.enable_fp8: | ||
log.warning("Fp8 is disabled here") | ||
return | ||
|
||
if not is_sm89_or_later() and not args.emulate_fp8: | ||
log.error("Failed to swap to Float8Linear because float8 is only supported on SM89 or later (H100+ GPUs)") | ||
raise RuntimeError("Float8Linear operation is not supported on the current hardware.") | ||
|
||
# Check if torchao is installed and version is >= 0.5.0 | ||
try: | ||
compare_version("torchao", operator.ge, "0.6.1") | ||
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType | ||
except ImportError as e: | ||
log.error(str(e)) | ||
raise | ||
|
||
# Configure Float8LinearConfig parameters from args | ||
scaling_type_input = ScalingType(args.scaling_type_input) | ||
scaling_type_weight = ScalingType(args.scaling_type_weight) | ||
scaling_type_grad_output = ScalingType(args.scaling_type_grad_output) | ||
|
||
enable_fsdp_float8_all_gather = ( | ||
parallel_dims.get("dp_shard_enabled", False) and args.enable_fsdp_float8_all_gather | ||
) | ||
|
||
enable_amax_init = args.enable_amax_init | ||
self.config = Float8LinearConfig( | ||
enable_amax_init=enable_amax_init, | ||
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, | ||
cast_config_input=CastConfig(scaling_type=scaling_type_input), | ||
cast_config_weight=CastConfig(scaling_type=scaling_type_weight), | ||
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), | ||
enable_pre_and_post_forward=args.enable_pre_and_post_forward, | ||
pad_inner_dim=args.pad_inner_dim, | ||
emulate=args.emulate_fp8, | ||
) | ||
|
||
# For precompute_float8_dynamic_scale_for_fsdp | ||
self.precompute_scale = enable_fsdp_float8_all_gather and args.precompute_float8_dynamic_scale_for_fsdp | ||
|
||
# For sync_float8_amax_and_scale_history | ||
self.delayed_scaling = ( | ||
scaling_type_input == ScalingType.DELAYED | ||
or scaling_type_weight == ScalingType.DELAYED | ||
or scaling_type_grad_output == ScalingType.DELAYED | ||
) | ||
self._sync_float8_amax_and_scale_history = None | ||
|
||
log.info("Float8 training active") | ||
|
||
def convert_to_float8_training(self, model: nn.Module, module_filter_fn: callable = None): | ||
"""Converts the linear layers of `model` to `Float8Linear` based on a module filter function. Mutates the model | ||
in place. | ||
|
||
Args: | ||
model (nn.Module): The model whose layers should be converted. | ||
module_filter_fn (callable, optional): A function to filter which modules should be replaced. | ||
Defaults to a model-specific filter based on `model_path`. | ||
|
||
""" | ||
if not self.enable_fp8: | ||
log.warning("FP8 is disabled, so layers will not be replaced.") | ||
return | ||
|
||
log.warning("Enabling FP8 Training") | ||
|
||
# Use the provided filter function or select from the map | ||
if module_filter_fn is None: | ||
model_path_lower = self.model_path.lower() | ||
module_filter_fn = next( | ||
(fn for key, fn in MODULE_FILTER_MAP.items() if key in model_path_lower), | ||
MODULE_FILTER_MAP["default"], # Default filter if no match is found | ||
) | ||
|
||
from torchao.float8 import convert_to_float8_training | ||
|
||
convert_to_float8_training( | ||
model, | ||
config=self.config, | ||
module_filter_fn=module_filter_fn, | ||
) | ||
log.info( | ||
f"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather={self.config.enable_fsdp_float8_all_gather}" | ||
) | ||
|
||
def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, list[nn.Module]]): | ||
if not self.enable_fp8 or not self.precompute_scale: | ||
return | ||
|
||
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp | ||
|
||
models = [model] if isinstance(model, nn.Module) else model | ||
for m in models: | ||
precompute_float8_dynamic_scale_for_fsdp(m) | ||
|
||
def sync_float8_amax_and_scale_history(self, model: Union[nn.Module, list[nn.Module]]): | ||
if not self.enable_fp8 or not self.delayed_scaling: | ||
return | ||
|
||
from torchao.float8 import sync_float8_amax_and_scale_history | ||
|
||
# Cache the compiled function if necessary | ||
if self._sync_float8_amax_and_scale_history is None: | ||
if self.compile: | ||
self._sync_float8_amax_and_scale_history = torch.compile(sync_float8_amax_and_scale_history) | ||
else: | ||
self._sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history | ||
|
||
models = [model] if isinstance(model, nn.Module) else model | ||
for m in models: | ||
self._sync_float8_amax_and_scale_history(m) |
100 changes: 100 additions & 0 deletions
100
examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fsdp2_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import logging | ||
import operator | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING | ||
|
||
import torch | ||
import torch.nn as nn | ||
from lightning_utilities.core.imports import compare_version | ||
|
||
if TYPE_CHECKING: | ||
from torch.distributed.device_mesh import DeviceMesh | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class FSDP2Config: | ||
enable_cpu_offload: bool = False | ||
enable_gradient_checkpointing: bool = False | ||
|
||
|
||
class FSDP2Handler: | ||
"""Handler for wrapping the model layers with FSDP2. | ||
|
||
Args: | ||
args (FSDP2Config): Configuration for FSDP2, including options for CPU offload and gradient checkpointing. | ||
device_mesh (DeviceMesh): Device mesh configuration for FSDP2 parallelism. | ||
|
||
Attributes: | ||
args (FSDP2Config): Stores the FSDP2 configuration. | ||
device_mesh (DeviceMesh): Stores the device mesh configuration. | ||
|
||
""" | ||
|
||
def __init__(self, args: FSDP2Config, device_mesh: "DeviceMesh"): | ||
self.args = args | ||
self.device_mesh = device_mesh | ||
|
||
# Check PyTorch version for FSDP2 support (currently we require PyTorch >= 2.6.0) | ||
try: | ||
compare_version("torch", operator.ge, "2.6.0") | ||
except RuntimeError as e: | ||
log.error(str(e)) | ||
raise | ||
|
||
# Import necessary FSDP modules | ||
try: | ||
from torch.distributed._composable.fsdp import ( | ||
CPUOffloadPolicy, | ||
MixedPrecisionPolicy, | ||
fully_shard, | ||
) | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
checkpoint_wrapper, | ||
) | ||
|
||
self.fully_shard = fully_shard | ||
self.checkpoint_wrapper = checkpoint_wrapper | ||
self.MixedPrecisionPolicy = MixedPrecisionPolicy | ||
self.CPUOffloadPolicy = CPUOffloadPolicy | ||
except ImportError as e: | ||
log.error(f"Failed to import FSDP modules: {e}") | ||
raise | ||
|
||
def wrap_model(self, model: nn.Module): | ||
"""Wraps the model layers with FSDP configurations. | ||
|
||
Args: | ||
model (nn.Module): The model to wrap. | ||
|
||
Returns: | ||
nn.Module: The wrapped model. | ||
|
||
""" | ||
dp_mesh = self.device_mesh["data_parallel"] | ||
assert dp_mesh.size() > 1, "FSDP requires at least two devices." | ||
|
||
fsdp_policy = { | ||
"mesh": dp_mesh, | ||
"mp_policy": self.MixedPrecisionPolicy( | ||
param_dtype=torch.bfloat16, | ||
reduce_dtype=torch.float32, | ||
), | ||
} | ||
if self.args.enable_cpu_offload: | ||
fsdp_policy["offload_policy"] = self.CPUOffloadPolicy() | ||
|
||
for layer_id, module in enumerate(model.model.layers): | ||
reshard_after_forward = layer_id < len(model.model.layers) - 1 | ||
if self.args.enable_gradient_checkpointing: | ||
module = self.checkpoint_wrapper(module) | ||
self.fully_shard( | ||
module, | ||
**fsdp_policy, | ||
reshard_after_forward=reshard_after_forward, | ||
) | ||
model.model.layers[layer_id] = module | ||
|
||
self.fully_shard(model, **fsdp_policy) | ||
return model |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems always being deleted 🤔 any ideas on how to keep it @lantiga