Skip to content

Commit

Permalink
[BACKEND] Simplify and comment warp allocation logic in mmav2 (#5041)
Browse files Browse the repository at this point in the history
It's not entirely clear to me whether the previous logic was equivalent
to this one, as it was rather obtuse. I think the new one is optimal but
I'm happy to run benchmarks to make sure we don't regress.
  • Loading branch information
lezcano authored Nov 4, 2024
1 parent 73df068 commit 04d655e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
46 changes: 26 additions & 20 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
Expand Down Expand Up @@ -77,28 +78,33 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}
}

SmallVector<unsigned> ret(rank, 1);
SmallVector<int64_t> shapePerWarp(rank, 1);
shapePerWarp[rank - 1] = 8;
shapePerWarp[rank - 2] = 16;
// TODO (@daadaada): double-check.
// original logic in
// https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252
// seems buggy for shape = [32, 16] ?
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
assert(rank == 2);
SmallVector<int64_t> shapePerWarp = {16, 8};
SmallVector<int64_t> warps = {1, 1};
// Compute repM and repN
SmallVector<int64_t> reps = {ceil(shape[0], shapePerWarp[0]),
ceil(shape[1], shapePerWarp[1])};
// The formula for the number of registers given the reps is
// repM * 4 * repK + repN * 2 * repK + regsC
// where regsC = repM * repN * 4, which does not depend on the warp shape
//
// As such, to minimize the register pressure, we need to balance
// repM and repN. We then untie towards M, as the lhs tile has 4 elements,
// and the rhs tile has just 2.
while (product(warps) < numWarps) {
if (reps[0] >= reps[1]) {
warps[0] *= 2;
// Too many warps for this mma (repM == repN == 1).
// We allocate the remainin warps to the left (arbitrary choice)
if (reps[0] != 1) {
reps[0] /= 2;
}
} else {
ret[1] *= 2;
warps[1] *= 2;
reps[1] /= 2;
}
} while (true);
return ret;
}
return {(unsigned)warps[0], (unsigned)warps[1]};
}

SmallVector<unsigned, 2>
Expand Down
6 changes: 3 additions & 3 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
Expand All @@ -93,7 +93,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}>
Expand Down Expand Up @@ -148,7 +148,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
// -----

// Verify that we use mmav2 when the k dim is too small for mmav3.
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}>
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: small_k_size
Expand Down

0 comments on commit 04d655e

Please sign in to comment.