diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 94b0d7e48488..572b8af36f13 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -878,6 +878,11 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True. """ + if not allreduce_gradients: + logger.warning( + f'Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed' + ) + # scale loss w.r.t. gradient accumulation if needed if self.gradient_accumulation_steps() > 1: loss = self._scale_loss(loss.float()) @@ -931,7 +936,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): self.timers('backward_allreduce_microstep').start() self.timers('backward_allreduce').start() - if allreduce_gradients and self.enable_backward_allreduce: + if self.enable_backward_allreduce: self.allreduce_gradients() if self.wall_clock_breakdown(): diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 4fcbab596821..f6fa523fc8c0 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -955,6 +955,12 @@ def reduce_ipg_grads(self): with torch.cuda.stream(stream): for _, param, param_id in self.params_in_ipg_bucket: + + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + self.params_already_reduced[param_id] = True if not self.is_param_in_current_partition[param_id]: