Skip to content

Commit

Permalink
[gpu] Use clustered gpu.subgroup_reduce for nested layout distribution
Browse files Browse the repository at this point in the history
There is now support in MLIR for expressing a subgroup reduction
operation that operates on several "clusters" in parallel, so it is no
longer necessary to build a series of shuffles.

It has been verified that, at least if the upstream patterns are used,
the resulting sequence of shuffles is the same as the old code.

Resolves #18142.

Signed-off-by: Andrea Faulds <[email protected]>
  • Loading branch information
andfau-amd committed Sep 13, 2024
1 parent febe0ed commit 21c8d62
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -513,8 +514,15 @@ struct DistributeMultiReduction final
// Reduce across all reduction dimensions 1-by-1.
for (unsigned i = 0; i < reductionMask.size(); ++i) {
if (reductionMask[i]) {
extracted = doPackedThreadReductionOnDim(rewriter, layout, extracted,
kind, i);
int64_t offset = getShuffleOffset(layout, i);
int64_t width = getShuffleWidth(layout, i);
if (offset > UINT32_MAX || width > UINT32_MAX)
return failure();

extracted = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, combiningKindToAllReduce(kind),
/*uniform=*/false, /*cluster_size=*/width,
/*cluster_stride=*/offset);
}
}

Expand All @@ -524,25 +532,6 @@ struct DistributeMultiReduction final
return res;
}

Value doPackedThreadReductionOnDim(RewriterBase &rewriter,
NestedLayoutAttr layout, Value val,
vector::CombiningKind kind,
int64_t dim) const {
Location loc = val.getLoc();
int64_t offset = getShuffleOffset(layout, dim);
int64_t width = getShuffleWidth(layout, dim);

for (int i = offset; i < offset * width; i <<= 1) {
auto shuffleOp = rewriter.create<gpu::ShuffleOp>(
loc, val, i, subgroupSize, gpu::ShuffleMode::XOR);
val =
makeArithReduction(rewriter, loc, kind, shuffleOp.getShuffleResult(),
val, nullptr, nullptr);
}

return val;
}

int64_t subgroupSize;
int64_t maxBitsPerShuffle;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -963,19 +963,13 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
// Local reduction
// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
// Global reduction
// CHECK: gpu.shuffle xor %{{.*}}, %[[C16]], %[[C64]] : f32
// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32
// CHECK: gpu.shuffle xor %{{.*}}, %[[C16]], %[[C64]] : f32
// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32
// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
// Accumulator reduction
// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
Expand Down Expand Up @@ -1012,11 +1006,9 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : i32
// Local reduction
// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
// Global reduction
// CHECK: gpu.shuffle xor %{{.*}}, %[[C32]], %[[C64]] : f32
// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
// Accumulator reduction
// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
5 changes: 2 additions & 3 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,8 @@ Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
return identity;
}

/// Return a matching GPU reduction operations.
static gpu::AllReduceOperation
combiningKindToAllReduce(vector::CombiningKind kind) {
/// Returns the matching GPU reduction operation.
gpu::AllReduceOperation combiningKindToAllReduce(vector::CombiningKind kind) {
switch (kind) {
#define MAP_CASE(X) \
case vector::CombiningKind::X: \
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
Expand Down Expand Up @@ -156,6 +157,11 @@ Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput,
/// Emit identity constant based on combiningKind and type.
Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
vector::CombiningKind kind, Type identityType);

/// Returns the matching GPU reduction operation.
mlir::gpu::AllReduceOperation
combiningKindToAllReduce(vector::CombiningKind kind);

//===----------------------------------------------------------------------===//
// GPU CodeGen op filter
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 21c8d62

Please sign in to comment.