Skip to content

Commit

Permalink
Revert "Support serialized checkpoint loading (#9605)" (#10057)
Browse files Browse the repository at this point in the history
This reverts commit f0e6f1b.
  • Loading branch information
jjenniferdai authored Oct 21, 2021
1 parent aa15404 commit 2d9db21
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 39 deletions.
36 changes: 3 additions & 33 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from pathlib import Path
from time import sleep
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Optional, Union

import __main__
import numpy as np
Expand Down Expand Up @@ -51,16 +51,10 @@
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_ddp_connection,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
Expand Down Expand Up @@ -129,7 +123,6 @@ def __init__(
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self._has_loaded_state_dict: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -540,26 +533,3 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

self._has_loaded_state_dict = False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if "state_dict" not in checkpoint and self._has_loaded_state_dict:
return
self.lightning_module.load_state_dict(checkpoint["state_dict"])

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
rank_zero_info(
f"DistributedDataParallel has {self.num_processes} processes. "
"Serializing checkpoint loading to avoid CPU OOMs."
)
for current_worker in range(self.num_processes):
if self.local_rank == current_worker:
checkpoint = super().load_checkpoint(checkpoint_path)
self.lightning_module.on_load_checkpoint(checkpoint)
self.load_model_state_dict(checkpoint)
del checkpoint["state_dict"]
self._has_loaded_state_dict = True
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
self.barrier()
return checkpoint
2 changes: 0 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,6 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool:
return False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if "state_dict" not in checkpoint and self._has_loaded_state_dict:
return
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
if self.load_full_weights and self.zero_stage_3:
self.model_to_device()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
return self.checkpoint_io.load_checkpoint(checkpoint_path)

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
self.lightning_module.on_load_checkpoint(checkpoint)
self.lightning_module.load_state_dict(checkpoint["state_dict"])

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def restore_model(self) -> None:

model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(self._loaded_checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)
Expand All @@ -160,6 +163,7 @@ def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
if checkpoint_path is not None:
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand All @@ -1038,8 +1041,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,

# check if we should delay restoring checkpoint till later
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()
self.checkpoint_connector.resume_start()
self._restore_modules_and_callbacks()

Expand Down
1 change: 0 additions & 1 deletion tests/trainer/connectors/test_checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
connector = trainer.checkpoint_connector
trainer.accelerator.connect(model)
connector.resume_start()
assert connector.resume_checkpoint_path == ckpt_path
assert connector._loaded_checkpoint
Expand Down

0 comments on commit 2d9db21

Please sign in to comment.