diff --git a/xla/service/gpu/gemm_rewriter.cc b/xla/service/gpu/gemm_rewriter.cc index 6d462a9cc20ae..24b8dcaa2de9f 100644 --- a/xla/service/gpu/gemm_rewriter.cc +++ b/xla/service/gpu/gemm_rewriter.cc @@ -595,10 +595,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (Match(instr, m::MultiplyAnyOrder( m::AnyOf( m::Slice(&slice_or_bitcast, - CublasLtMatmul(&existing_gemm)), + CublasLtMatmulMaybeF8(&existing_gemm)), m::Bitcast(&slice_or_bitcast, - CublasLtMatmul(&existing_gemm)), - CublasLtMatmul(&existing_gemm)), + CublasLtMatmulMaybeF8(&existing_gemm)), + CublasLtMatmulMaybeF8(&existing_gemm)), m::Op(&cdf).WithOneUser())) && Match(cdf, m::MultiplyAnyOrder( diff --git a/xla/service/gpu/tests/gemm_rewrite_test.cc b/xla/service/gpu/tests/gemm_rewrite_test.cc index fd4ae79ebf770..650265e470c3b 100644 --- a/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -5586,6 +5586,152 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, + ScaledABUnscaledDVectorBiasThenApproxGeluActivationF8) { +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#endif // CUDA_VERSION < 12000 + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = f8e4m3fn[16,32] parameter(0) + y = f8e4m3fn[32,16] parameter(1) + x_bf16 = bf16[16,32] convert(x) + y_bf16 = bf16[32,16] convert(y) + x_scale = bf16[] parameter(2) + y_scale = bf16[] parameter(3) + bias = bf16[16] parameter(4) + x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={} + y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={} + x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast) + y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast) + dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + b_bcast = bf16[16,16] broadcast(bias), dimensions={1} + dot = bf16[16,16] add(dot1, b_bcast) + mul.0 = bf16[16,16] multiply(dot, dot) + mul.1 = bf16[16,16] multiply(dot, mul.0) + const.0 = bf16[] constant(0.044715) + bcast.0 = bf16[16,16] broadcast(const.0), dimensions={} + mul.2 = bf16[16,16] multiply(mul.1, bcast.0) + add.0 = bf16[16,16] add(dot, mul.2) + const.1 = bf16[] constant(0.797884583) + bcast.1 = bf16[16,16] broadcast(const.1), dimensions={} + mul.3 = bf16[16,16] multiply(add.0, bcast.1) + tanh = bf16[16,16] tanh(mul.3) + const.2 = bf16[] constant(1) + bcast.2 = bf16[16,16] broadcast(const.2), dimensions={} + add.2 = bf16[16,16] add(tanh, bcast.2) + const.3 = bf16[] constant(0.5) + bcast.3 = bf16[16,16] broadcast(const.3), dimensions={} + mul.4 = bf16[16,16] multiply(add.2, bcast.3) + ROOT out = bf16[16,16] multiply(dot, mul.4) + } +)"; + + CheckFp8IfSupported(hlo_text); + RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(CudaHopperOrRocm()), + R"( +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: bf16[], y_scale: bf16[], bias: bf16[16]) -> bf16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2) +; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) +; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) +; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"BIAS_GELU" +; CHECK: } + )"); +} + +TEST_P(ParameterizedFp8GemmRewriteTest, + ScaledABUnscaledDApproxGeluActivationF8) { +#if GOOGLE_CUDA && CUDA_VERSION < 12000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#endif // CUDA_VERSION < 12000 + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = f8e4m3fn[16,32] parameter(0) + y = f8e4m3fn[32,16] parameter(1) + x_bf16 = bf16[16,32] convert(x) + y_bf16 = bf16[32,16] convert(y) + x_scale = bf16[] parameter(2) + y_scale = bf16[] parameter(3) + x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={} + y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={} + x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast) + y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast) + dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + mul.0 = bf16[16,16] multiply(dot, dot) + mul.1 = bf16[16,16] multiply(dot, mul.0) + const.0 = bf16[] constant(0.044715) + bcast.0 = bf16[16,16] broadcast(const.0), dimensions={} + mul.2 = bf16[16,16] multiply(mul.1, bcast.0) + add.0 = bf16[16,16] add(dot, mul.2) + const.1 = bf16[] constant(0.797884583) + bcast.1 = bf16[16,16] broadcast(const.1), dimensions={} + mul.3 = bf16[16,16] multiply(add.0, bcast.1) + tanh = bf16[16,16] tanh(mul.3) + const.2 = bf16[] constant(1) + bcast.2 = bf16[16,16] broadcast(const.2), dimensions={} + add.2 = bf16[16,16] add(tanh, bcast.2) + const.3 = bf16[] constant(0.5) + bcast.3 = bf16[16,16] broadcast(const.3), dimensions={} + mul.4 = bf16[16,16] multiply(add.2, bcast.3) + ROOT out = bf16[16,16] multiply(dot, mul.4) + } +)"; + + CheckFp8IfSupported(hlo_text); + RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(CudaHopperOrRocm()), + R"( +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], x_scale: bf16[], y_scale: bf16[]) -> bf16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[] parameter(2) +; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) +; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) +; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"GELU" +; CHECK: } + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { #if GOOGLE_CUDA && CUDA_VERSION < 12000 GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";