Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow DeepSpeed models to be initialized with optimizer=None #469

Merged
merged 6 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,9 @@ def _is_supported_optimizer(self, optimizer_name):
# Validate configuration based on command line arguments
def _do_sanity_check(self):
if not self.client_optimizer:
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
assert self.client_model_parameters, \
'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name())
if self.optimizer_name() is not None:
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())

if self.optimizer_name() == LAMB_OPTIMIZER:
assert self.dynamic_loss_scale(), \
Expand Down Expand Up @@ -1289,7 +1288,7 @@ def _load_checkpoint(self,

self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict)
if not self.zero_optimization():
if self.optimizer is not None and not self.zero_optimization():
if self.fp16_enabled():
self.optimizer.load_state_dict(
checkpoint['optimizer'],
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,34 @@ def _test_dist_init_true(args, model, hidden_dim):
model.step()

_test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim)


def test_init_no_optimizer(tmpdir):

config_dict = {"train_batch_size": 1, "fp16": {"enabled": True}}
config_path = create_config_from_dict(tmpdir, config_dict)

@distributed_test(world_size=1)
def _helper():
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepscale_config = config_path
args.local_rank = 0

hidden_dim = 10

model = SimpleModel(hidden_dim=hidden_dim)

model, _, _, _ = deepspeed.initialize(args=args, model=model)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
with pytest.raises(AssertionError):
model.backward(loss)
with pytest.raises(AssertionError):
model.step()

_helper()