Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flow] Improve slice raising to handle dynamic and unit dims #14845

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:ValueBoundsOpInterface",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ iree_cc_library(
MLIRTransformDialectTransforms
MLIRTransformUtils
MLIRTransforms
MLIRValueBoundsOpInterface
iree::compiler::Dialect::Flow::Conversion::TensorToFlow
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
#include "iree-dialects/Transforms/TransformMatchers.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using transform_ext::StructuredOpMatcher;

#define DEBUG_TYPE "iree-raise-special-ops"

#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir {
namespace iree_compiler {
namespace IREE {
Expand Down Expand Up @@ -145,14 +152,12 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) {
if (!llvm::hasSingleElement(extractOps)) {
return failure();
}
tensor::ExtractOp extractOp = *extractOps.begin();
auto resultType = dyn_cast<TensorType>(linalgOp.getResult(0).getType());
if (!resultType) {
return failure();
}

ArrayRef<int64_t> sourceShape = extractOp.getTensor().getType().getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();
LDBG("Attempting to raise extracting generic to elementwise: " << linalgOp);

tensor::ExtractOp extractOp = *extractOps.begin();
Value source = extractOp.getTensor();
Value result = linalgOp.getResult(0);

// Raise the tensor.extract op to an input.
SmallVector<AffineExpr> exprs;
Expand All @@ -167,29 +172,31 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) {
// Restrict to cases where the constant is 0. This is because handling
// constants other than 0 in indexing map, may cause problems in the
// lowering pipeline later.
if (constantIndex.getLimitedValue() != 0)
if (constantIndex.getLimitedValue() != 0) {
LDBG(" non-zero constant index -> FAIL");
return failure();
}
exprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
continue;
}
// 2. The indexing value is a linalg.index.
if (auto indexOp = indexValue.getDefiningOp<linalg::IndexOp>()) {
// Make sure that for this index, the size of the input and output
// match and are not dynamic. We need this to maintain the op to be
// match. We need this to maintain the op to be
// elementwise.
// TODO: This restriction can be relaxed by adding a extract_slice op
// on the `source` tensor. This is not same as raising the whole
// operation to an extract_slice, as there can be permutations and
// projections involved.
if (sourceShape[idx] == ShapedType::kDynamic ||
resultShape[indexOp.getDim()] == ShapedType::kDynamic ||
sourceShape[idx] != resultShape[indexOp.getDim()]) {
FailureOr<bool> dimsEqual = ValueBoundsConstraintSet::areEqual(
source, result, idx, indexOp.getDim());
if (failed(dimsEqual) || !*dimsEqual) {
LDBG(" Dimension sizes at index " << idx << " and "
<< indexOp.getDim() << " -> FAIL");
return failure();
}
exprs.push_back(
getAffineDimExpr(indexOp.getDim(), rewriter.getContext()));
continue;
}
LDBG(" Dimension size at index "
<< idx << " not indexed by linalg.index op -> FAIL");
return failure();
}
AffineMap indexingMap = AffineMap::get(
Expand Down Expand Up @@ -232,6 +239,8 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) {
linalgOp.getIteratorTypesAttr(), linalgOp.getDocAttr(),
linalgOp.getLibraryCallAttr(), bodyBuilder);

LDBG(" Successfully raised to elementwise linalg: " << newLinalgOp);

return newLinalgOp;
}

