Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Relax scalable vectorization restrictions #117991

Merged
merged 2 commits into from
Nov 29, 2024

Conversation

banach-space
Copy link
Contributor

Currently, the Linalg vectorizer disallows non-trailing parallel
dimensions to be scalable, e.g., vector_sizes [[8], 1] (*), for cases
like:

%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>

This restriction exists to avoid generating "scalable" arrays of
aggregates, which LLVM does not support (multi-dim vectors are lowered
into arrays of aggregates at the LLVM level).

This patch relaxes that restriction when the trailing parallel vector
dimension is 1, e.g., for vector_sizes [[8], 1]. Such cases are safe
since trailing unit dimensions can be collapsed. This relaxation is
necessary to support scalable vectorization for tensor.pack, where inner
tile sizes are [8] (scalable) and 1 (scalar).

(*) Transform Dialect notation

@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

Currently, the Linalg vectorizer disallows non-trailing parallel
dimensions to be scalable, e.g., vector_sizes [[8], 1] (*), for cases
like:

%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor&lt;?x?xf32&gt;) -&gt; tensor&lt;?x?xf32&gt;

This restriction exists to avoid generating "scalable" arrays of
aggregates, which LLVM does not support (multi-dim vectors are lowered
into arrays of aggregates at the LLVM level).

This patch relaxes that restriction when the trailing parallel vector
dimension is 1, e.g., for vector_sizes [[8], 1]. Such cases are safe
since trailing unit dimensions can be collapsed. This relaxation is
necessary to support scalable vectorization for tensor.pack, where inner
tile sizes are [8] (scalable) and 1 (scalar).

(*) Transform Dialect notation


Full diff: https://github.com/llvm/llvm-project/pull/117991.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+18-9)
  • (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+11-6)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 06bb6c0fb1cac9..f3fffbef67dc71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2022,26 +2022,35 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
   // it matches one of the supported cases:
-  //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
-  //  2. exactly 2 dims are scalable and those are the _last two adjacent_
-  //     parallel dims
-  //  3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
+  //  1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
+  //    (*).
+  //  2. Exactly 2 dims are scalable and those are the _last two adjacent_
+  //     parallel dims.
+  //  3. Exactly 1 reduction dim is scalable and that's the last (innermost) dim.
   // The 2nd restriction above means that only Matmul-like Ops are supported
   // when 2 dims are scalable, e.g. :
   //    * iterators = [parallel, parallel, reduction]
   //    * scalable flags = [true, true, false]
+  //
+  // (*) Non-unit dims get folded away in practice.
+  // TODO: Relax these conditions as good motivating examples are identified.
 
-  // Find the first scalable flag
-  bool seenParalell = false;
+  // Find the first scalable flag, and ...
+  bool seenNonUnitParallel = false;
   auto iterators = linalgOp.getIteratorTypesArray();
   SmallVector<bool> scalableFlags(inputScalableVecDims);
-  while (!scalableFlags.back()) {
-    seenParalell |= (iterators.back() == utils::IteratorType::parallel);
+  int64_t idx = scalableFlags.size() - 1;
+  while (!scalableFlags[idx]) {
+    bool isNonUnitDim = (inputVectorSizes[idx] != 1);
+    seenNonUnitParallel |=
+        (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
 
     iterators.pop_back();
     scalableFlags.pop_back();
+    idx--;
   }
 
+  // ... analyze the corresponding iterator.
   switch (iterators.back()) {
   case utils::IteratorType::reduction: {
     // Check 3. above is met.
@@ -2059,7 +2068,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
   }
   case utils::IteratorType::parallel: {
     // Check 1. and 2. above are met.
-    if (seenParalell) {
+    if (seenNonUnitParallel) {
       LDBG("Inner parallel dim not requested for scalable "
            "vectorization\n");
       return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 68bac72a1465d0..227829238a3d79 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -122,22 +122,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
+// NOTE: Often, non-trailing scalable sizes are problematic - there are no
+// "scalable" arrays of vectors at the LLVM level (multi-dim vectors are
+// decomposed into arrays of aggregates). However, the trailing dim in this
+// case is 1 and that can be folded away later.
+
+func.func @vectorize_dynamic_fill_leading_scalable(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
   %0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
-// CHECK-LABEL: func.func @vectorize_dynamic_fill
+// CHECK-LABEL: func.func @vectorize_dynamic_fill_leading_scalable
 //   CHECK: %[[DIM0:.*]] = tensor.dim
 //   CHECK: %[[DIM1:.*]] = tensor.dim
-//   CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1>
-//   CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32>
-//   CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor<?x?xf32> } : vector<8x[16]xi1>
+//   CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x1xi1>
+//   CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<[8]x1xf32>
+//   CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<[8]x1xf32>, tensor<?x?xf32> } : vector<[8]x1xi1>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [8, [16]] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [[8], 1] : !transform.any_op
     transform.yield
   }
 }

Copy link

github-actions bot commented Nov 28, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Currently, the Linalg vectorizer disallows non-trailing parallel
dimensions to be scalable, e.g., `vector_sizes [[8], 1]` (*), for cases
like:

```mlir
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
```

This restriction exists to avoid generating "scalable" arrays of
aggregates, which LLVM does not support (multi-dim vectors are lowered
into arrays of aggregates at the LLVM level).

This patch relaxes that restriction when the trailing parallel vector
dimension is `1`, e.g., for `vector_sizes [[8], 1]`. Such cases are safe
since trailing unit dimensions can be collapsed. This relaxation is
necessary to support scalable vectorization for tensor.pack, where inner
tile sizes are `[8]` (scalable) and `1` (scalar).

(*) Transform Dialect notation
@banach-space banach-space force-pushed the andrzej/relax_scalable_vec branch from 687cdfd to cb8fd0c Compare November 28, 2024 12:00
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG. Minor comments


// Find the first scalable flag
bool seenParalell = false;
// Find the first scalable flag, and ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the and ... got me. Thought something was missing :)
Needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at all, bad habit :)


iterators.pop_back();
scalableFlags.pop_back();
idx--;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: --> --idx

Minor tweaks based on Diego's suggestions
@banach-space banach-space merged commit aa9d368 into llvm:main Nov 29, 2024
8 checks passed
@banach-space banach-space deleted the andrzej/relax_scalable_vec branch December 12, 2024 11:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants