Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Set the compute_dtype with XLAFSDP #18497

Merged
merged 6 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_AVAILABLE
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import _StrategyRegistry, ParallelStrategy
from lightning.fabric.strategies.fsdp import _apply_filter
from lightning.fabric.strategies.launchers.xla import _XLALauncher
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
precision: Optional[XLAPrecision] = None,
auto_wrap_policy: Optional[_POLICY] = None,
activation_checkpointing_policy: Optional[_POLICY_SET] = None,
state_dict_type: Literal["full", "sharded"] = "sharded",
Expand All @@ -97,8 +97,8 @@ def __init__(
self._checkpoint_io: Optional[CheckpointIO]
self._backward_sync_control = _XLAFSDPBackwardSyncControl()

kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)
kwargs = _activation_checkpointing_kwargs(activation_checkpointing_policy, kwargs)
self._auto_wrap_policy = auto_wrap_policy
self._activation_checkpointing_policy = activation_checkpointing_policy
self._fsdp_kwargs = kwargs
self._state_dict_type = state_dict_type
self._sequential_save = sequential_save
Expand Down Expand Up @@ -172,17 +172,16 @@ def setup_module_and_optimizers(
def setup_module(self, module: Module) -> Module:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP

if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in self._fsdp_kwargs:
kwargs = self._parse_fsdp_kwargs()
if any(isinstance(mod, XLAFSDP) for mod in module.modules()) and "auto_wrap_policy" in kwargs:
rank_zero_warn(
"A XLAFSDP `auto_wrap_policy` is set, but at least one submodule is already wrapped."
" The policy will be ignored."
)
del self._fsdp_kwargs["auto_wrap_policy"]

del kwargs["auto_wrap_policy"]
# XLA FSDP requires that the root is wrapped, even if submodules are already wrapped
if not isinstance(module, XLAFSDP):
module = XLAFSDP(module=module, **self._fsdp_kwargs)

module = XLAFSDP(module=module, **kwargs)
return module

def module_to_device(self, module: Module) -> None:
Expand Down Expand Up @@ -590,6 +589,17 @@ def load_checkpoint(
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("xla_fsdp", cls, description=cls.__class__.__name__)

def _parse_fsdp_kwargs(self) -> Dict:
# this needs to be delayed because `self.precision` isn't available at init
kwargs = self._fsdp_kwargs.copy()
precision = self.precision
if isinstance(precision, XLAPrecision):
# the `compute_dtype` will be passed to the `auto_wrapper_callable` automatically, so we don't need to pass
# it when creating it
kwargs.setdefault("compute_dtype", precision._desired_dtype)
kwargs = _auto_wrap_policy_kwargs(self._auto_wrap_policy, kwargs)
return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs)


def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict:
if policy is None:
Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)
Expand Down
31 changes: 25 additions & 6 deletions tests/tests_fabric/strategies/test_xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.optim import Adam

from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.strategies import XLAFSDPStrategy
from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl
from tests_fabric.helpers.runif import RunIf
Expand Down Expand Up @@ -117,15 +118,33 @@ def test_xla_fsdp_policy(xla_available):
assert strategy._fsdp_kwargs == {"foo": 1}

strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear})
assert "auto_wrap_policy" in strategy._fsdp_kwargs
assert strategy._fsdp_kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy"
kwargs = strategy._parse_fsdp_kwargs()
assert set(kwargs) == {"auto_wrap_policy"}
assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy"

strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
assert "auto_wrapper_callable" in strategy._fsdp_kwargs
assert strategy._fsdp_kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper
kwargs = strategy._parse_fsdp_kwargs()
kwargs = strategy._parse_fsdp_kwargs() # ensure it's idempotent
assert set(kwargs) == {"auto_wrapper_callable"}
assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper

strategy = XLAFSDPStrategy(
accelerator=Mock(),
auto_wrap_policy={torch.nn.Linear},
activation_checkpointing_policy={torch.nn.Linear},
precision=XLAPrecision("bf16-true"),
)
kwargs = strategy._parse_fsdp_kwargs()
assert set(kwargs) == {"auto_wrap_policy", "auto_wrapper_callable", "compute_dtype"}
assert kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy"
assert kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper
assert kwargs["compute_dtype"] is torch.bfloat16
strategy.teardown()

strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo")
with pytest.raises(ValueError, match="cannot set both"):
XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo")
strategy._parse_fsdp_kwargs()

strategy = XLAFSDPStrategy(activation_checkpointing_policy="foo")
with pytest.raises(TypeError, match="must be a set"):
XLAFSDPStrategy(activation_checkpointing_policy="foo")
strategy._parse_fsdp_kwargs()
Loading