Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Fix uses of getOrder(DotOperand(Nvidia) and MMA(Nvidia)) #5055

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getOrderWithAxisAtBeginning();
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down
3 changes: 3 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
auto dstShapePerCTATile =
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());

assert(srcTy.getRank() == dstTy.getRank() &&
"src and dst must have the same rank");

unsigned rank = dstTy.getRank();
SmallVector<unsigned> repShape(rank);
for (unsigned d = 0; d < rank; ++d) {
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
Expand All @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) {
return axis;
}

SmallVector<unsigned> getParentOrder(Attribute layout) {
SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
return getParentThreadOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
Expand All @@ -44,12 +44,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
getParentThreadOrder(getSrcLayout())[0];
}

SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ struct ReduceOpConversion
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];

auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
Expand Down Expand Up @@ -370,7 +370,7 @@ struct ReduceOpConversion
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
int numContiguousValues = 1;
auto encoding = cast<BlockedEncodingAttr>(
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
int splitDim = encoding.getOrder().size() - 1;
for (int i = 0; i < encoding.getOrder().size(); i++) {
if (encoding.getOrder()[i] == splitDim)
int splitDim = encoding.getThreadOrder().size() - 1;
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
if (encoding.getThreadOrder()[i] == splitDim)
break;
numContiguousValues *= encoding.getSizePerThread()[i];
}
Expand Down Expand Up @@ -336,7 +336,6 @@ struct BroadcastOpConversion
unsigned rank = srcTy.getRank();
auto typeConverter = getTypeConverter();
assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> srcVals = unpackLLElements(loc, src, rewriter);
Expand Down
71 changes: 33 additions & 38 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
}

/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
/* Utility function used by get.*Order methods of SliceEncodingAttr.
* Erase dim and decrease all values larger than dim by 1.
* Example: order = [0, 2, 4, 3, 1], dim = 2
* resOrder = [0, 3, 2, 1]
Expand Down Expand Up @@ -262,29 +262,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
return getWarpOrder(dotLayout.getParent());
}
}
auto order = getOrder(layout);
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
// `MMAv2.cpp` to match.
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isHopper()) {
// Hopper MMA instructions force warps to be column-major
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
return getMatrixOrder(order.size(), /*rowMajor*/ false);
}
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error(
"DotOperandEncoding::getWarpOrder not implemented");
}
return order;
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getWarpOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
}

SmallVector<unsigned> getOrder(Attribute layout) {
Expand All @@ -293,7 +275,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the elements in the MMAv2/V3 lowerings. We choose row-major
// the output elements in the LLVM lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand All @@ -318,15 +300,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {

llvm::report_fatal_error("Unimplemented usage of getOrder");
return {};
};
}

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};
}

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
Expand Down Expand Up @@ -769,7 +751,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
return warpsPerCTA;
}
SmallVector<unsigned> SliceEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto parentWarpOrder = ::getWarpOrder(getParent());
return eraseOrder(parentWarpOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
auto parent = getParent();
Expand All @@ -781,7 +764,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
return threadsPerWarp;
}
SmallVector<unsigned> SliceEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto parentThreadOrder = ::getThreadOrder(getParent());
return eraseOrder(parentThreadOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
auto sizePerThread = ::getSizePerThread(getParent());
Expand Down Expand Up @@ -1049,7 +1033,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
return warps;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
return ::getWarpOrder(getParent());
}
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
Expand Down Expand Up @@ -1597,7 +1588,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
auto order = ::getOrder(*this);
Expand Down Expand Up @@ -1766,7 +1757,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
Expand Down Expand Up @@ -1890,7 +1881,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto rank = getWarpsPerCTA().size();
// Hopper (wgmma) uses column-major as this is embeded in the instruction
// For Ampere we can choose either row-major or column-major.
// We choose row-major as the legacy path did so
return getMatrixOrder(rank, /*rowMajor*/ !isHopper());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
auto rank = getWarpsPerCTA().size();
Expand All @@ -1914,10 +1909,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
"getThreadsPerWarp not implemented for unknown Mma version ");
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
auto rank = ::getOrder(*this).size();
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
if (isAmpere()) {
res[rank - 2] = 2;
Expand Down Expand Up @@ -2158,11 +2154,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
if (opIdx == 0) {
sizePerThread[rank - 2] = 2;
sizePerThread[rank - 1] = 2 * kWidth;
} else if (opIdx == 1) {
} else {
assert(opIdx == 1);
sizePerThread[rank - 2] = 2 * kWidth;
sizePerThread[rank - 1] = 1;
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
return sizePerThread;
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

// TODO Make the getOrder of Hopper explicit here via an assert
MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass
return;
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto rank = srcThreadOrder.size();
SmallVector<unsigned> sharedOrder;
if (rank == 3) {
// add all elements except the element that is zero
for (unsigned i = 0; i < rank; ++i)
if (srcOrder[i] != 0)
sharedOrder.emplace_back(srcOrder[i]);
if (srcThreadOrder[i] != 0)
sharedOrder.emplace_back(srcThreadOrder[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = srcOrder;
sharedOrder = srcThreadOrder;
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
Expand Down
1 change: 0 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,6 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field)

if constant_field == "value":
print(output, ref)
assert torch.all(output == ref)
else:
assert torch.all(output == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,16 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto threadOrder = triton::gpu::getThreadOrder(layout);
auto warpOrder = triton::gpu::getWarpOrder(layout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you call getWarpOrder(dot operand layout) here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it seems just fine to still have getWarpOrder defined for dot operand layouts.

But it causes confusion for you, I will add a condition for dot operand layouts here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it will break hard. That is good, because incidentally delinearize would not work as expected. Before it would create incorrect results, now at least it breaks hard.

Note that there is no combination of warpsPerCTA and warpOrder that can represent the order of DotOperands, so I think it's better to break hard,

auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DecomposeLocalLoadToDotOperand
type.getShape(), type.getElementType(),
triton::gpu::SharedEncodingAttr::get(
op.getContext(), dstDotOp, type.getShape(),
triton::gpu::getOrder(parentEnc),
triton::gpu::getThreadOrder(parentEnc),
triton::gpu::getCTALayout(parentEnc), type.getElementType()),
srcType.getMemorySpace());
auto tmp = rewriter.create<triton::gpu::LocalAllocOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto threadOrder = triton::gpu::getThreadOrder(layout);
auto warpOrder = triton::gpu::getWarpOrder(layout);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
Value warpSize = i32_val(32);
Expand All @@ -47,7 +47,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTATile[dim])
Expand Down
Loading