Skip to content

Commit

Permalink
feat: generic pow support (#25)
Browse files Browse the repository at this point in the history
* pow rewrite
* pinned onnx-mlir patch
* Problem with index to float cast
* Updated onnx testcases to problematic/finished
  • Loading branch information
0xThemis authored Jan 11, 2024
1 parent f152c12 commit c9ca86c
Show file tree
Hide file tree
Showing 73 changed files with 213 additions and 863 deletions.
2 changes: 1 addition & 1 deletion libs/onnx-mlir
17 changes: 10 additions & 7 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ namespace zk_ml_toolchain {
auto val = var_value(assignmnt, toConvert).data;
nil::blueprint::components::FixedPoint<BlueprintFieldType, 1, 1> out(val, 16);
return out.to_double();
// auto Lhs = frames.back().locals[mlir::hash_value(operation.getLhs())];
// auto Rhs = frames.back().locals[mlir::hash_value(operation.getRhs())];
// auto Result = frames.back().locals[mlir::hash_value(operation.getResult())];
// std::cout << toFixpoint(Lhs) << " * " << toFixpoint(Rhs) << " = " << toFixpoint(Result) << "\n";
}

void handleArithOperation(Operation *op) {
Expand Down Expand Up @@ -472,14 +476,14 @@ namespace zk_ml_toolchain {

typename BlueprintFieldType::value_type field_constant = value;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
frames.back().locals[mlir::hash_value(operation.getResult())] = val;
} else if (constantValue.isa<FloatAttr>()) {
double d = llvm::dyn_cast<FloatAttr>(constantValue).getValueAsDouble();
nil::blueprint::components::FixedPoint<BlueprintFieldType, 1, 1> fixed(d);
auto value = put_into_assignment(fixed.get_value());
// this insert is ok, since this should never change, so we
// don't override it if it is already there
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), value));
frames.back().locals[mlir::hash_value(operation.getResult())] = value;
} else {
logger << constantValue;
UNREACHABLE("unhandled constant");
Expand All @@ -489,6 +493,7 @@ namespace zk_ml_toolchain {
auto opHash = mlir::hash_value(operation->getOperand(0));
Type casteeType = operation->getOperand(0).getType();
if (casteeType.isa<IntegerType>()) {
UNREACHABLE("I SHOULD NOT WORK");
auto i = frames.back().locals.find(opHash);
assert(i != frames.back().locals.end());
frames.back().constant_values[mlir::hash_value(operation.getResult())] = resolve_number(i->second);
Expand All @@ -497,7 +502,7 @@ namespace zk_ml_toolchain {
assert(index != frames.back().constant_values.end());
typename BlueprintFieldType::value_type field_constant = index->second;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
frames.back().locals[mlir::hash_value(operation.getResult())] = val;
} else {
UNREACHABLE("unsupported Index Cast");
}
Expand All @@ -524,8 +529,6 @@ namespace zk_ml_toolchain {
std::uint32_t start_row = assignmnt.allocated_rows();
if (math::ExpOp operation = llvm::dyn_cast<math::ExpOp>(op)) {
handle_fixedpoint_exp_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::Exp2Op operation = llvm::dyn_cast<math::Exp2Op>(op)) {
UNREACHABLE("TODO: component for exp2 not ready");
} else if (math::LogOp operation = llvm::dyn_cast<math::LogOp>(op)) {
handle_fixedpoint_log_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::PowFOp operation = llvm::dyn_cast<math::PowFOp>(op)) {
Expand Down Expand Up @@ -874,8 +877,8 @@ namespace zk_ml_toolchain {
auto m = nil::blueprint::memref<VarType>(dims, type.getElementType());
auto hash = mlir::hash_value(operation.getMemref());
auto insert_res = frames.back().memrefs.insert({hash, m});
assert(insert_res.second); // Reallocating over an existing memref
// should not happen ATM
// assert(insert_res.second); // Reallocating over an existing memref
// should not happen ATM
logger.debug("inserting memref with hash %x", size_t(hash));
} else if (memref::AllocaOp operation = llvm::dyn_cast<memref::AllocaOp>(op)) {
// TACEO_TODO: handle cleanup of these stack memrefs
Expand Down
File renamed without changes.
Binary file not shown.
37 changes: 0 additions & 37 deletions mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.mlir

This file was deleted.

This file was deleted.

This file was deleted.

Binary file not shown.

This file was deleted.

This file was deleted.

This file was deleted.

Binary file not shown.

This file was deleted.

This file was deleted.

16 changes: 0 additions & 16 deletions mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.mlir

This file was deleted.

This file was deleted.

Loading

0 comments on commit c9ca86c

Please sign in to comment.