From 25c1a418a4873164f86d5fed6c858cf1ecd68be2 Mon Sep 17 00:00:00 2001 From: Andrea Faulds Date: Wed, 18 Sep 2024 22:16:08 +0200 Subject: [PATCH] [gpu] Use clustered gpu.subgroup_reduce for nested layout distribution There is now support in MLIR for expressing a subgroup reduction operation that operates on several "clusters" in parallel, so it is no longer necessary to build a series of shuffles. It has been verified that, at least if the upstream patterns are used, the resulting sequence of shuffles is the same as the old code. This commit also adds a new pass, ExpandGPUOps, which uses the upstream patterns to expand these ops, and adds it to the LLVMGPU pass list. Resolves #18142. Signed-off-by: Andrea Faulds --- .../compiler/Codegen/Common/GPU/BUILD.bazel | 1 + .../Codegen/Common/GPU/CMakeLists.txt | 1 + .../Codegen/Common/GPU/ExpandGPUOps.cpp | 48 +++++++++++++++++++ .../GPUNestedLayoutDistributionPatterns.cpp | 33 +++++-------- .../compiler/Codegen/Common/GPU/Passes.td | 8 ++++ ...gpu_nested_layout_vector_distribution.mlir | 12 +---- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 3 +- .../iree/compiler/Codegen/Utils/GPUUtils.cpp | 5 +- .../iree/compiler/Codegen/Utils/GPUUtils.h | 6 +++ 9 files changed, 81 insertions(+), 36 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index fdfd562910845..cee420a2e52dc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -50,6 +50,7 @@ iree_compiler_cc_library( name = "CommonGPUPasses", srcs = [ "AMDGPUDistributeContract.cpp", + "ExpandGPUOps.cpp", "GPUApplyTilingLevel.cpp", "GPUCheckResourceUsage.cpp", "GPUCombineValueBarriers.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index d73a119fda3f5..0fb55c2287992 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -48,6 +48,7 @@ iree_cc_library( "Passes.h" SRCS "AMDGPUDistributeContract.cpp" + "ExpandGPUOps.cpp" "GPUApplyTilingLevel.cpp" "GPUCheckResourceUsage.cpp" "GPUCombineValueBarriers.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp new file mode 100644 index 0000000000000..cd657e95bf0c6 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp @@ -0,0 +1,48 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-expand-gpu-ops" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_EXPANDGPUOPSPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { + +struct ExpandGPUOpsPass final : impl::ExpandGPUOpsPassBase { + void runOnOperation() override { + FunctionOpInterface funcOp = getOperation(); + MLIRContext *ctx = &getContext(); + + std::optional subgroupSize = getGPUSubgroupSize(funcOp); + if (!subgroupSize) { + funcOp->emitOpError("missing subgroup size"); + return signalPassFailure(); + } + + RewritePatternSet patterns(ctx); + populateGpuBreakDownSubgroupReducePatterns( + patterns, /* maxShuffleBitwidth=*/32, PatternBenefit(2)); + populateGpuLowerClusteredSubgroupReduceToShufflePatterns( + patterns, *subgroupSize, /* shuffleBitwidth=*/32, PatternBenefit(1)); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + }; +}; + +} // namespace + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 260f7c24b07c4..c9b8f3ad72c21 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/Transforms/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" @@ -511,10 +512,17 @@ struct DistributeMultiReduction final Value extracted = rewriter.create(loc, flat, i); // Reduce across all reduction dimensions 1-by-1. - for (unsigned i = 0; i < reductionMask.size(); ++i) { + for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) { if (reductionMask[i]) { - extracted = doPackedThreadReductionOnDim(rewriter, layout, extracted, - kind, i); + int64_t offset = getShuffleOffset(layout, i); + int64_t width = getShuffleWidth(layout, i); + assert(offset <= std::numeric_limits::max() && + width <= std::numeric_limits::max()); + + extracted = rewriter.create( + loc, extracted, combiningKindToAllReduce(kind), + /*uniform=*/false, /*cluster_size=*/width, + /*cluster_stride=*/offset); } } @@ -524,25 +532,6 @@ struct DistributeMultiReduction final return res; } - Value doPackedThreadReductionOnDim(RewriterBase &rewriter, - NestedLayoutAttr layout, Value val, - vector::CombiningKind kind, - int64_t dim) const { - Location loc = val.getLoc(); - int64_t offset = getShuffleOffset(layout, dim); - int64_t width = getShuffleWidth(layout, dim); - - for (int i = offset; i < offset * width; i <<= 1) { - auto shuffleOp = rewriter.create( - loc, val, i, subgroupSize, gpu::ShuffleMode::XOR); - val = - makeArithReduction(rewriter, loc, kind, shuffleOp.getShuffleResult(), - val, nullptr, nullptr); - } - - return val; - } - int64_t subgroupSize; int64_t maxBitsPerShuffle; }; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 09c4138971591..1c125d3f79126 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -263,4 +263,12 @@ def VectorReductionToGPUPass : ]; } +def ExpandGPUOpsPass : + InterfacePass<"iree-codegen-expand-gpu-ops", "mlir::FunctionOpInterface"> { + let summary = "Expands high-level GPU ops, such as clustered gpu.subgroup_reduce."; + let dependentDialects = [ + "::mlir::gpu::GPUDialect" + ]; +} + #endif // IREE_CODEGEN_COMMON_GPU_PASSES diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir index 9876b0086c654..3b9b016dde634 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir @@ -963,19 +963,13 @@ builtin.module attributes { transform.with_named_sequence } { } // CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1 -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32 -// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32 -// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32 // CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32> // CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32> // CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32> // Local reduction // CHECK: vector.multi_reduction , %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32> // Global reduction -// CHECK: gpu.shuffle xor %{{.*}}, %[[C16]], %[[C64]] : f32 -// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32 -// CHECK: gpu.shuffle xor %{{.*}}, %[[C16]], %[[C64]] : f32 -// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32 +// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 // Accumulator reduction // CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32> // CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32> @@ -1012,11 +1006,9 @@ builtin.module attributes { transform.with_named_sequence } { } // CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1 -// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32 -// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32 // Local reduction // CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32> // Global reduction -// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32 +// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32 // Accumulator reduction // CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index ec74f3541d920..68d07febaaa26 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1040,7 +1040,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, FunctionLikeNest funcPassManager(modulePassManager); funcPassManager.addPass(createFoldTensorExtractOpPass) - .addPass(createLLVMGPUVectorLoweringPass); + .addPass(createLLVMGPUVectorLoweringPass) + .addPass(createExpandGPUOpsPass); // This pass needs to run before SCF -> CF. addLowerAndOptimizeAddressComputationPasses(funcPassManager); diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index 041784ae6b5e0..5eb4519ec8cee 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -535,9 +535,8 @@ Value getCombiningIdentityValue(Location loc, OpBuilder &builder, return identity; } -/// Return a matching GPU reduction operations. -static gpu::AllReduceOperation -combiningKindToAllReduce(vector::CombiningKind kind) { +/// Returns the matching GPU reduction operation. +gpu::AllReduceOperation combiningKindToAllReduce(vector::CombiningKind kind) { switch (kind) { #define MAP_CASE(X) \ case vector::CombiningKind::X: \ diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h index e089b0005e228..cdbc297cb4c12 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h @@ -9,6 +9,7 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" @@ -156,6 +157,11 @@ Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput, /// Emit identity constant based on combiningKind and type. Value getCombiningIdentityValue(Location loc, OpBuilder &builder, vector::CombiningKind kind, Type identityType); + +/// Returns the matching GPU reduction operation. +mlir::gpu::AllReduceOperation +combiningKindToAllReduce(vector::CombiningKind kind); + //===----------------------------------------------------------------------===// // GPU CodeGen op filter //===----------------------------------------------------------------------===//