Skip to content

Commit

Permalink
[BACKEND] Fix uses of getOrder(DotOperand(Nvidia) and MMA(Nvidia)) (t…
Browse files Browse the repository at this point in the history
…riton-lang#5055)

We use `getOrder` very liberally throughout the codebase, when we really
meant to use `getThreadOrder`. This is an issue with the input layout is
an
`DotOperand(mma(opIdx=1))`, where the thread order and the matrix order
are opposite.

Found this to be an issue when a PR changed the `getOrder` of
`DotOperand(Hopper)` to an incorrect one and CI still passed! The issue
here is that the LLVM lowering for wgmma and the LinearLayout does not
use `getOrder`, but there are many other subsystems do, and many
heuristics would be getting an incorrect order, and potentially be
disabled.

This is particularly problematic for `DotOperand(opIdx=1)` in nvidia
hardware, as `getThreadOrder` and `getOrder` are different!

While doing so we:
- Audit most (all?) the calls to `getOrder(dotOperand)`. It turns out
that most of them really meant `getThreadOrder`
- Fix the ordering methods of `SliceEncodingAttr` to be consistent
- Move the implementation of `getWarpOrder` to the Attr classes, because
of OOP

The test strategy was to add `llvm::report_fatal_error("Testing");`
within `getOrder(nvidiaMma)` and `getOrder(DotOperand(nvidiaMma))` and
triaging all errors that were raised in CI.

(cherry picked from commit 38a11b8)
  • Loading branch information
lezcano authored and jataylo committed Dec 12, 2024
1 parent 0446f1a commit 832c369
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 64 deletions.
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getOrderWithAxisAtBeginning();
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down
3 changes: 3 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());

assert(srcTy.getRank() == dstTy.getRank() &&
"src and dst must have the same rank");

unsigned rank = dstTy.getRank();
SmallVector<unsigned> repShape(rank);
for (unsigned d = 0; d < rank; ++d) {
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
Expand All @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) {
return axis;
}

SmallVector<unsigned> getParentOrder(Attribute layout) {
SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
return getParentThreadOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
Expand All @@ -44,12 +44,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
getParentThreadOrder(getSrcLayout())[0];
}

SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ struct ReduceOpConversion
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];

auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
Expand Down Expand Up @@ -363,7 +363,7 @@ struct ReduceOpConversion
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
int numContiguousValues = 1;
auto encoding = cast<BlockedEncodingAttr>(
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
int splitDim = encoding.getOrder().size() - 1;
for (int i = 0; i < encoding.getOrder().size(); i++) {
if (encoding.getOrder()[i] == splitDim)
int splitDim = encoding.getThreadOrder().size() - 1;
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
if (encoding.getThreadOrder()[i] == splitDim)
break;
numContiguousValues *= encoding.getSizePerThread()[i];
}
Expand Down Expand Up @@ -336,7 +336,6 @@ struct BroadcastOpConversion
unsigned rank = srcTy.getRank();
auto typeConverter = getTypeConverter();
assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> srcVals = unpackLLElements(loc, src, rewriter);
Expand Down
71 changes: 33 additions & 38 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
}

/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
/* Utility function used by get.*Order methods of SliceEncodingAttr.
* Erase dim and decrease all values larger than dim by 1.
* Example: order = [0, 2, 4, 3, 1], dim = 2
* resOrder = [0, 3, 2, 1]
Expand Down Expand Up @@ -262,29 +262,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
return getWarpOrder(dotLayout.getParent());
}
}
auto order = getOrder(layout);
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
// `MMAv2.cpp` to match.
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isHopper()) {
// Hopper MMA instructions force warps to be column-major
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
return getMatrixOrder(order.size(), /*rowMajor*/ false);
}
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error(
"DotOperandEncoding::getWarpOrder not implemented");
}
return order;
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getWarpOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
}

SmallVector<unsigned> getOrder(Attribute layout) {
Expand All @@ -293,7 +275,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the elements in the MMAv2/V3 lowerings. We choose row-major
// the output elements in the LLVM lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand All @@ -318,15 +300,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {

llvm::report_fatal_error("Unimplemented usage of getOrder");
return {};
};
}

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};
}

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
Expand Down Expand Up @@ -769,7 +751,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
return warpsPerCTA;
}
SmallVector<unsigned> SliceEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto parentWarpOrder = ::getWarpOrder(getParent());
return eraseOrder(parentWarpOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
auto parent = getParent();
Expand All @@ -781,7 +764,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
return threadsPerWarp;
}
SmallVector<unsigned> SliceEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto parentThreadOrder = ::getThreadOrder(getParent());
return eraseOrder(parentThreadOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
auto sizePerThread = ::getSizePerThread(getParent());
Expand Down Expand Up @@ -1049,7 +1033,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
return warps;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
return ::getWarpOrder(getParent());
}
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
Expand Down Expand Up @@ -1597,7 +1588,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
auto order = ::getOrder(*this);
Expand Down Expand Up @@ -1766,7 +1757,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
Expand Down Expand Up @@ -1890,7 +1881,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto rank = getWarpsPerCTA().size();
// Hopper (wgmma) uses column-major as this is embeded in the instruction
// For Ampere we can choose either row-major or column-major.
// We choose row-major as the legacy path did so
return getMatrixOrder(rank, /*rowMajor*/ !isHopper());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
auto rank = getWarpsPerCTA().size();
Expand All @@ -1914,10 +1909,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
"getThreadsPerWarp not implemented for unknown Mma version ");
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
auto rank = ::getOrder(*this).size();
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
if (isAmpere()) {
res[rank - 2] = 2;
Expand Down Expand Up @@ -2158,11 +2154,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
if (opIdx == 0) {
sizePerThread[rank - 2] = 2;
sizePerThread[rank - 1] = 2 * kWidth;
} else if (opIdx == 1) {
} else {
assert(opIdx == 1);
sizePerThread[rank - 2] = 2 * kWidth;
sizePerThread[rank - 1] = 1;
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
return sizePerThread;
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

// TODO Make the getOrder of Hopper explicit here via an assert
MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass
return;
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto rank = srcThreadOrder.size();
SmallVector<unsigned> sharedOrder;
if (rank == 3) {
// add all elements except the element that is zero
for (unsigned i = 0; i < rank; ++i)
if (srcOrder[i] != 0)
sharedOrder.emplace_back(srcOrder[i]);
if (srcThreadOrder[i] != 0)
sharedOrder.emplace_back(srcThreadOrder[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = srcOrder;
sharedOrder = srcThreadOrder;
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
Expand Down
7 changes: 4 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto threadOrder = triton::gpu::getThreadOrder(layout);
auto warpOrder = triton::gpu::getWarpOrder(layout);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DecomposeLocalLoadToDotOperand
type.getShape(), type.getElementType(),
triton::gpu::SharedEncodingAttr::get(
op.getContext(), dstDotOp, type.getShape(),
triton::gpu::getOrder(parentEnc),
triton::gpu::getThreadOrder(parentEnc),
triton::gpu::getCTALayout(parentEnc), type.getElementType()),
srcType.getMemorySpace());
auto tmp = rewriter.create<triton::gpu::LocalAllocOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto threadOrder = triton::gpu::getThreadOrder(layout);
auto warpOrder = triton::gpu::getWarpOrder(layout);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
Value warpSize = i32_val(32);
Expand All @@ -47,7 +47,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
Expand Down

0 comments on commit 832c369

Please sign in to comment.