Skip to content

Commit

Permalink
Enable torch.compile compatibility for triton fp8 rowwise gemm (#2978)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2978

X-link: facebookresearch/FBGEMM#74

This diff adds custom op wrappers around `quantize_fp8_row` and `matmul_fp8_row`. This should make them opaque to torch.compile and prevent issues where dynamo tries to trace triton code that is meant to be precompiled. I also add registration for fake kernels so that torch.compile can properly pass faketensors through the ops.

Reviewed By: henrylhtsang

Differential Revision: D61216580

fbshipit-source-id: c96cba96775656213c27f3fe36a20325376ca082
  • Loading branch information
jwfromm authored and facebook-github-bot committed Aug 13, 2024
1 parent 425d1ac commit cdb290a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
45 changes: 38 additions & 7 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _test_matmul_fp8_row(
fp8_fast_accum: bool,
use_bias: bool = False,
transpose_input: bool = False,
compile: bool = False,
) -> None:
M, N, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
Expand All @@ -113,13 +114,42 @@ def _test_matmul_fp8_row(
torch.randn(N, dtype=torch.float32, device=device) if use_bias else None
)

# Quantize inputs.
a_fp8, a_scale = quantize_fp8_row(a)
b_fp8, b_scale = quantize_fp8_row(b)

result = matmul_fp8_row(
a_fp8, b_fp8, a_scale, b_scale, bias=bias, fp8_fast_accum=fp8_fast_accum
)
# Test that we can compile the full fp8 matmul operation.
if compile:

@torch.compile(fullgraph=True)
def _quantize_matmul_fp8(
a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor],
fp8_fast_accum: bool,
) -> torch.Tensor:
a_fp8, a_scale = quantize_fp8_row(a)
b_fp8, b_scale = quantize_fp8_row(b)
return matmul_fp8_row(
a_fp8,
b_fp8,
a_scale,
b_scale,
bias=bias,
fp8_fast_accum=fp8_fast_accum,
)

result = _quantize_matmul_fp8(a, b, bias, fp8_fast_accum)
# Otherwise run normally.
else:
# Quantize inputs.
a_fp8, a_scale = quantize_fp8_row(a)
b_fp8, b_scale = quantize_fp8_row(b)

result = matmul_fp8_row(
a_fp8,
b_fp8,
a_scale,
b_scale,
bias=bias,
fp8_fast_accum=fp8_fast_accum,
)
self.assertTrue(result.shape == (M, N))

expected_result = a @ b.T
Expand All @@ -130,6 +160,7 @@ def _test_matmul_fp8_row(
)

_test_matmul_fp8_row((3, 4, 5), torch.device("cuda"), True)
_test_matmul_fp8_row((3, 4, 5), torch.device("cuda"), True, compile=True)
_test_matmul_fp8_row(
(5, 4, 5), torch.device("cuda"), True, transpose_input=True
)
Expand Down
38 changes: 38 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ def _kernel_matmul_fp8_row_tma_persistent(
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)


@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=())
def matmul_fp8_row(
a: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -1065,6 +1066,25 @@ def persistent_grid(META):
return c.view(output_shape)


@matmul_fp8_row.register_fake
def matmul_fp8_row_meta(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dot_out_dtype: Optional[torch.dtype] = None,
allow_tf32: bool = True,
fp8_fast_accum: bool = True,
imprecise_acc: bool = False,
tma_persistent: bool = False,
) -> torch.Tensor:
"""Shape function for torch compile."""
M, K = a.shape
N, K = b.shape
return torch.empty((M, N), device=a.device, dtype=torch.bfloat16)


# pruned some unreasonable config
def prune_configs_block(configs, named_args, **kwargs):
configs = early_config_prune(configs, named_args, **kwargs)
Expand Down Expand Up @@ -1794,6 +1814,7 @@ def triton_quantize_fp8_row(
return a_fp8.view(a_shape), a_scale


@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
def quantize_fp8_row(
a: Tensor,
scale_ub: Optional[Tensor] = None,
Expand Down Expand Up @@ -1845,6 +1866,23 @@ def quantize_fp8_row(
return a_fp8.view(a_shape), 1 / a_scale # pyre-ignore


@quantize_fp8_row.register_fake
def quantize_fp8_row_meta(
a: Tensor,
scale_ub: Optional[Tensor] = None,
use_triton: bool = True,
output_device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Shape function for torch compile."""
if output_device is None:
output_device = a.device
M, K = a.shape
dtype = get_fp8_constants()[0]
fake_out = torch.empty((M, K), device=output_device, dtype=dtype)
fake_scale = torch.empty((M), device=output_device, dtype=torch.float32)
return fake_out, fake_scale


@triton.autotune(
configs=[
Config({"BLOCK_SIZE": 512}),
Expand Down

0 comments on commit cdb290a

Please sign in to comment.