From 9f9ded2b99980b3b40fce9b55e72c91df1dc3d72 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Mon, 8 Jul 2024 19:48:58 +0200 Subject: [PATCH] feat!: constant inputs for blackbox (#7222) This PR allows to use constant values for blackbox inputs. Only MultiScalarMul is currently handling constant input, so it will fail if constant inputs are used for any other blackboxes. Noir does ensure that other blackboxes functions do not use constant inputs in this PR. I will make a follow-up PR once this one is merged to have more blackbox functions using constant inputs. --------- Co-authored-by: TomAFrench Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> --- .../acir_format/acir_to_constraint_buf.cpp | 198 +++++++++++++----- .../dsl/acir_format/ec_operations.test.cpp | 52 ++++- .../dsl/acir_format/multi_scalar_mul.cpp | 58 ++++- .../dsl/acir_format/multi_scalar_mul.hpp | 14 +- .../dsl/acir_format/serde/acir.hpp | 178 +++++++++++++++- .../noir-repo/acvm-repo/acir/codegen/acir.cpp | 151 ++++++++++++- .../acvm-repo/acir/src/circuit/mod.rs | 29 ++- .../acvm-repo/acir/src/circuit/opcodes.rs | 4 +- .../opcodes/black_box_function_call.rs | 180 +++++++++------- noir/noir-repo/acvm-repo/acir/src/lib.rs | 6 +- .../acir/tests/test_program_serialization.rs | 69 +++--- .../compiler/optimizers/redundant_range.rs | 28 +-- .../acvm-repo/acvm/src/pwg/blackbox/aes128.rs | 6 +- .../acvm-repo/acvm/src/pwg/blackbox/bigint.rs | 6 +- .../src/pwg/blackbox/embedded_curve_ops.rs | 42 ++-- .../acvm-repo/acvm/src/pwg/blackbox/hash.rs | 32 ++- .../acvm-repo/acvm/src/pwg/blackbox/logic.rs | 34 +-- .../acvm-repo/acvm/src/pwg/blackbox/mod.rs | 30 +-- .../acvm/src/pwg/blackbox/pedersen.rs | 14 +- .../acvm-repo/acvm/src/pwg/blackbox/range.rs | 8 +- .../acvm/src/pwg/blackbox/signature/ecdsa.rs | 16 +- .../src/pwg/blackbox/signature/schnorr.rs | 14 +- .../acvm-repo/acvm/src/pwg/blackbox/utils.rs | 10 +- noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs | 17 +- .../acvm_js/test/shared/multi_scalar_mul.ts | 6 +- .../acvm_js/test/shared/schnorr_verify.ts | 26 +-- .../src/ssa/acir_gen/acir_ir/acir_variable.rs | 28 ++- .../ssa/acir_gen/acir_ir/generated_acir.rs | 4 +- .../tooling/fuzzer/src/dictionary/mod.rs | 12 +- .../tooling/profiler/src/opcode_formatter.rs | 2 +- 30 files changed, 924 insertions(+), 350 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp index e949fa13305..7e887afea2f 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp @@ -189,6 +189,69 @@ void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, s } } +uint32_t get_witness_from_function_input(Program::FunctionInput input) +{ + auto input_witness = std::get(input.input.value); + return input_witness.value.value; +} + +WitnessConstant parse_input(Program::FunctionInput input) +{ + WitnessConstant result = std::visit( + [&](auto&& e) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return WitnessConstant{ + .index = e.value.value, + .value = bb::fr::zero(), + .is_constant = false, + }; + } else if constexpr (std::is_same_v) { + return WitnessConstant{ + .index = 0, + .value = uint256_t(e.value), + .is_constant = true, + }; + } else { + ASSERT(false); + } + return WitnessConstant{ + .index = 0, + .value = bb::fr::zero(), + .is_constant = true, + }; + }, + input.input.value); + return result; + + // WitnessConstant result = std::visit( + // [&](auto&& e) { + // using T = std::decay_t; + // if constexpr (std::is_same_v) { + // return WitnessConstant{ + // .index = e.value.witness.value, + // .value = bb::fr::zero(), + // .is_constant = false, + // }; + // } else if constexpr (std::is_same_v) { + // return WitnessConstant{ + // .index = 0, + // .value = uint256_t(e.value.constant), + // .is_constant = true, + // }; + // } else { + // ASSERT(false); + // } + // return WitnessConstant{ + // .index = 0, + // .value = bb::fr::zero(), + // .is_constant = true, + // }; + // }, + // input.value); + // return result; +} + void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, AcirFormat& af, bool honk_recursion, @@ -198,26 +261,31 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, [&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) { + auto lhs_input = get_witness_from_function_input(arg.lhs); + auto rhs_input = get_witness_from_function_input(arg.rhs); af.logic_constraints.push_back(LogicConstraint{ - .a = arg.lhs.witness.value, - .b = arg.rhs.witness.value, + .a = lhs_input, + .b = rhs_input, .result = arg.output.value, .num_bits = arg.lhs.num_bits, .is_xor_gate = false, }); af.original_opcode_indices.logic_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { + auto lhs_input = get_witness_from_function_input(arg.lhs); + auto rhs_input = get_witness_from_function_input(arg.rhs); af.logic_constraints.push_back(LogicConstraint{ - .a = arg.lhs.witness.value, - .b = arg.rhs.witness.value, + .a = lhs_input, + .b = rhs_input, .result = arg.output.value, .num_bits = arg.lhs.num_bits, .is_xor_gate = true, }); af.original_opcode_indices.logic_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { + auto witness_input = get_witness_from_function_input(arg.input); af.range_constraints.push_back(RangeConstraint{ - .witness = arg.input.witness.value, + .witness = witness_input, .num_bits = arg.input.num_bits, }); af.original_opcode_indices.range_constraints.push_back(opcode_index); @@ -227,21 +295,23 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, .inputs = map(arg.inputs, [](auto& e) { return AES128Input{ - .witness = e.witness.value, + .witness = get_witness_from_function_input(e), .num_bits = e.num_bits, }; }), .iv = map(arg.iv, [](auto& e) { + auto witness = get_witness_from_function_input(e); return AES128Input{ - .witness = e.witness.value, + .witness = witness, .num_bits = e.num_bits, }; }), .key = map(arg.key, [](auto& e) { + auto input_witness = get_witness_from_function_input(e); return AES128Input{ - .witness = e.witness.value, + .witness = input_witness, .num_bits = e.num_bits, }; }), @@ -253,8 +323,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.sha256_constraints.push_back(Sha256Constraint{ .inputs = map(arg.inputs, [](auto& e) { + auto input_witness = get_witness_from_function_input(e); return Sha256Input{ - .witness = e.witness.value, + .witness = input_witness, .num_bits = e.num_bits, }; }), @@ -266,15 +337,17 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.sha256_compression.push_back(Sha256Compression{ .inputs = map(arg.inputs, [](auto& e) { + auto input_witness = get_witness_from_function_input(e); return Sha256Input{ - .witness = e.witness.value, + .witness = input_witness, .num_bits = e.num_bits, }; }), .hash_values = map(arg.hash_values, [](auto& e) { + auto input_witness = get_witness_from_function_input(e); return Sha256Input{ - .witness = e.witness.value, + .witness = input_witness, .num_bits = e.num_bits, }; }), @@ -285,8 +358,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.blake2s_constraints.push_back(Blake2sConstraint{ .inputs = map(arg.inputs, [](auto& e) { + auto input_witness = get_witness_from_function_input(e); return Blake2sInput{ - .witness = e.witness.value, + .witness = input_witness, .num_bits = e.num_bits, }; }), @@ -297,8 +371,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.blake3_constraints.push_back(Blake3Constraint{ .inputs = map(arg.inputs, [](auto& e) { + auto input_witness = get_witness_from_function_input(e); return Blake3Input{ - .witness = e.witness.value, + .witness = input_witness, .num_bits = e.num_bits, }; }), @@ -306,17 +381,20 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, }); af.original_opcode_indices.blake3_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { + auto input_pkey_x = get_witness_from_function_input(arg.public_key_x); + auto input_pkey_y = get_witness_from_function_input(arg.public_key_y); af.schnorr_constraints.push_back(SchnorrConstraint{ - .message = map(arg.message, [](auto& e) { return e.witness.value; }), - .public_key_x = arg.public_key_x.witness.value, - .public_key_y = arg.public_key_y.witness.value, + .message = map(arg.message, [](auto& e) { return get_witness_from_function_input(e); }), + .public_key_x = input_pkey_x, + .public_key_y = input_pkey_y, .result = arg.output.value, - .signature = map(arg.signature, [](auto& e) { return e.witness.value; }), + .signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }), }); af.original_opcode_indices.schnorr_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { + af.pedersen_constraints.push_back(PedersenConstraint{ - .scalars = map(arg.inputs, [](auto& e) { return e.witness.value; }), + .scalars = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }), .hash_index = arg.domain_separator, .result_x = arg.outputs[0].value, .result_y = arg.outputs[1].value, @@ -324,92 +402,111 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.original_opcode_indices.pedersen_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.pedersen_hash_constraints.push_back(PedersenHashConstraint{ - .scalars = map(arg.inputs, [](auto& e) { return e.witness.value; }), + .scalars = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }), .hash_index = arg.domain_separator, .result = arg.output.value, }); af.original_opcode_indices.pedersen_hash_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.ecdsa_k1_constraints.push_back(EcdsaSecp256k1Constraint{ - .hashed_message = map(arg.hashed_message, [](auto& e) { return e.witness.value; }), - .signature = map(arg.signature, [](auto& e) { return e.witness.value; }), - .pub_x_indices = map(arg.public_key_x, [](auto& e) { return e.witness.value; }), - .pub_y_indices = map(arg.public_key_y, [](auto& e) { return e.witness.value; }), + .hashed_message = + map(arg.hashed_message, [](auto& e) { return get_witness_from_function_input(e); }), + .signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }), + .pub_x_indices = map(arg.public_key_x, [](auto& e) { return get_witness_from_function_input(e); }), + .pub_y_indices = map(arg.public_key_y, [](auto& e) { return get_witness_from_function_input(e); }), .result = arg.output.value, }); af.original_opcode_indices.ecdsa_k1_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.ecdsa_r1_constraints.push_back(EcdsaSecp256r1Constraint{ - .hashed_message = map(arg.hashed_message, [](auto& e) { return e.witness.value; }), - .pub_x_indices = map(arg.public_key_x, [](auto& e) { return e.witness.value; }), - .pub_y_indices = map(arg.public_key_y, [](auto& e) { return e.witness.value; }), + .hashed_message = + map(arg.hashed_message, [](auto& e) { return get_witness_from_function_input(e); }), + .pub_x_indices = map(arg.public_key_x, [](auto& e) { return get_witness_from_function_input(e); }), + .pub_y_indices = map(arg.public_key_y, [](auto& e) { return get_witness_from_function_input(e); }), .result = arg.output.value, - .signature = map(arg.signature, [](auto& e) { return e.witness.value; }), + .signature = map(arg.signature, [](auto& e) { return get_witness_from_function_input(e); }), }); af.original_opcode_indices.ecdsa_r1_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.multi_scalar_mul_constraints.push_back(MultiScalarMul{ - .points = map(arg.points, [](auto& e) { return e.witness.value; }), - .scalars = map(arg.scalars, [](auto& e) { return e.witness.value; }), + .points = map(arg.points, [](auto& e) { return parse_input(e); }), + .scalars = map(arg.scalars, [](auto& e) { return parse_input(e); }), .out_point_x = arg.outputs[0].value, .out_point_y = arg.outputs[1].value, .out_point_is_infinite = arg.outputs[2].value, }); af.original_opcode_indices.multi_scalar_mul_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { + auto input_1_x = get_witness_from_function_input(arg.input1[0]); + auto input_1_y = get_witness_from_function_input(arg.input1[1]); + auto input_1_infinite = get_witness_from_function_input(arg.input1[2]); + auto input_2_x = get_witness_from_function_input(arg.input2[0]); + auto input_2_y = get_witness_from_function_input(arg.input2[1]); + auto input_2_infinite = get_witness_from_function_input(arg.input2[2]); + af.ec_add_constraints.push_back(EcAdd{ - .input1_x = arg.input1[0].witness.value, - .input1_y = arg.input1[1].witness.value, - .input1_infinite = arg.input1[2].witness.value, - .input2_x = arg.input2[0].witness.value, - .input2_y = arg.input2[1].witness.value, - .input2_infinite = arg.input2[2].witness.value, + .input1_x = input_1_x, + .input1_y = input_1_y, + .input1_infinite = input_1_infinite, + .input2_x = input_2_x, + .input2_y = input_2_y, + .input2_infinite = input_2_infinite, .result_x = arg.outputs[0].value, .result_y = arg.outputs[1].value, .result_infinite = arg.outputs[2].value, }); af.original_opcode_indices.ec_add_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { + auto input_var_message_size = get_witness_from_function_input(arg.var_message_size); af.keccak_constraints.push_back(KeccakConstraint{ .inputs = map(arg.inputs, [](auto& e) { + auto input_witness = get_witness_from_function_input(e); return HashInput{ - .witness = e.witness.value, + .witness = input_witness, .num_bits = e.num_bits, }; }), .result = map(arg.outputs, [](auto& e) { return e.value; }), - .var_message_size = arg.var_message_size.witness.value, + .var_message_size = input_var_message_size, }); af.original_opcode_indices.keccak_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.keccak_permutations.push_back(Keccakf1600{ - .state = map(arg.inputs, [](auto& e) { return e.witness.value; }), + .state = map(arg.inputs, + [](auto& e) { + auto input_witness = get_witness_from_function_input(e); + return input_witness; + }), .result = map(arg.outputs, [](auto& e) { return e.value; }), }); af.original_opcode_indices.keccak_permutations.push_back(opcode_index); } else if constexpr (std::is_same_v) { if (honk_recursion) { // if we're using the honk recursive verifier auto c = HonkRecursionConstraint{ - .key = map(arg.verification_key, [](auto& e) { return e.witness.value; }), - .proof = map(arg.proof, [](auto& e) { return e.witness.value; }), - .public_inputs = map(arg.public_inputs, [](auto& e) { return e.witness.value; }), + .key = map(arg.verification_key, [](auto& e) { return get_witness_from_function_input(e); }), + .proof = map(arg.proof, [](auto& e) { return get_witness_from_function_input(e); }), + .public_inputs = + map(arg.public_inputs, [](auto& e) { return get_witness_from_function_input(e); }), }; af.honk_recursion_constraints.push_back(c); af.original_opcode_indices.honk_recursion_constraints.push_back(opcode_index); } else { + auto input_key = get_witness_from_function_input(arg.key_hash); + auto c = RecursionConstraint{ - .key = map(arg.verification_key, [](auto& e) { return e.witness.value; }), - .proof = map(arg.proof, [](auto& e) { return e.witness.value; }), - .public_inputs = map(arg.public_inputs, [](auto& e) { return e.witness.value; }), - .key_hash = arg.key_hash.witness.value, + .key = map(arg.verification_key, [](auto& e) { return get_witness_from_function_input(e); }), + .proof = map(arg.proof, [](auto& e) { return get_witness_from_function_input(e); }), + .public_inputs = + map(arg.public_inputs, [](auto& e) { return get_witness_from_function_input(e); }), + .key_hash = input_key, }; af.recursion_constraints.push_back(c); af.original_opcode_indices.recursion_constraints.push_back(opcode_index); } } else if constexpr (std::is_same_v) { af.bigint_from_le_bytes_constraints.push_back(BigIntFromLeBytes{ - .inputs = map(arg.inputs, [](auto& e) { return e.witness.value; }), + .inputs = map(arg.inputs, [](auto& e) { return get_witness_from_function_input(e); }), .modulus = map(arg.modulus, [](auto& e) -> uint32_t { return e; }), .result = arg.output, }); @@ -454,7 +551,11 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.original_opcode_indices.bigint_operations.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.poseidon2_constraints.push_back(Poseidon2Constraint{ - .state = map(arg.inputs, [](auto& e) { return e.witness.value; }), + .state = map(arg.inputs, + [](auto& e) { + auto input_witness = get_witness_from_function_input(e); + return input_witness; + }), .result = map(arg.outputs, [](auto& e) { return e.value; }), .len = arg.len, }); @@ -484,7 +585,8 @@ BlockConstraint handle_memory_init(Program::Opcode::MemoryInit const& mem_init) }); } - // Databus is only supported for Goblin, non Goblin builders will treat call_data and return_data as normal array. + // Databus is only supported for Goblin, non Goblin builders will treat call_data and return_data as normal + // array. if (std::holds_alternative(mem_init.block_type.value)) { block.type = BlockType::CallData; } else if (std::holds_alternative(mem_init.block_type.value)) { diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp index 1dacfe85a0f..9ff5123ffd1 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp @@ -129,8 +129,56 @@ TEST_F(EcOperations, TestECMultiScalarMul) fr(0), }; msm_constrain = MultiScalarMul{ - .points = { 1, 2, 3, 1, 2, 3 }, - .scalars = { 4, 5, 4, 5 }, + .points = { WitnessConstant{ + .index = 1, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 2, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 3, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 1, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 2, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 3, + .value = fr(0), + .is_constant = false, + } }, + .scalars = { WitnessConstant{ + .index = 4, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 5, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 4, + .value = fr(0), + .is_constant = false, + }, + WitnessConstant{ + .index = 5, + .value = fr(0), + .is_constant = false, + } }, .out_point_x = 6, .out_point_y = 7, .out_point_is_infinite = 0, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp index 928eacb75f9..398c794dea1 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp @@ -1,4 +1,5 @@ #include "multi_scalar_mul.hpp" +#include "barretenberg/dsl/acir_format/serde/acir.hpp" #include "barretenberg/ecc/curves/bn254/fr.hpp" #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" #include "barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp" @@ -21,27 +22,62 @@ template void create_multi_scalar_mul_constraint(Builder& bui for (size_t i = 0; i < input.points.size(); i += 3) { // Instantiate the input point/variable base as `cycle_group_ct` - auto point_x = field_ct::from_witness_index(&builder, input.points[i]); - auto point_y = field_ct::from_witness_index(&builder, input.points[i + 1]); - auto infinite = bool_ct(field_ct::from_witness_index(&builder, input.points[i + 2])); + field_ct point_x; + field_ct point_y; + bool_ct infinite; + if (input.points[i].is_constant) { + point_x = field_ct(input.points[i].value); + } else { + point_x = field_ct::from_witness_index(&builder, input.points[i].index); + } + if (input.points[i + 1].is_constant) { + point_y = field_ct(input.points[i + 1].value); + } else { + point_y = field_ct::from_witness_index(&builder, input.points[i + 1].index); + } + if (input.points[i + 2].is_constant) { + infinite = bool_ct(field_ct(input.points[i + 2].value)); + } else { + infinite = bool_ct(field_ct::from_witness_index(&builder, input.points[i + 2].index)); + } cycle_group_ct input_point(point_x, point_y, infinite); // Reconstruct the scalar from the low and high limbs - field_ct scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3)]); - field_ct scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3) + 1]); + field_ct scalar_low_as_field; + field_ct scalar_high_as_field; + if (input.scalars[2 * (i / 3)].is_constant) { + scalar_low_as_field = field_ct(input.scalars[2 * (i / 3)].value); + } else { + scalar_low_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3)].index); + } + if (input.scalars[2 * (i / 3) + 1].is_constant) { + scalar_high_as_field = field_ct(input.scalars[2 * (i / 3) + 1].value); + } else { + scalar_high_as_field = field_ct::from_witness_index(&builder, input.scalars[2 * (i / 3) + 1].index); + } cycle_scalar_ct scalar(scalar_low_as_field, scalar_high_as_field); // Add the point and scalar to the vectors points.push_back(input_point); scalars.push_back(scalar); } - // Call batch_mul to multiply the points and scalars and sum the results auto output_point = cycle_group_ct::batch_mul(points, scalars).get_standard_form(); - - // Add the constraints - builder.assert_equal(output_point.x.get_witness_index(), input.out_point_x); - builder.assert_equal(output_point.y.get_witness_index(), input.out_point_y); - builder.assert_equal(output_point.is_point_at_infinity().witness_index, input.out_point_is_infinite); + // Add the constraints and handle constant values + if (output_point.is_point_at_infinity().is_constant()) { + builder.fix_witness(input.out_point_is_infinite, output_point.is_point_at_infinity().get_value()); + } else { + builder.assert_equal(output_point.is_point_at_infinity().witness_index, input.out_point_is_infinite); + } + if (output_point.x.is_constant()) { + builder.fix_witness(input.out_point_x, output_point.x.get_value()); + } else { + builder.assert_equal(output_point.x.get_witness_index(), input.out_point_x); + } + if (output_point.y.is_constant()) { + builder.fix_witness(input.out_point_y, output_point.y.get_value()); + } else { + builder.assert_equal(output_point.y.get_witness_index(), input.out_point_y); + } } template void create_multi_scalar_mul_constraint(UltraCircuitBuilder& builder, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp index 65e9a7c1b16..ef980a99ab6 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp @@ -1,13 +1,23 @@ #pragma once #include "barretenberg/serialize/msgpack.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" +#include "serde/index.hpp" #include #include namespace acir_format { +template struct WitnessConstant { + uint32_t index; + FF value; + bool is_constant; + MSGPACK_FIELDS(index, value, is_constant); + friend bool operator==(WitnessConstant const& lhs, WitnessConstant const& rhs) = default; +}; + struct MultiScalarMul { - std::vector points; - std::vector scalars; + std::vector> points; + std::vector> scalars; uint32_t out_point_x; uint32_t out_point_y; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index 2c8ad3cc5f8..2406e27a470 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -13,8 +13,33 @@ struct Witness { static Witness bincodeDeserialize(std::vector); }; +struct ConstantOrWitnessEnum { + + struct Constant { + std::string value; + + friend bool operator==(const Constant&, const Constant&); + std::vector bincodeSerialize() const; + static Constant bincodeDeserialize(std::vector); + }; + + struct Witness { + Program::Witness value; + + friend bool operator==(const Witness&, const Witness&); + std::vector bincodeSerialize() const; + static Witness bincodeDeserialize(std::vector); + }; + + std::variant value; + + friend bool operator==(const ConstantOrWitnessEnum&, const ConstantOrWitnessEnum&); + std::vector bincodeSerialize() const; + static ConstantOrWitnessEnum bincodeDeserialize(std::vector); +}; + struct FunctionInput { - Program::Witness witness; + Program::ConstantOrWitnessEnum input; uint32_t num_bits; friend bool operator==(const FunctionInput&, const FunctionInput&); @@ -6911,6 +6936,151 @@ Program::Circuit serde::Deserializable::deserialize(Deserializ namespace Program { +inline bool operator==(const ConstantOrWitnessEnum& lhs, const ConstantOrWitnessEnum& rhs) +{ + if (!(lhs.value == rhs.value)) { + return false; + } + return true; +} + +inline std::vector ConstantOrWitnessEnum::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline ConstantOrWitnessEnum ConstantOrWitnessEnum::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::ConstantOrWitnessEnum& obj, + Serializer& serializer) +{ + serializer.increase_container_depth(); + serde::Serializable::serialize(obj.value, serializer); + serializer.decrease_container_depth(); +} + +template <> +template +Program::ConstantOrWitnessEnum serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + deserializer.increase_container_depth(); + Program::ConstantOrWitnessEnum obj; + obj.value = serde::Deserializable::deserialize(deserializer); + deserializer.decrease_container_depth(); + return obj; +} + +namespace Program { + +inline bool operator==(const ConstantOrWitnessEnum::Constant& lhs, const ConstantOrWitnessEnum::Constant& rhs) +{ + if (!(lhs.value == rhs.value)) { + return false; + } + return true; +} + +inline std::vector ConstantOrWitnessEnum::Constant::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline ConstantOrWitnessEnum::Constant ConstantOrWitnessEnum::Constant::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize( + const Program::ConstantOrWitnessEnum::Constant& obj, Serializer& serializer) +{ + serde::Serializable::serialize(obj.value, serializer); +} + +template <> +template +Program::ConstantOrWitnessEnum::Constant serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + Program::ConstantOrWitnessEnum::Constant obj; + obj.value = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Program { + +inline bool operator==(const ConstantOrWitnessEnum::Witness& lhs, const ConstantOrWitnessEnum::Witness& rhs) +{ + if (!(lhs.value == rhs.value)) { + return false; + } + return true; +} + +inline std::vector ConstantOrWitnessEnum::Witness::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline ConstantOrWitnessEnum::Witness ConstantOrWitnessEnum::Witness::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize( + const Program::ConstantOrWitnessEnum::Witness& obj, Serializer& serializer) +{ + serde::Serializable::serialize(obj.value, serializer); +} + +template <> +template +Program::ConstantOrWitnessEnum::Witness serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + Program::ConstantOrWitnessEnum::Witness obj; + obj.value = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Program { + inline bool operator==(const Directive& lhs, const Directive& rhs) { if (!(lhs.value == rhs.value)) { @@ -7360,7 +7530,7 @@ namespace Program { inline bool operator==(const FunctionInput& lhs, const FunctionInput& rhs) { - if (!(lhs.witness == rhs.witness)) { + if (!(lhs.input == rhs.input)) { return false; } if (!(lhs.num_bits == rhs.num_bits)) { @@ -7393,7 +7563,7 @@ template void serde::Serializable::serialize(const Program::FunctionInput& obj, Serializer& serializer) { serializer.increase_container_depth(); - serde::Serializable::serialize(obj.witness, serializer); + serde::Serializable::serialize(obj.input, serializer); serde::Serializable::serialize(obj.num_bits, serializer); serializer.decrease_container_depth(); } @@ -7404,7 +7574,7 @@ Program::FunctionInput serde::Deserializable::deserializ { deserializer.increase_container_depth(); Program::FunctionInput obj; - obj.witness = serde::Deserializable::deserialize(deserializer); + obj.input = serde::Deserializable::deserialize(deserializer); obj.num_bits = serde::Deserializable::deserialize(deserializer); deserializer.decrease_container_depth(); return obj; diff --git a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp index 47e184a6332..c1160930571 100644 --- a/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp +++ b/noir/noir-repo/acvm-repo/acir/codegen/acir.cpp @@ -13,8 +13,33 @@ namespace Program { static Witness bincodeDeserialize(std::vector); }; + struct ConstantOrWitnessEnum { + + struct Constant { + std::string value; + + friend bool operator==(const Constant&, const Constant&); + std::vector bincodeSerialize() const; + static Constant bincodeDeserialize(std::vector); + }; + + struct Witness { + Program::Witness value; + + friend bool operator==(const Witness&, const Witness&); + std::vector bincodeSerialize() const; + static Witness bincodeDeserialize(std::vector); + }; + + std::variant value; + + friend bool operator==(const ConstantOrWitnessEnum&, const ConstantOrWitnessEnum&); + std::vector bincodeSerialize() const; + static ConstantOrWitnessEnum bincodeDeserialize(std::vector); + }; + struct FunctionInput { - Program::Witness witness; + Program::ConstantOrWitnessEnum input; uint32_t num_bits; friend bool operator==(const FunctionInput&, const FunctionInput&); @@ -5716,6 +5741,124 @@ Program::Circuit serde::Deserializable::deserialize(Deserializ return obj; } +namespace Program { + + inline bool operator==(const ConstantOrWitnessEnum &lhs, const ConstantOrWitnessEnum &rhs) { + if (!(lhs.value == rhs.value)) { return false; } + return true; + } + + inline std::vector ConstantOrWitnessEnum::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline ConstantOrWitnessEnum ConstantOrWitnessEnum::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::ConstantOrWitnessEnum &obj, Serializer &serializer) { + serializer.increase_container_depth(); + serde::Serializable::serialize(obj.value, serializer); + serializer.decrease_container_depth(); +} + +template <> +template +Program::ConstantOrWitnessEnum serde::Deserializable::deserialize(Deserializer &deserializer) { + deserializer.increase_container_depth(); + Program::ConstantOrWitnessEnum obj; + obj.value = serde::Deserializable::deserialize(deserializer); + deserializer.decrease_container_depth(); + return obj; +} + +namespace Program { + + inline bool operator==(const ConstantOrWitnessEnum::Constant &lhs, const ConstantOrWitnessEnum::Constant &rhs) { + if (!(lhs.value == rhs.value)) { return false; } + return true; + } + + inline std::vector ConstantOrWitnessEnum::Constant::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline ConstantOrWitnessEnum::Constant ConstantOrWitnessEnum::Constant::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::ConstantOrWitnessEnum::Constant &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.value, serializer); +} + +template <> +template +Program::ConstantOrWitnessEnum::Constant serde::Deserializable::deserialize(Deserializer &deserializer) { + Program::ConstantOrWitnessEnum::Constant obj; + obj.value = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Program { + + inline bool operator==(const ConstantOrWitnessEnum::Witness &lhs, const ConstantOrWitnessEnum::Witness &rhs) { + if (!(lhs.value == rhs.value)) { return false; } + return true; + } + + inline std::vector ConstantOrWitnessEnum::Witness::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline ConstantOrWitnessEnum::Witness ConstantOrWitnessEnum::Witness::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Program + +template <> +template +void serde::Serializable::serialize(const Program::ConstantOrWitnessEnum::Witness &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.value, serializer); +} + +template <> +template +Program::ConstantOrWitnessEnum::Witness serde::Deserializable::deserialize(Deserializer &deserializer) { + Program::ConstantOrWitnessEnum::Witness obj; + obj.value = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Program { inline bool operator==(const Directive &lhs, const Directive &rhs) { @@ -6086,7 +6229,7 @@ Program::ExpressionWidth::Bounded serde::Deserializable template void serde::Serializable::serialize(const Program::FunctionInput &obj, Serializer &serializer) { serializer.increase_container_depth(); - serde::Serializable::serialize(obj.witness, serializer); + serde::Serializable::serialize(obj.input, serializer); serde::Serializable::serialize(obj.num_bits, serializer); serializer.decrease_container_depth(); } @@ -6122,7 +6265,7 @@ template Program::FunctionInput serde::Deserializable::deserialize(Deserializer &deserializer) { deserializer.increase_container_depth(); Program::FunctionInput obj; - obj.witness = serde::Deserializable::deserialize(deserializer); + obj.input = serde::Deserializable::deserialize(deserializer); obj.num_bits = serde::Deserializable::deserialize(deserializer); deserializer.decrease_container_depth(); return obj; diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/mod.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/mod.rs index 7f3c1890717..115ecbf7992 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/mod.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/mod.rs @@ -360,7 +360,7 @@ mod tests { use std::collections::BTreeSet; use super::{ - opcodes::{BlackBoxFuncCall, FunctionInput}, + opcodes::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput}, Circuit, Compression, Opcode, PublicInputs, }; use crate::{ @@ -372,34 +372,29 @@ mod tests { fn and_opcode() -> Opcode { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND { - lhs: FunctionInput { witness: Witness(1), num_bits: 4 }, - rhs: FunctionInput { witness: Witness(2), num_bits: 4 }, + lhs: FunctionInput::witness(Witness(1), 4), + rhs: FunctionInput::witness(Witness(2), 4), output: Witness(3), }) } fn range_opcode() -> Opcode { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { witness: Witness(1), num_bits: 8 }, + input: FunctionInput::witness(Witness(1), 8), }) } fn keccakf1600_opcode() -> Opcode { - let inputs: Box<[FunctionInput; 25]> = Box::new(std::array::from_fn(|i| FunctionInput { - witness: Witness(i as u32 + 1), - num_bits: 8, - })); + let inputs: Box<[FunctionInput; 25]> = + Box::new(std::array::from_fn(|i| FunctionInput::witness(Witness(i as u32 + 1), 8))); let outputs: Box<[Witness; 25]> = Box::new(std::array::from_fn(|i| Witness(i as u32 + 26))); Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs }) } fn schnorr_verify_opcode() -> Opcode { - let public_key_x = - FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() }; - let public_key_y = - FunctionInput { witness: Witness(2), num_bits: FieldElement::max_num_bits() }; - let signature: Box<[FunctionInput; 64]> = Box::new(std::array::from_fn(|i| { - FunctionInput { witness: Witness(i as u32 + 3), num_bits: 8 } - })); - let message: Vec = vec![FunctionInput { witness: Witness(67), num_bits: 8 }]; + let public_key_x = FunctionInput::witness(Witness(1), FieldElement::max_num_bits()); + let public_key_y = FunctionInput::witness(Witness(2), FieldElement::max_num_bits()); + let signature: Box<[FunctionInput; 64]> = + Box::new(std::array::from_fn(|i| FunctionInput::witness(Witness(i as u32 + 3), 8))); + let message: Vec> = vec![FunctionInput::witness(Witness(67), 8)]; let output = Witness(68); Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SchnorrVerify { @@ -425,7 +420,7 @@ mod tests { }; let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() }; - fn read_write Deserialize<'a>>( + fn read_write Deserialize<'a> + AcirField>( program: Program, ) -> (Program, Program) { let bytes = Program::serialize_program(&program); diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs index 984422c5e3a..d303f9fbbab 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; mod black_box_function_call; mod memory_operation; -pub use black_box_function_call::{BlackBoxFuncCall, FunctionInput}; +pub use black_box_function_call::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput}; pub use memory_operation::{BlockId, MemOp}; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -67,7 +67,7 @@ pub enum Opcode { /// /// Aztec's Barretenberg uses BN254 as the main curve and Grumpkin as the /// embedded curve. - BlackBoxFuncCall(BlackBoxFuncCall), + BlackBoxFuncCall(BlackBoxFuncCall), /// This opcode is a specialization of a Brillig opcode. Instead of having /// some generic assembly code like Brillig, a directive has a hardcoded diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs index b8be81fcdef..7c560a0a346 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs @@ -4,124 +4,161 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; // Note: Some functions will not use all of the witness // So we need to supply how many bits of the witness is needed + #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub struct FunctionInput { - pub witness: Witness, +pub enum ConstantOrWitnessEnum { + Constant(F), + Witness(Witness), +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct FunctionInput { + pub input: ConstantOrWitnessEnum, pub num_bits: u32, } +impl FunctionInput { + pub fn to_witness(&self) -> Witness { + match self.input { + ConstantOrWitnessEnum::Constant(_) => unreachable!("ICE - Expected Witness"), + ConstantOrWitnessEnum::Witness(witness) => witness, + } + } + + pub fn num_bits(&self) -> u32 { + self.num_bits + } + + pub fn witness(witness: Witness, num_bits: u32) -> FunctionInput { + FunctionInput { input: ConstantOrWitnessEnum::Witness(witness), num_bits } + } + + pub fn constant(value: F, num_bits: u32) -> FunctionInput { + FunctionInput { input: ConstantOrWitnessEnum::Constant(value), num_bits } + } +} + +impl std::fmt::Display for FunctionInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.input { + ConstantOrWitnessEnum::Constant(constant) => write!(f, "{constant}"), + ConstantOrWitnessEnum::Witness(witness) => write!(f, "{}", witness.0), + } + } +} + #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum BlackBoxFuncCall { +pub enum BlackBoxFuncCall { AES128Encrypt { - inputs: Vec, - iv: Box<[FunctionInput; 16]>, - key: Box<[FunctionInput; 16]>, + inputs: Vec>, + iv: Box<[FunctionInput; 16]>, + key: Box<[FunctionInput; 16]>, outputs: Vec, }, AND { - lhs: FunctionInput, - rhs: FunctionInput, + lhs: FunctionInput, + rhs: FunctionInput, output: Witness, }, XOR { - lhs: FunctionInput, - rhs: FunctionInput, + lhs: FunctionInput, + rhs: FunctionInput, output: Witness, }, RANGE { - input: FunctionInput, + input: FunctionInput, }, SHA256 { - inputs: Vec, + inputs: Vec>, outputs: Box<[Witness; 32]>, }, Blake2s { - inputs: Vec, + inputs: Vec>, outputs: Box<[Witness; 32]>, }, Blake3 { - inputs: Vec, + inputs: Vec>, outputs: Box<[Witness; 32]>, }, SchnorrVerify { - public_key_x: FunctionInput, - public_key_y: FunctionInput, + public_key_x: FunctionInput, + public_key_y: FunctionInput, #[serde( serialize_with = "serialize_big_array", deserialize_with = "deserialize_big_array_into_box" )] - signature: Box<[FunctionInput; 64]>, - message: Vec, + signature: Box<[FunctionInput; 64]>, + message: Vec>, output: Witness, }, /// Will be deprecated PedersenCommitment { - inputs: Vec, + inputs: Vec>, domain_separator: u32, outputs: (Witness, Witness), }, /// Will be deprecated PedersenHash { - inputs: Vec, + inputs: Vec>, domain_separator: u32, output: Witness, }, EcdsaSecp256k1 { - public_key_x: Box<[FunctionInput; 32]>, - public_key_y: Box<[FunctionInput; 32]>, + public_key_x: Box<[FunctionInput; 32]>, + public_key_y: Box<[FunctionInput; 32]>, #[serde( serialize_with = "serialize_big_array", deserialize_with = "deserialize_big_array_into_box" )] - signature: Box<[FunctionInput; 64]>, - hashed_message: Box<[FunctionInput; 32]>, + signature: Box<[FunctionInput; 64]>, + hashed_message: Box<[FunctionInput; 32]>, output: Witness, }, EcdsaSecp256r1 { - public_key_x: Box<[FunctionInput; 32]>, - public_key_y: Box<[FunctionInput; 32]>, + public_key_x: Box<[FunctionInput; 32]>, + public_key_y: Box<[FunctionInput; 32]>, #[serde( serialize_with = "serialize_big_array", deserialize_with = "deserialize_big_array_into_box" )] - signature: Box<[FunctionInput; 64]>, - hashed_message: Box<[FunctionInput; 32]>, + signature: Box<[FunctionInput; 64]>, + hashed_message: Box<[FunctionInput; 32]>, output: Witness, }, MultiScalarMul { - points: Vec, - scalars: Vec, + points: Vec>, + scalars: Vec>, outputs: (Witness, Witness, Witness), }, EmbeddedCurveAdd { - input1: Box<[FunctionInput; 3]>, - input2: Box<[FunctionInput; 3]>, + input1: Box<[FunctionInput; 3]>, + input2: Box<[FunctionInput; 3]>, outputs: (Witness, Witness, Witness), }, Keccak256 { - inputs: Vec, + inputs: Vec>, /// This is the number of bytes to take /// from the input. Note: if `var_message_size` /// is more than the number of bytes in the input, /// then an error is returned. - var_message_size: FunctionInput, + var_message_size: FunctionInput, outputs: Box<[Witness; 32]>, }, Keccakf1600 { - inputs: Box<[FunctionInput; 25]>, + inputs: Box<[FunctionInput; 25]>, outputs: Box<[Witness; 25]>, }, RecursiveAggregation { - verification_key: Vec, - proof: Vec, + verification_key: Vec>, + proof: Vec>, /// These represent the public inputs of the proof we are verifying /// They should be checked against in the circuit after construction /// of a new aggregation state - public_inputs: Vec, + public_inputs: Vec>, /// A key hash is used to check the validity of the verification key. /// The circuit implementing this opcode can use this hash to ensure that the /// key provided to the circuit matches the key produced by the circuit creator - key_hash: FunctionInput, + key_hash: FunctionInput, }, BigIntAdd { lhs: u32, @@ -144,7 +181,7 @@ pub enum BlackBoxFuncCall { output: u32, }, BigIntFromLeBytes { - inputs: Vec, + inputs: Vec>, modulus: Vec, output: u32, }, @@ -156,7 +193,7 @@ pub enum BlackBoxFuncCall { /// outputting the permuted state. Poseidon2Permutation { /// Input state for the permutation of Poseidon2 - inputs: Vec, + inputs: Vec>, /// Permuted state outputs: Vec, /// State length (in number of field elements) @@ -172,15 +209,15 @@ pub enum BlackBoxFuncCall { /// * `outputs` - result of the input compressed into 256 bits Sha256Compression { /// 512 bits of the input message, represented by 16 u32s - inputs: Box<[FunctionInput; 16]>, + inputs: Box<[FunctionInput; 16]>, /// Vector of 8 u32s used to compress the input - hash_values: Box<[FunctionInput; 8]>, + hash_values: Box<[FunctionInput; 8]>, /// Output of the compression, represented by 8 u32s outputs: Box<[Witness; 8]>, }, } -impl BlackBoxFuncCall { +impl BlackBoxFuncCall { pub fn get_black_box_func(&self) -> BlackBoxFunc { match self { BlackBoxFuncCall::AES128Encrypt { .. } => BlackBoxFunc::AES128Encrypt, @@ -215,7 +252,7 @@ impl BlackBoxFuncCall { self.get_black_box_func().name() } - pub fn get_inputs_vec(&self) -> Vec { + pub fn get_inputs_vec(&self) -> Vec> { match self { BlackBoxFuncCall::AES128Encrypt { inputs, .. } | BlackBoxFuncCall::SHA256 { inputs, .. } @@ -240,7 +277,7 @@ impl BlackBoxFuncCall { | BlackBoxFuncCall::BigIntDiv { .. } | BlackBoxFuncCall::BigIntToLeBytes { .. } => Vec::new(), BlackBoxFuncCall::MultiScalarMul { points, scalars, .. } => { - let mut inputs: Vec = Vec::with_capacity(points.len() * 2); + let mut inputs: Vec> = Vec::with_capacity(points.len() * 2); inputs.extend(points.iter().copied()); inputs.extend(scalars.iter().copied()); inputs @@ -256,7 +293,7 @@ impl BlackBoxFuncCall { message, .. } => { - let mut inputs: Vec = + let mut inputs: Vec> = Vec::with_capacity(2 + signature.len() + message.len()); inputs.push(*public_key_x); inputs.push(*public_key_y); @@ -364,7 +401,7 @@ impl BlackBoxFuncCall { const ABBREVIATION_LIMIT: usize = 5; -fn get_inputs_string(inputs: &[FunctionInput]) -> String { +fn get_inputs_string(inputs: &[FunctionInput]) -> String { // Once a vectors length gets above this limit, // instead of listing all of their elements, we use ellipses // to abbreviate them @@ -373,7 +410,7 @@ fn get_inputs_string(inputs: &[FunctionInput]) -> String { if should_abbreviate_inputs { let mut result = String::new(); for (index, inp) in inputs.iter().enumerate() { - result += &format!("(_{}, num_bits: {})", inp.witness.witness_index(), inp.num_bits); + result += &format!("({})", inp); // Add a comma, unless it is the last entry if index != inputs.len() - 1 { result += ", "; @@ -385,14 +422,7 @@ fn get_inputs_string(inputs: &[FunctionInput]) -> String { let last = inputs.last().unwrap(); let mut result = String::new(); - - result += &format!( - "(_{}, num_bits: {})...(_{}, num_bits: {})", - first.witness.witness_index(), - first.num_bits, - last.witness.witness_index(), - last.num_bits, - ); + result += &format!("({})...({})", first, last,); result } @@ -421,7 +451,7 @@ fn get_outputs_string(outputs: &[Witness]) -> String { } } -impl std::fmt::Display for BlackBoxFuncCall { +impl std::fmt::Display for BlackBoxFuncCall { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { BlackBoxFuncCall::PedersenCommitment { .. } => { @@ -454,13 +484,16 @@ impl std::fmt::Display for BlackBoxFuncCall { } } -impl std::fmt::Debug for BlackBoxFuncCall { +impl std::fmt::Debug for BlackBoxFuncCall { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } } -fn serialize_big_array(big_array: &[FunctionInput; 64], s: S) -> Result +fn serialize_big_array( + big_array: &[FunctionInput; 64], + s: S, +) -> Result where S: Serializer, { @@ -469,15 +502,15 @@ where (*big_array).serialize(s) } -fn deserialize_big_array_into_box<'de, D>( +fn deserialize_big_array_into_box<'de, D, F: Deserialize<'de>>( deserializer: D, -) -> Result, D::Error> +) -> Result; 64]>, D::Error> where D: Deserializer<'de>, { use serde_big_array::BigArray; - let big_array: [FunctionInput; 64] = BigArray::deserialize(deserializer)?; + let big_array: [FunctionInput; 64] = BigArray::deserialize(deserializer)?; Ok(Box::new(big_array)) } @@ -487,26 +520,21 @@ mod tests { use crate::{circuit::Opcode, native_types::Witness}; use acir_field::{AcirField, FieldElement}; - use super::{BlackBoxFuncCall, FunctionInput}; + use super::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput}; fn keccakf1600_opcode() -> Opcode { - let inputs: Box<[FunctionInput; 25]> = Box::new(std::array::from_fn(|i| FunctionInput { - witness: Witness(i as u32 + 1), - num_bits: 8, - })); + let inputs: Box<[FunctionInput; 25]> = + Box::new(std::array::from_fn(|i| FunctionInput::witness(Witness(i as u32 + 1), 8))); let outputs: Box<[Witness; 25]> = Box::new(std::array::from_fn(|i| Witness(i as u32 + 26))); Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs }) } fn schnorr_verify_opcode() -> Opcode { - let public_key_x = - FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() }; - let public_key_y = - FunctionInput { witness: Witness(2), num_bits: FieldElement::max_num_bits() }; - let signature: Box<[FunctionInput; 64]> = Box::new(std::array::from_fn(|i| { - FunctionInput { witness: Witness(i as u32 + 3), num_bits: 8 } - })); - let message: Vec = vec![FunctionInput { witness: Witness(67), num_bits: 8 }]; + let public_key_x = FunctionInput::witness(Witness(1), FieldElement::max_num_bits()); + let public_key_y = FunctionInput::witness(Witness(2), FieldElement::max_num_bits()); + let signature: Box<[FunctionInput; 64]> = + Box::new(std::array::from_fn(|i| FunctionInput::witness(Witness(i as u32 + 3), 8))); + let message: Vec> = vec![FunctionInput::witness(Witness(67), 8)]; let output = Witness(68); Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SchnorrVerify { diff --git a/noir/noir-repo/acvm-repo/acir/src/lib.rs b/noir/noir-repo/acvm-repo/acir/src/lib.rs index f064cfaca0e..540e0f07eb5 100644 --- a/noir/noir-repo/acvm-repo/acir/src/lib.rs +++ b/noir/noir-repo/acvm-repo/acir/src/lib.rs @@ -42,7 +42,7 @@ mod reflection { circuit::{ brillig::{BrilligInputs, BrilligOutputs}, directives::Directive, - opcodes::{BlackBoxFuncCall, BlockType}, + opcodes::{BlackBoxFuncCall, BlockType, ConstantOrWitnessEnum, FunctionInput}, AssertionPayload, Circuit, ExpressionOrMemory, ExpressionWidth, Opcode, OpcodeLocation, Program, }, @@ -68,7 +68,9 @@ mod reflection { tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::().unwrap(); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); + tracer.trace_simple_type::>().unwrap(); + tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::().unwrap(); tracer.trace_simple_type::>().unwrap(); diff --git a/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs b/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs index 84a9aa719f2..3a42cc41d47 100644 --- a/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs +++ b/noir/noir-repo/acvm-repo/acir/tests/test_program_serialization.rs @@ -14,7 +14,7 @@ use std::collections::BTreeSet; use acir::{ circuit::{ brillig::{BrilligBytecode, BrilligInputs, BrilligOutputs}, - opcodes::{BlackBoxFuncCall, BlockId, FunctionInput, MemOp}, + opcodes::{BlackBoxFuncCall, BlockId, ConstantOrWitnessEnum, FunctionInput, MemOp}, Circuit, Opcode, Program, PublicInputs, }, native_types::{Expression, Witness}, @@ -62,13 +62,13 @@ fn multi_scalar_mul_circuit() { let multi_scalar_mul: Opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::MultiScalarMul { points: vec![ - FunctionInput { witness: Witness(1), num_bits: 128 }, - FunctionInput { witness: Witness(2), num_bits: 128 }, - FunctionInput { witness: Witness(3), num_bits: 1 }, + FunctionInput::witness(Witness(1), 128), + FunctionInput::witness(Witness(2), 128), + FunctionInput::witness(Witness(3), 1), ], scalars: vec![ - FunctionInput { witness: Witness(4), num_bits: 128 }, - FunctionInput { witness: Witness(5), num_bits: 128 }, + FunctionInput::witness(Witness(4), 128), + FunctionInput::witness(Witness(5), 128), ], outputs: (Witness(6), Witness(7), Witness(8)), }); @@ -91,10 +91,10 @@ fn multi_scalar_mul_circuit() { let bytes = Program::serialize_program(&program); let expected_serialization: Vec = vec![ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 141, 219, 10, 0, 32, 8, 67, 243, 214, 5, 250, 232, - 62, 189, 69, 123, 176, 132, 195, 116, 50, 149, 114, 107, 0, 97, 127, 116, 2, 75, 243, 2, - 74, 53, 122, 202, 189, 211, 15, 106, 5, 13, 116, 238, 35, 221, 81, 230, 61, 249, 37, 253, - 250, 179, 79, 109, 218, 22, 67, 227, 173, 0, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 141, 11, 10, 0, 32, 8, 67, 43, 181, 15, 116, 232, + 142, 158, 210, 130, 149, 240, 112, 234, 212, 156, 78, 12, 39, 67, 71, 158, 142, 80, 29, 44, + 228, 66, 90, 168, 119, 189, 74, 115, 131, 174, 78, 115, 58, 124, 70, 254, 130, 59, 74, 253, + 68, 255, 255, 221, 39, 54, 221, 93, 91, 132, 193, 0, 0, 0, ]; assert_eq!(bytes, expected_serialization) @@ -102,18 +102,15 @@ fn multi_scalar_mul_circuit() { #[test] fn schnorr_verify_circuit() { - let public_key_x = - FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() }; - let public_key_y = - FunctionInput { witness: Witness(2), num_bits: FieldElement::max_num_bits() }; - let signature: [FunctionInput; 64] = (3..(3 + 64)) - .map(|i| FunctionInput { witness: Witness(i), num_bits: 8 }) + let public_key_x = FunctionInput::witness(Witness(1), FieldElement::max_num_bits()); + let public_key_y = FunctionInput::witness(Witness(2), FieldElement::max_num_bits()); + let signature: [FunctionInput; 64] = (3..(3 + 64)) + .map(|i| FunctionInput::witness(Witness(i), 8)) .collect::>() .try_into() .unwrap(); - let message = ((3 + 64)..(3 + 64 + 10)) - .map(|i| FunctionInput { witness: Witness(i), num_bits: 8 }) - .collect(); + let message = + ((3 + 64)..(3 + 64 + 10)).map(|i| FunctionInput::witness(Witness(i), 8)).collect(); let output = Witness(3 + 64 + 10); let last_input = output.witness_index() - 1; @@ -137,22 +134,24 @@ fn schnorr_verify_circuit() { let bytes = Program::serialize_program(&program); let expected_serialization: Vec = vec![ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 85, 210, 85, 78, 67, 81, 24, 133, 209, 226, 238, 238, - 238, 238, 238, 165, 148, 82, 102, 193, 252, 135, 64, 232, 78, 87, 147, 114, 147, 147, 5, - 47, 132, 252, 251, 107, 41, 212, 191, 159, 218, 107, 241, 115, 236, 226, 111, 237, 181, - 178, 173, 246, 186, 107, 175, 157, 29, 236, 100, 23, 27, 175, 135, 189, 236, 99, 63, 7, 56, - 200, 33, 14, 115, 132, 163, 28, 227, 56, 39, 56, 201, 41, 78, 115, 134, 179, 156, 227, 60, - 23, 184, 200, 37, 46, 115, 133, 171, 92, 227, 58, 55, 184, 201, 45, 110, 115, 135, 187, - 220, 227, 62, 15, 120, 200, 35, 30, 243, 132, 167, 60, 227, 57, 47, 120, 201, 43, 94, 243, - 134, 183, 188, 227, 61, 31, 248, 200, 39, 62, 243, 133, 175, 77, 59, 230, 123, 243, 123, - 145, 239, 44, 241, 131, 101, 126, 178, 194, 47, 86, 249, 237, 239, 86, 153, 238, 210, 92, - 122, 75, 107, 233, 44, 141, 53, 250, 234, 241, 191, 164, 167, 180, 148, 142, 210, 80, 250, - 73, 59, 233, 38, 205, 164, 151, 180, 146, 78, 210, 72, 250, 72, 27, 233, 34, 77, 164, 135, - 180, 144, 14, 210, 64, 246, 95, 46, 212, 119, 207, 230, 217, 59, 91, 103, 231, 108, 156, - 125, 183, 237, 186, 107, 207, 125, 59, 30, 218, 239, 216, 110, 167, 246, 58, 183, 211, 165, - 125, 174, 237, 114, 107, 143, 123, 59, 60, 186, 255, 179, 187, 191, 186, 115, 209, 125, 75, - 238, 90, 118, 207, 138, 59, 54, 110, 214, 184, 91, 161, 233, 158, 255, 190, 63, 71, 59, 68, - 130, 233, 3, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 85, 211, 103, 78, 2, 81, 24, 70, 225, 193, 6, 216, 123, + 47, 216, 123, 239, 136, 136, 136, 136, 136, 187, 96, 255, 75, 32, 112, 194, 55, 201, 129, + 100, 50, 79, 244, 7, 228, 222, 243, 102, 146, 254, 167, 221, 123, 50, 97, 222, 217, 120, + 243, 116, 226, 61, 36, 15, 247, 158, 92, 120, 68, 30, 149, 199, 228, 172, 156, 147, 243, + 242, 184, 60, 33, 79, 202, 83, 242, 180, 60, 35, 207, 202, 115, 242, 188, 188, 32, 47, 202, + 75, 242, 178, 188, 34, 175, 202, 107, 242, 186, 188, 33, 111, 202, 91, 242, 182, 188, 35, + 23, 228, 93, 121, 79, 222, 151, 15, 228, 67, 249, 72, 62, 150, 79, 228, 83, 249, 76, 62, + 151, 47, 228, 75, 249, 74, 190, 150, 111, 228, 91, 249, 78, 190, 151, 31, 228, 71, 249, 73, + 126, 150, 95, 228, 87, 185, 40, 191, 201, 37, 249, 93, 46, 203, 31, 114, 69, 254, 148, 171, + 97, 58, 77, 226, 111, 95, 250, 127, 77, 254, 150, 235, 242, 143, 220, 144, 127, 229, 166, + 252, 39, 183, 194, 255, 241, 253, 45, 253, 14, 182, 201, 38, 217, 34, 27, 100, 123, 233, + 230, 242, 241, 155, 217, 20, 91, 98, 67, 108, 135, 205, 176, 21, 54, 194, 54, 216, 4, 91, + 96, 3, 180, 79, 243, 180, 78, 227, 180, 77, 211, 180, 76, 195, 180, 75, 179, 133, 164, 223, + 40, 109, 210, 36, 45, 210, 32, 237, 209, 28, 173, 209, 24, 109, 209, 20, 45, 209, 16, 237, + 208, 12, 173, 208, 8, 109, 208, 4, 45, 208, 0, 119, 207, 157, 115, 215, 220, 113, 49, 238, + 180, 20, 119, 88, 142, 59, 171, 196, 29, 85, 227, 46, 106, 113, 246, 245, 56, 235, 70, 156, + 109, 51, 206, 50, 61, 179, 244, 220, 18, 157, 231, 192, 167, 11, 75, 28, 99, 152, 25, 5, 0, + 0, ]; assert_eq!(bytes, expected_serialization) diff --git a/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs b/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs index 7001c953d63..87a026148d7 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs @@ -1,6 +1,6 @@ use acir::{ circuit::{ - opcodes::{BlackBoxFuncCall, FunctionInput}, + opcodes::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput}, Circuit, Opcode, }, native_types::Witness, @@ -74,7 +74,8 @@ impl RangeOptimizer { } Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { witness, num_bits }, + input: + FunctionInput { input: ConstantOrWitnessEnum::Witness(witness), num_bits }, }) => Some((*witness, *num_bits)), _ => None, @@ -105,14 +106,15 @@ impl RangeOptimizer { let mut new_order_list = Vec::with_capacity(order_list.len()); let mut optimized_opcodes = Vec::with_capacity(self.circuit.opcodes.len()); for (idx, opcode) in self.circuit.opcodes.into_iter().enumerate() { - let (witness, num_bits) = match &opcode { - Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => { - (input.witness, input.num_bits) - } + let (witness, num_bits) = match opcode { + Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: + FunctionInput { input: ConstantOrWitnessEnum::Witness(w), num_bits: bits }, + }) => (w, bits), _ => { // If its not the range opcode, add it to the opcode // list and continue; - optimized_opcodes.push(opcode); + optimized_opcodes.push(opcode.clone()); new_order_list.push(order_list[idx]); continue; } @@ -133,7 +135,7 @@ impl RangeOptimizer { if is_lowest_bit_size { already_seen_witness.insert(witness); new_order_list.push(order_list[idx]); - optimized_opcodes.push(opcode); + optimized_opcodes.push(opcode.clone()); } } @@ -148,7 +150,7 @@ mod tests { use crate::compiler::optimizers::redundant_range::RangeOptimizer; use acir::{ circuit::{ - opcodes::{BlackBoxFuncCall, FunctionInput}, + opcodes::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput}, Circuit, ExpressionWidth, Opcode, PublicInputs, }, native_types::{Expression, Witness}, @@ -158,7 +160,7 @@ mod tests { fn test_circuit(ranges: Vec<(Witness, u32)>) -> Circuit { fn test_range_constraint(witness: Witness, num_bits: u32) -> Opcode { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { witness, num_bits }, + input: FunctionInput::witness(witness, num_bits), }) } @@ -201,7 +203,7 @@ mod tests { assert_eq!( optimized_circuit.opcodes[0], Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { witness: Witness(1), num_bits: 16 } + input: FunctionInput::witness(Witness(1), 16) }) ); } @@ -224,13 +226,13 @@ mod tests { assert_eq!( optimized_circuit.opcodes[0], Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { witness: Witness(1), num_bits: 16 } + input: FunctionInput::witness(Witness(1), 16) }) ); assert_eq!( optimized_circuit.opcodes[1], Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { witness: Witness(2), num_bits: 23 } + input: FunctionInput::witness(Witness(2), 23) }) ); } diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/aes128.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/aes128.rs index 181a78a2a6a..e3c8dc78aa6 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/aes128.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/aes128.rs @@ -11,9 +11,9 @@ use super::utils::{to_u8_array, to_u8_vec}; pub(super) fn solve_aes128_encryption_opcode( initial_witness: &mut WitnessMap, - inputs: &[FunctionInput], - iv: &[FunctionInput; 16], - key: &[FunctionInput; 16], + inputs: &[FunctionInput], + iv: &[FunctionInput; 16], + key: &[FunctionInput; 16], outputs: &[Witness], ) -> Result<(), OpcodeResolutionError> { let scalars = to_u8_vec(initial_witness, inputs)?; diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/bigint.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/bigint.rs index be5a4613a55..1bce4aa6c5e 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/bigint.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/bigint.rs @@ -1,9 +1,9 @@ +use crate::pwg::input_to_value; use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, AcirField, BlackBoxFunc, }; - use acvm_blackbox_solver::BigIntSolver; use crate::pwg::OpcodeResolutionError; @@ -20,14 +20,14 @@ pub(crate) struct AcvmBigIntSolver { impl AcvmBigIntSolver { pub(crate) fn bigint_from_bytes( &mut self, - inputs: &[FunctionInput], + inputs: &[FunctionInput], modulus: &[u8], output: u32, initial_witness: &mut WitnessMap, ) -> Result<(), OpcodeResolutionError> { let bytes = inputs .iter() - .map(|input| initial_witness.get(&input.witness).unwrap().to_u128() as u8) + .map(|input| input_to_value(initial_witness, *input).unwrap().to_u128() as u8) .collect::>(); self.bigint_solver.bigint_from_bytes(&bytes, modulus, output)?; Ok(()) diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs index 411a6d1b737..c290faeaa4a 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs @@ -5,28 +5,28 @@ use acir::{ }; use acvm_blackbox_solver::BlackBoxFunctionSolver; -use crate::pwg::{insert_value, witness_to_value, OpcodeResolutionError}; +use crate::pwg::{input_to_value, insert_value, OpcodeResolutionError}; pub(super) fn multi_scalar_mul( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, - points: &[FunctionInput], - scalars: &[FunctionInput], + points: &[FunctionInput], + scalars: &[FunctionInput], outputs: (Witness, Witness, Witness), ) -> Result<(), OpcodeResolutionError> { let points: Result, _> = - points.iter().map(|input| witness_to_value(initial_witness, input.witness)).collect(); - let points: Vec<_> = points?.into_iter().cloned().collect(); + points.iter().map(|input| input_to_value(initial_witness, *input)).collect(); + let points: Vec<_> = points?.into_iter().collect(); let scalars: Result, _> = - scalars.iter().map(|input| witness_to_value(initial_witness, input.witness)).collect(); + scalars.iter().map(|input| input_to_value(initial_witness, *input)).collect(); let mut scalars_lo = Vec::new(); let mut scalars_hi = Vec::new(); for (i, scalar) in scalars?.into_iter().enumerate() { if i % 2 == 0 { - scalars_lo.push(*scalar); + scalars_lo.push(scalar); } else { - scalars_hi.push(*scalar); + scalars_hi.push(scalar); } } // Call the backend's multi-scalar multiplication function @@ -43,18 +43,24 @@ pub(super) fn multi_scalar_mul( pub(super) fn embedded_curve_add( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, - input1: [FunctionInput; 3], - input2: [FunctionInput; 3], + input1: [FunctionInput; 3], + input2: [FunctionInput; 3], outputs: (Witness, Witness, Witness), ) -> Result<(), OpcodeResolutionError> { - let input1_x = witness_to_value(initial_witness, input1[0].witness)?; - let input1_y = witness_to_value(initial_witness, input1[1].witness)?; - let input1_infinite = witness_to_value(initial_witness, input1[2].witness)?; - let input2_x = witness_to_value(initial_witness, input2[0].witness)?; - let input2_y = witness_to_value(initial_witness, input2[1].witness)?; - let input2_infinite = witness_to_value(initial_witness, input2[2].witness)?; - let (res_x, res_y, res_infinite) = - backend.ec_add(input1_x, input1_y, input1_infinite, input2_x, input2_y, input2_infinite)?; + let input1_x = input_to_value(initial_witness, input1[0])?; + let input1_y = input_to_value(initial_witness, input1[1])?; + let input1_infinite = input_to_value(initial_witness, input1[2])?; + let input2_x = input_to_value(initial_witness, input2[0])?; + let input2_y = input_to_value(initial_witness, input2[1])?; + let input2_infinite = input_to_value(initial_witness, input2[2])?; + let (res_x, res_y, res_infinite) = backend.ec_add( + &input1_x, + &input1_y, + &input1_infinite, + &input2_x, + &input2_y, + &input2_infinite, + )?; insert_value(&outputs.0, res_x, initial_witness)?; insert_value(&outputs.1, res_y, initial_witness)?; diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/hash.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/hash.rs index fe9bd46b091..b51139f76b7 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/hash.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/hash.rs @@ -5,15 +5,15 @@ use acir::{ }; use acvm_blackbox_solver::{sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError}; -use crate::pwg::{insert_value, witness_to_value}; +use crate::pwg::{input_to_value, insert_value}; use crate::OpcodeResolutionError; /// Attempts to solve a 256 bit hash function opcode. /// If successful, `initial_witness` will be mutated to contain the new witness assignment. pub(super) fn solve_generic_256_hash_opcode( initial_witness: &mut WitnessMap, - inputs: &[FunctionInput], - var_message_size: Option<&FunctionInput>, + inputs: &[FunctionInput], + var_message_size: Option<&FunctionInput>, outputs: &[Witness; 32], hash_function: fn(data: &[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, ) -> Result<(), OpcodeResolutionError> { @@ -26,16 +26,15 @@ pub(super) fn solve_generic_256_hash_opcode( /// Reads the hash function input from a [`WitnessMap`]. fn get_hash_input( initial_witness: &WitnessMap, - inputs: &[FunctionInput], - message_size: Option<&FunctionInput>, + inputs: &[FunctionInput], + message_size: Option<&FunctionInput>, ) -> Result, OpcodeResolutionError> { // Read witness assignments. let mut message_input = Vec::new(); for input in inputs.iter() { - let witness = input.witness; - let num_bits = input.num_bits as usize; + let num_bits = input.num_bits() as usize; - let witness_assignment = witness_to_value(initial_witness, witness)?; + let witness_assignment = input_to_value(initial_witness, *input)?; let bytes = witness_assignment.fetch_nearest_bytes(num_bits); message_input.extend(bytes); } @@ -43,8 +42,7 @@ fn get_hash_input( // Truncate the message if there is a `message_size` parameter given match message_size { Some(input) => { - let num_bytes_to_take = - witness_to_value(initial_witness, input.witness)?.to_u128() as usize; + let num_bytes_to_take = input_to_value(initial_witness, *input)?.to_u128() as usize; // If the number of bytes to take is more than the amount of bytes available // in the message, then we error. @@ -76,11 +74,11 @@ fn write_digest_to_outputs( fn to_u32_array( initial_witness: &WitnessMap, - inputs: &[FunctionInput; N], + inputs: &[FunctionInput; N], ) -> Result<[u32; N], OpcodeResolutionError> { let mut result = [0; N]; for (it, input) in result.iter_mut().zip(inputs) { - let witness_value = witness_to_value(initial_witness, input.witness)?; + let witness_value = input_to_value(initial_witness, *input)?; *it = witness_value.to_u128() as u32; } Ok(result) @@ -88,8 +86,8 @@ fn to_u32_array( pub(crate) fn solve_sha_256_permutation_opcode( initial_witness: &mut WitnessMap, - inputs: &[FunctionInput; 16], - hash_values: &[FunctionInput; 8], + inputs: &[FunctionInput; 16], + hash_values: &[FunctionInput; 8], outputs: &[Witness; 8], ) -> Result<(), OpcodeResolutionError> { let message = to_u32_array(initial_witness, inputs)?; @@ -107,7 +105,7 @@ pub(crate) fn solve_sha_256_permutation_opcode( pub(crate) fn solve_poseidon2_permutation_opcode( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, - inputs: &[FunctionInput], + inputs: &[FunctionInput], outputs: &[Witness], len: u32, ) -> Result<(), OpcodeResolutionError> { @@ -135,8 +133,8 @@ pub(crate) fn solve_poseidon2_permutation_opcode( // Read witness assignments let mut state = Vec::new(); for input in inputs.iter() { - let witness_assignment = witness_to_value(initial_witness, input.witness)?; - state.push(*witness_assignment); + let witness_assignment = input_to_value(initial_witness, *input)?; + state.push(witness_assignment); } let state = backend.poseidon2_permutation(&state, len)?; diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/logic.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/logic.rs index 5be888c8ac6..7ce0827d932 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/logic.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/logic.rs @@ -1,4 +1,4 @@ -use crate::pwg::{insert_value, witness_to_value}; +use crate::pwg::{input_to_value, insert_value}; use crate::OpcodeResolutionError; use acir::{ circuit::opcodes::FunctionInput, @@ -11,16 +11,17 @@ use acvm_blackbox_solver::{bit_and, bit_xor}; /// the result into the supplied witness map pub(super) fn and( initial_witness: &mut WitnessMap, - lhs: &FunctionInput, - rhs: &FunctionInput, + lhs: &FunctionInput, + rhs: &FunctionInput, output: &Witness, ) -> Result<(), OpcodeResolutionError> { assert_eq!( - lhs.num_bits, rhs.num_bits, + lhs.num_bits(), + rhs.num_bits(), "number of bits specified for each input must be the same" ); - solve_logic_opcode(initial_witness, &lhs.witness, &rhs.witness, *output, |left, right| { - bit_and(left, right, lhs.num_bits) + solve_logic_opcode(initial_witness, lhs, rhs, *output, |left, right| { + bit_and(left, right, lhs.num_bits()) }) } @@ -28,30 +29,31 @@ pub(super) fn and( /// the result into the supplied witness map pub(super) fn xor( initial_witness: &mut WitnessMap, - lhs: &FunctionInput, - rhs: &FunctionInput, + lhs: &FunctionInput, + rhs: &FunctionInput, output: &Witness, ) -> Result<(), OpcodeResolutionError> { assert_eq!( - lhs.num_bits, rhs.num_bits, + lhs.num_bits(), + rhs.num_bits(), "number of bits specified for each input must be the same" ); - solve_logic_opcode(initial_witness, &lhs.witness, &rhs.witness, *output, |left, right| { - bit_xor(left, right, lhs.num_bits) + solve_logic_opcode(initial_witness, lhs, rhs, *output, |left, right| { + bit_xor(left, right, lhs.num_bits()) }) } /// Derives the rest of the witness based on the initial low level variables fn solve_logic_opcode( initial_witness: &mut WitnessMap, - a: &Witness, - b: &Witness, + a: &FunctionInput, + b: &FunctionInput, result: Witness, logic_op: impl Fn(F, F) -> F, ) -> Result<(), OpcodeResolutionError> { - let w_l_value = witness_to_value(initial_witness, *a)?; - let w_r_value = witness_to_value(initial_witness, *b)?; - let assignment = logic_op(*w_l_value, *w_r_value); + let w_l_value = input_to_value(initial_witness, *a)?; + let w_r_value = input_to_value(initial_witness, *b)?; + let assignment = logic_op(w_l_value, w_r_value); insert_value(&result, assignment, initial_witness) } diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs index 0c65759ebcd..def4216fe15 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/mod.rs @@ -1,5 +1,5 @@ use acir::{ - circuit::opcodes::{BlackBoxFuncCall, FunctionInput}, + circuit::opcodes::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput}, native_types::{Witness, WitnessMap}, AcirField, }; @@ -11,7 +11,7 @@ use self::{ }; use super::{insert_value, OpcodeNotSolvable, OpcodeResolutionError}; -use crate::{pwg::witness_to_value, BlackBoxFunctionSolver}; +use crate::{pwg::input_to_value, BlackBoxFunctionSolver}; mod aes128; pub(crate) mod bigint; @@ -39,26 +39,33 @@ use signature::{ /// Returns the first missing assignment if any are missing fn first_missing_assignment( witness_assignments: &WitnessMap, - inputs: &[FunctionInput], + inputs: &[FunctionInput], ) -> Option { inputs.iter().find_map(|input| { - if witness_assignments.contains_key(&input.witness) { - None + if let ConstantOrWitnessEnum::Witness(witness) = input.input { + if witness_assignments.contains_key(&witness) { + None + } else { + Some(witness) + } } else { - Some(input.witness) + None } }) } /// Check if all of the inputs to the function have assignments -fn contains_all_inputs(witness_assignments: &WitnessMap, inputs: &[FunctionInput]) -> bool { - inputs.iter().all(|input| witness_assignments.contains_key(&input.witness)) +fn contains_all_inputs( + witness_assignments: &WitnessMap, + inputs: &[FunctionInput], +) -> bool { + first_missing_assignment(witness_assignments, inputs).is_none() } pub(crate) fn solve( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, - bb_func: &BlackBoxFuncCall, + bb_func: &BlackBoxFuncCall, bigint_solver: &mut AcvmBigIntSolver, ) -> Result<(), OpcodeResolutionError> { let inputs = bb_func.get_inputs_vec(); @@ -99,10 +106,9 @@ pub(crate) fn solve( BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => { let mut state = [0; 25]; for (it, input) in state.iter_mut().zip(inputs.as_ref()) { - let witness = input.witness; - let num_bits = input.num_bits as usize; + let num_bits = input.num_bits() as usize; assert_eq!(num_bits, 64); - let witness_assignment = witness_to_value(initial_witness, witness)?; + let witness_assignment = input_to_value(initial_witness, *input)?; let lane = witness_assignment.try_to_u64(); *it = lane.unwrap(); } diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs index f64a3a79465..b1b95393b19 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs @@ -5,20 +5,20 @@ use acir::{ }; use crate::{ - pwg::{insert_value, witness_to_value, OpcodeResolutionError}, + pwg::{input_to_value, insert_value, OpcodeResolutionError}, BlackBoxFunctionSolver, }; pub(super) fn pedersen( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, - inputs: &[FunctionInput], + inputs: &[FunctionInput], domain_separator: u32, outputs: (Witness, Witness), ) -> Result<(), OpcodeResolutionError> { let scalars: Result, _> = - inputs.iter().map(|input| witness_to_value(initial_witness, input.witness)).collect(); - let scalars: Vec<_> = scalars?.into_iter().cloned().collect(); + inputs.iter().map(|input| input_to_value(initial_witness, *input)).collect(); + let scalars: Vec<_> = scalars?.into_iter().collect(); let (res_x, res_y) = backend.pedersen_commitment(&scalars, domain_separator)?; @@ -31,13 +31,13 @@ pub(super) fn pedersen( pub(super) fn pedersen_hash( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, - inputs: &[FunctionInput], + inputs: &[FunctionInput], domain_separator: u32, output: Witness, ) -> Result<(), OpcodeResolutionError> { let scalars: Result, _> = - inputs.iter().map(|input| witness_to_value(initial_witness, input.witness)).collect(); - let scalars: Vec<_> = scalars?.into_iter().cloned().collect(); + inputs.iter().map(|input| input_to_value(initial_witness, *input)).collect(); + let scalars: Vec<_> = scalars?.into_iter().collect(); let res = backend.pedersen_hash(&scalars, domain_separator)?; diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/range.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/range.rs index 0ca001aff7a..054730bb6c0 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/range.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/range.rs @@ -1,15 +1,15 @@ use crate::{ - pwg::{witness_to_value, ErrorLocation}, + pwg::{input_to_value, ErrorLocation}, OpcodeResolutionError, }; use acir::{circuit::opcodes::FunctionInput, native_types::WitnessMap, AcirField}; pub(crate) fn solve_range_opcode( initial_witness: &WitnessMap, - input: &FunctionInput, + input: &FunctionInput, ) -> Result<(), OpcodeResolutionError> { - let w_value = witness_to_value(initial_witness, input.witness)?; - if w_value.num_bits() > input.num_bits { + let w_value = input_to_value(initial_witness, *input)?; + if w_value.num_bits() > input.num_bits() { return Err(OpcodeResolutionError::UnsatisfiedConstrain { opcode_location: ErrorLocation::Unresolved, payload: None, diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs index 707e3f26af0..db92d27b871 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs @@ -15,10 +15,10 @@ use crate::{ pub(crate) fn secp256k1_prehashed( initial_witness: &mut WitnessMap, - public_key_x_inputs: &[FunctionInput; 32], - public_key_y_inputs: &[FunctionInput; 32], - signature_inputs: &[FunctionInput; 64], - hashed_message_inputs: &[FunctionInput], + public_key_x_inputs: &[FunctionInput; 32], + public_key_y_inputs: &[FunctionInput; 32], + signature_inputs: &[FunctionInput; 64], + hashed_message_inputs: &[FunctionInput], output: Witness, ) -> Result<(), OpcodeResolutionError> { let hashed_message = to_u8_vec(initial_witness, hashed_message_inputs)?; @@ -34,10 +34,10 @@ pub(crate) fn secp256k1_prehashed( pub(crate) fn secp256r1_prehashed( initial_witness: &mut WitnessMap, - public_key_x_inputs: &[FunctionInput; 32], - public_key_y_inputs: &[FunctionInput; 32], - signature_inputs: &[FunctionInput; 64], - hashed_message_inputs: &[FunctionInput], + public_key_x_inputs: &[FunctionInput; 32], + public_key_y_inputs: &[FunctionInput; 32], + signature_inputs: &[FunctionInput; 64], + hashed_message_inputs: &[FunctionInput], output: Witness, ) -> Result<(), OpcodeResolutionError> { let hashed_message = to_u8_vec(initial_witness, hashed_message_inputs)?; diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs index 5e0ac94f8be..4f8e88373ba 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs @@ -1,7 +1,7 @@ use crate::{ pwg::{ blackbox::utils::{to_u8_array, to_u8_vec}, - insert_value, witness_to_value, OpcodeResolutionError, + input_to_value, insert_value, OpcodeResolutionError, }, BlackBoxFunctionSolver, }; @@ -15,14 +15,14 @@ use acir::{ pub(crate) fn schnorr_verify( backend: &impl BlackBoxFunctionSolver, initial_witness: &mut WitnessMap, - public_key_x: FunctionInput, - public_key_y: FunctionInput, - signature: &[FunctionInput; 64], - message: &[FunctionInput], + public_key_x: FunctionInput, + public_key_y: FunctionInput, + signature: &[FunctionInput; 64], + message: &[FunctionInput], output: Witness, ) -> Result<(), OpcodeResolutionError> { - let public_key_x: &F = witness_to_value(initial_witness, public_key_x.witness)?; - let public_key_y: &F = witness_to_value(initial_witness, public_key_y.witness)?; + let public_key_x: &F = &input_to_value(initial_witness, public_key_x)?; + let public_key_y: &F = &input_to_value(initial_witness, public_key_y)?; let signature = to_u8_array(initial_witness, signature)?; let message = to_u8_vec(initial_witness, message)?; diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/utils.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/utils.rs index 6880d21a324..9b9157421e5 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/utils.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/blackbox/utils.rs @@ -1,14 +1,14 @@ use acir::{circuit::opcodes::FunctionInput, native_types::WitnessMap, AcirField}; -use crate::pwg::{witness_to_value, OpcodeResolutionError}; +use crate::pwg::{input_to_value, OpcodeResolutionError}; pub(crate) fn to_u8_array( initial_witness: &WitnessMap, - inputs: &[FunctionInput; N], + inputs: &[FunctionInput; N], ) -> Result<[u8; N], OpcodeResolutionError> { let mut result = [0; N]; for (it, input) in result.iter_mut().zip(inputs) { - let witness_value_bytes = witness_to_value(initial_witness, input.witness)?.to_be_bytes(); + let witness_value_bytes = input_to_value(initial_witness, *input)?.to_be_bytes(); let byte = witness_value_bytes .last() .expect("Field element must be represented by non-zero amount of bytes"); @@ -19,11 +19,11 @@ pub(crate) fn to_u8_array( pub(crate) fn to_u8_vec( initial_witness: &WitnessMap, - inputs: &[FunctionInput], + inputs: &[FunctionInput], ) -> Result, OpcodeResolutionError> { let mut result = Vec::with_capacity(inputs.len()); for input in inputs { - let witness_value_bytes = witness_to_value(initial_witness, input.witness)?.to_be_bytes(); + let witness_value_bytes = input_to_value(initial_witness, *input)?.to_be_bytes(); let byte = witness_value_bytes .last() .expect("Field element must be represented by non-zero amount of bytes"); diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs index 4f88e17d109..4292d72fad5 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/mod.rs @@ -5,9 +5,10 @@ use std::collections::HashMap; use acir::{ brillig::ForeignCallResult, circuit::{ - brillig::BrilligBytecode, opcodes::BlockId, AssertionPayload, ErrorSelector, - ExpressionOrMemory, Opcode, OpcodeLocation, RawAssertionPayload, ResolvedAssertionPayload, - STRING_ERROR_SELECTOR, + brillig::BrilligBytecode, + opcodes::{BlockId, ConstantOrWitnessEnum, FunctionInput}, + AssertionPayload, ErrorSelector, ExpressionOrMemory, Opcode, OpcodeLocation, + RawAssertionPayload, ResolvedAssertionPayload, STRING_ERROR_SELECTOR, }, native_types::{Expression, Witness, WitnessMap}, AcirField, BlackBoxFunc, @@ -629,6 +630,16 @@ pub fn witness_to_value( } } +pub fn input_to_value( + initial_witness: &WitnessMap, + input: FunctionInput, +) -> Result> { + match input.input { + ConstantOrWitnessEnum::Witness(witness) => Ok(*witness_to_value(initial_witness, witness)?), + ConstantOrWitnessEnum::Constant(value) => Ok(value), + } +} + // TODO: There is an issue open to decide on whether we need to get values from Expressions // TODO versus just getting values from Witness pub fn get_value( diff --git a/noir/noir-repo/acvm-repo/acvm_js/test/shared/multi_scalar_mul.ts b/noir/noir-repo/acvm-repo/acvm_js/test/shared/multi_scalar_mul.ts index 5401da76974..80fbf14e8f1 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/test/shared/multi_scalar_mul.ts +++ b/noir/noir-repo/acvm-repo/acvm_js/test/shared/multi_scalar_mul.ts @@ -1,8 +1,8 @@ // See `multi_scalar_mul_circuit` integration test in `acir/tests/test_program_serialization.rs`. export const bytecode = Uint8Array.from([ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 141, 219, 10, 0, 32, 8, 67, 243, 214, 5, 250, 232, 62, 189, 69, 123, 176, 132, - 195, 116, 50, 149, 114, 107, 0, 97, 127, 116, 2, 75, 243, 2, 74, 53, 122, 202, 189, 211, 15, 106, 5, 13, 116, 238, 35, - 221, 81, 230, 61, 249, 37, 253, 250, 179, 79, 109, 218, 22, 67, 227, 173, 0, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 141, 11, 10, 0, 32, 8, 67, 43, 181, 15, 116, 232, 142, 158, 210, 130, 149, 240, + 112, 234, 212, 156, 78, 12, 39, 67, 71, 158, 142, 80, 29, 44, 228, 66, 90, 168, 119, 189, 74, 115, 131, 174, 78, 115, + 58, 124, 70, 254, 130, 59, 74, 253, 68, 255, 255, 221, 39, 54, 221, 93, 91, 132, 193, 0, 0, 0, ]); export const initialWitnessMap = new Map([ [1, '0x0000000000000000000000000000000000000000000000000000000000000001'], diff --git a/noir/noir-repo/acvm-repo/acvm_js/test/shared/schnorr_verify.ts b/noir/noir-repo/acvm-repo/acvm_js/test/shared/schnorr_verify.ts index 05fcc47e3aa..c071c86f61f 100644 --- a/noir/noir-repo/acvm-repo/acvm_js/test/shared/schnorr_verify.ts +++ b/noir/noir-repo/acvm-repo/acvm_js/test/shared/schnorr_verify.ts @@ -1,17 +1,19 @@ // See `schnorr_verify_circuit` integration test in `acir/tests/test_program_serialization.rs`. export const bytecode = Uint8Array.from([ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 85, 210, 85, 78, 67, 81, 24, 133, 209, 226, 238, 238, 238, 238, 238, 165, 148, 82, - 102, 193, 252, 135, 64, 232, 78, 87, 147, 114, 147, 147, 5, 47, 132, 252, 251, 107, 41, 212, 191, 159, 218, 107, 241, - 115, 236, 226, 111, 237, 181, 178, 173, 246, 186, 107, 175, 157, 29, 236, 100, 23, 27, 175, 135, 189, 236, 99, 63, 7, - 56, 200, 33, 14, 115, 132, 163, 28, 227, 56, 39, 56, 201, 41, 78, 115, 134, 179, 156, 227, 60, 23, 184, 200, 37, 46, - 115, 133, 171, 92, 227, 58, 55, 184, 201, 45, 110, 115, 135, 187, 220, 227, 62, 15, 120, 200, 35, 30, 243, 132, 167, - 60, 227, 57, 47, 120, 201, 43, 94, 243, 134, 183, 188, 227, 61, 31, 248, 200, 39, 62, 243, 133, 175, 77, 59, 230, 123, - 243, 123, 145, 239, 44, 241, 131, 101, 126, 178, 194, 47, 86, 249, 237, 239, 86, 153, 238, 210, 92, 122, 75, 107, 233, - 44, 141, 53, 250, 234, 241, 191, 164, 167, 180, 148, 142, 210, 80, 250, 73, 59, 233, 38, 205, 164, 151, 180, 146, 78, - 210, 72, 250, 72, 27, 233, 34, 77, 164, 135, 180, 144, 14, 210, 64, 246, 95, 46, 212, 119, 207, 230, 217, 59, 91, 103, - 231, 108, 156, 125, 183, 237, 186, 107, 207, 125, 59, 30, 218, 239, 216, 110, 167, 246, 58, 183, 211, 165, 125, 174, - 237, 114, 107, 143, 123, 59, 60, 186, 255, 179, 187, 191, 186, 115, 209, 125, 75, 238, 90, 118, 207, 138, 59, 54, 110, - 214, 184, 91, 161, 233, 158, 255, 190, 63, 71, 59, 68, 130, 233, 3, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 85, 211, 103, 78, 2, 81, 24, 70, 225, 193, 6, 216, 123, 47, 216, 123, 239, 136, + 136, 136, 136, 136, 187, 96, 255, 75, 32, 112, 194, 55, 201, 129, 100, 50, 79, 244, 7, 228, 222, 243, 102, 146, 254, + 167, 221, 123, 50, 97, 222, 217, 120, 243, 116, 226, 61, 36, 15, 247, 158, 92, 120, 68, 30, 149, 199, 228, 172, 156, + 147, 243, 242, 184, 60, 33, 79, 202, 83, 242, 180, 60, 35, 207, 202, 115, 242, 188, 188, 32, 47, 202, 75, 242, 178, + 188, 34, 175, 202, 107, 242, 186, 188, 33, 111, 202, 91, 242, 182, 188, 35, 23, 228, 93, 121, 79, 222, 151, 15, 228, + 67, 249, 72, 62, 150, 79, 228, 83, 249, 76, 62, 151, 47, 228, 75, 249, 74, 190, 150, 111, 228, 91, 249, 78, 190, 151, + 31, 228, 71, 249, 73, 126, 150, 95, 228, 87, 185, 40, 191, 201, 37, 249, 93, 46, 203, 31, 114, 69, 254, 148, 171, 97, + 58, 77, 226, 111, 95, 250, 127, 77, 254, 150, 235, 242, 143, 220, 144, 127, 229, 166, 252, 39, 183, 194, 255, 241, + 253, 45, 253, 14, 182, 201, 38, 217, 34, 27, 100, 123, 233, 230, 242, 241, 155, 217, 20, 91, 98, 67, 108, 135, 205, + 176, 21, 54, 194, 54, 216, 4, 91, 96, 3, 180, 79, 243, 180, 78, 227, 180, 77, 211, 180, 76, 195, 180, 75, 179, 133, + 164, 223, 40, 109, 210, 36, 45, 210, 32, 237, 209, 28, 173, 209, 24, 109, 209, 20, 45, 209, 16, 237, 208, 12, 173, + 208, 8, 109, 208, 4, 45, 208, 0, 119, 207, 157, 115, 215, 220, 113, 49, 238, 180, 20, 119, 88, 142, 59, 171, 196, 29, + 85, 227, 46, 106, 113, 246, 245, 56, 235, 70, 156, 109, 51, 206, 50, 61, 179, 244, 220, 18, 157, 231, 192, 167, 11, + 75, 28, 99, 152, 25, 5, 0, 0, ]); export const initialWitnessMap = new Map([ diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 928a7b105ea..74149af25ef 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -1366,9 +1366,10 @@ impl AcirContext { } _ => (vec![], vec![]), }; - + // Allow constant inputs only for MSM for now + let allow_constant_inputs = name.eq(&BlackBoxFunc::MultiScalarMul); // Convert `AcirVar` to `FunctionInput` - let inputs = self.prepare_inputs_for_black_box_func_call(inputs)?; + let inputs = self.prepare_inputs_for_black_box_func_call(inputs, allow_constant_inputs)?; // Call Black box with `FunctionInput` let mut results = vecmap(&constant_outputs, |c| self.add_constant(*c)); let outputs = self.acir_ir.call_black_box( @@ -1396,18 +1397,23 @@ impl AcirContext { fn prepare_inputs_for_black_box_func_call( &mut self, inputs: Vec, - ) -> Result>, RuntimeError> { + allow_constant_inputs: bool, + ) -> Result>>, RuntimeError> { let mut witnesses = Vec::new(); for input in inputs { let mut single_val_witnesses = Vec::new(); for (input, typ) in self.flatten(input)? { - // Intrinsics only accept Witnesses. This is not a limitation of the - // intrinsics, its just how we have defined things. Ideally, we allow - // constants too. - let witness_var = self.get_or_create_witness_var(input)?; - let witness = self.var_to_witness(witness_var)?; let num_bits = typ.bit_size::(); - single_val_witnesses.push(FunctionInput { witness, num_bits }); + match self.vars[&input].as_constant() { + Some(constant) if allow_constant_inputs => { + single_val_witnesses.push(FunctionInput::constant(*constant, num_bits)); + } + _ => { + let witness_var = self.get_or_create_witness_var(input)?; + let witness = self.var_to_witness(witness_var)?; + single_val_witnesses.push(FunctionInput::witness(witness, num_bits)); + } + } } witnesses.push(single_val_witnesses); } @@ -1896,10 +1902,10 @@ impl AcirContext { output_count: usize, predicate: AcirVar, ) -> Result, RuntimeError> { - let inputs = self.prepare_inputs_for_black_box_func_call(inputs)?; + let inputs = self.prepare_inputs_for_black_box_func_call(inputs, false)?; let inputs = inputs .iter() - .flat_map(|input| vecmap(input, |input| input.witness)) + .flat_map(|input| vecmap(input, |input| input.to_witness())) .collect::>(); let outputs = vecmap(0..output_count, |_| self.acir_ir.next_witness_index()); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index bcccfac8950..9d29d1d24d6 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -165,7 +165,7 @@ impl GeneratedAcir { pub(crate) fn call_black_box( &mut self, func_name: BlackBoxFunc, - inputs: &[Vec], + inputs: &[Vec>], constant_inputs: Vec, constant_outputs: Vec, output_count: usize, @@ -571,7 +571,7 @@ impl GeneratedAcir { }; let constraint = AcirOpcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { witness, num_bits }, + input: FunctionInput::witness(witness, num_bits), }); self.push_opcode(constraint); diff --git a/noir/noir-repo/tooling/fuzzer/src/dictionary/mod.rs b/noir/noir-repo/tooling/fuzzer/src/dictionary/mod.rs index bf2ab87be29..a45b9c3abb2 100644 --- a/noir/noir-repo/tooling/fuzzer/src/dictionary/mod.rs +++ b/noir/noir-repo/tooling/fuzzer/src/dictionary/mod.rs @@ -10,7 +10,7 @@ use acvm::{ circuit::{ brillig::{BrilligBytecode, BrilligInputs}, directives::Directive, - opcodes::{BlackBoxFuncCall, FunctionInput}, + opcodes::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput}, Circuit, Opcode, Program, }, native_types::Expression, @@ -84,7 +84,15 @@ fn build_dictionary_from_circuit(circuit: &Circuit) -> HashSet< } Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput { num_bits, .. }, + input: FunctionInput { input: ConstantOrWitnessEnum::Constant(c), num_bits }, + }) => { + let field = 1u128.wrapping_shl(*num_bits); + constants.insert(F::from(field)); + constants.insert(F::from(field - 1)); + constants.insert(*c); + } + Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: FunctionInput { input: ConstantOrWitnessEnum::Witness(_), num_bits }, }) => { let field = 1u128.wrapping_shl(*num_bits); constants.insert(F::from(field)); diff --git a/noir/noir-repo/tooling/profiler/src/opcode_formatter.rs b/noir/noir-repo/tooling/profiler/src/opcode_formatter.rs index aba92c95d85..a33de42a0ff 100644 --- a/noir/noir-repo/tooling/profiler/src/opcode_formatter.rs +++ b/noir/noir-repo/tooling/profiler/src/opcode_formatter.rs @@ -1,6 +1,6 @@ use acir::circuit::{directives::Directive, opcodes::BlackBoxFuncCall, Opcode}; -fn format_blackbox_function(call: &BlackBoxFuncCall) -> String { +fn format_blackbox_function(call: &BlackBoxFuncCall) -> String { match call { BlackBoxFuncCall::AES128Encrypt { .. } => "aes128_encrypt".to_string(), BlackBoxFuncCall::AND { .. } => "and".to_string(),