-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add tests for scalable vectors #67806
Conversation
Adds tests for scalable vectors in: * vector-contract-to-outerproduct-transforms.mlir Every existing test is duplicated with (fixed-width vectors are replaced with scalable vectors). One test required a fix in * LowerVectorContract.cpp. This change is a part of a larger effort to enable scalable vectorisation in Linalg. See this RFC for more context: * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/ Fixes llvm#67804
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector ChangesAdds tests for scalable vectors in:
This change is a part of a larger effort to enable scalable Fixes #67804 Patch is 38.84 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67806.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..6e63d52d22a1f6b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
return v;
Type promotedType = dstElementType;
if (vecType)
- promotedType = VectorType::get(vecType.getShape(), promotedType);
+ promotedType = vecType.clone(promotedType);
if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..3746897bcd864f6 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -34,16 +34,16 @@
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3xf32>,
@@ -54,22 +54,46 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}
+
+// CHECK-LABEL: func.func @masked_extract_contract2_scalable(
+// CHECK-SAME: %{{.*}}: vector<[2]x[3]xf32>,
+// CHECK-SAME: %{{.*}}: vector<[3]xf32>,
+// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x[3]xi1>) -> vector<[2]xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x[3]xi1> to vector<[3]x[2]xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable(%arg0: vector<[2]x[3]xf32>,
+ %arg1: vector<[3]xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME: %[[M:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK: %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK: %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
@@ -80,10 +104,36 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
return %0 : vector<3x7xf32>
}
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable(
+// CHECK-SAME: %{{.*}}: vector<[3]x[5]xf32>,
+// CHECK-SAME: %{{.*}}: vector<[5]x[7]xf32>,
+// CHECK-SAME: %{{.*}}: vector<[3]x[7]xf32>,
+// CHECK-SAME: %[[M:.*]]: vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+// CHECK: %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<[3]x[7]x[5]xi1> to vector<[5]x[3]x[7]xi1>
+// CHECK: %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK: %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK: %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK: %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK: %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+
+func.func @masked_extract_contract4_scalable(%arg0: vector<[3]x[5]xf32>,
+ %arg1: vector<[5]x[7]xf32>,
+ %arg2: vector<[3]x[7]xf32>,
+ %m : vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<[3]x[5]xf32>, vector<[5]x[7]xf32> into vector<[3]x[7]xf32> } : vector<[3]x[7]x[5]xi1> -> vector<[3]x[7]xf32>
+ return %0 : vector<[3]x[7]xf32>
+}
+
// CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<4x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
//
@@ -116,6 +166,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[4]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[4]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME: : vector<[2]x[4]xf32> to vector<[4]x[2]xf32>
+//
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[4]x[2]xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[4]x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<[2]xf32> from vector<[4]x[2]xf32>
+// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<[4]x[3]xf32>
+// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<[2]xf32> from vector<[4]x[2]xf32>
+// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<[4]x[3]xf32>
+// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<[2]xf32> from vector<[4]x[2]xf32>
+// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<[4]x[3]xf32>
+// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32>
+//
+// CHECK: return %[[c3]] : vector<[2]x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<[2]x[4]xf32>,
+ %arg1: vector<[4]x[3]xf32>,
+ %arg2: vector<[2]x[3]xf32>) -> vector<[2]x[3]xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<[2]x[4]xf32>, vector<[4]x[3]xf32> into vector<[2]x[3]xf32>
+ return %0 : vector<[2]x[3]xf32>
+}
+
// CHECK-LABEL: func @matmul_0
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -133,6 +219,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_0_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
// CHECK-LABEL: func @matmul_0_mixed
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -152,6 +255,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf16>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf16>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf16> from vector<[1]x[2]xf16>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<[1]x[3]xf16>
+// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<[2]xf16> to vector<[2]xf32>
+// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+// CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<[2]x[1]xf16>, %arg1: vector<[1]x[3]xf16>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<[2]x[1]xf16>, vector<[1]x[3]xf16> into vector<[2]x[3]xf32>
+ return %0 : vector<[2]x[3]xf32>
+}
+
#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -163,9 +285,9 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
}
// CHECK-LABEL: func @matmul_1
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
@@ -180,6 +302,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[3]x[1]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[3]x[1]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+ : vector<[2]x[1]xf32>, vector<[3]x[1]xf32> into vector<[2]x[3]xf32>
+ return %0 : vector<[2]x[3]xf32>
+}
+
#matmat_accesses_2 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -191,9 +331,9 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
}
// CHECK-LABEL: func @matmul_2
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
@@ -206,6 +346,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[1]x[2]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<[1]x[2]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+ : vector<[1]x[2]xf32>, vector<[1]x[3]xf32> into vector<[2]x[3]xf32>
+ return %0 : vector<[2]x[3]xf32>
+}
+
#matmat_accesses_3 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -217,9 +373,9 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
}
// CHECK-LABEL: func @matmul_3
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
@@ -233,6 +389,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-L...
[truncated]
|
@@ -34,16 +34,16 @@ | |||
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, | |||
// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>, | |||
// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>, | |||
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> | |||
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated change, please remove
|
||
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<[3]x[2]xi1> | ||
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> | ||
func.func @masked_extract_contract2_scalable(%arg0: vector<[2]x[3]xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is unrolling scalable dimensions? I don't understand how you can unroll by a quantity unknown at compile-time and it be correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right here, this looks like it's only computing the fixed part of the output. You can sometimes unroll scalable things by extracting/inserting scalable subvectors (e.g. via vector.scalable.insert/extract), but I don't think that'd work for a 2D vector like this. Those vector.extracts should be in a vscale bounded loop, I think.
I think that we need to split this into a few separate changes:
Closing in favor of #68400 (and other patches that will follow). |
Adds tests for scalable vectors in:
Every existing test is duplicated with (fixed-width vectors are replaced
with scalable vectors). One test required a fix in
Note that this patch intentionally refactors the whole test file so that:
VAL_1
) are replaced with something meaningful (e.g.MASK
),%{{.*}}:
instead of%[[VAL_1:.*]]
),vector.outerproduct
) are captured - that's the only way to verify fixed-width vs scalable lowering.This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
Fixes #67804