-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDPv1] Optimize memory usage for optimize_backward_concat=True #1186
[FSDPv1] Optimize memory usage for optimize_backward_concat=True #1186
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like both the trainer and ods think there is no extra memory usage any more. Thanks for root causing and fixing the issue!
@@ -174,7 +174,7 @@ def __init__( | |||
self._require_backward_grad_sync = True | |||
# If optimize_backward_concat == True, used to accumulate the | |||
# fp32 gradients for the flattened parameters | |||
self.fp32_grads = [] | |||
self.fp32_grads = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about renaming this to something like fp32_flat_grad
to indicate the type and shape after the change?
@@ -1765,11 +1765,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |||
|
|||
if self.fp32_reduce_scatter: | |||
if self.optimize_backward_concat: | |||
# Flatten and concat the accumulated fp32 grads | |||
# and assign them to param.unsharded_main_grad | |||
param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the issue that upon this torch.cat
call, we have 2x fp32 unsharded gradient memory? (This would be a temporary spike since the source individual fp32 gradients are freed immediately upon L1772?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly @awgu, even though the memory is freed right afterwards, it still triggers the memory allocation via CUDA Caching Allocator, which will increase the peak GPU memory(reflected via torch_cuda_max_reserved)
b73fffe
into
ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard
Avoid extra memory usage caused by concat(), directly allocate flattened fp32 grads and perform fp32 grad accumulation for individual parameters on specific slice within the flattened tensor.
Local test
Deterministic numerical test
baseline, optimize_backward_concat=False
https://www.internalfb.com/intern/paste/P1404601998/
optimize_backward_concat=True
https://www.internalfb.com/intern/paste/P1404700768/
memory usage
baseline, optimize_backward_concat=False
https://www.internalfb.com/intern/paste/P1404611094/
optimize_backward_concat=True, before optimization
https://www.internalfb.com/intern/paste/P1404620340/
optimize_backward_concat=True, after optimization
https://www.internalfb.com/intern/paste/P1404655599/
E2E MAST
model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP=4, CP = 1, seq_len=1024
baseline, optimize_backward_concat=False
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-c52vf7
** tflops/s = ~382**
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.1149070831916.json.gz&bucket=acadia
optimize_backward_concat=True before optimization
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-pdtcx1d5
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.24449323379469.json.gz&bucket=acadia
optimize_backward_concat=True after optimization
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-ghg1f57z
** tflops/s = ~440 (+15%)**
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.17125783820625.json.gz&bucket=acadia