diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index c9e23636142f..128ffa9fc46e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -73,6 +73,7 @@ iree_compiler_cc_library( "GPUPatterns.cpp", "GPUPipelining.cpp", "GPUPromoteMatmulOperands.cpp", + "GPUPropagateDispatchSizeBounds.cpp", "GPUReduceBankConflicts.cpp", "GPUReuseSharedMemoryAllocs.cpp", "GPUTensorAlloc.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 2aeb9add5f0c..97d324042e2c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -71,6 +71,7 @@ iree_cc_library( "GPUPatterns.cpp" "GPUPipelining.cpp" "GPUPromoteMatmulOperands.cpp" + "GPUPropagateDispatchSizeBounds.cpp" "GPUReduceBankConflicts.cpp" "GPUReuseSharedMemoryAllocs.cpp" "GPUTensorAlloc.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp new file mode 100644 index 000000000000..43aa70be6919 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPropagateDispatchSizeBounds.cpp @@ -0,0 +1,103 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_GPUPROPAGATEDISPATCHSIZEBOUNDSPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { + +static void applyBounds(FunctionOpInterface funcOp, + ArrayRef workgroupSizes, + ArrayRef workgroupCounts) { + Builder b(funcOp->getContext()); + funcOp->walk([&](Operation *op) { + TypeSwitch(op) + .Case([&](gpu::ThreadIdOp tidOp) { + tidOp.setUpperBoundAttr(b.getIndexAttr( + workgroupSizes[static_cast(tidOp.getDimension())])); + }) + .Case([&](IREE::HAL::InterfaceWorkgroupSizeOp wgSizeOp) { + wgSizeOp.setUpperBoundAttr(b.getIndexAttr( + workgroupSizes[wgSizeOp.getDimension().getZExtValue()])); + }) + .Case([&](IREE::HAL::InterfaceWorkgroupIDOp wgIdOp) { + wgIdOp.setUpperBoundAttr(b.getIndexAttr( + workgroupCounts[wgIdOp.getDimension().getZExtValue()])); + }) + .Case([&](IREE::HAL::InterfaceWorkgroupCountOp wgCountOp) { + wgCountOp.setUpperBoundAttr(b.getIndexAttr( + workgroupCounts[wgCountOp.getDimension().getZExtValue()])); + }) + .Default([](Operation *) {}); + }); +} + +struct GPUPropagateDispatchSizeBoundsPass final + : impl::GPUPropagateDispatchSizeBoundsPassBase< + GPUPropagateDispatchSizeBoundsPass> { + using Base::Base; + + void runOnOperation() override { + FunctionOpInterface funcOp = getOperation(); + IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); + if (!target) { + funcOp.emitWarning("no known target attribute late in GPU codegen"); + return; + } + SmallVector workgroupSizes( + target.getWgp().getMaxWorkgroupSizes().asArrayRef()); + SmallVector workgroupCounts( + target.getWgp().getMaxWorkgroupCounts().asArrayRef()); + + std::optional> staticWorkgroupSize = + getWorkgroupSize(funcOp); + + // Late in codegen, we've reconciled the workgroup size onto the export op. + if (std::optional exportOp = + getEntryPoint(funcOp)) { + if (std::optional exportWorkgroupSize = + exportOp->getWorkgroupSize()) { + staticWorkgroupSize = + llvm::map_to_vector(exportWorkgroupSize->getAsRange(), + [](IntegerAttr a) { return a.getInt(); }); + } + } + + if (staticWorkgroupSize) { + // Target info with no workgroup sizes gives a 0-length array, hence no + // zip_equal. + for (auto [size, staticSize] : + llvm::zip(workgroupSizes, *staticWorkgroupSize)) { + size = staticSize; + } + } + SmallVector staticWorkgroupCounts = getStaticNumWorkgroups(funcOp); + assert(staticWorkgroupCounts.size() <= 3 && + "workgroup counts are 3D at most"); + for (auto [count, staticCount] : + llvm::zip(workgroupCounts, staticWorkgroupCounts)) { + if (staticCount != ShapedType::kDynamic) { + count = staticCount; + } + } + + applyBounds(funcOp, workgroupSizes, workgroupCounts); + } +}; +} // namespace + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 789130940477..b3fdd50d4d46 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -178,6 +178,11 @@ def GPUPromoteMatmulOperandsPass : ]; } +def GPUPropagateDispatchSizeBoundsPass : + InterfacePass<"iree-codegen-gpu-propagate-dispatch-size-bounds", "mlir::FunctionOpInterface"> { + let summary = "Pass to annotate workitem and workgroup IDs with known bounds"; +} + def GPUReduceBankConflictsPass : InterfacePass<"iree-codegen-gpu-reduce-bank-conflicts", "mlir::FunctionOpInterface"> { let summary = "Pass to try to reduce the number of bank conflicts by padding memref.alloc ops."; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 41afbb6559f3..dc8e6a181ccf 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -41,6 +41,7 @@ iree_lit_test_suite( "gpu_pad_operands.mlir", "gpu_pipeline.mlir", "gpu_promote_matmul_operands.mlir", + "gpu_propagate_dispatch_size_bounds.mlir", "gpu_reorder_workgroups_static.mlir", "gpu_reorder_workgroups.mlir", "gpu_reuse_shared_memory_allocs.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index ad86649ada78..4dc0f289d3d5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -37,6 +37,7 @@ iree_lit_test_suite( "gpu_pad_operands.mlir" "gpu_pipeline.mlir" "gpu_promote_matmul_operands.mlir" + "gpu_propagate_dispatch_size_bounds.mlir" "gpu_reorder_workgroups.mlir" "gpu_reorder_workgroups_static.mlir" "gpu_reuse_shared_memory_allocs.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir new file mode 100644 index 000000000000..f26f2c5dfe52 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_propagate_dispatch_size_bounds.mlir @@ -0,0 +1,122 @@ +// RUN: iree-opt %s --split-input-file \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-propagate-dispatch-size-bounds)))))" \ +// RUN: | FileCheck %s + +// Note: not the real target definition, missing types +#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target>}> +#pipeline_layout = #hal.pipeline.layout]> + +hal.executable private @static { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target) { + hal.executable.export public @static ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} { + ^bb0(%arg0: !hal.device): + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + hal.return %c32, %c8, %c1 : index, index, index + } + builtin.module { +// CHECK-LABEL: func.func @static + func.func @static() { +// CHECK: gpu.thread_id x upper_bound 64 +// CHECK: gpu.thread_id y upper_bound 2 +// CHECK: gpu.thread_id z upper_bound 1 + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %thread_id_z = gpu.thread_id z + +// CHECK: hal.interface.workgroup.size[0] upper_bound 64 +// CHECK: hal.interface.workgroup.size[1] upper_bound 2 +// CHECK: hal.interface.workgroup.size[2] upper_bound 1 + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + +// CHECK: hal.interface.workgroup.id[0] upper_bound 32 +// CHECK: hal.interface.workgroup.id[1] upper_bound 8 +// CHECK: hal.interface.workgroup.id[2] upper_bound 1 + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + +// CHECK: hal.interface.workgroup.count[0] upper_bound 32 +// CHECK: hal.interface.workgroup.count[1] upper_bound 8 +// CHECK: hal.interface.workgroup.count[2] upper_bound 1 + %workgroup_conut_x = hal.interface.workgroup.count[0] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + + return + } + } + } +} + +// ----- + +#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", + {iree.gpu.target = #iree_gpu.target>}> +#pipeline_layout = #hal.pipeline.layout]> + +hal.executable private @dynamic { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target) { + hal.executable.export public @dynamic ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index): + %count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1] + %count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2] + %count_z = arith.constant 1 : index + hal.return %count_x, %count_y, %count_z : index, index, index + } + builtin.module { + func.func @dynamic() { +// CHECK: gpu.thread_id x upper_bound 1024 +// CHECK: gpu.thread_id y upper_bound 1024 +// CHECK: gpu.thread_id z upper_bound 1024 + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %thread_id_z = gpu.thread_id z + +// CHECK: hal.interface.workgroup.size[0] upper_bound 1024 +// CHECK: hal.interface.workgroup.size[1] upper_bound 1024 +// CHECK: hal.interface.workgroup.size[2] upper_bound 1024 + %workgroup_size_x = hal.interface.workgroup.size[0] : index + %workgroup_size_y = hal.interface.workgroup.size[1] : index + %workgroup_size_z = hal.interface.workgroup.size[2] : index + +// CHECK: hal.interface.workgroup.id[0] upper_bound 2147483647 +// CHECK: hal.interface.workgroup.id[1] upper_bound 2147483647 +// CHECK: hal.interface.workgroup.id[2] upper_bound 1 + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %workgroup_id_y = hal.interface.workgroup.id[1] : index + %workgroup_id_z = hal.interface.workgroup.id[2] : index + +// CHECK: hal.interface.workgroup.count[0] upper_bound 2147483647 +// CHECK: hal.interface.workgroup.count[1] upper_bound 2147483647 +// CHECK: hal.interface.workgroup.count[2] upper_bound 1 + %workgroup_conut_x = hal.interface.workgroup.count[0] : index + %workgroup_count_y = hal.interface.workgroup.count[1] : index + %workgroup_count_z = hal.interface.workgroup.count[2] : index + + return + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index c056d44538bb..1441f959b0bb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -505,7 +505,10 @@ struct HALInterfaceWorkgroupOpsConverter final int32_t index = static_cast(op.getDimension().getSExtValue()); std::array dimAttr{gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}; - rewriter.replaceOpWithNewOp(op, op.getType(), dimAttr[index]); + NewOpTy newOp = + rewriter.replaceOpWithNewOp(op, op.getType(), dimAttr[index]); + if (IntegerAttr bound = op.getUpperBoundAttr()) + newOp.setUpperBoundAttr(bound); return success(); } }; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 53e49efbf66a..f8ebe1cc0069 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1067,7 +1067,13 @@ addLowerAndOptimizeAddressComputationPasses(FunctionLikeNest &funcPassManager) { .addPass(createCSEPass) // Hoist the resulting decompositions. .addPass(createIREELoopInvariantCodeMotionPass) - .addPass(createLowerAffinePass); + .addPass(affine::createAffineExpandIndexOpsPass) + .addPass(createLowerAffinePass) + .addPass(IREE::Util::createOptimizeIntArithmeticPass) + // Do another round of LICM now that we've lowered and optimized + // arithmetic + .addPass(createCSEPass) + .addPass(createIREELoopInvariantCodeMotionPass); } static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, @@ -1103,7 +1109,9 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, FunctionLikeNest funcPassManager(modulePassManager); funcPassManager.addPass(createFoldTensorExtractOpPass) .addPass(createLLVMGPUVectorLoweringPass) - .addPass(createExpandGPUOpsPass); + .addPass(createExpandGPUOpsPass) + // Expose workitem and workgroup counts to range inference later. + .addPass(createGPUPropagateDispatchSizeBoundsPass); // This pass needs to run before SCF -> CF. addLowerAndOptimizeAddressComputationPasses(funcPassManager); @@ -1130,9 +1138,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, .addPass(memref::createExpandStridedMetadataPass) .addPass(createEmulateNarrowTypePass) .addPass(affine::createAffineExpandIndexOpsPass) - .addPass(createLowerAffinePass) - .addPass(createCanonicalizerPass) - .addPass(createCSEPass); + .addPass(createLowerAffinePass); // Strip out the debug info for the kernel. modulePassManager.addPass(createStripDebugInfoPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir index 6c1c5e117016..ba6b5da7f1fa 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir @@ -40,7 +40,7 @@ // CHECK-DAG: %[[C8192:.*]] = llvm.mlir.constant(8192 : index) : i64 // // Match the interesting special registers. -// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32 +// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y range : i32 // CHECK-DAG: %[[TID_Y_EXT:.*]] = llvm.sext %[[TID_Y]] : i32 to i64 // CHECK-DAG: %[[LANEID:.*]] = nvvm.read.ptx.sreg.laneid range : i32 // CHECK-DAG: %[[LANEID_EXT:.*]] = llvm.sext %[[LANEID]] : i32 to i64 diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index ea0aa9f45116..511dbe785300 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -227,9 +227,11 @@ static void addMemRefLoweringPasses(OpPassManager &modulePassManager) { /// Adds passes to perform the final SPIR-V conversion. static void addSPIRVLoweringPasses(OpPassManager &modulePassManager) { FunctionLikeNest(modulePassManager) + .addPass(createGPUPropagateDispatchSizeBoundsPass) .addPass(createCanonicalizerPass) .addPass(createCSEPass) .addPass(createLowerAffinePass) + .addPass(IREE::Util::createOptimizeIntArithmeticPass) // Lower ApplyScale before the i64 Emulation Pass so that new 64-bit ops // are also emulated if not supported by the target. diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel index d9d6a92ef71c..3f80245bfc8c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel @@ -35,6 +35,7 @@ iree_td_library( "//compiler/src/iree/compiler/Dialect/Util/IR:td_files", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:InferIntRangeInterfaceTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:ViewLikeInterfaceTdFiles", @@ -81,6 +82,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferIntRangeInterface", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Parser", diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 837855157e90..846bcf0d38a2 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -45,6 +45,7 @@ iree_cc_library( MLIRFuncDialect MLIRFunctionInterfaces MLIRIR + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRMemRefDialect MLIRParser diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 7210d402598d..cb5bb411810a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" namespace mlir::iree_compiler::IREE::HAL { @@ -2084,24 +2085,59 @@ static void getAsmResultNamesForInterfaceWorkgroupOp( } } +// Minimum is the smallest possible result we could get. It's 0 for ID-like +// operations and 1 for count-like ones. +static void setResultRangesForInterfaceWorkgroupOp( + Value result, const std::optional &upperBound, + SetIntRangeFn setResultRanges, int64_t minimum) { + unsigned width = ConstantIntRanges::getStorageBitwidth(result.getType()); + if (!upperBound.has_value()) { + setResultRanges( + result, ConstantIntRanges::fromSigned(APInt(width, minimum), + APInt::getSignedMaxValue(width))); + return; + } + setResultRanges(result, + ConstantIntRanges::fromUnsigned(APInt(width, minimum), + *upperBound + minimum - 1)); +} + void InterfaceWorkgroupIDOp::getAsmResultNames( function_ref setNameFn) { getAsmResultNamesForInterfaceWorkgroupOp("workgroup_id_", getDimension(), getResult(), setNameFn); } +void InterfaceWorkgroupIDOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(), + setResultRanges, /*minimum=*/0); +} + void InterfaceWorkgroupCountOp::getAsmResultNames( function_ref setNameFn) { getAsmResultNamesForInterfaceWorkgroupOp("workgroup_count_", getDimension(), getResult(), setNameFn); } +void InterfaceWorkgroupCountOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(), + setResultRanges, /*minimum=*/1); +} + void InterfaceWorkgroupSizeOp::getAsmResultNames( function_ref setNameFn) { getAsmResultNamesForInterfaceWorkgroupOp("workgroup_size_", getDimension(), getResult(), setNameFn); } +void InterfaceWorkgroupSizeOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(), + setResultRanges, /*minimum=*/1); +} + //===----------------------------------------------------------------------===// // hal.fence.* //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 16f1eadfdffd..d51e430b57c7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -3029,9 +3029,28 @@ def OpGroupInterfaceOps : OpDocGroup { let opDocGroup = OpGroupInterfaceOps in { -def HAL_InterfaceWorkgroupIDOp : HAL_PureOp<"interface.workgroup.id", [ - DeclareOpInterfaceMethods, -]> { +class HAL_InterfaceWorkgroupOp traits = []> + : HAL_PureOp, + DeclareOpInterfaceMethods])> { + let arguments = (ins + IndexAttr:$dimension, + OptionalAttr:$upper_bound); + let results = (outs HAL_Dim:$result); + + let builders = [ + OpBuilder<(ins "unsigned":$dim), + [{ + build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim), ::mlir::IntegerAttr{}); + }]>, + ]; + + let assemblyFormat = [{ + `[` $dimension `]` (`upper_bound` $upper_bound^)? attr-dict `:` type($result) + }]; +} + +def HAL_InterfaceWorkgroupIDOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.id"> { let summary = [{returns the index of the current workgroup in the grid}]; let description = [{ The global workgroup ID of the current tile in the range of @@ -3046,25 +3065,9 @@ def HAL_InterfaceWorkgroupIDOp : HAL_PureOp<"interface.workgroup.id", [ %z = hal.interface.workgroup.id[2] : index ``` }]; - - let arguments = (ins IndexAttr:$dimension); - let results = (outs HAL_Dim:$result); - - let builders = [ - OpBuilder<(ins "unsigned":$dim), - [{ - build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim)); - }]>, - ]; - - let assemblyFormat = [{ - `[` $dimension `]` attr-dict `:` type($result) - }]; } -def HAL_InterfaceWorkgroupCountOp : HAL_PureOp<"interface.workgroup.count", [ - DeclareOpInterfaceMethods, -]> { +def HAL_InterfaceWorkgroupCountOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.count"> { let summary = [{returns the total workgroup count of the grid}]; let description = [{ The total number of workgroups along each dimension in the dispatch grid. @@ -3081,24 +3084,9 @@ def HAL_InterfaceWorkgroupCountOp : HAL_PureOp<"interface.workgroup.count", [ ``` }]; - let arguments = (ins IndexAttr:$dimension); - let results = (outs HAL_Dim:$result); - - let builders = [ - OpBuilder<(ins "unsigned":$dim), - [{ - build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim)); - }]>, - ]; - - let assemblyFormat = [{ - `[` $dimension `]` attr-dict `:` type($result) - }]; } -def HAL_InterfaceWorkgroupSizeOp : HAL_PureOp<"interface.workgroup.size", [ - DeclareOpInterfaceMethods, -]> { +def HAL_InterfaceWorkgroupSizeOp : HAL_InterfaceWorkgroupOp<"interface.workgroup.size"> { let summary = [{returns the size of each workgroup in invocations}]; let description = [{ The number of local invocations within the current workgroup along each @@ -3114,20 +3102,6 @@ def HAL_InterfaceWorkgroupSizeOp : HAL_PureOp<"interface.workgroup.size", [ %z = hal.interface.workgroup.size[2] : index ``` }]; - - let arguments = (ins IndexAttr:$dimension); - let results = (outs HAL_Dim:$result); - - let builders = [ - OpBuilder<(ins "unsigned":$dim), - [{ - build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim)); - }]>, - ]; - - let assemblyFormat = [{ - `[` $dimension `]` attr-dict `:` type($result) - }]; } def HAL_InterfaceConstantLoadOp : HAL_PureOp<"interface.constant.load"> { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 9f3bee7d529a..d830c078b4bb 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -514,7 +514,8 @@ struct ConvertDispatchWorkgroupInfoPattern final LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getResult().getType(), - op.getDimensionAttr()); + op.getDimensionAttr(), + /*upper_bound=*/nullptr); return success(); } };