Skip to content

Commit

Permalink
Increase default threshold of TileLargeTensor pass (#19671)
Browse files Browse the repository at this point in the history
The current threshold of 64 causes performance regresion on TileAndFuse
pipeline for GEMMs as it generates code that cant be optimized by the
prefetch Shared memory pass. With the threshold of 256 we should not
have this issue.

Signed-off-by: Nirvedh Meshram <[email protected]>
  • Loading branch information
nirvedhmeshram authored Jan 13, 2025
1 parent 2452b22 commit 3978ce6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def TileLargeTensorsPass :
];
let options = [
Option<"maxVectorSize", "max-vector-size", "int64_t",
/*default=*/"64",
/*default=*/"256",
"Maximum static size to tile to (i.e. all remaining ops will be smaller)">,
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
// RUN: FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> {
func.func @simple_generic(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>, %5: tensor<64x512xf32>) -> tensor<64x512xf32> {
%6 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]
} ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) {
} ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%5 : tensor<64x512xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.addf %in, %in_0 : f32
linalg.yield %7 : f32
} -> tensor<64x256xf32>
return %6 : tensor<64x256xf32>
} -> tensor<64x512xf32>
return %6 : tensor<64x512xf32>
}

// CHECK-LABEL: func.func @simple_generic
// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>)
// CHECK: scf.for %{{.*}} = %c0 to %c512 step %c256
// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x256xf32>)

// -----

Expand Down Expand Up @@ -65,21 +65,21 @@ func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: te

// -----

func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>) {
%add_empty = tensor.empty() : tensor<64x256xf32>
func.func @multiple_use_tilable_op(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>) -> (tensor<64x512xf32>, tensor<512x64xf32>) {
%add_empty = tensor.empty() : tensor<64x512xf32>
%6 = linalg.add
ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>)
outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32>
%transpose_empty = tensor.empty() : tensor<256x64xf32>
ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>)
outs(%add_empty : tensor<64x512xf32>) -> tensor<64x512xf32>
%transpose_empty = tensor.empty() : tensor<512x64xf32>
%7 = linalg.transpose
ins(%6 : tensor<64x256xf32>)
outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0]
return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32>
ins(%6 : tensor<64x512xf32>)
outs(%transpose_empty : tensor<512x64xf32>) permutation = [1, 0]
return %6, %7 : tensor<64x512xf32>, tensor<512x64xf32>
}

// CHECK-LABEL: func.func @multiple_use_tilable_op
// CHECK: %[[ADD_TILING:.+]] = scf.for
// CHECK: linalg.add {{.*}} -> tensor<1x64xf32>
// CHECK: linalg.add {{.*}} -> tensor<1x256xf32>
// CHECK: %[[T_TILING:.+]] = scf.for
// CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32>
// CHECK: linalg.transpose ins(%[[FUSED_ADD]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,9 +1013,8 @@ hal.executable public @main {
// CHECK: scf.yield %[[REDUCE]]

// CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1
// CHECK: scf.for
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>
// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32>
// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type<storage_buffer>>

// -----

Expand Down

0 comments on commit 3978ce6

Please sign in to comment.