Skip to content

Commit

Permalink
[mlir][vector] Drop inner unit dims for transfer ops on dynamic shape…
Browse files Browse the repository at this point in the history
…s. (llvm#79752)
  • Loading branch information
hanhanW authored Jan 29, 2024
1 parent 4a39d08 commit 66347e5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
29 changes: 17 additions & 12 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ class DropInnerMostUnitDimsTransferRead
return failure();

auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
if (!srcType || !srcType.hasStaticShape())
if (!srcType)
return failure();

if (!readOp.getPermutationMap().isMinorIdentity())
Expand All @@ -1260,19 +1260,21 @@ class DropInnerMostUnitDimsTransferRead
targetType.getElementType());

auto loc = readOp.getLoc();
SmallVector<OpFoldResult> sizes =
memref::getMixedSizes(rewriter, loc, readOp.getSource());
SmallVector<OpFoldResult> offsets(srcType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
MemRefType resultMemrefType =
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);

ArrayAttr inBoundsAttr =
readOp.getInBounds()
? rewriter.getArrayAttr(
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
strides);
loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
Value result = rewriter.create<vector::TransferReadOp>(
Expand Down Expand Up @@ -1318,7 +1320,7 @@ class DropInnerMostUnitDimsTransferWrite
return failure();

auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
if (!srcType || !srcType.hasStaticShape())
if (!srcType)
return failure();

if (!writeOp.getPermutationMap().isMinorIdentity())
Expand All @@ -1341,20 +1343,23 @@ class DropInnerMostUnitDimsTransferWrite
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());

Location loc = writeOp.getLoc();
SmallVector<OpFoldResult> sizes =
memref::getMixedSizes(rewriter, loc, writeOp.getSource());
SmallVector<OpFoldResult> offsets(srcType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
MemRefType resultMemrefType =
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);
ArrayAttr inBoundsAttr =
writeOp.getInBounds()
? rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();

Location loc = writeOp.getLoc();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
strides);
loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,

// -----

func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
return %0 : vector<1x8x1xf32>
}
// CHECK: func @contiguous_outer_dyn_inner_most_view(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
// CHECK-SAME: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
// CHECK-SAME: memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
// CHECK: return %[[RESULT]]

// -----

func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
Expand Down Expand Up @@ -119,6 +138,27 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,

// -----

func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
{in_bounds = [true, true, true, true]}
: vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
return
}
// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
// CHECK-SAME: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
// CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]]

// -----

func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
Expand Down

0 comments on commit 66347e5

Please sign in to comment.