Skip to content

Commit

Permalink
Fix the contiguous_param_buffer bug about bprop overlap and redundant…
Browse files Browse the repository at this point in the history
… copy after all-gather.
  • Loading branch information
alpha0422 committed Apr 25, 2024
1 parent c562d3d commit 394f401
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,13 @@ def init_param_buffer(self) -> None:
param._data = buffer_view.view(param.size())
else:
# Preserve memory format for param here, i.e. NHWC tensors
param.data.set_(
source=buffer_view,
storage_offset=0,
size=param.size(),
stride=param.stride(),
# `param.data.set_()` failed to change storage.
# `param.set_()` invalidates bprop hook.
param.data = torch.as_strided(
buffer_view,
param.size(),
param.stride(),
storage_offset=buffer_view.storage_offset(),
)

def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None:
Expand Down

0 comments on commit 394f401

Please sign in to comment.