From 774906b23f2d5d5a5c6796b4fc8494baab647042 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 15 Jul 2020 16:22:35 +0000 Subject: [PATCH 1/7] updates to amp to support grad clip and grad accumulation --- deepspeed/pt/deepspeed_light.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 4ec715da00fe..1029bc954a13 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -772,7 +772,12 @@ def backward(self, loss, allreduce_gradients=True): if self.zero_optimization(): self.optimizer.backward(loss) elif self.amp_enabled(): - with amp.scale_loss(loss, self.optimizer) as scaled_loss: + # AMP requires delaying unscale when inside gradient accumulation boundaries + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = not self.is_gradient_accumulation_boundary() + with amp.scale_loss(loss, + self.optimizer, + delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward() elif self.fp16_enabled(): self.optimizer.backward(loss) @@ -826,8 +831,13 @@ def step(self): if self.is_gradient_accumulation_boundary(): - if not self.fp16_enabled() and self.gradient_clipping() > 0.0: - self.clip_fp32_gradients() + if self.gradient_clipping > 0.0: + if not self.fp16_enabled(): + self.clip_fp32_gradients() + elif self.amp_enabled(): + torch.nn.utils.clip_grad_norm_(parameters=amp.master_params( + self.optimizer), + max_norm=self.gradient_clipping()) self.optimizer.step() From 9231c432ae727bd91e4166ef44fd3c3bf1c4aece Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 15 Jul 2020 16:25:04 +0000 Subject: [PATCH 2/7] formatting --- deepspeed/pt/deepspeed_light.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 1029bc954a13..e236d955a3e9 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -835,8 +835,8 @@ def step(self): if not self.fp16_enabled(): self.clip_fp32_gradients() elif self.amp_enabled(): - torch.nn.utils.clip_grad_norm_(parameters=amp.master_params( - self.optimizer), + master_params = amp.master_params(self.optimizer) + torch.nn.utils.clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping()) self.optimizer.step() From 0b3ab8eac18804e3bf104f4066e76d720a24a7ae Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 15 Jul 2020 16:26:18 +0000 Subject: [PATCH 3/7] update grad clip with link --- deepspeed/pt/deepspeed_light.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index e236d955a3e9..7ae28008922d 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -835,6 +835,7 @@ def step(self): if not self.fp16_enabled(): self.clip_fp32_gradients() elif self.amp_enabled(): + # https://nvidia.github.io/apex/advanced.html#gradient-clipping master_params = amp.master_params(self.optimizer) torch.nn.utils.clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping()) From ec11c0f9a054f319789533be1b8bc40467ce5383 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 15 Jul 2020 16:41:52 +0000 Subject: [PATCH 4/7] fix syntax error --- deepspeed/pt/deepspeed_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 7ae28008922d..c8336334221a 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -831,7 +831,7 @@ def step(self): if self.is_gradient_accumulation_boundary(): - if self.gradient_clipping > 0.0: + if self.gradient_clipping() > 0.0: if not self.fp16_enabled(): self.clip_fp32_gradients() elif self.amp_enabled(): From b39b1e37402649fd639382521692d4743ff05597 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 15 Jul 2020 16:53:49 +0000 Subject: [PATCH 5/7] comment --- deepspeed/pt/deepspeed_light.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index c8336334221a..846b1287b901 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -835,6 +835,7 @@ def step(self): if not self.fp16_enabled(): self.clip_fp32_gradients() elif self.amp_enabled(): + # AMP's recommended way of doing clipping # https://nvidia.github.io/apex/advanced.html#gradient-clipping master_params = amp.master_params(self.optimizer) torch.nn.utils.clip_grad_norm_(parameters=master_params, From 1c48ec222ac49c4378e7043e0fa1ff5d3fa91460 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 16 Jul 2020 21:43:44 +0000 Subject: [PATCH 6/7] zero grad using optimizer if in amp mode --- deepspeed/pt/deepspeed_light.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 678eff3a00bf..e14b14094191 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -845,7 +845,8 @@ def step(self): #zero grad in basic optimizer could be unreliable and may not exhibit #the behaviour that we want - if not self.zero_optimization() and not self.fp16_enabled(): + if not self.zero_optimization() and not self.fp16_enabled( + ) and not self.amp_enabled(): self.zero_grad() else: self.optimizer.zero_grad() From e313de8c10522649cd7b4e91a3367bff220cd899 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 22 Jul 2020 09:10:24 -0700 Subject: [PATCH 7/7] fix bug --- deepspeed/pt/deepspeed_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index ff9a4f36b169..522ec72f99f5 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -832,7 +832,7 @@ def step(self): if self.is_gradient_accumulation_boundary(): if self.gradient_clipping() > 0.0: - if not self.fp16_enabled(): + if not self.fp16_enabled() and not self.amp_enabled(): self.clip_fp32_gradients() elif self.amp_enabled(): # AMP's recommended way of doing clipping