Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: apply new formatting
Browse files Browse the repository at this point in the history
dkales committed Dec 11, 2023

Verified

This commit was signed with the committer’s verified signature.
primeos Michael Weiss
1 parent 051cdbe commit cae3c94
Showing 30 changed files with 16,970 additions and 3,465 deletions.
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@
#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/components/algebra/fixedpoint/plonk/cmp_extended.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <nil/blueprint/component.hpp>

@@ -40,131 +40,115 @@
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>
handle_f_comparison_component(
mlir::arith::CmpFPredicate p,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;
using component_type = components::fix_cmp_extended<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(
0));
component_type component_instance(
params.witness,
ManifestReader<component_type, ArithmetizationParams, 1,
1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1,
1>::get_public_inputs(),
1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x, y},
start_row);
auto cmp_result = components::generate_assignments(
component_instance, assignment, {x, y}, start_row);

switch (p) {
case mlir::arith::CmpFPredicate::UGT:
case mlir::arith::CmpFPredicate::OGT: {
return cmp_result.gt;
}
case mlir::arith::CmpFPredicate::ULT:
case mlir::arith::CmpFPredicate::OLT: {
return cmp_result.lt;
}
case mlir::arith::CmpFPredicate::UGE:
case mlir::arith::CmpFPredicate::OGE: {
return cmp_result.geq;
}
case mlir::arith::CmpFPredicate::ULE:
case mlir::arith::CmpFPredicate::OLE: {
return cmp_result.leq;
}
case mlir::arith::CmpFPredicate::UNE:
case mlir::arith::CmpFPredicate::ONE: {
return cmp_result.neq;
}
case mlir::arith::CmpFPredicate::UEQ:
case mlir::arith::CmpFPredicate::OEQ: {
return cmp_result.eq;
}
case mlir::arith::CmpFPredicate::UNO:
case mlir::arith::CmpFPredicate::ORD:
case mlir::arith::CmpFPredicate::AlwaysFalse:
case mlir::arith::CmpFPredicate::AlwaysTrue: {
UNREACHABLE("TACEO_TODO implement fcmp");
break;
}
default:
UNREACHABLE("Unsupported fcmp predicate");
break;
}
}
} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_comparison_component(
mlir::arith::CmpFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto pred = operation.getPredicate();

auto x = lhs->second;
auto y = rhs->second;

// std::stringstream ss;
// ss << var_value(assignment, x) << " cmp " << var_value(assignment, y) <<
// "\n"; llvm::outs() << ss.str();

// TACEO_TODO: check types

auto result = detail::handle_f_comparison_component(pred, x, y, bp,
assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_F_COMPARISON_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
handle_f_comparison_component(
mlir::arith::CmpFPredicate p,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;
using component_type = components::fix_cmp_extended<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0));
component_type component_instance(
params.witness, ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_public_inputs(), 1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x, y}, start_row);
auto cmp_result = components::generate_assignments(component_instance, assignment, {x, y}, start_row);

