diff --git a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir index 720e00b2f8353..90b9a31b9fcc2 100644 --- a/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir +++ b/compiler/plugins/target/MetalSPIRV/test/smoketest.mlir @@ -6,7 +6,8 @@ module attributes { #hal.executable.target<"metal-spirv", "metal-msl-fb", { iree.gpu.target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]>> }> ]> ] diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index c801f9bb513e2..9f3b46f28f090 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -8,7 +8,8 @@ // GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32, // GFX942-SAME: mma = [, , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], -// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>, +// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, +// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>, // GFX942-SAME: chip = > // GFX940: target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]>> }> ]> ] diff --git a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir index 31f361b1ab5fe..9b2a6424c22d9 100644 --- a/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir +++ b/compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir @@ -7,7 +7,8 @@ module attributes { #hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", { iree.gpu.target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]>> }> ]> ] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index af69895937c1e..1ebcfecb9bfc2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -289,6 +289,8 @@ def IREEGPU_TargetWgpAttr : AttrDef { "uint32_t":$max_thread_count_per_workgroup, // The maximal number of shared memory bytes we can allocate per workgroup. "uint32_t":$max_workgroup_memory_bytes, + // Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch. + "DenseI32ArrayAttr":$max_workgroup_counts, // An optional extra dict // This field allows to inject more features/limits not supported in the diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir index baa47b2be12ed..e8611005d71fb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir @@ -11,7 +11,8 @@ func.func @test_target_wgp() attributes { // CHECK-SAME: subgroup_size_choices = [32, 64], // CHECK-SAME: max_workgroup_sizes = [1024, 1024, 1024], // CHECK-SAME: max_thread_count_per_workgroup = 1024, - // CHECK-SAME: max_workgroup_memory_bytes = 65536> + // CHECK-SAME: max_workgroup_memory_bytes = 65536, + // CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]> wgp = #iree_gpu.target_wgp< compute = fp16|fp32|int8, storage = b16|b32, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, @@ -19,7 +20,8 @@ func.func @test_target_wgp() attributes { subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, - max_workgroup_memory_bytes = 65536 + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647] > } { return } @@ -37,7 +39,8 @@ func.func @test_target_wgp_none() attributes { subgroup_size_choices = [32], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, - max_workgroup_memory_bytes = 65536 + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647] > } { return } @@ -67,7 +70,8 @@ func.func @test_target() attributes { subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, - max_workgroup_memory_bytes = 65536>, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>, chip = > } { return } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 993963c739df5..30c35a053fbfa 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -45,6 +45,7 @@ struct WgpDetails { std::array maxWorkgroupSizes; uint32_t maxThreadSize; uint32_t maxWorkgroupMemoryBytes; + std::array maxWorkgroupCounts; }; // Chip level feature/limit details @@ -106,7 +107,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, MMAOpsArrayAttr::get(context, mmaAttrs), DenseI32ArrayAttr::get(context, subgroupSizes), DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes), - wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DictionaryAttr{}); + wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, + DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts), + DictionaryAttr{}); TargetChipAttr targetChip; if (details.chip) @@ -118,6 +121,10 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, //===----------------------------------------------------------------------===// // Known AMD target details +// +// Note: the max workgroup size is given as signed int32 max because MLIR's +// `index` is signed and the workgroup ID is sign-extended, not zero-extended, +// to 64-bits. //===----------------------------------------------------------------------===// const WgpDetails *getCDNA3WgpDetails() { @@ -127,11 +134,17 @@ const WgpDetails *getCDNA3WgpDetails() { MMAIntrinsic::MFMA_I8_16x16x32_I32, MMAIntrinsic::MFMA_I8_32x32x16_I32, }; - static const WgpDetails cdna3Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(cdna3MMAOps), cdna3MMAOps, - {64, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails cdna3Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(cdna3MMAOps), + cdna3MMAOps, + {64, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &cdna3Wgp; } @@ -140,11 +153,17 @@ const WgpDetails *getCDNA2WgpDetails() { MMAIntrinsic::MFMA_F16_16x16x16_F32, MMAIntrinsic::MFMA_F16_32x32x8_F32, }; - static const WgpDetails cdna2Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(cdna2MMAOps), cdna2MMAOps, - {64, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails cdna2Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(cdna2MMAOps), + cdna2MMAOps, + {64, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &cdna2Wgp; } @@ -153,11 +172,17 @@ const WgpDetails *getCDNA1WgpDetails() { MMAIntrinsic::MFMA_F16_16x16x16_F32, MMAIntrinsic::MFMA_F16_32x32x8_F32, }; - static const WgpDetails cdna1Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(cdna1MMAOps), cdna1MMAOps, - {64, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails cdna1Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(cdna1MMAOps), + cdna1MMAOps, + {64, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &cdna1Wgp; } @@ -166,27 +191,39 @@ const WgpDetails *getRDNA3WgpDetails() { MMAIntrinsic::WMMA_F16_16x16x16_F32, MMAIntrinsic::WMMA_F16_16x16x16_F16, }; - static const WgpDetails rdna3Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(rdna3MMAOps), rdna3MMAOps, - {32, 64}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails rdna3Wgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(rdna3MMAOps), + rdna3MMAOps, + {32, 64}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &rdna3Wgp; } const WgpDetails *getRDNA2WgpDetails() { static const WgpDetails rdna2Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps, - /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024}, - 1024, 64 * 1024}; + allComputeBits, allStorageBits, + allSubgroupOps, allDotProductOps, + /*mmaCount=*/0, + /*mmaOps=*/nullptr, {32, 64}, + {1024, 1024, 1024}, 1024, + 64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &rdna2Wgp; } const WgpDetails *getRDNA1WgpDetails() { static const WgpDetails rdna1Wgp = { - allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None, - /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024}, - 1024, 64 * 1024}; + allComputeBits, allStorageBits, + allSubgroupOps, DotProductOps::None, + /*mmaCount=*/0, + /*mmaOps=*/nullptr, {32, 64}, + {1024, 1024, 1024}, 1024, + 64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}}; return &rdna1Wgp; } @@ -279,7 +316,9 @@ std::optional getAppleTargetDetails() { static const WgpDetails wgp = { computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32}, - {1024, 1024, 1024}, 1024, 32 * 1024}; + {1024, 1024, 1024}, 1024, 32 * 1024, + // Note: These values have not been checked and may be higher + {0xffff, 0xffff, 0xffff}}; // clang-format on return TargetDetails{&wgp, nullptr}; @@ -300,7 +339,9 @@ const WgpDetails *getValhallWgpDetails() { static const WgpDetails valhallWgp = { computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr, {16, 16}, {512, 512, 512}, - 512, 32 * 1024}; + 512, 32 * 1024, + // Note: These values have not been checked and may be higher + {0xffff, 0xffff, 0xffff}}; // clang-format on return &valhallWgp; } @@ -356,11 +397,17 @@ const WgpDetails *getAmpereWgpDetails() { MMAIntrinsic::WMMA_F16_16x16x16_F32, MMAIntrinsic::WMMA_F16_16x16x16_F16, }; - static const WgpDetails ampereWgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps, - {32, 32}, {1024, 1024, 1024}, 1024, - 163 * 1024}; + static const WgpDetails ampereWgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(mmaOps), + mmaOps, + {32, 32}, + {1024, 1024, 1024}, + 1024, + 163 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; return &ereWgp; } @@ -369,11 +416,17 @@ const WgpDetails *getTuringWgpDetails() { MMAIntrinsic::WMMA_F16_16x16x16_F32, MMAIntrinsic::WMMA_F16_16x16x16_F16, }; - static const WgpDetails turingWgp = { - allComputeBits, allStorageBits, allSubgroupOps, - allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps, - {32, 32}, {1024, 1024, 1024}, 1024, - 64 * 1024}; + static const WgpDetails turingWgp = {allComputeBits, + allStorageBits, + allSubgroupOps, + allDotProductOps, + ARRAY_SIZE(mmaOps), + mmaOps, + {32, 32}, + {1024, 1024, 1024}, + 1024, + 64 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; return &turingWgp; } @@ -386,7 +439,8 @@ const WgpDetails *getVoltaWgpDetails() { static const WgpDetails voltaWgp = { allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None, ARRAY_SIZE(mmaOps), mmaOps, {32, 32}, {1024, 1024, 1024}, - 1024, 96 * 1024}; + 1024, 96 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; // clang-format on return &voltaWgp; } @@ -396,7 +450,8 @@ const WgpDetails *getPascalWgpDetails() { static const WgpDetails pascalWgp = { allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None, 0, nullptr, // Pascal does not have tensor core support. - {32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024}; + {32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024, + {0x7fffffff, 0xffff, 0xffff}}; // clang-format on return &pascalWgp; } @@ -477,7 +532,9 @@ const WgpDetails *getAdrenoWgpDetails() { computeBitwdiths, storageBitwidths, allSubgroupOps, allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr, {64, 64}, {1024, 1024, 1024}, 1024, - 32 * 1024}; + 32 * 1024, + // Note: These values have not been checked and may be higher + {0xffff, 0xffff, 0xffff}}; // clang-format on return &adrenoWgp; } @@ -543,7 +600,8 @@ const WgpDetails *getAndroidBaseline2022WgpDetails() { computeBitwdiths, storageBitwidths, SubgroupOps::None, DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr, {64, 64}, {128, 128, 64}, 128, - 16 * 1024}; + 16 * 1024, + {0xffff, 0xffff, 0xffff}}; // clang-format on return &androidWgp; } @@ -641,7 +699,8 @@ TargetAttr getWebGPUTargetDetails(MLIRContext *context) { computeBitwdiths, storageBitwidths, SubgroupOps::None, DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32}, {128, 128, 64}, 128, - 16 * 1024}; + 16 * 1024, + {0xffff, 0xffff, 0xffff}}; // clang-format on return createTargetAttr( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir index b40de6ba70e5e..a13a0c56e2356 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir @@ -76,7 +76,8 @@ module { subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], - max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>> + max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>> #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target}> module { func.func @matmul_256x256x256() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir index b1f809293db2a..c03730532c006 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_gpu_target.mlir @@ -4,7 +4,8 @@ hal.executable @dispatch { hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", { iree.gpu.target = #iree_gpu.target, ], - subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>}>) { + subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>) { hal.executable.export public @dispatch ordinal(0) layout(#hal.pipeline.layout]>]>) { ^bb0(%arg0: !hal.device): %x, %y, %z = flow.dispatch.workgroup_count_from_slice diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir index 5761741f37870..88dfda915da3f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir @@ -69,7 +69,8 @@ func.func @main1(%arg0: tensor<2x130x130x320xf16>, %arg1: tensor<3x3x320x4xf16>, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], - max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>> + max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647]>> #rocm_executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target, ukernels = "none"}> // CHECK-LABEL: func.func @main2( diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir index ef10fb7b7dbda..72c4771fec34e 100644 --- a/samples/custom_dispatch/vulkan/shaders/example.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example.mlir @@ -19,7 +19,8 @@ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir index 36912bb35df99..96ca9857d32b5 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir @@ -19,7 +19,8 @@ compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir index b4885a03081d7..3951f5d9c3638 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir @@ -23,7 +23,8 @@ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir index 5bcdafe7fba1e..8e232069fa153 100644 --- a/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir +++ b/samples/custom_dispatch/vulkan/shaders/example_transform_spec.mlir @@ -12,7 +12,8 @@ compute = fp32|int32, storage = b32, subgroup = shuffle|arithmetic, dot = none, mma = [], subgroup_size_choices = [64, 64], max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, - max_workgroup_memory_bytes = 16384> + max_workgroup_memory_bytes = 16384, + max_workgroup_counts = [65535, 65535, 65535]> > }> diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir index 585bb25915345..c0ed4848e18c6 100644 --- a/samples/transform_dialect/example_module.mlir +++ b/samples/transform_dialect/example_module.mlir @@ -27,7 +27,7 @@ #target = #iree_gpu.target> + max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384, max_workgroup_counts = [65535, 65535, 65535]>> module attributes { hal.device.targets = [