Skip to content

Commit

Permalink
Allow DeepSpeed models to be initialized with optimizer=None (#469)
Browse files Browse the repository at this point in the history
Allow DeepSpeed models to be initialized with optimizer=None

Co-authored-by: Shaden Smith <[email protected]>
  • Loading branch information
gcooper-isi and Shaden Smith authored Jan 5, 2021
1 parent e6ac731 commit a9a83a6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
9 changes: 4 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,9 @@ def _is_supported_optimizer(self, optimizer_name):
# Validate configuration based on command line arguments
def _do_sanity_check(self):
if not self.client_optimizer:
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
assert self.client_model_parameters, \
'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name())
if self.optimizer_name() is not None:
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())

if self.optimizer_name() == LAMB_OPTIMIZER:
assert self.dynamic_loss_scale(), \
Expand Down Expand Up @@ -1289,7 +1288,7 @@ def _load_checkpoint(self,

self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
if self.optimizer is not None and not self.zero_optimization():
if self.fp16_enabled():
self.optimizer.load_state_dict(
checkpoint['optimizer'],
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,34 @@ def _test_dist_init_true(args, model, hidden_dim):
model.step()

_test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim)


def test_init_no_optimizer(tmpdir):

config_dict = {"train_batch_size": 1, "fp16": {"enabled": True}}
config_path = create_config_from_dict(tmpdir, config_dict)

@distributed_test(world_size=1)
def _helper():
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepscale_config = config_path
args.local_rank = 0

hidden_dim = 10

model = SimpleModel(hidden_dim=hidden_dim)

model, _, _, _ = deepspeed.initialize(args=args, model=model)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
with pytest.raises(AssertionError):
model.backward(loss)
with pytest.raises(AssertionError):
model.step()

_helper()

0 comments on commit a9a83a6

Please sign in to comment.