Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDPv1] Optimize memory usage for optimize_backward_concat=True #1186

Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1765,11 +1765,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

if self.fp32_reduce_scatter:
if self.optimize_backward_concat:
# Flatten and concat the accumulated fp32 grads
# and assign them to param.unsharded_main_grad
param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the issue that upon this torch.cat call, we have 2x fp32 unsharded gradient memory? (This would be a temporary spike since the source individual fp32 gradients are freed immediately upon L1772?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly @awgu, even though the memory is freed right afterwards, it still triggers the memory allocation via CUDA Caching Allocator, which will increase the peak GPU memory(reflected via torch_cuda_max_reserved)

param.unsharded_main_grad = self._fsdp_wrapped_module.fp32_grads
# Clean up accumulated grads between data batches
self._fsdp_wrapped_module.fp32_grads = []
self._fsdp_wrapped_module.fp32_grads = None
else:
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
33 changes: 20 additions & 13 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
@@ -174,7 +174,7 @@ def __init__(
self._require_backward_grad_sync = True
# If optimize_backward_concat == True, used to accumulate the
# fp32 gradients for the flattened parameters
self.fp32_grads = []
self.fp32_grads = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about renaming this to something like fp32_flat_grad to indicate the type and shape after the change?


# Handle param_list being None.
if param_list is None:
@@ -382,12 +382,16 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No
def _grad_accumulation_hook(
self,
grad,
param_index,
start,
end,
):
if self.fp32_grads[param_index] is None:
self.fp32_grads[param_index] = grad.to(torch.float32)
else:
self.fp32_grads[param_index].add_(grad)
"""
start: int, the starting index(inclusive) of the grad of this parameter in self.fp32_grads
end: int, the ending index(exclusive) of the grad of this parameter in self.fp32_grads
"""

assert self.fp32_grads is not None
self.fp32_grads[start:end].add_(grad.flatten())
return grad

def _unflatten_params_as_views(self) -> None:
@@ -411,26 +415,29 @@ def _unflatten_params_as_views(self) -> None:
ps = self.get_param_views()

param_views = []
param_start = 0
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(m, n, p) # This will set as plain attr
if self.optimize_backward_concat:
# The param_index of parameter p used to accumulate the correspnding
# gradients in self.fp32_grads
param_index = len(param_views)
# Register post backward hook to accumulate the gradients
# in self.fp32_grads
param_end = param_start + torch.numel(p)
p.register_hook(
functools.partial(
self._grad_accumulation_hook,
param_index=param_index
start=param_start,
end=param_end,
)
)
param_start = param_end
param_views.append(p)

if self.optimize_backward_concat and len(self.fp32_grads) == 0:
# Allocate self.fp32_grads at the beginning of each data batch's forward()
self.fp32_grads = [None] * len(param_views)
if self.optimize_backward_concat and self.fp32_grads is None:
# Allocate GPU memory for flattened fp32 grad accumulation
total_numels = sum([torch.numel(p) for p in param_views])
self.fp32_grads = torch.zeros(total_numels, dtype=torch.float32, device=torch.cuda.current_device())


# Save param views for easy access if anyone still wants to access
# parameters of the module.