Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: bitwise operations #14

Merged
merged 3 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions mlir-assigner/include/mlir-assigner/components/boolean/and.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//---------------------------------------------------------------------------//
// Copyright (c) 2023 Nikita Kaskov <[email protected]>
//
// MIT License
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//---------------------------------------------------------------------------//

#ifndef CRYPTO3_ASSIGNER_AND_HPP
#define CRYPTO3_ASSIGNER_AND_HPP

#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_and(
mlir::arith::AndIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::lookup_logic_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::AndIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_or(
mlir::arith::OrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
//FIXME logic_or is commented out. As soon as it is enabled, remove add the liens above and it SHOULD work
UNREACHABLE("LogicOR not enabled in blueprint");
// using component_type = components::lookup_logic_or<
// crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
//
// auto input = PREPARE_INPUT(mlir::arith::OrIOp);
// const auto p = detail::PolicyManager::get_parameters(
// detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
//
// component_type component(p.witness);
// fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_xor(
mlir::arith::XOrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::lookup_logic_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::XOrIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_LOGIC_OPS_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>

#include <mlir-assigner/components/fields/addition.hpp>

namespace nil {
namespace blueprint {
namespace detail {
Expand Down
117 changes: 117 additions & 0 deletions mlir-assigner/include/mlir-assigner/components/handle_component.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

//---------------------------------------------------------------------------//
// Copyright (c) 2023 Nikita Kaskov <[email protected]>
//
// MIT License
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//---------------------------------------------------------------------------//

#ifndef CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP
#define CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP

#define PREPARE_INPUT(OP) \
prepare_binary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
typename component_type::input_type>(operation, frame, bp, assignment);

#include <mlir-assigner/memory/stack_frame.hpp>
namespace nil {
namespace blueprint {
template<typename BlueprintFieldType, typename ArithmetizationParams, typename BinOp, typename input_type>
input_type prepare_binary_operation_input(
BinOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

input_type instance_input;
instance_input.input[0] = x;
instance_input.input[1] = y;
return instance_input;
}

template<typename BlueprintFieldType, typename ArithmetizationParams, typename ComponentType>
void handle_component_input(
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
typename ComponentType::input_type &instance_input) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

std::vector<var> all_vars = instance_input.all_vars();
std::vector<std::reference_wrapper<var>> input(all_vars.begin(), all_vars.end());
const auto &used_rows = assignment.get_used_rows();

for (auto &v : input) {
bool found = (used_rows.find(v.get().rotation) != used_rows.end());
if (!found &&
(v.get().type == var::column_type::witness || v.get().type == var::column_type::constant)) {
const auto new_v = save_shared_var(assignment, v);
v.get().index = new_v.index;
v.get().rotation = new_v.rotation;
v.get().relative = new_v.relative;
v.get().type = new_v.type;
}
}
}

template<typename BlueprintFieldType, typename ArithmetizationParams, typename component_type, typename BinOp>
void fill_trace(
component_type &component,
typename component_type::input_type &input,
BinOp &mlir_op,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

handle_component_input<BlueprintFieldType, ArithmetizationParams, component_type>(assignment, input);

components::generate_circuit(component, bp, assignment, input, start_row);
auto result = components::generate_assignments(component, assignment, input, start_row);
frame.locals[mlir::hash_value(mlir_op.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP
32 changes: 24 additions & 8 deletions mlir-assigner/include/mlir-assigner/memory/memref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,37 @@ namespace nil {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void print(
llvm::raw_ostream &os,
std::ostream& os,
const assignment<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {
os << "memref<";
for (int i = 0; i < dims.size(); i++) {
os << dims[i];
os << "x";
}
os << type << ">[";
for (int i = 0; i < data.size(); i++) {
auto value = var_value(assignment, data[i]).data;
components::FixedPoint<BlueprintFieldType, 1, 1> out(value, 16);
os << out.to_double();
if (i != data.size() - 1)
os << ",";
std::string type_str;
llvm::raw_string_ostream ss(type_str);
ss << type << ">[";
os << type_str;
if (type.isa<mlir::IntegerType>()) {
if (type.getIntOrFloatBitWidth() == 1) {
//bool
for (int i = 0; i < data.size(); i++) {
os << var_value(assignment, data[i]).data;
if (i != data.size() - 1)
os << ",";
}
} else {
//int
}
} else if (type.isa<mlir::FloatType>()) {
for (int i = 0; i < data.size(); i++) {
auto value = var_value(assignment, data[i]).data;
components::FixedPoint<BlueprintFieldType, 1, 1> out(value, 16);
os << out.to_double();
if (i != data.size() - 1)
os << ",";
}
}
os << "]\n";
}
Expand Down
44 changes: 40 additions & 4 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP
#define CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP

#include <cassert>
#include <cstdint>
#define TEST_WITHOUT_LOOKUP_TABLES

#include "mlir-assigner/helper/asserts.hpp"
Expand Down Expand Up @@ -46,6 +48,7 @@
#include <mlir-assigner/components/fixedpoint/remainder.hpp>
#include <mlir-assigner/components/fixedpoint/subtraction.hpp>
#include <mlir-assigner/components/fixedpoint/dot_product.hpp>
#include <mlir-assigner/components/boolean/and.hpp>

#include <mlir-assigner/memory/memref.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
Expand Down Expand Up @@ -248,7 +251,35 @@ namespace zk_ml_toolchain {
} else if (arith::NegFOp operation = llvm::dyn_cast<arith::NegFOp>(op)) {
handle_fixedpoint_neg_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::AndIOp operation = llvm::dyn_cast<arith::AndIOp>(op)) {
UNREACHABLE("TODO component not finished at nils side");
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for AndIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_and(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise And Gadget");
}
} else if (arith::OrIOp operation = llvm::dyn_cast<arith::OrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
}
} else if (arith::XOrIOp operation = llvm::dyn_cast<arith::XOrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for XOrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_xor(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise XOr Gadget");
}
} else if (arith::AddIOp operation = llvm::dyn_cast<arith::AddIOp>(op)) {

// TODO: ATM, handle only the case where we work on indices that are
Expand Down Expand Up @@ -291,12 +322,17 @@ namespace zk_ml_toolchain {
} else if (arith::ConstantOp operation = llvm::dyn_cast<arith::ConstantOp>(op)) {
TypedAttr constantValue = operation.getValueAttr();
if (constantValue.isa<IntegerAttr>()) {
int64_t value = llvm::dyn_cast<IntegerAttr>(constantValue).getInt();
// this insert is ok, since this should never change, so we don't
// override it if it is already there

// TACEO_TODO: better separation of constant values that come from the
// loop bounds an normal ones, ATM just do both
int64_t value;
if (constantValue.isa<BoolAttr>()) {
value = llvm::dyn_cast<BoolAttr>(constantValue).getValue() ? 1 : 0;
} else {
value = llvm::dyn_cast<IntegerAttr>(constantValue).getInt();
}
frames.back().constant_values.insert(
std::make_pair(mlir::hash_value(operation.getResult()), value));

Expand Down Expand Up @@ -763,8 +799,8 @@ namespace zk_ml_toolchain {
auto retval = frames.back().memrefs.find(mlir::hash_value(ops[0]));
assert(retval != frames.back().memrefs.end());
if (PrintCircuitOutput) {
llvm::outs() << "Result:\n";
retval->second.print(llvm::outs(), assignmnt);
std::cout << "Result:\n";
retval->second.print(std::cout, assignmnt);
}
return;
}
Expand Down
Loading