Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][SVE] Add more e2e test for vector.contract #70367

Merged
merged 3 commits into from
Oct 27, 2023

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 26, 2023

Adds basic integration tests for vector.contract for the dot product
and matvec operations. These tests exercise scalable vectors.

Depends on #69845

@banach-space banach-space changed the title andrzej/add e2e dot product [mlir][SVE] Add integration test for vector.contract Oct 26, 2023
@banach-space banach-space requested a review from MacDue October 26, 2023 19:09
@llvmbot
Copy link
Member

llvmbot commented Oct 26, 2023

@llvm/pr-subscribers-mlir-sve
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Adds basic integration tests for vector.contract for the dot product
and matviec operations. These tests excercise scalable vectors.

Depends on #69845


Full diff: https://github.com/llvm/llvm-project/pull/70367.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6-3)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+44-15)
  • (modified) mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir (+4-1)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir (+183)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8de132daf3c6a5d..d77476c10908395 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
       return emitOpError("expected #1 operand dim to match result dim #1");
     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
       return emitOpError("expected #2 operand dim to match result dim #2");
-    if (vRHS.isScalable() != vLHS.isScalable())
-      return emitOpError("expected either all or none of vector operands #1 "
-                         "and #2 to be scalable");
+    if (vLHS.isScalable() && !vRHS.isScalable()) {
+      // This restriction reflects what's currently supported in terms of
+      // scalable vectors. However, we could relax this if there's a use case.
+      return emitOpError(
+          "expected either both or only #2 operand dim to be scalable");
+    }
   } else {
     // An AXPY operation.
     if (vRES.getRank() != 1)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 3530393c013a1ae..44fb23088cea933 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
 }
 
 // CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
 func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
                                     %arg1: vector<5x7xf32>,
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
   return %0 : vector<3x7xf32>
 }
 
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+// Note that only the J dimension is scalable in this example. In theory, all
+// dimensions could be be scalable, but there is no target yet for which this
+// would make sense.
+func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
+                                    %arg1: vector<5x[7]xf32>,
+                                    %arg2: vector<3x[7]xf32>,
+                                    %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+  return %0 : vector<3x[7]xf32>
+}
+
 // CHECK-LABEL: func @matmul
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
index 7d9923e036660c9..3b4e24da92aaacc 100644
--- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
   %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
   %1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
 
-  // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+  // expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
   %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+
+  return
 }
+
 // -----
 
 func.func @invalid_outerproduct1(%src : memref<?xf32>) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
