Skip to content

Commit

Permalink
[BACKEND] Minor Bugfixes for SharedToDotOperand MMAv3 (triton-lang#5030)
Browse files Browse the repository at this point in the history
Two bugfixes following triton-lang#5009.

- When `BLOCK_M=64` and `num_warps > 4`, the order of warps for
DotOpEncoded tensor should be M-major instead of N-major, since WGMMA
expects the 4 warps in each warp group to be stacked along the M
dimension.
- Should use `mmaBitwidth` instead of `bitwidth` when calculating
`numRep` in `SharedToDotOperandMMAv2OrV3`. This was missed in a bad
rebase.

@lezcano I encountered these bugs when attempting to locally test the
[DotOp hoisting PR](triton-lang#5003)
after rebasing (they normally would be caught by `test_core.py` but that
path was not yet enabled in the last PR). With these fixes added, I was
able to successfully validate against pytorch.
  • Loading branch information
ggengnv authored Nov 4, 2024
1 parent 04d655e commit e82dfd9
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -659,15 +659,15 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,

int kWidth = encoding.getKWidth();
auto numRep = mmaLayout.getMMAv2OrV3RepForOperand(
shapePerCTA, bitwidth, kWidth, encoding.getOpIdx());
shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx());

auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto order = triton::gpu::getOrder(mmaLayout);
auto warpOrder = mmaLayout.getWarpOrder();
Value warp = udiv(thread, i32_val(32));
Value lane = urem(thread, i32_val(32));

SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warp, warpsPerCTA, order);
delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder);
Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0]));
int warpsPerTile;
Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16));
Expand Down

0 comments on commit e82dfd9

Please sign in to comment.