Skip to content

Commit

Permalink
updates to amp to support grad clip and grad accumulation (microsoft#290
Browse files Browse the repository at this point in the history
)

* updates to amp to support grad clip and grad accumulation
* zero grad using optimizer if in amp mode
  • Loading branch information
jeffra authored Jul 23, 2020
1 parent 3cc96e1 commit eb74c3f
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,12 @@ def backward(self, loss, allreduce_gradients=True):
if self.zero_optimization():
self.optimizer.backward(loss)
elif self.amp_enabled():
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss,
self.optimizer,
delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
elif self.fp16_enabled():
self.optimizer.backward(loss)
Expand Down Expand Up @@ -828,14 +833,22 @@ def step(self):

if self.is_gradient_accumulation_boundary():

if not self.fp16_enabled() and self.gradient_clipping() > 0.0:
self.clip_fp32_gradients()
if self.gradient_clipping() > 0.0:
if not self.fp16_enabled() and not self.amp_enabled():
self.clip_fp32_gradients()
elif self.amp_enabled():
# AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())

self.optimizer.step()

# zero grad in basic optimizer could be unreliable and may not exhibit
# the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled():
#zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled(
) and not self.amp_enabled():
self.zero_grad()
else:
self.optimizer.zero_grad()
Expand Down

0 comments on commit eb74c3f

Please sign in to comment.