Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Nov 20, 2024
1 parent 9fdfc9c commit 6a13d3f
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,11 +780,12 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
// Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
// Like LinearLayout::empty() but with a rank and an order
int rank = repOrder.size();
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout ctaLayout =
identityND(S("register"), trivialShape, repOrder, dimNames);
identityStandardND(S("register"), trivialShape, repOrder);

assert(rank >= 2);
auto inner = order[0];
Expand Down Expand Up @@ -838,11 +839,7 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
// The triton orders are defined on [dim0, dim1, ...], so we need to pass
// those dims Then, for some reason, operator* requires the orders to match
// so we need to reorder the outs to match
// FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
// The order in triton assumes the standardDims, so it should
// use those.
ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(),
standardOutDimNames(ctx, rank))
ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), getWarpOrder())
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand All @@ -866,11 +863,8 @@ LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
// In other words, we need to broadcast along K
auto rank = mmaWarpOrder.size();
auto inner = isA ? rank - 1 : rank - 2;
auto outer = isA ? rank - 2 : rank - 1;
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout warpLayout =
identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames);
LinearLayout warpLayout = LinearLayout::empty();

// We have to broadcast along the inner dimension
// For A, when moving along M we go from 0 to 2.
Expand Down

0 comments on commit 6a13d3f

Please sign in to comment.