Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[codegen] Add max_workgroup_counts to TargetWgpAttr #17771

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/plugins/target/MetalSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ module attributes {
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>>
}>
]> : !hal.device
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
// GFX942-SAME: chip = <wgp_count = 304>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
Expand Down
3 changes: 2 additions & 1 deletion compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ module attributes {
#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>>
}>
]> : !hal.device
]
Expand Down
3 changes: 2 additions & 1 deletion compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ module attributes {
#hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.0,cap:Shader,ext:SPV_KHR_storage_buffer_storage_class", wgp = <
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>>
}>
]> : !hal.device
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def IREEGPU_TargetWgpAttr : AttrDef<IREEGPU_Dialect, "TargetWgp"> {
"uint32_t":$max_thread_count_per_workgroup,
// The maximal number of shared memory bytes we can allocate per workgroup.
"uint32_t":$max_workgroup_memory_bytes,
// Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch.
antiagainst marked this conversation as resolved.
Show resolved Hide resolved
"DenseI32ArrayAttr":$max_workgroup_counts,

// An optional extra dict
// This field allows to inject more features/limits not supported in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ func.func @test_target_wgp() attributes {
// CHECK-SAME: subgroup_size_choices = [32, 64],
// CHECK-SAME: max_workgroup_sizes = [1024, 1024, 1024],
// CHECK-SAME: max_thread_count_per_workgroup = 1024,
// CHECK-SAME: max_workgroup_memory_bytes = 65536>
// CHECK-SAME: max_workgroup_memory_bytes = 65536,
// CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>
wgp = #iree_gpu.target_wgp<
compute = fp16|fp32|int8, storage = b16|b32,
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [<MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>],
subgroup_size_choices = [32, 64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
>
} { return }

Expand All @@ -37,7 +39,8 @@ func.func @test_target_wgp_none() attributes {
subgroup_size_choices = [32],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
>
} { return }

Expand Down Expand Up @@ -67,7 +70,8 @@ func.func @test_target() attributes {
subgroup_size_choices = [32, 64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536>,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
chip = <wgp_count = 304>
>
} { return }
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct WgpDetails {
std::array<int32_t, 3> maxWorkgroupSizes;
uint32_t maxThreadSize;
uint32_t maxWorkgroupMemoryBytes;
std::array<int32_t, 3> maxWorkgroupCounts;
};

// Chip level feature/limit details
Expand Down Expand Up @@ -106,7 +107,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
MMAOpsArrayAttr::get(context, mmaAttrs),
DenseI32ArrayAttr::get(context, subgroupSizes),
DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes),
wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DictionaryAttr{});
wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes,
DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts),
DictionaryAttr{});

TargetChipAttr targetChip;
if (details.chip)
Expand All @@ -118,6 +121,10 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,

