Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero] revert PR #3166, it disabled grad clip for bf16 #3790

Merged
merged 29 commits into from
Jul 3, 2023
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
df1859d
zero++ tutorial PR (#3783)
HeyangQin Jun 21, 2023
d81a6ad
[Fix] _conv_flops_compute when padding is a str and stride=1 (#3169)
zhiruiluo Jun 21, 2023
a8c182a
fix interpolate flops compute (#3782)
cli99 Jun 22, 2023
c4c442f
use `Flops Profiler` to test `model.generate()` (#2515)
CaffreyR Jun 22, 2023
9bd7b24
revert PR #3166, it disabled grad clip for bf16
jeffra Jun 22, 2023
6075a29
ensure no loss scaling for non-fp16 dtypes
jeffra Jun 22, 2023
fc9e1ee
revert PR #3611 (#3786)
jeffra Jun 22, 2023
40045dc
bump to 0.9.6
jeffra Jun 22, 2023
710a59c
Merge branch 'master' into revert-3166
jeffra Jun 22, 2023
49a0a1b
ZeRO++ chinese blog (#3793)
HeyangQin Jun 23, 2023
2c62cb4
remove staging trigger (#3792)
jeffra Jun 23, 2023
4dc65f7
DeepSpeed-Triton for Inference (#3748)
stephen-youn Jun 23, 2023
e1119d8
ZeRO++ (#3784)
HeyangQin Jun 23, 2023
01b843a
adding zero++ to navigation panel of deepspeed.ai (#3796)
HeyangQin Jun 23, 2023
319b64e
Add ZeRO++ Japanese blog (#3797)
tohtana Jun 23, 2023
b4a2c0a
Bug Fixes for autotuner and flops profiler (#1880)
cli99 Jun 23, 2023
b7e1010
Missing strided copy for gated MLP (#3788)
cmikeh2 Jun 23, 2023
e5b1ead
Requires grad checking. (#3789)
jomayeri Jun 23, 2023
9c756cf
bump to 0.10.0
jeffra Jun 23, 2023
a204edc
Fix Bug in transform.cu (#3534)
rraminen Jun 23, 2023
f6e2e38
bug fix: triton importing error (#3799)
stephen-youn Jun 23, 2023
5c8bae0
Merge branch 'master' into revert-3166
jeffra Jun 23, 2023
928dc2c
Merge branch 'master' into revert-3166
jeffra Jun 23, 2023
c290d4c
Merge branch 'master' into revert-3166
tjruwase Jun 26, 2023
25e083a
Merge branch 'master' into revert-3166
loadams Jun 26, 2023
cafd818
Merge branch 'master' into revert-3166
tjruwase Jun 30, 2023
f3c44cc
Merge branch 'master' into revert-3166
tjruwase Jun 30, 2023
4854b5c
Merge branch 'master' into revert-3166
tjruwase Jul 3, 2023
a8ffc37
Merge branch 'master' into revert-3166
tjruwase Jul 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
@@ -480,6 +480,11 @@ def __init__(self,
dynamic_loss_args=dynamic_loss_args)
self.dynamic_loss_scale = self.loss_scaler.dynamic

if self.dtype != torch.float16:
# Only fp16 should use dynamic loss scaling
assert self.loss_scaler.cur_scale == 1.0
assert not self.dynamic_loss_scale

see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True)
@@ -1669,21 +1674,19 @@ def step(self, closure=None):
self.stop_timers(timer_names)
return

# Step 1:- Calculate gradient norm using fp-16 grads
if self.dtype == torch.float16:
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / prev_scale
see_memory_usage('After norm before optimizer')
# Step 1:- Calculate gradient norm using bit-16 grads
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / prev_scale
see_memory_usage('After norm before optimizer')

# Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups):
self.start_timers([OPTIMIZER_GRADIENTS])
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
if self.dtype == torch.float16:
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)

self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP])
@@ -1723,8 +1726,7 @@ def step(self, closure=None):

self.averaged_gradients[i] = None

if self.dtype == torch.float16:
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)

self.stop_timers([OPTIMIZER_GRADIENTS])