Skip to content

Commit

Permalink
[AMD] remove redundant LDS bypass checks (triton-lang#5002)
Browse files Browse the repository at this point in the history
This commit removes special cases for MFMA -> Dot Operand
LDS shortcuts. Now it is supported by common linear layout
infrastructure.

No tests are added, mfma-shortcut.mlir already testing this.

(cherry picked from commit 69f656c)
  • Loading branch information
binarman authored and jataylo committed Dec 12, 2024
1 parent ca16eec commit 06569e5
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 74 deletions.
2 changes: 0 additions & 2 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

assert(!isMfmaToDotShortcut(srcTy, dstTy));
assert(cvtNeedsSharedMemory(srcTy, dstTy));

auto inOrd = gpu::getOrder(srcLayout);
auto outOrd = gpu::getOrder(dstLayout);
Expand Down
19 changes: 1 addition & 18 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return matrixDimsCompatible && bDimCompatible;
}

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand Down Expand Up @@ -738,8 +722,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand Down
51 changes: 0 additions & 51 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,64 +115,13 @@ struct LocalLoadOpConversion
}
};

struct ConvertLayoutOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
using ConvertOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<RankedTensorType>(dst.getType());
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

if (isa<AMDMfmaEncodingAttr>(srcLayout) &&
isa<DotOperandEncodingAttr>(dstLayout)) {
return lowerMfmaToDotOperand(op, adaptor, rewriter);
}
return failure();
}

private:
LogicalResult
lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
RankedTensorType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
if (isMfmaToDotShortcut(srcTy, dstTy)) {
// vecSize is an number of sequential elements stored by one thread
// - For MFMA encoding (encoding of the result tensor of dot
// operation) it is 4
// - For MFMA operand encoding it is
// dotOperandEncoding::kWidth,
// which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack
// = 1)
//
// For cases where these two values are equal MFMA and MFMA operand
// layouts are the same.
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
Value view =
packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
};
} // namespace

namespace mlir::triton::AMD {
void populateConvertLayoutOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
}
} // namespace mlir::triton::AMD
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ struct DecomposeUnsupportedAMDConversions

triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod);

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod,
isMfmaToDotShortcut);
auto isShortcut =
mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory));

triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut);

/* -------------------------------- */
// Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op`
Expand Down

0 comments on commit 06569e5

Please sign in to comment.