Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: bitwise operations #27

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef CRYPTO3_ASSIGNER_AND_HPP
#define CRYPTO3_ASSIGNER_AND_HPP

#include <cstdint>
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>
Expand All @@ -51,7 +52,15 @@ namespace nil {
using component_type = components::lookup_logic_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::AndIOp);
auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

typename component_type::input_type input;
input.input[0] = lhs->second;
input.input[1] = rhs->second;

const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand All @@ -70,17 +79,8 @@ namespace nil {
std::uint32_t start_row) {
using component_type = components::logic_or_flag<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
using input_type = typename component_type::input_type;

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

input_type input;
input.x = lhs->second;
input.y = rhs->second;
const auto p = detail::PolicyManager::get_parameters(
auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp) const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
Expand All @@ -98,7 +98,14 @@ namespace nil {
using component_type = components::lookup_logic_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::XOrIOp);
auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

typename component_type::input_type input;
input.input[0] = lhs->second;
input.input[1] = rhs->second;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand All @@ -107,6 +114,72 @@ namespace nil {
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
void handle_bitwise_and(
mlir::arith::AndIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::bitwise_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::AndIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
void handle_bitwise_or(
mlir::arith::OrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::bitwise_or<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
void handle_bitwise_xor(
mlir::arith::XOrIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::bitwise_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::XOrIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

} // namespace blueprint
} // namespace nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmax.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sqrt.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_and.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_or.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_xor.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/logic_or_flag.hpp>

Expand Down Expand Up @@ -75,19 +78,17 @@ namespace nil {
&assignment) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

input_type instance_input;
instance_input.input[0] = x;
instance_input.input[1] = y;
instance_input.x = lhs->second;
instance_input.y = rhs->second;
return instance_input;
}


template<typename BlueprintFieldType, typename ArithmetizationParams, typename ComponentType>
void handle_component_input(
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
Expand Down
34 changes: 28 additions & 6 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,24 @@ namespace zk_ml_toolchain {
// std::cout << toFixpoint(Lhs) << " * " << toFixpoint(Rhs) << " = " << toFixpoint(Result) << "\n";
}

#define BITSWITCHER(func, b) \
switch (b) { \
case 8: \
func<1>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
case 16: \
func<2>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
case 32: \
func<4>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
case 64: \
func<8>(operation, frames.back(), bp, assignmnt, start_row); \
break; \
default: \
UNREACHABLE(std::string("unsupported int bit size for bitwise op: ") + std::to_string(b)); \
}

void handleArithOperation(Operation *op) {
std::uint32_t start_row = assignmnt.allocated_rows();
if (arith::AddFOp operation = llvm::dyn_cast<arith::AddFOp>(op)) {
Expand Down Expand Up @@ -332,10 +350,11 @@ namespace zk_ml_toolchain {
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for AndIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
uint8_t bits = LhsType.getIntOrFloatBitWidth();
if (1 == bits) {
handle_logic_and(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise And Gadget");
BITSWITCHER(handle_bitwise_and, bits);
}
} else if (arith::OrIOp operation = llvm::dyn_cast<arith::OrIOp>(op)) {
ASSERT(operation.getNumOperands() == 2 && "Or must have two operands");
Expand All @@ -355,21 +374,23 @@ namespace zk_ml_toolchain {
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
unsigned bits = LhsType.getIntOrFloatBitWidth();
if (1 == bits) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
BITSWITCHER(handle_bitwise_or, bits);
}
}
} else if (arith::XOrIOp operation = llvm::dyn_cast<arith::XOrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for XOrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
unsigned bits = LhsType.getIntOrFloatBitWidth();
if (1 == bits) {
handle_logic_xor(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise XOr Gadget");
BITSWITCHER(handle_bitwise_xor, bits);
}
} else if (arith::AddIOp operation = llvm::dyn_cast<arith::AddIOp>(op)) {
// TODO: ATM, handle only the case where we work on indices that are
Expand Down Expand Up @@ -524,6 +545,7 @@ namespace zk_ml_toolchain {
UNREACHABLE(std::string("unhandled arith operation: ") + opName);
}
}
#undef BITSWITCHER

void handleMathOperation(Operation *op) {
std::uint32_t start_row = assignmnt.allocated_rows();
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-23105, 19325, 20849, -30787, -28574, 6400, -17468, -25069, -3772, -5585], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [12315, -8527, -15259, -2999, 26965, -9250, -9078, 9823, -13382, -24594], "dims": [1, 10], "type": "int"}}]
20 changes: 20 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseAnd/BitwiseAndI16Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
 :

in_a
in_bout_a"
BitwiseAndBitwiseAndI16SimpleZ
in_a



Z
in_b



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi16>[8219, 18993, 16481, -31735, 64, 6400, -26496, 1555, -16128, -30162]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-17405, -6334, 17342, -2390, 11827, -4355, 97, 24390, 5832, 21905], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [18283, 12705, -22076, 23371, -15509, -21397, 857, -24556, -30933, -31965], "dims": [1, 10], "type": "int"}}]
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
 :|
 :

in_a
in_bout_a"
BitwiseAndBitwiseAndSimpleZ
BitwiseAndBitwiseAndI32SimpleZ
in_a


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi32>[1027, 8448, 388, 21002, 547, -21399, 65, 4, 1544, 257]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-31058, 24103, 20613, -14245, 30834, -9972, -16062, 24153, 6507, -2612], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [-5647, 8761, -16974, -18729, 26300, -32249, 2301, -19788, 27028, -9132], "dims": [1, 10], "type": "int"}}]
20 changes: 20 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseAnd/BitwiseAndI64Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
 :

in_a
in_bout_a"
BitwiseAndBitwiseAndI64SimpleZ
in_a



Z
in_b



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi64>[-32608, 545, 4224, -32685, 24624, -32764, 64, 4624, 2304, -11196]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-99, 68, 14, -106, 31, 114, 45, 30, 35, 50], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [51, 77, 19, -123, 102, -44, -103, -7, 84, -105], "dims": [1, 10], "type": "int"}}]
20 changes: 20 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseAnd/BitwiseAndI8Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
 :~

in_a
in_bout_a"
BitwiseAndBitwiseAndI8SimpleZ
in_a



Z
in_b



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi8>[17, 68, 2, -124, 6, 80, 9, 24, 0, 18]
22
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-23105, 19325, 20849, -30787, -28574, 6400, -17468, -25069, -3772, -5585], "dims": [1, 10], "type": "int"}}, {"memref": {"data": [12315, -8527, -15259, -2999, 26965, -9250, -9078, 9823, -13382, -24594], "dims": [1, 10], "type": "int"}}]
19 changes: 19 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseOr/BitwiseOrI16Simple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 :}

in_a
in_bout_a" BitwiseOrBitwiseOrI16SimpleZ
in_a



Z
in_b



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/BitwiseOr/BitwiseOrI16Simple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi16>[-19009, -8195, -10891, -2051, -1673, -9250, -50, -16801, -1026, -17]
22
Loading