Skip to content

Commit

Permalink
[BACKEND]Fix DotOperand(Ampere) LinearLayoutConversion (triton-lang#5038
Browse files Browse the repository at this point in the history
)

We also clean a bit `TritonGPU/IR/Dialect.cpp` using some auxiliary
functions to make the intentions a bit clearer.

We add a few asserts in the `LinearLayoutConversion` to make sure it's
clear why we do certain things here and there.

We also kill `getCvtOrder`, as it was not used anywhere
  • Loading branch information
lezcano authored Nov 2, 2024
1 parent 144c7dc commit 56584c4
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 80 deletions.
11 changes: 11 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ unsigned getNumWarpsPerCTA(Attribute layout);

unsigned getNumCTAs(Attribute layout);

// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMajor
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

// Return true if a view between the two types cannot be implemented as a no-op.
Expand Down
74 changes: 34 additions & 40 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,19 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
return resOrder;
}

SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
assert(rank >= 2);
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
if (!rowMajor) {
std::swap(order[0], order[1]);
}
return order;
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor) {
// kMajor: if true, the matrix is fastest-running on k,
Expand All @@ -244,15 +257,8 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
// If opIdx is 1 and kMajor is true, the order is [0, 1]
// (resp. [1, 2, 0] if rank == 3)
// Same if opIdx is 0 and kMajor is false
if (bool(opIdx) == kMajor) {
std::swap(order[0], order[1]);
}
return order;
auto rowMajor = bool(opIdx) != kMajor;
return getMatrixOrder(rank, rowMajor);
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
Expand All @@ -262,20 +268,21 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
}
}
auto order = getOrder(layout);
// FIXME: This mmaLayout if should just return
// getOrderForDotOperand(0, order.size(), kMajor=false)
// as mma has the same order as DotOperand(opIdx=0)
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
// `MMAv2.cpp` to match.
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isHopper()) {
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
// Hopper MMA instructions force warps to be column-major
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
auto it = std::find(order.begin(), order.end(), 0);
order.erase(it);
order.insert(order.begin(), 0);
return getMatrixOrder(order.size(), /*rowMajor*/ false);
}
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
/*kMajor*/ false);
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error(
"DotOperandEncoding::getWarpOrder not implemented");
}
return order;
}
Expand All @@ -285,11 +292,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return llvm::to_vector(blockedLayout.getOrder());
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the elements in the MMAv2/V3 lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return order;
return getMatrixOrder(rank, /*rowMajor*/ true);
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
Expand Down Expand Up @@ -421,7 +428,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout))
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
return getNumWarpsPerCTA(dotLayout.getParent());
warpsPerCTA = dotLayout.getWarpsPerCTA();
else if (auto sharedLayout = dyn_cast<SharedEncodingAttr>(layout))
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
else
Expand Down Expand Up @@ -2136,25 +2143,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTATile = getShapePerCTATile(shape);
auto rank = parentShapePerCTATile.size();
auto shapePerCTATile = getShapePerCTATile(shape);
auto rank = shapePerCTATile.size();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
// 4 threads * 2 subtiles
unsigned kWidthTile = kWidth * 2 * 4;
if (opIdx == 0) {
if (rank == 2)
return {parentShapePerCTATile[rank - 2], kWidthTile};
else
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
kWidthTile};
} else if (opIdx == 1) {
if (rank == 2)
return {kWidthTile, parentShapePerCTATile[rank - 1]};
else
return {parentShapePerCTATile[0], kWidthTile,
parentShapePerCTATile[rank - 1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
shapePerCTATile[kDim] = kWidth * 2 * 4;
return shapePerCTATile;
}
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
Expand Down
104 changes: 78 additions & 26 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
return ret;
}

// TODO Have order be a mandatory argument of standardOutDimNames.
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
const SmallVector<unsigned> &order) {
assert(names.size() == order.size());
SmallVector<StringAttr> ret;
for (unsigned i : order) {
ret.push_back(names[i]);
}
return ret;
}

void assertIsRegisterLayout(const LinearLayout &layout) {
assert(layout.getNumInDims() > 0);
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();
Expand Down Expand Up @@ -281,15 +292,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// By using `reverse(dimNames)` below, we set the order to be row-major
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

ctaLayout *= identityND(
S("warp"), mma.getWarpsPerCTA(),
llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))), dimNames);
ArrayRef(orderedDimNames).take_front(2));
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant with the order of the out dims.
ctaLayout *=
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}
Expand Down Expand Up @@ -322,10 +337,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
S("register"), S("dim1"));

