Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add tests for scalable vectors #67806

Closed
wants to merge 1 commit into from

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Sep 29, 2023

Adds tests for scalable vectors in:

  • vector-contract-to-outerproduct-transforms.mlir
    Every existing test is duplicated with (fixed-width vectors are replaced
    with scalable vectors). One test required a fix in
  • LowerVectorContract.cpp.

Note that this patch intentionally refactors the whole test file so that:

  • tests for fixed width and scalable vectors are very similar (makes it easier to compare),
  • auto-generated variable names (e.g. VAL_1 ) are replaced with something meaningful (e.g. MASK),
  • unused REGEX variables are no longer captured (so e.g. %{{.*}}: instead of %[[VAL_1:.*]]),
  • types in key operations (e.g. vector.outerproduct) are captured - that's the only way to verify fixed-width vs scalable lowering.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:

Fixes #67804

Adds tests for scalable vectors in:
  * vector-contract-to-outerproduct-transforms.mlir
Every existing test is duplicated with (fixed-width vectors are replaced
with scalable vectors). One test required a fix in
  * LowerVectorContract.cpp.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Fixes llvm#67804
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Changes

Adds tests for scalable vectors in:

  • vector-contract-to-outerproduct-transforms.mlir
    Every existing test is duplicated with (fixed-width vectors are replaced
    with scalable vectors). One test required a fix in
  • LowerVectorContract.cpp.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:

Fixes #67804


Patch is 38.84 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67806.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+1-1)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+452-43)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..6e63d52d22a1f6b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
       return v;
     Type promotedType = dstElementType;
     if (vecType)
-      promotedType = VectorType::get(vecType.getShape(), promotedType);
+      promotedType = vecType.clone(promotedType);
     if (isa<FloatType>(dstElementType))
       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..3746897bcd864f6 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -34,16 +34,16 @@
 // CHECK-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
 // CHECK-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
 // CHECK-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
 // CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
 // CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
                                     %arg1: vector<3xf32>,
@@ -54,22 +54,46 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
+
+// CHECK-LABEL:   func.func @masked_extract_contract2_scalable(
+// CHECK-SAME:      %{{.*}}: vector<[2]x[3]xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[3]xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x[3]xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x[3]xi1> to vector<[3]x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK:         %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK:         %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
 func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
                                     %arg1: vector<5x7xf32>,
@@ -80,10 +104,36 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
   return %0 : vector<3x7xf32>
 }
 
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable(
+// CHECK-SAME:    %{{.*}}: vector<[3]x[5]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<[5]x[7]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<[3]x[7]xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+// CHECK:         %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<[3]x[7]x[5]xi1> to vector<[5]x[3]x[7]xi1>
+// CHECK:         %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+
+func.func @masked_extract_contract4_scalable(%arg0: vector<[3]x[5]xf32>,
+                                    %arg1: vector<[5]x[7]xf32>,
+                                    %arg2: vector<[3]x[7]xf32>,
+                                    %m : vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<[3]x[5]xf32>, vector<[5]x[7]xf32> into vector<[3]x[7]xf32> } : vector<[3]x[7]x[5]xi1> -> vector<[3]x[7]xf32>
+  return %0 : vector<[3]x[7]xf32>
+}
+
 // CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<4x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 // CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
 //
@@ -116,6 +166,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[4]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[4]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME:  : vector<[2]x[4]xf32> to vector<[4]x[2]xf32>
+//
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: return %[[c3]] : vector<[2]x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<[2]x[4]xf32>,
+                          %arg1: vector<[4]x[3]xf32>,
+                          %arg2: vector<[2]x[3]xf32>) -> vector<[2]x[3]xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<[2]x[4]xf32>, vector<[4]x[3]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 // CHECK-LABEL: func @matmul_0
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -133,6 +219,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
 // CHECK-LABEL: func @matmul_0_mixed
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -152,6 +255,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf16>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf16>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf16> from vector<[1]x[2]xf16>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<[1]x[3]xf16>
+//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<[2]xf16> to vector<[2]xf32>
+//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<[2]x[1]xf16>, %arg1: vector<[1]x[3]xf16>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf16>, vector<[1]x[3]xf16> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_1 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -163,9 +285,9 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
 }
 
 // CHECK-LABEL: func @matmul_1
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
@@ -180,6 +302,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[3]x[1]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[3]x[1]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf32>, vector<[3]x[1]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_2 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -191,9 +331,9 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_2
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
@@ -206,6 +346,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[1]x[2]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<[1]x[2]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+    : vector<[1]x[2]xf32>, vector<[1]x[3]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_3 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -217,9 +373,9 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_3
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
@@ -233,6 +389,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-L...
[truncated]

@@ -34,16 +34,16 @@
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>,
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated change, please remove


// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<[3]x[2]xi1>
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
func.func @masked_extract_contract2_scalable(%arg0: vector<[2]x[3]xf32>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unrolling scalable dimensions? I don't understand how you can unroll by a quantity unknown at compile-time and it be correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right here, this looks like it's only computing the fixed part of the output. You can sometimes unroll scalable things by extracting/inserting scalable subvectors (e.g. via vector.scalable.insert/extract), but I don't think that'd work for a 2D vector like this. Those vector.extracts should be in a vscale bounded loop, I think.

@banach-space
Copy link
Contributor Author

I think that we need to split this into a few separate changes:

  • scalable vectors where the reduction dimension is fixed-width (should already work),
  • scalable vectors where the reduction is scalable - more work is needed (that's what's "broken" in this patch).

Closing in favor of #68400 (and other patches that will follow).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vector-contract-to-outerproduct-transforms.mlir
4 participants