Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAYOUTS] Unify the implementation of getShapePerCTATile #5183

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
* (3) In the implementation of emitIndices, ShapePerCTATile will
* be replicated or wrapped to fit ShapePerCTA.
*/
SmallVector<unsigned>
getShapePerCTATile(Attribute layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<unsigned> getShapePerCTATile(Attribute layout);

SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape);
Expand Down
16 changes: 0 additions & 16 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,6 @@ We call each individual tile "rep".
"SmallVector<unsigned>",
"getCTASplitNum">,

InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA",
"SmallVector<unsigned>",
"getShapePerCTATile",
(ins "ArrayRef<int64_t>":$tensorShape)>,

InterfaceMethod<"Gets the number of contiguous elements per thread.",
"SmallVector<unsigned>",
"getContigPerThread">,
Expand Down Expand Up @@ -565,7 +560,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
SmallVector<unsigned> getThreadOrder() const;

SmallVector<unsigned> getSizePerThread() const;
SmallVector<unsigned> getShapePerCTATile(ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>()) const;

std::optional<LinearLayout> toLinearLayout(ArrayRef<int64_t> shape) const;
}];
Expand Down Expand Up @@ -765,13 +759,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
"bool",
"supportReduction">,

InterfaceMethod<"Return shape per CTA.",
"SmallVector<unsigned>",
"getShapePerCTATileForOperand",
(ins "ArrayRef<int64_t>":$tensorShape,
"int":$kWidth,
"int":$opIdx)>,

InterfaceMethod<"Return size per thread for dot operands.",
"SmallVector<unsigned>",
"getSizePerThreadForOperand",
Expand Down Expand Up @@ -900,7 +887,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
return true;
}
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
Expand Down Expand Up @@ -1008,7 +994,6 @@ Row | warp 0 warp 2
return true;
}
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Expand Down Expand Up @@ -1140,7 +1125,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
return false;
};
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
assert(isAmpere() || isHopper());
Expand Down
6 changes: 2 additions & 4 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,

auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
auto srcShapePerCTATile =
gpu::getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile =
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout);
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout);

assert(srcTy.getRank() == dstTy.getRank() &&
"src and dst must have the same rank");
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout);
auto dstShapePerCTATile = getShapePerCTATile(dstLayout);
auto shapePerCTA = getShapePerCTA(srcLayout, shape);

for (unsigned d = 0; d < rank; ++d) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ struct ReduceOpConversion
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
resultLayout, resultTy, true);
auto resultShape = resultTy.getShape();
auto resultCTATile = getShapePerCTATile(resultLayout, resultShape);
auto resultCTATile = getShapePerCTATile(resultLayout);
assert(resultIndices.size() == resultElems);

SmallVector<Value> resultVals(resultElems);
Expand Down
138 changes: 19 additions & 119 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,25 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
}
return ret;
}

SmallVector<unsigned> getShapePerCTATile(Attribute layout,
ArrayRef<int64_t> tensorShape) {
SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distributedLayout.getShapePerCTATile(tensorShape);
auto sizePerThread = distributedLayout.getSizePerThread();
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
// ThreadsPerWarp does not align with this function for slice layout
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
}
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
assert(sizePerThread.size() == threadsPerWarp.size() &&
sizePerThread.size() == warpsPerCTA.size());
SmallVector<unsigned> shape;
for (auto [size, thread, warp] :
llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) {
shape.push_back(size * thread * warp);
}
return shape;
} else {
llvm::report_fatal_error("getShapePerCTATile not implemented");
return SmallVector<unsigned>();
Expand Down Expand Up @@ -678,14 +691,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
SmallVector<unsigned> BlockedEncodingAttr::getSizePerThread() const {
return SmallVector<unsigned>(getSizePerThread__());
}
SmallVector<unsigned>
BlockedEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
SmallVector<unsigned> shape;
for (unsigned d = 0, n = getOrder().size(); d < n; ++d)
shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] *
getWarpsPerCTA()[d]);
return shape;
}

template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
Expand Down Expand Up @@ -787,12 +792,6 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
sizePerThread.erase(sizePerThread.begin() + getDim());
return sizePerThread;
}
SmallVector<unsigned>
SliceEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
SmallVector<unsigned> shape = ::getShapePerCTATile(getParent(), tensorShape);
shape.erase(shape.begin() + getDim());
return shape;
}

//

