Skip to content

Commit

Permalink
[BACKEND] Improve detection of register to register conversion (trito…
Browse files Browse the repository at this point in the history
…n-lang#4991)

Specifically, it fixes problems when `srcLayout` and `dstLayout` have
different number of registers but the same number of not free registers.
We solved the problem by padding free registers to either `srcLayout` or
`dstLayout`, but this can be improved by fixing the `invertAndCompose`
function.
  • Loading branch information
Jokeren authored Oct 25, 2024
1 parent d31ccfe commit 15c5e55
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 24 deletions.
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

Expand Down
7 changes: 7 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,13 @@ class LinearLayout {
// (i.e. every input bit affects the output).
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;

// Increase an input dimension without affecting the output dimension. The
// added free variables are mapped to 0, ensuring that the new input
// dimensions correspond directly to the existing output space. The function
// errors out if `newInDimSize` is less than the current size or the new size
// is not a power of 2.
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;

std::string toString() const;

friend bool operator==(LinearLayout lhs, LinearLayout rhs);
Expand Down
42 changes: 40 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
Expand Down Expand Up @@ -655,8 +655,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (!(srcLayout.has_value() && dstLayout.has_value()))
return std::nullopt;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
auto numDstRegs = dstLayout->getInDimSize(kRegister);
// The `invertAndCompose` function will generate a layout that is injective
// by assigning new output dimensions to free variables. For instance,
// consider a scenario where `srcLayout` has a free variable in the lane
// dimension, while `dstLayout` has two free variables in the lane
// dimension and also a larger number of registers.
// The injective form of `srcLayout` will add only a single additional row
// to the transformation matrix, whereas the injective form of `dstLayout`
// will add two additional rows. This discrepancy causes misleading results
// because the matrices end up with a different number of rows.
//
// Take `dstLayout ⋅ srcLayout^-1` as an example:
//
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
// 1] → [n + 2, n + 1]
//
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
// variable in registers, and the `(n + 2)`-th row represents the free
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
// in two layouts do not correspond to the same free variable.
//
// To address this issue, we pad the free variables in `srcLayout` and
// `dstLayout` to ensure they have the same number of registers. This
// guarantees that the resulting matrices have the same number of rows,
// ensuring consistency in the composition process.
auto numRegs = std::max(numSrcRegs, numDstRegs);
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
// comp describes the layout function to create dst from src.
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
LinearLayout comp =
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
Expand Down
53 changes: 32 additions & 21 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,60 +288,71 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return rewriter.notifyMatchFailure(
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
}
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());

StringAttr kBlock = str_attr("block");
StringAttr kWarp = str_attr("warp");
StringAttr kLane = str_attr("lane");
StringAttr kRegister = str_attr("register");

assert(to_vector(conversion->getInDimNames()) ==
to_vector(conversion->getOutDimNames()));
auto dims = conversion->getInDimNames();
if (llvm::is_contained(dims, str_attr("block"))) {
if (llvm::is_contained(dims, kBlock)) {
// Case 1: Transfer between values in different CTAs.
// This requires moving values through distributed shared memory.
return rewriter.notifyMatchFailure(
op, "NYI: Transfer between different CTAs");
} else if (llvm::is_contained(dims, str_attr("warp"))) {
} else if (llvm::is_contained(dims, kWarp)) {
// Case 2: Transfer between values in the same CTA, in which case we move
// values through shared memory.
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
} else if (llvm::is_contained(dims, str_attr("lane"))) {
} else if (llvm::is_contained(dims, kLane)) {
// Case 3. Transfer between values in the same warp, in which case we try
// to move values using warp shuffles, though if the pattern is
// complicated enough we may fall back to using shared memory
// TODO(Keren): implement warp shuffle instead of using the general
// approach that uses shared memory
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
} else if (llvm::is_contained(dims, str_attr("register"))) {
} else if (llvm::is_contained(dims, kRegister) ||
dstLayout.getInDimSize(kRegister) !=
srcLayout.getInDimSize(kRegister)) {
// Case 4. Transfer between values in the same thread, in which case we
// simply reorder the elements of adaptor.getSrc().
return transferWithinThread(op, *conversion, adaptor, rewriter);
return transferWithinThread(
op, dstLayout.getFreeVariableMasks()[kRegister],
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
} else {
// The two layouts are equivalent. We should probably remove these in
// RemoveLayoutConversion.
// Cast 5. The two layouts are equivalent. We should probably remove
// these in RemoveLayoutConversion.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
}

LogicalResult
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
OpAdaptor adaptor,
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
const LinearLayout &conversion, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
StringAttr kRegister = str_attr("register");
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals;
outVals.resize(conversion.getInDimSize(kRegister));
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
SmallVector<Value> outVals(numRegs);
for (int i = 0; i < outVals.size(); i++) {
// Remove free masks from the register index
// For example, if idx = 0b00111, and masks = 0b00100, then we get
// 0b00011. It means that register 7 (0b111) has the same value as
// register 3 (0b011).
auto idx = i & (~regMasks);
auto srcIdx = conversion.hasInDim(kRegister)
? conversion.apply({{kRegister, idx}}).begin()->second
: idx;
outVals[i] = inVals[srcIdx];
}
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
Expand Down
15 changes: 15 additions & 0 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const {
return true;
}

LinearLayout LinearLayout::resize(StringAttr inDim,
int32_t newInDimSize) const {
BasesT bases = getBases();
assert(bases.contains(inDim) && "inDim not in layout");
assert(llvm::isPowerOf2_32(newInDimSize) &&
"newInDimSize must be a power of 2");
assert(newInDimSize >= getInDimSize(inDim) &&
"newInDimSize must be >= old size");
auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim);
for (int i = 0; i < numFreeVariables; i++) {
bases[inDim].push_back(std::vector<int32_t>(getNumOutDims(), 0));
}
return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames()));
}

std::string LinearLayout::toString() const {
// Start with a newline because we print out a bulleted list; it doesn't
// make sense for the first line of this list to be on the same line as
Expand Down
74 changes: 74 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_mmav2_dot_reg
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
tt.return
}
}

// -----

#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
// CHECK-NOT: st.shared
// CHECK-NOT: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
Expand Down
33 changes: 33 additions & 0 deletions unittest/Tools/LinearLayoutTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,39 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) {
ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value());
}

TEST_F(LinearLayoutTest, Resize) {
auto init = LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}}},
{S("in1"), {{1, 0}, {2, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")});
EXPECT_EQ(init.resize(S("in0"), 8),
LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}, {0, 0}}},
{S("in1"), {{1, 0}, {2, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}}},
{S("in1"), {{1, 0}, {2, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(init.resize(S("in1"), 8),
LinearLayout(
{
{S("in0"), {{0, 1}, {0, 2}}},
{S("in1"), {{1, 0}, {2, 0}, {0, 0}}},
{S("in2"), {}},
},
{S("dim0"), S("dim1")}));
}

} // anonymous namespace
} // namespace mlir::triton

Expand Down

0 comments on commit 15c5e55

Please sign in to comment.