From eb74c3f1df1475a75bb8c18dd5def83ff849c2a5 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 23 Jul 2020 11:49:57 -0700 Subject: [PATCH] updates to amp to support grad clip and grad accumulation (#290) * updates to amp to support grad clip and grad accumulation * zero grad using optimizer if in amp mode --- deepspeed/pt/deepspeed_light.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 1291d4c08115..1b46e6ac6d29 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -774,7 +774,12 @@ def backward(self, loss, allreduce_gradients=True): if self.zero_optimization(): self.optimizer.backward(loss) elif self.amp_enabled(): - with amp.scale_loss(loss, self.optimizer) as scaled_loss: + # AMP requires delaying unscale when inside gradient accumulation boundaries + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = not self.is_gradient_accumulation_boundary() + with amp.scale_loss(loss, + self.optimizer, + delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward() elif self.fp16_enabled(): self.optimizer.backward(loss) @@ -828,14 +833,22 @@ def step(self): if self.is_gradient_accumulation_boundary(): - if not self.fp16_enabled() and self.gradient_clipping() > 0.0: - self.clip_fp32_gradients() + if self.gradient_clipping() > 0.0: + if not self.fp16_enabled() and not self.amp_enabled(): + self.clip_fp32_gradients() + elif self.amp_enabled(): + # AMP's recommended way of doing clipping + # https://nvidia.github.io/apex/advanced.html#gradient-clipping + master_params = amp.master_params(self.optimizer) + torch.nn.utils.clip_grad_norm_(parameters=master_params, + max_norm=self.gradient_clipping()) self.optimizer.step() - # zero grad in basic optimizer could be unreliable and may not exhibit - # the behaviour that we want - if not self.zero_optimization() and not self.fp16_enabled(): + #zero grad in basic optimizer could be unreliable and may not exhibit + #the behaviour that we want + if not self.zero_optimization() and not self.fp16_enabled( + ) and not self.amp_enabled(): self.zero_grad() else: self.optimizer.zero_grad()