From 2f50164d7193c1b219fc25e411ca4c4f34a7561f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 6 Sep 2023 19:50:31 +0200 Subject: [PATCH 1/5] [TPU] Set the compute_dtype with XLAFSDP --- src/lightning/fabric/strategies/xla_fsdp.py | 22 ++++++++++---- tests/tests_fabric/conftest.py | 1 + .../tests_fabric/strategies/test_xla_fsdp.py | 30 +++++++++++++++---- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index a4504950f4c95..ac0b4c198a18a 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -25,10 +25,10 @@ from lightning.fabric.accelerators import Accelerator from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_AVAILABLE +from lightning.fabric.plugins import XLAPrecision from lightning.fabric.plugins.environments import XLAEnvironment from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from lightning.fabric.plugins.io.xla import XLACheckpointIO -from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies import _StrategyRegistry, ParallelStrategy from lightning.fabric.strategies.fsdp import _apply_filter from lightning.fabric.strategies.launchers.xla import _XLALauncher @@ -80,7 +80,7 @@ def __init__( accelerator: Optional[Accelerator] = None, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, + precision: Optional[XLAPrecision] = None, auto_wrap_policy: Optional[_POLICY] = None, activation_checkpointing_policy: Optional[_POLICY_SET] = None, state_dict_type: Literal["full", "sharded"] = "sharded", @@ -97,8 +97,8 @@ def __init__( self._checkpoint_io: Optional[CheckpointIO] self._backward_sync_control = _XLAFSDPBackwardSyncControl() - kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs) - kwargs = _activation_checkpointing_kwargs(activation_checkpointing_policy, kwargs) + self._auto_wrap_policy = auto_wrap_policy + self._activation_checkpointing_policy = activation_checkpointing_policy self._fsdp_kwargs = kwargs self._state_dict_type = state_dict_type self._sequential_save = sequential_save @@ -181,7 +181,8 @@ def setup_module(self, module: Module) -> Module: # XLA FSDP requires that the root is wrapped, even if submodules are already wrapped if not isinstance(module, XLAFSDP): - module = XLAFSDP(module=module, **self._fsdp_kwargs) + kwargs = self._parse_fsdp_kwargs() + module = XLAFSDP(module=module, **kwargs) return module @@ -590,6 +591,17 @@ def load_checkpoint( def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("xla_fsdp", cls, description=cls.__class__.__name__) + def _parse_fsdp_kwargs(self) -> Dict: + # this needs to be delayed because `self.precision` isn't available at init + kwargs = self._fsdp_kwargs + precision = self.precision + if isinstance(precision, XLAPrecision): + # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass + # it when creating it + kwargs.setdefault("compute_dtype", precision._desired_dtype) + kwargs = _auto_wrap_policy_kwargs(self._auto_wrap_policy, kwargs) + return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs) + def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: if policy is None: diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index b3b7eca823877..63b4a2e0be1e6 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -92,6 +92,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value) + monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index eba8c9183a3f4..51721a7e05c44 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -21,6 +21,7 @@ from torch.optim import Adam from lightning.fabric.accelerators import XLAAccelerator +from lightning.fabric.plugins import XLAPrecision from lightning.fabric.strategies import XLAFSDPStrategy from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl from tests_fabric.helpers.runif import RunIf @@ -117,15 +118,32 @@ def test_xla_fsdp_policy(xla_available): assert strategy._fsdp_kwargs == {"foo": 1} strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear}) - assert "auto_wrap_policy" in strategy._fsdp_kwargs - assert strategy._fsdp_kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" + kwargs = strategy._parse_fsdp_kwargs() + assert set(kwargs) == {"auto_wrap_policy"} + assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}) - assert "auto_wrapper_callable" in strategy._fsdp_kwargs - assert strategy._fsdp_kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + kwargs = strategy._parse_fsdp_kwargs() + assert set(kwargs) == {"auto_wrapper_callable"} + assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + strategy = XLAFSDPStrategy( + accelerator=Mock(), + auto_wrap_policy={torch.nn.Linear}, + activation_checkpointing_policy={torch.nn.Linear}, + precision=XLAPrecision("bf16-true"), + ) + kwargs = strategy._parse_fsdp_kwargs() + assert set(kwargs) == {"auto_wrap_policy", "auto_wrapper_callable", "compute_dtype"} + assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" + assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper + assert kwargs["compute_dtype"] is torch.bfloat16 + strategy.teardown() + + strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo") with pytest.raises(ValueError, match="cannot set both"): - XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo") + strategy._parse_fsdp_kwargs() + strategy = XLAFSDPStrategy(activation_checkpointing_policy="foo") with pytest.raises(TypeError, match="must be a set"): - XLAFSDPStrategy(activation_checkpointing_policy="foo") + strategy._parse_fsdp_kwargs() From 72c603646487fa0c29627efb57f4248a48c17b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 6 Sep 2023 20:41:51 +0200 Subject: [PATCH 2/5] Fix --- src/lightning/fabric/strategies/xla_fsdp.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index ac0b4c198a18a..acc675b7e931c 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -172,18 +172,16 @@ def setup_module_and_optimizers( def setup_module(self, module: Module) -> Module: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in self._fsdp_kwargs: + kwargs = self._parse_fsdp_kwargs() + if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in kwargs: rank_zero_warn( "A XLAFSDP `auto_wrap_policy` is set, but at least one submodule is already wrapped." " The policy will be ignored." ) - del self._fsdp_kwargs["auto_wrap_policy"] - + del kwargs["auto_wrap_policy"] # XLA FSDP requires that the root is wrapped, even if submodules are already wrapped if not isinstance(module, XLAFSDP): - kwargs = self._parse_fsdp_kwargs() module = XLAFSDP(module=module, **kwargs) - return module def module_to_device(self, module: Module) -> None: From 978b8c229d64aeb18d63fae116bc0288ab441e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 6 Sep 2023 20:52:45 +0200 Subject: [PATCH 3/5] Idempotent --- src/lightning/fabric/strategies/xla_fsdp.py | 3 ++- tests/tests_fabric/strategies/test_xla_fsdp.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index acc675b7e931c..dc353fa6b56d2 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -13,6 +13,7 @@ # limitations under the License. import io from contextlib import contextmanager, nullcontext +from copy import copy from functools import partial from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple, Type, TYPE_CHECKING, Union @@ -591,7 +592,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: def _parse_fsdp_kwargs(self) -> Dict: # this needs to be delayed because `self.precision` isn't available at init - kwargs = self._fsdp_kwargs + kwargs = copy(self._fsdp_kwargs) precision = self.precision if isinstance(precision, XLAPrecision): # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index 51721a7e05c44..04a49bd8cb0d1 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -124,6 +124,7 @@ def test_xla_fsdp_policy(xla_available): strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}) kwargs = strategy._parse_fsdp_kwargs() + kwargs = strategy._parse_fsdp_kwargs() # ensure its idempotent assert set(kwargs) == {"auto_wrapper_callable"} assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper From bdf5a9cd0c11a92cf66ec4d265512a4006e4952d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 6 Sep 2023 20:52:58 +0200 Subject: [PATCH 4/5] typo --- tests/tests_fabric/strategies/test_xla_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index 04a49bd8cb0d1..7d9b5ba7cf8d9 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -124,7 +124,7 @@ def test_xla_fsdp_policy(xla_available): strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}) kwargs = strategy._parse_fsdp_kwargs() - kwargs = strategy._parse_fsdp_kwargs() # ensure its idempotent + kwargs = strategy._parse_fsdp_kwargs() # ensure it's idempotent assert set(kwargs) == {"auto_wrapper_callable"} assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper From 268be2e5ebbbf9e59b7769b915930d326226eb03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 6 Sep 2023 23:35:55 +0200 Subject: [PATCH 5/5] nit --- src/lightning/fabric/strategies/xla_fsdp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index dc353fa6b56d2..379db372f6c5c 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -13,7 +13,6 @@ # limitations under the License. import io from contextlib import contextmanager, nullcontext -from copy import copy from functools import partial from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple, Type, TYPE_CHECKING, Union @@ -592,7 +591,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: def _parse_fsdp_kwargs(self) -> Dict: # this needs to be delayed because `self.precision` isn't available at init - kwargs = copy(self._fsdp_kwargs) + kwargs = self._fsdp_kwargs.copy() precision = self.precision if isinstance(precision, XLAPrecision): # the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass