Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDPv1] Optimize memory usage for optimize_backward_concat=True #1186

Conversation

chrisxcai
Copy link

@chrisxcai chrisxcai commented Jun 10, 2024

Avoid extra memory usage caused by concat(), directly allocate flattened fp32 grads and perform fp32 grad accumulation for individual parameters on specific slice within the flattened tensor.

Local test

Deterministic numerical test

baseline, optimize_backward_concat=False

 NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 torchrun --master_port 1024 --nproc_per_node=8 tra
in.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=8192 --gpu_check_level=-1 --steps=5 --log_all_ste
ps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1 --reshard_after_forward=False --batch_size=128 --m
odel.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parall
el=True --mem_snapshot_stop_step 5 --log_all_steps=True --log_freq=1 --model.use_te_layers=True --optim.use_fp32_copy_opti
m=True --use_microbatching=True --optimize_backward_concat=False --mem_snapshot_max_entries=100000 --model.use_fp8=True

https://www.internalfb.com/intern/paste/P1404601998/

AVG loss: 10.8708400726318359

optimize_backward_concat=True

https://www.internalfb.com/intern/paste/P1404700768/

AVG loss: 10.8708400726318359

memory usage

baseline, optimize_backward_concat=False

NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/
torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1
 --reshard_after_forward=False --batch_size=128 --model.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parallel=True --mem_snapshot_stop_step 3 --log_all_steps=True --log_freq=1 --
model.use_te_layers=True --optim.use_fp32_copy_optim=True --use_microbatching=True --optimize_backward_concat=False --mem_snapshot_max_entries=500000 --model.use_fp8=False

https://www.internalfb.com/intern/paste/P1404611094/

torch_cuda_max_reserved: 15.1GB

optimize_backward_concat=True, before optimization

NVTE_TORCH_COMPILE=0 NVTE_DISABLE_NVRTC=1 CRYPTOGRAPHY_OPENSSL_NO_LEGACY=1 PYTHONPATH=~/benchmark/fairscale_repos/fairscale/
torchrun --master_port 1024 --nproc_per_node=8 train.py --dump_dir /tmp/chriscai/xldumps --model_parallel_size 4 --seq_len=1024 --gpu_check_level=-1 --steps=10 --log_all_steps=True --profile_freq=10 --dump_profile_traces=True --model.n_layers=1
 --reshard_after_forward=False --batch_size=128 --model.efficient_attn=flash --model.attn_bias_type=causal --model.layer_ckpt=none --model=llama3_kv8 --model.sequence_parallel=True --mem_snapshot_stop_step 3 --log_all_steps=True --log_freq=1 --
model.use_te_layers=True --optim.use_fp32_copy_optim=True --use_microbatching=True --optimize_backward_concat=True --mem_snapshot_max_entries=500000 --model.use_fp8=False

https://www.internalfb.com/intern/paste/P1404620340/

torch_cuda_max_reserved: 17.4GB

optimize_backward_concat=True, after optimization

https://www.internalfb.com/intern/paste/P1404655599/

torch_cuda_max_reserved: 15.1GB (-13.2%)

E2E MAST

model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP=4, CP = 1, seq_len=1024

baseline, optimize_backward_concat=False
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-c52vf7

Screenshot 2024-06-09 at 5 14 34 PM

** tflops/s = ~382**
Screenshot 2024-06-09 at 5 15 31 PM

trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.1149070831916.json.gz&bucket=acadia

optimize_backward_concat=True before optimization

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-pdtcx1d5
Screenshot 2024-06-09 at 5 16 54 PM

https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.24449323379469.json.gz&bucket=acadia

optimize_backward_concat=True after optimization

https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-ghg1f57z
Screenshot 2024-06-09 at 5 18 28 PM

** tflops/s = ~440 (+15%)**
Screenshot 2024-06-09 at 5 19 03 PM

trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.17125783820625.json.gz&bucket=acadia

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 10, 2024
@chrisxcai chrisxcai requested a review from awgu June 10, 2024 00:31
@chrisxcai chrisxcai marked this pull request as ready for review June 10, 2024 00:36
Copy link

@yuchenhao yuchenhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like both the trainer and ods think there is no extra memory usage any more. Thanks for root causing and fixing the issue!

@@ -174,7 +174,7 @@ def __init__(
self._require_backward_grad_sync = True
# If optimize_backward_concat == True, used to accumulate the
# fp32 gradients for the flattened parameters
self.fp32_grads = []
self.fp32_grads = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about renaming this to something like fp32_flat_grad to indicate the type and shape after the change?

@@ -1765,11 +1765,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

if self.fp32_reduce_scatter:
if self.optimize_backward_concat:
# Flatten and concat the accumulated fp32 grads
# and assign them to param.unsharded_main_grad
param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the issue that upon this torch.cat call, we have 2x fp32 unsharded gradient memory? (This would be a temporary spike since the source individual fp32 gradients are freed immediately upon L1772?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly @awgu, even though the memory is freed right afterwards, it still triggers the memory allocation via CUDA Caching Allocator, which will increase the peak GPU memory(reflected via torch_cuda_max_reserved)

@chrisxcai chrisxcai merged commit b73fffe into ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard Jun 10, 2024
1 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants