From 7e5cc1f7925f53e86b3a7cc5f56d04d635b41775 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Fri, 28 Jun 2024 19:23:38 +0000 Subject: [PATCH] [codegen] Add max_workgroup_counts to TargetWgpAttr This commit adds a max_workgroup_counts to the workgroup processor information attribute and sets values for the known targets. Some of these values may be underestimates as I was not able to locate information on their values. This field is added so that we can annotate calls to workgroup.id and workgroup.count with upper bounds, neabling range inference and strength reduction. Note that in some cases (for instance, AMD) we give a max_workgroup_counts value lower than what is actually supported because a grid dimension greater than int32_max would be sign-extended to a negative number to meet the 64-bit nature of `index`. (This PR is split out of #17707) Signed-off-by: Krzysztof Drewniak --- .../target/MetalSPIRV/test/smoketest.mlir | 3 +- .../ROCM/test/target_device_features.mlir | 3 +- .../target/VulkanSPIRV/test/smoketest.mlir | 3 +- .../target/WebGPUSPIRV/test/smoketest.mlir | 3 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 2 + .../Dialect/GPU/IR/test/target_attrs.mlir | 12 +- .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 147 ++++++++++++------ .../test/ROCDL/config_vector_distribute.mlir | 3 +- .../SPIRV/test/convert_gpu_target.mlir | 3 +- .../Common/test/pad_to_intrinsics_mfma.mlir | 3 +- .../vulkan/shaders/example.mlir | 3 +- .../vulkan/shaders/example_inline.mlir | 3 +- .../vulkan/shaders/example_transform.mlir | 3 +- .../shaders/example_transform_spec.mlir | 3 +- samples/transform_dialect/example_module.mlir | 2 +- 15 files changed, 136 insertions(+), 60 deletions(-) diff --git a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir index 720e00b2f8353..90b9a31b9fcc2 100644 --- a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir +++ b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir @@ -6,7 +6,8 @@ module attributes { #hal.executable.target<"metal-spirv", "metal-msl-fb", { iree.gpu.target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]>> }> ]> ] diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index c801f9bb513e2..9f3b46f28f090 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -8,7 +8,8 @@ // GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32, // GFX942-SAME: mma = [, , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], -// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>, +// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, +// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>, // GFX942-SAME: chip = > // GFX940: target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]>> }> ]> ] diff --git a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir index 31f361b1ab5fe..9b2a6424c22d9 100644 --- a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir +++ b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir @@ -7,7 +7,8 @@ module attributes { #hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", { iree.gpu.target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]>> }> ]> ] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index af69895937c1e..1ebcfecb9bfc2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -289,6 +289,8 @@ def IREEGPU_TargetWgpAttr : AttrDef { "uint32_t":$max_thread_count_per_workgroup, // The maximal number of shared memory bytes we can allocate per workgroup. "uint32_t":$max_workgroup_memory_bytes, + // Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch. + "DenseI32ArrayAttr":$max_workgroup_counts, // An optional extra dict // This field allows to inject more features/limits not supported in the diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir index baa47b2be12ed..e8611005d71fb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir @@ -11,7 +11,8 @@ func.func @test_target_wgp() attributes { // CHECK-SAME: subgroup_size_choices = [32, 64], // CHECK-SAME: max_workgroup_sizes = [1024, 1024, 1024], // CHECK-SAME: max_thread_count_per_workgroup = 1024, - // CHECK-SAME: max_workgroup_memory_bytes = 65536> + // CHECK-SAME: max_workgroup_memory_bytes = 65536, + // CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]> wgp = #iree_gpu.target_wgp< compute = fp16|fp32|int8, storage = b16|b32, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, @@ -19,7 +20,8 @@ func.func @test_target_wgp() attributes { subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, - max_workgroup_memory_bytes = 65536 + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647] > } { return } @@ -37,7 +39,8 @@ func.func @test_target_wgp_none() attributes { subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, - max_workgroup_memory_bytes = 65536 + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647] > } { return } @@ -67,7 +70,8 @@ func.func @test_target() attributes { subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, - max_workgroup_memory_bytes = 65536>, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>, chip = > } { return } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 993963c739df5..30c35a053fbfa 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -45,6 +45,7 @@ struct WgpDetails { std::array maxWorkgroupSizes; uint32_t maxThreadSize; uint32_t maxWorkgroupMemoryBytes; + std::array maxWorkgroupCounts; }; // Chip level feature/limit details @@ -106,7 +107,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, MMAOpsArrayAttr::get(context, mmaAttrs), DenseI32ArrayAttr::get(context, subgroupSizes), DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes), - wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DictionaryAttr{}); + wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, + DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts), + DictionaryAttr{}); TargetChipAttr targetChip; if (details.chip) @@ -118,6 +121,10 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, //===----------------------------------------------------------------------===// // Known AMD target details +// +// Note: the max workgroup size is given as signed int32 max because MLIR's +// `index` is signed and the workgroup ID is sign-extended, not zero-extended, +// to 64-bits. //===----------------------------------------------------------------------===// const WgpDetails *getCDNA3WgpDetails() { @@ -127,11 +134,17 @@ const WgpDetails *getCDNA3WgpDetails() { MMAIntrinsic::MFMA_I8_16x16x32_I32, MMAIntrinsic::MFMA_I8_32x32x16_I32, }; - static const WgpDetails cdna3Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(cdna3MMAOps), cdna3MMAOps, - {64, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails cdna3Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(cdna3MMAOps), + cdna3MMAOps, + {64, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &cdna3Wgp; } @@ -140,11 +153,17 @@ const WgpDetails *getCDNA2WgpDetails() { MMAIntrinsic::MFMA_F16_16x16x16_F32, MMAIntrinsic::MFMA_F16_32x32x8_F32, }; - static const WgpDetails cdna2Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(cdna2MMAOps), cdna2MMAOps, - {64, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails cdna2Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(cdna2MMAOps), + cdna2MMAOps, + {64, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &cdna2Wgp; } @@ -153,11 +172,17 @@ const WgpDetails *getCDNA1WgpDetails() { MMAIntrinsic::MFMA_F16_16x16x16_F32, MMAIntrinsic::MFMA_F16_32x32x8_F32, }; - static const WgpDetails cdna1Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(cdna1MMAOps), cdna1MMAOps, - {64, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails cdna1Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(cdna1MMAOps), + cdna1MMAOps, + {64, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &cdna1Wgp; } @@ -166,27 +191,39 @@ const WgpDetails *getRDNA3WgpDetails() { MMAIntrinsic::WMMA_F16_16x16x16_F32, MMAIntrinsic::WMMA_F16_16x16x16_F16, }; - static const WgpDetails rdna3Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(rdna3MMAOps), rdna3MMAOps, - {32, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails rdna3Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(rdna3MMAOps), + rdna3MMAOps, + {32, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &rdna3Wgp; } const WgpDetails *getRDNA2WgpDetails() { static const WgpDetails rdna2Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps, - /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024}, - 1024, 64 * 1024}; + allComputeBits, allStorageBits, + allSubgroupOps, allDotProductOps, + /*mmaCount=*/0, + /*mmaOps=*/nullptr, {32, 64}, + {1024, 1024, 1024}, 1024, + 64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &rdna2Wgp; } const WgpDetails *getRDNA1WgpDetails() { static const WgpDetails rdna1Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None, - /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024}, - 1024, 64 * 1024}; + allComputeBits, allStorageBits, + allSubgroupOps, DotProductOps::None, + /*mmaCount=*/0, + /*mmaOps=*/nullptr, {32, 64}, + {1024, 1024, 1024}, 1024, + 64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &rdna1Wgp; } @@ -279,7 +316,9 @@ std::optional getAppleTargetDetails() { static const WgpDetails wgp = { computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32}, - {1024, 1024, 1024}, 1024, 32 * 1024}; + {1024, 1024, 1024}, 1024, 32 * 1024, + // Note: These values have not been checked and may be higher + {0xffff, 0xffff, 0xffff}}; // clang-format on return TargetDetails{&wgp, nullptr}; @@ -300,7 +339,9 @@ const WgpDetails *getValhallWgpDetails() { static const WgpDetails valhallWgp = { computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr, {16, 16}, {512, 512, 512}, - 512, 32 * 1024}; + 512, 32 * 1024, + // Note: These values have not been checked and may be higher + {0xffff, 0xffff, 0xffff}}; // clang-format on return &valhallWgp; } @@ -356,11 +397,17 @@ const WgpDetails *getAmpereWgpDetails() { MMAIntrinsic::WMMA_F16_16x16x16_F32, MMAIntrinsic::WMMA_F16_16x16x16_F16, }; - static const WgpDetails ampereWgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps, - {32, 32}, {1024, 1024, 1024}, 1024, - 163 * 1024}; + static const WgpDetails ampereWgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(mmaOps), + mmaOps, + {32, 32}, + {1024, 1024, 1024}, + 1024, + 163 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; return &ereWgp; } @@ -369,11 +416,17 @@ const WgpDetails *getTuringWgpDetails() { MMAIntrinsic::WMMA_F16_16x16x16_F32, MMAIntrinsic::WMMA_F16_16x16x16_F16, }; - static const WgpDetails turingWgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps, - {32, 32}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails turingWgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(mmaOps), + mmaOps, + {32, 32}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; return &turingWgp; } @@ -386,7 +439,8 @@ const WgpDetails *getVoltaWgpDetails() { static const WgpDetails voltaWgp = { allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None, ARRAY_SIZE(mmaOps), mmaOps, {32, 32}, {1024, 1024, 1024}, - 1024, 96 * 1024}; + 1024, 96 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; // clang-format on return &voltaWgp; } @@ -396,7 +450,8 @@ const WgpDetails *getPascalWgpDetails() { static const WgpDetails pascalWgp = { allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None, 0, nullptr, // Pascal does not have tensor core support. - {32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024}; + {32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; // clang-format on return &pascalWgp; } @@ -477,7 +532,9 @@ const WgpDetails *getAdrenoWgpDetails() { computeBitwdiths, storageBitwidths, allSubgroupOps, allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr, {64, 64}, {1024, 1024, 1024}, 1024, - 32 * 1024}; + 32 * 1024, + // Note: These values have not been checked and may be higher + {0xffff, 0xffff, 0xffff}}; // clang-format on return &adrenoWgp; } @@ -543,7 +600,8 @@ const WgpDetails *getAndroidBaseline2022WgpDetails() { computeBitwdiths, storageBitwidths, SubgroupOps::None, DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr, {64, 64}, {128, 128, 64}, 128, - 16 * 1024}; + 16 * 1024, + {0xffff, 0xffff, 0xffff}}; // clang-format on return &androidWgp; } @@ -641,7 +699,8 @@ TargetAttr getWebGPUTargetDetails(MLIRContext *context) { computeBitwdiths, storageBitwidths, SubgroupOps::None, DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32}, {128, 128, 64}, 128, - 16 * 1024}; + 16 * 1024, + {0xffff, 0xffff, 0xffff}}; // clang-format on return createTargetAttr( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir index b40de6ba70e5e..a13a0c56e2356 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir @@ -76,7 +76,8 @@ module { subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], - max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>> + max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>> #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target}> module { func.func @matmul_256x256x256() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir index b1f809293db2a..c03730532c006 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir @@ -4,7 +4,8 @@ hal.executable @dispatch { hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", { iree.gpu.target = #iree_gpu.target, ], - subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>}>) { + subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>) { hal.executable.export public @dispatch ordinal(0) layout(#hal.pipeline.layout]>]>) { ^bb0(%arg0: !hal.device): %x, %y, %z = flow.dispatch.workgroup_count_from_slice diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir index 5761741f37870..88dfda915da3f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir @@ -69,7 +69,8 @@ func.func @main1(%arg0: tensor<2x130x130x320xf16>, %arg1: tensor<3x3x320x4xf16>, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], - max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>> + max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>> #rocm_executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target, ukernels = "none"}> // CHECK-LABEL: func.func @main2( diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir index ef10fb7b7dbda..72c4771fec34e 100644 --- a/samples/custom_dispatch/vulkan/shaders/example.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example.mlir @@ -19,7 +19,8 @@ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir index 36912bb35df99..96ca9857d32b5 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir @@ -19,7 +19,8 @@ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir index b4885a03081d7..3951f5d9c3638 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir @@ -23,7 +23,8 @@ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir index 5bcdafe7fba1e..8e232069fa153 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir @@ -12,7 +12,8 @@ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir index 585bb25915345..c0ed4848e18c6 100644 --- a/samples/transform_dialect/example_module.mlir +++ b/samples/transform_dialect/example_module.mlir @@ -27,7 +27,7 @@ #target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, max_workgroup_counts = [65535, 65535, 65535]>> module attributes { hal.device.targets = [