new file mode 100644
index 000000000000000..ef1b3d98412b60b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
@@ -0,0 +1,183 @@
+// DEFINE: %{compile} = mlir-opt %s  -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
+// DEFINE:    -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry} =
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
+
+// This check whether the files compiles and generates a temporary that will be executed further down.
+// RUN: %{compile}
+
+// REDEFINE: %{entry} = matmul_i32
+// RUN: %{run} | FileCheck %s --check-prefix=I32
+
+// REDEFINE: %{entry} = matmul_f32
+// RUN: %{run} | FileCheck %s --check-prefix=F32
+
+// REDEFINE: %{entry} = dot_product_i32
+// RUN: %{run} | FileCheck %s --check-prefix=DP
+
+// REDEFINE: %{entry} = matvec_i32
+// RUN: %{run} | FileCheck %s --check-prefix=MV
+
+// NOTE: These tests are meant to complement the integration tests from:
+//    * ../test-contraction.mlir
+// (tests with fixed width vectors). Rather than duplicating those tests, this
+// file focuses on excercissing scalable vectors in a few most common cases.
+
+// TODO: Masks + matvec + dot product
+
+#dotp_accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> ()>
+]
+#dotp_trait = {
+  indexing_maps = #dotp_accesses,
+  iterator_types = ["reduction"]
+}
+
+#matvec_accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (j)>,
+  affine_map<(i, j) -> (i)>
+]
+#matvec_trait = {
+  indexing_maps = #matvec_accesses,
+  iterator_types = ["parallel", "reduction"]
+}
+
+#matmat_accesses = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// Contraction: dot-product a x b.
+func.func @dot_product_i32() {
+  %acc = arith.constant 0: i32
+
+  %vector_a = arith.constant dense<123> : vector<[4]xi32>
+  %vector_b = arith.constant dense<314> : vector<[4]xi32>
+  %vector_c = arith.constant dense<0> : vector<[4]xi32>
+
+  // The result of this dot-product will depend
+  // on the vector length, so we are unable to verify it.
+  %dp1 = vector.contract #dotp_trait %vector_a, %vector_b, %acc
+    : vector<[4]xi32>, vector<[4]xi32> into i32
+  // DP: {{[0-9]*}}
+  vector.print %dp1 : i32
+
+  // The result of this dot-product should be 0.
+  %dp2 = vector.contract #dotp_trait %vector_a, %vector_c, %acc
+    : vector<[4]xi32>, vector<[4]xi32> into i32
+  // DP: 0
+  vector.print %dp2 : i32
+
+  // DP: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+// Contraction: matrix-vector A x c
+func.func @matvec_i32() {
+  %acc = arith.constant dense<0>: vector<3xi32>
+
+  %vector_a = arith.constant dense<123> : vector<3x[4]xi32>
+  %vector_b = arith.constant dense<314> : vector<[4]xi32>
+  %vector_c = arith.constant dense<0> : vector<[4]xi32>
+
+  // The result of this matvec will depend on the vector length, so we are
+  // unable to verify it.
+  %dp1 = vector.contract #matvec_trait %vector_a, %vector_b, %acc
+    : vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
+  // MV: {{[0-9]*}}, {{[0-9]*}}, {{[0-9]*}}
+  vector.print %dp1 : vector<3xi32>
+
+  // The result of this matvc should be a vector of 0s.
+  %dp2 = vector.contract #matvec_trait %vector_a, %vector_c, %acc
+    : vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
+  // MV: 0, 0, 0
+  vector.print %dp2 : vector<3xi32>
+
+  // MV: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+func.func @matmul_i32() {
+  // Setup vector A:
+  %vector_a = arith.constant dense<123> : vector<3x5xi32>
+
+  // Setup vector B:
+  %vector_b = arith.constant dense<123> : vector<5x[2]xi32>
+
+  // Setup vector C:
+  %vector_c = arith.constant dense<314> : vector<3x[2]xi32>
+
+  // Matmul
+  %0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+    : vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32>
+
+  // Print the output
+  %slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
+  // I32: ( 75959, 75959
+  vector.print %slice1 : vector<[2]xi32>
+  %slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
+  // I32-NEXT: ( 75959, 75959
+  vector.print %slice2 : vector<[2]xi32>
+  %slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
+  // I32-NEXT: ( 75959, 75959
+  vector.print %slice3 : vector<[2]xi32>
+
+  // CHECK: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+func.func @matmul_f32() {
+  // Setup vector A:
+  %vector_a = arith.constant dense<1.23> : vector<3x5xf32>
+
+  // Setup vector B:
+  %vector_b = arith.constant dense<1.23> : vector<5x[2]xf32>
+
+  // Setup vector C:
+  %vector_c = arith.constant dense<3.14> : vector<3x[2]xf32>
+
+  // Matmul
+  %0 = vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+    : vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32>
+
+  // Print the output
+  %slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
+  // F32: ( 10.7045, 10.7045
+  vector.print %slice1 : vector<[2]xf32>
+  %slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
+  // F32-NEXT: ( 10.7045, 10.7045
+  vector.print %slice2 : vector<[2]xf32>
+  %slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
+  // F32-NEXT: ( 10.7045, 10.7045
+  vector.print %slice3 : vector<[2]xf32>
+
+  // CHECK: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+  } : !transform.any_op
+}

@banach-space banach-space force-pushed the andrzej/add_e2e_dot_product branch from e8fa6fc to 1db7c9f Compare October 26, 2023 20:00
@banach-space banach-space changed the title [mlir][SVE] Add integration test for vector.contract [mlir][SVE] Add more e2e test for vector.contract Oct 26, 2023
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! :)
Typo in PR description:
s/matviec/matvec/g

Adds basic integration tests for `vector.contract` for the dot product
and matvec operations. These tests excercise scalable vectors.
@banach-space banach-space force-pushed the andrzej/add_e2e_dot_product branch from 1db7c9f to 7a0d0c6 Compare October 27, 2023 09:56
Update the dot product teset to check the actual result (thanks Ben)
@banach-space banach-space force-pushed the andrzej/add_e2e_dot_product branch from 028a5bb to e80df20 Compare October 27, 2023 13:59
@banach-space banach-space merged commit e9478b1 into llvm:main Oct 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants