From 1a0887ec98111def6a9863eb2e3c9f6a50e37bfb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 15 May 2024 14:22:00 +0200 Subject: [PATCH 1/7] memory-optimized loading of full checkpoints into dist model --- .../fabric/strategies/model_parallel.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 9cd721f930d1b..6200be8092c73 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -429,7 +429,6 @@ def _load_checkpoint( StateDictOptions, get_model_state_dict, get_optimizer_state_dict, - set_model_state_dict, set_optimizer_state_dict, ) @@ -484,13 +483,8 @@ def _load_checkpoint( if not _TORCH_GREATER_EQUAL_2_4: raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.") - state_dict_options = StateDictOptions( - broadcast_from_rank0=True, # type: ignore[call-arg] - full_state_dict=True, - strict=strict, - ) checkpoint = torch.load(path, mmap=True, map_location="cpu") - set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options) + _load_raw_module_state(checkpoint.pop(module_key), module, world_size=self.world_size, strict=strict) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) @@ -535,8 +529,13 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg] - set_model_state_dict(module, state_dict, options=state_dict_options) + state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, strict=strict) # type: ignore[call-arg] + + for submodule_name, submodule in module.named_modules(): + for param_name, _ in submodule.named_parameters(recurse=False): + full_param_name = f"{submodule_name}.{param_name}" + local_state_dict = {param_name: state_dict[full_param_name]} + set_model_state_dict(submodule, local_state_dict, options=state_dict_options) elif isinstance(module, FSDP): with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False): From d1e5bd159016e1e553016b4f0ac7d2ee31dc2910 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 15 May 2024 14:26:03 +0200 Subject: [PATCH 2/7] simplify --- src/lightning/fabric/strategies/model_parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 6200be8092c73..dc894805cd29d 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -484,7 +484,7 @@ def _load_checkpoint( raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.") checkpoint = torch.load(path, mmap=True, map_location="cpu") - _load_raw_module_state(checkpoint.pop(module_key), module, world_size=self.world_size, strict=strict) + _load_raw_module_state(checkpoint.pop(module_key), module, strict=strict) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) @@ -519,7 +519,9 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int _load_raw_module_state(state_dict=state_dict, module=module, world_size=world_size, strict=strict) -def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None: +def _load_raw_module_state( + state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True +) -> None: """Loads the state dict into the module by gathering all weights first and then and writing back to each shard.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP From 637199996a755f91689769936a68b887ec654572 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 15 May 2024 08:51:12 -0400 Subject: [PATCH 3/7] handle buffers --- src/lightning/fabric/strategies/model_parallel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index dc894805cd29d..f830b53222e52 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import shutil from contextlib import ExitStack from datetime import timedelta @@ -534,7 +535,9 @@ def _load_raw_module_state( state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, strict=strict) # type: ignore[call-arg] for submodule_name, submodule in module.named_modules(): - for param_name, _ in submodule.named_parameters(recurse=False): + for param_name, _ in itertools.chain(submodule.named_buffers(recurse=False), submodule.named_parameters(recurse=False)): + if param_name in submodule._non_persistent_buffers_set: + continue full_param_name = f"{submodule_name}.{param_name}" local_state_dict = {param_name: state_dict[full_param_name]} set_model_state_dict(submodule, local_state_dict, options=state_dict_options) From 91664a1de56ee99ac7ea66d777bcac0c131ad586 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 May 2024 12:51:40 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/model_parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index f830b53222e52..4c959c068c993 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -535,7 +535,9 @@ def _load_raw_module_state( state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, strict=strict) # type: ignore[call-arg] for submodule_name, submodule in module.named_modules(): - for param_name, _ in itertools.chain(submodule.named_buffers(recurse=False), submodule.named_parameters(recurse=False)): + for param_name, _ in itertools.chain( + submodule.named_buffers(recurse=False), submodule.named_parameters(recurse=False) + ): if param_name in submodule._non_persistent_buffers_set: continue full_param_name = f"{submodule_name}.{param_name}" From 4f861f6428f5e79c114379d5ecf9bbad02f6d764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 15 May 2024 09:34:46 -0400 Subject: [PATCH 5/7] handle strict loading, buffers, and add test --- .../fabric/strategies/model_parallel.py | 35 ++++++++++---- .../test_model_parallel_integration.py | 46 ++++++++++++++++++- 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 4c959c068c993..2381cd4b58235 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -16,7 +16,7 @@ from contextlib import ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -532,15 +532,23 @@ def _load_raw_module_state( from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, strict=strict) # type: ignore[call-arg] + state_dict_options = StateDictOptions( + broadcast_from_rank0=True, # type: ignore[call-arg] + full_state_dict=True, + strict=strict, # gets ignored at the moment + ) for submodule_name, submodule in module.named_modules(): - for param_name, _ in itertools.chain( - submodule.named_buffers(recurse=False), submodule.named_parameters(recurse=False) - ): - if param_name in submodule._non_persistent_buffers_set: - continue - full_param_name = f"{submodule_name}.{param_name}" + for param_name, _ in _named_parameters_and_buffers_to_load(submodule): + full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}" + if full_param_name not in state_dict: + # Note: PyTorch does not currently respect the `strict` setting in state_dict_options! + if not strict: + continue + raise KeyError( + f"The model contains a key '{full_param_name}' that does not exist in the loaded checkpoint." + " To disable strict loading, set `strict=False`." + ) local_state_dict = {param_name: state_dict[full_param_name]} set_model_state_dict(submodule, local_state_dict, options=state_dict_options) @@ -549,3 +557,14 @@ def _load_raw_module_state( module.load_state_dict(state_dict, strict=strict) else: module.load_state_dict(state_dict, strict=strict) + + +def _named_parameters_and_buffers_to_load(module: Module) -> Generator: + """Returns parameters and buffers, with non-persistent buffers excluded.""" + for param_name, param in itertools.chain( + module.named_buffers(recurse=False), + module.named_parameters(recurse=False), + ): + if param_name in module._non_persistent_buffers_set: + continue + yield param_name, param diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index d864a9687ebb5..dac159afcf5bf 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -20,9 +20,10 @@ import torch.nn as nn import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.utilities.load import _load_distributed_checkpoint from torch.utils.data import DataLoader, DistributedSampler +from copy import deepcopy +from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf @@ -675,3 +676,46 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): state = {"model": model, "steps": 1} fabric.load(checkpoint_path_full, state) + + +@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) +def test_load_raw_module_state(): + from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel + from torch.distributed.device_mesh import init_device_mesh + + class CustomModel(nn.Module): + def __init__(self): + super().__init__() + self.parameter = nn.Parameter(torch.rand(2, 2)) + self.layer1 = nn.Linear(4, 4) + self.layer2 = nn.Linear(4, 4) + self.register_buffer("persistent_buffer", torch.rand(2), persistent=True) + self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False) + + fabric = Fabric(accelerator="cuda", devices=2) + fabric.launch() + fabric.seed_everything(0) + + with fabric.init_module(): + model = CustomModel() + + state_dict = deepcopy(model.state_dict()) + + with fabric.init_module(): + model = CustomModel() + + device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp", )) + plan = {"layer1": ColwiseParallel()} + parallelize_module(model, device_mesh, plan) + _load_raw_module_state(state_dict, model, strict=True) + + assert torch.equal(model.parameter, state_dict["parameter"]) + assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"]) + assert torch.equal(model.layer2.weight, state_dict["layer2.weight"]) + assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"]) + + state_dict.pop("parameter") + with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"): + _load_raw_module_state(state_dict, model, strict=True) + + _load_raw_module_state(state_dict, model, strict=False) From fcb7fdba3ee520d86630ecec62c4b9d8d8273b68 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 May 2024 13:35:22 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fabric/strategies/model_parallel.py | 4 ++-- .../test_model_parallel_integration.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 2381cd4b58235..4141ea454ca51 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -534,7 +534,7 @@ def _load_raw_module_state( state_dict_options = StateDictOptions( broadcast_from_rank0=True, # type: ignore[call-arg] - full_state_dict=True, + full_state_dict=True, strict=strict, # gets ignored at the moment ) @@ -562,7 +562,7 @@ def _load_raw_module_state( def _named_parameters_and_buffers_to_load(module: Module) -> Generator: """Returns parameters and buffers, with non-persistent buffers excluded.""" for param_name, param in itertools.chain( - module.named_buffers(recurse=False), + module.named_buffers(recurse=False), module.named_parameters(recurse=False), ): if param_name in module._non_persistent_buffers_set: diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index dac159afcf5bf..1f12822c69ee6 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from copy import deepcopy from pathlib import Path from unittest import mock @@ -20,10 +21,9 @@ import torch.nn as nn import torch.nn.functional as F from lightning.fabric import Fabric +from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state from lightning.fabric.utilities.load import _load_distributed_checkpoint from torch.utils.data import DataLoader, DistributedSampler -from copy import deepcopy -from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf @@ -680,8 +680,8 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) def test_load_raw_module_state(): - from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module class CustomModel(nn.Module): def __init__(self): @@ -691,24 +691,24 @@ def __init__(self): self.layer2 = nn.Linear(4, 4) self.register_buffer("persistent_buffer", torch.rand(2), persistent=True) self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False) - + fabric = Fabric(accelerator="cuda", devices=2) fabric.launch() fabric.seed_everything(0) - + with fabric.init_module(): model = CustomModel() - + state_dict = deepcopy(model.state_dict()) with fabric.init_module(): model = CustomModel() - device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp", )) + device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",)) plan = {"layer1": ColwiseParallel()} parallelize_module(model, device_mesh, plan) _load_raw_module_state(state_dict, model, strict=True) - + assert torch.equal(model.parameter, state_dict["parameter"]) assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"]) assert torch.equal(model.layer2.weight, state_dict["layer2.weight"]) @@ -717,5 +717,5 @@ def __init__(self): state_dict.pop("parameter") with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"): _load_raw_module_state(state_dict, model, strict=True) - + _load_raw_module_state(state_dict, model, strict=False) From bd2843f6cbac769ad71b0b6404e411e9844ea9ce Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 15 May 2024 15:50:31 +0200 Subject: [PATCH 7/7] chlog --- src/lightning/fabric/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index c5b883320bf2e..b74d9c34ea546 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708)) -- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852)) +- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870)) ### Changed