Skip to content

Commit

Permalink
[BACKEND] Simplify and comment warp allocation logic in mmav2 (triton…
Browse files Browse the repository at this point in the history
…-lang#5041)

It's not entirely clear to me whether the previous logic was equivalent
to this one, as it was rather obtuse. I think the new one is optimal but
I'm happy to run benchmarks to make sure we don't regress.
  • Loading branch information
lezcano authored and guacamoleo committed Nov 14, 2024
1 parent 1219b01 commit dc5286d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
46 changes: 26 additions & 20 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
Expand Down Expand Up @@ -77,28 +78,33 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}
}

SmallVector<unsigned> ret(rank, 1);
SmallVector<int64_t> shapePerWarp(rank, 1);
shapePerWarp[rank - 1] = 8;
shapePerWarp[rank - 2] = 16;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
assert(rank == 2);
SmallVector<int64_t> shapePerWarp = {16, 8};
SmallVector<int64_t> warps = {1, 1};
// Compute repM and repN
SmallVector<int64_t> reps = {ceil(shape[0], shapePerWarp[0]),
ceil(shape[1], shapePerWarp[1])};
// The formula for the number of registers given the reps is
// repM * 4 * repK + repN * 2 * repK + regsC
// where regsC = repM * repN * 4, which does not depend on the warp shape
//
// As such, to minimize the register pressure, we need to balance
// repM and repN. We then untie towards M, as the lhs tile has 4 elements,
// and the rhs tile has just 2.
while (product(warps) < numWarps) {
if (reps[0] >= reps[1]) {
warps[0] *= 2;
// Too many warps for this mma (repM == repN == 1).
// We allocate the remainin warps to the left (arbitrary choice)
if (reps[0] != 1) {
reps[0] /= 2;
}
} else {
ret[1] *= 2;
warps[1] *= 2;
reps[1] /= 2;
}
} while (true);
return ret;
}
return {(unsigned)warps[0], (unsigned)warps[1]};
}

SmallVector<unsigned, 2>
Expand Down
6 changes: 3 additions & 3 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
Expand All @@ -93,7 +93,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}>
Expand Down Expand Up @@ -148,7 +148,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
// -----

// Verify that we use mmav2 when the k dim is too small for mmav3.
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}>
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: small_k_size
Expand Down

0 comments on commit dc5286d

Please sign in to comment.