Skip to content

Commit

Permalink
Improve based on review #1
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Dec 14, 2023
1 parent 818127c commit 5ce3108
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5596,15 +5596,15 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
ENTRY test {
x = f8e4m3fn[16,32] parameter(0)
y = f8e4m3fn[32,16] parameter(1)
x_f32 = bf16[16,32] convert(x)
y_f32 = bf16[32,16] convert(y)
x_bf16 = bf16[16,32] convert(x)
y_bf16 = bf16[32,16] convert(y)
x_scale = bf16[] parameter(2)
y_scale = bf16[] parameter(3)
bias = bf16[16] parameter(4)
x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
x_unscaled = bf16[16,32] multiply(x_f32, x_scale_bcast)
y_unscaled = bf16[32,16] multiply(y_f32, y_scale_bcast)
x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
dot1 = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
b_bcast = bf16[16,16] broadcast(bias), dimensions={1}
dot = bf16[16,16] add(dot1, b_bcast)
Expand Down Expand Up @@ -5671,14 +5671,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest,
ENTRY test {
x = f8e4m3fn[16,32] parameter(0)
y = f8e4m3fn[32,16] parameter(1)
x_f32 = bf16[16,32] convert(x)
y_f32 = bf16[32,16] convert(y)
x_bf16 = bf16[16,32] convert(x)
y_bf16 = bf16[32,16] convert(y)
x_scale = bf16[] parameter(2)
y_scale = bf16[] parameter(3)
x_scale_bcast = bf16[16,32] broadcast(x_scale), dimensions={}
y_scale_bcast = bf16[32,16] broadcast(y_scale), dimensions={}
x_unscaled = bf16[16,32] multiply(x_f32, x_scale_bcast)
y_unscaled = bf16[32,16] multiply(y_f32, y_scale_bcast)
x_unscaled = bf16[16,32] multiply(x_bf16, x_scale_bcast)
y_unscaled = bf16[32,16] multiply(y_bf16, y_scale_bcast)
dot = bf16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0}
mul.0 = bf16[16,16] multiply(dot, dot)
mul.1 = bf16[16,16] multiply(dot, mul.0)
Expand Down

0 comments on commit 5ce3108

Please sign in to comment.