Skip to content

Commit

Permalink
PR #7751: Support cublasLt Fp8 Approx Gelu epilogue fusion.
Browse files Browse the repository at this point in the history
Imported from GitHub PR #7751

Due to fast accumulation being turned on in the forward mode, the cublasLt fp8 gemm with gelu epilogue can efficiently operate with a fused kernel. Compared against the XLA-generated gelu kernel on H100, the performance demonstrates some improvement for size of [8192, 4096] x [4096, 16384] + gelu:

Execution time for matmul using cublasLt and gelu (XLA): 1.28ms
Execution time for matmul_gelu using cublasLt: 1.25ms
Copybara import of the project:

--
e8abce3 by Shu Wang <[email protected]>:

Support cublasLt Fp8 Approx Gelu epilogue fusion.

--
818127c by shuw <[email protected]>:

Remove F32 check

--
5ce3108 by shuw <[email protected]>:

Improve based on review #1

Merging this change closes #7751

COPYBARA_INTEGRATE_REVIEW=#7751 from wenscarl:cublaslt_fp8_gelu 5ce3108
PiperOrigin-RevId: 591236441
  • Loading branch information
wenscarl authored and copybara-github committed Dec 15, 2023
1 parent 3f62ba1 commit 2724718
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 3 deletions.
6 changes: 3 additions & 3 deletions xla/service/gpu/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
if (Match(instr, m::MultiplyAnyOrder(
m::AnyOf<HloInstruction>(
m::Slice(&slice_or_bitcast,
CublasLtMatmul(&existing_gemm)),
CublasLtMatmulMaybeF8(&existing_gemm)),
m::Bitcast(&slice_or_bitcast,
CublasLtMatmul(&existing_gemm)),
CublasLtMatmul(&existing_gemm)),
CublasLtMatmulMaybeF8(&existing_gemm)),
CublasLtMatmulMaybeF8(&existing_gemm)),
m::Op(&cdf).WithOneUser())) &&
Match(cdf,
m::MultiplyAnyOrder(
Expand Down
146 changes: 146 additions & 0 deletions xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5586,6 +5586,152 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) {
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest,
ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) {
#if GOOGLE_CUDA && CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
#endif // CUDA_VERSION < 12000
const char* hlo_text = R"(
HloModule test
ENTRY test {
x = f8e4m3fn[16,32] parameter(0)
y = f8e4m3fn[32,16] parameter(1)
x_bf16 = bf16[16,32] convert(x)
y_bf16 = bf16[32,16] convert(y)
x_scale = bf16[] parameter(2)
y_scale = bf16[] parameter(3)
bias = bf16[16] parameter(4)
x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
b_bcast = bf16[16,16] broadcast(bias), dimensions={1}
dot = bf16[16,16] add(dot1, b_bcast)
mul.0 = bf16[16,16] multiply(dot, dot)
mul.1 = bf16[16,16] multiply(dot, mul.0)
const.0 = bf16[] constant(0.044715)
bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
add.0 = bf16[16,16] add(dot, mul.2)
const.1 = bf16[] constant(0.797884583)
bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
mul.3 = bf16[16,16] multiply(add.0, bcast.1)
tanh = bf16[16,16] tanh(mul.3)
const.2 = bf16[] constant(1)
bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
add.2 = bf16[16,16] add(tanh, bcast.2)
const.3 = bf16[] constant(0.5)
bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
mul.4 = bf16[16,16] multiply(add.2, bcast.3)
ROOT out = bf16[16,16] multiply(dot, mul.4)
}
)";

CheckFp8IfSupported(hlo_text);
RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(CudaHopperOrRocm()),
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: bf16[], y_scale: bf16[], bias: bf16[16]) -> bf16[16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0)
; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2)
; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]])
; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3)
; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]])
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4)
; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
; CHECK-DAG: "alpha_imag":0
; CHECK-DAG: "beta":0
; CHECK-DAG: "dot_dimension_numbers":{
; CHECK-DAG: "lhs_contracting_dimensions":["1"]
; CHECK-DAG: "rhs_contracting_dimensions":["1"]
; CHECK-DAG: "lhs_batch_dimensions":[]
; CHECK-DAG: "rhs_batch_dimensions":[]
; CHECK-DAG: }
; CHECK-DAG: "precision_config":{
; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
; CHECK-DAG: }
; CHECK-DAG: "epilogue":"BIAS_GELU"
; CHECK: }
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest,
ScaledABUnscaledDApproxGeluActivationF8) {
#if GOOGLE_CUDA && CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
#endif // CUDA_VERSION < 12000
const char* hlo_text = R"(
HloModule test
ENTRY test {
x = f8e4m3fn[16,32] parameter(0)
y = f8e4m3fn[32,16] parameter(1)
x_bf16 = bf16[16,32] convert(x)
y_bf16 = bf16[32,16] convert(y)
x_scale = bf16[] parameter(2)
y_scale = bf16[] parameter(3)
x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
mul.0 = bf16[16,16] multiply(dot, dot)
mul.1 = bf16[16,16] multiply(dot, mul.0)
const.0 = bf16[] constant(0.044715)
bcast.0 = bf16[16,16] broadcast(const.0), dimensions={}
mul.2 = bf16[16,16] multiply(mul.1, bcast.0)
add.0 = bf16[16,16] add(dot, mul.2)
const.1 = bf16[] constant(0.797884583)
bcast.1 = bf16[16,16] broadcast(const.1), dimensions={}
mul.3 = bf16[16,16] multiply(add.0, bcast.1)
tanh = bf16[16,16] tanh(mul.3)
const.2 = bf16[] constant(1)
bcast.2 = bf16[16,16] broadcast(const.2), dimensions={}
add.2 = bf16[16,16] add(tanh, bcast.2)
const.3 = bf16[] constant(0.5)
bcast.3 = bf16[16,16] broadcast(const.3), dimensions={}
mul.4 = bf16[16,16] multiply(add.2, bcast.3)
ROOT out = bf16[16,16] multiply(dot, mul.4)
}
)";

CheckFp8IfSupported(hlo_text);
RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(CudaHopperOrRocm()),
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: bf16[], y_scale: bf16[]) -> bf16[16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0)
; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2)
; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]])
; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3)
; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]])
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
; CHECK-DAG: "alpha_imag":0
; CHECK-DAG: "beta":0
; CHECK-DAG: "dot_dimension_numbers":{
; CHECK-DAG: "lhs_contracting_dimensions":["1"]
; CHECK-DAG: "rhs_contracting_dimensions":["1"]
; CHECK-DAG: "lhs_batch_dimensions":[]
; CHECK-DAG: "rhs_batch_dimensions":[]
; CHECK-DAG: }
; CHECK-DAG: "precision_config":{
; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
; CHECK-DAG: }
; CHECK-DAG: "epilogue":"GELU"
; CHECK: }
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) {
#if GOOGLE_CUDA && CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
Expand Down

0 comments on commit 2724718

Please sign in to comment.