Skip to content

Commit

Permalink
feat: bitwise operations (#14)
Browse files Browse the repository at this point in the history
* fixed bool constant loading
* added tests for bitwise operations
  • Loading branch information
0xThemis authored Jan 2, 2024
1 parent 7b9f76a commit 0a5df0d
Show file tree
Hide file tree
Showing 33 changed files with 501 additions and 25 deletions.
105 changes: 105 additions & 0 deletions mlir-assigner/include/mlir-assigner/components/boolean/and.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//---------------------------------------------------------------------------//
// Copyright (c) 2023 Nikita Kaskov <[email protected]>
//
// MIT License
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//---------------------------------------------------------------------------//

#ifndef CRYPTO3_ASSIGNER_AND_HPP
#define CRYPTO3_ASSIGNER_AND_HPP

#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_and(
mlir::arith::AndIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::lookup_logic_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::AndIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_or(
mlir::arith::OrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
//FIXME logic_or is commented out. As soon as it is enabled, remove add the liens above and it SHOULD work
UNREACHABLE("LogicOR not enabled in blueprint");
// using component_type = components::lookup_logic_or<
// crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
//
// auto input = PREPARE_INPUT(mlir::arith::OrIOp);
// const auto p = detail::PolicyManager::get_parameters(
// detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
//
// component_type component(p.witness);
// fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_xor(
mlir::arith::XOrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::lookup_logic_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::XOrIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_LOGIC_OPS_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>

#include <mlir-assigner/components/fields/addition.hpp>

namespace nil {
namespace blueprint {
namespace detail {
Expand Down
117 changes: 117 additions & 0 deletions mlir-assigner/include/mlir-assigner/components/handle_component.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

//---------------------------------------------------------------------------//
// Copyright (c) 2023 Nikita Kaskov <[email protected]>
//
// MIT License
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//---------------------------------------------------------------------------//

#ifndef CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP
#define CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP

#define PREPARE_INPUT(OP) \
prepare_binary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
typename component_type::input_type>(operation, frame, bp, assignment);

#include <mlir-assigner/memory/stack_frame.hpp>
namespace nil {
namespace blueprint {
template<typename BlueprintFieldType, typename ArithmetizationParams, typename BinOp, typename input_type>
input_type prepare_binary_operation_input(
BinOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

input_type instance_input;
instance_input.input[0] = x;
instance_input.input[1] = y;
return instance_input;
}

template<typename BlueprintFieldType, typename ArithmetizationParams, typename ComponentType>
void handle_component_input(
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
typename ComponentType::input_type &instance_input) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

std::vector<var> all_vars = instance_input.all_vars();
std::vector<std::reference_wrapper<var>> input(all_vars.begin(), all_vars.end());
const auto &used_rows = assignment.get_used_rows();

for (auto &v : input) {
bool found = (used_rows.find(v.get().rotation) != used_rows.end());
if (!found &&
(v.get().type == var::column_type::witness || v.get().type == var::column_type::constant)) {
const auto new_v = save_shared_var(assignment, v);
v.get().index = new_v.index;
v.get().rotation = new_v.rotation;
v.get().relative = new_v.relative;
v.get().type = new_v.type;
}
}
}

template<typename BlueprintFieldType, typename ArithmetizationParams, typename component_type, typename BinOp>
void fill_trace(
component_type &component,
typename component_type::input_type &input,
BinOp &mlir_op,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

handle_component_input<BlueprintFieldType, ArithmetizationParams, component_type>(assignment, input);

components::generate_circuit(component, bp, assignment, input, start_row);
auto result = components::generate_assignments(component, assignment, input, start_row);
frame.locals[mlir::hash_value(mlir_op.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP
32 changes: 24 additions & 8 deletions mlir-assigner/include/mlir-assigner/memory/memref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,37 @@ namespace nil {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void print(
llvm::raw_ostream &os,
std::ostream& os,
const assignment<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {
os << "memref<";
for (int i = 0; i < dims.size(); i++) {
os << dims[i];
os << "x";
}
os << type << ">[";
for (int i = 0; i < data.size(); i++) {
auto value = var_value(assignment, data[i]).data;
components::FixedPoint<BlueprintFieldType, 1, 1> out(value, 16);
os << out.to_double();
if (i != data.size() - 1)
os << ",";
std::string type_str;
llvm::raw_string_ostream ss(type_str);
ss << type << ">[";
os << type_str;
if (type.isa<mlir::IntegerType>()) {
if (type.getIntOrFloatBitWidth() == 1) {
//bool
for (int i = 0; i < data.size(); i++) {
os << var_value(assignment, data[i]).data;
if (i != data.size() - 1)
os << ",";
}
} else {
//int
}
} else if (type.isa<mlir::FloatType>()) {
for (int i = 0; i < data.size(); i++) {
auto value = var_value(assignment, data[i]).data;
components::FixedPoint<BlueprintFieldType, 1, 1> out(value, 16);
os << out.to_double();
if (i != data.size() - 1)
os << ",";
}
}
os << "]\n";
}
Expand Down
44 changes: 40 additions & 4 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP
#define CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP

#include <cassert>
#include <cstdint>
#define TEST_WITHOUT_LOOKUP_TABLES

#include "mlir-assigner/helper/asserts.hpp"
Expand Down Expand Up @@ -46,6 +48,7 @@
#include <mlir-assigner/components/fixedpoint/remainder.hpp>
#include <mlir-assigner/components/fixedpoint/subtraction.hpp>
#include <mlir-assigner/components/fixedpoint/dot_product.hpp>
#include <mlir-assigner/components/boolean/and.hpp>

#include <mlir-assigner/memory/memref.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
Expand Down Expand Up @@ -248,7 +251,35 @@ namespace zk_ml_toolchain {
} else if (arith::NegFOp operation = llvm::dyn_cast<arith::NegFOp>(op)) {
handle_fixedpoint_neg_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::AndIOp operation = llvm::dyn_cast<arith::AndIOp>(op)) {
UNREACHABLE("TODO component not finished at nils side");
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for AndIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_and(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise And Gadget");
}
} else if (arith::OrIOp operation = llvm::dyn_cast<arith::OrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
}
} else if (arith::XOrIOp operation = llvm::dyn_cast<arith::XOrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for XOrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_xor(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise XOr Gadget");
}
} else if (arith::AddIOp operation = llvm::dyn_cast<arith::AddIOp>(op)) {

// TODO: ATM, handle only the case where we work on indices that are
Expand Down Expand Up @@ -291,12 +322,17 @@ namespace zk_ml_toolchain {
} else if (arith::ConstantOp operation = llvm::dyn_cast<arith::ConstantOp>(op)) {
TypedAttr constantValue = operation.getValueAttr();
if (constantValue.isa<IntegerAttr>()) {
int64_t value = llvm::dyn_cast<IntegerAttr>(constantValue).getInt();
// this insert is ok, since this should never change, so we don't
// override it if it is already there

// TACEO_TODO: better separation of constant values that come from the
// loop bounds an normal ones, ATM just do both
int64_t value;
if (constantValue.isa<BoolAttr>()) {
value = llvm::dyn_cast<BoolAttr>(constantValue).getValue() ? 1 : 0;
} else {
value = llvm::dyn_cast<IntegerAttr>(constantValue).getInt();
}
frames.back().constant_values.insert(
std::make_pair(mlir::hash_value(operation.getResult()), value));

Expand Down Expand Up @@ -763,8 +799,8 @@ namespace zk_ml_toolchain {
auto retval = frames.back().memrefs.find(mlir::hash_value(ops[0]));
assert(retval != frames.back().memrefs.end());
if (PrintCircuitOutput) {
llvm::outs() << "Result:\n";
retval->second.print(llvm::outs(), assignmnt);
std::cout << "Result:\n";
retval->second.print(std::cout, assignmnt);
}
return;
}
Expand Down
Loading

0 comments on commit 0a5df0d

Please sign in to comment.