Skip to content

Commit

Permalink
Validate consistent ckpt tags across ranks (microsoft#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Jan 14, 2021
1 parent 981bc7d commit f032e56
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import log_dist
Expand Down
24 changes: 24 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
ADAM_W_MODE_PARAM = "adam_w_mode"


class DeepSpeedConfigError(Exception):
pass


def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
Expand Down Expand Up @@ -471,6 +475,21 @@ def get_tensorboard_job_name(param_dict):
return TENSORBOARD_JOB_NAME_DEFAULT


def get_checkpoint_params(param_dict):
return param_dict.get(CHECKPOINT, {})


def get_checkpoint_tag_validation_mode(checkpoint_params):
tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION,
CHECKPOINT_TAG_VALIDATION_DEFAULT)
tag_validation_mode = tag_validation_mode.upper()
if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES:
return tag_validation_mode
else:
raise DeepSpeedConfigError("Checkpoint config contains invalid tag_validation " \
f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")


'''Write deepspeed config files by modifying basic templates.
Can be used for quicly changing parameters via command line parameters.'''

Expand Down Expand Up @@ -627,6 +646,11 @@ def _initialize_params(self, param_dict):
self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict)

checkpoint_params = get_checkpoint_params(param_dict)
validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE
self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL

def _batch_assertion(self):

train_batch = self.train_batch_size
Expand Down
25 changes: 25 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@
TENSORBOARD_JOB_NAME = "job_name"
TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"

#########################################
# Progressive Layer Drop (PLD)
#########################################
PROGRESSIVE_LAYER_DROP = "progressive_layer_drop"

# PLD enable signal
Expand All @@ -299,3 +301,26 @@

PLD_GAMMA = "gamma"
PLD_GAMMA_DEFAULT = 0.001


#########################################
# Validation modes
#########################################
class ValidationMode:
WARN = "WARN"
IGNORE = "IGNORE"
FAIL = "FAIL"


#########################################
# Checkpoint config params
#########################################
# "checkpoint": {tag_validation=["Ignore"|"Warn"|"Fail"]}
CHECKPOINT = "checkpoint"
CHECKPOINT_TAG_VALIDATION = "tag_validation"
CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN
CHECKPOINT_TAG_VALIDATION_MODES = [
ValidationMode.WARN,
ValidationMode.IGNORE,
ValidationMode.FAIL
]
30 changes: 29 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import torch
import warnings
import hashlib
import torch.distributed as dist

from torch.nn.modules import Module
Expand Down Expand Up @@ -213,6 +214,12 @@ def get_batch_info(self):
"""
return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps

def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled

def checkpoint_tag_validation_fail(self):
return self._config.checkpoint_tag_validation_fail

def elasticity_enabled(self):
return self._config.elasticity_enabled

Expand Down Expand Up @@ -1435,12 +1442,30 @@ def _get_all_zero_checkpoints(self, load_dir, tag):
)
return zero_optimizer_sd

def _checkpoint_tag_validation(self, tag):
if self.checkpoint_tag_validation_enabled():
s_hash = hashlib.sha1(tag.encode())
bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device)
max_bhash = bhash.clone()
min_bhash = bhash.clone()
dist.all_reduce(max_bhash, op=torch.distributed.ReduceOp.MAX)
dist.all_reduce(min_bhash, op=torch.distributed.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " \
"all ranks. Including rank unique information in checkpoint tag could cause issues when " \
"restoring with different world sizes."
if self.checkpoint_tag_validation_fail():
assert valid, msg
elif not valid:
logger.warning(msg)

def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided.
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
used if not provided. Tag name must be the same across all ranks.
client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
"""
Expand All @@ -1454,6 +1479,9 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
if tag is None:
tag = f"global_step{self.global_steps}"

# Ensure checkpoint tag is consistent across ranks
self._checkpoint_tag_validation(tag)

if self.save_non_zero_checkpoint:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state)
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,68 @@ def _helper(args, model, hidden_dim):
model.load_checkpoint(tmpdir)

_helper(args=args, model=model, hidden_dim=hidden_dim)


@pytest.mark.parametrize('valid_mode', ["FAIL", "WARN", "IGNORE"])
def test_checkpoint_unique_tag(tmpdir, valid_mode):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"checkpoint": {
"tag_validation": valid_mode
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)

model = SimpleModel(hidden_dim, rank=args.local_rank)

@distributed_test(world_size=[2])
def _helper(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
if valid_mode == "FAIL":
with pytest.raises(AssertionError):
model.save_checkpoint(save_dir=tmpdir,
tag=f"tag-{torch.distributed.get_rank()}")
else:
model.save_checkpoint(save_dir=tmpdir,
tag=f"tag-{torch.distributed.get_rank()}")

_helper(args=args, model=model, hidden_dim=hidden_dim)


def test_checkpoint_unknown_tag_validation(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"checkpoint": {
"tag_validation": "foo"
}
}
hidden_dim = 10
args = args_from_dict(tmpdir, config_dict)

model = SimpleModel(hidden_dim, rank=args.local_rank)

@distributed_test(world_size=[1])
def _helper(args, model, hidden_dim):
with pytest.raises(deepspeed.DeepSpeedConfigError):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())

_helper(args=args, model=model, hidden_dim=hidden_dim)

0 comments on commit f032e56

Please sign in to comment.