Skip to content

Commit

Permalink
[BACKEND] Get rid of unpack/pack I32 (triton-lang#5044)
Browse files Browse the repository at this point in the history
- Removed functions related to unpacking and packing I32 values.
- Updated utilities to handle conversion of mxfp4 values without
packing/unpacking I32.
- Move the register value ordering logic from the element-wise operation
lowering to the dot operation lowering.
- Use linear layout to handle conversions between almost all distributed
layouts.
- Clean up data loading and mma computation involving `repN`, `repK`,
and `repM`.

(cherry picked from commit 1cf7b1b)
  • Loading branch information
Jokeren authored and jataylo committed Dec 13, 2024
1 parent 5287a68 commit 376fe7e
Show file tree
Hide file tree
Showing 20 changed files with 299 additions and 538 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace mlir::triton {

namespace gpu {

SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -179,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -201,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
it += curr.size();
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
73 changes: 8 additions & 65 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,14 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
// MXFP utilities
// -----------------------------------------------------------------------

// Convert one int8, which contain, 2 packed mxfp4 values, into 2 bf16
// standalone values and returns them as a pair for (high 4 bits, low 4 bits).
std::pair<Value, Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter,
Location loc, Value v);
// Convert each value, which is an int8 containing 2 packed mxfp4 values,
// into 2 standalone bf16 values
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values);

// Scale a mxfp4 value by a given scale.
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);

} // namespace LLVM

/* ------------------------------------ */
Expand Down Expand Up @@ -1397,67 +1401,6 @@ inline Value getStructFromSharedMemoryObject(Location loc,
return llvmStruct;
}

// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of tensors with dot operand encodings in i32 to
// facilitate instructions such as `ldmatrix`.
//
// TODO: Confirm if the problem is still there.
inline bool requiresI32Conversion(Type type) {
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return false;
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!dotOpEnc)
return false;
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
if (!(parent && parent.getVersionMajor() < 3))
return false;
return true;
}

inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
auto vecTy = vec_ty(eltTy, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecTy);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
for (auto v : inValues) {
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecTy);
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
6 changes: 2 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1199,8 +1199,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isAmpere() const;
bool isHopper() const;

unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;

// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;

Expand All @@ -1217,8 +1215,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2OrV3RepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
Expand Down
27 changes: 4 additions & 23 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
} else {
// Cast 5. The two layouts are equivalent. We should probably remove
// these in RemoveLayoutConversion.
auto dstCvt = requiresI32Conversion(dstTy);
auto srcCvt = requiresI32Conversion(srcTy);
if (dstCvt || srcCvt) {
auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter);
inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(),
getTypeConverter());
inVals =
packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter());
auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals,
rewriter, op.getType());
rewriter.replaceOp(op, res);
} else {
rewriter.replaceOp(op, adaptor.getSrc());
}
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
}
Expand All @@ -358,7 +345,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
SmallVector<Value> outVals(numRegs);
for (int i = 0; i < numRegs; i++) {
// Remove free masks from the register index
Expand All @@ -371,7 +357,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
: idx;
outVals[i] = inVals[srcIdx];
}
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -406,11 +391,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (useLegacyMMAConversion) {
return false;
}
// FIXME [Dot LL]
// Enabling LL path for buggy kWidth path
bool largeKWidth =
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
return largeKWidth && nvidiaMma.isAmpere();
if (nvidiaMma.isAmpere()) {
return true;
}
}
return false;
}
Expand Down Expand Up @@ -454,7 +437,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
}
}
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());

// Pretty sure this is the identity function ATM
// It'd be better to simply call `quotient({kBlock})` and
Expand All @@ -474,7 +456,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
auto dstDotOp =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (srcBlocked && dstDotOp) {
auto dotParent = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent());
if (dotParent && dotParent.isAmpere()) {
return;
}
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
auto tmpType = MemDescType::get(
Expand Down
140 changes: 7 additions & 133 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,138 +11,23 @@ using namespace mlir::triton::gpu;

namespace mlir::triton::gpu {

namespace {

bool isDotOpTensorAndPacked(Type srcTy) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return false;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!encoding)
return false;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
// By code convention, values for Hopper's dotOp-encoded tensors are not
// packed
if (!parentEnc || parentEnc.isHopper())
return false;
return true;
}

} // namespace

Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = dyn_cast<RankedTensorType>(type))
return tensorType.getElementType();
return type;
}
// MMA encoding has a different order depending on the element's bit width;
// reorder if we're in this case.
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType) {
auto inTensorTy = dyn_cast<RankedTensorType>(inType);
auto ouTensorTy = dyn_cast<RankedTensorType>(ouType);
if (!inTensorTy || !ouTensorTy)
return values;
auto inEncoding = dyn_cast<DotOperandEncodingAttr>(inTensorTy.getEncoding());
auto ouEncoding = dyn_cast<DotOperandEncodingAttr>(ouTensorTy.getEncoding());
assert(inEncoding == ouEncoding);
if (!inEncoding)
return values;
// If the parent of the dot operand is in block encoding, we don't need to
// reorder elements
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
if (!parentEncoding || parentEncoding.isHopper())
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
// Register layout conversion:
//
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
// [2, 3], [6, 7] [2], [3], [6], [7]
//
// Original access order:
//
// [0, 1], [2, 3], [4, 5], [6, 7]
//
// Transformed access order:
//
// [0], [2], [1], [3], [4], [6], [5], [7]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 7]);
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
// Register layout conversion:
//
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
//
// Original access order:
//
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
//
// Transformed access order:
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
return ret;
}
llvm_unreachable("unimplemented code path");
}

int getNumElementsPerThreads(Type type,
const LLVMTypeConverter *typeConverter) {
int numElemsPerThread = 1;
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return numElemsPerThread;
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (structType) {
numElemsPerThread = structType.getBody().size();
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (structType)
numElemsPerThread = structType.getBody().size();
}
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
return numElemsPerThread;
auto eltType = tensorTy.getElementType();
assert(eltType.getIntOrFloatBitWidth() <= 32 &&
"Only support element type with bit width <= 32 in dot operand mma "
"layout");
// dot operand data are packed into i32 elements so use the following formula
// to get the number of elements per thread.
return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread;
return numElemsPerThread;
}

} // namespace mlir::triton::gpu
Expand Down Expand Up @@ -473,8 +358,7 @@ struct ElementwiseInlineAsmOpConversion
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
unpackedOperands.push_back(
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
unpackedOperands.push_back(subOperands);
}

int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
Expand Down Expand Up @@ -527,16 +411,6 @@ struct ElementwiseInlineAsmOpConversion
// Reorder and pack the results.
SmallVector<Value> outs;
for (int i = 0; i < unpackedResults.size(); i++) {
// We reordered all the inputs so they match operand 0. Reorder the
// outputs accordingly.
if (op->getNumOperands() > 0) {
unpackedResults[i] = reorderValues(
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
/*ouType=*/op->getResult(i).getType());
}
auto dstTy = op->getResult(i).getType();
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
getTypeConverter());
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
rewriter, op->getResult(i).getType()));
}
Expand Down
Loading

0 comments on commit 376fe7e

Please sign in to comment.