diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index a4504950f4c95..379db372f6c5c 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -25,10 +25,10 @@ from lightning.fabric.accelerators import Accelerator from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_AVAILABLE +from lightning.fabric.plugins import XLAPrecision from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from lightning.fabric.plugins.io.xla import XLACheckpointIO -from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies import _StrategyRegistry, ParallelStrategy from lightning.fabric.strategies.fsdp import _apply_filter from lightning.fabric.strategies.launchers.xla import _XLALauncher @@ -80,7 +80,7 @@ def __init__( accelerator: Optional[Accelerator] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, + precision: Optional[XLAPrecision] = None, auto_wrap_policy: Optional[_POLICY] = None, activation_checkpointing_policy: Optional[_POLICY_SET] = None, state_dict_type: Literal["full", "sharded"] = "sharded", @@ -97,8 +97,8 @@ def __init__( self._checkpoint_io: Optional[CheckpointIO] self._backward_sync_control = _XLAFSDPBackwardSyncControl() - kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs) - kwargs = _activation_checkpointing_kwargs(activation_checkpointing_policy, kwargs) + self._auto_wrap_policy = auto_wrap_policy + self._activation_checkpointing_policy = activation_checkpointing_policy self._fsdp_kwargs = kwargs self._state_dict_type = state_dict_type self._sequential_save = sequential_save @@ -172,17 +172,16 @@ def setup_module_and_optimizers( def setup_module(self, module: Module) -> Module: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in self._fsdp_kwargs: + kwargs = self._parse_fsdp_kwargs() + if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in kwargs: rank_zero_warn( "A XLAFSDP `auto_wrap_policy` is set, but at least one submodule is already wrapped." " The policy will be ignored." ) - del self._fsdp_kwargs["auto_wrap_policy"] - + del kwargs["auto_wrap_policy"] # XLA FSDP requires that the root is wrapped, even if submodules are already wrapped if not isinstance(module, XLAFSDP): - module = XLAFSDP(module=module, **self._fsdp_kwargs) - + module = XLAFSDP(module=module, **kwargs) return module def module_to_device(self, module: Module) -> None: @@ -590,6 +589,17 @@ def load_checkpoint( def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("xla_fsdp", cls, description=cls.__class__.__name__) + def _parse_fsdp_kwargs(self) -> Dict: + # this needs to be delayed because `self.precision` isn't available at init + kwargs = self._fsdp_kwargs.copy() + precision = self.precision + if isinstance(precision, XLAPrecision): + # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass + # it when creating it + kwargs.setdefault("compute_dtype", precision._desired_dtype) + kwargs = _auto_wrap_policy_kwargs(self._auto_wrap_policy, kwargs) + return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs) + def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: if policy is None: diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index b3b7eca823877..63b4a2e0be1e6 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -92,6 +92,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index eba8c9183a3f4..7d9b5ba7cf8d9 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -21,6 +21,7 @@ from torch.optim import Adam from lightning.fabric.accelerators import XLAAccelerator +from lightning.fabric.plugins import XLAPrecision from lightning.fabric.strategies import XLAFSDPStrategy from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl from tests_fabric.helpers.runif import RunIf @@ -117,15 +118,33 @@ def test_xla_fsdp_policy(xla_available): assert strategy._fsdp_kwargs == {"foo": 1} strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear}) - assert "auto_wrap_policy" in strategy._fsdp_kwargs - assert strategy._fsdp_kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" + kwargs = strategy._parse_fsdp_kwargs() + assert set(kwargs) == {"auto_wrap_policy"} + assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}) - assert "auto_wrapper_callable" in strategy._fsdp_kwargs - assert strategy._fsdp_kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + kwargs = strategy._parse_fsdp_kwargs() + kwargs = strategy._parse_fsdp_kwargs() # ensure it's idempotent + assert set(kwargs) == {"auto_wrapper_callable"} + assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + strategy = XLAFSDPStrategy( + accelerator=Mock(), + auto_wrap_policy={torch.nn.Linear}, + activation_checkpointing_policy={torch.nn.Linear}, + precision=XLAPrecision("bf16-true"), + ) + kwargs = strategy._parse_fsdp_kwargs() + assert set(kwargs) == {"auto_wrap_policy", "auto_wrapper_callable", "compute_dtype"} + assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" + assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + assert kwargs["compute_dtype"] is torch.bfloat16 + strategy.teardown() + + strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo") with pytest.raises(ValueError, match="cannot set both"): - XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo") + strategy._parse_fsdp_kwargs() + strategy = XLAFSDPStrategy(activation_checkpointing_policy="foo") with pytest.raises(TypeError, match="must be a set"): - XLAFSDPStrategy(activation_checkpointing_policy="foo") + strategy._parse_fsdp_kwargs()