Skip to content

Commit

Permalink
[codegen] Add max_workgroup_counts to TargetWgpAttr
Browse files Browse the repository at this point in the history
This commit adds a max_workgroup_counts to the workgroup processor
information attribute and sets values for the known targets. Some of
these values may be underestimates as I was not able to locate
information on their values.

This field is added so that we can annotate calls to workgroup.id and
workgroup.count with upper bounds, neabling range inference and
strength reduction.

Note that in some cases (for instance, AMD) we give a
max_workgroup_counts value lower than what is actually supported
because a grid dimension greater than int32_max would be sign-extended
to a negative number to meet the 64-bit nature of `index`.

(This PR is split out of #17707)

Signed-off-by: Krzysztof Drewniak <[email protected]>
  • Loading branch information
krzysz00 committed Jul 12, 2024
1 parent 05dfe0b commit 3caef32
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 60 deletions.
3 changes: 2 additions & 1 deletion compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ module attributes {
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>>
}>
]>
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
// GFX942-SAME: chip = <wgp_count = 304>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
Expand Down
3 changes: 2 additions & 1 deletion compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ module attributes {
#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>>
}>
]>
]
Expand Down
3 changes: 2 additions & 1 deletion compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ module attributes {
#hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.0,cap:Shader,ext:SPV_KHR_storage_buffer_storage_class", wgp = <
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>>
}>
]>
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ def IREEGPU_TargetWgpAttr : AttrDef<IREEGPU_Dialect, "TargetWgp"> {
"uint32_t":$max_thread_count_per_workgroup,
// The maximal number of shared memory bytes we can allocate per workgroup.
"uint32_t":$max_workgroup_memory_bytes,
// Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch.
"DenseI32ArrayAttr":$max_workgroup_counts,

// An optional extra dict
// This field allows to inject more features/limits not supported in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ func.func @test_target_wgp() attributes {
// CHECK-SAME: subgroup_size_choices = [32, 64],
// CHECK-SAME: max_workgroup_sizes = [1024, 1024, 1024],
// CHECK-SAME: max_thread_count_per_workgroup = 1024,
// CHECK-SAME: max_workgroup_memory_bytes = 65536>
// CHECK-SAME: max_workgroup_memory_bytes = 65536,
// CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>
wgp = #iree_gpu.target_wgp<
compute = fp16|fp32|int8, storage = b16|b32,
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>],
subgroup_size_choices = [32, 64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
>
} { return }

Expand All @@ -37,7 +39,8 @@ func.func @test_target_wgp_none() attributes {
subgroup_size_choices = [32],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
>
} { return }

Expand Down Expand Up @@ -67,7 +70,8 @@ func.func @test_target() attributes {
subgroup_size_choices = [32, 64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536>,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
chip = <wgp_count = 304>
>
} { return }
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct WgpDetails {
std::array<int32_t, 3> maxWorkgroupSizes;
uint32_t maxThreadSize;
uint32_t maxWorkgroupMemoryBytes;
std::array<int32_t, 3> maxWorkgroupCounts;
};

// Chip level feature/limit details
Expand Down Expand Up @@ -106,7 +107,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
MMAOpsArrayAttr::get(context, mmaAttrs),
DenseI32ArrayAttr::get(context, subgroupSizes),
DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes),
wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DictionaryAttr{});
wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes,
DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts),
DictionaryAttr{});

TargetChipAttr targetChip;
if (details.chip)
Expand All @@ -118,6 +121,10 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,

//===----------------------------------------------------------------------===//
// Known AMD target details
//
// Note: the max workgroup size is given as signed int32 max because MLIR's
// `index` is signed and the workgroup ID is sign-extended, not zero-extended,
// to 64-bits.
//===----------------------------------------------------------------------===//

