From cd8acc26c3dd4d37a2b9aca458bd0d511d77df5c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 15 May 2024 19:07:31 +0200 Subject: [PATCH] (3/n) Support 2D Parallelism - Efficient loading of full-state checkpoints (#19870) * memory-optimized loading of full checkpoints into dist model * simplify * handle buffers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * handle strict loading, buffers, and add test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/fabric/CHANGELOG.md | 2 +- .../fabric/strategies/model_parallel.py | 47 ++++++++++++++----- .../test_model_parallel_integration.py | 46 +++++++++++++++++- 3 files changed, 82 insertions(+), 13 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index c5b883320bf2e..b74d9c34ea546 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708)) -- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852)) +- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870)) ### Changed diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 9cd721f930d1b..4141ea454ca51 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import shutil from contextlib import ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -429,7 +430,6 @@ def _load_checkpoint( StateDictOptions, get_model_state_dict, get_optimizer_state_dict, - set_model_state_dict, set_optimizer_state_dict, ) @@ -484,13 +484,8 @@ def _load_checkpoint( if not _TORCH_GREATER_EQUAL_2_4: raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.") - state_dict_options = StateDictOptions( - broadcast_from_rank0=True, # type: ignore[call-arg] - full_state_dict=True, - strict=strict, - ) checkpoint = torch.load(path, mmap=True, map_location="cpu") - set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options) + _load_raw_module_state(checkpoint.pop(module_key), module, strict=strict) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) @@ -525,7 +520,9 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int _load_raw_module_state(state_dict=state_dict, module=module, world_size=world_size, strict=strict) -def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None: +def _load_raw_module_state( + state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True +) -> None: """Loads the state dict into the module by gathering all weights first and then and writing back to each shard.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -535,11 +532,39 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg] - set_model_state_dict(module, state_dict, options=state_dict_options) + state_dict_options = StateDictOptions( + broadcast_from_rank0=True, # type: ignore[call-arg] + full_state_dict=True, + strict=strict, # gets ignored at the moment + ) + + for submodule_name, submodule in module.named_modules(): + for param_name, _ in _named_parameters_and_buffers_to_load(submodule): + full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}" + if full_param_name not in state_dict: + # Note: PyTorch does not currently respect the `strict` setting in state_dict_options! + if not strict: + continue + raise KeyError( + f"The model contains a key '{full_param_name}' that does not exist in the loaded checkpoint." + " To disable strict loading, set `strict=False`." + ) + local_state_dict = {param_name: state_dict[full_param_name]} + set_model_state_dict(submodule, local_state_dict, options=state_dict_options) elif isinstance(module, FSDP): with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False): module.load_state_dict(state_dict, strict=strict) else: module.load_state_dict(state_dict, strict=strict) + + +def _named_parameters_and_buffers_to_load(module: Module) -> Generator: + """Returns parameters and buffers, with non-persistent buffers excluded.""" + for param_name, param in itertools.chain( + module.named_buffers(recurse=False), + module.named_parameters(recurse=False), + ): + if param_name in module._non_persistent_buffers_set: + continue + yield param_name, param diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index d864a9687ebb5..1f12822c69ee6 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from copy import deepcopy from pathlib import Path from unittest import mock @@ -20,7 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.strategies import ModelParallelStrategy +from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state from lightning.fabric.utilities.load import _load_distributed_checkpoint from torch.utils.data import DataLoader, DistributedSampler @@ -675,3 +676,46 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): state = {"model": model, "steps": 1} fabric.load(checkpoint_path_full, state) + + +@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) +def test_load_raw_module_state(): + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module + + class CustomModel(nn.Module): + def __init__(self): + super().__init__() + self.parameter = nn.Parameter(torch.rand(2, 2)) + self.layer1 = nn.Linear(4, 4) + self.layer2 = nn.Linear(4, 4) + self.register_buffer("persistent_buffer", torch.rand(2), persistent=True) + self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False) + + fabric = Fabric(accelerator="cuda", devices=2) + fabric.launch() + fabric.seed_everything(0) + + with fabric.init_module(): + model = CustomModel() + + state_dict = deepcopy(model.state_dict()) + + with fabric.init_module(): + model = CustomModel() + + device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",)) + plan = {"layer1": ColwiseParallel()} + parallelize_module(model, device_mesh, plan) + _load_raw_module_state(state_dict, model, strict=True) + + assert torch.equal(model.parameter, state_dict["parameter"]) + assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"]) + assert torch.equal(model.layer2.weight, state_dict["layer2.weight"]) + assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"]) + + state_dict.pop("parameter") + with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"): + _load_raw_module_state(state_dict, model, strict=True) + + _load_raw_module_state(state_dict, model, strict=False)