Skip to content

Commit

Permalink
[LLVMGPU] Prefer non-scatter ops as the configuration root (#19581)
Browse files Browse the repository at this point in the history
When fused with another operation, we generally prefer to let the
producer determine the lowering config rather than the scatter.
  • Loading branch information
qedawkins authored Jan 2, 2025
1 parent fa325c5 commit c203e6b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
19 changes: 15 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2506,10 +2506,11 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {

Operation *rootOperation = nullptr;

// Find the root operation. linalg.generic and linalg.fill are not root
// operations if there are other compute operations present.
// Find the root operation. linalg.generic, linalg.fill, and scatter are not
// root operations if there are other compute operations present.
for (Operation *op : llvm::reverse(computeOps)) {
if (!isa<linalg::GenericOp, linalg::FillOp>(op)) {
if (!isa<linalg::GenericOp, linalg::FillOp, IREE::LinalgExt::ScatterOp>(
op)) {
rootOperation = op;
break;
}
Expand All @@ -2522,9 +2523,19 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {
}
}

// Generic ops take priority over scatter and fill ops as the root op.
if (!rootOperation) {
for (Operation *op : llvm::reverse(computeOps)) {
if (isa<linalg::GenericOp, linalg::FillOp>(op)) {
if (isa<linalg::GenericOp>(op)) {
rootOperation = op;
break;
}
}
}

if (!rootOperation) {
for (Operation *op : llvm::reverse(computeOps)) {
if (isa<IREE::LinalgExt::ScatterOp, linalg::FillOp>(op)) {
rootOperation = op;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,30 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5
// CHECK-SAME: reduction = [0, 0, 0, 1]
// CHECK-SAME: subgroup = [0, 1, 8, 0]
// CHECK-SAME: workgroup = [1, 16, 512, 0]

// -----

func.func @elementwise_scatter(%arg0: tensor<3x2048x2048xf32>,
%arg1: tensor<3x2048x2048xf32>,
%arg2: tensor<3x1xi32>) -> tensor<3x2048x2048xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<3x2048x2048xf32>
%1 = linalg.add ins(%arg0, %arg1 : tensor<3x2048x2048xf32>, tensor<3x2048x2048xf32>)
outs(%0 : tensor<3x2048x2048xf32>) -> tensor<3x2048x2048xf32>
%2 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
ins(%1, %arg2 : tensor<3x2048x2048xf32>, tensor<3x1xi32>) outs(%0 : tensor<3x2048x2048xf32>) {
^bb0(%arg3: f32, %arg4: f32):
iree_linalg_ext.yield %arg3 : f32
} -> tensor<3x2048x2048xf32>
return %2 : tensor<3x2048x2048xf32>
}

// CHECK-LABEL: func.func @elementwise_scatter
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64

// CHECK: linalg.add {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1, 1, 4]
// CHECK-SAME: workgroup = [1, 1, 256]

// Verify that the scatter does not get a lowering config
// CHECK: linalg_ext.scatter dimension_map

0 comments on commit c203e6b

Please sign in to comment.