Skip to content

Commit

Permalink
Deprecate client ability to disable gradient reduction (microsoft#552)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
tjruwase and jeffra authored Nov 24, 2020
1 parent 1ef5cd2 commit 6e65c2c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
7 changes: 6 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,11 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False):
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""

if not allreduce_gradients:
logger.warning(
f'Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed'
)

# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1:
loss = self._scale_loss(loss.float())
Expand Down Expand Up @@ -931,7 +936,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False):
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce').start()

if allreduce_gradients and self.enable_backward_allreduce:
if self.enable_backward_allreduce:
self.allreduce_gradients()

if self.wall_clock_breakdown():
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,12 @@ def reduce_ipg_grads(self):

with torch.cuda.stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:

assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported"

self.params_already_reduced[param_id] = True

if not self.is_param_in_current_partition[param_id]:
Expand Down

0 comments on commit 6e65c2c

Please sign in to comment.