Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Nov 20, 2024
1 parent 13cf4fd commit 9fdfc9c
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace {

#define S(v) StringAttr::get(ctx, (v))

// Returns ["out0", "out1", ..., "out<rank-1>"].
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
SmallVector<StringAttr> ret;
for (int i = 0; i < rank; i++) {
Expand Down Expand Up @@ -74,8 +74,8 @@ void assertIsRegisterLayout(const LinearLayout &layout) {
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
// permute(shape, order).
LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
MLIRContext *ctx = inDimName.getContext();
auto rank = shape.size();
Expand Down Expand Up @@ -305,7 +305,8 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
ArrayRef(orderedDimNames).take_front(2));
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder());
ctaLayout *=
identityStandardND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder());

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}
Expand Down Expand Up @@ -342,8 +343,9 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
// Since the warpOrder needs to be M-major, we need to transpose the out
// dimensions AND transpose the order
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
ctaLayout *=
identityStandardND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}
Expand Down Expand Up @@ -558,7 +560,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

// And each warp takes the same register and lane sub-layout. So mulitply with
// an identity layout for the warp.
LinearLayout warpLayout = identityND(S("warp"), getWarpsPerCTA(), order);
LinearLayout warpLayout =
identityStandardND(S("warp"), getWarpsPerCTA(), order);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand Down Expand Up @@ -668,7 +671,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
}

LinearLayout warpLayout = identityND(kWarp, warpsPerCTA, warpOrder);
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);

LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
warpLayout.transposeOuts(outDimNames);
Expand Down Expand Up @@ -749,7 +752,8 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

// And each warp takes the same register and lane sub-layout. So mulitply with
// an identity layout for the warp.
LinearLayout warpLayout = identityND(S("warp"), getWarpsPerCTA(), order);
LinearLayout warpLayout =
identityStandardND(S("warp"), getWarpsPerCTA(), order);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
Expand All @@ -765,9 +769,9 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

const auto &order = getOrder();
LinearLayout ctaLayout =
identityND(S("register"), getSizePerThread(), order) *
identityND(S("lane"), getThreadsPerWarp(), order) *
identityND(S("warp"), getWarpsPerCTA(), order);
identityStandardND(S("register"), getSizePerThread(), order) *
identityStandardND(S("lane"), getThreadsPerWarp(), order) *
identityStandardND(S("warp"), getWarpsPerCTA(), order);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}
Expand Down Expand Up @@ -1151,7 +1155,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(

// Expand the `warp` dimension according to warpsPerCTA.
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
layout *= identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));

// Expand the `register` dimension so the size of columns matches `n`.
Expand Down Expand Up @@ -1190,7 +1194,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol);

// Expand the `warp` dimension according to warpsPerCTA.
layout *= identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
auto ret =
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
Expand Down

0 comments on commit 9fdfc9c

Please sign in to comment.