-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Add CUDA graph tests with FP8 weight caching #869
[PyTorch] Add CUDA graph tests with FP8 weight caching #869
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Note that users must also initialize grads before calling import torch
torch.set_default_device('cuda')
# Construct linear module
model = torch.nn.Linear(1, 1, bias=False)
with torch.no_grad():
model.weight.fill_(1)
# model.weight.grad = torch.empty_like(model.weight) # Uncomment to fix bug
# Capture CUDA graph
x = torch.ones((1, 1), requires_grad=True)
model = torch.cuda.make_graphed_callables(model, (x,))
# Training steps
for step in range(3):
if model.weight.grad is not None:
model.weight.grad.zero_()
x = torch.ones((1, 1), requires_grad=True)
y = model(x)
y.backward(torch.ones((1, 1)))
print(f"{step=}, {model.weight.grad.item()=}") I expect the weight gradient to always be 1. However:
|
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a question about coverage, otherwise LGTM.
@@ -199,7 +201,7 @@ def _test_cuda_graphs( | |||
|
|||
# Loss function and optimizer. | |||
if not dpa: | |||
optimizer = optimizer(model.parameters(), lr=0.001) | |||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In dropping the adam optimizer from the test parameterization, are we losing any test coverage w.r.t. cuda graphs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test doesn't really test convergence, but just checks that results match exactly with and without CUDA graphs. SGD and Adam are implemented similarly, so we just need to test one to make sure that the CUDA graph infrastructure is working correctly.
Description
Our CUDA graph infrastructure (#575) supports FP8 weight caching when training with multiple gradient accumulation steps. While adding tests for this functionality (see #820 (comment)), I ran into subtle correctness issues because
te.make_graphed_callables
resets grad buffers after capturing graphs, so the gradient buffer filled in the backward pass is different from the gradient buffer used by the optimizer. Note that we didn't detect this before because Megatron-LM and Nemo explicitly manage gradient buffers, e.g. with a distributed optimizer.This PR modifies the CUDA graph tests to initialize grads before
make_graphed_callables
and it avoids resetting grads withinmake_graphed_callables
.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: