Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: apply new formatting
Browse files Browse the repository at this point in the history
dkales committed Dec 11, 2023
1 parent 051cdbe commit cae3c94
Showing 30 changed files with 16,970 additions and 3,465 deletions.
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@
#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/components/algebra/fixedpoint/plonk/cmp_extended.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <nil/blueprint/component.hpp>

@@ -40,131 +40,115 @@
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>
handle_f_comparison_component(
mlir::arith::CmpFPredicate p,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;
using component_type = components::fix_cmp_extended<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(
0));
component_type component_instance(
params.witness,
ManifestReader<component_type, ArithmetizationParams, 1,
1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1,
1>::get_public_inputs(),
1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<
BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x, y},
start_row);
auto cmp_result = components::generate_assignments(
component_instance, assignment, {x, y}, start_row);

switch (p) {
case mlir::arith::CmpFPredicate::UGT:
case mlir::arith::CmpFPredicate::OGT: {
return cmp_result.gt;
}
case mlir::arith::CmpFPredicate::ULT:
case mlir::arith::CmpFPredicate::OLT: {
return cmp_result.lt;
}
case mlir::arith::CmpFPredicate::UGE:
case mlir::arith::CmpFPredicate::OGE: {
return cmp_result.geq;
}
case mlir::arith::CmpFPredicate::ULE:
case mlir::arith::CmpFPredicate::OLE: {
return cmp_result.leq;
}
case mlir::arith::CmpFPredicate::UNE:
case mlir::arith::CmpFPredicate::ONE: {
return cmp_result.neq;
}
case mlir::arith::CmpFPredicate::UEQ:
case mlir::arith::CmpFPredicate::OEQ: {
return cmp_result.eq;
}
case mlir::arith::CmpFPredicate::UNO:
case mlir::arith::CmpFPredicate::ORD:
case mlir::arith::CmpFPredicate::AlwaysFalse:
case mlir::arith::CmpFPredicate::AlwaysTrue: {
UNREACHABLE("TACEO_TODO implement fcmp");
break;
}
default:
UNREACHABLE("Unsupported fcmp predicate");
break;
}
}
} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_comparison_component(
mlir::arith::CmpFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto pred = operation.getPredicate();

auto x = lhs->second;
auto y = rhs->second;

// std::stringstream ss;
// ss << var_value(assignment, x) << " cmp " << var_value(assignment, y) <<
// "\n"; llvm::outs() << ss.str();

// TACEO_TODO: check types

auto result = detail::handle_f_comparison_component(pred, x, y, bp,
assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_F_COMPARISON_HPP
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
handle_f_comparison_component(
mlir::arith::CmpFPredicate p,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;
using component_type = components::fix_cmp_extended<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0));
component_type component_instance(
params.witness, ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_public_inputs(), 1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

components::generate_circuit(component_instance, bp, assignment, {x, y}, start_row);
auto cmp_result = components::generate_assignments(component_instance, assignment, {x, y}, start_row);

switch (p) {
case mlir::arith::CmpFPredicate::UGT:
case mlir::arith::CmpFPredicate::OGT: {
return cmp_result.gt;
}
case mlir::arith::CmpFPredicate::ULT:
case mlir::arith::CmpFPredicate::OLT: {
return cmp_result.lt;
}
case mlir::arith::CmpFPredicate::UGE:
case mlir::arith::CmpFPredicate::OGE: {
return cmp_result.geq;
}
case mlir::arith::CmpFPredicate::ULE:
case mlir::arith::CmpFPredicate::OLE: {
return cmp_result.leq;
}
case mlir::arith::CmpFPredicate::UNE:
case mlir::arith::CmpFPredicate::ONE: {
return cmp_result.neq;
}
case mlir::arith::CmpFPredicate::UEQ:
case mlir::arith::CmpFPredicate::OEQ: {
return cmp_result.eq;
}
case mlir::arith::CmpFPredicate::UNO:
case mlir::arith::CmpFPredicate::ORD:
case mlir::arith::CmpFPredicate::AlwaysFalse:
case mlir::arith::CmpFPredicate::AlwaysTrue: {
UNREACHABLE("TACEO_TODO implement fcmp");
break;
}
default:
UNREACHABLE("Unsupported fcmp predicate");
break;
}
}
} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_comparison_component(
mlir::arith::CmpFOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.locals.end());
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(rhs != frame.locals.end());

auto pred = operation.getPredicate();

auto x = lhs->second;
auto y = rhs->second;

// std::stringstream ss;
// ss << var_value(assignment, x) << " cmp " << var_value(assignment, y) <<
// "\n"; llvm::outs() << ss.str();

// TACEO_TODO: check types

auto result = detail::handle_f_comparison_component(pred, x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_F_COMPARISON_HPP
124 changes: 54 additions & 70 deletions mlir-assigner/include/mlir-assigner/components/comparison/select.hpp
Original file line number Diff line number Diff line change
@@ -39,79 +39,63 @@
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {
namespace blueprint {
namespace detail {

template <typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>::result_type
handle_select_component(
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &c,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type> &y,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>;
using component_type = components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
params.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type,
ArithmetizationParams>::get_public_inputs());
template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_select_component(
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &c,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &x,
const typename crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &y,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
using non_native_policy_type = basic_non_native_policy<BlueprintFieldType>;
using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;
using component_type = components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto params = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
component_type component_instance(
params.witness,
ManifestReader<component_type, ArithmetizationParams>::get_constants(),
ManifestReader<component_type, ArithmetizationParams>::get_public_inputs());

components::generate_circuit(component_instance, bp, assignment, {c, x, y},
start_row);
return components::generate_assignments(component_instance, assignment,
{c, x, y}, start_row);
}
} // namespace detail
template <typename BlueprintFieldType, typename ArithmetizationParams>
void handle_select_component(
mlir::arith::SelectOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<
typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<
BlueprintFieldType, ArithmetizationParams>> &assignment,
std::uint32_t start_row) {
components::generate_circuit(component_instance, bp, assignment, {c, x, y}, start_row);
return components::generate_assignments(component_instance, assignment, {c, x, y}, start_row);
}
} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_select_component(
mlir::arith::SelectOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

auto false_value =
frame.locals.find(mlir::hash_value(operation.getFalseValue()));
ASSERT(false_value != frame.locals.end());
auto true_value =
frame.locals.find(mlir::hash_value(operation.getTrueValue()));
ASSERT(true_value != frame.locals.end());
auto condition =
frame.locals.find(mlir::hash_value(operation.getCondition()));
ASSERT(condition != frame.locals.end());
auto false_value = frame.locals.find(mlir::hash_value(operation.getFalseValue()));
ASSERT(false_value != frame.locals.end());
auto true_value = frame.locals.find(mlir::hash_value(operation.getTrueValue()));
ASSERT(true_value != frame.locals.end());
auto condition = frame.locals.find(mlir::hash_value(operation.getCondition()));
ASSERT(condition != frame.locals.end());

auto c = condition->second;
auto x = true_value->second;
auto y = false_value->second;
auto c = condition->second;
auto x = true_value->second;
auto y = false_value->second;

// TACEO_TODO: check types
// TACEO_TODO: check types

auto result =
detail::handle_select_component(c, x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_SELECT_HPP
auto result = detail::handle_select_component(c, x, y, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil
#endif // CRYPTO3_ASSIGNER_SELECT_HPP
Loading

0 comments on commit cae3c94

Please sign in to comment.