const WgpDetails *getCDNA3WgpDetails() {
Expand All @@ -127,11 +134,17 @@ const WgpDetails *getCDNA3WgpDetails() {
MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_I8_32x32x16_I32,
};
static const WgpDetails cdna3Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(cdna3MMAOps), cdna3MMAOps,
{64, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails cdna3Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(cdna3MMAOps),
cdna3MMAOps,
{64, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &cdna3Wgp;
}

Expand All @@ -140,11 +153,17 @@ const WgpDetails *getCDNA2WgpDetails() {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
};
static const WgpDetails cdna2Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(cdna2MMAOps), cdna2MMAOps,
{64, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails cdna2Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(cdna2MMAOps),
cdna2MMAOps,
{64, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &cdna2Wgp;
}

Expand All @@ -153,11 +172,17 @@ const WgpDetails *getCDNA1WgpDetails() {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
};
static const WgpDetails cdna1Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(cdna1MMAOps), cdna1MMAOps,
{64, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails cdna1Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(cdna1MMAOps),
cdna1MMAOps,
{64, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &cdna1Wgp;
}

Expand All @@ -166,27 +191,39 @@ const WgpDetails *getRDNA3WgpDetails() {
MMAIntrinsic::WMMA_F16_16x16x16_F32,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
};
static const WgpDetails rdna3Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(rdna3MMAOps), rdna3MMAOps,
{32, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails rdna3Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(rdna3MMAOps),
rdna3MMAOps,
{32, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &rdna3Wgp;
}

const WgpDetails *getRDNA2WgpDetails() {
static const WgpDetails rdna2Wgp = {
allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
1024, 64 * 1024};
allComputeBits, allStorageBits,
allSubgroupOps, allDotProductOps,
/*mmaCount=*/0,
/*mmaOps=*/nullptr, {32, 64},
{1024, 1024, 1024}, 1024,
64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &rdna2Wgp;
}

const WgpDetails *getRDNA1WgpDetails() {
static const WgpDetails rdna1Wgp = {
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
1024, 64 * 1024};
allComputeBits, allStorageBits,
allSubgroupOps, DotProductOps::None,
/*mmaCount=*/0,
/*mmaOps=*/nullptr, {32, 64},
{1024, 1024, 1024}, 1024,
64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &rdna1Wgp;
}

Expand Down Expand Up @@ -279,7 +316,9 @@ std::optional<TargetDetails> getAppleTargetDetails() {
static const WgpDetails wgp = {
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32},
{1024, 1024, 1024}, 1024, 32 * 1024};
{1024, 1024, 1024}, 1024, 32 * 1024,
// Note: These values have not been checked and may be higher
{0xffff, 0xffff, 0xffff}};
// clang-format on

return TargetDetails{&wgp, nullptr};
Expand All @@ -300,7 +339,9 @@ const WgpDetails *getValhallWgpDetails() {
static const WgpDetails valhallWgp = {
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {16, 16}, {512, 512, 512},
512, 32 * 1024};
512, 32 * 1024,
// Note: These values have not been checked and may be higher
{0xffff, 0xffff, 0xffff}};
// clang-format on
return &valhallWgp;
}
Expand Down Expand Up @@ -356,11 +397,17 @@ const WgpDetails *getAmpereWgpDetails() {
MMAIntrinsic::WMMA_F16_16x16x16_F32,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
};
static const WgpDetails ampereWgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
{32, 32}, {1024, 1024, 1024}, 1024,
163 * 1024};
static const WgpDetails ampereWgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(mmaOps),
mmaOps,
{32, 32},
{1024, 1024, 1024},
1024,
163 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
return &ampereWgp;
}

Expand All @@ -369,11 +416,17 @@ const WgpDetails *getTuringWgpDetails() {
MMAIntrinsic::WMMA_F16_16x16x16_F32,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
};
static const WgpDetails turingWgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
{32, 32}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails turingWgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(mmaOps),
mmaOps,
{32, 32},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
return &turingWgp;
}

Expand All @@ -386,7 +439,8 @@ const WgpDetails *getVoltaWgpDetails() {
static const WgpDetails voltaWgp = {
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
ARRAY_SIZE(mmaOps), mmaOps, {32, 32}, {1024, 1024, 1024},
1024, 96 * 1024};
1024, 96 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
// clang-format on
return &voltaWgp;
}
Expand All @@ -396,7 +450,8 @@ const WgpDetails *getPascalWgpDetails() {
static const WgpDetails pascalWgp = {
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
0, nullptr, // Pascal does not have tensor core support.
{32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024};
{32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
// clang-format on
return &pascalWgp;
}
Expand Down Expand Up @@ -477,7 +532,9 @@ const WgpDetails *getAdrenoWgpDetails() {
computeBitwdiths, storageBitwidths, allSubgroupOps,
allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr,
{64, 64}, {1024, 1024, 1024}, 1024,
32 * 1024};
32 * 1024,
// Note: These values have not been checked and may be higher
{0xffff, 0xffff, 0xffff}};
// clang-format on
return &adrenoWgp;
}
Expand Down Expand Up @@ -543,7 +600,8 @@ const WgpDetails *getAndroidBaseline2022WgpDetails() {
computeBitwdiths, storageBitwidths, SubgroupOps::None,
DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
{64, 64}, {128, 128, 64}, 128,
16 * 1024};
16 * 1024,
{0xffff, 0xffff, 0xffff}};
// clang-format on
return &androidWgp;
}
Expand Down Expand Up @@ -641,7 +699,8 @@ TargetAttr getWebGPUTargetDetails(MLIRContext *context) {
computeBitwdiths, storageBitwidths, SubgroupOps::None,
DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
{32, 32}, {128, 128, 64}, 128,
16 * 1024};
16 * 1024,
{0xffff, 0xffff, 0xffff}};
// clang-format on

return createTargetAttr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ module {
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [],
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target}>
module {
func.func @matmul_256x256x256() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ hal.executable @dispatch {
hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {
iree.gpu.target = #iree_gpu.target<arch = "rdna3", features = "spirv:v1.6,cap:Shader",
wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>],
subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>}>) {
subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>) {
hal.executable.export public @dispatch ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func.func @main1(%arg0: tensor<2x130x130x320xf16>, %arg1: tensor<3x3x320x4xf16>,
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [<MFMA_F16_32x32x8_F32>],
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>
#rocm_executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target, ukernels = "none"}>

// CHECK-LABEL: func.func @main2(
Expand Down
3 changes: 2 additions & 1 deletion samples/custom_dispatch/vulkan/shaders/example.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
compute = fp32|int32, storage = b32, subgroup = none,
dot = none, mma = [], subgroup_size_choices = [64, 64],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128,
max_workgroup_memory_bytes = 16384>
max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>
>
}>

Expand Down
Loading

0 comments on commit 3caef32

Please sign in to comment.