From 5a9af39aab40bba52d4e46cabf4b1ab47f614fa2 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:12:47 -0800 Subject: [PATCH] [mlir][sparse] made sparse vectorizer more robust on position of invariants (#80766) Because the sparse vectorizer relies on the code coming out of the sparsifier, the "patterns" are not always made very general. However, a recent change in the generated code revealed an obvious situation where the subscript analysis could be made a bit more robust. Fixes: https://github.com/llvm/llvm-project/issues/79897 --- .../Dialect/SparseTensor/Transforms/SparseVectorization.cpp | 6 ++++++ mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir | 3 +-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 3a487a3bd6a0692..2b81d6cdc1eabe6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -316,6 +316,12 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, if (auto load = cast.getDefiningOp()) { Value inv = load.getOperand(0); Value idx = load.getOperand(1); + // Swap non-invariant. + if (!isInvariantValue(inv, block)) { + inv = idx; + idx = load.getOperand(0); + } + // Inspect. if (isInvariantValue(inv, block)) { if (auto arg = llvm::dyn_cast(idx)) { if (isInvariantArg(arg, block) || !innermost) diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir index dfee2b1261b6cc2..e25c3a02f91271c 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir @@ -1,4 +1,3 @@ -// FIXME: re-enable. // RUN: mlir-opt %s -sparsifier="vl=8" | FileCheck %s #Dense = #sparse_tensor.encoding<{ @@ -16,7 +15,7 @@ } // CHECK-LABEL: llvm.func @kernel_matvec -// C_HECK: llvm.intr.vector.reduce.fadd +// CHECK: llvm.intr.vector.reduce.fadd func.func @kernel_matvec(%arga: tensor, %argb: tensor, %argx: tensor) -> tensor {