Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serialized checkpoint loading #9605

Merged
merged 17 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from pathlib import Path
from time import sleep
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import __main__
import numpy as np
Expand Down Expand Up @@ -51,13 +51,14 @@
from pytorch_lightning.utilities.distributed import (
distributed_available,
init_ddp_connection,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self._self_deleted_checkpoint_state_dict: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -535,3 +537,21 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

self._self_deleted_checkpoint_state_dict = False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if "state_dict" not in checkpoint and self._self_deleted_checkpoint_state_dict:
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
return
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.lightning_module.load_state_dict(checkpoint["state_dict"])

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
rank_zero_info(f"DistributedDataParallel has {self.num_processes} processes. Serializing to avoid CPU OOMs.")
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
for current_worker in range(self.num_processes):
if self.local_rank == current_worker:
checkpoint = super().load_checkpoint(checkpoint_path)
del checkpoint["state_dict"]
self._self_deleted_checkpoint_state_dict = True
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
self.barrier()
return checkpoint
11 changes: 0 additions & 11 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,6 @@ def wrap_policy(*args, **kwargs):
):
yield

def setup_environment(self) -> None:
super().setup_environment()
model_call_configure_sharded_model_hook = getattr(
self.lightning_module, "call_configure_sharded_model_hook", False
)
if not model_call_configure_sharded_model_hook:
# if model has not called configure sharded model, we reset
# the training type plugin's call_configure_sharded_model_hook
# to give trainer a chance to configure.
self.call_configure_sharded_model_hook = True
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved

def configure_ddp(self) -> None:
if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)
checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path)
self.lightning_module.on_load_checkpoint(checkpoint)
self.load_model_state_dict(checkpoint)
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
return checkpoint

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
self.lightning_module.load_state_dict(checkpoint["state_dict"])
Expand Down
12 changes: 1 addition & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,10 @@ def restore_model(self) -> None:

model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(self._loaded_checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

# reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing.
if not self.trainer.is_global_zero:
for module in self.trainer.lightning_module.modules():
Expand All @@ -154,12 +148,8 @@ def restore_model(self) -> None:

def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if checkpoint_path is not None:
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
self._load_and_validate_checkpoint(checkpoint_path)

def restore_training_state(self) -> None:
"""Restore the trainer state from the pre-loaded checkpoint.
Expand Down
20 changes: 5 additions & 15 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,9 +976,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand All @@ -988,6 +985,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,

# check if we should delay restoring checkpoint till later
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()
self._restore_modules_and_callbacks()

self._call_configure_sharded_model() # allow user to setup in model sharded environment
Expand Down Expand Up @@ -1278,18 +1277,9 @@ def _call_setup_hook(self) -> None:
self.accelerator.barrier("post_setup")

def _call_configure_sharded_model(self) -> None:
# Call configure sharded model hook if accelerator requests. In some cases
# we will not call the hook; the hook has initialized the sharded model for example.

# used on the model if the user re-create a trainer with resume_from_checkpoint
model = self.lightning_module
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
with self.accelerator.model_sharded_context():
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")
model.call_configure_sharded_model_hook = True
self.accelerator.call_configure_sharded_model_hook = False
with self.accelerator.model_sharded_context():
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn
Expand Down
55 changes: 0 additions & 55 deletions tests/accelerators/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import SingleDevicePlugin
from tests.accelerators.test_dp import CustomClassificationModelDP
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -77,57 +76,3 @@ def configure_sharded_model(self):
trainer.fit(model)

assert model.configure_sharded_model_called


class DummyModel(BoringModel):
def __init__(self):
super().__init__()
self.configure_sharded_model_called = False

def configure_sharded_model(self):
self.configure_sharded_model_called = True


def test_configure_sharded_model_false(tmpdir):
"""Ensure ``configure_sharded_model`` is not called, when turned off."""

class CustomPlugin(SingleDevicePlugin):
@property
def call_configure_sharded_model_hook(self) -> bool:
return False

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
plugins=CustomPlugin(device=torch.device("cpu")),
)
trainer.fit(model)

assert not model.configure_sharded_model_called


def test_accelerator_configure_sharded_model_called_once(tmpdir):
"""Ensure that the configure sharded model hook is called, and set to False after to ensure not called
again."""

model = DummyModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
assert trainer.accelerator.call_configure_sharded_model_hook is True
trainer.fit(model)
assert trainer.accelerator.call_configure_sharded_model_hook is False


def test_configure_sharded_model_called_once(tmpdir):
"""Ensure ``configure_sharded_model`` is only called once."""

model = DummyModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
trainer.fit(model)

assert model.configure_sharded_model_called
model.configure_sharded_model_called = False

assert not model.configure_sharded_model_called
17 changes: 5 additions & 12 deletions tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,16 @@ def setup(self, stage: str) -> None:
# when running stages like test, validate, and predict, we will skip setting up,
# will directly use the module itself unless we load from checkpoint
return
# resetting call_configure_sharded_model_hook attribute so that we could call
# configure sharded model
self.call_configure_sharded_model_hook = False
# for loading full state dict, we first need to create a new unwrapped model
# to load state dict and then wrapping
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

def configure_sharded_model(self) -> None:
for i, layer in enumerate(self.layer):
if i % 2 == 0:
self.layer[i] = wrap(layer)
self.layer = wrap(self.layer)
if not isinstance(self.layer, FullyShardedDataParallel):
for i, layer in enumerate(self.layer):
if i % 2 == 0:
self.layer[i] = wrap(layer)
self.layer = wrap(self.layer)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# when loading full state dict, we first need to create a new unwrapped model
Expand Down Expand Up @@ -131,13 +129,8 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)

model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook

model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path

assert model_call_configure_sharded_model_hook
assert not trainer_accelerator_call_configure_sharded_model_hook
trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=TestFSDPModel)
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
run_checkpoint_test(tmpdir)


@RunIf(min_gpus=1, deepspeed=True, special=False)
@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir):
"""Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the
optimizer state and scheduler states cannot be restored."""
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/connectors/test_checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
connector = trainer.checkpoint_connector
trainer.accelerator.connect(model)
connector.resume_start()
assert connector.resume_checkpoint_path == ckpt_path
assert connector._loaded_checkpoint
Expand Down