Skip to content

Commit

Permalink
[BACKEND] Replace isMmaToDotShortcut with linear layout based logic (
Browse files Browse the repository at this point in the history
…triton-lang#4951)

This PR removes the legacy `isMmaToDotShortcut` and its associated shortcut conversion.
  • Loading branch information
Jokeren authored Oct 28, 2024
1 parent 3889f3f commit 1d5fdfe
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 200 deletions.
2 changes: 0 additions & 2 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ namespace gpu {
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -187,8 +179,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -215,7 +207,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
61 changes: 61 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,67 @@ inline Value getStructFromSharedMemoryObject(Location loc,
return llvmStruct;
}

// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of tensors with dot operand encodings in i32 to
// facilitate instructions such as `ldmatrix`.
//
// TODO: Confirm if the problem is still there.
inline bool requiresI32Conversion(Type type) {
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return false;
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!dotOpEnc)
return false;
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
if (!(parent && parent.getVersionMajor() < 3))
return false;
return true;
}

inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
auto vecTy = vec_ty(eltTy, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecTy);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
for (auto v : inValues) {
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecTy);
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
22 changes: 4 additions & 18 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,14 +731,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
// subsumed by the linear-layout checks.
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
// checks.
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}

Expand All @@ -749,20 +749,6 @@ bool atomicNeedsSharedMemory(Value value) {
return true;
}

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
!srcTy.getElementType().isF32();
}

namespace {

/// A data structure similar to SetVector but maintains
Expand Down
43 changes: 22 additions & 21 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
} else {
// Cast 5. The two layouts are equivalent. We should probably remove
// these in RemoveLayoutConversion.
rewriter.replaceOp(op, adaptor.getSrc());
auto dstCvt = requiresI32Conversion(dstTy);
auto srcCvt = requiresI32Conversion(srcTy);
if (dstCvt || srcCvt) {
auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter);
inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(),
getTypeConverter());
inVals =
packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter());
auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals,
rewriter, op.getType());
rewriter.replaceOp(op, res);
} else {
rewriter.replaceOp(op, adaptor.getSrc());
}
return success();
}
}
Expand All @@ -342,9 +355,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
StringAttr kRegister = str_attr("register");
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
SmallVector<Value> outVals(numRegs);
for (int i = 0; i < outVals.size(); i++) {
for (int i = 0; i < numRegs; i++) {
// Remove free masks from the register index
// For example, if idx = 0b00111, and masks = 0b00100, then we get
// 0b00011. It means that register 7 (0b111) has the same value as
Expand All @@ -355,6 +371,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
: idx;
outVals[i] = inVals[srcIdx];
}
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -386,9 +403,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (auto nvidiaMma =
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
return false;
}
if (useLegacyMMAConversion) {
return false;
}
Expand All @@ -398,6 +412,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
return largeKWidth && nvidiaMma.isAmpere();
}
return false;
}
if (isa<BlockedEncodingAttr>(layout)) {
return true;
Expand Down Expand Up @@ -439,6 +454,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
}
}
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());

// Pretty sure this is the identity function ATM
// It'd be better to simply call `quotient({kBlock})` and
Expand All @@ -458,22 +474,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

// FIXME [Dot LL]
// We know it's just for largeKWidth case in Ampere
// In this case, we need to pack the outputs into i32
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
auto concat = [&](Value a, Value b) {
return or_(zext(i32_ty, bitcast(a, i16_ty)),
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
};

SmallVector<Value> outVals32(outVals.size() / 2);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
}
outVals = outVals32;
}

outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down
56 changes: 6 additions & 50 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,51 +103,6 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
llvm_unreachable("unimplemented code path");
}

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
return inValues;
SmallVector<Value> outValues;
for (auto v : inValues) {
// cast i32 to appropriate eltType vector and extract elements
auto eltType = typeConverter->convertType(tensorTy.getElementType());
auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecType);
for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
return inValues;
SmallVector<Value> outValues;
auto eltType = typeConverter->convertType(tensorTy.getElementType());
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
auto vecType = vec_ty(eltType, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecType);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

int getNumElementsPerThreads(Type type,
const LLVMTypeConverter *typeConverter) {
int numElemsPerThread = 1;
Expand Down Expand Up @@ -500,7 +455,7 @@ struct ElementwiseInlineAsmOpConversion
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
unpackedOperands.push_back(
unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter()));
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
}

int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
Expand Down Expand Up @@ -560,10 +515,11 @@ struct ElementwiseInlineAsmOpConversion
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
/*ouType=*/op->getResult(i).getType());
}
auto packed = packI32(unpackedResults[i], op->getResult(i).getType(),
rewriter, loc, getTypeConverter());
outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter,
op->getResult(i).getType()));
auto dstTy = op->getResult(i).getType();
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
getTypeConverter());
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
rewriter, op->getResult(i).getType()));
}

rewriter.replaceOp(op, outs);
Expand Down
37 changes: 1 addition & 36 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,42 +184,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
SmallVector<Value> outVals = loadSharedToDistributed(
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);

// FIXME [Dot LL]
// Ampere case
// In this case, we need to pack the outputs into i32
if (auto dotOp = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding())) {
if (auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOp.getParent())) {
if (parent.isAmpere()) {
if (elemLlvmTy.isInteger(8)) {
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
return or_(
or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
or_(shl(zext(i32_ty, a3), i32_val(16)),
shl(zext(i32_ty, a4), i32_val(24))));
};
SmallVector<Value> outVals32(outVals.size() / 4);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
outVals[4 * i + 2], outVals[4 * i + 3]);
}
outVals = outVals32;
} else {
assert(elemLlvmTy.isBF16() && "Unexpected element type");
auto concat = [&](Value a, Value b) {
return or_(zext(i32_ty, bitcast(a, i16_ty)),
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
};

SmallVector<Value> outVals32(outVals.size() / 2);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
}
outVals = outVals32;
}
}
}
}

outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter);
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

Expand Down
Loading

0 comments on commit 1d5fdfe

Please sign in to comment.