Skip to content

Commit

Permalink
feat: split assignment and circuit gen (#48)
Browse files Browse the repository at this point in the history
* build(deps):

* build(deps): update crypto3

* wip: introduce generation flag for splitting assignment and circuit gen

* chore: updated .gitignore

* chore: added command in justfile for deleting old output.json

* wip: add,sub working with only circuit

* wip: all now use comp_params

* wip: load store now use zero on gen mdoe circuit

* feat: handle input_reader for split assignment/circuit gen

* pined blueprint

* fix: input is 0-dimensionsl for ExpandSimple

* fix: re-introduce larger float types as inputs

* pined blueprint

* added circuit diff

* chore: add some asserts for unreachable paths

* fix: respect component sizes when not assigning

---------

Co-authored-by: Franco Nieddu <[email protected]>
  • Loading branch information
dkales and 0xThemis authored Feb 15, 2024
1 parent 13d97ef commit a13bb8d
Show file tree
Hide file tree
Showing 38 changed files with 768 additions and 316 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mlir-assigner/tests/Ops/Current/
**output.json
## Circifier, Assigner and Transpiler generated files:
/*.crct
Expand Down
5 changes: 4 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ buildbpfptester:
make -C build/ -j 12 blueprint_algebra_fixedpoint_plonk_tester_test

build-warn:
make -C build/ -j 12 zkml-onnx-compiler mlir-assigner
make -C build/ -j 12 zkml-onnx-compiler mlir-assigner

delete-output:
find . -type f -iname \*.output.json -delete
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::logic_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
typename component_type::input_type input;
Expand All @@ -62,7 +62,7 @@ namespace nil {

using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, stack, bp, assignment, start_row);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -72,7 +72,7 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::logic_or<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

Expand All @@ -84,7 +84,7 @@ namespace nil {
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, stack, bp, assignment, start_row);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -94,7 +94,7 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::logic_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
typename component_type::input_type input;
Expand All @@ -105,7 +105,7 @@ namespace nil {

using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, stack, bp, assignment, start_row);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -115,19 +115,18 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::bitwise_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::AndIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, stack, bp, assignment, start_row);
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
m);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -137,19 +136,18 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::bitwise_or<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, stack, bp, assignment, start_row);
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
m);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}

template<uint8_t m, typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -159,19 +157,18 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::bitwise_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, m>;

auto input = PREPARE_BINARY_INPUT(mlir::arith::XOrIOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, m));

component_type component(
p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), m);
fill_trace(component, input, operation, stack, bp, assignment, start_row);
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
m);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}

} // namespace blueprint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace nil {
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &nextIndex,
std::uint32_t start_row) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::fix_argmin<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
Expand All @@ -44,7 +45,7 @@ namespace nil {
PreLimbs, PostLimbs, var_value(assignment, nextIndex),
operation.getSelectLastIndex());

auto result = fill_trace_get_result(component, input, operation, stack, bp, assignment, start_row);
auto result = fill_trace_get_result(component, input, operation, stack, bp, assignment, compParams);
stack.push_local(operation.getResult(0), result.min);
stack.push_local(operation.getResult(1), result.index);
}
Expand All @@ -58,7 +59,8 @@ namespace nil {
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &nextIndex,
std::uint32_t start_row) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::fix_argmax<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
Expand All @@ -74,7 +76,7 @@ namespace nil {
PreLimbs, PostLimbs, var_value(assignment, nextIndex),
operation.getSelectLastIndex());

auto result = fill_trace_get_result(component, input, operation, stack, bp, assignment, start_row);
auto result = fill_trace_get_result(component, input, operation, stack, bp, assignment, compParams);
stack.push_local(operation.getResult(0), result.max);
stack.push_local(operation.getResult(1), result.index);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ namespace nil {
&bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = FCmpComponent;
using manifest_reader =
detail::ManifestReader<FCmpComponent, ArithmetizationParams, PreLimbs, PostLimbs>;
Expand All @@ -62,7 +63,7 @@ namespace nil {

FCmpComponent component(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), PreLimbs, PostLimbs);
return fill_trace_get_result(component, input, operation, stack, bp, assignment, start_row);
return fill_trace_get_result(component, input, operation, stack, bp, assignment, compParams);
}
} // namespace

Expand All @@ -74,13 +75,15 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::fix_cmp_extended<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
// we compare 64 bits with this configuration
auto result =
call_component<PreLimbs, PostLimbs, component_type>(operation, stack, bp, assignment, start_row);
call_component<PreLimbs, PostLimbs, component_type>(operation, stack, bp, assignment, compParams);
// TODO should we store zero instead???
switch (operation.getPredicate()) {
case mlir::arith::CmpFPredicate::UGT:
case mlir::arith::CmpFPredicate::OGT: {
Expand Down Expand Up @@ -132,12 +135,13 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::fix_cmp_extended<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
// we compare 64 bits with this configuration
auto result = call_component<2, 2, component_type>(operation, stack, bp, assignment, start_row);
auto result = call_component<2, 2, component_type>(operation, stack, bp, assignment, compParams);
switch (operation.getPredicate()) {
case mlir::arith::CmpIPredicate::sgt:
stack.push_local(operation.getResult(), result.gt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,13 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {

const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
auto c = stack.get_local(operation.getCondition());
auto x = stack.get_local(operation.getTrueValue());
auto y = stack.get_local(operation.getFalseValue());
using component_type = components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
typename component_type::input_type input;
Expand All @@ -65,7 +63,7 @@ namespace nil {
input.y = stack.get_local(operation.getFalseValue());
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, stack, bp, assignment, start_row);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}
} // namespace blueprint
} // namespace nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::addition<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
auto input = PREPARE_BINARY_INPUT(MlirOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, stack, bp, assignment, start_row);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}

template<typename MlirOp, typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -39,15 +41,17 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::subtraction<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
auto input = PREPARE_BINARY_INPUT(MlirOp);
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0));
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, stack, bp, assignment, start_row);
fill_trace(component, input, operation, stack, bp, assignment, compParams);
}
} // namespace blueprint
} // namespace nil
Expand Down
Loading

0 comments on commit a13bb8d

Please sign in to comment.