Skip to content

Commit

Permalink
Add support for converting rank-reducing subtensor ops to Flow dialect (
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW authored Apr 19, 2021
1 parent 03f455b commit 840f4c3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
26 changes: 23 additions & 3 deletions iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,29 @@ struct SubTensorToTensorSlice : OpRewritePattern<SubTensorOp> {
Value source = subTensorOp.source();
SmallVector<Value, 4> sourceSizesVals = sizesVals;
sourceSizesVals[0] = rewriter.createOrFold<memref::DimOp>(loc, source, 0);
rewriter.replaceOpWithNewOp<TensorSliceOp>(
subTensorOp, subTensorOp.getType(), subTensorOp.source(),
sourceSizesVals, offsetVals, sizesVals, sizesVals);

// Different from SubTensor op, a TensorSliceOp does not have
// rank-reducing behavior.
Type type = SubTensorOp::inferResultType(subTensorOp.getSourceType(),
offsets, sizes, strides);
Value tensorSliceOp = rewriter.create<TensorSliceOp>(
loc, type, subTensorOp.source(), sourceSizesVals, offsetVals, sizesVals,
sizesVals);

if (type == subTensorOp.getType()) {
// Not rank-reducing subtensor, can replace with it directly.
rewriter.replaceOp(subTensorOp, tensorSliceOp);
} else {
// Rank-reducing subtensor, need a reshape op.
SmallVector<Value, 4> sourceDynSizes, resultDynSizes;
auto sourceType = tensorSliceOp.getType().cast<RankedTensorType>();
for (auto i : llvm::seq<unsigned>(0, sourceType.getNumDynamicDims())) {
sourceDynSizes.push_back(rewriter.create<ConstantIndexOp>(
loc, sourceType.getDynamicDimIndex(i)));
}
rewriter.replaceOpWithNewOp<TensorReshapeOp>(
subTensorOp, subTensorOp.getType(), tensorSliceOp, sourceDynSizes);
}
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,24 @@ func @subtensor_convert(%arg0 : tensor<?x24x48xf32>, %arg1 : index) ->
// CHECK-DAG: %[[UNMODIFIED2:.+]] = subtensor %[[SLICE2]][0, 0, 0] [%[[D1]], 12, 24] [1, 2, 2]
// CHECK-DAG: %[[UNMODIFIED3:.+]] = subtensor %[[ARG0]][0, %[[ARG1]], 0]
// CHECK: return %[[UNMODIFIED1]], %[[UNMODIFIED2]], %[[UNMODIFIED3]]

// -----

func @rank_reducing_subtensor(%arg0: tensor<2x513xi32>, %arg1: index,
%arg2: index) -> tensor<513xi32> {
%0 = subtensor %arg0[%arg1, %arg2] [1, 513] [1, 1] : tensor<2x513xi32> to tensor<513xi32>
return %0 : tensor<513xi32>
}
// CHECK-LABEL: func @rank_reducing_subtensor
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C513:.+]] = constant 513 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]]
// CHECK-SAME: [%[[ARG1]], %[[ARG2]] for %[[C1]], %[[C513]]]
// CHECK-SAME: : tensor<2x513xi32>{%[[C2]], %[[C513]]}
// CHECK-SAME: -> tensor<1x513xi32>{%[[C1]], %[[C513]]}
// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %[[SLICE]] : tensor<1x513xi32> -> tensor<513xi32>
// CHECK: return %[[RESHAPE]] : tensor<513xi32>

0 comments on commit 840f4c3

Please sign in to comment.