Skip to content

Commit

Permalink
Add AMD Rowwise FP8 Matmul
Browse files Browse the repository at this point in the history
Summary: This diff extends the `fp8fp8bf16_rowwise` gemm operation to AMD through a new CK kernel. The new kernel requires new stride support that is only available in developer branches of CK, so we must rely on the `ai_codesign/gen_ai` CK repo. I also extend the fp8 benchmarking suite to include rowwise measurements. I'll soon add detailed benchmarking results but the quick summary is that performance looks quite good, typically inline with tensorwise quantization and sometimes faster, presumably due to using the latest and greatest CK pipelines.

Differential Revision: D57600068
  • Loading branch information
jwfromm authored and facebook-github-bot committed May 21, 2024
1 parent 581fcec commit 352a9af
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 36 deletions.
110 changes: 90 additions & 20 deletions fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import triton # @manual=//triton:triton

E4M3_MAX_POS: float = torch.finfo(torch.float8_e4m3fnuz).max
FP16_MAX_POS: float = torch.finfo(torch.float16).max
EPS = 1e-12


Expand All @@ -31,6 +32,23 @@ def set_amd_env_vars() -> None:
os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30"


@torch.no_grad()
def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Quantize an input tensor and return the fp8 tensor and its inverse scale.
x_row_max = torch.max(torch.abs(x), dim=1).values
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
scale = E4M3_MAX_POS / torch.clamp(x_row_max, EPS)
if x.dtype is torch.float16:
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`.
scale = torch.clamp(scale, max=FP16_MAX_POS)
# pyre-fixme[16]: Item `float` of `typing.Union[float, torch._tensor.Tensor]` has no attribute `__getitem__`.
xq = torch.clamp(x * scale[:, None], min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS).to(
torch.float8_e4m3fnuz
)
# pyre-fixme[16]: Item `float` of `typing.Union[float, torch._tensor.Tensor]` has no attribute `__getitem__`.
return xq, scale.to(torch.float32).reciprocal()


def fp8_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
f8_max = torch.tensor(E4M3_MAX_POS, device=x.device)
x_amax = torch.max(torch.abs(x))
Expand Down Expand Up @@ -65,11 +83,22 @@ def forward(
return output


class CKMatmul(torch.nn.Module):
class CKTensorMatmul(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
out = torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale)
return torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale)


class CKRowMatmul(torch.nn.Module):
def forward(
self,
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
) -> torch.Tensor:
out = torch.ops.fbgemm.f8f8bf16_rowwise(a, b, a_scale, b_scale)
return out


Expand All @@ -82,13 +111,18 @@ def evaluate_impl(
baseline_func: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
],
ck_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> Tuple[float, float, float, float, float]:
ck_tensor_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
ck_row_func: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
],
) -> Tuple[float, float, float, float, float, float, float]:
print(f"Evaluating {M=}, {N=}, {K=}")
A = torch.randn(M, K).to(dtype=torch.bfloat16, device="cuda")
QA, a_scale = fp8_quantize(A)
B = torch.randn(N, K).to(dtype=torch.bfloat16, device="cuda")
QB, b_scale = fp8_quantize(B)
QA_row, a_scale_row = fp8_row_quantize(A)
QB_row, b_scale_row = fp8_row_quantize(B)

# Check accuracy.
out_ref = fp_func(A.to(torch.float32), B.t().to(torch.float32))
Expand All @@ -97,9 +131,13 @@ def evaluate_impl(
baseline_sim = torch.mean(torch.pow(torch.abs(baseline_out - out_ref), 2))
print(f"Baseline accuracy: {baseline_sim}")

ck_out = ck_func(QA, QB, a_scale * b_scale)
ck_sim = torch.mean(torch.pow(torch.abs(ck_out - out_ref), 2))
print(f"CK accuracy: {ck_sim}")
ck_tensor_out = ck_tensor_func(QA, QB, a_scale * b_scale)
ck_tensor_sim = torch.mean(torch.pow(torch.abs(ck_tensor_out - out_ref), 2))
print(f"CK tensorwise accuracy: {ck_tensor_sim}")

ck_row_out = ck_row_func(QA_row, QB_row, a_scale_row, b_scale_row)
ck_row_sim = torch.mean(torch.pow(torch.abs(ck_row_out - out_ref), 2))
print(f"CK rowwise accuracy: {ck_row_sim}")

# Benchmark runtimes.
ms_ref: float = triton.testing.do_bench(lambda: fp_func(A, B.t()))
Expand All @@ -110,23 +148,45 @@ def evaluate_impl(
)
print(f"Baseline runtime: {ms_baseline} ms")

ms_ck: float = triton.testing.do_bench(lambda: ck_func(QA, QB, a_scale * b_scale))
print(f"CK runtime: {ms_ck} ms")
ms_tensor_ck: float = triton.testing.do_bench(
lambda: ck_tensor_func(QA, QB, a_scale * b_scale)
)
print(f"CK tensorwise runtime: {ms_tensor_ck} ms")

return float(baseline_sim.item()), float(ck_sim.item()), ms_baseline, ms_ck, ms_ref
ms_row_ck: float = triton.testing.do_bench(
lambda: ck_row_func(QA_row, QB_row, a_scale_row, b_scale_row)
)
print(f"CK rowwise runtime: {ms_row_ck} ms")

return (
float(baseline_sim.item()),
float(ck_tensor_sim.item()),
float(ck_row_sim.item()),
ms_baseline,
ms_tensor_ck,
ms_row_ck,
ms_ref,
)


def main(args: Any) -> None:
if args.enable_amd_env_vars:
set_amd_env_vars()

with torch.no_grad():
ck_mod = CKMatmul()
ck_tensor_mod = CKTensorMatmul()
ck_row_mod = CKRowMatmul()
baseline_mod = BaselineMatmul()
bf16_mod = FPMatMul()
if args.torch_compile_mode:
ck_mod = torch.compile(
ck_mod,
ck_tensor_mod = torch.compile(
ck_tensor_mod,
dynamic=False,
backend="inductor",
mode=args.torch_compile_mode,
)
ck_row_mod = torch.compile(
ck_row_mod,
dynamic=False,
backend="inductor",
mode=args.torch_compile_mode,
Expand All @@ -147,25 +207,35 @@ def main(args: Any) -> None:
benchmark_results = []

# Test over a bunch of shapes.
M = [13312, 16384, 16032, 2304, 2048]
N = [4096, 2304, 13312, 8192]
K = [16384, 6656, 2304, 2048, 13312]
M = [128, 2048, 2304, 13312, 16032, 16384]
N = [128, 2304, 4096, 8192, 13312]
K = [128, 2048, 2304, 6656, 13312, 16384]

for m in M:
for n in N:
for k in K:
baseline_sim, ck_sim, ms_baseline, ms_ck, ms_bf16 = evaluate_impl(
m, n, k, bf16_mod, baseline_mod, ck_mod
(
baseline_sim,
ck_tensor_sim,
ck_row_sim,
ms_baseline,
ms_tensor_ck,
ms_row_ck,
ms_bf16,
) = evaluate_impl(
m, n, k, bf16_mod, baseline_mod, ck_tensor_mod, ck_row_mod
)
benchmark_results.append(
{
"M": m,
"N": n,
"K": k,
"baseline_sim": baseline_sim,
"ck_sim": ck_sim,
"ck_tensor_sim": ck_tensor_sim,
"ck_row_sim": ck_row_sim,
"ms_baseline": ms_baseline,
"ms_ck": ms_ck,
"ms_tensor_ck": ms_tensor_ck,
"ms_row_ck": ms_row_ck,
"ms_bf16": ms_bf16,
}
)
Expand Down
Loading

0 comments on commit 352a9af

Please sign in to comment.