Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(3/n) Support 2D Parallelism - Efficient loading of full-state checkpoints #19870

Merged
merged 7 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))

- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852))
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870))


### Changed
Expand Down
47 changes: 36 additions & 11 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import shutil
from contextlib import ExitStack
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union

import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
Expand Down Expand Up @@ -429,7 +430,6 @@ def _load_checkpoint(
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)

Expand Down Expand Up @@ -484,13 +484,8 @@ def _load_checkpoint(
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")

state_dict_options = StateDictOptions(
broadcast_from_rank0=True, # type: ignore[call-arg]
full_state_dict=True,
strict=strict,
)
checkpoint = torch.load(path, mmap=True, map_location="cpu")
set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options)
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)

requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
Expand Down Expand Up @@ -525,7 +520,9 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int
_load_raw_module_state(state_dict=state_dict, module=module, world_size=world_size, strict=strict)


def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None:
def _load_raw_module_state(
state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True
) -> None:
"""Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

Expand All @@ -535,11 +532,39 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz

from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg]
set_model_state_dict(module, state_dict, options=state_dict_options)
state_dict_options = StateDictOptions(
broadcast_from_rank0=True, # type: ignore[call-arg]
full_state_dict=True,
strict=strict, # gets ignored at the moment
)

for submodule_name, submodule in module.named_modules():
for param_name, _ in _named_parameters_and_buffers_to_load(submodule):
full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}"
if full_param_name not in state_dict:
# Note: PyTorch does not currently respect the `strict` setting in state_dict_options!
lantiga marked this conversation as resolved.
Show resolved Hide resolved
if not strict:
continue
raise KeyError(
f"The model contains a key '{full_param_name}' that does not exist in the loaded checkpoint."
" To disable strict loading, set `strict=False`."
)
local_state_dict = {param_name: state_dict[full_param_name]}
set_model_state_dict(submodule, local_state_dict, options=state_dict_options)

elif isinstance(module, FSDP):
with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
module.load_state_dict(state_dict, strict=strict)
else:
module.load_state_dict(state_dict, strict=strict)


def _named_parameters_and_buffers_to_load(module: Module) -> Generator:
"""Returns parameters and buffers, with non-persistent buffers excluded."""
for param_name, param in itertools.chain(
module.named_buffers(recurse=False),
module.named_parameters(recurse=False),
):
if param_name in module._non_persistent_buffers_set:
continue
yield param_name, param
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from pathlib import Path
from unittest import mock

Expand All @@ -20,7 +21,7 @@
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric import Fabric
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from torch.utils.data import DataLoader, DistributedSampler

Expand Down Expand Up @@ -675,3 +676,46 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):

state = {"model": model, "steps": 1}
fabric.load(checkpoint_path_full, state)


@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_load_raw_module_state():
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module

class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.parameter = nn.Parameter(torch.rand(2, 2))
self.layer1 = nn.Linear(4, 4)
self.layer2 = nn.Linear(4, 4)
self.register_buffer("persistent_buffer", torch.rand(2), persistent=True)
self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False)

fabric = Fabric(accelerator="cuda", devices=2)
fabric.launch()
fabric.seed_everything(0)

with fabric.init_module():
model = CustomModel()

state_dict = deepcopy(model.state_dict())

with fabric.init_module():
model = CustomModel()

device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",))
plan = {"layer1": ColwiseParallel()}
parallelize_module(model, device_mesh, plan)
_load_raw_module_state(state_dict, model, strict=True)

assert torch.equal(model.parameter, state_dict["parameter"])
assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"])
assert torch.equal(model.layer2.weight, state_dict["layer2.weight"])
assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"])

state_dict.pop("parameter")
with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"):
_load_raw_module_state(state_dict, model, strict=True)

_load_raw_module_state(state_dict, model, strict=False)
Loading