//===----------------------------------------------------------------------===//
// Known AMD target details
//
// Note: the max workgroup size is given as signed int32 max because MLIR's
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the limit reported by rocm-smi? (Although I think rocm-smi reported number doesn't reflect the reality iirc?) We can make the field as int64 or optional if necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While rocm-smi does report 0xff_ff_ff_ff, that's a problem in MLIR, because we use index for the workgroup ID, which is signed.

(I'm also not sure if that number is in units of workgroups or workitems, so it might be 0xff_ff_ff_ff / work_group_size)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we may want to make the field optional and interpret missing as int32 max then.. doesn't help much to explicitly print out int32 max.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, yeah, I'd be open to doing that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do it as a follow up later. not blocking.

// `index` is signed and the workgroup ID is sign-extended, not zero-extended,
// to 64-bits.
//===----------------------------------------------------------------------===//

const WgpDetails *getCDNA3WgpDetails() {
Expand All @@ -129,11 +136,17 @@ const WgpDetails *getCDNA3WgpDetails() {
MMAIntrinsic::MFMA_I32_16x16x32_I8,
MMAIntrinsic::MFMA_I32_32x32x16_I8,
};
static const WgpDetails cdna3Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(cdna3MMAOps), cdna3MMAOps,
{64, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails cdna3Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(cdna3MMAOps),
cdna3MMAOps,
{64, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &cdna3Wgp;
}

Expand All @@ -142,11 +155,17 @@ const WgpDetails *getCDNA2WgpDetails() {
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
};
static const WgpDetails cdna2Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(cdna2MMAOps), cdna2MMAOps,
{64, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails cdna2Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(cdna2MMAOps),
cdna2MMAOps,
{64, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &cdna2Wgp;
}

Expand All @@ -155,11 +174,17 @@ const WgpDetails *getCDNA1WgpDetails() {
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
};
static const WgpDetails cdna1Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(cdna1MMAOps), cdna1MMAOps,
{64, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails cdna1Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(cdna1MMAOps),
cdna1MMAOps,
{64, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &cdna1Wgp;
}

Expand All @@ -168,27 +193,39 @@ const WgpDetails *getRDNA3WgpDetails() {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
};
static const WgpDetails rdna3Wgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(rdna3MMAOps), rdna3MMAOps,
{32, 64}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails rdna3Wgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(rdna3MMAOps),
rdna3MMAOps,
{32, 64},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &rdna3Wgp;
}

const WgpDetails *getRDNA2WgpDetails() {
static const WgpDetails rdna2Wgp = {
allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
1024, 64 * 1024};
allComputeBits, allStorageBits,
allSubgroupOps, allDotProductOps,
/*mmaCount=*/0,
/*mmaOps=*/nullptr, {32, 64},
{1024, 1024, 1024}, 1024,
64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &rdna2Wgp;
}

const WgpDetails *getRDNA1WgpDetails() {
static const WgpDetails rdna1Wgp = {
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
1024, 64 * 1024};
allComputeBits, allStorageBits,
allSubgroupOps, DotProductOps::None,
/*mmaCount=*/0,
/*mmaOps=*/nullptr, {32, 64},
{1024, 1024, 1024}, 1024,
64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}};
return &rdna1Wgp;
}

Expand Down Expand Up @@ -281,7 +318,9 @@ std::optional<TargetDetails> getAppleTargetDetails() {
static const WgpDetails wgp = {
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32},
{1024, 1024, 1024}, 1024, 32 * 1024};
{1024, 1024, 1024}, 1024, 32 * 1024,
// Note: These values have not been checked and may be higher
{0xffff, 0xffff, 0xffff}};
// clang-format on

return TargetDetails{&wgp, nullptr};
Expand All @@ -302,7 +341,9 @@ const WgpDetails *getValhallWgpDetails() {
static const WgpDetails valhallWgp = {
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
/*mmaCount=*/0, /*mmaOps=*/nullptr, {16, 16}, {512, 512, 512},
512, 32 * 1024};
512, 32 * 1024,
// Note: These values have not been checked and may be higher
{0xffff, 0xffff, 0xffff}};
// clang-format on
return &valhallWgp;
}
Expand Down Expand Up @@ -358,11 +399,17 @@ const WgpDetails *getAmpereWgpDetails() {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
};
static const WgpDetails ampereWgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
{32, 32}, {1024, 1024, 1024}, 1024,
163 * 1024};
static const WgpDetails ampereWgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(mmaOps),
mmaOps,
{32, 32},
{1024, 1024, 1024},
1024,
163 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
return &ampereWgp;
}

Expand All @@ -371,11 +418,17 @@ const WgpDetails *getTuringWgpDetails() {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
};
static const WgpDetails turingWgp = {
allComputeBits, allStorageBits, allSubgroupOps,
allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
{32, 32}, {1024, 1024, 1024}, 1024,
64 * 1024};
static const WgpDetails turingWgp = {allComputeBits,
allStorageBits,
allSubgroupOps,
allDotProductOps,
ARRAY_SIZE(mmaOps),
mmaOps,
{32, 32},
{1024, 1024, 1024},
1024,
64 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
return &turingWgp;
}

Expand All @@ -388,7 +441,8 @@ const WgpDetails *getVoltaWgpDetails() {
static const WgpDetails voltaWgp = {
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
ARRAY_SIZE(mmaOps), mmaOps, {32, 32}, {1024, 1024, 1024},
1024, 96 * 1024};
1024, 96 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
// clang-format on
return &voltaWgp;
}
Expand All @@ -398,7 +452,8 @@ const WgpDetails *getPascalWgpDetails() {
static const WgpDetails pascalWgp = {
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
0, nullptr, // Pascal does not have tensor core support.
{32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024};
{32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024,
{0x7fffffff, 0xffff, 0xffff}};
// clang-format on
return &pascalWgp;
}
Expand Down Expand Up @@ -479,7 +534,9 @@ const WgpDetails *getAdrenoWgpDetails() {
computeBitwdiths, storageBitwidths, allSubgroupOps,
allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr,
{64, 64}, {1024, 1024, 1024}, 1024,
32 * 1024};
32 * 1024,
// Note: These values have not been checked and may be higher
{0xffff, 0xffff, 0xffff}};
// clang-format on
return &adrenoWgp;
}
Expand Down Expand Up @@ -545,7 +602,8 @@ const WgpDetails *getAndroidBaseline2022WgpDetails() {
computeBitwdiths, storageBitwidths, SubgroupOps::None,
DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
{64, 64}, {128, 128, 64}, 128,
16 * 1024};
16 * 1024,
{0xffff, 0xffff, 0xffff}};
// clang-format on
return &androidWgp;
}
Expand Down Expand Up @@ -645,7 +703,8 @@ TargetAttr getWebGPUTargetDetails(MLIRContext *context) {
computeBitwdiths, storageBitwidths, SubgroupOps::None,
DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
{32, 32}, {128, 128, 64}, 128,
16 * 1024};
16 * 1024,
{0xffff, 0xffff, 0xffff}};
// clang-format on

return createTargetAttr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ func.func @conv_nhwc() {
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
mma = [],
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target}>
func.func @matmul_256x256x256() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
%cst = arith.constant 0.000000e+00 : f32
Expand Down
Loading
Loading