From a9a83a6fcfcf654d75017453fbb3a476000180ce Mon Sep 17 00:00:00 2001 From: gcooper-isi <42359489+gcooper-isi@users.noreply.github.com> Date: Tue, 5 Jan 2021 13:14:29 -0500 Subject: [PATCH] Allow DeepSpeed models to be initialized with optimizer=None (#469) Allow DeepSpeed models to be initialized with optimizer=None Co-authored-by: Shaden Smith --- deepspeed/runtime/engine.py | 9 ++++----- tests/unit/test_config.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a87a56cb5b9b..99db78ec6dc5 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -466,10 +466,9 @@ def _is_supported_optimizer(self, optimizer_name): # Validate configuration based on command line arguments def _do_sanity_check(self): if not self.client_optimizer: - assert self._is_supported_optimizer(self.optimizer_name()), \ - '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name()) - assert self.client_model_parameters, \ - 'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name()) + if self.optimizer_name() is not None: + assert self._is_supported_optimizer(self.optimizer_name()), \ + '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name()) if self.optimizer_name() == LAMB_OPTIMIZER: assert self.dynamic_loss_scale(), \ @@ -1289,7 +1288,7 @@ def _load_checkpoint(self, self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict) - if not self.zero_optimization(): + if self.optimizer is not None and not self.zero_optimization(): if self.fp16_enabled(): self.optimizer.load_state_dict( checkpoint['optimizer'], diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index e5fe75b281e0..4cabefe71a33 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -195,3 +195,34 @@ def _test_dist_init_true(args, model, hidden_dim): model.step() _test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim) + + +def test_init_no_optimizer(tmpdir): + + config_dict = {"train_batch_size": 1, "fp16": {"enabled": True}} + config_path = create_config_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=1) + def _helper(): + parser = argparse.ArgumentParser() + args = parser.parse_args(args='') + args.deepscale_config = config_path + args.local_rank = 0 + + hidden_dim = 10 + + model = SimpleModel(hidden_dim=hidden_dim) + + model, _, _, _ = deepspeed.initialize(args=args, model=model) + data_loader = random_dataloader(model=model, + total_samples=5, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + with pytest.raises(AssertionError): + model.backward(loss) + with pytest.raises(AssertionError): + model.step() + + _helper()