switch (p) {
case mlir::arith::CmpFPredicate::UGT:
case mlir::arith::CmpFPredicate::OGT: {
return cmp_result.gt;
}
case mlir::arith::CmpFPredicate::ULT:
case mlir::arith::CmpFPredicate::OLT: {
return cmp_result.lt;
}
case mlir::arith::CmpFPredicate::UGE:
case mlir::arith::CmpFPredicate::OGE: {
return cmp_result.geq;
}
case mlir::arith::CmpFPredicate::ULE:
case mlir::arith::CmpFPredicate::OLE: {
return cmp_result.leq;
}
case mlir::arith::CmpFPredicate::UNE:
case mlir::arith::CmpFPredicate::ONE: {
return cmp_result.neq;
}
case mlir::arith::CmpFPredicate::UEQ:
case mlir::arith::CmpFPredicate::OEQ: {
return cmp_result.eq;
}
case mlir::arith::CmpFPredicate::UNO:
case mlir::arith::CmpFPredicate::ORD:
case mlir::arith::CmpFPredicate::AlwaysFalse:
case mlir::arith::CmpFPredicate::AlwaysTrue: {
UNREACHABLE("TACEO_TODO implement fcmp");
break;
}
default:
UNREACHABLE("Unsupported fcmp predicate");
break;
}
}
} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_comparison_component(
mlir::arith::CmpFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto pred = operation.getPredicate();

auto x = lhs->second;
auto y = rhs->second;

// std::stringstream ss;
// ss << var_value(assignment, x) << " cmp " << var_value(assignment, y) <<
// "\n"; llvm::outs() << ss.str();

// TACEO_TODO: check types

auto result = detail::handle_f_comparison_component(pred, x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_F_COMPARISON_HPP
124 changes: 54 additions & 70 deletions mlir-assigner/include/mlir-assigner/components/comparison/select.hpp
Original file line number Diff line number Diff line change
@@ -39,79 +39,63 @@
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_select_component(
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &c,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;
using component_type = components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
params.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type,
ArithmetizationParams>::get_public_inputs());
template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_select_component(
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &c,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;
using component_type = components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
params.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type, ArithmetizationParams>::get_public_inputs());

components::generate_circuit(component_instance, bp, assignment, {c, x, y},
start_row);
return components::generate_assignments(component_instance, assignment,
{c, x, y}, start_row);
}
} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_select_component(
mlir::arith::SelectOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
components::generate_circuit(component_instance, bp, assignment, {c, x, y}, start_row);
return components::generate_assignments(component_instance, assignment, {c, x, y}, start_row);
}
} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_select_component(
mlir::arith::SelectOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto false_value =
frame.locals.find(mlir::hash_value(operation.getFalseValue()));
ASSERT(false_value != frame.locals.end());
auto true_value =
frame.locals.find(mlir::hash_value(operation.getTrueValue()));
ASSERT(true_value != frame.locals.end());
auto condition =
frame.locals.find(mlir::hash_value(operation.getCondition()));
ASSERT(condition != frame.locals.end());
auto false_value = frame.locals.find(mlir::hash_value(operation.getFalseValue()));
ASSERT(false_value != frame.locals.end());
auto true_value = frame.locals.find(mlir::hash_value(operation.getTrueValue()));
ASSERT(true_value != frame.locals.end());
auto condition = frame.locals.find(mlir::hash_value(operation.getCondition()));
ASSERT(condition != frame.locals.end());

auto c = condition->second;
auto x = true_value->second;
auto y = false_value->second;
auto c = condition->second;
auto x = true_value->second;
auto y = false_value->second;

// TACEO_TODO: check types
// TACEO_TODO: check types

auto result =
detail::handle_select_component(c, x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_SELECT_HPP
auto result = detail::handle_select_component(c, x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_SELECT_HPP
Original file line number Diff line number Diff line change
@@ -36,62 +36,58 @@ namespace nil {
namespace detail {

struct FlexibleParameters {
std::vector <std::uint32_t> witness;
std::vector<std::uint32_t> witness;

FlexibleParameters(std::uint32_t witness_amount) {
witness.resize(witness_amount);
std::iota(witness.begin(), witness.end(), 0); // fill 0, 1, ...
std::iota(witness.begin(), witness.end(), 0); // fill 0, 1, ...
}
};

template<typename ArithmetizationParams>
struct CompilerRestrictions {
inline static compiler_manifest common_restriction_manifest = compiler_manifest(ArithmetizationParams::witness_columns,
std::numeric_limits<std::int32_t>::max() - 1,
std::numeric_limits<std::int32_t>::max(), true);
inline static compiler_manifest common_restriction_manifest = compiler_manifest(
ArithmetizationParams::witness_columns, std::numeric_limits<std::int32_t>::max() - 1,
std::numeric_limits<std::int32_t>::max(), true);
};

template<typename ComponentType, typename ArithmetizationParams, uint8_t... manifest_args>
struct ManifestReader {
inline static typename ComponentType::manifest_type manifest =
CompilerRestrictions<ArithmetizationParams>::common_restriction_manifest.intersect(ComponentType::get_manifest(manifest_args...));
CompilerRestrictions<ArithmetizationParams>::common_restriction_manifest.intersect(
ComponentType::get_manifest(manifest_args...));

template<typename... Args>
static std::vector <std::pair<std::uint32_t, std::uint32_t>>
get_witness(Args... args) {
static std::vector<std::pair<std::uint32_t, std::uint32_t>> get_witness(Args... args) {
ASSERT(manifest.is_satisfiable());
auto witness_amount_ptr = manifest.witness_amount;
std::vector <std::pair<std::uint32_t, std::uint32_t>> values;
for (auto it = witness_amount_ptr->begin();
it != witness_amount_ptr->end(); it++) {
std::vector<std::pair<std::uint32_t, std::uint32_t>> values;
for (auto it = witness_amount_ptr->begin(); it != witness_amount_ptr->end(); it++) {
const auto witness_amount = *it;
const auto rows_amount = ComponentType::get_rows_amount(witness_amount,
args...);
const auto rows_amount = ComponentType::get_rows_amount(witness_amount, args...);
const auto total_amount_rows_power_two = std::pow(2, std::ceil(std::log2(rows_amount)));
const auto total_amount_of_gates = ComponentType::get_gate_manifest(witness_amount, args...).get_gates_amount();
values.emplace_back(witness_amount,
total_amount_rows_power_two + total_amount_of_gates);
const auto total_amount_of_gates =
ComponentType::get_gate_manifest(witness_amount, args...).get_gates_amount();
values.emplace_back(witness_amount, total_amount_rows_power_two + total_amount_of_gates);
}
ASSERT(values.size() > 0);
return values;
}

static typename ComponentType::component_type::constant_container_type
get_constants() {
static typename ComponentType::component_type::constant_container_type get_constants() {
typename ComponentType::component_type::constant_container_type constants;
std::iota(constants.begin(), constants.end(), 0); // fill 0, 1, ...
std::iota(constants.begin(), constants.end(), 0); // fill 0, 1, ...
return constants;
}

static typename ComponentType::component_type::public_input_container_type
get_public_inputs() {
static typename ComponentType::component_type::public_input_container_type get_public_inputs() {
typename ComponentType::component_type::public_input_container_type public_inputs;
std::iota(public_inputs.begin(), public_inputs.end(), 0); // fill 0, 1, ...
std::iota(public_inputs.begin(), public_inputs.end(), 0); // fill 0, 1, ...
return public_inputs;
}
};
} // namespace detail
} // namespace blueprint
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_COMPONENT_MANIFEST_UTILITIES_HPP
70 changes: 32 additions & 38 deletions mlir-assigner/include/mlir-assigner/components/fields/addition.hpp
Original file line number Diff line number Diff line change
@@ -42,47 +42,41 @@
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::addition<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_native_field_addition_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::addition<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_native_field_addition_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::addition<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type,
ArithmetizationParams>::get_public_inputs());
using component_type = components::addition<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type, ArithmetizationParams>::get_public_inputs());

components::generate_circuit(component_instance, bp, assignment, {x, y},
start_row);
return components::generate_assignments(component_instance, assignment,
{x, y}, start_row);
}
components::generate_circuit(component_instance, bp, assignment, {x, y}, start_row);
return components::generate_assignments(component_instance, assignment, {x, y}, start_row);
}

} // namespace detail
} // namespace detail

} // namespace blueprint
} // namespace nil
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIELD_ADDITION_HPP
#endif // CRYPTO3_ASSIGNER_FIELD_ADDITION_HPP
Original file line number Diff line number Diff line change
@@ -42,47 +42,41 @@
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::subtraction<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_native_field_subtraction_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::subtraction<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_native_field_subtraction_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::subtraction<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type,
ArithmetizationParams>::get_public_inputs());
using component_type = components::subtraction<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type, ArithmetizationParams>::get_public_inputs());

