From 21119e31f2e5d517643749657403d4e43deb13d4 Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Wed, 13 Nov 2024 18:36:42 -0800 Subject: [PATCH] [AMD] Support warp-level reduction with DPP (#5019) This commit adds support for warp-level reduction with DPP instructions, which can improve performance. See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ --- include/triton/Dialect/Triton/IR/TritonOps.td | 4 + lib/Dialect/Triton/IR/Ops.cpp | 16 ++ test/Conversion/amd/tritongpu_to_llvm.mlir | 76 ++++++++ .../include/TritonAMDGPUToLLVM/TargetUtils.h | 11 ++ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 184 +++++++++++++++++- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 123 ++++++++++-- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 18 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 12 +- 8 files changed, 404 insertions(+), 40 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 7f59241bb4e4..43e4ac027105 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -731,6 +731,10 @@ def TT_ReduceOp: TT_Op<"reduce", llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); unsigned getNumOperands(); + + // Returns the CombineOp iff this ReduceOp's region contains only + // one CombineOp other than the return, or nullptr if not applicable. + ::mlir::Operation *getSingleCombiner(); }]; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 34013655f1cb..a5ef8a487e09 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -503,6 +503,22 @@ llvm::SmallVector ReduceOp::getElementTypes() { return getElementTypesImpl(this->getOperands()); } +::mlir::Operation *ReduceOp::getSingleCombiner() { + if (getNumOperands() != 1 || getNumResults() != 1) + return nullptr; + Block *block = &(*getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return nullptr; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return nullptr; + + return reduceOp; +} + unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } //-- ScanOp -- diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 7ceae6b58eeb..d9a37b5c753f 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -132,3 +132,79 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_dpp_max + tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) { + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 274, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 273, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 322, 10, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK-NEXT: rocdl.update.dpp + // CHECK-SAME: with 323, 15, 15, true : f32 + // CHECK-NEXT: llvm.intr.maxnum + + // CHECK: llvm.amdgcn.readlane + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64xf32, #blocked3>) -> f32 + tt.return + } +} + +// ----- + +#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: reduce_xor_max + tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) { + // CHECK: rocdl.ds_swizzle + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 280, 15, 12, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 264, 15, 3, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 276, 15, 10, false : i32 + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 260, 15, 5, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 78, 15, 15, false : i32 + // CHECK: llvm.intr.maxnum + + // CHECK: rocdl.update.dpp + // CHECK-SAME: with 177, 15, 15, false : i32 + %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32xf32, #blocked4>) -> f32 + tt.return + } +} diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h index a49e442d3984..9e174d545dd9 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h @@ -19,6 +19,17 @@ enum class ISAFamily { // Deduces the corresponding ISA family for the given target gfx |arch|. ISAFamily deduceISAFamily(llvm::StringRef arch); +// Here is a partial definition of DppCtrl enums. For the complete definition, +// please check: +// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939 +enum class DppCtrl : uint32_t { + QUAD_PERM_FIRST = 0, + ROW_SHL0 = 0x100, + ROW_SHR0 = 0x110, + BCAST15 = 0x142, + BCAST31 = 0x143 +}; + } // namespace mlir::triton::AMD #endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 3a40d73c2a7c..525361fee603 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +using mlir::triton::AMD::DppCtrl; namespace mlir::triton::AMD { namespace { @@ -103,22 +104,22 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleXor(loc, rewriter, val, i); + return LLVM::AMD::shuffleXor(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleUp(loc, rewriter, val, i); + return LLVM::AMD::shuffleUp(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const { - return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); + return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily()); } Value TargetInfo::programId(RewriterBase &rewriter, Location loc, @@ -126,11 +127,184 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } +// Cast and sext values into specific-length int to meet the requirements of +// instructions like UpdateDpp or readlane if necessary. +static inline Type castToAndSExtInt(RewriterBase &rewriter, Location loc, + Value &val, Type fromType, + unsigned toBits) { + unsigned originalBits = fromType.getIntOrFloatBitWidth(); + Type toType = fromType; + + if (!fromType.isIntOrIndex()) { + val = bitcast(val, int_ty(originalBits)); + toType = int_ty(originalBits); + } + + if (originalBits < toBits) { + val = sext(int_ty(toBits), val); + toType = int_ty(toBits); + } + + return toType; +} + +// Trunc the value to specific length and then cast it to given type if +// necessary. This function is typically used in conjunction with +// castToAndSExtInt. +static inline Value truncAndCastFromInt(RewriterBase &rewriter, Location loc, + Value val, Type valType, + unsigned fromBits) { + unsigned originalBits = valType.getIntOrFloatBitWidth(); + Value toVal = val; + + if (originalBits < fromBits) { + toVal = trunc(int_ty(originalBits), toVal); + } + + if (!valType.isIntOrIndex()) { + toVal = bitcast(toVal, valType); + } + + return toVal; +} + bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { - return false; + if (numLaneToReduce != 64) + return false; + + if (auto family = getISAFamily(); + family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) { + return false; + } + + Operation *reduxOp = op.getSingleCombiner(); + if (!reduxOp) + return false; + + auto createDppReduxOpWithBoundCtrl = [&](Type valType, Value &src, + uint32_t dppCtrl, int rowMask, + int bankMask) -> Value { + // DPP has limited support for data types, so here we need to + // cast non-integer types or integer types shorter than 32 bits + // to int32, except for fp32. + Type actualType = valType; + if (!valType.isF32()) { + actualType = castToAndSExtInt(rewriter, loc, src, valType, 32); + } + + Value dppResult = + rewriter + .create(loc, actualType, src, src, + rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), + rewriter.getBoolAttr(true)) + .getRes(); + + if (!valType.isF32()) { + src = truncAndCastFromInt(rewriter, loc, src, valType, 32); + dppResult = truncAndCastFromInt(rewriter, loc, dppResult, valType, 32); + } + + IRMapping mapping; + mapping.map(reduxOp->getOperand(0), src); + mapping.map(reduxOp->getOperand(1), dppResult); + return rewriter.clone(*reduxOp, mapping)->getResult(0); + }; + + for (int i = 0; i < acc.size(); i++) { + Value buf; + auto valType = acc[i].getType(); + + /* + Here's the implementation of full-wavefront reduction using dpp. + https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + + Each step has a v_mov_dpp instruction following the redux op. In + some cases, the lower-level compiler could merge them into single + instruction. For example, v_mov_dpp + max => v_max_dpp. + + For gfx9, we have 64 threads per warp. These 64 threads are arranged + into 4 rows, with each row being 16 threads. Each 16 threads are arranged + further into 4 banks, with each bank being 4 threads. Overall it's in a + (row, bank, thread) structure. When shuffling, we use row/bank mask to + indicate which row/bank to participate. Then modifier like row_shr and + row_bcast means exact data movement schemes. In the following + instructions, taking row 0 as an example: + + Step 1: Right shift for 8 lanes. + lane 8-15 = redux(lane 0-7, lane 8-15) + + Step 2: Right shift for 4 lanes. + lane 12-15 = redux(lane 8-11, lane 12-15) + + Step 3: Right shift for 2 lanes. + lane 14-15 = redux(lane 12-13, lane 14-15) + + Step 4: Right shift for 1 lane. + lane 15 = redux(lane 14, lane 15) + + Step 5: Broadcast lane 15 of each row to all the lanes of its next row. + lane 16-31 = redux(lane 15, lane 16-31) + + Step 6: Broadcast lane 31 to lane 32-63. + lane 32-63 = redux(lane 31, lane 32-63) + + Now the reduction result is stored in lane 63. + + Step 7: Read the reduction result from lane 63 and broadcast with + readlane. + */ + + const int allRows = 0xf; + const int allBanks = 0xf; + + const uint32_t dppCtrlRowShr = static_cast(DppCtrl::ROW_SHR0); + + // row_shr:8 + buf = createDppReduxOpWithBoundCtrl(valType, acc[i], 8 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:4 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 4 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:2 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 2 + dppCtrlRowShr, + allRows, allBanks); + + // row_shr:1 + buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, + allRows, allBanks); + + // row_bcast:15 row_mask:0xa + buf = createDppReduxOpWithBoundCtrl( + valType, buf, static_cast(DppCtrl::BCAST15), 0xa, allBanks); + + // row_bcast:31 + buf = createDppReduxOpWithBoundCtrl(valType, buf, + static_cast(DppCtrl::BCAST31), + allRows, allBanks); + + // Similarly, we need to cast data types for readlane instruction. + Type actualType = castToAndSExtInt(rewriter, loc, buf, valType, 16); + + // Get reduction result from lane 63 + std::string intrinsic = "llvm.amdgcn.readlane"; + Value result = + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, actualType, + ValueRange{buf, i32_val(63)}) + ->getResult(0); + + result = truncAndCastFromInt(rewriter, loc, result, valType, 16); + + acc[i] = result; + } + + return true; } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 542b1ecbb7fb..0bd401f1993a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,6 +8,8 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +using mlir::triton::AMD::DppCtrl; +using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; @@ -71,8 +73,9 @@ Type castToVectorType(Type ty) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, - Value i, int strideInt, ShflKind mode, Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, + ISAFamily isaFamily, Value val, Value i, + int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -84,7 +87,8 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, if (bits < 32) val = sext(i32_ty, val); - val = shuffleCommon(loc, rewriter, val, i, strideInt, mode, clamp); + val = + shuffleCommon(loc, rewriter, isaFamily, val, i, strideInt, mode, clamp); if (bits < 32) val = trunc(int_ty(bits), val); @@ -98,8 +102,10 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); - val0 = shuffleCommon(loc, rewriter, val0, i, strideInt, mode, clamp); - val1 = shuffleCommon(loc, rewriter, val1, i, strideInt, mode, clamp); + val0 = shuffleCommon(loc, rewriter, isaFamily, val0, i, strideInt, mode, + clamp); + val1 = shuffleCommon(loc, rewriter, isaFamily, val1, i, strideInt, mode, + clamp); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); @@ -134,13 +140,83 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, Value stride = i32_val(32); Value lineId = xor_(threadId, stride); return bpermute(lineId); - } else { - // This map facilates the butterfly shuffle pattern for a stride less - // than 16. The pattern stride is the key of the map. - DenseMap masks{ - {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; - Value offset = i32_val(masks[strideInt]); + } else if (strideInt == 16) { + Value offset = i32_val(0x401F); return rewriter.create(loc, valType, val, offset); + } else { + if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) { + // DPP is only supportted for CDNA2 and CDNA3 right now, so we fallback + // to ds_swizzle for other archs. + // + // This map facilates the butterfly shuffle pattern for a stride less + // than 16. The pattern stride is the key of the map. + DenseMap masks{ + {16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}}; + Value offset = i32_val(masks[strideInt]); + return rewriter.create(loc, valType, val, offset); + } + + auto createDppOpWithoutBoundCtrl = [&](Value &old, Value &src, + uint32_t dppCtrl, uint32_t rowMask, + uint32_t bankMask) { + return rewriter.create( + loc, valType, old, src, rewriter.getI32IntegerAttr(dppCtrl), + rewriter.getI32IntegerAttr(rowMask), + rewriter.getI32IntegerAttr(bankMask), rewriter.getBoolAttr(false)); + }; + + const int allRows = 0xf; + const int allBanks = 0xf; + + switch (strideInt) { + case 1: { + // quad_perm: 1, 0, 3, 2 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {1, 0, 3, 2}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 2: { + // quad_perm: 2, 3, 0, 1 + uint32_t dppCtrl = static_cast(DppCtrl::QUAD_PERM_FIRST); + std::array mask = {2, 3, 0, 1}; + for (int i = 0; i < mask.size(); i++) { + dppCtrl |= mask[i] << (i * 2); + } + return createDppOpWithoutBoundCtrl(val, val, dppCtrl, allRows, + allBanks); + } + case 4: { + // row_shr:4 bank_mask: 0xa + auto ret = createDppOpWithoutBoundCtrl( + val, val, 4 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xa) + .getRes(); + + // row_shl:4 bank_mask: 0x5 + return createDppOpWithoutBoundCtrl( + ret, val, 4 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x5); + } + case 8: { + // row_shr:8 bank_mask: 0xc + auto ret = createDppOpWithoutBoundCtrl( + val, val, 8 + static_cast(DppCtrl::ROW_SHR0), + allRows, 0xc) + .getRes(); + + // row_shl:8 bank_mask: 0x3 + return createDppOpWithoutBoundCtrl( + ret, val, 8 + static_cast(DppCtrl::ROW_SHL0), allRows, + 0x3); + } + default: + assert(false && + "bfly shfl with stride >= 16 should not be handled by dpp."); + } } break; case ShflKind::up: { @@ -158,22 +234,27 @@ static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, return Value(); } -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, - i32_val(0x1f)); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, - i32_val(0x0)); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i32_val(i), i, + ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { - return shuffleIdx(loc, rewriter, val, i32_val(i)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + ISAFamily isaFamily) { + return shuffleIdx(loc, rewriter, val, i32_val(i), isaFamily); } -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { - return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + ISAFamily isaFamily) { + return shuffleCommon(loc, rewriter, isaFamily, val, i, 0, ShflKind::idx, + i32_val(0x1f)); } Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 123234fd4824..d150531848e3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -2,12 +2,14 @@ #define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_UTILITY_H #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "TritonAMDGPUToLLVM/TargetUtils.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" + namespace mlir::LLVM::AMD { const char predicatedLoad[] = "__predicated_load"; @@ -19,10 +21,18 @@ const char predicatedStoreCG[] = "__predicated_store_CG"; const char predicatedStoreCS[] = "__predicated_store_CS"; const char predicatedStoreWT[] = "__predicated_store_WT"; -Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); -Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + mlir::triton::AMD::ISAFamily isaFamily = + mlir::triton::AMD::ISAFamily::Unknown); Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, int axis); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 75f9354104b1..d1cef15a354e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -93,20 +93,12 @@ static std::optional matchReduxKind(triton::ReduceOp op, int computeCapability) { if (computeCapability < 80) return std::nullopt; - if (op.getNumOperands() != 1 || op.getNumResults() != 1) - return std::nullopt; - Block *block = &(*op.getCombineOp().begin()); - Operation *yield = block->getTerminator(); - Operation *reduceOp = yield->getOperand(0).getDefiningOp(); - if (!reduceOp || reduceOp->getNumOperands() != 2 || - reduceOp->getNumResults() != 1) + Operation *reduceOp = op.getSingleCombiner(); + if (!reduceOp) return std::nullopt; auto intType = dyn_cast(reduceOp->getResultTypes()[0]); if (!intType || intType.getWidth() > 32) return std::nullopt; - if (reduceOp->getOperand(0) != block->getArgument(0) || - reduceOp->getOperand(1) != block->getArgument(1)) - return std::nullopt; if (isa(reduceOp)) return NVVM::ReduxKind::ADD; if (isa(reduceOp))