Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serialized checkpoint loading #9605

Merged
merged 17 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from pathlib import Path
from time import sleep
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import __main__
import numpy as np
Expand Down Expand Up @@ -50,10 +50,16 @@
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
init_ddp_connection,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
Expand Down Expand Up @@ -122,6 +128,7 @@ def __init__(
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self._self_deleted_checkpoint_state_dict: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -530,3 +537,21 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

self._self_deleted_checkpoint_state_dict = False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if "state_dict" not in checkpoint and self._self_deleted_checkpoint_state_dict:
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
return
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.lightning_module.load_state_dict(checkpoint["state_dict"])

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
rank_zero_info(f"DistributedDataParallel has {self.num_processes} processes. Serializing to avoid CPU OOMs.")
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
for current_worker in range(self.num_processes):
if self.local_rank == current_worker:
checkpoint = super().load_checkpoint(checkpoint_path)
del checkpoint["state_dict"]
self._self_deleted_checkpoint_state_dict = True
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
self.barrier()
return checkpoint
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)
checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path)
self.lightning_module.on_load_checkpoint(checkpoint)
self.load_model_state_dict(checkpoint)
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
return checkpoint

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
self.lightning_module.load_state_dict(checkpoint["state_dict"])
Expand Down
12 changes: 1 addition & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,10 @@ def restore_model(self) -> None:

model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(self._loaded_checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

# reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing.
if not self.trainer.is_global_zero:
for module in self.trainer.lightning_module.modules():
Expand All @@ -156,12 +150,8 @@ def restore_model(self) -> None:

def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if checkpoint_path is not None:
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
self._load_and_validate_checkpoint(checkpoint_path)

def restore_training_state(self) -> None:
"""Restore the trainer state from the pre-loaded checkpoint.
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,9 +984,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand All @@ -996,6 +993,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,

# check if we should delay restoring checkpoint till later
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()
self._restore_modules_and_callbacks()

self._call_configure_sharded_model() # allow user to setup in model sharded environment
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/connectors/test_checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
connector = trainer.checkpoint_connector
trainer.accelerator.connect(model)
connector.resume_start()
assert connector.resume_checkpoint_path == ckpt_path
assert connector._loaded_checkpoint
Expand Down