Skip to content


This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Flow] Enable softmax-like fusion under aggressive fusion.
Browse files Browse the repository at this point in the history
Under aggressive fusion, drop the restriction of consumer iteration
space being same dimensionality as the producer iteration
space. Typically this can lead to large vectors if not handled
properly. So this is guarded under
`--iree-flow-enable-aggressive-fusion` flag.

Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar committed Jun 27, 2024
1 parent 36aec8a commit 10c3216
Showing 2 changed files with 113 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -636,10 +636,12 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// Check if the iteration spaces of the producer and consumer are same.
// TODO(#12664): This is unnecessary requirement, but we need a better config
// to tile the consumer with a larger iteration space.
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;
if (!options.aggressiveFusion) {
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;

// Under aggressive fusion assume that the dispatches are vectorized. In which
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s

util.func public @pack_elementwise_fusion(%arg0 : tensor<?xf32>,
%arg1 : tensor<?x?xf32>) -> tensor<?x?x8x32xf32> {
@@ -640,3 +640,109 @@ util.func public @broadcasting_dequant_op(%arg0 : tensor<?x?xi8>,
// CHECK-SAME: ins(%[[GENERIC]],
// CHECK: flow.return %[[MATMUL]]
// CHECK: return %[[RETURN]]

// -----

util.func @softmax_like_fusion(%arg0: tensor<2x4096x640xf16>,
%arg1: tensor<640xf16>, %arg2: tensor<640xf16>) -> tensor<2x4096x640x1xf16> {
%expanded = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
output_shape [2, 4096, 640, 1] : tensor<2x4096x640xf16> into tensor<2x4096x640x1xf16>
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.100000e+01 : f32
%cst_1 = arith.constant 4.000000e+00 : f32
%0 = tensor.empty() : tensor<2x4096x640xf32>
%1 = tensor.empty() : tensor<2x4096x640x1xf16>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<2x4096x640xf16>) outs(%0 : tensor<2x4096x640xf32>) {
^bb0(%in: f16, %out: f32):
%9 = arith.extf %in : f16 to f32
linalg.yield %9 : f32
} -> tensor<2x4096x640xf32>
%3 = tensor.empty() : tensor<2x4096xf32>
%4 = linalg.fill ins(%cst : f32)
outs(%3 : tensor<2x4096xf32>) -> tensor<2x4096xf32>
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%2 : tensor<2x4096x640xf32>) outs(%4 : tensor<2x4096xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.addf %in, %out : f32
linalg.yield %9 : f32
} -> tensor<2x4096xf32>
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%5 : tensor<2x4096xf32>) outs(%3 : tensor<2x4096xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.divf %in, %cst_0 : f32
linalg.yield %9 : f32
} -> tensor<2x4096xf32>
%7 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%2, %6 : tensor<2x4096x640xf32>, tensor<2x4096xf32>)
outs(%4 : tensor<2x4096xf32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%9 = arith.subf %in, %in_4 : f32
%10 = arith.mulf %9, %9 : f32
%11 = arith.addf %10, %out : f32
linalg.yield %11 : f32
} -> tensor<2x4096xf32>
%expanded_2 = tensor.expand_shape %arg1 [[0, 1]] output_shape [640, 1]
: tensor<640xf16> into tensor<640x1xf16>
%expanded_3 = tensor.expand_shape %arg2 [[0, 1]] output_shape [640, 1]
: tensor<640xf16> into tensor<640x1xf16>
%8 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded, %6, %7, %expanded_2, %expanded_3
: tensor<2x4096x640x1xf16>, tensor<2x4096xf32>, tensor<2x4096xf32>,
tensor<640x1xf16>, tensor<640x1xf16>)
outs(%1 : tensor<2x4096x640x1xf16>) {
^bb0(%in: f16, %in_4: f32, %in_5: f32, %in_6: f16, %in_7: f16, %out: f16):
%9 = arith.divf %in_5, %cst_0 : f32
%10 = arith.addf %9, %cst_1 : f32
%11 = math.rsqrt %10 : f32
%12 = arith.extf %in : f16 to f32
%13 = arith.subf %12, %in_4 : f32
%14 = arith.mulf %13, %11 : f32
%15 = arith.extf %in_6 : f16 to f32
%16 = arith.mulf %14, %15 : f32
%17 = arith.extf %in_7 : f16 to f32
%18 = arith.addf %16, %17 : f32
%19 = arith.truncf %18 : f32 to f16
linalg.yield %19 : f16
} -> tensor<2x4096x640x1xf16>
util.return %8 : tensor<2x4096x640x1xf16>
// CHECK-LABEL: func public @softmax_like_fusion(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x4096x640xf16>
// CHECK: %[[BITEXTEND:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: %[[RESULT:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[GENERIC1]] :
// CHECK: %[[GENERIC3:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[BITEXTEND]], %[[GENERIC2]] :
// CHECK: %[[GENERIC4:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC2]], %[[GENERIC3]]
// CHECK: flow.return %[[GENERIC4]]
// CHECK: util.return %[[RESULT]]

0 comments on commit 10c3216

Please sign in to comment.