Expand All @@ -242,19 +251,27 @@ static FailureOr<tensor::ExtractSliceOp>
tryRaiseToExtractSlice(AffineMap inputIndexingMap, AffineMap outputIndexingMap,
Value input, Value output, linalg::GenericOp linalgOp,
RewriterBase &rewriter) {
// Output shape must be smaller than input shape.
if (outputIndexingMap.getNumResults() >= inputIndexingMap.getNumResults()) {
// Output rank cannot exceed input rank.
if (outputIndexingMap.getNumResults() > inputIndexingMap.getNumResults()) {
LDBG(" Not (rank reducing) slice -> FAIL");
return failure();
}
// Output map should be identity.
if (!outputIndexingMap.isIdentity()) {
LDBG(" Output map not identity -> FAIL");
return failure();
}
// All iterator types must be parallel.
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) {
LDBG(" Has reduction iterators -> FAIL");
return failure();
}

auto outType = dyn_cast<RankedTensorType>(output.getType());
if (!outType) {
return failure();
}

ArrayRef<int64_t> outShape = outType.getShape();

// Try to match each output dimension to an input dimension, in order.
Expand All @@ -266,42 +283,49 @@ tryRaiseToExtractSlice(AffineMap inputIndexingMap, AffineMap outputIndexingMap,
IntegerAttr zero = rewriter.getI64IntegerAttr(0);
IntegerAttr one = rewriter.getI64IntegerAttr(1);
unsigned currOutDim = 0;
unsigned leadOutDim = 0;
for (auto [idx, expr] : llvm::enumerate(inputIndexingMap.getResults())) {
// Check if the input dimension matches the current output dimension.
if (expr == outputIndexingMap.getResult(currOutDim)) {
offsets.push_back(zero);
// Get the dim size from the output tensor.
if (outShape[currOutDim] == ShapedType::kDynamic) {
auto dim = rewriter.create<tensor::DimOp>(linalgOp.getLoc(), output,
currOutDim);
sizes.push_back(dim.getResult());
} else {
sizes.push_back(rewriter.getI64IntegerAttr(outShape[currOutDim]));
}
++currOutDim;
continue;
}
// Assume that the constant access is a rank reducing access.
if (expr.isa<AffineConstantExpr>()) {
IntegerAttr constIdx = rewriter.getI64IntegerAttr(
expr.cast<AffineConstantExpr>().getValue());
// Constant accesses can either be rank reducing or an access into a unit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually limiting the cases covered? The constant access does not have to be for dimensions that are unit dim. You could have a constant access where the source of the unit dim could be anything. Then the constant (or any ssa value for that matter that is not related the Linalg op loops become the offset, and the size is 1.

I'd like to take a bit of a further look into this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment here is a bit misleading. It does have to either be a unit dim or rank reducing on the output though due to the constraints of a linalg op. If not one of those then it is a broadcast and thus can't be raised, but I think that should be captured properly. That said, I might not have thought through it properly and will come back to it later. I'm also starting to think that this intermediate linalg approach is nice for the cases that work, but there are certain classes of slices (exactly when we have an arbitrary computation of an index/size for slicing) that warrant a direct lowering. I'd like to see real world use cases for those before adding raisings for them though.

I'll add a test case for broadcast to make sure it isn't erroneously raised though.

// dim. This is tracked by counting the number of unit output dimensions
// between non-unit ones.
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
IntegerAttr constIdx = rewriter.getI64IntegerAttr(constExpr.getValue());
offsets.push_back(constIdx);
sizes.push_back(one);
if (leadOutDim < outShape.size() && outShape[leadOutDim] == 1) {
++leadOutDim;
}
continue;
}
// Check if the input dimension matches the current output dimension.
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
int dimPos = dimExpr.getPosition();
if (dimPos >= currOutDim && dimPos <= leadOutDim) {
offsets.push_back(zero);
// Get the dim size from the output tensor.
sizes.push_back(
tensor::getMixedSize(rewriter, linalgOp.getLoc(), output, dimPos));
currOutDim = dimPos + 1;
leadOutDim = currOutDim;
continue;
}
}
// Unknown access, fail.
LDBG(" Unknown access type along index " << idx << " -> FAIL");
return failure();
}

// All output dimensions did not match an input dimension.
if (currOutDim != outputIndexingMap.getNumResults()) {
LDBG(" Not all output dimensions match an input dimension -> FAIL");
return failure();
}

// We only support dim expr or a constant expr on the input map, so strides
// will always be 1.
SmallVector<OpFoldResult> strides(inputIndexingMap.getNumResults(), one);

LDBG(" Lowering to slice -> SUCCESS");
return rewriter.create<tensor::ExtractSliceOp>(
linalgOp.getLoc(), outType, input, offsets, sizes, strides);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,8 @@ func.func @generic_fill(%arg0: tensor<?x?xf32>) -> tensor<1x1x?x?xf32> {
// -----

#map = affine_map<(d0) -> (d0)>
func.func @test(%A : tensor<1x1x5120xf32>, %B : tensor<5120xf32>) -> tensor<5120xf32> {
func.func @test_rank_reduce(%A : tensor<1x1x5120xf32>, %B : tensor<5120xf32>) -> tensor<5120xf32> {
%c0 = arith.constant 0 : index
// CHECK: tensor.extract_slice
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%B : tensor<5120xf32>) {
^bb0(%out: f32):
%12 = linalg.index 0 : index
Expand All @@ -252,35 +251,84 @@ func.func @test(%A : tensor<1x1x5120xf32>, %B : tensor<5120xf32>) -> tensor<5120
return %0 : tensor<5120xf32>
}

// CHECK-LABEL: func @test_rank_reduce
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [1, 1, 5120] [1, 1, 1]
// CHECK-SAME: tensor<1x1x5120xf32> to tensor<5120xf32>

// -----

// This currently should not be raised as the operation does not remain
// elementwise after raising the tensor.extract to input.
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @test(%A : tensor<128x128x128xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
func.func @test_slice_middle(%A : tensor<64x64x64xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
%c0 = arith.constant 0 : index
// CHECK: linalg.generic
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%B : tensor<64x64xf32>) {
^bb0(%out: f32):
%i1 = linalg.index 0 : index
%i2 = linalg.index 1 : index
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<128x128x128xf32>
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<64x64x64xf32>
linalg.yield %extracted : f32
} -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

// CHECK-LABEL: func @test_slice_middle
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [64, 1, 64] [1, 1, 1]
// CHECK-SAME: tensor<64x64x64xf32> to tensor<64x64xf32>

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @test_unit_identity_slice(%A : tensor<1x1x64xf32>, %B : tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
%c0 = arith.constant 0 : index
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%B : tensor<1x1x64xf32>) {
^bb0(%out: f32):
%i2 = linalg.index 2 : index
%extracted = tensor.extract %A[%c0, %c0, %i2] : tensor<1x1x64xf32>
linalg.yield %extracted : f32
} -> tensor<1x1x64xf32>
return %0 : tensor<1x1x64xf32>
}

// CHECK-LABEL: func @test_unit_identity_slice
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x64xf32>,
// CHECK: return %[[ARG0]]

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @test(%A : tensor<64x64x64xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
func.func @test_dynamic_slice(%A : tensor<1x128x?xf32>) -> tensor<128x?xf32> {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %A, %c2 : tensor<1x128x?xf32>
%empty = tensor.empty(%dim) : tensor<128x?xf32>
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%empty : tensor<128x?xf32>) {
^bb0(%out: f32):
%i0 = linalg.index 0 : index
%i1 = linalg.index 1 : index
%extracted = tensor.extract %A[%c0, %i0, %i1] : tensor<1x128x?xf32>
linalg.yield %extracted : f32
} -> tensor<128x?xf32>
return %0 : tensor<128x?xf32>
}

// CHECK-LABEL: func @test_dynamic_slice
// CHECK: %[[DIM:.+]] = tensor.dim {{.*}} : tensor<1x128x?xf32>
// CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [1, 128, %[[DIM]]] [1, 1, 1]
// CHECK-SAME: tensor<1x128x?xf32> to tensor<128x?xf32>

// -----

// This currently should not be raised as the operation does not remain
// elementwise after raising the tensor.extract to input.
#map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @test_non_slice
func.func @test_non_slice(%A : tensor<128x128x128xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
%c0 = arith.constant 0 : index
// CHECK: linalg.generic
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%B : tensor<64x64xf32>) {
^bb0(%out: f32):
%i1 = linalg.index 0 : index
%i2 = linalg.index 1 : index
// CHECK: tensor.extract_slice
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<64x64x64xf32>
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<128x128x128xf32>
linalg.yield %extracted : f32
} -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
Expand Down