From 66347e516e22f9159b86024071fb92f364ac4418 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Mon, 29 Jan 2024 00:30:19 -0800 Subject: [PATCH] [mlir][vector] Drop inner unit dims for transfer ops on dynamic shapes. (#79752) --- .../Vector/Transforms/VectorTransforms.cpp | 29 ++++++++------ ...tor-transfer-collapse-inner-most-dims.mlir | 40 +++++++++++++++++++ 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 12aa11e9e33f5e7..8363e73857e5c54 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1236,7 +1236,7 @@ class DropInnerMostUnitDimsTransferRead return failure(); auto srcType = dyn_cast(readOp.getSource().getType()); - if (!srcType || !srcType.hasStaticShape()) + if (!srcType) return failure(); if (!readOp.getPermutationMap().isMinorIdentity()) @@ -1260,19 +1260,21 @@ class DropInnerMostUnitDimsTransferRead targetType.getElementType()); auto loc = readOp.getLoc(); + SmallVector sizes = + memref::getMixedSizes(rewriter, loc, readOp.getSource()); + SmallVector offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), + rewriter.getIndexAttr(1)); MemRefType resultMemrefType = getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); - SmallVector offsets(srcType.getRank(), 0); - SmallVector strides(srcType.getRank(), 1); - ArrayAttr inBoundsAttr = readOp.getInBounds() ? rewriter.getArrayAttr( readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) : ArrayAttr(); Value rankedReducedView = rewriter.create( - loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), - strides); + loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); Value result = rewriter.create( @@ -1318,7 +1320,7 @@ class DropInnerMostUnitDimsTransferWrite return failure(); auto srcType = dyn_cast(writeOp.getSource().getType()); - if (!srcType || !srcType.hasStaticShape()) + if (!srcType) return failure(); if (!writeOp.getPermutationMap().isMinorIdentity()) @@ -1341,20 +1343,23 @@ class DropInnerMostUnitDimsTransferWrite VectorType::get(targetType.getShape().drop_back(dimsToDrop), targetType.getElementType()); + Location loc = writeOp.getLoc(); + SmallVector sizes = + memref::getMixedSizes(rewriter, loc, writeOp.getSource()); + SmallVector offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), + rewriter.getIndexAttr(1)); MemRefType resultMemrefType = getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); - SmallVector offsets(srcType.getRank(), 0); - SmallVector strides(srcType.getRank(), 1); ArrayAttr inBoundsAttr = writeOp.getInBounds() ? rewriter.getArrayAttr( writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) : ArrayAttr(); - Location loc = writeOp.getLoc(); Value rankedReducedView = rewriter.create( - loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(), - strides); + loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index d6d69c8af88508d..3984f17f9e8cdb7 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8, // ----- +func.func @contiguous_outer_dyn_inner_most_view(%in: memref>) -> vector<1x8x1xf32>{ + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref>, vector<1x8x1xf32> + return %0 : vector<1x8x1xf32> +} +// CHECK: func @contiguous_outer_dyn_inner_most_view( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]] +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1] +// CHECK-SAME: memref> to memref> +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] +// CHECK-SAME: memref>, vector<1x8xf32> +// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] +// CHECK: return %[[RESULT]] + +// ----- + func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) { %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 @@ -119,6 +138,27 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, // ----- +func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0] + {in_bounds = [true, true, true, true]} + : vector<1x16x16x1xf32>, memref> + return +} +// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]] +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1] +// CHECK-SAME: memref> to memref> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32> +// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] +// CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]] + +// ----- + func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]