Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Nov 4, 2024
1 parent a643eba commit a0ce675
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 47 deletions.
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ReduceOpHelper {
// The shape of the shared memory space needed for the reduction.
SmallVector<unsigned> getScratchRepShape();

SmallVector<unsigned> getOrderWithAxisAtBeginning();
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();

unsigned getScratchSizeInBytes();

Expand Down
3 changes: 3 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
auto dstShapePerCTATile =
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());

assert(srcTy.getRank() == dstTy.getRank() &&
"src and dst must have the same rank");

unsigned rank = dstTy.getRank();
SmallVector<unsigned> repShape(rank);
for (unsigned d = 0; d < rank; ++d) {
Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
Expand All @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto order = triton::gpu::getThreadOrder(layout);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
10 changes: 5 additions & 5 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) {
return axis;
}

SmallVector<unsigned> getParentOrder(Attribute layout) {
SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
return getParentThreadOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
Expand All @@ -44,12 +44,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
getParentThreadOrder(getSrcLayout())[0];
}

SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ struct ReduceOpConversion
getMultiDimWarpId(helper, warpId, loc, rewriter);
Value warpIdAxis = multiDimWarpId[axis];

auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
Expand Down Expand Up @@ -370,7 +370,7 @@ struct ReduceOpConversion
Location loc = op.getLoc();
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
auto smemOrder = helper.getThreadOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
int numContiguousValues = 1;
auto encoding = cast<BlockedEncodingAttr>(
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
int splitDim = encoding.getOrder().size() - 1;
for (int i = 0; i < encoding.getOrder().size(); i++) {
if (encoding.getOrder()[i] == splitDim)
int splitDim = encoding.getThreadOrder().size() - 1;
for (int i = 0; i < encoding.getThreadOrder().size(); i++) {
if (encoding.getThreadOrder()[i] == splitDim)
break;
numContiguousValues *= encoding.getSizePerThread()[i];
}
Expand Down Expand Up @@ -336,7 +336,6 @@ struct BroadcastOpConversion
unsigned rank = srcTy.getRank();
auto typeConverter = getTypeConverter();
assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> srcVals = unpackLLElements(loc, src, rewriter);
Expand Down
63 changes: 31 additions & 32 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
}

/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
/* Utility function used by get.*Order methods of SliceEncodingAttr.
* Erase dim and decrease all values larger than dim by 1.
* Example: order = [0, 2, 4, 3, 1], dim = 2
* resOrder = [0, 3, 2, 1]
Expand Down Expand Up @@ -262,24 +262,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
return getWarpOrder(dotLayout.getParent());
}
}

auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout);
if (nvidiaMma && nvidiaMma.isHopper()) {
auto rank = nvidiaMma.getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ false);
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error(
"DotOperandEncoding::getWarpOrder not implemented");
}

return getOrder(layout);
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getWarpOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
}

SmallVector<unsigned> getOrder(Attribute layout) {
Expand Down Expand Up @@ -317,15 +304,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {

llvm::report_fatal_error("Unimplemented usage of getOrder");
return {};
};
}

SmallVector<unsigned> getThreadOrder(Attribute layout) {
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getThreadOrder();
else
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
return {};
};
}

CTALayoutAttr getCTALayout(Attribute layout) {
if (auto distributedLayout =
Expand Down Expand Up @@ -768,7 +755,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
return warpsPerCTA;
}
SmallVector<unsigned> SliceEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto parentWarpOrder = ::getWarpOrder(getParent());
return eraseOrder(parentWarpOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
auto parent = getParent();
Expand All @@ -780,7 +768,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
return threadsPerWarp;
}
SmallVector<unsigned> SliceEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
auto parentThreadOrder = ::getThreadOrder(getParent());
return eraseOrder(parentThreadOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
auto sizePerThread = ::getSizePerThread(getParent());
Expand Down Expand Up @@ -1048,7 +1037,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
return warps;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
return ::getWarpOrder(getParent());
}
// It's quite weird to talk about warp order when that the warps
// are broadcasted along the K dimension
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
Expand Down Expand Up @@ -1596,7 +1592,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
auto order = ::getOrder(*this);
Expand Down Expand Up @@ -1765,7 +1761,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
return ::getOrder(*this);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
Expand Down Expand Up @@ -1889,7 +1885,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
return SmallVector<unsigned>(getWarpsPerCTA__());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
auto rank = getWarpsPerCTA().size();
// Hopper (wgmma) uses column-major as this is embeded in the instruction
// For Ampere we can choose either row-major or column-major.
// We choose row-major as the legacy path did so
return getMatrixOrder(rank, /*rowMajor*/ !isHopper());
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
auto rank = getWarpsPerCTA().size();
Expand All @@ -1913,11 +1913,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
"getThreadsPerWarp not implemented for unknown Mma version ");
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadOrder() const {
auto rank = getThreadsPerWarp().size();
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}
SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
auto rank = ::getOrder(*this).size();
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
if (isAmpere()) {
res[rank - 2] = 2;
Expand Down Expand Up @@ -2158,11 +2158,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
if (opIdx == 0) {
sizePerThread[rank - 2] = 2;
sizePerThread[rank - 1] = 2 * kWidth;
} else if (opIdx == 1) {
} else {
assert(opIdx == 1);
sizePerThread[rank - 2] = 2 * kWidth;
sizePerThread[rank - 1] = 1;
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
return sizePerThread;
}
Expand Down

0 comments on commit a0ce675

Please sign in to comment.