Skip to content

Commit

Permalink
feat: bitwise operations (#27)
Browse files Browse the repository at this point in the history
* added functionality for bitwise ops in evaluator
* added tests for bitwise
* pinned blueprint to correct commit
  • Loading branch information
0xThemis authored Jan 11, 2024
1 parent 75f4b45 commit 537e221
Show file tree
Hide file tree
Showing 46 changed files with 352 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef CRYPTO3_ASSIGNER_AND_HPP
#define CRYPTO3_ASSIGNER_AND_HPP

#include <cstdint>
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>
Expand All @@ -51,7 +52,15 @@ namespace nil {
using component_type = components::lookup_logic_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::AndIOp);
auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

typename component_type::input_type input;
input.input[0] = lhs->second;
input.input[1] = rhs->second;

const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand All @@ -70,17 +79,8 @@ namespace nil {
std::uint32_t start_row) {
using component_type = components::logic_or_flag<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
using input_type = typename component_type::input_type;

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

input_type input;
input.x = lhs->second;
input.y = rhs->second;
const auto p = detail::PolicyManager::get_parameters(
auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp) const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
Expand All @@ -98,7 +98,14 @@ namespace nil {
using component_type = components::lookup_logic_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::XOrIOp);
auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

typename component_type::input_type input;
input.input[0] = lhs->second;
input.input[1] = rhs->second;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand All @@ -107,6 +114,72 @@ namespace nil {
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
void handle_bitwise_and(
mlir::arith::AndIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::bitwise_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::AndIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
void handle_bitwise_or(
mlir::arith::OrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::bitwise_or<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
void handle_bitwise_xor(
mlir::arith::XOrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::bitwise_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::XOrIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

} // namespace blueprint
} // namespace nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmax.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sqrt.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_and.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_or.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_xor.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/logic_or_flag.hpp>

Expand Down Expand Up @@ -75,19 +78,17 @@ namespace nil {
&assignment) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

input_type instance_input;
instance_input.input[0] = x;
instance_input.input[1] = y;
instance_input.x = lhs->second;
instance_input.y = rhs->second;
return instance_input;
}


template<typename BlueprintFieldType, typename ArithmetizationParams, typename ComponentType>
void handle_component_input(
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
Expand Down
34 changes: 28 additions & 6 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,24 @@ namespace zk_ml_toolchain {
// std::cout << toFixpoint(Lhs) << " * " << toFixpoint(Rhs) << " = " << toFixpoint(Result) << "\n";
}

#define BITSWITCHER(func, b) \
switch (b) { \
case 8: \
func<1>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
case 16: \
func<2>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
case 32: \
func<4>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
case 64: \
func<8>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
default: \
UNREACHABLE(std::string("unsupported int bit size for bitwise op: ") + std::to_string(b)); \
}

void handleArithOperation(Operation *op) {
std::uint32_t start_row = assignmnt.allocated_rows();
if (arith::AddFOp operation = llvm::dyn_cast<arith::AddFOp>(op)) {
Expand Down Expand Up @@ -332,10 +350,11 @@ namespace zk_ml_toolchain {
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for AndIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
uint8_t bits = LhsType.getIntOrFloatBitWidth();
if (1 == bits) {
handle_logic_and(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise And Gadget");
BITSWITCHER(handle_bitwise_and, bits);
}
} else if (arith::OrIOp operation = llvm::dyn_cast<arith::OrIOp>(op)) {
ASSERT(operation.getNumOperands() == 2 && "Or must have two operands");
Expand All @@ -355,21 +374,23 @@ namespace zk_ml_toolchain {
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
unsigned bits = LhsType.getIntOrFloatBitWidth();
if (1 == bits) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
BITSWITCHER(handle_bitwise_or, bits);
}
}
} else if (arith::XOrIOp operation = llvm::dyn_cast<arith::XOrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for XOrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
unsigned bits = LhsType.getIntOrFloatBitWidth();
if (1 == bits) {
handle_logic_xor(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise XOr Gadget");
BITSWITCHER(handle_bitwise_xor, bits);
}
} else if (arith::AddIOp operation = llvm::dyn_cast<arith::AddIOp>(op)) {
// TODO: ATM, handle only the case where we work on indices that are
Expand Down Expand Up @@ -524,6 +545,7 @@ namespace zk_ml_toolchain {
UNREACHABLE(std::string("unhandled arith operation: ") + opName);
}
}
#undef BITSWITCHER

void handleMathOperation(Operation *op) {
std::uint32_t start_row = assignmnt.allocated_rows();
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-23105, 19325, 20849, -30787, -28574, 6400, -17468, -25069, -3772, -5585], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [12315, -8527, -15259, -2999, 26965, -9250, -9078, 9823, -13382, -24594], "dims": [1, 10], "type": "int"}}]
20 changes: 20 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseAnd/BitwiseAndI16Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
 :

in_a
in_bout_a"
BitwiseAndBitwiseAndI16SimpleZ
in_a



Z
in_b



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi16>[8219, 18993, 16481, -31735, 64, 6400, -26496, 1555, -16128, -30162]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-17405, -6334, 17342, -2390, 11827, -4355, 97, 24390, 5832, 21905], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [18283, 12705, -22076, 23371, -15509, -21397, 857, -24556, -30933, -31965], "dims": [1, 10], "type": "int"}}]
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
 :|
 :

in_a
in_bout_a"
BitwiseAndBitwiseAndSimpleZ
BitwiseAndBitwiseAndI32SimpleZ
in_a


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi32>[1027, 8448, 388, 21002, 547, -21399, 65, 4, 1544, 257]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-31058, 24103, 20613, -14245, 30834, -9972, -16062, 24153, 6507, -2612], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [-5647, 8761, -16974, -18729, 26300, -32249, 2301, -19788, 27028, -9132], "dims": [1, 10], "type": "int"}}]
20 changes: 20 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseAnd/BitwiseAndI64Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
 :

in_a
in_bout_a"
BitwiseAndBitwiseAndI64SimpleZ
in_a



Z
in_b



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi64>[-32608, 545, 4224, -32685, 24624, -32764, 64, 4624, 2304, -11196]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-99, 68, 14, -106, 31, 114, 45, 30, 35, 50], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [51, 77, 19, -123, 102, -44, -103, -7, 84, -105], "dims": [1, 10], "type": "int"}}]
20 changes: 20 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseAnd/BitwiseAndI8Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
 :~

in_a
in_bout_a"
BitwiseAndBitwiseAndI8SimpleZ
in_a



Z
in_b



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi8>[17, 68, 2, -124, 6, 80, 9, 24, 0, 18]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-23105, 19325, 20849, -30787, -28574, 6400, -17468, -25069, -3772, -5585], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [12315, -8527, -15259, -2999, 26965, -9250, -9078, 9823, -13382, -24594], "dims": [1, 10], "type": "int"}}]
19 changes: 19 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseOr/BitwiseOrI16Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 :}

in_a
in_bout_a" BitwiseOrBitwiseOrI16SimpleZ
in_a



Z
in_b



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseOr/BitwiseOrI16Simple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi16>[-19009, -8195, -10891, -2051, -1673, -9250, -50, -16801, -1026, -17]
22
Loading

0 comments on commit 537e221

Please sign in to comment.