-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Fix invalid IR in Linalg op fusion #74425
[mlir][linalg] Fix invalid IR in Linalg op fusion #74425
Conversation
Linalg op fusion (`Linalg/Transforms/Fusion.cpp`) used to generate invalid fused producer ops: ``` error: 'linalg.conv_2d_nhwc_hwcf' op expected type of operand llvm#2 ('tensor<1x8x16x4xf32>') to match type of corresponding result ('tensor<?x?x?x?xf32>') note: see current operation: %24 = "linalg.conv_2d_nhwc_hwcf"(%21, %22, %23) <{dilations = dense<1> : tensor<2xi64>, operandSegmentSizes = array<i32: 2, 1>, strides = dense<2> : tensor<2xi64>}> ({ ^bb0(%arg9: f32, %arg10: f32, %arg11: f32): %28 = "arith.mulf"(%arg9, %arg10) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 %29 = "arith.addf"(%arg11, %28) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32 "linalg.yield"(%29) : (f32) -> () }) {linalg.memoized_indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]} : (tensor<1x?x?x3xf32>, tensor<3x3x3x4xf32>, tensor<1x8x16x4xf32>) -> tensor<?x?x?x?xf32> ``` This is a problem because the input IR to greedy pattern rewriter during `-test-linalg-greedy-fusion` is invalid. This commit fixes tests such as `mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir` when verifying the IR after each pattern application (llvm#74270).
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesLinalg op fusion (
This is a problem because the input IR to greedy pattern rewriter during Full diff: https://github.com/llvm/llvm-project/pull/74425.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 11bd886c36e53..e48188fe516d3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -144,27 +144,17 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
/**omitPartialTileCheck=*/false));
- // Iterate over the results in order.
- // Extract the subtensor type from the linearized range.
- // Since we do not enforce any canonicalizations on the fly, this is always
- // fully dynamic at construction time.
+ // Take result types from the tiled init operands.
+ MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(producer->getNumResults());
- for (Value operand : producer.getDpsInits()) {
- auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
- if (!tensorType)
- continue;
- unsigned rank = tensorType.getRank();
- SmallVector<int64_t, 4> staticOffsetsVector(
- rank, ShapedType::kDynamic);
- SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamic);
- SmallVector<int64_t, 4> staticStridesVector(
- rank, ShapedType::kDynamic);
- resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
- tensorType, staticOffsetsVector, staticSizesVector,
- staticStridesVector));
+ int64_t firstInitOperandIdx =
+ static_cast<OperandRange>(producerDpsInits).getBeginOperandIndex();
+ for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
+ resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType());
}
+ // Clone the producer with new operands and result types.
LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
// Shift all IndexOp results by the tile offset.
|
Linalg op fusion (
Linalg/Transforms/Fusion.cpp
) used to generate invalid fused producer ops:This is a problem because the input IR to greedy pattern rewriter during
-test-linalg-greedy-fusion
is invalid. This commit fixes tests such asmlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
when verifying the IR after each pattern application (#74270).