-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Relax scalable vectorization restrictions #117991
[mlir][linalg] Relax scalable vectorization restrictions #117991
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesCurrently, the Linalg vectorizer disallows non-trailing parallel %0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32> This restriction exists to avoid generating "scalable" arrays of This patch relaxes that restriction when the trailing parallel vector (*) Transform Dialect notation Full diff: https://github.com/llvm/llvm-project/pull/117991.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 06bb6c0fb1cac9..f3fffbef67dc71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2022,26 +2022,35 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
// it matches one of the supported cases:
- // 1. exactly 1 dim is scalable and that's the _last_ parallel dim
- // 2. exactly 2 dims are scalable and those are the _last two adjacent_
- // parallel dims
- // 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
+ // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
+ // (*).
+ // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
+ // parallel dims.
+ // 3. Exactly 1 reduction dim is scalable and that's the last (innermost) dim.
// The 2nd restriction above means that only Matmul-like Ops are supported
// when 2 dims are scalable, e.g. :
// * iterators = [parallel, parallel, reduction]
// * scalable flags = [true, true, false]
+ //
+ // (*) Non-unit dims get folded away in practice.
+ // TODO: Relax these conditions as good motivating examples are identified.
- // Find the first scalable flag
- bool seenParalell = false;
+ // Find the first scalable flag, and ...
+ bool seenNonUnitParallel = false;
auto iterators = linalgOp.getIteratorTypesArray();
SmallVector<bool> scalableFlags(inputScalableVecDims);
- while (!scalableFlags.back()) {
- seenParalell |= (iterators.back() == utils::IteratorType::parallel);
+ int64_t idx = scalableFlags.size() - 1;
+ while (!scalableFlags[idx]) {
+ bool isNonUnitDim = (inputVectorSizes[idx] != 1);
+ seenNonUnitParallel |=
+ (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
iterators.pop_back();
scalableFlags.pop_back();
+ idx--;
}
+ // ... analyze the corresponding iterator.
switch (iterators.back()) {
case utils::IteratorType::reduction: {
// Check 3. above is met.
@@ -2059,7 +2068,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
}
case utils::IteratorType::parallel: {
// Check 1. and 2. above are met.
- if (seenParalell) {
+ if (seenNonUnitParallel) {
LDBG("Inner parallel dim not requested for scalable "
"vectorization\n");
return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 68bac72a1465d0..227829238a3d79 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -122,22 +122,27 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
+// NOTE: Often, non-trailing scalable sizes are problematic - there are no
+// "scalable" arrays of vectors at the LLVM level (multi-dim vectors are
+// decomposed into arrays of aggregates). However, the trailing dim in this
+// case is 1 and that can be folded away later.
+
+func.func @vectorize_dynamic_fill_leading_scalable(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: func.func @vectorize_dynamic_fill
+// CHECK-LABEL: func.func @vectorize_dynamic_fill_leading_scalable
// CHECK: %[[DIM0:.*]] = tensor.dim
// CHECK: %[[DIM1:.*]] = tensor.dim
-// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1>
-// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32>
-// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor<?x?xf32> } : vector<8x[16]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x1xi1>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<[8]x1xf32>
+// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<[8]x1xf32>, tensor<?x?xf32> } : vector<[8]x1xi1>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [8, [16]] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [[8], 1] : !transform.any_op
transform.yield
}
}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Currently, the Linalg vectorizer disallows non-trailing parallel dimensions to be scalable, e.g., `vector_sizes [[8], 1]` (*), for cases like: ```mlir %0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32> ``` This restriction exists to avoid generating "scalable" arrays of aggregates, which LLVM does not support (multi-dim vectors are lowered into arrays of aggregates at the LLVM level). This patch relaxes that restriction when the trailing parallel vector dimension is `1`, e.g., for `vector_sizes [[8], 1]`. Such cases are safe since trailing unit dimensions can be collapsed. This relaxation is necessary to support scalable vectorization for tensor.pack, where inner tile sizes are `[8]` (scalable) and `1` (scalar). (*) Transform Dialect notation
687cdfd
to
cb8fd0c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG. Minor comments
|
||
// Find the first scalable flag | ||
bool seenParalell = false; | ||
// Find the first scalable flag, and ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the and ...
got me. Thought something was missing :)
Needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not at all, bad habit :)
|
||
iterators.pop_back(); | ||
scalableFlags.pop_back(); | ||
idx--; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: --> --idx
Minor tweaks based on Diego's suggestions
Currently, the Linalg vectorizer disallows non-trailing parallel
dimensions to be scalable, e.g.,
vector_sizes [[8], 1]
(*), for caseslike:
This restriction exists to avoid generating "scalable" arrays of
aggregates, which LLVM does not support (multi-dim vectors are lowered
into arrays of aggregates at the LLVM level).
This patch relaxes that restriction when the trailing parallel vector
dimension is
1
, e.g., forvector_sizes [[8], 1]
. Such cases are safesince trailing unit dimensions can be collapsed. This relaxation is
necessary to support scalable vectorization for tensor.pack, where inner
tile sizes are
[8]
(scalable) and1
(scalar).(*) Transform Dialect notation