Skip to content

Commit

Permalink
fix bug related to stitching reduced grads across communication parti…
Browse files Browse the repository at this point in the history
…tions (microsoft#318)
  • Loading branch information
jeffra authored Sep 15, 2020
1 parent 91b4a93 commit 55ed105
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions deepspeed/runtime/zero/stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ def __init__(self,

# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \
params_not_local = self.get_all_sub_partition_info(
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
tensor_list=self.fp16_groups[i],
all_element_intervals=element_intervals,
local_rank=local_rank,
Expand Down Expand Up @@ -591,28 +590,20 @@ def reduce_scatter_gradients(self,

all_comm_partitions.append(single_comm_all_partitions)

for p in my_params:
partitions = param_partition_map[p]
parts = []
for part in partitions:
params, offsets = partition_param_map[part]
found = False
for p_idx, _p in enumerate(params):
if p.__hash__() == _p.__hash__():
found = True
if offsets[p_idx][0] is not None:
my_part = part.narrow(0,
offsets[p_idx][0],
offsets[p_idx][1])
parts.append(my_part)
assert found
if p is not None:
updated_grad = _unflatten_dense_tensors(torch.cat(parts), [p])
p.grad.copy_(updated_grad[0])
# stitch together all rank sub partitions for each comm idx
flat_comm_grads = []
for comm_idx, rank_partitions in enumerate(all_comm_partitions):
flat_comm_grads.append(torch.cat(rank_partitions))

flat_all_grads = torch.cat(flat_comm_grads)

# copy back reduced gradients but only those needed for this local rank
for param, updated_grad in zip(self.fp16_groups[i], _unflatten_dense_tensors(flat_all_grads, self.fp16_groups[i])):
if param in my_params:
param.grad.copy_(updated_grad)

def step(self, closure=None):
# First compute norm for all group so we know if there is overflow

self.overflow = self.overflow_checker.check()

prev_scale = self.loss_scale
Expand Down Expand Up @@ -649,6 +640,7 @@ def step(self, closure=None):
#)

#TODO RS: can we safely use dtype of the first sub-partition? i think so
# create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
Expand Down

0 comments on commit 55ed105

Please sign in to comment.