Skip to content

Commit

Permalink
Pipeline warnings and checkpoint portability (microsoft#588)
Browse files Browse the repository at this point in the history
* Switch from deprecated allreduce interface.

* Make pipeline checkpoint files portable.
  • Loading branch information
Shaden Smith authored Dec 8, 2020
1 parent e8b126d commit 2f62697
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
9 changes: 7 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop

from .pipe.module import PipelineModule
from .utils import ensure_directory_exists
from ..ops.op_builder import UtilsBuilder
from ..ops.adam import DeepSpeedCPUAdam
Expand Down Expand Up @@ -1355,6 +1356,10 @@ def _load_checkpoint(self,
logger.info(f'rank: {self.global_rank} loading checkpoint: {load_path}')
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)

if isinstance(self.module, PipelineModule):
# Pipeline parallelism uses this to load its own checkpoint files.
self._curr_ckpt_path = os.path.join(load_dir, tag)

self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
Expand Down Expand Up @@ -1522,8 +1527,8 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
save_path = self._get_ckpt_name(save_dir, tag)
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns self._curr_save_path.
self._curr_save_path = os.path.dirname(save_path)
# then instead just returns None.
self._curr_ckpt_path = os.path.join(save_dir, tag)

state = {
'module':
Expand Down
21 changes: 12 additions & 9 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def __init__(self, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"

# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False

# pipeline step for logging
self.log_batch_step_id = -1

Expand Down Expand Up @@ -546,7 +549,7 @@ def _exec_backward_pass(self, buffer_id):
# The last stage just runs backward on the loss using DeepSpeed's typical
# mechanisms.
if self.is_last_stage():
super().backward(self.loss, allreduce_gradients=False)
super().backward(self.loss)
self.mem_status('AFTER BWD')
return

Expand Down Expand Up @@ -1100,31 +1103,31 @@ def module_state_dict(self):
is ``save_state_dict()``.
Returns:
str: The directory path where the checkpoint was saved.
None
"""
assert isinstance(self.module, PipelineModule)
assert self._curr_save_path is not None, \
assert self._curr_ckpt_path is not None, \
"PipelineEngine expects module_state_dict() to be called from save_checkpoint()"

self.module.save_state_dict(self._curr_save_path)
return self._curr_save_path
self.module.save_state_dict(self._curr_ckpt_path)
return None

def load_module_state_dict(self, state_dict, strict=True):
"""Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank.
If ``state_dict`` is not a ``str``, we revert to ``super()`` expecting a ``dict``.
If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
Args:
state_dict (str): Path to the directory for checkpoint.
state_dict (str, None): unused
strict (bool, optional): Strict state loading. Defaults to True.
"""
if not isinstance(state_dict, str):
if (state_dict is not None) and (not isinstance(state_dict, str)):
super().load_module_state_dict(state_dict, strict)
return

self.module.load_state_dir(state_dict, strict=strict)
self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict)

# A map of PipeInstruction types to methods. Each method will be executed with the
# kwargs provided to the PipeInstruction from the scheduler.
Expand Down

0 comments on commit 2f62697

Please sign in to comment.