Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate getOrder as "element order" and implement getRepOrder for general and NVIDIA layouts #5089

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
SmallVector<unsigned> getOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down
7 changes: 4 additions & 3 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,15 +466,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto threadOrder = blockedLayout.getThreadOrder();
auto warpOrder = blockedLayout.getWarpOrder();
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();

// delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);

SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
Expand Down
5 changes: 2 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For blocked, mma, and dotOperand layouts,
// though the elements are in registers, the order refers to memory
// layout of the original tensor in global memory.
// major (slow-varying). For distributed layouts, this represents
// the order of the elements within a thread.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);
Expand Down
11 changes: 11 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,16 @@ layout = [0 4 8 12]
[3 7 11 15]

For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding.

If the layout does not completely cover the tensor, we tile it until we cover the entire tensor.
We call each individual tile "rep".
}];

let methods = [
InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrder">,

// Interface for the meta information about the multiple thread hierarchy.
InterfaceMethod<"Get the shape of the CTAs per CGA.",
"SmallVector<unsigned>",
Expand Down Expand Up @@ -563,6 +570,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];

code extraDistributedDeclaration = extraBaseClassDeclaration # [{
SmallVector<unsigned> getRepOrder() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;
Expand Down Expand Up @@ -914,6 +922,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
auto rank = getWarpsPerCTA().size();
Expand Down Expand Up @@ -1022,6 +1031,7 @@ Row | warp 0 warp 2
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
static SmallVector<unsigned> getMNKDimPerInstr();

SmallVector<unsigned> getContigPerThread() {
Expand Down Expand Up @@ -1217,6 +1227,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
Expand Down
8 changes: 2 additions & 6 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,

assert(cvtNeedsSharedMemory(srcTy, dstTy));

// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
// TODO Implement getElemOrder and make sure it's consistent with
// getContigPerThread
auto inOrd = gpu::getThreadOrder(srcLayout);
auto outOrd = gpu::getThreadOrder(dstLayout);
auto inOrd = gpu::getOrder(srcLayout);
auto outOrd = gpu::getOrder(dstLayout);
scratchConfig.order = outOrd;

unsigned srcContigPerThread =
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getThreadOrder(layout);
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
Expand All @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getThreadOrder(layout);
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) {
return axis;
}

SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentThreadOrder(sliceEncoding.getParent());
return getParentOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
Expand All @@ -44,12 +44,12 @@ SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentThreadOrder(getSrcLayout())[0];
getParentOrder(getSrcLayout())[0];
}

SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getThreadOrder(srcLayout);
auto order = getOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ struct ReduceOpConversion
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];

auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
Expand Down Expand Up @@ -409,7 +409,7 @@ struct ReduceOpConversion
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,10 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,

auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto order = triton::gpu::getOrder(srcEncoding);
auto threadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto warpOrder = triton::gpu::getWarpOrder(srcEncoding);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);

Expand All @@ -402,7 +402,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
multiDimLaneId[axis] = i32_val(0);
threadsPerWarp[axis] = 1;
Value laneIdParallel =
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order);
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder);
multiDimWarpId[axis] = i32_val(0);
warpsPerCTA[axis] = 1;
Value warpIdParallel =
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
int numContiguousValues = 1;
auto encoding = cast<BlockedEncodingAttr>(
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
int splitDim = encoding.getThreadOrder().size() - 1;
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
if (encoding.getThreadOrder()[i] == splitDim)
int splitDim = encoding.getOrder().size() - 1;
for (int i = 0; i < encoding.getOrder().size(); i++) {
if (encoding.getOrder()[i] == splitDim)
break;
numContiguousValues *= encoding.getSizePerThread()[i];
}
Expand Down
46 changes: 44 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
return getMatrixOrder(rank, rowMajor);
}

SmallVector<unsigned> getRepOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getRepOrder();
else
llvm::report_fatal_error("Unimplemented usage of getRepOrder");
return {};
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getWarpOrder();
Expand All @@ -269,13 +277,13 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
return {};
}

// Returns the order of the elements in a layout from the fastest running
// dimension to the slowest
SmallVector<unsigned> getOrder(Attribute layout) {
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
return llvm::to_vector(blockedLayout.getOrder());
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the output elements in the LLVM lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand Down Expand Up @@ -643,6 +651,9 @@ unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here.
// But we need to have a consistent interface with e.g. SliceEncodingAttr, which
// computes some of these fields.
SmallVector<unsigned> BlockedEncodingAttr::getRepOrder() const {
return SmallVector<unsigned>(getOrder());
}
SmallVector<unsigned> BlockedEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -709,6 +720,10 @@ unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
return product<unsigned>(getElemsPerThread(shape, eltTy));
}
SmallVector<unsigned> SliceEncodingAttr::getRepOrder() const {
auto parentRepOrder = ::getRepOrder(getParent());
return eraseOrder(parentRepOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getCTASplitNum() const {
SmallVector<unsigned> res = ::getCTASplitNum(getParent());
res.erase(res.begin() + getDim());
Expand Down Expand Up @@ -1651,6 +1666,10 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
return {kDim, nDim};
}

SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder");
}

SmallVector<int64_t>
AMDMfmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
int kWidth, int opIdx) const {
Expand Down Expand Up @@ -1734,6 +1753,9 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
shapePerCTATile[rank - 1] *= mnkDim[1];
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder");
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -1858,6 +1880,10 @@ bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; }

bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; }

SmallVector<unsigned> NvidiaMmaEncodingAttr::getRepOrder() const {
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -2011,6 +2037,13 @@ SmallVector<int> NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const {
int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const {
return 2 * getMMAv1Rep(opIdx)[opIdx];
}

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int opIdx) const {
Expand Down Expand Up @@ -2147,6 +2180,15 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
//===----------------------------------------------------------------------===//
// DotOperand Encoding
//===----------------------------------------------------------------------===//
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
return mma.getRepOrderForOperand(getOpIdx());
}
llvm::report_fatal_error(
"getRepOrder not implemented for DotOperandEncodingAttr");
return {};
}

SmallVector<unsigned> DotOperandEncodingAttr::getThreadsPerWarp() const {
auto parent = getParent();
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
Expand Down
19 changes: 11 additions & 8 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// By using `reverse(dimNames)` below, we set the order to be row-major
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));

auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder());
assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down Expand Up @@ -327,7 +327,6 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

// TODO Make the getOrder of Hopper explicit here via an assert
MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down Expand Up @@ -875,14 +874,18 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();
// A and B have kMajor order
assert(getOrder(dot) ==

// The A and B operands are tiled in a kMajor fashion
auto kMajorOrder = dot.getRepOrder();
assert(kMajorOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));
permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder);
// This agrees with the order of the elements, which means that we can share
// the code below for both A and B without having to perform any swaps
assert(getOrder(dot) == kMajorOrder);

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
int32_t i = 1;
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass
return;
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto rank = srcThreadOrder.size();
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
SmallVector<unsigned> sharedOrder;
if (rank == 3) {
// add all elements except the element that is zero
for (unsigned i = 0; i < rank; ++i)
if (srcThreadOrder[i] != 0)
sharedOrder.emplace_back(srcThreadOrder[i]);
if (srcOrder[i] != 0)
sharedOrder.emplace_back(srcOrder[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = srcThreadOrder;
sharedOrder = srcOrder;
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
Expand Down
Loading
Loading