Skip to content

Commit

Permalink
[PyTorch] cuda graph support (#575)
Browse files Browse the repository at this point in the history
* FP8 cuda graphs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>

* Fix numerics

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* exclude torch compile from numerics tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* More numerics fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix CI

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* rm fusion from unfused path

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
  • Loading branch information
3 people authored Apr 12, 2024
1 parent 1b20f2d commit 73f8d90
Show file tree
Hide file tree
Showing 35 changed files with 2,196 additions and 789 deletions.
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.onnx_export

.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
5 changes: 3 additions & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
60 changes: 48 additions & 12 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import functools
from importlib.metadata import version
import os
import math
from typing import Any, Dict, List, Tuple, Union

from pkg_resources import packaging
Expand All @@ -28,15 +27,9 @@
fused_attn_bwd,
fused_attn_fwd,
)
from transformer_engine.pytorch.distributed import (
_set_cuda_rng_state,
CudaRNGStatesTracker,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule,
_prepare_backward,
)
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
Expand All @@ -58,10 +51,18 @@

_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))


def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()


@functools.cache
def _cudnn_version() -> Tuple[int, int, int]:
Expand All @@ -71,6 +72,7 @@ def _cudnn_version() -> Tuple[int, int, int]:
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)


class ModelConfig:
def __init__(
self,
Expand Down Expand Up @@ -103,6 +105,7 @@ def __init__(
self.num_layers = num_layers
self.bias_shape = bias_shape


def _is_fused_attention_supported(
config: ModelConfig,
dtype: torch.dtype,
Expand Down Expand Up @@ -151,24 +154,28 @@ def _is_fused_attention_supported(
return True, backends
return False, backends


@functools.cache
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")


@functools.cache
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")


@functools.cache
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")


def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
Expand All @@ -184,6 +191,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False
return True


def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type):
Expand All @@ -192,6 +200,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool:
return False
return True


model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
Expand All @@ -200,11 +209,13 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
}


param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]


def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
Expand All @@ -216,6 +227,7 @@ def get_swa(seq_q, seq_kv, w=None):
ml = ~ ml
return w, ml


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
Expand Down Expand Up @@ -313,6 +325,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
Expand All @@ -321,6 +334,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False)


model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
Expand All @@ -337,6 +351,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
Expand All @@ -345,6 +360,7 @@ def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)


model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
Expand Down Expand Up @@ -373,6 +389,7 @@ def test_dpa_mask(dtype, model_configs, model):
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
Expand All @@ -381,6 +398,7 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)


model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
Expand All @@ -398,6 +416,7 @@ def test_dpa_bias(dtype, model_configs, model):
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
Expand All @@ -413,6 +432,8 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}


@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
Expand All @@ -428,6 +449,8 @@ def test_dpa_sliding_window(dtype, model_configs, model):
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
}


@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
Expand All @@ -436,13 +459,15 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)


qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
# will add tests for thd layouts later when the support is available in fused attention
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
]


model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
Expand All @@ -455,6 +480,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout])
Expand All @@ -464,6 +490,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)


def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
Expand Down Expand Up @@ -646,6 +673,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:

return out, (inp[0].grad, inp[1].grad, inp[2].grad)


model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
Expand All @@ -658,6 +686,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
Expand Down Expand Up @@ -742,6 +771,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
Expand All @@ -755,6 +785,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)


@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
Expand All @@ -780,6 +811,7 @@ def find_factors(x):
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)


def _run_transformer_layer(
dtype: torch.dtype,
config: ModelConfig,
Expand Down Expand Up @@ -912,8 +944,10 @@ def _run_transformer_layer(
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
}

param_types_fp8 = [torch.float16]


@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
Expand Down Expand Up @@ -946,6 +980,7 @@ def test_dpa_fp8(dtype, model):
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)


def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
Expand Down Expand Up @@ -989,6 +1024,7 @@ def _run_dpa_fp8(dtype, config, backend):
dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous())


def _run_dpa_fp8_ref(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e.
plain PyTorch implementation in FP16 and inputs/outputs
Expand Down Expand Up @@ -1188,8 +1224,7 @@ def forward(
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:

with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
with torch.cuda.nvtx.range("_DPA"):
(
inputmat_t,
qkv_weight_t_fp8,
Expand Down Expand Up @@ -1298,6 +1333,7 @@ def backward(
None,
None)


class DPA_FP8(TransformerEngineBaseModule):
def __init__(
self,
Expand Down
Loading

0 comments on commit 73f8d90

Please sign in to comment.