From bf7a0bc3495bd69e4c75dcd1e6bbb320c61c1c25 Mon Sep 17 00:00:00 2001 From: Daniel Kales <11509575+dkales@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:37:03 +0100 Subject: [PATCH] chore: also format other code --- .../src/Passes/mlir/Analysis/CountPass.cpp | 219 ++++------ .../src/Passes/mlir/Analysis/PrintPass.cpp | 205 +++++----- .../mlir/Conversion/AffineFullUnrollPass.cpp | 14 +- .../Conversion/AffineFullUnrollPattern.cpp | 32 +- .../mlir/Transform/ElimCopySignPass.cpp | 13 +- zkml-onnx-compiler/src/zkml-onnx-compiler.cpp | 383 +++++++++--------- 6 files changed, 384 insertions(+), 482 deletions(-) diff --git a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.cpp index ae07eaf..b31b9fa 100644 --- a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.cpp +++ b/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.cpp @@ -1,39 +1,34 @@ #include "CountPass.h" -int64_t zk_ml::evalAffineExpr(AffineExpr expr, ArrayRef dims, - ArrayRef symbols) -{ +int64_t zk_ml::evalAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { int64_t lhs = 0, rhs = 0; - if (auto bin = expr.dyn_cast()) - { + if (auto bin = expr.dyn_cast()) { lhs = evalAffineExpr(bin.getLHS(), dims, symbols); rhs = evalAffineExpr(bin.getRHS(), dims, symbols); } - switch (expr.getKind()) - { - case AffineExprKind::Add: - return lhs + rhs; - case AffineExprKind::Mul: - return lhs * rhs; - case AffineExprKind::Mod: - return mod(lhs, rhs); - case AffineExprKind::FloorDiv: - return floorDiv(lhs, rhs); - case AffineExprKind::CeilDiv: - return ceilDiv(lhs, rhs); - case AffineExprKind::Constant: - return expr.cast().getValue(); - case AffineExprKind::DimId: - return dims[expr.cast().getPosition()]; - case AffineExprKind::SymbolId: - return symbols[expr.cast().getPosition()]; - default: - llvm_unreachable("must be one of AffineExprKind"); + switch (expr.getKind()) { + case AffineExprKind::Add: + return lhs + rhs; + case AffineExprKind::Mul: + return lhs * rhs; + case AffineExprKind::Mod: + return mod(lhs, rhs); + case AffineExprKind::FloorDiv: + return floorDiv(lhs, rhs); + case AffineExprKind::CeilDiv: + return ceilDiv(lhs, rhs); + case AffineExprKind::Constant: + return expr.cast().getValue(); + case AffineExprKind::DimId: + return dims[expr.cast().getPosition()]; + case AffineExprKind::SymbolId: + return symbols[expr.cast().getPosition()]; + default: + llvm_unreachable("must be one of AffineExprKind"); } } -bool zk_ml::evalIntegerSet(IntegerSet set, ArrayRef dims, ArrayRef symbols) -{ +bool zk_ml::evalIntegerSet(IntegerSet set, ArrayRef dims, ArrayRef symbols) { // according to mlir/lib/IR/IntegerSetDetail.h constraints are either // an equality (affine_expr == 0) or an inequality (affine_expr >= 0). // Nevertheless, according to https://mlir.llvm.org/docs/Dialects/Affine/ @@ -43,105 +38,85 @@ bool zk_ml::evalIntegerSet(IntegerSet set, ArrayRef dims, ArrayRef= affine_expr // we have to stick to code anyway but somehow strange ArrayRef constraints = set.getConstraints(); - for (unsigned i = 0; i < constraints.size(); ++i) - { + for (unsigned i = 0; i < constraints.size(); ++i) { int64_t constraint = evalAffineExpr(constraints[i], dims, symbols); - if (set.isEq(i)) - { + if (set.isEq(i)) { llvm::outs() << "we have a equality????\n"; exit(-1); - } - else - { - if (constraint < 0) - { + } else { + if (constraint < 0) { return false; } } } return true; } -bool zk_ml::evalIntegerSet(IntegerSet set, - ArrayRef operands) -{ - return evalIntegerSet(set, operands.take_front(set.getNumDims()), - operands.drop_front(set.getNumDims())); +bool zk_ml::evalIntegerSet(IntegerSet set, ArrayRef operands) { + return evalIntegerSet(set, operands.take_front(set.getNumDims()), operands.drop_front(set.getNumDims())); } -SmallVector zk_ml::evalAffineMap(AffineMap map, ArrayRef dims, - ArrayRef symbols) -{ +SmallVector zk_ml::evalAffineMap(AffineMap map, ArrayRef dims, ArrayRef symbols) { SmallVector result; - for (auto expr : map.getResults()) - { + for (auto expr : map.getResults()) { result.push_back(evalAffineExpr(expr, dims, symbols)); } return result; } -llvm::SmallVector zk_ml::evalAffineMap(AffineMap map, - ArrayRef operands) -{ - return evalAffineMap(map, operands.take_front(map.getNumDims()), - operands.drop_front(map.getNumDims())); +llvm::SmallVector zk_ml::evalAffineMap(AffineMap map, ArrayRef operands) { + return evalAffineMap(map, operands.take_front(map.getNumDims()), operands.drop_front(map.getNumDims())); } // END COPY -StringRef zk_ml::CountPass::getArgument() const { return "count-pass"; } -StringRef zk_ml::CountPass::getDescription() const { return "Does some counting - lets see what"; } -void zk_ml::CountPass::runOnOperation() -{ +StringRef zk_ml::CountPass::getArgument() const { + return "count-pass"; +} +StringRef zk_ml::CountPass::getDescription() const { + return "Does some counting - lets see what"; +} +void zk_ml::CountPass::runOnOperation() { Operation *op = getOperation(); countDepth(op); - for (auto elem : this->counter) - { + for (auto elem : this->counter) { llvm::outs() << elem.first << ": " << elem.second << "\n"; } } -template -T zk_ml::CountPass::castFromAttr(Attribute attr) -{ +template +T zk_ml::CountPass::castFromAttr(Attribute attr) { T result = llvm::dyn_cast(attr); assert(result); return result; } -int64_t zk_ml::CountPass::getMaxFromVector(llvm::SmallVector v) -{ +int64_t zk_ml::CountPass::getMaxFromVector(llvm::SmallVector v) { assert(!v.empty()); int64_t currentMax = v[0]; - for (unsigned i = 1; i < v.size(); ++i) - { + for (unsigned i = 1; i < v.size(); ++i) { if (currentMax < v[i]) currentMax = v[i]; } return currentMax; } -int64_t zk_ml::CountPass::getMinFromVector(llvm::SmallVector v) -{ +int64_t zk_ml::CountPass::getMinFromVector(llvm::SmallVector v) { assert(!v.empty()); int64_t currentMin = v[0]; - for (unsigned i = 1; i < v.size(); ++i) - { + for (unsigned i = 1; i < v.size(); ++i) { if (currentMin > v[i]) currentMin = v[i]; } return currentMin; } -void zk_ml::CountPass::printIndent(unsigned offset) -{ - if (DEBUG_FLAG) - { +void zk_ml::CountPass::printIndent(unsigned offset) { + if (DEBUG_FLAG) { assert(indent >= offset); for (unsigned i = 0; i < indent - offset; ++i) llvm::outs() << " "; } } -void zk_ml::CountPass::doAffineFor(Operation *op, int64_t from, int64_t to, int64_t step) -{ +void zk_ml::CountPass::doAffineFor(Operation *op, int64_t from, int64_t to, int64_t step) { assert(from < to); assert(step); assert(op->getRegions().size() == 1); @@ -153,8 +128,7 @@ void zk_ml::CountPass::doAffineFor(Operation *op, int64_t from, int64_t to, int6 llvm::hash_code counterHash = hash_value(op->getRegions()[0].getArguments()[0]); DEBUG("inserting hash: " << counterHash << ":" << from); this->values.insert(std::make_pair(counterHash, from)); - while (from < to) - { + while (from < to) { for (Region ®ion : op->getRegions()) countRegion(region); from += step; @@ -169,11 +143,9 @@ void zk_ml::CountPass::doAffineFor(Operation *op, int64_t from, int64_t to, int6 indent--; } -template -void zk_ml::CountPass::printSmallvector(llvm::SmallVector &v) -{ - if (DEBUG_FLAG) - { +template +void zk_ml::CountPass::printSmallvector(llvm::SmallVector &v) { + if (DEBUG_FLAG) { llvm::outs() << "v["; for (auto c : v) llvm::outs() << c << ","; @@ -181,29 +153,21 @@ void zk_ml::CountPass::printSmallvector(llvm::SmallVector &v) } } -int64_t zk_ml::CountPass::evaluateForParameter(AffineMap &affineMap, llvm::SmallVector &operands, bool from) -{ - if (affineMap.isConstant()) - { +int64_t zk_ml::CountPass::evaluateForParameter(AffineMap &affineMap, llvm::SmallVector &operands, bool from) { + if (affineMap.isConstant()) { return affineMap.getResult(0).cast().getValue(); - } - else - { + } else { assert(affineMap.getNumInputs() == operands.size()); llvm::SmallVector inVector(affineMap.getNumInputs()); - for (unsigned i = 0; i < affineMap.getNumInputs(); ++i) - { + for (unsigned i = 0; i < affineMap.getNumInputs(); ++i) { llvm::hash_code hash = hash_value(operands[i]); DEBUG("looking for: " << hash); - if (values.find(hash) == values.end()) - { + if (values.find(hash) == values.end()) { DEBUG(affineMap); DEBUG("CANNOT FIND " << hash_value(operands[i])); DEBUG("CANNOT FIND " << operands[i]); exit(0); - } - else - { + } else { assert(values.find(hash) != values.end()); assert(values.count(hash)); inVector[i] = this->values[hash]; @@ -214,29 +178,27 @@ int64_t zk_ml::CountPass::evaluateForParameter(AffineMap &affineMap, llvm::Small } } -void zk_ml::CountPass::countDepth(Operation *op) -{ +void zk_ml::CountPass::countDepth(Operation *op) { // Print the operation itself and some of its properties // Print the operation attributes std::string opName = op->getName().getIdentifier().str(); // printIndent(); // DEBUG("visiting " << opName); - if (opName == AFFINE_FOR) - { + if (opName == AFFINE_FOR) { DEBUG("visiting affine for!"); assert(op->getAttrs().size() == 3); AffineMap fromMap = castFromAttr(op->getAttrs()[0].getValue()).getAffineMap(); int64_t step = llvm::dyn_cast(op->getAttrs()[1].getValue()).getInt(); AffineMap toMap = castFromAttr(op->getAttrs()[2].getValue()).getAffineMap(); assert(fromMap.getNumInputs() + toMap.getNumInputs() == op->getNumOperands()); - llvm::SmallVector operandsFrom(op->getOperands().begin(), op->getOperands().begin() + fromMap.getNumInputs()); - llvm::SmallVector operandsTo(op->getOperands().begin() + fromMap.getNumInputs(), op->getOperands().end()); + llvm::SmallVector operandsFrom(op->getOperands().begin(), + op->getOperands().begin() + fromMap.getNumInputs()); + llvm::SmallVector operandsTo(op->getOperands().begin() + fromMap.getNumInputs(), + op->getOperands().end()); int64_t from = evaluateForParameter(fromMap, operandsFrom, true); int64_t to = evaluateForParameter(toMap, operandsTo, false); doAffineFor(op, from, to, step); - } - else if (opName == AFFINE_IF) - { + } else if (opName == AFFINE_IF) { DEBUG("visiting affine if!"); assert(op->getAttrs().size() == 1); IntegerSet condition = castFromAttr(op->getAttrs()[0].getValue()).getValue(); @@ -245,25 +207,19 @@ void zk_ml::CountPass::countDepth(Operation *op) llvm::SmallVector operands(op->getNumOperands()); DEBUG(op->getAttrs()[0].getValue()); int i = 0; - for (auto operand : op->getOperands()) - { + for (auto operand : op->getOperands()) { llvm::hash_code hash = hash_value(operand); assert(values.find(hash) != values.end()); assert(values.count(hash)); int64_t test = this->values[hash]; operands[i++] = test; } - if (evalIntegerSet(condition, operands)) - { + if (evalIntegerSet(condition, operands)) { countRegion(op->getRegion(0)); - } - else - { + } else { countRegion(op->getRegion(1)); } - } - else if (opName == "affine.apply" || opName == "affine.min") - { + } else if (opName == "affine.apply" || opName == "affine.min") { DEBUG("got affine.apply"); assert(op->getResults().size() == 1); assert(op->getAttrs().size() == 1); @@ -271,9 +227,7 @@ void zk_ml::CountPass::countDepth(Operation *op) llvm::SmallVector operands(op->getOperands().begin(), op->getOperands().end()); int64_t result = evaluateForParameter(applyMap, operands, false); values.insert(std::make_pair(hash_value(op->getResults()[0]), result)); - } - else if (opName == "affine.max") - { + } else if (opName == "affine.max") { DEBUG("got affine.apply"); assert(op->getResults().size() == 1); assert(op->getAttrs().size() == 1); @@ -281,32 +235,22 @@ void zk_ml::CountPass::countDepth(Operation *op) llvm::SmallVector operands(op->getOperands().begin(), op->getOperands().end()); int64_t result = evaluateForParameter(applyMap, operands, true); values.insert(std::make_pair(hash_value(op->getResults()[0]), result)); - } - else if (opName == ARITH_CONST) - { + } else if (opName == ARITH_CONST) { assert(op->getNumResults() == 1); assert(op->getAttrs().size() == 1); Attribute contantValue = op->getAttrs()[0].getValue(); - if (contantValue.isa()) - { + if (contantValue.isa()) { int64_t value = llvm::dyn_cast(contantValue).getInt(); values.insert(std::make_pair(hash_value(op->getResult(0)), value)); - } - else - { + } else { DEBUG("ignoring non int constant"); } - } - else - { + } else { auto operationIter = this->counter.find(opName); - if (operationIter != this->counter.end()) - { + if (operationIter != this->counter.end()) { (*operationIter).second++; // std::cout << "increasing " << opName << std::endl; - } - else - { + } else { this->counter.insert(std::make_pair(opName, 1)); // std::cout << "inserting " << opName << std::endl; } @@ -317,19 +261,16 @@ void zk_ml::CountPass::countDepth(Operation *op) } } -void zk_ml::CountPass::countRegion(Region ®ion) -{ +void zk_ml::CountPass::countRegion(Region ®ion) { for (Block &block : region.getBlocks()) countBlock(block); } -void zk_ml::CountPass::countBlock(Block &block) -{ +void zk_ml::CountPass::countBlock(Block &block) { for (Operation &op : block.getOperations()) countDepth(&op); } -std::unique_ptr zk_ml::createCountPass() -{ +std::unique_ptr zk_ml::createCountPass() { return std::make_unique(); } diff --git a/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.cpp index 5043c3c..3de8cc3 100644 --- a/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.cpp +++ b/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.cpp @@ -1,137 +1,120 @@ #include "PrintPass.h" -StringRef zk_ml_toolchain::PrintPass::getArgument() const { return "print-pass"; } - -StringRef zk_ml_toolchain::PrintPass::getDescription() const -{ - return "Prints some Debug Information (copied from Tutorial)"; +StringRef zk_ml_toolchain::PrintPass::getArgument() const { + return "print-pass"; } -void zk_ml_toolchain::PrintPass::runOnOperation() -{ - Operation *op = getOperation(); - resetIndent(); - std::vector typeIds; - printOperation(op, typeIds); +StringRef zk_ml_toolchain::PrintPass::getDescription() const { + return "Prints some Debug Information (copied from Tutorial)"; } -void zk_ml_toolchain::PrintPass::printVector(std::vector &typeIds) -{ - std::cout << "["; - for (auto element : typeIds) - { - std::cout << element << ", "; - } - std::cout << "]" << std::endl; +void zk_ml_toolchain::PrintPass::runOnOperation() { + Operation *op = getOperation(); + resetIndent(); + std::vector typeIds; + printOperation(op, typeIds); } -void zk_ml_toolchain::PrintPass::printOperation(Operation *op, std::vector &typeIds) -{ - // Print the operation itself and some of its properties - std::string opName = op->getName().getIdentifier().str(); - if (opName == "krnl.gloabl") - { - printIndent() << "visiting: krnl.global"; - return; - } - unsigned numOperands = op->getNumOperands(); - unsigned numResults = op->getNumResults(); - printIndent() << "visiting op: '" << op->getName() << "' with " - << numOperands << " operands and " << numResults - << " results\n"; - // Print the operation attributes - if (!op->getAttrs().empty()) - { - printIndent() << op->getAttrs().size() << " attributes:\n"; - for (NamedAttribute attr : op->getAttrs()) - printIndent() << " - '" << attr.getName().getValue() << "' : '" - << attr.getValue() << "'\n"; - } - - // Recurse into each of the regions attached to the operation. - printIndent() << " " << op->getNumRegions() << " nested regions:\n"; - if (opName == "arith.constant") - { - if (numResults != 1) - { - std::cout << "whaaaat" << std::endl; - exit(0); +void zk_ml_toolchain::PrintPass::printVector(std::vector &typeIds) { + std::cout << "["; + for (auto element : typeIds) { + std::cout << element << ", "; } - llvm::hash_code hash = hash_value(op->getResults()[0]); - std::cout << hash << std::endl; - printVector(typeIds); - if (std::find(typeIds.begin(), typeIds.end(), hash) != typeIds.end()) - { - std::cout << "whaaaaaaaaat already in vector" << std::endl; - std::cout << *(std::find(typeIds.begin(), typeIds.end(), hash)) - << std::endl; + std::cout << "]" << std::endl; +} + +void zk_ml_toolchain::PrintPass::printOperation(Operation *op, std::vector &typeIds) { + // Print the operation itself and some of its properties + std::string opName = op->getName().getIdentifier().str(); + if (opName == "krnl.gloabl") { + printIndent() << "visiting: krnl.global"; + return; } - else - { - typeIds.emplace_back(hash); + unsigned numOperands = op->getNumOperands(); + unsigned numResults = op->getNumResults(); + printIndent() << "visiting op: '" << op->getName() << "' with " << numOperands << " operands and " << numResults + << " results\n"; + // Print the operation attributes + if (!op->getAttrs().empty()) { + printIndent() << op->getAttrs().size() << " attributes:\n"; + for (NamedAttribute attr : op->getAttrs()) + printIndent() << " - '" << attr.getName().getValue() << "' : '" << attr.getValue() << "'\n"; } - } - else if (opName == "affine.for" && numOperands > 0) - { - OperandRange operands = op->getOperands(); - for (uint64_t i = 0; i < operands.size(); ++i) - { - llvm::hash_code hash = hash_value(operands[i].getType()); - if (std::find(typeIds.begin(), typeIds.end(), hash) == typeIds.end()) - { - std::cout << "whaaaaaaaaat not in vector" << std::endl; - exit(0); - } + + // Recurse into each of the regions attached to the operation. + printIndent() << " " << op->getNumRegions() << " nested regions:\n"; + if (opName == "arith.constant") { + if (numResults != 1) { + std::cout << "whaaaat" << std::endl; + exit(0); + } + llvm::hash_code hash = hash_value(op->getResults()[0]); + std::cout << hash << std::endl; + printVector(typeIds); + if (std::find(typeIds.begin(), typeIds.end(), hash) != typeIds.end()) { + std::cout << "whaaaaaaaaat already in vector" << std::endl; + std::cout << *(std::find(typeIds.begin(), typeIds.end(), hash)) << std::endl; + } else { + typeIds.emplace_back(hash); + } + } else if (opName == "affine.for" && numOperands > 0) { + OperandRange operands = op->getOperands(); + for (uint64_t i = 0; i < operands.size(); ++i) { + llvm::hash_code hash = hash_value(operands[i].getType()); + if (std::find(typeIds.begin(), typeIds.end(), hash) == typeIds.end()) { + std::cout << "whaaaaaaaaat not in vector" << std::endl; + exit(0); + } + } } - } - auto indent = pushIndent(); - for (Region ®ion : op->getRegions()) - printRegion(region, typeIds); + auto indent = pushIndent(); + for (Region ®ion : op->getRegions()) + printRegion(region, typeIds); } -void zk_ml_toolchain::PrintPass::printRegion(Region ®ion, std::vector &typeIds) -{ - // A region does not hold anything by itself other than a list of blocks. - printIndent() << "Region with " << region.getBlocks().size() - << " blocks:\n"; - auto indent = pushIndent(); - for (Block &block : region.getBlocks()) - printBlock(block, typeIds); +void zk_ml_toolchain::PrintPass::printRegion(Region ®ion, std::vector &typeIds) { + // A region does not hold anything by itself other than a list of blocks. + printIndent() << "Region with " << region.getBlocks().size() << " blocks:\n"; + auto indent = pushIndent(); + for (Block &block : region.getBlocks()) + printBlock(block, typeIds); } -void zk_ml_toolchain::PrintPass::printBlock(Block &block, std::vector &typeIds) -{ - // Print the block intrinsics properties (basically: argument list) - printIndent() - << "Block with " << block.getNumArguments() << " arguments, " - << block.getNumSuccessors() - << " successors, and " - // Note, this `.size()` is traversing a linked-list and is O(n). - << block.getOperations().size() << " operations\n"; - - // Block main role is to hold a list of Operations: let's recurse. - auto indent = pushIndent(); - for (Operation &op : block.getOperations()) - printOperation(&op, typeIds); +void zk_ml_toolchain::PrintPass::printBlock(Block &block, std::vector &typeIds) { + // Print the block intrinsics properties (basically: argument list) + printIndent() << "Block with " << block.getNumArguments() << " arguments, " << block.getNumSuccessors() + << " successors, and " + // Note, this `.size()` is traversing a linked-list and is O(n). + << block.getOperations().size() << " operations\n"; + + // Block main role is to hold a list of Operations: let's recurse. + auto indent = pushIndent(); + for (Operation &op : block.getOperations()) + printOperation(&op, typeIds); } -zk_ml_toolchain::PrintPass::IdentRAII::IdentRAII(int &indent) : indent(indent) {} +zk_ml_toolchain::PrintPass::IdentRAII::IdentRAII(int &indent) : indent(indent) { +} -zk_ml_toolchain::PrintPass::IdentRAII::~IdentRAII() { --indent; } +zk_ml_toolchain::PrintPass::IdentRAII::~IdentRAII() { + --indent; +} -void zk_ml_toolchain::PrintPass::resetIndent() { indent = 0; } +void zk_ml_toolchain::PrintPass::resetIndent() { + indent = 0; +} -zk_ml_toolchain::PrintPass::IdentRAII zk_ml_toolchain::PrintPass::pushIndent() { return IdentRAII(++indent); } +zk_ml_toolchain::PrintPass::IdentRAII zk_ml_toolchain::PrintPass::pushIndent() { + return IdentRAII(++indent); +} -llvm::raw_ostream &zk_ml_toolchain::PrintPass::printIndent() -{ - for (int i = 0; i < indent; ++i) - llvm::outs() << " "; - return llvm::outs(); +llvm::raw_ostream &zk_ml_toolchain::PrintPass::printIndent() { + for (int i = 0; i < indent; ++i) + llvm::outs() << " "; + return llvm::outs(); } -std::unique_ptr zk_ml_toolchain::createPrintPass() -{ - return std::make_unique(); +std::unique_ptr zk_ml_toolchain::createPrintPass() { + return std::make_unique(); } diff --git a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.cpp index 753d569..6a912c3 100644 --- a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.cpp +++ b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.cpp @@ -8,13 +8,13 @@ using mlir::AffineForOp; using mlir::loopUnrollFull; void zk_ml::AffineFullUnrollPass::runOnOperation() { - getOperation().walk([&](AffineForOp op) { - if (failed(loopUnrollFull(op))) { - op.emitError("unrolling failed"); - signalPassFailure(); - } - }); + getOperation().walk([&](AffineForOp op) { + if (failed(loopUnrollFull(op))) { + op.emitError("unrolling failed"); + signalPassFailure(); + } + }); } std::unique_ptr zk_ml::createFullUnrollPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.cpp b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.cpp index bf68723..8152690 100644 --- a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.cpp +++ b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.cpp @@ -6,30 +6,28 @@ #include "AffineFullUnrollPattern.h" - using mlir::AffineForOp; using mlir::loopUnrollFull; namespace { - struct AffineFullUnrollPattern : public OpRewritePattern { + struct AffineFullUnrollPattern : public OpRewritePattern { - AffineFullUnrollPattern(MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/1) {} - - LogicalResult matchAndRewrite(AffineForOp op, - PatternRewriter &rewriter) const override { - return loopUnrollFull(op); - } - }; -} + AffineFullUnrollPattern(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) { + } + + LogicalResult matchAndRewrite(AffineForOp op, PatternRewriter &rewriter) const override { + return loopUnrollFull(op); + } + }; +} // namespace void zk_ml::AffineFullUnrollPassAsPatternRewrite::runOnOperation() { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - // One could use GreedyRewriteConfig here to slightly tweak the behavior of - // the pattern application. - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + // One could use GreedyRewriteConfig here to slightly tweak the behavior of + // the pattern application. + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } std::unique_ptr zk_ml::createFullUnrollPassPatternRewriter() { - return std::make_unique(); + return std::make_unique(); } diff --git a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp index 4903462..93f6868 100644 --- a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp +++ b/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp @@ -1,21 +1,20 @@ #include "ElimCopySignPass.h" -StringRef zk_ml_toolchain::ElimCopySignPass::getArgument() const { return "elim-copysign-pass"; } +StringRef zk_ml_toolchain::ElimCopySignPass::getArgument() const { + return "elim-copysign-pass"; +} -StringRef zk_ml_toolchain::ElimCopySignPass::getDescription() const -{ +StringRef zk_ml_toolchain::ElimCopySignPass::getDescription() const { return "Eliminates redundant copysign operations that follow an frem operation"; } -void zk_ml_toolchain::ElimCopySignPass::runOnOperation() -{ +void zk_ml_toolchain::ElimCopySignPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -std::unique_ptr zk_ml_toolchain::createElimCopySignPass() -{ +std::unique_ptr zk_ml_toolchain::createElimCopySignPass() { return std::make_unique(); } diff --git a/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp b/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp index af27f01..a4cd281 100644 --- a/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp +++ b/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp @@ -15,229 +15,210 @@ enum EmitLevel { zkMLIR, ONNX, MLIR, LLVMIR }; -llvm::cl::opt InputFilename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::Required); -llvm::cl::opt - OutputFilename("i", llvm::cl::desc("Specify output filename"), - llvm::cl::value_desc("filename"), - llvm::cl::init(STDOUT_MARKER)); - -llvm::cl::opt - EmitLevel(llvm::cl::desc("Which lowering level do you want?"), - llvm::cl::values(clEnumVal(ONNX, "Lower to \"ONNX\" dialect."), - clEnumVal(MLIR, "Lower to \"MLIR-IR\"."), - clEnumVal(zkMLIR, "Lower to \"zkMLIR-IR\"."), - clEnumVal(LLVMIR, "Lower to \"LLVM-IR\"."))); - - -llvm::cl::opt ZkMlDebugFlag("DEBUG", - llvm::cl::desc("turns on debugging log"), - llvm::cl::init(false)); +llvm::cl::opt InputFilename(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::Required); +llvm::cl::opt OutputFilename("i", llvm::cl::desc("Specify output filename"), + llvm::cl::value_desc("filename"), llvm::cl::init(STDOUT_MARKER)); + +llvm::cl::opt EmitLevel(llvm::cl::desc("Which lowering level do you want?"), + llvm::cl::values(clEnumVal(ONNX, "Lower to \"ONNX\" dialect."), + clEnumVal(MLIR, "Lower to \"MLIR-IR\"."), + clEnumVal(zkMLIR, "Lower to \"zkMLIR-IR\"."), + clEnumVal(LLVMIR, "Lower to \"LLVM-IR\"."))); + +llvm::cl::opt ZkMlDebugFlag("DEBUG", llvm::cl::desc("turns on debugging log"), llvm::cl::init(false)); bool hasEnding(std::string const &fullString, std::string const &ending) { - if (fullString.length() >= ending.length()) { - return (0 == fullString.compare(fullString.length() - ending.length(), - ending.length(), ending)); - } else { - return false; - } + if (fullString.length() >= ending.length()) { + return (0 == fullString.compare(fullString.length() - ending.length(), ending.length(), ending)); + } else { + return false; + } } std::string dirName(StringRef inputFilename) { - llvm::SmallVector path(inputFilename.begin(), inputFilename.end()); - llvm::sys::path::remove_filename(path); - return std::string(path.data(), path.size()); + llvm::SmallVector path(inputFilename.begin(), inputFilename.end()); + llvm::sys::path::remove_filename(path); + return std::string(path.data(), path.size()); } -int loadOnnxFile(StringRef inputFilename, mlir::MLIRContext &context, - mlir::OwningOpRef &module, +int loadOnnxFile(StringRef inputFilename, mlir::MLIRContext &context, mlir::OwningOpRef &module, std::string *errorMessage) { - // we use default options for now from onnx-mlir, lets see if we need - // something else - onnx_mlir::ImportOptions options; - options.useOnnxModelTypes = true; - options.invokeOnnxVersionConverter = false; - // TODO check the default value - options.shapeInformation = onnx_mlir::shapeInformation; - options.allowSorting = true; - options.externalDataDir = dirName(inputFilename); - // does not exist at commit a04f518c1 - // options.functionsToDecompose.insert(options.functionsToDecompose.end(), - // onnx_mlir::functionsToDecompose.begin(), - // onnx_mlir::functionsToDecompose.end()); - return onnx_mlir::ImportFrontendModelFile(inputFilename, context, module, - errorMessage); - // return onnx_mlir::ImportFrontendModelFile(inputFilename, context, module, - // errorMessage, options); + // we use default options for now from onnx-mlir, lets see if we need + // something else + onnx_mlir::ImportOptions options; + options.useOnnxModelTypes = true; + options.invokeOnnxVersionConverter = false; + // TODO check the default value + options.shapeInformation = onnx_mlir::shapeInformation; + options.allowSorting = true; + options.externalDataDir = dirName(inputFilename); + // does not exist at commit a04f518c1 + // options.functionsToDecompose.insert(options.functionsToDecompose.end(), + // onnx_mlir::functionsToDecompose.begin(), + // onnx_mlir::functionsToDecompose.end()); + return onnx_mlir::ImportFrontendModelFile(inputFilename, context, module, errorMessage); + // return onnx_mlir::ImportFrontendModelFile(inputFilename, context, module, + // errorMessage, options); } -std::unique_ptr -lowerToLLVM(llvm::LLVMContext &llvmContext, - mlir::OwningOpRef &mlirModule, int *error_code) { - std::error_code error; - - // TODO do we want to emit .bc? Or at least make it configureable - mlir::registerLLVMDialectTranslation(*mlirModule->getContext()); - std::unique_ptr llvmModule = - mlir::translateModuleToLLVMIR(*mlirModule, llvmContext); - if (!llvmModule) { - llvm::errs() << "Failed to translate module to LLVMIR.\n"; - *error_code = -1; - return nullptr; - } - return llvmModule; +std::unique_ptr lowerToLLVM(llvm::LLVMContext &llvmContext, mlir::OwningOpRef &mlirModule, + int *error_code) { + std::error_code error; + + // TODO do we want to emit .bc? Or at least make it configureable + mlir::registerLLVMDialectTranslation(*mlirModule->getContext()); + std::unique_ptr llvmModule = mlir::translateModuleToLLVMIR(*mlirModule, llvmContext); + if (!llvmModule) { + llvm::errs() << "Failed to translate module to LLVMIR.\n"; + *error_code = -1; + return nullptr; + } + return llvmModule; } -void runZkMlPasses(std::unique_ptr &llvm_module, - llvm::OptimizationLevel OptimizationLevel) { - // create all analyses - // llvm::ModuleAnalysisManager MAM; - // llvm::LoopAnalysisManager LAM; - // llvm::FunctionAnalysisManager FAM; - // llvm::CGSCCAnalysisManager CGAM; - - // llvm::PassBuilder PB; - //// Register all the basic analyses with the managers. - // PB.registerModuleAnalyses(MAM); - // PB.registerCGSCCAnalyses(CGAM); - // PB.registerFunctionAnalyses(FAM); - // PB.registerLoopAnalyses(LAM); - // PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); - - // This one corresponds to a typical -O2 optimization pipeline. - // llvm::ModulePassManager MPM = - // PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O1); - // // we got default Module Passmanager for corresponding OptimizationLevel - // // now add our passes - // llvm::FunctionPassManager FPM; - // FPM.addPass(zk_ml::AddCircuitFnAttrPass()); - // MPM.addPass(llvm::createModuleToFunctionPassAdaptor(std::move(FPM))); - // MPM.run(*llvm_module, MAM); +void runZkMlPasses(std::unique_ptr &llvm_module, llvm::OptimizationLevel OptimizationLevel) { + // create all analyses + // llvm::ModuleAnalysisManager MAM; + // llvm::LoopAnalysisManager LAM; + // llvm::FunctionAnalysisManager FAM; + // llvm::CGSCCAnalysisManager CGAM; + + // llvm::PassBuilder PB; + //// Register all the basic analyses with the managers. + // PB.registerModuleAnalyses(MAM); + // PB.registerCGSCCAnalyses(CGAM); + // PB.registerFunctionAnalyses(FAM); + // PB.registerLoopAnalyses(LAM); + // PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + // This one corresponds to a typical -O2 optimization pipeline. + // llvm::ModulePassManager MPM = + // PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O1); + // // we got default Module Passmanager for corresponding OptimizationLevel + // // now add our passes + // llvm::FunctionPassManager FPM; + // FPM.addPass(zk_ml::AddCircuitFnAttrPass()); + // MPM.addPass(llvm::createModuleToFunctionPassAdaptor(std::move(FPM))); + // MPM.run(*llvm_module, MAM); } -void outputModule(mlir::OwningOpRef &module, - std::string &outputFilename, int64_t largeElementLimit = -1) { - mlir::OpPrintingFlags flags; - if (onnx_mlir::preserveLocations) - flags.enableDebugInfo(); - if (largeElementLimit >= 0) - flags.elideLargeElementsAttrs(largeElementLimit); - // yeah zero means equal.... - if (outputFilename.compare(STDOUT_MARKER) == 0) { - module->print(llvm::outs(), flags); - } else { - std::error_code fileError; - llvm::raw_fd_ostream fileStream(llvm::StringRef(outputFilename), fileError); - module->print(fileStream, flags); - fileStream.close(); - } +void outputModule(mlir::OwningOpRef &module, std::string &outputFilename, + int64_t largeElementLimit = -1) { + mlir::OpPrintingFlags flags; + if (onnx_mlir::preserveLocations) + flags.enableDebugInfo(); + if (largeElementLimit >= 0) + flags.elideLargeElementsAttrs(largeElementLimit); + // yeah zero means equal.... + if (outputFilename.compare(STDOUT_MARKER) == 0) { + module->print(llvm::outs(), flags); + } else { + std::error_code fileError; + llvm::raw_fd_ostream fileStream(llvm::StringRef(outputFilename), fileError); + module->print(fileStream, flags); + fileStream.close(); + } } -std::unique_ptr -translateToLLVMIR(mlir::ModuleOp module, llvm::LLVMContext &llvm_context) { - mlir::registerLLVMDialectTranslation(*module.getContext()); - std::unique_ptr llvm_module = - mlir::translateModuleToLLVMIR(module, llvm_context); - if (!llvm_module) { - llvm::errs() << "Failed to translate module to LLVMIR.\n"; - } - return llvm_module; +std::unique_ptr translateToLLVMIR(mlir::ModuleOp module, llvm::LLVMContext &llvm_context) { + mlir::registerLLVMDialectTranslation(*module.getContext()); + std::unique_ptr llvm_module = mlir::translateModuleToLLVMIR(module, llvm_context); + if (!llvm_module) { + llvm::errs() << "Failed to translate module to LLVMIR.\n"; + } + return llvm_module; } int main(int argc, char **argv) { - /* int optLevel = std::stoi(argv[2]); - switch (optLevel) { - case 0: - onnx_mlir::setOptLevel(onnx_mlir::O0); - break; - case 1: - onnx_mlir::setOptLevel(onnx_mlir::O1); - break; - case 2: - onnx_mlir::setOptLevel(onnx_mlir::O2); - break; - case 3: - onnx_mlir::setOptLevel(onnx_mlir::O3); - break; - default: - llvm::outs() << "opt level must be on of {0,1,2,3}"; - return -2; - }*/ - llvm::cl::ParseCommandLineOptions(argc, argv); - std::string inputFilename = InputFilename.c_str(); - //=========================== - // LETS SEE IF WE NEED THIS - - // copied from onnx-mlir.cpp (lets see what we need) - // Register MLIR command line options. - mlir::registerAsmPrinterCLOptions(); - mlir::registerMLIRContextCLOptions(); - mlir::registerPassManagerCLOptions(); - mlir::registerDefaultTimingManagerCLOptions(); - mlir::registerAsmPrinterCLOptions(); - - llvm::cl::SetVersionPrinter(onnx_mlir::getVersionPrinter); - //=========================== - // - mlir::MLIRContext context; - // does not exist at commit a04f518c1 - // context.appendDialectRegistry(onnx_mlir::registerDialects(onnx_mlir::maccel)); - // context.loadAllAvailableDialects(); - onnx_mlir::registerDialects(context); - context.getOrLoadDialect(); - - mlir::OwningOpRef module; - std::string errorMessage; - if (int rc = loadOnnxFile(llvm::StringRef(inputFilename), context, module, - &errorMessage)) { - llvm::errs() << "Cannot load .onnx file:\n"; - llvm::errs() << errorMessage << "\n"; - return rc; - } - bool EmitMLIR = EmitLevel::zkMLIR == EmitLevel || EmitLevel::MLIR == EmitLevel; - mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit); - if (EmitLevel == EmitLevel::ONNX) { - onnx_mlir::addPasses(module, pm, onnx_mlir::EmissionTargetType::EmitONNXIR); - } else { - onnx_mlir::addPasses(module, pm, onnx_mlir::EmissionTargetType::EmitMLIR, EmitLevel == EmitLevel::zkMLIR); - pm.addPass(zk_ml_toolchain::createElimCopySignPass()); - if (!EmitMLIR) { - // third parameter here is optional in onnx-mlir. Maybe we should do that - // too? - onnx_mlir::addKrnlToLLVMPasses(pm, true, true); + /* int optLevel = std::stoi(argv[2]); + switch (optLevel) { + case 0: + onnx_mlir::setOptLevel(onnx_mlir::O0); + break; + case 1: + onnx_mlir::setOptLevel(onnx_mlir::O1); + break; + case 2: + onnx_mlir::setOptLevel(onnx_mlir::O2); + break; + case 3: + onnx_mlir::setOptLevel(onnx_mlir::O3); + break; + default: + llvm::outs() << "opt level must be on of {0,1,2,3}"; + return -2; + }*/ + llvm::cl::ParseCommandLineOptions(argc, argv); + std::string inputFilename = InputFilename.c_str(); + //=========================== + // LETS SEE IF WE NEED THIS + + // copied from onnx-mlir.cpp (lets see what we need) + // Register MLIR command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + mlir::registerDefaultTimingManagerCLOptions(); + mlir::registerAsmPrinterCLOptions(); + + llvm::cl::SetVersionPrinter(onnx_mlir::getVersionPrinter); + //=========================== + // + mlir::MLIRContext context; + // does not exist at commit a04f518c1 + // context.appendDialectRegistry(onnx_mlir::registerDialects(onnx_mlir::maccel)); + // context.loadAllAvailableDialects(); + onnx_mlir::registerDialects(context); + context.getOrLoadDialect(); + + mlir::OwningOpRef module; + std::string errorMessage; + if (int rc = loadOnnxFile(llvm::StringRef(inputFilename), context, module, &errorMessage)) { + llvm::errs() << "Cannot load .onnx file:\n"; + llvm::errs() << errorMessage << "\n"; + return rc; + } + bool EmitMLIR = EmitLevel::zkMLIR == EmitLevel || EmitLevel::MLIR == EmitLevel; + mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit); + if (EmitLevel == EmitLevel::ONNX) { + onnx_mlir::addPasses(module, pm, onnx_mlir::EmissionTargetType::EmitONNXIR); + } else { + onnx_mlir::addPasses(module, pm, onnx_mlir::EmissionTargetType::EmitMLIR, EmitLevel == EmitLevel::zkMLIR); + pm.addPass(zk_ml_toolchain::createElimCopySignPass()); + if (!EmitMLIR) { + // third parameter here is optional in onnx-mlir. Maybe we should do that + // too? + onnx_mlir::addKrnlToLLVMPasses(pm, true, true); + } } - } - - (void)mlir::applyPassManagerCLOptions(pm); - mlir::applyDefaultTimingPassManagerCLOptions(pm); - if (mlir::failed(pm.run(*module))) { - llvm::errs() << "Passmanager failed to run!\n"; - return -1; - } - std::string outputFilename = OutputFilename.c_str(); - - if (EmitMLIR || EmitLevel::ONNX == EmitLevel) { - outputModule(module, outputFilename); - return 0; - } else { - int error_code; - llvm::LLVMContext llvmContext; - std::unique_ptr llvm_module = - lowerToLLVM(llvmContext, module, &error_code); - if (!llvm_module) - return error_code; - if (outputFilename.compare(STDOUT_MARKER) == 0) { - llvm_module->print(llvm::outs(), nullptr); + (void)mlir::applyPassManagerCLOptions(pm); + mlir::applyDefaultTimingPassManagerCLOptions(pm); + if (mlir::failed(pm.run(*module))) { + llvm::errs() << "Passmanager failed to run!\n"; + return -1; + } + std::string outputFilename = OutputFilename.c_str(); + + if (EmitMLIR || EmitLevel::ONNX == EmitLevel) { + outputModule(module, outputFilename); + return 0; } else { - std::error_code fileError; - llvm::raw_fd_ostream fileStream(llvm::StringRef(outputFilename), - fileError); - llvm_module->print(fileStream, nullptr); - fileStream.close(); + int error_code; + llvm::LLVMContext llvmContext; + std::unique_ptr llvm_module = lowerToLLVM(llvmContext, module, &error_code); + if (!llvm_module) + return error_code; + + if (outputFilename.compare(STDOUT_MARKER) == 0) { + llvm_module->print(llvm::outs(), nullptr); + } else { + std::error_code fileError; + llvm::raw_fd_ostream fileStream(llvm::StringRef(outputFilename), fileError); + llvm_module->print(fileStream, nullptr); + fileStream.close(); + } } - } - return 0; + return 0; }