Skip to content

Commit

Permalink
Consolidate getOrder as "element order" and implement getRepOrder
Browse files Browse the repository at this point in the history
… for general and NVIDIA layouts (triton-lang#5089)

This partially reverts commit 38a11b8.
Supersedes triton-lang#5085

It also documents that we are implicitly choosing a way to tile a
full tensor depending on the layout. See
triton-lang#5085 (comment)
  • Loading branch information
lezcano authored and Luosuu committed Nov 13, 2024
1 parent dcef286 commit f2d7658
Show file tree
Hide file tree
Showing 16 changed files with 136 additions and 45 deletions.
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
SmallVector<unsigned> getOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down
7 changes: 4 additions & 3 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,15 +520,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto threadOrder = blockedLayout.getThreadOrder();
auto warpOrder = blockedLayout.getWarpOrder();
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();

// delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);

SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
Expand Down
5 changes: 2 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For blocked, mma, and dotOperand layouts,
// though the elements are in registers, the order refers to memory
// layout of the original tensor in global memory.
// major (slow-varying). For distributed layouts, this represents
// the order of the elements within a thread.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);
Expand Down
11 changes: 11 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,16 @@ layout = [0 4 8 12]
[3 7 11 15]

For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding.

If the layout does not completely cover the tensor, we tile it until we cover the entire tensor.
We call each individual tile "rep".
}];

let methods = [
InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrder">,

// Interface for the meta information about the multiple thread hierarchy.
InterfaceMethod<"Get the shape of the CTAs per CGA.",
"SmallVector<unsigned>",
Expand Down Expand Up @@ -563,6 +570,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];

code extraDistributedDeclaration = extraBaseClassDeclaration # [{
SmallVector<unsigned> getRepOrder() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;
Expand Down Expand Up @@ -914,6 +922,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
auto rank = getWarpsPerCTA().size();
Expand Down Expand Up @@ -1022,6 +1031,7 @@ Row | warp 0 warp 2
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
static SmallVector<unsigned> getMNKDimPerInstr();

SmallVector<unsigned> getContigPerThread() {
Expand Down Expand Up @@ -1217,6 +1227,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
Expand Down
8 changes: 2 additions & 6 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,

assert(cvtNeedsSharedMemory(srcTy, dstTy));

// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
// TODO Implement getElemOrder and make sure it's consistent with
// getContigPerThread
auto inOrd = gpu::getThreadOrder(srcLayout);
auto outOrd = gpu::getThreadOrder(dstLayout);
auto inOrd = gpu::getOrder(srcLayout);
auto outOrd = gpu::getOrder(dstLayout);
scratchConfig.order = outOrd;

unsigned srcContigPerThread =
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getThreadOrder(layout);
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
Expand All @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getThreadOrder(layout);
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) {
return axis;
}

SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentThreadOrder(sliceEncoding.getParent());
return getParentOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
Expand All @@ -44,12 +44,12 @@ SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentThreadOrder(getSrcLayout())[0];
getParentOrder(getSrcLayout())[0];
}

SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getThreadOrder(srcLayout);
auto order = getOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ struct ReduceOpConversion
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];

auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
Expand Down Expand Up @@ -409,7 +409,7 @@ struct ReduceOpConversion
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,10 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,

auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto order = triton::gpu::getOrder(srcEncoding);
auto threadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto warpOrder = triton::gpu::getWarpOrder(srcEncoding);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);

Expand All @@ -402,7 +402,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
multiDimLaneId[axis] = i32_val(0);
threadsPerWarp[axis] = 1;
Value laneIdParallel =
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order);
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder);
multiDimWarpId[axis] = i32_val(0);
warpsPerCTA[axis] = 1;
Value warpIdParallel =
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
int numContiguousValues = 1;
auto encoding = cast<BlockedEncodingAttr>(
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
int splitDim = encoding.getThreadOrder().size() - 1;
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
if (encoding.getThreadOrder()[i] == splitDim)
int splitDim = encoding.getOrder().size() - 1;
for (int i = 0; i < encoding.getOrder().size(); i++) {
if (encoding.getOrder()[i] == splitDim)
break;
numContiguousValues *= encoding.getSizePerThread()[i];
}
Expand Down
46 changes: 44 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
return getMatrixOrder(rank, rowMajor);
}