// Expand the `warp` dimension according to warpsPerCTA.
//
// It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but
// this really does seem to be correct.
// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
// Since the warpOrder needs to be M-major, we need to transpose the out
// dimensions AND transpose the order
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant. The order is already given by the order of the
// out dims, and if it has an order, it shouldn't change the
// order of the out dims.
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
{S("dim0"), S("dim1")})
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
Expand Down Expand Up @@ -843,18 +862,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// TODO,BE. Implement ampereMMA in terms of this one
// Note that, even though MMAv2 looks similar to this layout, they are just
// the same at a register and lane level. The warps treatment is different!
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;

assert(mma.isAmpere());
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
// A and B have kMajor order
assert(getOrder(dot) ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
Expand All @@ -881,24 +906,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
}
registers.push_back({i, 0});

if (!isA) {
for (auto &r : registers) {
std::swap(r[0], r[1]);
LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
ArrayRef(kMajorDims).take_front(2));

// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {0, 1}
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
// the C is owned as per the following layout:
// C: 0 | 1
// - | -
// 2 | 3
// In order to be able to compute C, we need the following warp tiling of
// A and B:
// A: 0 1 | 0 1 B: 0 2 | 1 3
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In particular, for A and B we need to broadcast along K

assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
auto warpsPerCTAMma = mma.getWarpsPerCTA();
std::vector<std::vector<int32_t>> warps;
if (isA) {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, 0});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, i});
}
} else {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, i});
}
for (auto &l : lanes) {
std::swap(l[0], l[1]);
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, 0});
}
}
if (rank == 3) {
for (auto &w : warps) {
w.push_back(0);
}
}

LinearLayout ctaLayout(
{{S("register"), registers}, {S("lane"), lanes}},
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));

auto order = dot.getCTAOrder();
assert(order[0] == rank - 1 && order[1] == rank - 2);
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}

std::optional<LinearLayout>
Expand All @@ -907,7 +959,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) {
if (mma.isAmpere()) {
return ampereDotToLinearLayout(shape, *this);
}
}
Expand Down
16 changes: 6 additions & 10 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,15 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
}

if (dot.getOpIdx() == 1) {
// there are kWidth * 2 elems packed as bf16x2
int elemsInTile = dot.getKWidth();
// n0 and n1 are unrolled in the legacy path
// Unrolling n1 makes some sense, but unrolling n0 makes absolutely no
// sense IMO
// n0 is unrolled in the legacy path, which makes no sense
n0 *= 2;
n1 *= 2;
for (auto b = 0; b < batch; ++b)
for (auto j = 0; j < n1 / elemsInTile; ++j)
for (auto i = 0; i < n0; ++i)
for (auto k = 0; k < elemsInTile; ++k) {
vals[{b, i, elemsInTile * j + k}] = elems[offset++];
}
for (auto i = 0; i < n0; ++i)
for (auto j = 0; j < n1; ++j) {
vals[{b, i, 2 * j}] = elems[offset++];
vals[{b, i, 2 * j + 1}] = elems[offset++];
}
return vals;
}
}
Expand Down
41 changes: 37 additions & 4 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,14 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
{2, 0},
{4, 0},
{32, 0},
{64, 0},
{0, 8},
{0, 16},
{0, 32},
{64, 0}}},
{0, 32}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
{{0, 0}, {0, 0}},
},
{S("block"), {}},
},
Expand All @@ -582,13 +582,46 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{
S("warp"),
{},
{{0, 0}, {0, 0}},
},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) {
EXPECT_EQ(
toLinearLayout({32, 64}, dotMMAv2(0, 8, {2, 2})),
LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{0, 0}, {16, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(
toLinearLayout({64, 16}, dotMMAv2(1, 8, {2, 2})),
LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{S("warp"), {{0, 8}, {0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(0, 8, {2, 2})),
LinearLayout(
{{S("register"),
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{0, 0}, {16, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
EXPECT_EQ(
toLinearLayout({128, 32}, dotMMAv2(1, 8, {2, 2})),
LinearLayout(
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}},
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
{S("warp"), {{0, 8}, {0, 0}}},
{S("block"), {}}},
{S("dim0"), S("dim1")}));
}

TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
/*isTransposed=*/false);
Expand Down

0 comments on commit 56584c4

Please sign in to comment.