diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index e93e2aefb344fd..252908b026968a 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -50,6 +50,14 @@ enum class ReinterpretMapScope { #define GEN_PASS_DECL #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" +//===----------------------------------------------------------------------===// +// The SparseAssembler pass. +//===----------------------------------------------------------------------===// + +void populateSparseAssembler(RewritePatternSet &patterns); + +std::unique_ptr createSparseAssembler(); + //===----------------------------------------------------------------------===// // The SparseReinterpretMap pass. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index f38779ed9ed2b8..f0e5e8286c49fb 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -11,6 +11,26 @@ include "mlir/Pass/PassBase.td" +def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> { + let summary = "Add [dis]assemble operations on external sparse tensors"; + let description = [{ + A pass that converts public entry methods that use sparse tensors as + input parameters and/or output return values into wrapper functions + that [dis]assemble the individual tensors that constitute the actual + storage used externally into MLIR sparse tensors. This pass can be used + to prepare the public entry methods of a program that is compiled by the + MLIR sparsifier to interface with an external runtime, e.g., when passing + sparse tensors as numpy arrays from and to Python. Note that eventual + bufferization decisions (e.g. who [de]allocates the underlying memory) + should be resolved in agreement with the external runtime. + }]; + let constructor = "mlir::createSparseAssembler()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + "tensor::TensorDialect", + ]; +} + def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> { let summary = "Reinterprets sparse tensor type mappings"; let description = [{ @@ -183,7 +203,6 @@ def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> { ]; } - def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> { let summary = "Convert sparse tensors and primitives to library calls"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt index 456e45a040193e..3c0f82fc00bb9d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms # Rewriting. BufferizableOpInterfaceImpl.cpp + SparseAssembler.cpp SparseBufferRewriting.cpp SparseGPUCodegen.cpp SparseReinterpretMap.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp new file mode 100644 index 00000000000000..f9b6397e0f086f --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -0,0 +1,239 @@ +//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Utils/CodegenUtils.h" + +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; +using namespace sparse_tensor; + +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + +// TODO: reuse StorageLayout::foreachField? + +// TODO: we need COO AoS and SoA + +// Convert type range to new types range, with sparse tensors externalized. +void convTypes(TypeRange types, SmallVectorImpl &convTypes, + SmallVectorImpl *extraTypes = nullptr) { + for (auto type : types) { + // All "dense" data passes through unmodified. + if (!getSparseTensorEncoding(type)) { + convTypes.push_back(type); + continue; + } + // Convert the external representation of the values array. + const SparseTensorType stt(cast(type)); + auto shape = {ShapedType::kDynamic}; + auto vtp = RankedTensorType::get(shape, stt.getElementType()); + convTypes.push_back(vtp); + if (extraTypes) + extraTypes->push_back(vtp); + // Convert the external representations of the pos/crd arrays. + for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { + const auto lt = stt.getLvlType(lvl); + if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { + auto ptp = RankedTensorType::get(shape, stt.getPosType()); + auto ctp = RankedTensorType::get(shape, stt.getCrdType()); + convTypes.push_back(ptp); + convTypes.push_back(ctp); + if (extraTypes) { + extraTypes->push_back(ptp); + extraTypes->push_back(ctp); + } + } else { + assert(isDenseLT(lt)); // TODO: handle other cases + } + } + } +} + +// Convert input and output values to [dis[assemble ops for sparse tensors. +void convVals(OpBuilder &builder, Location loc, TypeRange types, + ValueRange fromVals, ValueRange extraVals, + SmallVectorImpl &toVals, unsigned extra, bool isIn) { + unsigned idx = 0; + for (auto type : types) { + // All "dense" data passes through unmodified. + if (!getSparseTensorEncoding(type)) { + toVals.push_back(fromVals[idx++]); + continue; + } + // Convert the external representation of the values array. + auto rtp = cast(type); + const SparseTensorType stt(rtp); + auto shape = {ShapedType::kDynamic}; + SmallVector inputs; + SmallVector retTypes; + SmallVector cntTypes; + // Collect the external representation of the values array for + // input or the outgoing sparse tensor for output. + inputs.push_back(fromVals[idx++]); + if (!isIn) { + inputs.push_back(extraVals[extra++]); + retTypes.push_back(RankedTensorType::get(shape, stt.getElementType())); + cntTypes.push_back(builder.getIndexType()); + } + // Collect the external representations of the pos/crd arrays. + for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { + const auto lt = stt.getLvlType(lvl); + if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { + if (isIn) { + inputs.push_back(fromVals[idx++]); + inputs.push_back(fromVals[idx++]); + } else { + Type pTp = stt.getPosType(); + Type cTp = stt.getCrdType(); + inputs.push_back(extraVals[extra++]); + inputs.push_back(extraVals[extra++]); + retTypes.push_back(RankedTensorType::get(shape, pTp)); + retTypes.push_back(RankedTensorType::get(shape, cTp)); + cntTypes.push_back(pTp); + cntTypes.push_back(cTp); + } + } else { + assert(isDenseLT(lt)); // TODO: handle other cases + } + } + if (isIn) { + // Assemble multiple inputs into a single sparse tensor. + auto a = builder.create(loc, rtp, inputs); + toVals.push_back(a.getResult()); + } else { + // Disassemble a single sparse input into multiple outputs. + // Note that this includes the counters, which are dropped. + unsigned len = retTypes.size(); + retTypes.append(cntTypes); + auto d = + builder.create(loc, retTypes, inputs); + for (unsigned i = 0; i < len; i++) + toVals.push_back(d.getResult(i)); + } + } +} + +//===----------------------------------------------------------------------===// +// Rewriting rules. +//===----------------------------------------------------------------------===// + +namespace { + +// A rewriting rules that converts public entry methods that use sparse tensors +// as input parameters and/or output return values into wrapper functions +// that [dis]assemble the individual tensors that constitute the actual +// storage used externally into MLIR sparse tensors. +// +// In particular, each sparse tensor input +// +// void foo(..., t, ...) { } +// +// adds the following strucuture in a wrapper +// +// void spiface_foo(..., t1..tn, ...) { +// t = assemble t1..tn +// foo(..., t, ...) +// } +// +// and likewise, each output tensor +// +// ... T ... bar(...) { return ..., t, ...; } +// +// adds the following structure in a wrapper +// +// ... T1..TN ... spiface_bar(..., t1'..tn') { +// ..., t, ... = bar(...) +// t1..tn = disassemble t, t1'..tn' +// return ..., t1..tn, ... +// } +// +// TODO: refine output sparse tensors to work well with external framework +// +// TODO: use "inlining" instead of a wrapper? +// +struct SparseFuncAssembler : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::FuncOp funcOp, + PatternRewriter &rewriter) const override { + // Only a rewrite an entry with the c-interface requested. + if (!funcOp->getAttrOfType( + LLVM::LLVMDialect::getEmitCWrapperAttrName())) + return failure(); + + // Translate sparse tensor types to external types. + SmallVector inputTypes; + SmallVector outputTypes; + SmallVector extraTypes; + convTypes(funcOp.getArgumentTypes(), inputTypes); + convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes); + + // Only sparse inputs or outputs need a wrapper function. + if (inputTypes.size() == funcOp.getArgumentTypes().size() && + outputTypes.size() == funcOp.getResultTypes().size()) + return failure(); + + // Start the new wrapper function. Together with the c-interface mangling, + // a sparse external entry point eventually will have a name like: + // _mlir_ciface_spiface_XXX(...) + Location loc = funcOp.getLoc(); + ModuleOp modOp = funcOp->getParentOfType(); + MLIRContext *context = modOp.getContext(); + OpBuilder moduleBuilder(modOp.getBodyRegion()); + std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str(); + unsigned extra = inputTypes.size(); + inputTypes.append(extraTypes); + auto func = moduleBuilder.create( + loc, wrapper, FunctionType::get(context, inputTypes, outputTypes)); + func.setPublic(); + func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), + UnitAttr::get(context)); + + // Construct new wrapper function body. + auto org = SymbolRefAttr::get(context, funcOp.getName()); + OpBuilder::InsertionGuard insertionGuard(rewriter); + Block *body = func.addEntryBlock(); + rewriter.setInsertionPointToStart(body); + + // Convert inputs. + SmallVector inputs; + convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), + ValueRange(), inputs, 0, /*isIn=*/true); + + // Call original function. + auto call = rewriter.create(loc, funcOp.getResultTypes(), org, + inputs); + + // Convert outputs and return. + SmallVector outputs; + convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), + body->getArguments(), outputs, extra, /*isIn=*/false); + rewriter.create(loc, outputs); + + // Strip the c-interface attribute from the original function. + funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Public method for populating conversion rules. +//===----------------------------------------------------------------------===// + +void mlir::populateSparseAssembler(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 375e10f9068e43..40e98604848cd0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { +#define GEN_PASS_DEF_SPARSEASSEMBLER #define GEN_PASS_DEF_SPARSEREINTERPRETMAP #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE #define GEN_PASS_DEF_SPARSIFICATIONPASS @@ -46,6 +47,18 @@ namespace { // Passes implementation. //===----------------------------------------------------------------------===// +struct SparseAssembler : public impl::SparseAssemblerBase { + SparseAssembler() = default; + SparseAssembler(const SparseAssembler &pass) = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateSparseAssembler(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct SparseReinterpretMap : public impl::SparseReinterpretMapBase { SparseReinterpretMap() = default; @@ -378,6 +391,10 @@ struct StorageSpecifierToLLVMPass // Pass creation methods. //===----------------------------------------------------------------------===// +std::unique_ptr mlir::createSparseAssembler() { + return std::make_unique(); +} + std::unique_ptr mlir::createSparseReinterpretMapPass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir new file mode 100644 index 00000000000000..57df8aca3a6a5b --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/external.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-opt %s --sparse-assembler -split-input-file | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @nop( +// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} { +// CHECK: return %[[A]] : tensor<100xf32> +// CHECK: } +func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } { + return %arg0 : tensor<100xf32> +} + +// ----- + +// CHECK-LABEL: func.func @spiface_sparse_in( +// CHECK-SAME: %[[A:.*]]: tensor, +// CHECK-SAME: %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} { +// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]] +// CHECK: %[[F:.*]] = call @sparse_in(%[[I]]) +// CHECK: return %[[F]] : tensor<64x64xf32> +// CHECK: } +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } { + %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32> + return %0 : tensor<64x64xf32> +} + +// ----- + +// CHECK-LABEL: func.func @spiface_sparse_in2( +// CHECK-SAME: %[[X:.*]]: tensor<100xf32>, +// CHECK-SAME: %[[A:.*]]: tensor, +// CHECK-SAME: %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} { +// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]] +// CHECK: %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]]) +// CHECK: return %[[F]] : tensor<64x64xf32> +// CHECK: } +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } { + %0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32> + return %0 : tensor<64x64xf32> +} + +// ----- + +// CHECK-LABEL: func.func @spiface_sparse_out( +// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>, +// CHECK-SAME: %[[A:.*]]: tensor, +// CHECK-SAME: %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) -> (tensor, tensor, tensor) attributes {llvm.emit_c_interface} { +// CHECK: %[[F:.*]] = call @sparse_out(%[[X]]) +// CHECK: sparse_tensor.disassemble %[[F]] +// CHECK: return +// CHECK: } +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } { + %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse> + return %0 : tensor<64x64xf32, #sparse> +} + +// ----- + +// CHECK-LABEL: func.func @spiface_sparse_out2( +// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>, +// CHECK-SAME: %[[A:.*]]: tensor, +// CHECK-SAME: %[[B:.*]]: tensor, +// CHECK-SAME: %[[C:.*]]: tensor) -> (tensor<64x64xf32>, tensor, tensor, tensor) attributes {llvm.emit_c_interface} { +// CHECK: %[[F:.*]]:2 = call @sparse_out2(%[[X]]) +// CHECK: sparse_tensor.disassemble %[[F]]#1 +// CHECK: return %[[F]]#0 +// CHECK: } +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } { + %0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse> + return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse> +} + +// ----- + +// CHECK-LABEL: func.func @spiface_sparse_inout( +// CHECK-SAME: %[[A:.*0]]: tensor, +// CHECK-SAME: %[[B:.*1]]: tensor, +// CHECK-SAME: %[[C:.*2]]: tensor, +// CHECK-SAME: %[[D:.*3]]: tensor, +// CHECK-SAME: %[[E:.*4]]: tensor, +// CHECK-SAME: %[[F:.*5]]: tensor) -> (tensor, tensor, tensor) attributes {llvm.emit_c_interface} { +// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]] +// CHECK: %[[F:.*]] = call @sparse_inout(%[[I]]) +// CHECK: sparse_tensor.disassemble %[[F]] +// CHECK: return +// CHECK: } +#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } { + return %arg0 : tensor<64x64xf32, #sparse> +}