Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convert_module to FSDP #20323

Merged
merged 12 commits into from
Dec 10, 2024
6 changes: 6 additions & 0 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
}
self._desired_input_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_input_dtype)
return module

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
- Fixed PyTorch Lightning FSDP takes more memory than PyTorch FSDP ([#20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323))


## [2.3.0] - 2024-06-13
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
from typing_extensions import get_args, override

import lightning.pytorch as pl
Expand Down Expand Up @@ -73,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
}
self._desired_input_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_input_dtype)
return module

@override
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,21 @@ def test_invalid_precision_with_fsdp_precision():

with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"):
FSDPPrecision(precision="64-true")


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_convert_module(precision, expected_dtype):
precision = FSDPPrecision(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype
18 changes: 18 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ def test_fsdp_precision_config(precision, expected):
assert config.reduce_dtype == expected[2]


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_convert_module(precision, expected_dtype):
precision = FSDPPrecision(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype


def test_fsdp_precision_default_scaler():
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down
Loading