Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: input ops #21

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

namespace nil {
namespace blueprint {
// TODO There is also the logic_and_flag. Should we use this one or should we use the logic_ops?????
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_and(
mlir::arith::AndIOp &operation,
Expand All @@ -54,7 +55,8 @@ namespace nil {
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

Expand All @@ -66,17 +68,23 @@ namespace nil {
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
//FIXME logic_or is commented out. As soon as it is enabled, remove add the liens above and it SHOULD work
UNREACHABLE("LogicOR not enabled in blueprint");
// using component_type = components::lookup_logic_or<
// crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
//
// auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp);
// const auto p = detail::PolicyManager::get_parameters(
// detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
//
// component_type component(p.witness);
// fill_trace(component, input, operation, frame, bp, assignment, start_row);
using component_type = components::logic_or_flag<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
using input_type = typename component_type::input_type;

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

input_type input;
input.x = lhs->second;
input.y = rhs->second;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -94,7 +102,8 @@ namespace nil {
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmax.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sqrt.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/logic_or_flag.hpp>

#define PREPARE_UNARY_INPUT(OP) \
prepare_unary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
Expand Down
159 changes: 139 additions & 20 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP
#define CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP

#include "nil/blueprint/blueprint/plonk/assignment.hpp"
#include <cassert>
#include <cstdint>
#define TEST_WITHOUT_LOOKUP_TABLES
Expand Down Expand Up @@ -53,7 +54,7 @@
#include <mlir-assigner/components/fixedpoint/subtraction.hpp>
#include <mlir-assigner/components/fixedpoint/dot_product.hpp>
#include <mlir-assigner/components/fixedpoint/trigonometric.hpp>
#include <mlir-assigner/components/boolean/and.hpp>
#include <mlir-assigner/components/boolean/logic_ops.hpp>
#include <mlir-assigner/components/fixedpoint/to_fixpoint.hpp>

#include <mlir-assigner/memory/memref.hpp>
Expand Down Expand Up @@ -190,6 +191,17 @@ namespace zk_ml_toolchain {
bool PrintCircuitOutput;
nil::blueprint::logger &logger;

template<typename NumberType>
NumberType resolve_number(VarType scalar) {
auto scalar_value = var_value(assignmnt, scalar);
static constexpr auto limit_value =
typename BlueprintFieldType::integral_type(std::numeric_limits<NumberType>::max());
auto integral_value = static_cast<typename BlueprintFieldType::integral_type>(scalar_value.data);
ASSERT_MSG(integral_value < limit_value, "Too large to cast");
NumberType number = static_cast<NumberType>(integral_value);
return number;
}

void doAffineFor(affine::AffineForOp &op, int64_t from, int64_t to, int64_t step) {
assert(from < to);
assert(step);
Expand Down Expand Up @@ -253,7 +265,45 @@ namespace zk_ml_toolchain {
} else if (arith::CmpFOp operation = llvm::dyn_cast<arith::CmpFOp>(op)) {
handle_fixedpoint_comparison_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::SelectOp operation = llvm::dyn_cast<arith::SelectOp>(op)) {
handle_select_component(operation, frames.back(), bp, assignmnt, start_row);
ASSERT(operation.getNumOperands() == 3 && "Select must have three operands");
ASSERT(operation->getOperand(1).getType() == operation->getOperand(2).getType() &&
"Select must operate on same type");
// check if we work on indices
Type operandType = operation->getOperand(1).getType();
auto i1Hash = mlir::hash_value(operation->getOperand(0));
if (operandType.isa<IndexType>()) {
// for now we expect that if we select on indices, that we also have the cmp result in
// constant values. Let's see if this holds true in the future
auto cmpResult = frames.back().constant_values.find(i1Hash);
ASSERT(cmpResult != frames.back().constant_values.end());
if (cmpResult->second) {
auto truthy = frames.back().constant_values.find(mlir::hash_value(operation->getOperand(1)));
ASSERT(truthy != frames.back().constant_values.end());
frames.back().constant_values[mlir::hash_value(operation->getResult(0))] = truthy->second;
} else {
auto falsy = frames.back().constant_values.find(mlir::hash_value(operation->getOperand(2)));
ASSERT(falsy != frames.back().constant_values.end());
frames.back().constant_values[mlir::hash_value(operation->getResult(0))] = falsy->second;
}
} else if (frames.back().constant_values.find(i1Hash) != frames.back().constant_values.end()) {
// we come from index comparision but we do not work on indices, ergo we need to get from locals
if (frames.back().constant_values[i1Hash]) {
auto truthy = frames.back().locals.find(mlir::hash_value(operation->getOperand(1)));
ASSERT(truthy != frames.back().locals.end());
frames.back().locals[mlir::hash_value(operation->getResult(0))] = truthy->second;
} else {
auto falsy = frames.back().locals.find(mlir::hash_value(operation->getOperand(2)));
ASSERT(falsy != frames.back().locals.end());
frames.back().locals[mlir::hash_value(operation->getResult(0))] = falsy->second;
}
} else if (operandType.isa<FloatType>()) {
handle_select_component(operation, frames.back(), bp, assignmnt, start_row);
} else {
std::string typeStr;
llvm::raw_string_ostream ss(typeStr);
ss << operandType;
UNREACHABLE(std::string("unhandled select operand: ") + typeStr);
}
} else if (arith::NegFOp operation = llvm::dyn_cast<arith::NegFOp>(op)) {
handle_fixedpoint_neg_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::AndIOp operation = llvm::dyn_cast<arith::AndIOp>(op)) {
Expand All @@ -267,14 +317,28 @@ namespace zk_ml_toolchain {
UNREACHABLE("TODO add Bitwise And Gadget");
}
} else if (arith::OrIOp operation = llvm::dyn_cast<arith::OrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
ASSERT(operation.getNumOperands() == 2 && "Or must have two operands");
ASSERT(operation->getOperand(0).getType() == operation->getOperand(1).getType() &&
"Or must operate on same type");
// check if we work on indices
// TODO this seems like a hack, maybe we can do something better
auto lhsHash = mlir::hash_value(operation.getLhs());
if (frames.back().constant_values.find(lhsHash) != frames.back().constant_values.end()) {
auto lhs = frames.back().constant_values[lhsHash];
auto rhs = frames.back().constant_values.find(mlir::hash_value(operation.getRhs()));
assert(rhs != frames.back().constant_values.end());
auto result = lhs | rhs->second;
frames.back().constant_values[mlir::hash_value(operation.getResult())] = result;
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
}
}
} else if (arith::XOrIOp operation = llvm::dyn_cast<arith::XOrIOp>(op)) {
// check if logical and or bitwise and
Expand All @@ -287,7 +351,6 @@ namespace zk_ml_toolchain {
UNREACHABLE("TODO add Bitwise XOr Gadget");
}
} else if (arith::AddIOp operation = llvm::dyn_cast<arith::AddIOp>(op)) {

// TODO: ATM, handle only the case where we work on indices that are
// constant values
auto lhs = frames.back().constant_values.find(mlir::hash_value(operation.getLhs()));
Expand Down Expand Up @@ -323,8 +386,53 @@ namespace zk_ml_toolchain {
frames.back().constant_values[mlir::hash_value(operation.getResult())] = result;

} else if (arith::CmpIOp operation = llvm::dyn_cast<arith::CmpIOp>(op)) {
llvm::outs() << "icmp\n";
exit(0);
assert(operation.getLhs().getType().isa<IndexType>());
assert(operation.getRhs().getType().isa<IndexType>());

// TODO: ATM, handle only the case where we work on indices that are
// constant values
auto lhs = frames.back().constant_values.find(mlir::hash_value(operation.getLhs()));
auto rhs = frames.back().constant_values.find(mlir::hash_value(operation.getRhs()));
assert(lhs != frames.back().constant_values.end());
assert(rhs != frames.back().constant_values.end());
int64_t cmpResult;
switch (operation.getPredicate()) {
case arith::CmpIPredicate::eq:
cmpResult = static_cast<int64_t>(lhs->second == rhs->second);
break;
case arith::CmpIPredicate::ne:
cmpResult = static_cast<int64_t>(lhs->second != rhs->second);
break;
case arith::CmpIPredicate::slt:
cmpResult = static_cast<int64_t>(lhs->second < rhs->second);
break;
case arith::CmpIPredicate::sle:
cmpResult = static_cast<int64_t>(lhs->second <= rhs->second);
break;
case arith::CmpIPredicate::sgt:
cmpResult = static_cast<int64_t>(lhs->second > rhs->second);
break;
case arith::CmpIPredicate::sge:
cmpResult = static_cast<int64_t>(lhs->second >= rhs->second);
break;
case arith::CmpIPredicate::ult:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) <
static_cast<uint64_t>(rhs->second));
break;
case arith::CmpIPredicate::ule:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) <=
static_cast<uint64_t>(rhs->second));
break;
case arith::CmpIPredicate::ugt:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) >
static_cast<uint64_t>(rhs->second));
break;
case arith::CmpIPredicate::uge:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) >=
static_cast<uint64_t>(rhs->second));
break;
}
frames.back().constant_values[mlir::hash_value(operation.getResult())] = cmpResult;
} else if (arith::ConstantOp operation = llvm::dyn_cast<arith::ConstantOp>(op)) {
TypedAttr constantValue = operation.getValueAttr();
if (constantValue.isa<IntegerAttr>()) {
Expand Down Expand Up @@ -358,10 +466,21 @@ namespace zk_ml_toolchain {
}
} else if (arith::IndexCastOp operation = llvm::dyn_cast<arith::IndexCastOp>(op)) {
assert(operation->getNumOperands() == 1 && "IndexCast must have exactly one operand");
auto index = frames.back().constant_values[mlir::hash_value(operation->getOperand(0))];
typename BlueprintFieldType::value_type field_constant = index;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
auto opHash = mlir::hash_value(operation->getOperand(0));
// from int to index
if (operation->getOperand(0).getType().isa<IntegerType>()) {
auto i = frames.back().locals.find(opHash);
assert(i != frames.back().locals.end());
frames.back().constant_values[mlir::hash_value(operation.getResult())] = resolve_number<int64_t>(i->second);
} else if (operation->getOperand(0).getType().isa<IndexType>()) {
auto index = frames.back().constant_values.find(opHash);
assert(index != frames.back().constant_values.end());
typename BlueprintFieldType::value_type field_constant = index->second;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
} else {
UNREACHABLE("unsupported Index Cast");
}
} else if (arith::SIToFPOp operation = llvm::dyn_cast<arith::SIToFPOp>(op)) {
// TODO this does not respect negative and no different ranges for ints...
handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row);
Expand Down Expand Up @@ -543,16 +662,16 @@ namespace zk_ml_toolchain {
// Create the global at the entry of the module.
assert(operation.getValue().has_value() && "Krnl Global must always have a value");
auto value = operation.getValue().value();
//TODO check other bit sizes. Also no range constraint is this necessary????
// TODO check other bit sizes. Also no range constraint is this necessary????
if (DenseElementsAttr attr = llvm::dyn_cast<DenseElementsAttr>(value)) {
mlir::Type attrType = attr.getElementType();
if (attrType.isa<mlir::IntegerType>()) {
auto ints = attr.tryGetValues<APInt>();
assert(!mlir::failed(ints) && "must work as we checked above");
size_t idx = 0;
for (auto a : ints.value()) {
auto var = put_into_assignment(a.getSExtValue());
m.put_flat(idx++, var);
auto var = put_into_assignment(a.getSExtValue());
m.put_flat(idx++, var);
}
} else if (attrType.isa<mlir::FloatType>()) {
auto floats = attr.tryGetValues<APFloat>();
Expand All @@ -572,7 +691,7 @@ namespace zk_ml_toolchain {
m.put_flat(idx++, var);
}
} else {
UNREACHABLE("Unsupported attribute type");
UNREACHABLE("Unsupported attribute type");
}
} else {
UNREACHABLE("Expected a DenseElementsAttr");
Expand Down
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<f32>[1]
3
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<10xf32>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
12
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
 :`

in_aout_a"ConstantOfShapeConstantOfShapeSimple*:
Bin_ab
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file added mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<3x12xf32>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]
11
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<3x12xi32>[5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 2]
11
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Or/OrSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi1>[1, 1, 0, 0, 1, 1, 0, 1, 1, 1]
23
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file added mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<2xf32>[3, 6]
4
Loading