components::generate_circuit(component_instance, bp, assignment, {x, y},
start_row);
return components::generate_assignments(component_instance, assignment,
{x, y}, start_row);
}
components::generate_circuit(component_instance, bp, assignment, {x, y}, start_row);
return components::generate_assignments(component_instance, assignment, {x, y}, start_row);
}

} // namespace detail
} // namespace detail

} // namespace blueprint
} // namespace nil
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIELD_SUBTRACTION_HPP
#endif // CRYPTO3_ASSIGNER_FIELD_SUBTRACTION_HPP
157 changes: 73 additions & 84 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/abs.hpp
Original file line number Diff line number Diff line change
@@ -8,93 +8,82 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sign_abs.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_sign_abs<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_abs_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;

using component_type = components::fix_sign_abs<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader =
ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x},
start_row);
return components::generate_assignments(component_instance, assignment, {x},
start_row);
}

} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_abs_component(
mlir::math::AbsFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result =
detail::handle_fixedpoint_abs_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.abs;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_ABS_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_sign_abs<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_abs_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

using component_type = components::fix_sign_abs<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader = ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x}, start_row);
return components::generate_assignments(component_instance, assignment, {x}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_abs_component(
mlir::math::AbsFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_abs_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.abs;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_ABS_HPP
Original file line number Diff line number Diff line change
@@ -15,34 +15,31 @@
#include <mlir-assigner/components/fields/addition.hpp>

namespace nil {
namespace blueprint {

template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_addition_component(
mlir::arith::AddFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_native_field_addition_component(
x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_ADDITION_HPP
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_addition_component(
mlir::arith::AddFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_native_field_addition_component(x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_ADDITION_HPP
157 changes: 73 additions & 84 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/ceil.hpp
Original file line number Diff line number Diff line change
@@ -8,93 +8,82 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/ceil.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_ceil<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_ceil_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;

using component_type = components::fix_ceil<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader =
ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x},
start_row);
return components::generate_assignments(component_instance, assignment, {x},
start_row);
}

} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_ceil_component(
mlir::math::CeilOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result =
detail::handle_fixedpoint_ceil_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_CEIL_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_ceil<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_ceil_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

using component_type = components::fix_ceil<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader = ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x}, start_row);
return components::generate_assignments(component_instance, assignment, {x}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_ceil_component(
mlir::math::CeilOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_ceil_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_CEIL_HPP
170 changes: 78 additions & 92 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/division.hpp
Original file line number Diff line number Diff line change
@@ -8,101 +8,87 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/div.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_div<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_division_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::fix_div<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(
0, 1, 1));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams, 1,
1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1,
1>::get_public_inputs(),
1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x, y},
start_row);
return components::generate_assignments(component_instance, assignment,
{x, y}, start_row);
}

} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_division_component(
mlir::arith::DivFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_division_component(
x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_DIVISION_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_div<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_division_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::fix_div<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));
component_type component_instance(
p.witness, ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_public_inputs(), 1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x, y}, start_row);
return components::generate_assignments(component_instance, assignment, {x, y}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_division_component(
mlir::arith::DivFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_division_component(x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_DIVISION_HPP
Original file line number Diff line number Diff line change
@@ -9,116 +9,100 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/dot_rescale_2_gates.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>

#include <mlir-assigner/components/fields/addition.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_dot_product_component(
memref<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>>
x,
memref<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>>
y,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
&zero_var,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;

using component_type = components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader =
ManifestReader<component_type, ArithmetizationParams, 1, 1>;
auto dims = x.getDims();
ASSERT(dims.size() == 1 && "must be one-dim for dot product");
const auto p = PolicyManager::get_parameters(
manifest_reader::get_witness(0, dims.front(), 1));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(),
dims.front(), 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

using DotProductInputType =
const typename components::plonk_fixedpoint_dot_rescale_2_gates<
BlueprintFieldType, ArithmetizationParams>::input_type;

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

DotProductInputType input = {x.getData(), y.getData(), zero_var};

components::generate_circuit(component_instance, bp, assignment, input,
start_row);
return components::generate_assignments(component_instance, assignment, input,
start_row);
}

} // namespace detail

template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_dot_product_component(
mlir::zkml::DotProductOp &operation,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
&zero_var,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment) {

auto lhs = frame.memrefs.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.memrefs.end());
auto rhs = frame.memrefs.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.memrefs.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_dot_product_component(
x, y, zero_var, bp, assignment, assignment.allocated_rows());
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_DOT_PRODUCT_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_dot_product_component(
memref<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>>
x,
memref<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>>
y,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &zero_var,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

using component_type = components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = ManifestReader<component_type, ArithmetizationParams, 1, 1>;
auto dims = x.getDims();
ASSERT(dims.size() == 1 && "must be one-dim for dot product");
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0, dims.front(), 1));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), dims.front(), 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

