Skip to content

Commit

Permalink
Support non-tensor state in checkpoint (#548)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Nov 21, 2020
1 parent 0178e6c commit 6021b70
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 58 deletions.
13 changes: 9 additions & 4 deletions deepspeed/runtime/zero/stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,9 +947,10 @@ def _partition_base_optimizer_state(self,
state_key,
all_partition_states,
max_elems_per_comm):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
if not torch.is_tensor(all_partition_states[0]):
return all_partition_states[0]

alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group),
Expand All @@ -964,6 +965,7 @@ def _partition_base_optimizer_state(self,
dp_process_group=self.dp_process_group
)

partition_id = dist.get_rank(group=self.dp_process_group)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]

# Compute the optimizer state partitions for the group by
Expand Down Expand Up @@ -1013,8 +1015,11 @@ def _restore_base_optimizer_state(self, state_dict_list):
for group_idx, group in enumerate(self.optimizer.param_groups):
for param_idx, param in enumerate(group['params']):
for key, saved in base_optimizer_group_states[group_idx].items():
current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
if torch.is_tensor(self.optimizer.state[param][key]):
current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
else:
self.optimizer.state[param][key] = saved

# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,37 @@ def step(self, closure=None):
return loss


class HybridStateOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.11072018):
defaults = dict(lr=lr)
super(HybridStateOptimizer, self).__init__(params, defaults)

def __setstate__(self, state):
super(HybridStateOptimizer, self).__setstate__(state)

def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

state = self.state[p]
if len(state) == 0:
state['integer_step'] = 0
state['tensor_step'] = torch.zeros(1)

d_p = p.grad.data
p.data.add_(-group['lr'], d_p)
state['integer_step'] += 1
state['tensor_step'] += 1

return loss


class PLD_SimpleModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank)
Expand Down
Loading

0 comments on commit 6021b70

Please sign in to comment.