Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for the tt.dot_scaled operation on XPU #2633

Closed
anmyachev opened this issue Nov 5, 2024 · 2 comments · Fixed by #2679 or #2804
Closed

Implement support for the tt.dot_scaled operation on XPU #2633

anmyachev opened this issue Nov 5, 2024 · 2 comments · Fixed by #2679 or #2804
Assignees
Labels
enhancement New feature or request tests: ut

Comments

@anmyachev
Copy link
Contributor

anmyachev commented Nov 5, 2024

Triton has introduced a new operation tt.dot_scaled in triton-lang/triton#4795 which we need to support for Intel GPUs.

@anmyachev anmyachev added enhancement New feature or request tests: ut labels Nov 5, 2024
@etiotto etiotto self-assigned this Nov 11, 2024
@etiotto etiotto changed the title Enable test_scaled_dot on XPU Implement support for the tt.dot_scaled operation on XPU Nov 11, 2024
@etiotto etiotto linked a pull request Nov 11, 2024 that will close this issue
@etiotto
Copy link
Contributor

etiotto commented Nov 14, 2024

Given this kernel, containing a tt.dot_scaled operation:

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>

  tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> {
    %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
    %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
    tt.return %result : tensor<128x128xf32, #blocked>
  }

The tt.dot_scaled operation can be replaced by a tt.dot operation where operands %a is converted to a fp16 tensor via the triton_gen.upcast_mxfp` operation:

de we need to generate is:

 
 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
 #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
 #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
 #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>

   tt.func @dot_scaled(%a: tensor<128x32xi8, #blocked>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked2>) -> tensor<128x128xf32, #blocked2> {
     %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked2>
    %0 = triton_gpu.convert_layout %cst : tensor<128x128xf32, #blocked2> -> tensor<128x128xf32, #mma>
    %1 = triton_gpu.convert_layout %a : tensor<128x32xi8, #blocked> -> tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
    %2 = triton_gpu.convert_layout %scale : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #blocked1>
    %3 = triton_gpu.upcast_mxfp %1, %2 fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<128x2xi8, #blocked1> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
    %4 = triton_gpu.convert_layout %b : tensor<64x128xbf16, #blocked2> -> tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
    %5 = tt.dot %3, %4, %0 : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma>
    %6 = triton_gpu.convert_layout %5 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked2>
    tt.return %6 : tensor<128x128xf32, #blocked2>
  }

@etiotto
Copy link
Contributor

etiotto commented Dec 2, 2024

This is complete when the scale operator is applied to the "A" operand of the tt.dot_scaled operator.

@etiotto etiotto closed this as completed Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment