-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][SVE] Add more e2e test for vector.contract #70367
[mlir][SVE] Add more e2e test for vector.contract #70367
Conversation
@llvm/pr-subscribers-mlir-sve @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesAdds basic integration tests for Depends on #69845 Full diff: https://github.com/llvm/llvm-project/pull/70367.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8de132daf3c6a5d..d77476c10908395 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
return emitOpError("expected #1 operand dim to match result dim #1");
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
return emitOpError("expected #2 operand dim to match result dim #2");
- if (vRHS.isScalable() != vLHS.isScalable())
- return emitOpError("expected either all or none of vector operands #1 "
- "and #2 to be scalable");
+ if (vLHS.isScalable() && !vRHS.isScalable()) {
+ // This restriction reflects what's currently supported in terms of
+ // scalable vectors. However, we could relax this if there's a use case.
+ return emitOpError(
+ "expected either both or only #2 operand dim to be scalable");
+ }
} else {
// An AXPY operation.
if (vRES.getRank() != 1)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 3530393c013a1ae..44fb23088cea933 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
}
// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
return %0 : vector<3x7xf32>
}
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+// Note that only the J dimension is scalable in this example. In theory, all
+// dimensions could be be scalable, but there is no target yet for which this
+// would make sense.
+func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x[7]xf32>,
+ %arg2: vector<3x[7]xf32>,
+ %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+ return %0 : vector<3x[7]xf32>
+}
+
// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
index 7d9923e036660c9..3b4e24da92aaacc 100644
--- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
%1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
- // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+ // expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
%op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+
+ return
}
+
// -----
func.func @invalid_outerproduct1(%src : memref<?xf32>) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
new file mode 100644
index 000000000000000..ef1b3d98412b60b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
@@ -0,0 +1,183 @@
+// DEFINE: %{compile} = mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
+// DEFINE: -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry} =
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
+
+// This check whether the files compiles and generates a temporary that will be executed further down.
+// RUN: %{compile}
+
+// REDEFINE: %{entry} = matmul_i32
+// RUN: %{run} | FileCheck %s --check-prefix=I32
+
+// REDEFINE: %{entry} = matmul_f32
+// RUN: %{run} | FileCheck %s --check-prefix=F32
+
+// REDEFINE: %{entry} = dot_product_i32
+// RUN: %{run} | FileCheck %s --check-prefix=DP
+
+// REDEFINE: %{entry} = matvec_i32
+// RUN: %{run} | FileCheck %s --check-prefix=MV
+
+// NOTE: These tests are meant to complement the integration tests from:
+// * ../test-contraction.mlir
+// (tests with fixed width vectors). Rather than duplicating those tests, this
+// file focuses on excercissing scalable vectors in a few most common cases.
+
+// TODO: Masks + matvec + dot product
+
+#dotp_accesses = [
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> ()>
+]
+#dotp_trait = {
+ indexing_maps = #dotp_accesses,
+ iterator_types = ["reduction"]
+}
+
+#matvec_accesses = [
+ affine_map<(i, j) -> (i, j)>,
+ affine_map<(i, j) -> (j)>,
+ affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+ indexing_maps = #matvec_accesses,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matmat_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// Contraction: dot-product a x b.
+func.func @dot_product_i32() {
+ %acc = arith.constant 0: i32
+
+ %vector_a = arith.constant dense<123> : vector<[4]xi32>
+ %vector_b = arith.constant dense<314> : vector<[4]xi32>
+ %vector_c = arith.constant dense<0> : vector<[4]xi32>
+
+ // The result of this dot-product will depend
+ // on the vector length, so we are unable to verify it.
+ %dp1 = vector.contract #dotp_trait %vector_a, %vector_b, %acc
+ : vector<[4]xi32>, vector<[4]xi32> into i32
+ // DP: {{[0-9]*}}
+ vector.print %dp1 : i32
+
+ // The result of this dot-product should be 0.
+ %dp2 = vector.contract #dotp_trait %vector_a, %vector_c, %acc
+ : vector<[4]xi32>, vector<[4]xi32> into i32
+ // DP: 0
+ vector.print %dp2 : i32
+
+ // DP: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+// Contraction: matrix-vector A x c
+func.func @matvec_i32() {
+ %acc = arith.constant dense<0>: vector<3xi32>
+
+ %vector_a = arith.constant dense<123> : vector<3x[4]xi32>
+ %vector_b = arith.constant dense<314> : vector<[4]xi32>
+ %vector_c = arith.constant dense<0> : vector<[4]xi32>
+
+ // The result of this matvec will depend on the vector length, so we are
+ // unable to verify it.
+ %dp1 = vector.contract #matvec_trait %vector_a, %vector_b, %acc
+ : vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
+ // MV: {{[0-9]*}}, {{[0-9]*}}, {{[0-9]*}}
+ vector.print %dp1 : vector<3xi32>
+
+ // The result of this matvc should be a vector of 0s.
+ %dp2 = vector.contract #matvec_trait %vector_a, %vector_c, %acc
+ : vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
+ // MV: 0, 0, 0
+ vector.print %dp2 : vector<3xi32>
+
+ // MV: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+func.func @matmul_i32() {
+ // Setup vector A:
+ %vector_a = arith.constant dense<123> : vector<3x5xi32>
+
+ // Setup vector B:
+ %vector_b = arith.constant dense<123> : vector<5x[2]xi32>
+
+ // Setup vector C:
+ %vector_c = arith.constant dense<314> : vector<3x[2]xi32>
+
+ // Matmul
+ %0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+ : vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32>
+
+ // Print the output
+ %slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32: ( 75959, 75959
+ vector.print %slice1 : vector<[2]xi32>
+ %slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32-NEXT: ( 75959, 75959
+ vector.print %slice2 : vector<[2]xi32>
+ %slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32-NEXT: ( 75959, 75959
+ vector.print %slice3 : vector<[2]xi32>
+
+ // CHECK: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+func.func @matmul_f32() {
+ // Setup vector A:
+ %vector_a = arith.constant dense<1.23> : vector<3x5xf32>
+
+ // Setup vector B:
+ %vector_b = arith.constant dense<1.23> : vector<5x[2]xf32>
+
+ // Setup vector C:
+ %vector_c = arith.constant dense<3.14> : vector<3x[2]xf32>
+
+ // Matmul
+ %0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+ : vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32>
+
+ // Print the output
+ %slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32: ( 10.7045, 10.7045
+ vector.print %slice1 : vector<[2]xf32>
+ %slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32-NEXT: ( 10.7045, 10.7045
+ vector.print %slice2 : vector<[2]xf32>
+ %slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32-NEXT: ( 10.7045, 10.7045
+ vector.print %slice3 : vector<[2]xf32>
+
+ // CHECK: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ } : !transform.any_op
+}
|
e8fa6fc
to
1db7c9f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG! :)
Typo in PR description:
s/matviec/matvec/g
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
Adds basic integration tests for `vector.contract` for the dot product and matvec operations. These tests excercise scalable vectors.
1db7c9f
to
7a0d0c6
Compare
Update the dot product teset to check the actual result (thanks Ben)
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
Outdated
Show resolved
Hide resolved
0ac6707
to
028a5bb
Compare
Further generalization
028a5bb
to
e80df20
Compare
Adds basic integration tests for
vector.contract
for the dot productand matvec operations. These tests exercise scalable vectors.
Depends on #69845