Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates to amp to support grad clip and grad accumulation #290

Merged
merged 10 commits into from
Jul 23, 2020
21 changes: 17 additions & 4 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,12 @@ def backward(self, loss, allreduce_gradients=True):
if self.zero_optimization():
self.optimizer.backward(loss)
elif self.amp_enabled():
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss,
self.optimizer,
delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
elif self.fp16_enabled():
self.optimizer.backward(loss)
Expand Down Expand Up @@ -826,14 +831,22 @@ def step(self):

if self.is_gradient_accumulation_boundary():

if not self.fp16_enabled() and self.gradient_clipping() > 0.0:
self.clip_fp32_gradients()
if self.gradient_clipping() > 0.0:
if not self.fp16_enabled():
jeffra marked this conversation as resolved.
Show resolved Hide resolved
self.clip_fp32_gradients()
elif self.amp_enabled():
# AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())

self.optimizer.step()

#zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled():
if not self.zero_optimization() and not self.fp16_enabled(
) and not self.amp_enabled():
self.zero_grad()
else:
self.optimizer.zero_grad()
Expand Down