diff --git a/deepspeed/runtime/zero/stage1.py b/deepspeed/runtime/zero/stage1.py index 9a5c0baa2ba0..d5c7616ff87e 100755 --- a/deepspeed/runtime/zero/stage1.py +++ b/deepspeed/runtime/zero/stage1.py @@ -947,9 +947,10 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, max_elems_per_comm): - partition_id = dist.get_rank(group=self.dp_process_group) - alignment = dist.get_world_size(group=self.dp_process_group) + if not torch.is_tensor(all_partition_states[0]): + return all_partition_states[0] + alignment = dist.get_world_size(group=self.dp_process_group) flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned( tensor_list=all_partition_states, dp=dist.get_world_size(group=self.dp_process_group), @@ -964,6 +965,7 @@ def _partition_base_optimizer_state(self, dp_process_group=self.dp_process_group ) + partition_id = dist.get_rank(group=self.dp_process_group) return [sub_partition for sub_partition in dp_sub_partitions[partition_id]] # Compute the optimizer state partitions for the group by @@ -1013,8 +1015,11 @@ def _restore_base_optimizer_state(self, state_dict_list): for group_idx, group in enumerate(self.optimizer.param_groups): for param_idx, param in enumerate(group['params']): for key, saved in base_optimizer_group_states[group_idx].items(): - current = self.optimizer.state[param][key] - current.data.copy_(saved[param_idx].data) + if torch.is_tensor(self.optimizer.state[param][key]): + current = self.optimizer.state[param][key] + current.data.copy_(saved[param_idx].data) + else: + self.optimizer.state[param][key] = saved # Restore base optimizer fp32 weights from ZeRO fp16 weights def _restore_from_fp16_weights(self): diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 360efa0aaf4e..b0c268341224 100755 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -101,6 +101,37 @@ def step(self, closure=None): return loss +class HybridStateOptimizer(torch.optim.Optimizer): + def __init__(self, params, lr=0.11072018): + defaults = dict(lr=lr) + super(HybridStateOptimizer, self).__init__(params, defaults) + + def __setstate__(self, state): + super(HybridStateOptimizer, self).__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state['integer_step'] = 0 + state['tensor_step'] = torch.zeros(1) + + d_p = p.grad.data + p.data.add_(-group['lr'], d_p) + state['integer_step'] += 1 + state['tensor_step'] += 1 + + return loss + + class PLD_SimpleModel(SimpleModel): def __init__(self, hidden_dim, empty_grad=False, rank=0): super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index a5a2a63697fb..097a581d96d9 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -36,6 +36,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True): compare_deepspeed_states(saved_model, loaded_model) for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()): + assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}' assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}" if not compare_optimizer: @@ -43,20 +44,24 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True): if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer): for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups): + assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}' assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1): for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups): for p0, p1 in zip(partition0, partition1): + assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}' assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, FP16_Optimizer): for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat): + assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}' assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer): for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups): for p0, p1 in zip(params0, params1): + assert id(p0) != id(p1), f'Comparing fp32 model state tensor against itself: {id(p0)} <====> {id(p1)}' assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, torch.optim.Optimizer): pass @@ -72,6 +77,7 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): loaded_optimizer.state.values()): for s0, s1 in zip(state0.values(), state1.values()): if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): + assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' assert torch.equal(s0, s1) else: assert s0 == s1 @@ -100,18 +106,34 @@ def compare_lr_scheduler_states(saved_model, loaded_model): assert state0 == state1 +def create_deepspeed_model(args, model, base_optimizer): + if base_optimizer is None: + ds_model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + else: + ds_model, _, _, _ = deepspeed.initialize(args=args, + model=model, + optimizer=base_optimizer) + + return ds_model + + def checkpoint_correctness_verification(args, - model, + models, hidden_dim, tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False, fp16=True, - train_batch=False): + train_batch=False, + base_optimizers=[None, + None]): dtype = torch.half if fp16 else torch.float32 - ds_model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) + ds_model = create_deepspeed_model(args=args, + model=models[0], + base_optimizer=base_optimizers[0]) + data_loader = random_dataloader(model=ds_model, total_samples=50, hidden_dim=hidden_dim, @@ -125,7 +147,6 @@ def checkpoint_correctness_verification(args, else: for n, batch in enumerate(data_loader): loss = ds_model(batch[0], batch[1]) - print(loss) ds_model.backward(loss) ds_model.step() @@ -136,9 +157,9 @@ def checkpoint_correctness_verification(args, trained_model.save_checkpoint(save_folder, save_tag) - loaded_model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) + loaded_model = create_deepspeed_model(args=args, + model=models[1], + base_optimizer=base_optimizers[1]) loaded_model.load_checkpoint(save_folder, save_tag, @@ -191,25 +212,26 @@ def test_checkpoint_unfused_optimizer(tmpdir): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=False) + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @distributed_test(world_size=[2]) def _test_checkpoint_unfused_optimizer(args, - model, + models, hidden_dim, load_optimizer_states): checkpoint_correctness_verification(args, - model, - hidden_dim, - tmpdir, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, load_optimizer_states=load_optimizer_states) _test_checkpoint_unfused_optimizer(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=True) + _test_checkpoint_unfused_optimizer(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=False) @@ -236,22 +258,26 @@ def test_checkpoint_fused_optimizer(tmpdir): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=False) + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @distributed_test(world_size=[2]) - def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_states): + def _test_checkpoint_fused_optimizer(args, + models, + hidden_dim, + load_optimizer_states): checkpoint_correctness_verification(args, - model, - hidden_dim, - tmpdir, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, load_optimizer_states=load_optimizer_states) _test_checkpoint_fused_optimizer(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=True) + _test_checkpoint_fused_optimizer(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=False) @@ -293,18 +319,18 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=False) + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @distributed_test(world_size=[2]) - def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_states): + def _test_checkpoint_zero_optimizer(args, models, hidden_dim, load_optimizer_states): checkpoint_correctness_verification(args, - model, - hidden_dim, - tmpdir, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, load_optimizer_states=load_optimizer_states) _test_checkpoint_zero_optimizer(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=True) @@ -346,21 +372,21 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=False) + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @distributed_test(world_size=[2]) def _test_checkpoint_zero_no_optimizer(args, - model, + models, hidden_dim, load_optimizer_states): checkpoint_correctness_verification(args, - model, - hidden_dim, - tmpdir, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, load_optimizer_states=load_optimizer_states) _test_checkpoint_zero_no_optimizer(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=False) @@ -412,24 +438,24 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=False) + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @distributed_test(world_size=[2]) def _test_checkpoint_lr_scheduler(args, - model, + models, hidden_dim, load_optimizer_states, load_lr_scheduler_states): checkpoint_correctness_verification( args, - model, - hidden_dim, - tmpdir, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_lr_scheduler_states) _test_checkpoint_lr_scheduler(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=False, load_lr_scheduler_states=True) @@ -478,24 +504,24 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=False) + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @distributed_test(world_size=[2]) def _test_checkpoint_no_lr_scheduler(args, - model, + models, hidden_dim, load_optimizer_states, load_lr_scheduler_states): checkpoint_correctness_verification( args, - model, - hidden_dim, - tmpdir, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_lr_scheduler_states) _test_checkpoint_no_lr_scheduler(args=args, - model=model, + models=models, hidden_dim=hidden_dim, load_optimizer_states=False, load_lr_scheduler_states=False) @@ -523,13 +549,17 @@ def test_checkpoint_fp32_optimizer(tmpdir): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=False) + models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] @distributed_test(world_size=[2]) - def _test_checkpoint_fp32_optimizer(args, model, hidden_dim): - checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, fp16=False) + def _test_checkpoint_fp32_optimizer(args, models, hidden_dim): + checkpoint_correctness_verification(args, + models=models, + hidden_dim=hidden_dim, + tmpdir=tmpdir, + fp16=False) - _test_checkpoint_fp32_optimizer(args=args, model=model, hidden_dim=hidden_dim) + _test_checkpoint_fp32_optimizer(args=args, models=models, hidden_dim=hidden_dim) @pytest.mark.parametrize("zero_stage", [0, 1]) @@ -571,10 +601,10 @@ def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2): @distributed_test(world_size=4) def _test(save_folder, num_stages): args = args_from_dict(tmpdir, config_dict) - model = LinearStackPipe(num_stages=num_stages) + models = [LinearStackPipe(num_stages=num_stages) for _ in range(2)] checkpoint_correctness_verification(args=args, - model=model, - hidden_dim=model.hidden_dim, + models=models, + hidden_dim=models[0].hidden_dim, tmpdir=save_folder, fp16=config_dict['fp16']['enabled'], load_optimizer_states=True, @@ -635,3 +665,42 @@ def _test(base_topo, test_topo, save_folder): assert torch.allclose(p0, p1, atol=1e-07), f"Model state {p0} is not equal to {p1}" _test(base_topo, test_topo, save_folder=tmpdir) + + +@pytest.mark.parametrize('zero_stage', [1, 2]) +def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage): + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage + }, + "zero_allow_untested_optimizer": True, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)] + optimizers = [HybridStateOptimizer(model.parameters()) for model in models] + + @distributed_test(world_size=[2]) + def _test_checkpoint_zero_hybrid_optimizer_state(args, + models, + optimizers, + hidden_dim): + checkpoint_correctness_verification(args, + models=models, + base_optimizers=optimizers, + hidden_dim=hidden_dim, + tmpdir=tmpdir, + load_optimizer_states=True) + + _test_checkpoint_zero_hybrid_optimizer_state(args=args, + models=models, + optimizers=optimizers, + hidden_dim=hidden_dim)