using DotProductInputType =
const typename components::plonk_fixedpoint_dot_rescale_2_gates<BlueprintFieldType,
ArithmetizationParams>::input_type;

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

DotProductInputType input = {x.getData(), y.getData(), zero_var};

components::generate_circuit(component_instance, bp, assignment, input, start_row);
return components::generate_assignments(component_instance, assignment, input, start_row);
}

} // namespace detail

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_dot_product_component(
mlir::zkml::DotProductOp &operation,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &zero_var,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {

auto lhs = frame.memrefs.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.memrefs.end());
auto rhs = frame.memrefs.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.memrefs.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_dot_product_component(x, y, zero_var, bp, assignment,
assignment.allocated_rows());
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_DOT_PRODUCT_HPP
163 changes: 76 additions & 87 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/exp.hpp
Original file line number Diff line number Diff line change
@@ -8,96 +8,85 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/exp.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_exp<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_exp_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;

using component_type = components::fix_exp<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams,
1>::get_public_inputs(),
1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x},
start_row);
return components::generate_assignments(component_instance, assignment, {x},
start_row);
}

} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_exp_component(
mlir::math::ExpOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result =
detail::handle_fixedpoint_exp_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_EXP_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_exp<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_exp_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

using component_type = components::fix_exp<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1>::get_public_inputs(),
1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x}, start_row);
return components::generate_assignments(component_instance, assignment, {x}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_exp_component(
mlir::math::ExpOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_exp_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_EXP_HPP
157 changes: 73 additions & 84 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/floor.hpp
Original file line number Diff line number Diff line change
@@ -8,93 +8,82 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/floor.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_floor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_floor_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;

using component_type = components::fix_floor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader =
ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x},
start_row);
return components::generate_assignments(component_instance, assignment, {x},
start_row);
}

} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_floor_component(
mlir::math::FloorOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result =
detail::handle_fixedpoint_floor_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_FLOOR_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_floor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_floor_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

using component_type = components::fix_floor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader = ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x}, start_row);
return components::generate_assignments(component_instance, assignment, {x}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_floor_component(
mlir::math::FloorOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_floor_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_FLOOR_HPP
Original file line number Diff line number Diff line change
@@ -8,99 +8,89 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/mul_rescale.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_mul_rescale<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_mul_rescale_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::fix_mul_rescale<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams,
1>::get_public_inputs(),
1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x, y},
start_row);
return components::generate_assignments(component_instance, assignment,
{x, y}, start_row);
}

} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_mul_rescale_component(
mlir::arith::MulFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_mul_rescale_component(
x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_MULTIPLICATION_RESCALE_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_mul_rescale<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_mul_rescale_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::fix_mul_rescale<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1>::get_witness(0));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1>::get_public_inputs(),
1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x, y}, start_row);
return components::generate_assignments(component_instance, assignment, {x, y}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_mul_rescale_component(
mlir::arith::MulFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_mul_rescale_component(x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_MULTIPLICATION_RESCALE_HPP
51 changes: 24 additions & 27 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/neg.hpp
Original file line number Diff line number Diff line change
@@ -11,30 +11,27 @@
#include <mlir-assigner/memory/stack_frame.hpp>

namespace nil {
namespace blueprint {

template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_neg_component(
mlir::arith::NegFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result =
detail::handle_integer_neg_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_NEG_HPP
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_neg_component(
mlir::arith::NegFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result = detail::handle_integer_neg_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_NEG_HPP
Original file line number Diff line number Diff line change
@@ -8,98 +8,87 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/rem.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_rem<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_remainder_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::fix_rem<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader =
ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p =
PolicyManager::get_parameters(manifest_reader::get_witness(0, 1, 1));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x, y},
start_row);
return components::generate_assignments(component_instance, assignment,
{x, y}, start_row);
}

} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_remainder_component(
mlir::arith::RemFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_remainder_component(
x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_DIVISION_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_rem<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_remainder_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using component_type = components::fix_rem<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader = ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0, 1, 1));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x, y}, start_row);
return components::generate_assignments(component_instance, assignment, {x, y}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_remainder_component(
mlir::arith::RemFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_remainder_component(x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_DIVISION_HPP
Original file line number Diff line number Diff line change
@@ -12,34 +12,31 @@
#include <mlir-assigner/memory/stack_frame.hpp>

namespace nil {
namespace blueprint {

template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_subtraction_component(
mlir::arith::SubFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_native_field_subtraction_component(
x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_SUBTRACTION_HPP
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_subtraction_component(
mlir::arith::SubFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto x = lhs->second;
auto y = rhs->second;

// TACEO_TODO: check types

auto result = detail::handle_native_field_subtraction_component(x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_SUBTRACTION_HPP
Original file line number Diff line number Diff line change
@@ -9,113 +9,113 @@
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/dot_rescale_2_gates.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>

#include <mlir-assigner/components/fields/addition.hpp>

namespace nil {
namespace blueprint {
/*
namespace detail {
template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_dot_product_component(
memref<crypto3::zk::snark::plonk_variable<typename
BlueprintFieldType::value_type>> x,
memref<crypto3::zk::snark::plonk_variable<typename
BlueprintFieldType::value_type>> y, crypto3::zk::snark::plonk_variable<typename
BlueprintFieldType::value_type> &zero_var,
circuit<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;
using component_type = components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader =
ManifestReader<component_type, ArithmetizationParams, 1, 1>;
auto dims = x.getDims();
ASSERT(dims.size() == 1 && "must be one-dim for dot product");
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0,
dims.front(), 1)); component_type component_instance(p.witness,
manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
dims.front(), 1);
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::
lookup_table_definition<BlueprintFieldType>>(t));
}
};
if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};
using DotProductInputType = const typename
components::plonk_fixedpoint_dot_rescale_2_gates< BlueprintFieldType,
ArithmetizationParams>::input_type;
// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
DotProductInputType input = {x.getData(), y.getData(), zero_var};
components::generate_circuit(component_instance, bp, assignment, input,
start_row);
return components::generate_assignments(component_instance, assignment, input,
start_row);
}
} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_dot_product_component(
mlir::zkml::DotProductOp &operation,
crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &zero_var,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment) {
auto lhs = frame.memrefs.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.memrefs.end());
auto rhs = frame.memrefs.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.memrefs.end());
auto x = lhs->second;
auto y = rhs->second;
// TACEO_TODO: check types
auto result = detail::handle_fixedpoint_dot_product_component(
x, y, zero_var, bp, assignment, assignment.allocated_rows());
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}*/
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_ACOS_HPP
namespace blueprint {
/*
namespace detail {
template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_dot_product_component(
memref<crypto3::zk::snark::plonk_variable<typename
BlueprintFieldType::value_type>> x,
memref<crypto3::zk::snark::plonk_variable<typename
BlueprintFieldType::value_type>> y, crypto3::zk::snark::plonk_variable<typename
BlueprintFieldType::value_type> &zero_var,
circuit<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;
using component_type = components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader =
ManifestReader<component_type, ArithmetizationParams, 1, 1>;
auto dims = x.getDims();
ASSERT(dims.size() == 1 && "must be one-dim for dot product");
const auto p = PolicyManager::get_parameters(manifest_reader::get_witness(0,
dims.front(), 1)); component_type component_instance(p.witness,
manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
dims.front(), 1);
if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::
lookup_table_definition<BlueprintFieldType>>(t));
}
};
if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};
using DotProductInputType = const typename
components::plonk_fixedpoint_dot_rescale_2_gates< BlueprintFieldType,
ArithmetizationParams>::input_type;
// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here
DotProductInputType input = {x.getData(), y.getData(), zero_var};
components::generate_circuit(component_instance, bp, assignment, input,
start_row);
return components::generate_assignments(component_instance, assignment, input,
start_row);
}
} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_dot_product_component(
mlir::zkml::DotProductOp &operation,
crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &zero_var,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment) {
auto lhs = frame.memrefs.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.memrefs.end());
auto rhs = frame.memrefs.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.memrefs.end());
auto x = lhs->second;
auto y = rhs->second;
// TACEO_TODO: check types
auto result = detail::handle_fixedpoint_dot_product_component(
x, y, zero_var, bp, assignment, assignment.allocated_rows());
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}*/
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_ACOS_HPP
Loading

0 comments on commit cae3c94

Please sign in to comment.