Skip to content

Commit

Permalink
[gpu] Use clustered gpu.subgroup_reduce for nested layout distribution (
Browse files Browse the repository at this point in the history
#18515)

There is now support in MLIR for expressing a subgroup reduction
operation that operates on several "clusters" in parallel, so it is no
longer necessary to build a series of shuffles.

It has been verified that, at least if the upstream patterns are used,
the resulting sequence of shuffles is the same as the old code.

This commit also adds a new pass, ExpandGPUOps, which uses the upstream
patterns to expand these ops, and adds it to the LLVMGPU pass list.

Resolves #18142.

Signed-off-by: Andrea Faulds <[email protected]>
  • Loading branch information
andfau-amd authored Sep 23, 2024
1 parent 0d9c5a8 commit c0909a4
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 36 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ iree_compiler_cc_library(
name = "CommonGPUPasses",
srcs = [
"AMDGPUDistributeContract.cpp",
"ExpandGPUOps.cpp",
"GPUApplyTilingLevel.cpp",
"GPUCheckResourceUsage.cpp",
"GPUCombineValueBarriers.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ iree_cc_library(
"Passes.h"
SRCS
"AMDGPUDistributeContract.cpp"
"ExpandGPUOps.cpp"
"GPUApplyTilingLevel.cpp"
"GPUCheckResourceUsage.cpp"
"GPUCombineValueBarriers.cpp"
Expand Down
48 changes: 48 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/ExpandGPUOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-expand-gpu-ops"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_EXPANDGPUOPSPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

namespace {

struct ExpandGPUOpsPass final : impl::ExpandGPUOpsPassBase<ExpandGPUOpsPass> {
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = &getContext();

std::optional<int> subgroupSize = getGPUSubgroupSize(funcOp);
if (!subgroupSize) {
funcOp->emitOpError("missing subgroup size");
return signalPassFailure();
}

RewritePatternSet patterns(ctx);
populateGpuBreakDownSubgroupReducePatterns(
patterns, /* maxShuffleBitwidth=*/32, PatternBenefit(2));
populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
patterns, *subgroupSize, /* shuffleBitwidth=*/32, PatternBenefit(1));
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
};
};

} // namespace

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -512,10 +513,17 @@ struct DistributeMultiReduction final
Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);

// Reduce across all reduction dimensions 1-by-1.
for (unsigned i = 0; i < reductionMask.size(); ++i) {
for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
if (reductionMask[i]) {
extracted = doPackedThreadReductionOnDim(rewriter, layout, extracted,
kind, i);
int64_t offset = getShuffleOffset(layout, i);
int64_t width = getShuffleWidth(layout, i);
assert(offset <= std::numeric_limits<uint32_t>::max() &&
width <= std::numeric_limits<uint32_t>::max());

extracted = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, combiningKindToAllReduce(kind),
/*uniform=*/false, /*cluster_size=*/width,
/*cluster_stride=*/offset);
}
}

Expand All @@ -525,25 +533,6 @@ struct DistributeMultiReduction final
return res;
}

Value doPackedThreadReductionOnDim(RewriterBase &rewriter,
NestedLayoutAttr layout, Value val,
vector::CombiningKind kind,
int64_t dim) const {
Location loc = val.getLoc();
int64_t offset = getShuffleOffset(layout, dim);
int64_t width = getShuffleWidth(layout, dim);

for (int i = offset; i < offset * width; i <<= 1) {
auto shuffleOp = rewriter.create<gpu::ShuffleOp>(
loc, val, i, subgroupSize, gpu::ShuffleMode::XOR);
val =
makeArithReduction(rewriter, loc, kind, shuffleOp.getShuffleResult(),
val, nullptr, nullptr);
}

return val;
}

int64_t subgroupSize;
int64_t maxBitsPerShuffle;
};
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,12 @@ def VectorReductionToGPUPass :
];
}

def ExpandGPUOpsPass :
InterfacePass<"iree-codegen-expand-gpu-ops", "mlir::FunctionOpInterface"> {
let summary = "Expands high-level GPU ops, such as clustered gpu.subgroup_reduce.";
let dependentDialects = [
"::mlir::gpu::GPUDialect"
];
}

#endif // IREE_CODEGEN_COMMON_GPU_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -963,19 +963,13 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
// Local reduction
// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
// Global reduction
// CHECK: gpu.shuffle xor %{{.*}}, %[[C16]], %[[C64]] : f32
// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32
// CHECK: gpu.shuffle xor %{{.*}}, %[[C16]], %[[C64]] : f32
// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32
// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
// Accumulator reduction
// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
Expand Down Expand Up @@ -1012,11 +1006,9 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
// Local reduction
// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
// Global reduction
// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32
// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
// Accumulator reduction
// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,

FunctionLikeNest funcPassManager(modulePassManager);
funcPassManager.addPass(createFoldTensorExtractOpPass)
.addPass(createLLVMGPUVectorLoweringPass);
.addPass(createLLVMGPUVectorLoweringPass)
.addPass(createExpandGPUOpsPass);

// This pass needs to run before SCF -> CF.
addLowerAndOptimizeAddressComputationPasses(funcPassManager);
Expand Down
5 changes: 2 additions & 3 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,8 @@ Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
return identity;
}

/// Return a matching GPU reduction operations.
static gpu::AllReduceOperation
combiningKindToAllReduce(vector::CombiningKind kind) {
/// Returns the matching GPU reduction operation.
gpu::AllReduceOperation combiningKindToAllReduce(vector::CombiningKind kind) {
switch (kind) {
#define MAP_CASE(X) \
case vector::CombiningKind::X: \
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
Expand Down Expand Up @@ -156,6 +157,11 @@ Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput,
/// Emit identity constant based on combiningKind and type.
Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
vector::CombiningKind kind, Type identityType);

/// Returns the matching GPU reduction operation.
mlir::gpu::AllReduceOperation
combiningKindToAllReduce(vector::CombiningKind kind);

//===----------------------------------------------------------------------===//
// GPU CodeGen op filter
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit c0909a4

Please sign in to comment.