-
Notifications
You must be signed in to change notification settings - Fork 285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDPv1] Optimize memory usage for optimize_backward_concat=True #1186
Merged
chrisxcai
merged 6 commits into
ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard
from
chriscai_ngoyal_changes_for_pp_fp8_jiecaoyu_free_fp16_shard_1
Jun 10, 2024
+23
−19
Merged
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -174,7 +174,7 @@ def __init__( | |
self._require_backward_grad_sync = True | ||
# If optimize_backward_concat == True, used to accumulate the | ||
# fp32 gradients for the flattened parameters | ||
self.fp32_grads = [] | ||
self.fp32_grads = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about renaming this to something like |
||
|
||
# Handle param_list being None. | ||
if param_list is None: | ||
|
@@ -382,12 +382,16 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No | |
def _grad_accumulation_hook( | ||
self, | ||
grad, | ||
param_index, | ||
start, | ||
end, | ||
): | ||
if self.fp32_grads[param_index] is None: | ||
self.fp32_grads[param_index] = grad.to(torch.float32) | ||
else: | ||
self.fp32_grads[param_index].add_(grad) | ||
""" | ||
start: int, the starting index(inclusive) of the grad of this parameter in self.fp32_grads | ||
end: int, the ending index(exclusive) of the grad of this parameter in self.fp32_grads | ||
""" | ||
|
||
assert self.fp32_grads is not None | ||
self.fp32_grads[start:end].add_(grad.flatten()) | ||
return grad | ||
|
||
def _unflatten_params_as_views(self) -> None: | ||
|
@@ -411,26 +415,29 @@ def _unflatten_params_as_views(self) -> None: | |
ps = self.get_param_views() | ||
|
||
param_views = [] | ||
param_start = 0 | ||
for (_, m, n), p in zip(self._param_infos, ps): | ||
setattr(p, '_fsdp_weight', True) | ||
setattr(m, n, p) # This will set as plain attr | ||
if self.optimize_backward_concat: | ||
# The param_index of parameter p used to accumulate the correspnding | ||
# gradients in self.fp32_grads | ||
param_index = len(param_views) | ||
# Register post backward hook to accumulate the gradients | ||
# in self.fp32_grads | ||
param_end = param_start + torch.numel(p) | ||
p.register_hook( | ||
functools.partial( | ||
self._grad_accumulation_hook, | ||
param_index=param_index | ||
start=param_start, | ||
end=param_end, | ||
) | ||
) | ||
param_start = param_end | ||
param_views.append(p) | ||
|
||
if self.optimize_backward_concat and len(self.fp32_grads) == 0: | ||
# Allocate self.fp32_grads at the beginning of each data batch's forward() | ||
self.fp32_grads = [None] * len(param_views) | ||
if self.optimize_backward_concat and self.fp32_grads is None: | ||
# Allocate GPU memory for flattened fp32 grad accumulation | ||
total_numels = sum([torch.numel(p) for p in param_views]) | ||
self.fp32_grads = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device()) | ||
|
||
|
||
# Save param views for easy access if anyone still wants to access | ||
# parameters of the module. | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the issue that upon this
torch.cat
call, we have 2x fp32 unsharded gradient memory? (This would be a temporary spike since the source individual fp32 gradients are freed immediately upon L1772?)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly @awgu, even though the memory is freed right afterwards, it still triggers the memory allocation via CUDA Caching Allocator, which will increase the peak GPU memory(reflected via torch_cuda_max_reserved)