Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save/load/resume from checkpoint for DeepSpeed Plugin #8397

Merged
merged 75 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
24a3e50
wip
Jul 7, 2021
03a8769
Change trainer loading behaviour for validate/test/predict
Jul 9, 2021
a943e33
Fix
Jul 9, 2021
40a3446
Fix/add tests
Jul 9, 2021
8c24ffd
remove
Jul 9, 2021
1879be7
Cleanups
Jul 12, 2021
3162ff7
Space
Jul 12, 2021
6dd61d6
cleanups
Jul 12, 2021
5772e17
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
b072868
Add CHANGELOG.md
Jul 12, 2021
de2738d
Merge branch 'master' into feat/ckpt_load
Jul 12, 2021
6910e39
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
bf5afe3
Fix
Jul 12, 2021
f2ee8b5
Move after setup
Jul 12, 2021
b7c24d9
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
8659426
Cleanups on logic
Jul 12, 2021
84d20f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
9e367fd
Remve
Jul 12, 2021
6ea8b44
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
3f8c3d3
Remve
Jul 12, 2021
b8ffc39
fix test
Jul 12, 2021
b02f35b
feedback
Jul 12, 2021
dbb03af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
1c7b9a1
Update pytorch_lightning/trainer/properties.py
Jul 12, 2021
444fb55
Feedback
Jul 12, 2021
4632bba
Same fix
Jul 12, 2021
e92b757
Same fix
Jul 12, 2021
66bea8e
Add test for behaviour, modify based on feedback
Jul 12, 2021
0139a19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
d48d916
Wording
Jul 12, 2021
100d73b
Apply suggestions from code review
Jul 12, 2021
f3f92a5
Cleanup docs
Jul 12, 2021
2849d0b
Update pytorch_lightning/trainer/trainer.py
Jul 12, 2021
f53c896
feedback
Jul 12, 2021
ebc713b
Fixes to test API
Jul 12, 2021
76e22c2
Add carlos description
Jul 12, 2021
9a62650
Merge branch 'feat/ckpt_load' into fix/ds_saving_2
Jul 12, 2021
0b46226
Fixes
Jul 13, 2021
7a85b44
Merge branch 'master' into fix/ds_saving_2
Jul 13, 2021
8042fb4
Changes
Jul 13, 2021
203fd49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 13, 2021
8d0f260
Try delaying
Jul 14, 2021
d4e2295
Merge branch 'master' into fix/ds_saving_2
Jul 14, 2021
28d7575
Fixes
Jul 27, 2021
32f73e4
Merge branch 'master' into fix/ds_saving_2
Jul 27, 2021
a3c6009
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2021
857a6aa
Merge branch 'master' into fix/ds_saving_2
Jul 27, 2021
e6c3bd1
Merge branch 'master' into fix/ds_saving_2
Jul 28, 2021
c51033a
fixes
Jul 28, 2021
4f5bd96
Add extra condition
Jul 28, 2021
e1fb2f0
Fix
Jul 28, 2021
77036a2
Fix
Jul 28, 2021
82e00be
Attempt to fix tests
Jul 28, 2021
57355aa
Add guard
Jul 28, 2021
3fc8f67
Fix test
Jul 29, 2021
6adb83d
Fix
Jul 29, 2021
607aef2
Add test
Jul 29, 2021
0c30656
Update pytorch_lightning/plugins/training_type/deepspeed.py
Jul 29, 2021
c9849e0
Fix description
Jul 29, 2021
0d3866c
Add test
Jul 29, 2021
fd7a168
Fix test
Jul 29, 2021
256b145
Refactors
Jul 29, 2021
c189595
add recursive
Jul 29, 2021
670810f
Merge branch 'master' into fix/ds_saving_2
Jul 30, 2021
64a4eba
Merge branch 'master' into fix/ds_saving_2
Aug 2, 2021
0d2ec03
Fix dupe
Aug 2, 2021
ef33d90
Merge branch 'master' into fix/ds_saving_2
Aug 2, 2021
5329c48
Force 0.4.3
Aug 2, 2021
95d1287
Address reviews
Aug 2, 2021
88ab306
Add todo
Aug 2, 2021
a15cd8d
Update pytorch_lightning/plugins/training_type/training_type_plugin.py
Aug 2, 2021
9365cc0
Apply suggestions from code review
Aug 2, 2021
5f994c4
Add asserts for properties, address reviews
Aug 2, 2021
cdf8c25
Fix description
Aug 2, 2021
c47abf2
Merge branch 'master' into fix/ds_saving_2
Aug 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .azure-pipelines/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- bash: |
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
pip install fairscale>=0.3.4
pip install "deepspeed>=0.4.0, !=0.4.4" # FIXME: bug with 0.4.4
pip install "deepspeed>=0.4.3, !=0.4.4" # FIXME: bug with 0.4.4
pip install . --requirement requirements/devel.txt
pip list
displayName: 'Install dependencies'
Expand Down
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613))


-
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
[#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397),
[#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644),
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))


## [1.4.0] - 2021-07-27
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def save_function(self, value: Optional[Callable]) -> None:

def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
self._fs.rm(filepath)
self._fs.rm(filepath, recursive=True)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
Expand Down
138 changes: 97 additions & 41 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, rank_zero_warn, WarningCache

warning_cache = WarningCache()

if _DEEPSPEED_AVAILABLE:
import deepspeed
Expand Down Expand Up @@ -119,7 +122,7 @@ def __init__(
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
save_full_weights: bool = True,
load_full_weights: bool = False,
cpu_offload: bool = False,
cpu_offload_params: bool = False,
cpu_offload_use_pin_memory: bool = False,
Expand Down Expand Up @@ -250,10 +253,9 @@ def __init__(

synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.

save_full_weights: Gathers weights across all processes before saving to disk
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
rather than individual sharded weight files.
Disable to save sharded states individually.
load_full_weights: True when loading a single checkpoint file containing the model state dict
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
Expand Down Expand Up @@ -313,7 +315,7 @@ def __init__(
deepspeed.utils.logging.logger.setLevel(logging_level)

self.remote_device = remote_device
self.save_full_weights = save_full_weights
self.load_full_weights = load_full_weights

# default FP16 parameters.
self.loss_scale = loss_scale
Expand Down Expand Up @@ -365,6 +367,10 @@ def _set_node_environment_variables(
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(self.local_rank)

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return True

def pre_dispatch(self):
self.init_deepspeed()
self.barrier()
Expand Down Expand Up @@ -657,43 +663,36 @@ def _create_default_config(
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
return cfg

def _filepath_to_dir(self, filepath: str) -> str:
return os.path.dirname(filepath)

@property
def deepspeed_engine(self):
return self.model

@property
def _multi_device(self) -> bool:
return self.num_processes > 1 or self.num_nodes > 1

def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
"""
if self.world_size > 1 and self.zero_stage_3:
if self.save_full_weights:
# todo: expose this as general function in deepspeed
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
if self.is_global_zero:
# State dict keys will include reference to wrapper LightningDeepSpeedModule
# Delete `module` prefix before saving.
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
checkpoint["state_dict"] = state_dict
return super().save_checkpoint(checkpoint, filepath)
return

# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
save_dir = self._filepath_to_dir(filepath)
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
else:
super().save_checkpoint(checkpoint, filepath)

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
if self.save_full_weights or self.world_size == 1:
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
# todo (sean): Add link to docs once docs are merged.
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory."
"If a single file is required after training, see <TODO> for instructions."
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.broadcast(checkpoint_path)
Expand All @@ -703,20 +702,77 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A
from pytorch_lightning.trainer.states import TrainerFn

is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
save_dir = self._filepath_to_dir(checkpoint_path)

if self.zero_stage_3:
# TODO: Currently required as this call is missing within the deepspeed engine.
self.deepspeed_engine.optimizer._partition_all_parameters()

_, client_state = self.deepspeed_engine.load_checkpoint(
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
)
if client_state is None:
raise MisconfigurationException(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
"or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`."
)
return client_state

@property
def lightning_restore_optimizer_and_schedulers(self) -> bool:
# managed by DeepSpeed
if self.load_full_weights and self.zero_stage_3 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
rank_zero_warn(
"A single checkpoint file has been given. This means optimizer states and "
"scheduler states can not be restored. If you'd like to restore these states, you must "
"provide a path to the originally saved DeepSpeed checkpoint."
)
return False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
pass
if self.load_full_weights and self.zero_stage_3:
self.model_to_device()
self._restore_zero_state(checkpoint)

def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
"""
Overrides the normal load_state_dict behaviour in PyTorch to ensure
we gather parameters that may be sharded across processes before loading
the state dictionary when using ZeRO stage 3.
This is then automatically synced across processes.
Args:
ckpt: The ckpt file.
"""

def load(module: torch.nn.Module, prefix=""):

missing_keys = []
unexpected_keys = []
error_msgs = []
state_dict = ckpt["state_dict"]

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if self.is_global_zero:
module._load_from_state_dict(
state_dict=state_dict,
prefix=prefix,
local_metadata=local_metadata,
strict=True,
missing_keys=missing_keys,
unexpected_keys=unexpected_keys,
error_msgs=error_msgs,
)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")

load(self.lightning_module, prefix="")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
Override to delay setting optimizers and schedulers till after dispatch.
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
However this may break certain precision plugins such as APEX which require optimizers to be set.
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.

Returns:
If True, delay setup optimizers till pre_dispatch, else call within setup.
"""
return False

Expand Down
Loading