Skip to content

Commit

Permalink
Simplify identityND
Browse files Browse the repository at this point in the history
The auxiliary function `identityND` used to take an `order` parameter,
that comes from triton, and a set of dimensions. Now, the order in
triton is defined wrt. `dim0..dim<rank-1>`, so the dimension arg was
redundant. This was quite confusing.

We see that in all the uses of `identiyND`, we would pass the canonical
dimensions, other than in one that we simply remove as it was not
necessary.

We remove the dims arg and simply return a layout with output dims
`dim0..dim<rank-1>`.
  • Loading branch information
lezcano committed Nov 20, 2024
1 parent 7bce361 commit 7741dfb
Showing 1 changed file with 27 additions and 33 deletions.
60 changes: 27 additions & 33 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace {

#define S(v) StringAttr::get(ctx, (v))

// Returns ["out0", "out1", ..., "out<rank-1>"].
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
SmallVector<StringAttr> ret;
for (int i = 0; i < rank; i++) {
Expand Down Expand Up @@ -71,14 +71,18 @@ void assertIsRegisterLayout(const LinearLayout &layout) {
expectedOuts.end()));
}

// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of
// size product(shape) and then reshaping to permute(shape, order).
LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order,
ArrayRef<StringAttr> outDimNames) {
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
// permute(shape, order).
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());

MLIRContext *ctx = inDimName.getContext();
auto rank = shape.size();

// The order in triton is written wrt. [dim0, dim1, ...].
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

LinearLayout ret = LinearLayout::empty();
for (int i = 0; i < shape.size(); i++) {
// Start with the most-minor dimension, which is order[0].
Expand Down Expand Up @@ -491,7 +495,7 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
// And each warp takes the same register and lane sub-layout. So mulitply with
// an identity layout for the warp.
LinearLayout warpLayout =
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
identityStandardND(S("warp"), getWarpsPerCTA(), order);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand Down Expand Up @@ -601,8 +605,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
}

LinearLayout warpLayout =
identityND(kWarp, warpsPerCTA, warpOrder, outDimNames);
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);

LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
warpLayout.transposeOuts(outDimNames);
Expand Down Expand Up @@ -684,7 +687,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
// And each warp takes the same register and lane sub-layout. So mulitply with
// an identity layout for the warp.
LinearLayout warpLayout =
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
identityStandardND(S("warp"), getWarpsPerCTA(), order);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand All @@ -700,9 +703,9 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

const auto &order = getOrder();
LinearLayout ctaLayout =
identityND(S("register"), getSizePerThread(), order, outDimNames) *
identityND(S("lane"), getThreadsPerWarp(), order, outDimNames) *
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
identityStandardND(S("register"), getSizePerThread(), order) *
identityStandardND(S("lane"), getThreadsPerWarp(), order) *
identityStandardND(S("warp"), getWarpsPerCTA(), order);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}
Expand All @@ -711,11 +714,12 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
// Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
// Like LinearLayout::empty() but with a rank and an order
int rank = repOrder.size();
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout ctaLayout =
identityND(S("register"), trivialShape, repOrder, dimNames);
identityStandardND(S("register"), trivialShape, repOrder);

assert(rank >= 2);
auto inner = order[0];
Expand Down Expand Up @@ -769,11 +773,7 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
// The triton orders are defined on [dim0, dim1, ...], so we need to pass
// those dims Then, for some reason, operator* requires the orders to match
// so we need to reorder the outs to match
// FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
// The order in triton assumes the standardDims, so it should
// use those.
ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(),
standardOutDimNames(ctx, rank))
ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), getWarpOrder())
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand All @@ -797,11 +797,8 @@ LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
// In other words, we need to broadcast along K
auto rank = mmaWarpOrder.size();
auto inner = isA ? rank - 1 : rank - 2;
auto outer = isA ? rank - 2 : rank - 1;
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout warpLayout =
identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames);
LinearLayout warpLayout = LinearLayout::empty();

// We have to broadcast along the inner dimension
// For A, when moving along M we go from 0 to 2.
Expand Down Expand Up @@ -1086,9 +1083,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(

// Expand the `warp` dimension according to warpsPerCTA.
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
layout *=
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));

// Expand the `register` dimension so the size of columns matches `n`.
int n = mma.getInstrShape()[1];
Expand Down Expand Up @@ -1126,9 +1122,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol);

// Expand the `warp` dimension according to warpsPerCTA.
layout *=
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
auto ret =
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape());
Expand All @@ -1138,9 +1133,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
ret = ensureLayoutNotSmallerThan(ret, namedTensorShape);
ret = ensureLayoutNotLargerThan(ret, namedTensorShape);
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
.reshapeOuts({{S("offset"), ret.getTotalOutDimSize()},
{S("iteration"), 1}}) *
identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")});
.reshapeOuts(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
}

} // anonymous namespace
Expand Down

0 comments on commit 7741dfb

Please sign in to comment.