Skip to content

Commit

Permalink
memory-optimized loading of full checkpoints into dist model
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 15, 2024
1 parent 9455871 commit 1a0887e
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ def _load_checkpoint(
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)

Expand Down Expand Up @@ -484,13 +483,8 @@ def _load_checkpoint(
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")

state_dict_options = StateDictOptions(
broadcast_from_rank0=True, # type: ignore[call-arg]
full_state_dict=True,
strict=strict,
)
checkpoint = torch.load(path, mmap=True, map_location="cpu")
set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options)
_load_raw_module_state(checkpoint.pop(module_key), module, world_size=self.world_size, strict=strict)

requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
Expand Down Expand Up @@ -535,8 +529,13 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz

from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg]
set_model_state_dict(module, state_dict, options=state_dict_options)
state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, strict=strict) # type: ignore[call-arg]

for submodule_name, submodule in module.named_modules():
for param_name, _ in submodule.named_parameters(recurse=False):
full_param_name = f"{submodule_name}.{param_name}"
local_state_dict = {param_name: state_dict[full_param_name]}
set_model_state_dict(submodule, local_state_dict, options=state_dict_options)

elif isinstance(module, FSDP):
with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
Expand Down

0 comments on commit 1a0887e

Please sign in to comment.