Expand Down Expand Up @@ -979,9 +978,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
}
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
auto shapePerCTA = getShapePerCTA(*this, shape);
auto shapePerCTATile = ::getShapePerCTATile(blockedLayout);
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
auto order = blockedLayout.getOrder();
auto sizePerThread = ::getSizePerThread(blockedLayout);
auto sizePerThread = blockedLayout.getSizePerThread();

int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0];
int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0];
Expand Down Expand Up @@ -1043,19 +1042,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
ArrayRef<int64_t> tensorShape) const {
auto parentLayout = getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
return parentMmaLayout.getShapePerCTATileForOperand(
tensorShape, getKWidth(), getOpIdx());
} else {
llvm::report_fatal_error(
"DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
"supported yet");
}
}

LogicalResult DotOperandEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
Expand Down Expand Up @@ -1562,16 +1548,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
//===----------------------------------------------------------------------===//
// TODO: there is a lot of common code with MmaEncoding here

SmallVector<unsigned>
AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
auto warpsPerCTA = getWarpsPerCTA();
auto rank = warpsPerCTA.size();
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
shapePerCTATile[rank - 1] *= getMDim();
shapePerCTATile[rank - 2] *= getNDim();
return shapePerCTATile;
}

SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -1715,43 +1691,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
return sizePerThread;
}

SmallVector<unsigned>
AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
int kWidth, int opIdx) const {
assert(getMDim() == 32 || getMDim() == 16);
auto parentShapePerCTATile = getShapePerCTATile(shape);
auto rank = parentShapePerCTATile.size();
if (opIdx == 0) {
if (rank == 2)
return {parentShapePerCTATile[rank - 2], 32};
else
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32};
} else if (opIdx == 1) {
if (rank == 2)
return {32, parentShapePerCTATile[rank - 1]};
else
return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1");
}

//===----------------------------------------------------------------------===//
// Wmma encoding
//===----------------------------------------------------------------------===//

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
auto warpsPerCTA = getWarpsPerCTA();
auto rank = warpsPerCTA.size();
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());

auto mnkDim = getMNKDimPerInstr();
shapePerCTATile[rank - 2] *= mnkDim[0];
shapePerCTATile[rank - 1] *= mnkDim[1];
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand Down Expand Up @@ -1816,21 +1759,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
return sizePerThread;
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
int kWidth, int opIdx) const {
auto parentShapePerCTA = getShapePerCTATile(shape);
auto rank = shape.size();
assert(rank == 2);
if (opIdx == 0) {
return {parentShapePerCTA[0], static_cast<unsigned>(shape[1])};
} else if (opIdx == 1) {
return {static_cast<unsigned>(shape[0]), parentShapePerCTA[1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
}

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx);
Expand Down Expand Up @@ -1949,24 +1877,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
llvm_unreachable("Unexpected mma version");
}

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
if (isAmpere()) {
auto warpsPerCTA = getWarpsPerCTA();
auto rank = warpsPerCTA.size();
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(),
warpsPerCTA.end());
shapePerCTATile[rank - 1] *= 8;
shapePerCTATile[rank - 2] *= 16;
return shapePerCTATile;
}
if (isHopper()) {
auto instrShape = getInstrShape();
return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]};
}
llvm::report_fatal_error("Unexpected MMA layout version found");
}

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
Expand Down Expand Up @@ -2007,16 +1917,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
}
}

SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
assert(isAmpere() && "mmaLayout Hopper is not implemented yet");
auto shapePerCTATile = getShapePerCTATile(shape);
auto rank = shapePerCTATile.size();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
// 4 threads * 2 subtiles
shapePerCTATile[kDim] = kWidth * 2 * 4;
return shapePerCTATile;
}
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
auto rank = getWarpsPerCTA().size();
Expand Down
3 changes: 1 addition & 2 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ LogicalResult ExtractSliceOp::verify() {
}

auto srcShape = srcTy.getShape();
auto shapePerCTATile =
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct ExtractSliceOpConversion
auto order = triton::gpu::getOrder(srcLayout);

// Calculate valid total number of workers in each dimension
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
} else {
warpOrder = triton::gpu::getWarpOrder(layout);
}
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
} else {
warpOrder = triton::gpu::getWarpOrder(layout);
}
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
Value warpSize = i32_val(32);
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
Expand Down
2 changes: 0 additions & 2 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ using namespace mlir;

using mlir::LLVM::getWrappedMultiDimOffset;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
namespace {
// declare vprintf(i8*, i8*) as external function
LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) {
Expand Down
Loading