Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add CUDA graph tests with FP8 weight caching #869

Merged
merged 6 commits into from
Jun 3, 2024

Conversation

timmoon10
Copy link
Collaborator

Description

Our CUDA graph infrastructure (#575) supports FP8 weight caching when training with multiple gradient accumulation steps. While adding tests for this functionality (see #820 (comment)), I ran into subtle correctness issues because te.make_graphed_callables resets grad buffers after capturing graphs, so the gradient buffer filled in the backward pass is different from the gradient buffer used by the optimizer. Note that we didn't detect this before because Megatron-LM and Nemo explicitly manage gradient buffers, e.g. with a distributed optimizer.

This PR modifies the CUDA graph tests to initialize grads before make_graphed_callables and it avoids resetting grads within make_graphed_callables.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

Changes

Please list the changes introduced in this PR:

  • Add CUDA graph tests for FP8 weight caching
  • Do not reset parameter gradients after graph capture

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 requested a review from ksivaman May 25, 2024 01:15
@timmoon10 timmoon10 added the bug Something isn't working label May 25, 2024
@timmoon10
Copy link
Collaborator Author

timmoon10 commented May 25, 2024

Note that users must also initialize grads before calling torch.cuda.make_graphed_callables, or else they'll run into similar correctness issues. Consider the following example:

import torch
torch.set_default_device('cuda')

# Construct linear module
model = torch.nn.Linear(1, 1, bias=False)
with torch.no_grad():
    model.weight.fill_(1)
    # model.weight.grad = torch.empty_like(model.weight)  # Uncomment to fix bug

# Capture CUDA graph
x = torch.ones((1, 1), requires_grad=True)
model = torch.cuda.make_graphed_callables(model, (x,))

# Training steps
for step in range(3):
    if model.weight.grad is not None:
        model.weight.grad.zero_()
    x = torch.ones((1, 1), requires_grad=True)
    y = model(x)
    y.backward(torch.ones((1, 1)))
    print(f"{step=}, {model.weight.grad.item()=}")

I expect the weight gradient to always be 1. However:

step=0, model.weight.grad.item()=1.0
step=1, model.weight.grad.item()=2.0
step=2, model.weight.grad.item()=2.0

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Copy link

@nvcforster nvcforster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a question about coverage, otherwise LGTM.

tests/pytorch/test_cuda_graphs.py Outdated Show resolved Hide resolved
@@ -199,7 +201,7 @@ def _test_cuda_graphs(

# Loss function and optimizer.
if not dpa:
optimizer = optimizer(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In dropping the adam optimizer from the test parameterization, are we losing any test coverage w.r.t. cuda graphs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't really test convergence, but just checks that results match exactly with and without CUDA graphs. SGD and Adam are implemented similarly, so we just need to test one to make sure that the CUDA graph infrastructure is working correctly.

@timmoon10 timmoon10 merged commit 868c7d3 into NVIDIA:main Jun 3, 2024
9 checks passed
@timmoon10 timmoon10 deleted the debug-cuda-graph-microbatching branch June 3, 2024 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants