Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add CUDA graph tests with FP8 weight caching #869

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def __init__(self, hidden_size, nheads, kv, seq_len):

modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]

optimizers = [torch.optim.SGD, torch.optim.Adam]

all_boolean = [True, False]

dtypes = [torch.float32, torch.float16]
Expand Down Expand Up @@ -123,7 +121,6 @@ def _test_cuda_graphs(
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
optimizer: torch.optim.Optimizer,
graph_mode: str,
) -> List[torch.Tensor]:
"""Helper function for test."""
Expand Down Expand Up @@ -170,6 +167,11 @@ def _test_cuda_graphs(
config.h, config.h, device="cuda", params_dtype=dtype
) for _ in range(num_layers)]

# Initialize gradient buffers.
for module in modules:
for param in module.parameters():
param.grad = torch.empty_like(param)

# Generate model and wrap API to return graphed version.
if graph_mode == "full":
# Graph entire model at once.
Expand Down Expand Up @@ -199,7 +201,7 @@ def _test_cuda_graphs(

# Loss function and optimizer.
if not dpa:
optimizer = optimizer(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In dropping the adam optimizer from the test parameterization, are we losing any test coverage w.r.t. cuda graphs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't really test convergence, but just checks that results match exactly with and without CUDA graphs. SGD and Adam are implemented similarly, so we just need to test one to make sure that the CUDA graph infrastructure is working correctly.


# Launch.
for train_step in range(3):
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -212,7 +214,7 @@ def _test_cuda_graphs(
if fp8_weight_caching:
kwargs["is_first_microbatch"] = (grad_accumulation_step == 0)
output = model(*inputs, **kwargs)
(output * grad_output).sum().backward()
output.backward(grad_output)
if not dpa:
optimizer.step()

Expand All @@ -222,12 +224,11 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("num_layers", [1, 10])
@pytest.mark.parametrize("num_layers", [1, 3])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
@pytest.mark.parametrize("fp8_weight_caching", all_boolean)
@pytest.mark.parametrize("module", modules)
@pytest.mark.parametrize("optimizer", optimizers)
def test_gpt_make_graphed_callables(
dtype: torch.dtype,
bs: int,
Expand All @@ -237,7 +238,6 @@ def test_gpt_make_graphed_callables(
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
optimizer: torch.optim.Optimizer,
) -> None:
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
Expand All @@ -259,7 +259,6 @@ def test_gpt_make_graphed_callables(
fp8_params=fp8_params,
fp8_weight_caching=fp8_weight_caching,
module=module,
optimizer=optimizer,
)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
Expand Down
5 changes: 0 additions & 5 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,6 @@ def forward_func(*args, **kwargs):
else:
torch.cuda.set_rng_state(original_rng_states)

# Reset FP8 gradients.
for module in modules:
for p in module.parameters():
p.grad = None

# Restore FP8 state.
restore_fp8_tensors(modules, saved_fp8_tensors)

Expand Down
Loading