Skip to content

Commit

Permalink
(3/n) Support 2D Parallelism - Efficient loading of full-state checkp…
Browse files Browse the repository at this point in the history
…oints (#19870)

* memory-optimized loading of full checkpoints into dist model

* simplify

* handle buffers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* handle strict loading, buffers, and add test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chlog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored May 15, 2024
1 parent 9455871 commit cd8acc2
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))

- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852))
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870))


### Changed
Expand Down
47 changes: 36 additions & 11 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import shutil
from contextlib import ExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union

import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
Expand Down Expand Up @@ -429,7 +430,6 @@ def _load_checkpoint(
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)

Expand Down Expand Up @@ -484,13 +484,8 @@ def _load_checkpoint(
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")

state_dict_options = StateDictOptions(
broadcast_from_rank0=True, # type: ignore[call-arg]
full_state_dict=True,
strict=strict,
)
checkpoint = torch.load(path, mmap=True, map_location="cpu")
set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options)
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)

requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
Expand Down Expand Up @@ -525,7 +520,9 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int
_load_raw_module_state(state_dict=state_dict, module=module, world_size=world_size, strict=strict)


def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None:
def _load_raw_module_state(
state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True
) -> None:
"""Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

Expand All @@ -535,11 +532,39 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz

from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg]
set_model_state_dict(module, state_dict, options=state_dict_options)
state_dict_options = StateDictOptions(
broadcast_from_rank0=True, # type: ignore[call-arg]
full_state_dict=True,
strict=strict, # gets ignored at the moment
)

for submodule_name, submodule in module.named_modules():
for param_name, _ in _named_parameters_and_buffers_to_load(submodule):
full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}"
if full_param_name not in state_dict:
# Note: PyTorch does not currently respect the `strict` setting in state_dict_options!
if not strict:
continue
raise KeyError(
f"The model contains a key '{full_param_name}' that does not exist in the loaded checkpoint."
" To disable strict loading, set `strict=False`."
)
local_state_dict = {param_name: state_dict[full_param_name]}
set_model_state_dict(submodule, local_state_dict, options=state_dict_options)

elif isinstance(module, FSDP):
with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
module.load_state_dict(state_dict, strict=strict)
else:
module.load_state_dict(state_dict, strict=strict)


def _named_parameters_and_buffers_to_load(module: Module) -> Generator:
"""Returns parameters and buffers, with non-persistent buffers excluded."""
for param_name, param in itertools.chain(
module.named_buffers(recurse=False),
module.named_parameters(recurse=False),
):
if param_name in module._non_persistent_buffers_set:
continue
yield param_name, param
46 changes: 45 additions & 1 deletion tests/tests_fabric/strategies/test_model_parallel_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from pathlib import Path
from unittest import mock

Expand All @@ -20,7 +21,7 @@
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric import Fabric
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from torch.utils.data import DataLoader, DistributedSampler

Expand Down Expand Up @@ -675,3 +676,46 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):

state = {"model": model, "steps": 1}
fabric.load(checkpoint_path_full, state)


@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_load_raw_module_state():
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module

class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.parameter = nn.Parameter(torch.rand(2, 2))
self.layer1 = nn.Linear(4, 4)
self.layer2 = nn.Linear(4, 4)
self.register_buffer("persistent_buffer", torch.rand(2), persistent=True)
self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False)

fabric = Fabric(accelerator="cuda", devices=2)
fabric.launch()
fabric.seed_everything(0)

with fabric.init_module():
model = CustomModel()

state_dict = deepcopy(model.state_dict())

with fabric.init_module():
model = CustomModel()

device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",))
plan = {"layer1": ColwiseParallel()}
parallelize_module(model, device_mesh, plan)
_load_raw_module_state(state_dict, model, strict=True)

assert torch.equal(model.parameter, state_dict["parameter"])
assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"])
assert torch.equal(model.layer2.weight, state_dict["layer2.weight"])
assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"])

state_dict.pop("parameter")
with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"):
_load_raw_module_state(state_dict, model, strict=True)

_load_raw_module_state(state_dict, model, strict=False)

0 comments on commit cd8acc2

Please sign in to comment.