diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index d5eb81eb9f4a..f094ce963a5a 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -63,6 +63,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUStreamPipelineV2(); mlir::registerTritonAMDGPUCanonicalizePointers(); mlir::registerTritonAMDGPUConvertToBufferOps(); + mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); + mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); // TODO: register Triton & TritonGPU passes registry.insert + localStoreOpConversion = nullptr; +}; + void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit); -void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, - PatternBenefit benefit); +// The given callback is invoked at the end of a successful rewrite. The +// callback receives 1) the current source op, 2) the number of issued LLVM +// instructions and 3) their input types. Each MLIR backend can provide a +// callback and, thus, handle backend-specific behaviors. +void populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks = std::nullopt); void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 56a82d7cc0fb..a3d8fe9e64ba 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1366,11 +1366,11 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, Location loc, RewriterBase &rewriter, const TargetInfoBase &target); -void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, - Type elemLlvmTy, ArrayRef srcVals, - Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, - const TargetInfoBase &target); +void storeDistributedToShared( + MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index e2ed0228de8d..38fa1bd62343 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu; // blocked -> shared. // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. -void lowerDistributedToShared(Location loc, Value src, Value dst, - Value adaptorSrc, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &targetInfo) { +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst, auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo); + loc, rewriter, targetInfo, llvmOpCount); } struct LocalAllocOpConversion @@ -200,12 +199,15 @@ struct LocalStoreOpConversion public: using ConvertOpToLLVMPattern< triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + using BackendCallbackType = + decltype(BackendCallbacks::localStoreOpConversion); LocalStoreOpConversion(const LLVMTypeConverter &converter, const TargetInfoBase &targetInfo, + BackendCallbackType backendCallback, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(converter, benefit), - targetInfo(targetInfo) {} + targetInfo(targetInfo), backendCallback(backendCallback) {} LogicalResult matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, @@ -215,24 +217,36 @@ struct LocalStoreOpConversion getTypeConverter()->convertType(op.getDst().getType().getElementType()); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), adaptor.getSrc(), smemObj, getTypeConverter(), - rewriter, targetInfo); + rewriter, targetInfo, &llvmOpCount); + + if (backendCallback) + (backendCallback)(op, llvmOpCount.first, llvmOpCount.second); + rewriter.eraseOp(op); return success(); } private: const TargetInfoBase &targetInfo; + BackendCallbackType backendCallback; }; } // namespace void mlir::triton::populateMemoryOpToLLVMPattern( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, PatternBenefit benefit, + std::optional backendCallbacks) { patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); + + auto backendCall = + backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr; + patterns.add(typeConverter, targetInfo, backendCall, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e857dd36f6cb..67954e5daede 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target) { + const TargetInfoBase &target, + std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { @@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }); + if (!success) llvm::report_fatal_error("Failed to emit transfer from register to shared"); } diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir new file mode 100644 index 000000000000..400c219b6790 --- /dev/null +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -0,0 +1,103 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 + +module { + // INSERT_IGLP0-LABEL: @test_dot_op + // INSERT_IGLP1-LABEL: @test_dot_op + // INSTR_COUNT_NS1-LABEL: @test_dot_op + // INSTR_COUNT_NS2-LABEL: @test_dot_op + // LABELING_PS_1-LABEL: @test_dot_op + // LABELING_PS_2-LABEL: @test_dot_op + tt.func @test_dot_op(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, + %B : !tt.ptr {tt.divisibility = 16 : i32}, + %C : !tt.ptr {tt.divisibility = 16 : i32}) { + // A ptrs + %a_ptr_splat = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr> + %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32> + %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> + %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32> -> tensor<128x32xi32> + %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + // B ptrs + %b_ptr_splat = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr> + %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32> -> tensor<32x128xi32> + %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + + %a_mask = arith.constant dense : tensor<128x32xi1> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16> + %b_mask = arith.constant dense : tensor<32x128xi1> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32> + + %a_off = arith.constant dense<4> : tensor<128x32xi32> + %b_off = arith.constant dense<4> : tensor<32x128xi32> + + %loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32>) { + %a = tt.load %a_ptr : tensor<128x32x!tt.ptr> + %b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr> + + // INSERT_IGLP0: rocdl.iglp.opt 0 + // INSERT_IGLP1: rocdl.iglp.opt 1 + + // INSTR_COUNT_NS1: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // INSTR_COUNT_NS2: amdgpu.instruction_sched_hint + // INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false + // INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false + // INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> + + // USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints] + // USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions. + + // LABELING_PS_1: scf.for + // LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_1: %[[REG1_OP0:.+]] = triton_gpu.convert_layout %[[REG0_OP0]] + // LABELING_PS_1: %[[REG1_OP1:.+]] = triton_gpu.convert_layout %[[REG0_OP1]] + // LABELING_PS_1: tt.dot %[[REG1_OP0]], %[[REG1_OP1]], {{.*}} + + // LABELING_PS_2: scf.for + // LABELING_PS_2: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>} + // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP0]], %{{.*}} {OpIdx = #amdgpu.OpIdx<0>} + // LABELING_PS_2: triton_gpu.local_store %[[REG0_OP1]], %{{.*}} {OpIdx = #amdgpu.OpIdx<1>} + + %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr>, tensor<128x32xi32> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr>, tensor<32x128xi32> + scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr>, tensor<32x128x!tt.ptr>, tensor<128x128xf32> + } + + // C ptrs + %c_ptr_splat = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr> + %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32> + %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32> -> tensor<128x128xi32> + %c_ptr = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + + tt.store %c_ptr, %loop#2 : tensor<128x128x!tt.ptr> + tt.return +} +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 390d1c83e61d..8669f5e04707 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -274,7 +274,7 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant) + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.num_stages, options.instruction_sched_variant) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 31a43acd2f89..c0aa08421bdd 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -32,4 +32,31 @@ class TritonAMDGPU_Attr traits = [], : AttrDef { } +def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "OpIdx"; + let summary = "An operand index attribute."; + let description = [{ + The attribute is a way to describe which input argument of the target + operation (e.g., `tt.dot`) the result of a given operation belongs to. + }]; + + let parameters = (ins "uint32_t":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "InstCounter"; + let summary = "An instruction counter attribute."; + let description = [{ + The attribute holds the number of issued LLVM instructions of a specific kind as well as + the data type. + }]; + + let parameters = (ins "uint32_t":$value, "Type":$type); + let assemblyFormat = "`<` params `>`"; +} + + #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index d5956cf7a33c..c0c18b07e907 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -35,6 +35,9 @@ def TritonAMDGPU_Dialect : Dialect { }]; let dependentDialects = []; + + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 538e31378fe8..68c50d48635b 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -57,7 +57,29 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { interleave for better instruction level parallelism. }]; - let assemblyFormat = [{attr-dict}]; + let arguments = (ins + TritonAMDGPU_InstCounter:$numDsReadsA, + TritonAMDGPU_InstCounter:$numDsReadsB, + TritonAMDGPU_InstCounter:$numDsWritesA, + TritonAMDGPU_InstCounter:$numDsWritesB, + TritonAMDGPU_InstCounter:$numGlobalLoadsA, + TritonAMDGPU_InstCounter:$numGlobalLoadsB, + BoolAttr:$isBufferLoadsAEnabled, + BoolAttr:$isBufferLoadsBEnabled, + TritonAMDGPU_InstCounter:$numMMAs + ); + + let builders = [ + OpBuilder<(ins), [{ + auto ctx = $_state.getContext(); + auto noneType = NoneType::get(ctx); + auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType); + build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr, + emptyAttr, emptyAttr, false, false, emptyAttr); + }]> + ]; + + let assemblyFormat = [{ attr-dict }]; } // diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index bd726bd845d2..4036cdecd1bd 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -36,9 +36,10 @@ createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(bool ftz); std::unique_ptr> -createInsertInstructionSchedHintsPass(); +createTritonAMDGPUInsertInstructionSchedHintsPass(); std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant); +createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages, + std::string variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 9f4665aef217..0c1ccee76d77 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -59,20 +59,25 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul ]; } -def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; - let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; } -def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { +def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { let summary = "Lower instruction scheduling hints to LLVM intrinsics"; - let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")"; + let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*numStages=*/2, /*variant=*/\"\")"; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::ROCDL::ROCDLDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ + Option<"numStages", "num_stages", "int32_t", /*default*/"2", + "number of pipeline stages">, Option<"variant", "variant", "std::string", /*default*/"\"default\"", "instruction scheduling variant">, ]; diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 433e60be67f6..93345b0d6de4 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -13,7 +13,7 @@ def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; - let dependentDialects = []; + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; let options = [ Option<"numStages", "num_stages", diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index a82a77e9f57e..1e429fdc39a9 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -24,6 +24,9 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" + +#include "llvm/ADT/TypeSwitch.h" // clang-format off #include "Dialect/TritonAMDGPU/IR/Dialect.h" @@ -45,5 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#define GET_ATTRDEF_CLASSES +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + #define GET_OP_CLASSES #include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index b832d985bbe7..9043090802bf 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -336,6 +337,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int elemsPerLoad = numOfElems / loadsPerThread; assert(numOfElems % loadsPerThread == 0); + VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -346,7 +348,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset; loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; @@ -363,6 +364,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = mfmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index b60c86e1a3a5..1ca9e49745d6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -21,6 +21,7 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "SharedToDotOperandHelper.h" #include "Utility.h" @@ -212,6 +213,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int loadsPerThread = offsets.size() / (numRepNonK * numRepK); int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread; assert(numElemsPerThreadPerRep % loadsPerThread == 0); + auto loadVecTy = vec_ty(elemTy, elemsPerLoad); for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), @@ -221,7 +223,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep); Value valVec = undef(vecTy); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { - auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset = offsets[nonK * loadsPerThread * numRepK + k * loadsPerThread + loadId]; loadOffset = add(loadOffset, batchOffset); @@ -237,6 +238,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } } + for (auto op : tensor.getUsers()) { + if (auto localLoadOp = llvm::dyn_cast(op)) { + const size_t numDsReadsCount = + repB * numRepNonK * numRepK * loadsPerThread; + setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy); + } + } + MLIRContext *ctx = wmmaLayout.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(loadedValues.size(), loadedValues[0].getType())); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index 204d54894d3b..1eed112c30c0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -21,9 +21,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "TritonAMDGPUTransforms/MfmaGroup.h" #include "Utility.h" - #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" using namespace mlir; @@ -261,6 +261,14 @@ struct DotOpMFMAConversionHelper { Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + Type elemtTy = elemTyA; + const size_t mmaCount = + numRepB * numRepM * numRepN * numRepK * kWidth / kBase; + setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(), + maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(), + elemtTy); + rewriter.replaceOp(op, res); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 5a003f768833..0042cf89e93b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -22,6 +22,7 @@ */ #include "../PatternTritonGPUOpToLLVM.h" +#include "../TritonAMDGPUToLLVM/SchedInstructions.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -325,6 +326,10 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, Type structTy = LLVM::LLVMStructType::getLiteral( wmmaLayout.getContext(), SmallVector(fc.size(), dstElemTy)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); + + const size_t mmaCount = numRepB * numRepM * numRepN * numRepK; + setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy); + rewriter.replaceOp(op, res); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 5265f631ad9e..ef0ef5e59132 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,6 +1,7 @@ #include "BufferOpsEmitter.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -276,6 +277,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto cacheMod = op.getCache(); SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; @@ -286,7 +288,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, assert(wordNElems * nWords * numVecs == numElems); Value pred = mask ? maskElems[vecStart] : int_val(1, 1); - auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); @@ -309,6 +310,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -391,6 +395,10 @@ struct BufferLoadOpConversion Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, rewriter, llvmResultStructTy); + + const int numVecs = numElems / vec; + setNumGeneratedGlobalLoads(op, numVecs, vecTy); + rewriter.replaceOp(op, {resultStruct}); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index 9bed87961966..62ef7a164337 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -1,87 +1,157 @@ +#include "SchedInstructions.h" #include "TritonAMDGPUToLLVM/Passes.h" - +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" -#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir::triton { -#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS -#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPUINSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_TRITONAMDGPULOWERINSTRUCTIONSCHEDHINTS #include "TritonAMDGPUToLLVM/Passes.h.inc" } // namespace mlir::triton +#undef DEBUG_TYPE +#define DEBUG_TYPE "lower-insert-instruction-sched-hints" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + using namespace mlir; -namespace { +// TODO: The following passes/algorithms are applicable only for a single +// `tt.dot` op in a `scf.for` block -i.e., a single schedule hint op per block. +// Note, we need to relax this assumption in the future and extend the current +// implementation. -// The bitmask that encodes kinds of the instructions from AMD ISA. -// The bitmask is used for providing instruction scheduling hints. -enum InstructionKindMask { - NONE = 0x0000000, - ALL_ALU = 0x00000001, - VALU = 0x00000002, - SALU = 0x00000004, - MFMA = 0x00000008, - ALL_VMEM = 0x00000010, - VMEM_READ = 0x00000020, - VMEM_WRITE = 0x00000040, - ALL_DS = 0x00000080, - DS_READ = 0x00000100, - DS_WRITE = 0x00000200 -}; +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType) { + auto *ctx = op->getContext(); + auto mmaType = RankedTensorType::get({m, n, k}, elementType); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + schedHint.setNumMMAsAttr(counterAttr); + }); +} + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->template getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + const bool isBufferLoadOp = + std::is_same_v; + if (opIdxAttr.getValue() == 0) { + schedHint.setNumGlobalLoadsAAttr(counterAttr); + schedHint.setIsBufferLoadsAEnabled(isBufferLoadOp); + } else { + schedHint.setNumGlobalLoadsBAttr(counterAttr); + schedHint.setIsBufferLoadsBEnabled(isBufferLoadOp); + } + } + }); +} +template void setNumGeneratedGlobalLoads(triton::amdgpu::BufferLoadOp op, + size_t globalLoadsCount, Type type); +template void setNumGeneratedGlobalLoads(triton::LoadOp op, + size_t globalLoadsCount, Type type); + +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount, + Type type) { + auto *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + const size_t opIdx = dotOperandLayout.getOpIdx(); + assert(opIdx < 2); + if (opIdx == 0) + schedHint.setNumDsReadsAAttr(counterAttr); + else + schedHint.setNumDsReadsBAttr(counterAttr); + }); +} + +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, + size_t localStoreOpCount, Type type) { + MLIRContext *ctx = op->getContext(); + auto counterAttr = + triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type); + + op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) { + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + assert(opIdxAttr.getValue() < 2); + if (opIdxAttr.getValue() == 0) + schedHint.setNumDsWritesAAttr(counterAttr); + else + schedHint.setNumDsWritesBAttr(counterAttr); + } + }); +} + +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp) { + triton::DotOp dotOp = nullptr; + size_t dotCounter = 0; + forOp->walk( + [&dotOp, &dotCounter](triton::DotOp op) { dotOp = op, ++dotCounter; }); + + return (dotCounter == 1) ? dotOp : nullptr; +} +} // namespace mlir::triton + +namespace { // Create an intrinsic to control how different instruction kinds should // interleave for better ILP. void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, - InstructionKindMask maskValue, int sizeValue, - int groupIdValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.group.barrier"; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - Value size = - LLVM::createConstantI32(loc, rewriter, static_cast(sizeValue)); - Value groupId = LLVM::createConstantI32(loc, rewriter, - static_cast(groupIdValue)); - - LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, TypeRange{}, - ValueRange{mask, size, groupId}); + mlir::amdgpu::sched_barrier_opt_enum maskValue, + int sizeValue, int groupIdValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + IntegerAttr size = + rewriter.getI32IntegerAttr(static_cast(sizeValue)); + IntegerAttr groupId = + rewriter.getI32IntegerAttr(static_cast(groupIdValue)); + rewriter.create(loc, mask, size, groupId); } // Insert intrinsic that controls the types of instructions that may be -// allowed to cross the intrinsic during instruction scheduling +// allowed to cross the intrinsic during instruction scheduling. Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, - int64_t maskValue) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.sched.barrier"; - LLVM::FastmathFlagsAttr defaultFlags{}; - - Value mask = - LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{mask}); + mlir::amdgpu::sched_barrier_opt_enum maskValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + return rewriter.create(loc, mask); } // Insert an experimental intrinsic for instruction group level parallelism. // The intrinsic takes a value that specifies the strategy. Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { - MLIRContext *ctx = rewriter.getContext(); - const char *intrinsicName = "llvm.amdgcn.iglp.opt"; - LLVM::FastmathFlagsAttr defaultFlags{}; - Value iglpValue = - LLVM::createConstantI32(loc, rewriter, static_cast(value)); - return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, - TypeRange{}, ValueRange{iglpValue}); + IntegerAttr iglpValue = + rewriter.getI32IntegerAttr(static_cast(value)); + return rewriter.create(loc, iglpValue); } struct InstructionSchedHintsRewriter : public OpRewritePattern { - InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) - : OpRewritePattern(ctx) { + InstructionSchedHintsRewriter(MLIRContext *ctx, int32_t numStages, + std::string variant) + : OpRewritePattern(ctx), numStages(numStages) { std::transform(variant.begin(), variant.end(), variant.begin(), [](unsigned char c) { return std::tolower(c); }); @@ -89,20 +159,162 @@ struct InstructionSchedHintsRewriter .Case("default", SchedulingType::NONE) .Case("iglp0", SchedulingType::IGLP0) .Case("iglp1", SchedulingType::IGLP1) + .Case("ck_v3", SchedulingType::CK_V3) .Default(SchedulingType::UNKNOWN); + + if (this->numStages < 2) { + this->schedulingType = SchedulingType::NONE; + LDBG("ignoring instruction scheduling due to a very low num. " + "stages value. Must be >= 2"); + } } - enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; + enum class SchedulingType : uint32_t { + NONE = 0, + IGLP0, + IGLP1, + CK_V3, + UNKNOWN + }; + + // This is the implementation of the CK's V3 pipelining (see + // see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). + // This scheduling requires 1x register and 1x LDS buffers combined with the + // local (LDS to registers) and global (HBM to registers) data prefetching. + // see: + // include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h + void + createCKV3Schedule(PatternRewriter &rewriter, Location loc, + triton::amdgpu::InstructionSchedHint schedHint) const { + + if (!(schedHint.getIsBufferLoadsAEnabled() && + schedHint.getIsBufferLoadsBEnabled())) { + LDBG("Skipping instruction scheduling because `ck_v3` " + "scheduling can be used only with `buffer_load` instructions."); + return; + } + + const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue(); + const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue(); + + const uint32_t numDsWriteInstA = schedHint.getNumDsWritesA().getValue(); + const uint32_t numDsWriteInstB = schedHint.getNumDsWritesB().getValue(); + + const uint32_t numBufferLoadInstA = + schedHint.getNumGlobalLoadsA().getValue(); + const uint32_t numBufferLoadInstB = + schedHint.getNumGlobalLoadsB().getValue(); + + if (numBufferLoadInstA == 0) + schedHint.emitError("buffer load count for tile A must be initialized"); + + if (numBufferLoadInstB == 0) + schedHint.emitError("buffer load count for tile B must be initialized"); + + const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue(); + + auto mfmaType = cast(schedHint.getNumMMAs().getType()); + const uint32_t nPerXDL = mfmaType.getShape()[1]; + const uint32_t mfmaCycle = nPerXDL == 16 ? 16 : 32; + + auto dsReadsAType = cast(schedHint.getNumDsReadsA().getType()); + auto dsReadsBType = cast(schedHint.getNumDsReadsB().getType()); + + const uint32_t dsReadAIssueCycle = dsReadsAType.getShape()[0] == 16 ? 8 : 4; + const uint32_t dsReadBIssueCycle = dsReadsBType.getShape()[0] == 16 ? 8 : 4; + + const auto dsReadAMfmaRate = + (mfmaCycle - 4 + 2 * dsReadAIssueCycle - 1) / (2 * dsReadAIssueCycle); + const auto dsReadBMfmaRate = + (mfmaCycle - 4 + 2 * dsReadBIssueCycle - 1) / (2 * dsReadBIssueCycle); + + const auto numDsreadAMfma = + (numDsReadInstA + dsReadAMfmaRate - 1) / dsReadAMfmaRate; + const auto numDsreadBMfma = + (numDsReadInstB + dsReadBMfmaRate - 1) / dsReadBMfmaRate; + + // stage 1 + const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma); + const auto numMfmaPerIssue = + numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB); + + const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA; + const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB; + + for (size_t i = 0; i < numBufferLoadInstA; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMfmaPerIssue - numDswritePerIssueA, 0); + } + + for (size_t i = 0; i < numBufferLoadInstB; ++i) { + for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_write, + 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + 1, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0); + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, + numMfmaPerIssue - numDswritePerIssueB, 0); + } + + // stage 2 + for (size_t i = 0; i < numDsreadAMfma; ++i) { + if ((numDsReadInstA - (i + 1) * dsReadAMfmaRate) >= dsReadAMfmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadAMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstA - (numDsreadAMfma - 1) * dsReadAMfmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + + for (size_t i = 0; i < numDsreadBMfma; ++i) { + if ((numDsReadInstB - (i + 1) * dsReadBMfmaRate) >= dsReadBMfmaRate) { + createSchedGroupBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::ds_read, + dsReadBMfmaRate, 0); + } else { + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::ds_read, + numDsReadInstB - (numDsreadBMfma - 1) * dsReadBMfmaRate, 0); + } + createSchedGroupBarrier( + rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma, 1, 0); + } + } LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { + if (this->schedulingType == SchedulingType::NONE) { + rewriter.eraseOp(instructionSchedHint); + return success(); + } if (this->schedulingType == SchedulingType::UNKNOWN) { - llvm::dbgs() - << "[" << getDebugName() << "]: " - << "unknown instruction scheduling variant has been provided\n"; - return mlir::failure(); + instructionSchedHint.emitError( + "unknown instruction scheduling variant has been provided"); + return failure(); } // The switch controls whether instructions are allowed to cross the basic @@ -110,13 +322,15 @@ struct InstructionSchedHintsRewriter // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::IGLP0 || + !(schedulingType == SchedulingType::NONE || + schedulingType == SchedulingType::IGLP0 || schedulingType == SchedulingType::IGLP1); Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); if (limitSchedulingRange) { rewriter.setInsertionPointToStart(block); - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); } rewriter.setInsertionPoint(block, std::prev(block->end())); @@ -128,6 +342,10 @@ struct InstructionSchedHintsRewriter createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); break; } + case SchedulingType::CK_V3: { + createCKV3Schedule(rewriter, loc, instructionSchedHint); + break; + } case SchedulingType::NONE: [[fallthrough]]; default: { @@ -136,21 +354,25 @@ struct InstructionSchedHintsRewriter } if (limitSchedulingRange) - createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + createSchedBarrier(rewriter, loc, + mlir::amdgpu::sched_barrier_opt_enum::none); rewriter.eraseOp(instructionSchedHint); - return mlir::success(); + return success(); } private: + int32_t numStages; SchedulingType schedulingType; }; -struct LowerInstructionSchedHints - : public triton::impl::LowerInstructionSchedHintsBase< - LowerInstructionSchedHints> { +struct TritonAMDGPULowerInstructionSchedHints + : public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase< + TritonAMDGPULowerInstructionSchedHints> { - explicit LowerInstructionSchedHints(std::string variant) { + explicit TritonAMDGPULowerInstructionSchedHints(int32_t numStages, + std::string variant) { + this->numStages = numStages; this->variant = variant; } @@ -161,29 +383,40 @@ struct LowerInstructionSchedHints ConversionTarget target(*ctx); target.addLegalDialect(); target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); RewritePatternSet patterns(ctx); - patterns.add(ctx, this->variant); + + patterns.add(ctx, this->numStages, + + this->variant); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); } } }; -struct InsertInstructionSchedHints - : public triton::impl::InsertInstructionSchedHintsBase< - InsertInstructionSchedHints> { +struct TritonAMDGPUInsertInstructionSchedHints + : public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase< + TritonAMDGPUInsertInstructionSchedHints> { + void runOnOperation() override { MLIRContext *ctx = &getContext(); ModuleOp mod = getOperation(); - mod->walk([ctx](triton::DotOp dot) { - if (dyn_cast(dot->getParentOp())) { - mlir::OpBuilder rewriter(ctx); - rewriter.setInsertionPointAfter(dot); - rewriter.create(dot->getLoc()); + mod.walk([this, ctx](scf::ForOp forOp) { + // Note, instruction schedule barriers are inserted only in the case of + // a single `tt.dot` op in a `scf::ForOp` scope in the current + // implementation. + if (auto dotOp = getSingleDotOpIfExists(forOp)) { + OpBuilder rewriter(ctx); + rewriter.setInsertionPointAfter(dotOp); + rewriter.create(dotOp->getLoc()); } }); } @@ -192,12 +425,14 @@ struct InsertInstructionSchedHints namespace mlir::triton { std::unique_ptr> -createLowerInstructionSchedHintsPass(std::string variant) { - return std::make_unique(variant); +createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages, + std::string variant) { + return std::make_unique(numStages, + variant); } std::unique_ptr> -createInsertInstructionSchedHintsPass() { - return std::make_unique(); +createTritonAMDGPUInsertInstructionSchedHintsPass() { + return std::make_unique(); } } // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h new file mode 100644 index 000000000000..45985fe808f2 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H +#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_SCHED_INSTRUCTIONS_H + +#include "mlir/IR/Types.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// The following functions are used to collect and set side-channel information +// during to LLVM conversion/lowering to facilitate instruction scheduling +// controls. +namespace mlir::triton { +void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n, + unsigned k, Type elementType); + +template +void setNumGeneratedGlobalLoads(LoadOpType op, size_t globalLoadsCount, + Type type); +void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount, + Type type); +void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount, + Type type); +triton::DotOp getSingleDotOpIfExists(scf::ForOp forOp); +} // namespace mlir::triton + +#endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index d227bb6c6a4b..f99cd50b0d27 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,6 +1,7 @@ #include "TritonAMDGPUToLLVM/Passes.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SchedInstructions.h" #include "TargetInfo.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -20,6 +21,7 @@ #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -72,8 +74,9 @@ struct ConvertTritonAMDGPUToLLVM } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override { @@ -193,8 +196,12 @@ struct ConvertTritonAMDGPUToLLVM commonBenefit); populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns, commonBenefit); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, commonBenefit); + + mlir::triton::BackendCallbacks callbacks; + callbacks.localStoreOpConversion = storeOpConversionCallback; + + mlir::triton::populateMemoryOpToLLVMPattern( + typeConverter, targetInfo, patterns, commonBenefit, callbacks); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, commonBenefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index f1d922041fcf..e66a2feb57fe 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -177,8 +177,21 @@ struct ConvertTritonLoadToBufferLoad Value maybeMask{}; if (op.getMask() && !isZeroConst(op.getMask())) maybeMask = op.getMask(); - rewriter.replaceOpWithNewOp( - op, op.getType(), basePtr, tensorOffset, maybeMask, maybeOther); + + auto bufferLoadOp = rewriter.create( + op->getLoc(), op.getType(), basePtr, tensorOffset, maybeMask, + maybeOther); + + // Propagate `OpIdxAttr` if the currently processed `tt.LoadOp` was + // labeled it. The attribute needs to be preserved for custom instruction + // scheduling. + if (auto opIdxAttr = op->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + bufferLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), + opIdxAttr); + } + rewriter.replaceOp(op, bufferLoadOp); + return success(); } LDBG("Failed to convert: " << op); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index deb566a8b1b5..3b4935026c3f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -1,6 +1,8 @@ #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -168,6 +170,15 @@ void StreamPipeliner::createStreamCopy( result = select->getResults(); } + // If the currently processed `LoadOp` is labeled with an index regarding + // to which `DotOp` operand the corresponding data belongs to, then label the + // expanded `LocalStoreOp` with the same index. This is required for + // instruction scheduling hints to correctly count the emitted `ds_write` + // instructions for each GEMM tile. + if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) { + storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); + } + loadOp->replaceAllUsesWith(result); // Prefetch load ahead of the dot stage if is used by the dot. @@ -685,6 +696,41 @@ bool StreamPipeliner::pipelineLoop() { } namespace { +// Go through a single use chain to get the result of the target op after all +// unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. +template Operation *passPrevUnaryOps(Value value) { + auto getNextUnaryOps = [](Value value) -> Operation * { + if (auto defOp = value.getDefiningOp()) { + if ((defOp->getNumOperands() == 1) || llvm::dyn_cast(defOp)) + return defOp; + } + return nullptr; + }; + + auto unaryOp = getNextUnaryOps(value); + while (unaryOp) { + if (llvm::dyn_cast(unaryOp)) + return unaryOp; + unaryOp = getNextUnaryOps(unaryOp->getOperand(0)); + } + return nullptr; +} + +// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand +// index. Note, this is a part of the instruction scheduling routine. Currently, +// we support `forOp`s which contain only a single `tt.DotOp` in the bodies. +void labelLoadOpsForTritonDot(scf::ForOp forOp) { + mlir::MLIRContext *ctx = forOp->getContext(); + if (auto dotOp = triton::getSingleDotOpIfExists(forOp)) { + for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) { + if (auto loadOp = passPrevUnaryOps(dotOperand)) { + auto opIdxAttr = triton::amdgpu::OpIdxAttr::get(ctx, opIdx); + loadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), opIdxAttr); + } + } + } +} + struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { PipelinePass() = default; PipelinePass(int32_t numStages) { this->numStages = numStages; } @@ -692,6 +738,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { void runOnOperation() override { SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { + labelLoadOpsForTritonDot(forOp); // Bail out for loops with num_stage <= 1. if (getNumStagesOrDefault(forOp) > 1) loops.push_back(forOp); diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index f97676aafe36..d30be6959839 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -45,11 +45,12 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { pm.addPass(createConvertBuiltinFuncToLLVMPass(ftz)); }); m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { - pm.addPass(createInsertInstructionSchedHintsPass()); + pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass()); }); m.def("lower_instruction_sched_hints", - [](mlir::PassManager &pm, std::string variant) { - pm.addPass(createLowerInstructionSchedHintsPass(variant)); + [](mlir::PassManager &pm, int32_t numStages, std::string variant) { + pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(numStages, + variant)); }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) {