Skip to content

Commit

Permalink
Consolidate getOrder as "element order" and document implicit "tile…
Browse files Browse the repository at this point in the history
… order"

This partially reverts commit 38a11b8.

It also documents that we are implicitly choosing a way to tile a
full tensor depending on the layout. See #5085 (comment)
  • Loading branch information
lezcano committed Nov 7, 2024
1 parent 1cf7b1b commit c726428
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 46 deletions.
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
SmallVector<unsigned> getOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down
7 changes: 4 additions & 3 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,15 +466,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
auto threadOrder = blockedLayout.getThreadOrder();
auto warpOrder = blockedLayout.getWarpOrder();
auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape);
unsigned rank = shape.size();

// delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);

SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
Expand Down
5 changes: 2 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For blocked, mma, and dotOperand layouts,
// though the elements are in registers, the order refers to memory
// layout of the original tensor in global memory.
// major (slow-varying). For distributed layouts, this represents
// the order of the elements within a thread.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);
Expand Down
8 changes: 8 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,15 @@ layout = [0 4 8 12]
[3 7 11 15]

For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding.

If the layout does not completely cover the tensor, we tile it until we cover the entire tensor.
The order of this tiling is not defined explicitly in the language at the time of this writing.
In most cases, it is the same as the order of the elements, but there are exceptions like
}];
// FIXME: See the sentence above and the comment in
// https://github.com/triton-lang/triton/pull/5089#discussion_r1832243393
// We need to add a getTilingOrder method to this class that returns the order used
// to tile the tensor when the layout is not large enough.

let methods = [
// Interface for the meta information about the multiple thread hierarchy.
Expand Down
8 changes: 2 additions & 6 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,

assert(cvtNeedsSharedMemory(srcTy, dstTy));

// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
// TODO Implement getElemOrder and make sure it's consistent with
// getContigPerThread
auto inOrd = gpu::getThreadOrder(srcLayout);
auto outOrd = gpu::getThreadOrder(dstLayout);
auto inOrd = gpu::getOrder(srcLayout);
auto outOrd = gpu::getOrder(dstLayout);
scratchConfig.order = outOrd;

unsigned srcContigPerThread =
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getThreadOrder(layout);
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
Expand All @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getThreadOrder(layout);
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) {
return axis;
}

SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentThreadOrder(sliceEncoding.getParent());
return getParentOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
Expand All @@ -44,12 +44,12 @@ SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentThreadOrder(getSrcLayout())[0];
getParentOrder(getSrcLayout())[0];
}

SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getThreadOrder(srcLayout);
auto order = getOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ struct ReduceOpConversion
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];

auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
Expand Down Expand Up @@ -409,7 +409,7 @@ struct ReduceOpConversion
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,10 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,

auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto order = triton::gpu::getOrder(srcEncoding);
auto threadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto warpOrder = triton::gpu::getWarpOrder(srcEncoding);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);

Expand All @@ -402,7 +402,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
multiDimLaneId[axis] = i32_val(0);
threadsPerWarp[axis] = 1;
Value laneIdParallel =
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order);
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder);
multiDimWarpId[axis] = i32_val(0);
warpsPerCTA[axis] = 1;
Value warpIdParallel =
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
int numContiguousValues = 1;
auto encoding = cast<BlockedEncodingAttr>(
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
int splitDim = encoding.getThreadOrder().size() - 1;
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
if (encoding.getThreadOrder()[i] == splitDim)
int splitDim = encoding.getOrder().size() - 1;
for (int i = 0; i < encoding.getOrder().size(); i++) {
if (encoding.getOrder()[i] == splitDim)
break;
numContiguousValues *= encoding.getSizePerThread()[i];
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,13 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
return {};
}

// Returns the order of the elements in a layout from the fastest running
// dimension to the slowest
SmallVector<unsigned> getOrder(Attribute layout) {
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
return llvm::to_vector(blockedLayout.getOrder());
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
// Order doesn't really matter. We just have to be consistent when unpacking
// the output elements in the LLVM lowerings. We choose row-major
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
auto rank = distributedLayout.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand Down
20 changes: 11 additions & 9 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
// By using `reverse(dimNames)` below, we set the order to be row-major
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
// Mma is tiled in a row-major fashion
auto orderedDimNames =
permuteDimNames(dimNames, getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down Expand Up @@ -327,7 +327,6 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

// TODO Make the getOrder of Hopper explicit here via an assert
MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
Expand Down Expand Up @@ -875,14 +874,17 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();
// A and B have kMajor order
assert(getOrder(dot) ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

// The A and B operands are tiled in a kMajor fashion
auto kMajorOrder =
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true);

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));
permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder);
// This agrees with the order of the elements, which means that we can share
// the code below for both A and B without having to perform any swaps
assert(getOrder(dot) == kMajorOrder);

// Implement A. For B transpose in the end
std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
int32_t i = 1;
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass
return;
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding);
auto rank = srcThreadOrder.size();
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
SmallVector<unsigned> sharedOrder;
if (rank == 3) {
// add all elements except the element that is zero
for (unsigned i = 0; i < rank; ++i)
if (srcThreadOrder[i] != 0)
sharedOrder.emplace_back(srcThreadOrder[i]);
if (srcOrder[i] != 0)
sharedOrder.emplace_back(srcOrder[i]);
sharedOrder.emplace_back(0);
} else {
sharedOrder = srcThreadOrder;
sharedOrder = srcOrder;
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DecomposeLocalLoadToDotOperand
type.getShape(), type.getElementType(),
triton::gpu::SharedEncodingAttr::get(
op.getContext(), dstDotOp, type.getShape(),
triton::gpu::getThreadOrder(parentEnc),
triton::gpu::getOrder(parentEnc),
triton::gpu::getCTALayout(parentEnc), type.getElementType()),
srcType.getMemorySpace());
auto tmp = rewriter.create<triton::gpu::LocalAllocOp>(
Expand Down

0 comments on commit c726428

Please sign in to comment.