From 1bca666ab33c7728a6a5535ebc64e6174441a7db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Fri, 9 Aug 2024 18:04:25 +0100 Subject: [PATCH] [mlir][vector][test] Split tests from vector-transfer-flatten.mlir (#102584) Move tests that exercise DropUnitDimFromElementwiseOps and DropUnitDimsFromTransposeOp to a dedicated file. While these patterns are collected under populateFlattenVectorTransferPatterns (and are tested via -test-vector-transfer-flatten-patterns), they can actually be tested without the xfer Ops, and hence the split. Note, this is mostly just moving tests from one file to another. The only real change is the removal of the following check-lines: ```mlir // CHECK-128B-NOT: memref.collapse_shape ``` These were added specifically to check the "flattening" logic (which introduces `memref.collapse_shape`). However, these tests were never meant to test that logic (in fact, that's the reason I am moving them to a different file) and hence are being removed as copy&paste errors. I also removed the following TODO: ```mlir /// TODO: Potential duplication with tests from: /// * "vector-dropleadunitdim-transforms.mlir" /// * "vector-transfer-drop-unit-dims-patterns.mlir" ``` I've checked what patterns are triggered in those test files and neither DropUnitDimFromElementwiseOps nor DropUnitDimsFromTransposeOp does. --- .../drop-unit-dims-with-shape-cast.mlir | 209 ++++++++++++++++ .../Vector/vector-transfer-flatten.mlir | 236 ------------------ 2 files changed, 209 insertions(+), 236 deletions(-) create mode 100644 mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir new file mode 100644 index 000000000000000..af3fc924c1dbe76 --- /dev/null +++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir @@ -0,0 +1,209 @@ +// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s + +///---------------------------------------------------------------------------------------- +/// [Pattern: DropUnitDimFromElementwiseOps] +///---------------------------------------------------------------------------------------- + +func.func @fold_unit_dim_add_basic(%vec : vector<1x8xi32>) -> vector<1x8xi32> { + %res = arith.addi %vec, %vec : vector<1x8xi32> + return %res : vector<1x8xi32> +} +// CHECK-LABEL: func.func @fold_unit_dim_add_basic( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi32>) -> vector<1x8xi32> { +// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32> +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32> +// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32> +// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32> +// CHECK: return %[[VAL_4]] : vector<1x8xi32> + +// ----- + +func.func @fold_unit_dim_add_leading_and_trailing(%vec : vector<1x8x1xi32>) -> vector<1x8x1xi32> { + %res = arith.addi %vec, %vec : vector<1x8x1xi32> + return %res : vector<1x8x1xi32> +} +// CHECK-LABEL: func.func @fold_unit_dim_add_leading_and_trailing( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8x1xi32>) -> vector<1x8x1xi32> { +// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32> +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32> +// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32> +// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32> +// CHECK: return %[[VAL_4]] : vector<1x8x1xi32> + +// ----- + +func.func @fold_unit_dim_add(%vec_0 : vector<8x1xi32>, + %vec_1 : vector<1x8xi32>) -> vector<8xi32> { + %sc_vec_0 = vector.shape_cast %vec_0 : vector<8x1xi32> to vector<1x8xi32> + %add = arith.addi %sc_vec_0, %vec_1 : vector<1x8xi32> + %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32> + return %res : vector<8xi32> +} + +// CHECK-LABEL: func.func @fold_unit_dim_add( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> { +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32> +// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32> +// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32> +// CHECK: return %[[VAL_4]] : vector<8xi32> + +// ----- + +func.func @fold_unit_dim_mulf(%vec_0 : vector<8x[2]x1xf32>, + %vec_1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> { + %sc_vec_0 = vector.shape_cast %vec_0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32> + %add = arith.mulf %sc_vec_0, %vec_1 : vector<1x8x[2]xf32> + %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32> + return %res : vector<8x[2]xf32> +} + +// CHECK-LABEL: func.func @fold_unit_dim_mulf( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> { +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32> +// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32> +// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32> +// CHECK: return %[[VAL_4]] : vector<8x[2]xf32> + +// ----- + +func.func @fold_unit_dim_sitofp(%vec : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { + %sc_vec_0 = vector.shape_cast %vec : vector<8x[2]x1xi8> to vector<1x8x[2]xi8> + %add = arith.sitofp %sc_vec_0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32> + %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32> + return %res : vector<8x[2]xf32> +} + +// CHECK-LABEL: func.func @fold_unit_dim_sitofp( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { +// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8> +// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32> +// CHECK: return %[[VAL_2]] : vector<8x[2]xf32> + +// ----- + +// All shape casts are folded away + +func.func @fold_unit_dims_entirely(%vec_0 : vector<8xi32>, + %vec_1 : vector<8xi32>, + %vec_2 : vector<8xi32>) -> vector<8xi32> { + %sc_vec_0 = vector.shape_cast %vec_0 : vector<8xi32> to vector<1x8xi32> + %sc_vec_1 = vector.shape_cast %vec_1 : vector<8xi32> to vector<1x8xi32> + %sc_vec_2 = vector.shape_cast %vec_2 : vector<8xi32> to vector<1x8xi32> + %mul = arith.muli %sc_vec_0, %sc_vec_1 : vector<1x8xi32> + %add = arith.addi %mul, %sc_vec_2 : vector<1x8xi32> + %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32> + return %res : vector<8xi32> +} + +// CHECK-LABEL: func.func @fold_unit_dims_entirely( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi32>, %[[VAL_1:.*]]: vector<8xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> { +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32> +// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32> +// CHECK: return %[[VAL_4]] : vector<8xi32> + +// ----- + +func.func @fold_inner_unit_dim(%vec_0 : vector<8x1x3xf128>, + %vec_1 : vector<1x8x3xf128>) -> vector<8x3xf128> { + %sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x3xf128> to vector<8x1x3xf128> + %mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x3xf128> + %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128> + return %res : vector<8x3xf128> +} + +// CHECK-LABEL: func.func @fold_inner_unit_dim( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> { +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128> +// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128> +// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128> +// CHECK: return %[[VAL_4]] : vector<8x3xf128> + +// ----- + +func.func @fold_inner_unit_dim_scalable(%vec_0 : vector<8x1x[1]x3xf128>, + %vec_1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> { + %sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128> + %mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x[1]x3xf128> + %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128> + return %res : vector<8x[1]x3xf128> +} + +// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable( +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> { +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128> +// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128> +// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128> +// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128> + +// ----- + +func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> { + %0 = arith.mulf %vec, %vec : vector<1x1xf32> + %res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32> + return %res : vector<1xf32> +} + +// CHECK-LABEL: func.func @fold_all_unit_dims( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32> +// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32> +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32> +// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32> +// CHECK: return %[[VAL_3]] : vector<1xf32> + +///---------------------------------------------------------------------------------------- +/// [Pattern: DropUnitDimsFromTransposeOp] +///---------------------------------------------------------------------------------------- + +func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> { + %res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> + return %res : vector<[4]x1x1x4xf32> +} + +// CHECK-LABEL: func.func @transpose_with_internal_unit_dims( +// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>) +// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> +// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> +// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32> + +// ----- + +func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32> +{ + %res = vector.transpose %vec, [4, 1, 3, 2, 0] : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32> + return %res: vector<1x1x4x2x[1]xf32> +} + +// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims( +// CHECK-SAME: %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>) +// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32> +// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32> +// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32> + +// ----- + +func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> { + %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32> + return %res : vector<1x1x1xf32> +} +// The `vec` is returned because there are other flattening patterns that fold +// vector.shape_cast ops away. +// CHECK-LABEL: func.func @transpose_with_all_unit_dims +// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]] +// CHECK-NEXT: return %[[VEC]] + +// ----- + +func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> { + %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> + return %res : vector<4x3x2xf32> +} + +// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims +// CHECK-NOT: vector.shape_cast diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 85afdf7a7dc7718..e840dc6bbf224c7 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -538,186 +538,6 @@ func.func @transfer_write_non_contiguous_src( // ----- -///---------------------------------------------------------------------------------------- -/// [Pattern: DropUnitDimFromElementwiseOps] -/// -/// TODO: Move to a dedicated file - there's no "flattening" in the following tests -/// TODO: Potential duplication with tests from: -/// * "vector-dropleadunitdim-transforms.mlir" -/// * "vector-transfer-drop-unit-dims-patterns.mlir" -///---------------------------------------------------------------------------------------- - -func.func @fold_unit_dim_add_basic(%vec : vector<1x8xi32>) -> vector<1x8xi32> { - %res = arith.addi %vec, %vec : vector<1x8xi32> - return %res : vector<1x8xi32> -} -// CHECK-LABEL: func.func @fold_unit_dim_add_basic( -// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi32>) -> vector<1x8xi32> { -// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32> -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32> -// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32> -// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32> -// CHECK: return %[[VAL_4]] : vector<1x8xi32> - -// CHECK-128B-LABEL: func @fold_unit_dim_add_basic( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @fold_unit_dim_add_leading_and_trailing(%vec : vector<1x8x1xi32>) -> vector<1x8x1xi32> { - %res = arith.addi %vec, %vec : vector<1x8x1xi32> - return %res : vector<1x8x1xi32> -} -// CHECK-LABEL: func.func @fold_unit_dim_add_leading_and_trailing( -// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8x1xi32>) -> vector<1x8x1xi32> { -// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32> -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32> -// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32> -// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32> -// CHECK: return %[[VAL_4]] : vector<1x8x1xi32> - -// CHECK-128B-LABEL: func @fold_unit_dim_add_leading_and_trailing( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @fold_unit_dim_add(%vec_0 : vector<8x1xi32>, - %vec_1 : vector<1x8xi32>) -> vector<8xi32> { - %sc_vec_0 = vector.shape_cast %vec_0 : vector<8x1xi32> to vector<1x8xi32> - %add = arith.addi %sc_vec_0, %vec_1 : vector<1x8xi32> - %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32> - return %res : vector<8xi32> -} - -// CHECK-LABEL: func.func @fold_unit_dim_add( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> { -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32> -// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32> -// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32> -// CHECK: return %[[VAL_4]] : vector<8xi32> - -// CHECK-128B-LABEL: func @fold_unit_dim_add( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @fold_unit_dim_mulf(%vec_0 : vector<8x[2]x1xf32>, - %vec_1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> { - %sc_vec_0 = vector.shape_cast %vec_0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32> - %add = arith.mulf %sc_vec_0, %vec_1 : vector<1x8x[2]xf32> - %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32> - return %res : vector<8x[2]xf32> -} - -// CHECK-LABEL: func.func @fold_unit_dim_mulf( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> { -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32> -// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32> -// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32> -// CHECK: return %[[VAL_4]] : vector<8x[2]xf32> - -// CHECK-128B-LABEL: func @fold_unit_dim_mulf( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @fold_unit_dim_sitofp(%vec : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { - %sc_vec_0 = vector.shape_cast %vec : vector<8x[2]x1xi8> to vector<1x8x[2]xi8> - %add = arith.sitofp %sc_vec_0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32> - %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32> - return %res : vector<8x[2]xf32> -} - -// CHECK-LABEL: func.func @fold_unit_dim_sitofp( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { -// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8> -// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32> -// CHECK: return %[[VAL_2]] : vector<8x[2]xf32> - -// CHECK-128B-LABEL: func @fold_unit_dim_sitofp( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -// All shape casts are folded away - -func.func @fold_unit_dims_entirely(%vec_0 : vector<8xi32>, - %vec_1 : vector<8xi32>, - %vec_2 : vector<8xi32>) -> vector<8xi32> { - %sc_vec_0 = vector.shape_cast %vec_0 : vector<8xi32> to vector<1x8xi32> - %sc_vec_1 = vector.shape_cast %vec_1 : vector<8xi32> to vector<1x8xi32> - %sc_vec_2 = vector.shape_cast %vec_2 : vector<8xi32> to vector<1x8xi32> - %mul = arith.muli %sc_vec_0, %sc_vec_1 : vector<1x8xi32> - %add = arith.addi %mul, %sc_vec_2 : vector<1x8xi32> - %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32> - return %res : vector<8xi32> -} - -// CHECK-LABEL: func.func @fold_unit_dims_entirely( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi32>, %[[VAL_1:.*]]: vector<8xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> { -// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32> -// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32> -// CHECK: return %[[VAL_4]] : vector<8xi32> - -// CHECK-128B-LABEL: func @fold_unit_dims_entirely( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @fold_inner_unit_dim(%vec_0 : vector<8x1x3xf128>, - %vec_1 : vector<1x8x3xf128>) -> vector<8x3xf128> { - %sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x3xf128> to vector<8x1x3xf128> - %mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x3xf128> - %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128> - return %res : vector<8x3xf128> -} - -// CHECK-LABEL: func.func @fold_inner_unit_dim( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> { -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128> -// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128> -// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128> -// CHECK: return %[[VAL_4]] : vector<8x3xf128> - -// ----- - -func.func @fold_inner_unit_dim_scalable(%vec_0 : vector<8x1x[1]x3xf128>, - %vec_1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> { - %sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128> - %mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x[1]x3xf128> - %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128> - return %res : vector<8x[1]x3xf128> -} - -// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable( -// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> { -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128> -// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128> -// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128> -// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128> - -// ----- - -func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> { - %0 = arith.mulf %vec, %vec : vector<1x1xf32> - %res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32> - return %res : vector<1xf32> -} - -// CHECK-LABEL: func.func @fold_all_unit_dims( -// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32> -// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32> -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32> -// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32> -// CHECK: return %[[VAL_3]] : vector<1xf32> - -// ----- - func.func @negative_out_of_bound_transfer_read( %mem : memref>) -> vector<5x4x3x2xi8> { %c0 = arith.constant 0 : index @@ -740,59 +560,3 @@ func.func @negative_out_of_bound_transfer_write( } // CHECK: func.func @negative_out_of_bound_transfer_write // CHECK-NOT: memref.collapse_shape - -// ----- - -///---------------------------------------------------------------------------------------- -/// [Pattern: DropUnitDimsFromTransposeOp] -/// TODO: Move to a dedicated file - there's no "flattening" in the following tests -///---------------------------------------------------------------------------------------- - -func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> { - %res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> - return %res : vector<[4]x1x1x4xf32> -} - -// CHECK-LABEL: func.func @transpose_with_internal_unit_dims( -// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>) -// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> -// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> -// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> -// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32> - -// ----- - -func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32> -{ - %res = vector.transpose %vec, [4, 1, 3, 2, 0] : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32> - return %res: vector<1x1x4x2x[1]xf32> -} - -// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims( -// CHECK-SAME: %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>) -// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32> -// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32> -// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32> -// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32> - -// ----- - -func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> { - %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32> - return %res : vector<1x1x1xf32> -} -// The `vec` is returned because there are other flattening patterns that fold -// vector.shape_cast ops away. -// CHECK-LABEL: func.func @transpose_with_all_unit_dims -// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]] -// CHECK-NEXT: return %[[VEC]] - -// ----- - -func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> { - %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> - return %res : vector<4x3x2xf32> -} - -// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims -// CHECK-NOT: vector.shape_cast