From 72199c627a91f45e33d8ea8e21d5ae6335fbb4f0 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 13 Dec 2024 23:13:28 +0000 Subject: [PATCH] [rc/3.2.x] LLVM bump for gfx950 target support (#5417) This PR brings in required LLVM bumps and additional targets for gfx950 support. - https://github.com/triton-lang/triton/pull/5040 - https://github.com/triton-lang/triton/pull/5064 - https://github.com/triton-lang/triton/pull/5180 - https://github.com/triton-lang/triton/pull/5242 - https://github.com/triton-lang/triton/pull/5392 Note this PR reverts the last two PRs to only focus on the LLVM upgrade - #5347 - #5191 --------- Co-authored-by: peterbell10 Co-authored-by: Hongtao Yu Co-authored-by: Lei Zhang Co-authored-by: Jungwook Park --- cmake/llvm-hash.txt | 2 +- include/triton/Dialect/Triton/IR/TritonOps.td | 4 - lib/Dialect/Triton/IR/Ops.cpp | 16 - lib/Dialect/Triton/Transforms/Combine.td | 4 +- test/Conversion/amd/tritongpu_to_llvm.mlir | 108 ----- .../include/Dialect/TritonAMDGPU/IR/Dialect.h | 1 - .../include/TritonAMDGPUToLLVM/TargetUtils.h | 11 - .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 41 +- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 184 +------- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 123 +----- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 18 +- .../lib/TritonAMDGPUTransforms/CMakeLists.txt | 1 - .../ReorderInstructions.cpp | 396 ++++++++---------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 12 +- 14 files changed, 240 insertions(+), 681 deletions(-) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 0952ab984cc9..50d024794663 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -1f20eee6dc367bd202895e3eedb03974a628ef16 +86b69c31642e98f8357df62c09d118ad1da4e16a diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index c39c408d9330..283dd9165918 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -727,10 +727,6 @@ def TT_ReduceOp: TT_Op<"reduce", llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); unsigned getNumOperands(); - - // Returns the CombineOp iff this ReduceOp's region contains only - // one CombineOp other than the return, or nullptr if not applicable. - ::mlir::Operation *getSingleCombiner(); }]; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index e77e2d5c8691..ffea5f3c67a6 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -503,22 +503,6 @@ llvm::SmallVector ReduceOp::getElementTypes() { return getElementTypesImpl(this->getOperands()); } -::mlir::Operation *ReduceOp::getSingleCombiner() { - if (getNumOperands() != 1 || getNumResults() != 1) - return nullptr; - Block *block = &(*getCombineOp().begin()); - Operation *yield = block->getTerminator(); - Operation *reduceOp = yield->getOperand(0).getDefiningOp(); - if (!reduceOp || reduceOp->getNumOperands() != 2 || - reduceOp->getNumResults() != 1) - return nullptr; - if (reduceOp->getOperand(0) != block->getArgument(0) || - reduceOp->getOperand(1) != block->getArgument(1)) - return nullptr; - - return reduceOp; -} - unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } //-- ScanOp -- diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 28b871983bfe..5a2fcecfa949 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 72e02a4ef46e..ef6733845721 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -62,111 +62,3 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } - -// ----- - -#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_f16 - tt.func @atomic_add_f16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> - // CHECK: llvm.cond_br - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> - tt.return - } -} - -// ----- - -#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_bf16 - tt.func @atomic_add_bf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { - %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> - %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> - %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> - // CHECK: llvm.cond_br - // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> - %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> - tt.return - } -} - -// ----- - -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: reduce_dpp_max - tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 280, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 276, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 274, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 273, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 322, 10, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK-NEXT: rocdl.update.dpp - // CHECK-SAME: with 323, 15, 15, true : f32 - // CHECK-NEXT: llvm.intr.maxnum - - // CHECK: llvm.amdgcn.readlane - %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<64xf32, #blocked3>) -> f32 - tt.return - } -} - -// ----- - -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: reduce_xor_max - tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { - // CHECK: rocdl.ds_swizzle - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 280, 15, 12, false : i32 - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 264, 15, 3, false : i32 - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 276, 15, 10, false : i32 - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 260, 15, 5, false : i32 - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 78, 15, 15, false : i32 - // CHECK: llvm.intr.maxnum - - // CHECK: rocdl.update.dpp - // CHECK-SAME: with 177, 15, 15, false : i32 - %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<32xf32, #blocked4>) -> f32 - tt.return - } -} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 6dbb0435e20c..a7395f86dc50 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -30,7 +30,6 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Dialect/Triton/IR/Traits.h" - // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" // clang-format on diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h index 9e174d545dd9..a49e442d3984 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -19,17 +19,6 @@ enum class ISAFamily { // Deduces the corresponding ISA family for the given target gfx |arch|. ISAFamily deduceISAFamily(llvm::StringRef arch); -// Here is a partial definition of DppCtrl enums. For the complete definition, -// please check: -// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939 -enum class DppCtrl : uint32_t { - QUAD_PERM_FIRST = 0, - ROW_SHL0 = 0x100, - ROW_SHR0 = 0x110, - BCAST15 = 0x142, - BCAST31 = 0x143 -}; - } // namespace mlir::triton::AMD #endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 5265f631ad9e..a45efd4a7971 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -768,11 +768,7 @@ struct AtomicRMWOpConversion // tensor if (tensorTy) { auto valTy = cast(val.getType()); - Type elTy = valTy.getElementType(); - vec = std::min(vec, llvm::isa(elTy) && - elTy.getIntOrFloatBitWidth() == 16 - ? 2 - : 1); + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); // mask numElems = tensorTy.getNumElements(); } @@ -787,22 +783,13 @@ struct AtomicRMWOpConversion auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; SmallVector resultVals(elemsPerThread); + const bool f16v2 = vec == 2 && valueElemTy.isF16(); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; - Value operand; - if (vec == 1) { - operand = valElements[i]; - } else { - operand = undef(vecTy); - for (size_t ii = 0; ii < vec; ++ii) - operand = - insert_element(vecTy, operand, valElements[i + ii], i32_val(ii)); - } - Value undefVal = undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); @@ -819,11 +806,25 @@ struct AtomicRMWOpConversion auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. - Value atom = - rewriter - .create(loc, *maybeKind, rmwPtr, operand, - atomicMemOrdering, StringRef("agent")) - .getResult(); + Value atom = rewriter + .create( + loc, *maybeKind, rmwPtr, valElements[i], + atomicMemOrdering, StringRef("agent")) + .getResult(); + + // NV for the f16v2 case generates one packed instruction. We have to + // create two separate instructions since LLVM::AtomicRMWOp doesn't + // support this. Can be optimized out with rocdl.raw.buffer.atomic. + if (f16v2) { + Value atom2 = + rewriter + .create( + loc, *maybeKind, ptrElements[i + 1], valElements[i + 1], + atomicMemOrdering, StringRef("agent")) + .getResult(); + auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); + atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); + } if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 525361fee603..3a40d73c2a7c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -5,7 +5,6 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -using mlir::triton::AMD::DppCtrl; namespace mlir::triton::AMD { namespace { @@ -104,22 +103,22 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleXor(loc, rewriter, val, i); } Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleUp(loc, rewriter, val, i); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } Value TargetInfo::programId(RewriterBase &rewriter, Location loc, @@ -127,184 +126,11 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } -// Cast and sext values into specific-length int to meet the requirements of -// instructions like UpdateDpp or readlane if necessary. -static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, - Value &val, Type fromType, - unsigned toBits) { - unsigned originalBits = fromType.getIntOrFloatBitWidth(); - Type toType = fromType; - - if (!fromType.isIntOrIndex()) { - val = bitcast(val, int_ty(originalBits)); - toType = int_ty(originalBits); - } - - if (originalBits < toBits) { - val = sext(int_ty(toBits), val); - toType = int_ty(toBits); - } - - return toType; -} - -// Trunc the value to specific length and then cast it to given type if -// necessary. This function is typically used in conjunction with -// castToAndSExtInt. -static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, - Value val, Type valType, - unsigned fromBits) { - unsigned originalBits = valType.getIntOrFloatBitWidth(); - Value toVal = val; - - if (originalBits < fromBits) { - toVal = trunc(int_ty(originalBits), toVal); - } - - if (!valType.isIntOrIndex()) { - toVal = bitcast(toVal, valType); - } - - return toVal; -} - bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - if (numLaneToReduce != 64) - return false; - - if (auto family = getISAFamily(); - family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) { - return false; - } - - Operation *reduxOp = op.getSingleCombiner(); - if (!reduxOp) - return false; - - auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src, - uint32_t dppCtrl, int rowMask, - int bankMask) -> Value { - // DPP has limited support for data types, so here we need to - // cast non-integer types or integer types shorter than 32 bits - // to int32, except for fp32. - Type actualType = valType; - if (!valType.isF32()) { - actualType = castToAndSExtInt(rewriter, loc, src, valType, 32); - } - - Value dppResult = - rewriter - .create(loc, actualType, src, src, - rewriter.getI32IntegerAttr(dppCtrl), - rewriter.getI32IntegerAttr(rowMask), - rewriter.getI32IntegerAttr(bankMask), - rewriter.getBoolAttr(true)) - .getRes(); - - if (!valType.isF32()) { - src = truncAndCastFromInt(rewriter, loc, src, valType, 32); - dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32); - } - - IRMapping mapping; - mapping.map(reduxOp->getOperand(0), src); - mapping.map(reduxOp->getOperand(1), dppResult); - return rewriter.clone(*reduxOp, mapping)->getResult(0); - }; - - for (int i = 0; i < acc.size(); i++) { - Value buf; - auto valType = acc[i].getType(); - - /* - Here's the implementation of full-wavefront reduction using dpp. - https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ - - Each step has a v_mov_dpp instruction following the redux op. In - some cases, the lower-level compiler could merge them into single - instruction. For example, v_mov_dpp + max => v_max_dpp. - - For gfx9, we have 64 threads per warp. These 64 threads are arranged - into 4 rows, with each row being 16 threads. Each 16 threads are arranged - further into 4 banks, with each bank being 4 threads. Overall it's in a - (row, bank, thread) structure. When shuffling, we use row/bank mask to - indicate which row/bank to participate. Then modifier like row_shr and - row_bcast means exact data movement schemes. In the following - instructions, taking row 0 as an example: - - Step 1: Right shift for 8 lanes. - lane 8-15 = redux(lane 0-7, lane 8-15) - - Step 2: Right shift for 4 lanes. - lane 12-15 = redux(lane 8-11, lane 12-15) - - Step 3: Right shift for 2 lanes. - lane 14-15 = redux(lane 12-13, lane 14-15) - - Step 4: Right shift for 1 lane. - lane 15 = redux(lane 14, lane 15) - - Step 5: Broadcast lane 15 of each row to all the lanes of its next row. - lane 16-31 = redux(lane 15, lane 16-31) - - Step 6: Broadcast lane 31 to lane 32-63. - lane 32-63 = redux(lane 31, lane 32-63) - - Now the reduction result is stored in lane 63. - - Step 7: Read the reduction result from lane 63 and broadcast with - readlane. - */ - - const int allRows = 0xf; - const int allBanks = 0xf; - - const uint32_t dppCtrlRowShr = static_cast(DppCtrl::ROW_SHR0); - - // row_shr:8 - buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr, - allRows, allBanks); - - // row_shr:4 - buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr, - allRows, allBanks); - - // row_shr:2 - buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr, - allRows, allBanks); - - // row_shr:1 - buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, - allRows, allBanks); - - // row_bcast:15 row_mask:0xa - buf = createDppReduxOpWithBoundCtrl( - valType, buf, static_cast(DppCtrl::BCAST15), 0xa, allBanks); - - // row_bcast:31 - buf = createDppReduxOpWithBoundCtrl(valType, buf, - static_cast(DppCtrl::BCAST31), - allRows, allBanks); - - // Similarly, we need to cast data types for readlane instruction. - Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16); - - // Get reduction result from lane 63 - std::string intrinsic = "llvm.amdgcn.readlane"; - Value result = - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, - ValueRange{buf, i32_val(63)}) - ->getResult(0); - - result = truncAndCastFromInt(rewriter, loc, result, valType, 16); - - acc[i] = result; - } - - return true; + return false; } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 0bd401f1993a..542b1ecbb7fb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,8 +8,6 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" -using mlir::triton::AMD::DppCtrl; -using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; @@ -73,9 +71,8 @@ Type castToVectorType(Type ty) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, RewriterBase &rewriter, - ISAFamily isaFamily, Value val, Value i, - int strideInt, ShflKind mode, Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -87,8 +84,7 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, if (bits < 32) val = sext(i32_ty, val); - val = - shuffleCommon(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); + val = shuffleCommon(loc, rewriter, val, i, strideInt, mode, clamp); if (bits < 32) val = trunc(int_ty(bits), val); @@ -102,10 +98,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = shuffleCommon(loc, rewriter, isaFamily, val0, i, strideInt, mode, - clamp); - val1 = shuffleCommon(loc, rewriter, isaFamily, val1, i, strideInt, mode, - clamp); + val0 = shuffleCommon(loc, rewriter, val0, i, strideInt, mode, clamp); + val1 = shuffleCommon(loc, rewriter, val1, i, strideInt, mode, clamp); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); @@ -140,83 +134,13 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value stride = i32_val(32); Value lineId = xor_(threadId, stride); return bpermute(lineId); - } else if (strideInt == 16) { - Value offset = i32_val(0x401F); - return rewriter.create(loc, valType, val, offset); } else { - if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) { - // DPP is only supportted for CDNA2 and CDNA3 right now, so we fallback - // to ds_swizzle for other archs. - // - // This map facilates the butterfly shuffle pattern for a stride less - // than 16. The pattern stride is the key of the map. - DenseMap masks{ - {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = i32_val(masks[strideInt]); - return rewriter.create(loc, valType, val, offset); - } - - auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, - uint32_t dppCtrl, uint32_t rowMask, - uint32_t bankMask) { - return rewriter.create( - loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl), - rewriter.getI32IntegerAttr(rowMask), - rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false)); - }; - - const int allRows = 0xf; - const int allBanks = 0xf; - - switch (strideInt) { - case 1: { - // quad_perm: 1, 0, 3, 2 - uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); - std::array mask = {1, 0, 3, 2}; - for (int i = 0; i < mask.size(); i++) { - dppCtrl |= mask[i] << (i * 2); - } - return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, - allBanks); - } - case 2: { - // quad_perm: 2, 3, 0, 1 - uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); - std::array mask = {2, 3, 0, 1}; - for (int i = 0; i < mask.size(); i++) { - dppCtrl |= mask[i] << (i * 2); - } - return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, - allBanks); - } - case 4: { - // row_shr:4 bank_mask: 0xa - auto ret = createDppOpWithoutBoundCtrl( - val, val, 4 + static_cast(DppCtrl::ROW_SHR0), - allRows, 0xa) - .getRes(); - - // row_shl:4 bank_mask: 0x5 - return createDppOpWithoutBoundCtrl( - ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, - 0x5); - } - case 8: { - // row_shr:8 bank_mask: 0xc - auto ret = createDppOpWithoutBoundCtrl( - val, val, 8 + static_cast(DppCtrl::ROW_SHR0), - allRows, 0xc) - .getRes(); - - // row_shl:8 bank_mask: 0x3 - return createDppOpWithoutBoundCtrl( - ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, - 0x3); - } - default: - assert(false && - "bfly shfl with stride >= 16 should not be handled by dpp."); - } + // This map facilates the butterfly shuffle pattern for a stride less + // than 16. The pattern stride is the key of the map. + DenseMap masks{ + {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; + Value offset = i32_val(masks[strideInt]); + return rewriter.create(loc, valType, val, offset); } break; case ShflKind::up: { @@ -234,27 +158,22 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, return Value(); } -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, - ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, - ShflKind::bfly, i32_val(0x1f)); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { + return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, + i32_val(0x1f)); } -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, - ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, - ShflKind::up, i32_val(0x0)); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { + return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, + i32_val(0x0)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, - ISAFamily isaFamily) { - return shuffleIdx(loc, rewriter, val, i32_val(i), isaFamily); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { + return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, - ISAFamily isaFamily) { - return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, - i32_val(0x1f)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { + return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index d150531848e3..123234fd4824 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -2,14 +2,12 @@ #define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_UTILITY_H #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" -#include "TritonAMDGPUToLLVM/TargetUtils.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" - namespace mlir::LLVM::AMD { const char predicatedLoad[] = "__predicated_load"; @@ -21,18 +19,10 @@ const char predicatedStoreCG[] = "__predicated_store_CG"; const char predicatedStoreCS[] = "__predicated_store_CS"; const char predicatedStoreWT[] = "__predicated_store_WT"; -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, - mlir::triton::AMD::ISAFamily isaFamily = - mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, int axis); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index c3a69a5f9a2a..7da8083cfb92 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -8,7 +8,6 @@ add_triton_library(TritonAMDGPUTransforms MfmaGroup.cpp DEPENDS - TritonAMDGPUIR TritonAMDGPUTransformsIncGen TritonGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 22349c50e308..e122f15fd901 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -5,28 +5,23 @@ #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "llvm/ADT/STLExtras.h" +#include + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" using namespace mlir; namespace ttg = mlir::triton::gpu; - -//===----------------------------------------------------------------------===// -// Utility functions -//===----------------------------------------------------------------------===// - -// Return true if the given moduleOp contains a pure matmul problem; i.e., -// single dot in the main loop. -static bool isPureMatmulProblem(ModuleOp moduleOp) { - for (auto forOp : moduleOp.getOps()) { - int counter = 0; - forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); - if (counter != 1) - return false; - } - return true; +namespace tt = mlir::triton; + +static bool isLocalLoadOrDotLayoutConversion(Operation *op) { + if (isa(op)) + return true; + if (auto cvt = dyn_cast(op)) + return isa(cvt.getType().getEncoding()); + return false; } // Search through block to find earliest insertion point for move op. This can @@ -66,233 +61,194 @@ findEarlyInsertionPoint(Block *block, Operation *move) { return ipnt; } -// Return the first user in the same block of the given op. If the user is in a -// nested block then return the op owning the block. Return nullptr if not -// existing. -static Operation *getFirstUseInSameBlock(Operation *op) { - SmallVector usersInSameBlock; - for (auto user : op->getUsers()) { - if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) - usersInSameBlock.push_back(ancestor); - } - auto minOpIt = - llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); - }); - return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr; -} - // Check if the operation opInsideLoop is inside any scf::ForOp and // opOutsideLoop is not inside the same loop. -static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, - mlir::Operation *opOutsideLoop) { +bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { scf::ForOp parentForOp = opInsideLoop->getParentOfType(); return parentForOp && !parentForOp->isAncestor(opOutsideLoop); } -//===----------------------------------------------------------------------===// -// Reorder mechanisms -//===----------------------------------------------------------------------===// - -// Sink dot layout conversions into loops to decrease register pressure when -// possible. -static void sinkDotConversion(ModuleOp moduleOp) { - DenseMap opToMove; - moduleOp.walk([&](ttg::ConvertLayoutOp op) { - Attribute encoding = op.getType().getEncoding(); - if (!isa_and_nonnull(encoding)) - return; - if (!op->hasOneUse()) - return; - Operation *user = *op->getUsers().begin(); - if (user->getParentOfType() == - op->getParentOfType()) - return; - opToMove[op] = user; - }); - - for (auto &kv : opToMove) - kv.first->moveBefore(kv.second); -} +class TritonAMDGPUReorderInstructionsPass + : public TritonAMDGPUReorderInstructionsBase< + TritonAMDGPUReorderInstructionsPass> { +public: + TritonAMDGPUReorderInstructionsPass() = default; + + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); + } + auto minOpIt = std::min_element(users.begin(), users.end(), + [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; + } -// Adjust the placement of shared memory writes and reads to immediately follow -// the definition of their operands in case where shared memory write is in the -// loop but its operand is not. -// -// This is a heuristic driven by optimizing fused attention by hoisting Q tensor -// shared memory read/write operations outside of the loop, as Q is a loop -// invariant and can be loaded once before entering the loop. But it should be -// generally applicable. -// -// There are two possible patterns for this adjustment depending on whether the -// write to shared memory is performed using an optional `local_alloc` argument -// or a `local_store` instruction. -// -// 1) %1 = some_op ... (typically a load or an operation that scales the tensor -// after loading) -// %2 = local_alloc %1 -// %3 = local_load %2 -// -// 2) %1 = some_op ... -// %2 = local_alloc -// %3 = local_store %1, %2 -// %4 = local_load %2 -static void hoistLocalLoad(ModuleOp moduleOp) { - moduleOp.walk([&](ttg::LocalLoadOp localLoad) { - auto localAlloc = localLoad.getSrc().getDefiningOp(); - if (!localAlloc) - return; + void runOnOperation() override { + ModuleOp m = getOperation(); - // Case when localAlloc has operands - if (localAlloc->getNumOperands() == 1) { - if (!localAlloc->hasOneUse()) + // Sink shared memory loads and layout conversions into loops to decrease + // register pressure when possible. + DenseMap opToMove; + m.walk([&](Operation *op) { + if (!isLocalLoadOrDotLayoutConversion(op)) return; - - auto srcTensorOp = localAlloc.getSrc().getDefiningOp(); - // Check if localAlloc is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) + if (!op->hasOneUse()) + return; + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == + op->getParentOfType()) + return; + opToMove.insert({op, user}); + }); + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); + opToMove.clear(); + + // Adjust the placement of LDS writes and reads to immediately follow the + // definition of their operands in case where LDS write is in the + // loop but it's operand is not. This is a heuristic for optimizing fused + // attention by hoisting Q tensor LDS read/write operations outside of the + // loop, as Q is a loop invariant and can be loaded once before entering the + // loop. + // There are two possible patterns for this adjustment depending on + // whether the write to LDS is performed using an optional `local_alloc` + // argument or a `local_store` instruction. + // + // clang-format off + // + // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) + // %2 = local_alloc %1 + // %3 = local_load %2 + // + // 2) %1 = some_op ... + // %2 = local_alloc + // %3 = local_store %1, %2 + // %4 = local_load %2 + // + // clang-format on + m.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) return; - localAlloc->moveAfter(srcTensorOp); - localLoad->moveAfter(localAlloc); - return; - } - - // Case when localAlloc has no operands - assert(localAlloc->getNumOperands() < 1); - auto allocVal = localAlloc->getResult(0); - - // Check if the localAlloc has exactly two uses (localStore and localLoad) - int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); - if (numUses != 2) - return; - - // localStore comes before localLoad in block. - Operation *localStore = getFirstUseInSameBlock(localAlloc); - if (!isa(localStore)) - return; - - auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); - // Check if localStore is in the loop but it's src tensor defining op is - // outside of it. - if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { - return; - } - - localAlloc->moveAfter(srcTensorOp); - localStore->moveAfter(localAlloc); - localLoad->moveAfter(localStore); - }); -} + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) + return; -// Sink conversion after the last dealloc but before the first use in its block. -// This helps to avoid unnecessary shared memory allocation. -static void moveDownCoversion(ModuleOp moduleOp) { - SmallVector convertOps; - moduleOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); }); + auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { + return; + } - for (auto op : convertOps) { - Operation *user = getFirstUseInSameBlock(op); - for (auto it = Block::iterator(op), ie = op->getBlock()->end(); - it != ie && &*it != user; ++it) - if (isa(&*it)) - op->moveAfter(&*it); - } -} + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); + return; + } -// Move transpositions just after their definition. -static void moveUpTranspose(ModuleOp moduleOp) { - SmallVector transOps; - moduleOp.walk([&](triton::TransOp op) { transOps.push_back(op); }); + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); - for (auto op : transOps) - if (Operation *argOp = op.getSrc().getDefiningOp()) - op->moveAfter(argOp); -} + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; -// Schedule global load and local store ops for better GEMM performance. -static void scheduleGlobalLoadLocalStore(ModuleOp m) { - SmallVector moveOps; - // Move global loads early to prefetch. This may increase register pressure - // but it enables issuing global loads early. - m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); - // Move local_stores early if dependence distance greater than one iteration. - // Best perf on GEMM when these precede global loads. - m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + // localStore comes before localLoad in block. + Operation *localStore = getFirstUse(localAlloc); + if (!isa(localStore)) + return; - for (auto op : llvm::reverse(moveOps)) { - // Gather use-def chain in block. - Block *block = op->getBlock(); - bool leadsToLoad = false; - SetVector backwardSet; + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { + return; + } - BackwardSliceOptions options; - options.omitBlockArguments = true; - options.inclusive = false; - options.filter = [&](Operation *defOp) -> bool { - Block *defBlock = defOp->getBlock(); - if (!block->findAncestorOpInBlock(*defOp)) - return false; - // Check for a `load` dependent path. - leadsToLoad |= isa(defOp); - // Only move ops residing in the same block. - return defBlock == block; - }; - mlir::getBackwardSlice(op, &backwardSet, options); - backwardSet.insert(op); + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); + }); - // Don't move a local_store if its source is a load from - // the same iteration. - if (isa(op) && leadsToLoad) - continue; + // Sink conversion after the last dealloc but before the first use ancestor + // in its block. This helps to avoid unnecessary shared memory allocation. + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + for (; &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); - auto ipoint = findEarlyInsertionPoint(block, op); - // Remove ops that already precede the insertion point. This is done - // before moves happen to avoid `Operation::isBeforeInBlock` N^2 - // complexity. + // Move transpositions just after their definition. + m.walk([&](triton::TransOp op) { + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); + }); - SmallVector dfg = backwardSet.takeVector(); - if (ipoint != block->end()) { - // Move ops to insertion point. - llvm::erase_if( - dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveAfter(block, ipoint); - } else { - // Move ops to block begin. - for (auto *dfgop : llvm::reverse(dfg)) - dfgop->moveBefore(block, block->begin()); + SmallVector moveOps; + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than + // one iteration. + // Best perf on GEMM when these precede global loads. + m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); + } } } -} - -//===----------------------------------------------------------------------===// -// Pass definition -//===----------------------------------------------------------------------===// - -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h" - -namespace { -struct TritonAMDGPUReorderInstructionsPass - : public TritonAMDGPUReorderInstructionsBase< - TritonAMDGPUReorderInstructionsPass> { - void runOnOperation() override { - ModuleOp m = getOperation(); - - hoistLocalLoad(m); - - sinkDotConversion(m); - moveDownCoversion(m); - - moveUpTranspose(m); - - if (isPureMatmulProblem(m)) - scheduleGlobalLoadLocalStore(m); - } }; -} // namespace std::unique_ptr mlir::createTritonAMDGPUReorderInstructionsPass() { return std::make_unique(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index d1cef15a354e..75f9354104b1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -93,12 +93,20 @@ static std::optional matchReduxKind(triton::ReduceOp op, int computeCapability) { if (computeCapability < 80) return std::nullopt; - Operation *reduceOp = op.getSingleCombiner(); - if (!reduceOp) + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return std::nullopt; + Block *block = &(*op.getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) return std::nullopt; auto intType = dyn_cast(reduceOp->getResultTypes()[0]); if (!intType || intType.getWidth() > 32) return std::nullopt; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return std::nullopt; if (isa(reduceOp)) return NVVM::ReduxKind::ADD; if (isa(reduceOp))