From 1a0887ec98111def6a9863eb2e3c9f6a50e37bfb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 15 May 2024 14:22:00 +0200 Subject: [PATCH] memory-optimized loading of full checkpoints into dist model --- .../fabric/strategies/model_parallel.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 9cd721f930d1b..6200be8092c73 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -429,7 +429,6 @@ def _load_checkpoint( StateDictOptions, get_model_state_dict, get_optimizer_state_dict, - set_model_state_dict, set_optimizer_state_dict, ) @@ -484,13 +483,8 @@ def _load_checkpoint( if not _TORCH_GREATER_EQUAL_2_4: raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.") - state_dict_options = StateDictOptions( - broadcast_from_rank0=True, # type: ignore[call-arg] - full_state_dict=True, - strict=strict, - ) checkpoint = torch.load(path, mmap=True, map_location="cpu") - set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options) + _load_raw_module_state(checkpoint.pop(module_key), module, world_size=self.world_size, strict=strict) requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) @@ -535,8 +529,13 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg] - set_model_state_dict(module, state_dict, options=state_dict_options) + state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, strict=strict) # type: ignore[call-arg] + + for submodule_name, submodule in module.named_modules(): + for param_name, _ in submodule.named_parameters(recurse=False): + full_param_name = f"{submodule_name}.{param_name}" + local_state_dict = {param_name: state_dict[full_param_name]} + set_model_state_dict(submodule, local_state_dict, options=state_dict_options) elif isinstance(module, FSDP): with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):