SmallVector<unsigned> getRepOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getRepOrder();
else
llvm::report_fatal_error("Unimplemented usage of getRepOrder");
return {};
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getWarpOrder();
Expand All @@ -269,13 +277,13 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
return {};
}

// Returns the order of the elements in a layout from the fastest running
// dimension to the slowest
SmallVector<unsigned> getOrder(Attribute layout) {
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
return llvm::to_vector(blockedLayout.getOrder());
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the output elements in the LLVM lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand Down Expand Up @@ -643,6 +651,9 @@ unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here.
// But we need to have a consistent interface with e.g. SliceEncodingAttr, which
// computes some of these fields.
SmallVector<unsigned> BlockedEncodingAttr::getRepOrder() const {
return SmallVector<unsigned>(getOrder());
}
SmallVector<unsigned> BlockedEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -709,6 +720,10 @@ unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
return product<unsigned>(getElemsPerThread(shape, eltTy));
}
SmallVector<unsigned> SliceEncodingAttr::getRepOrder() const {
auto parentRepOrder = ::getRepOrder(getParent());
return eraseOrder(parentRepOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getCTASplitNum() const {
SmallVector<unsigned> res = ::getCTASplitNum(getParent());
res.erase(res.begin() + getDim());
Expand Down Expand Up @@ -1651,6 +1666,10 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
return {kDim, nDim};
}

SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder");
}

SmallVector<int64_t>
AMDMfmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
int kWidth, int opIdx) const {
Expand Down Expand Up @@ -1734,6 +1753,9 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
shapePerCTATile[rank - 1] *= mnkDim[1];
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder");
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -1858,6 +1880,10 @@ bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; }

bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; }

SmallVector<unsigned> NvidiaMmaEncodingAttr::getRepOrder() const {
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -2011,6 +2037,13 @@ SmallVector<int> NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const {
int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const {
return 2 * getMMAv1Rep(opIdx)[opIdx];
}

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int opIdx) const {
Expand Down Expand Up @@ -2147,6 +2180,15 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
//===----------------------------------------------------------------------===//
// DotOperand Encoding
//===----------------------------------------------------------------------===//
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
return mma.getRepOrderForOperand(getOpIdx());
}
llvm::report_fatal_error(
"getRepOrder not implemented for DotOperandEncodingAttr");
return {};
}

SmallVector<unsigned> DotOperandEncodingAttr::getThreadsPerWarp() const {
auto parent = getParent();
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
Expand Down
19 changes: 11 additions & 8 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// By using `reverse(dimNames)` below, we set the order to be row-major
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));

auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder());
assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down Expand Up @@ -327,7 +327,6 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

// TODO Make the getOrder of Hopper explicit here via an assert
MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down Expand Up @@ -875,14 +874,18 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();
// A and B have kMajor order
assert(getOrder(dot) ==

// The A and B operands are tiled in a kMajor fashion
auto kMajorOrder = dot.getRepOrder();
assert(kMajorOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));
permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder);
// This agrees with the order of the elements, which means that we can share
// the code below for both A and B without having to perform any swaps
assert(getOrder(dot) == kMajorOrder);

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
int32_t i = 1;
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass
return;
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto rank = srcThreadOrder.size();
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
SmallVector<unsigned> sharedOrder;
if (rank == 3) {
// add all elements except the element that is zero
for (unsigned i = 0; i < rank; ++i)
if (srcThreadOrder[i] != 0)
sharedOrder.emplace_back(srcThreadOrder[i]);
if (srcOrder[i] != 0)
sharedOrder.emplace_back(srcOrder[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = srcThreadOrder;
sharedOrder = srcOrder;
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
Expand Down
Loading

0 comments on commit f2d7658

Please sign in to comment.