From 3247058d2e54e1fb84680ad9fcf9ece611235e96 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Mon, 29 Jul 2024 19:54:29 +0200 Subject: [PATCH] chore: constant inputs for most blackboxes (#7613) blackbox can use constant inputs: aes128 keccakf1600 sha256compression ecoperations blake2s Blake3s logic sha256compression The following have not been modified to use constant inputs: sha256 and keccak256 because they are deprecated. block, recursion and honk recursion because I am not sure they support it. shnorr and ecdsa, because they might be replaced. --- .../dsl/acir_format/acir_format.test.cpp | 32 ++++- .../acir_format/acir_to_constraint_buf.cpp | 119 ++++-------------- .../dsl/acir_format/aes128_constraint.cpp | 43 ++++--- .../dsl/acir_format/aes128_constraint.hpp | 8 +- .../dsl/acir_format/blake2s_constraint.cpp | 4 +- .../dsl/acir_format/blake2s_constraint.hpp | 5 +- .../dsl/acir_format/blake3_constraint.cpp | 4 +- .../dsl/acir_format/blake3_constraint.hpp | 5 +- .../dsl/acir_format/ec_operations.cpp | 41 ++++-- .../dsl/acir_format/ec_operations.hpp | 13 +- .../dsl/acir_format/ec_operations.test.cpp | 32 ++--- .../honk_recursion_constraint.test.cpp | 4 +- .../dsl/acir_format/keccak_constraint.cpp | 2 +- .../dsl/acir_format/keccak_constraint.hpp | 3 +- .../dsl/acir_format/logic_constraint.cpp | 16 +-- .../dsl/acir_format/logic_constraint.hpp | 17 ++- .../dsl/acir_format/multi_scalar_mul.cpp | 18 +-- .../dsl/acir_format/multi_scalar_mul.hpp | 13 +- .../dsl/acir_format/poseidon2_constraint.cpp | 2 +- .../dsl/acir_format/poseidon2_constraint.hpp | 3 +- .../acir_format/poseidon2_constraint.test.cpp | 8 +- .../acir_format/recursion_constraint.test.cpp | 4 +- .../dsl/acir_format/sha256_constraint.cpp | 8 +- .../dsl/acir_format/sha256_constraint.hpp | 5 +- .../acir_format/sha256_constraint.test.cpp | 8 +- .../dsl/acir_format/witness_constant.hpp | 33 +++++ 26 files changed, 227 insertions(+), 223 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/dsl/acir_format/witness_constant.hpp diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp index 6085671fb0a..511228ac251 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.test.cpp @@ -107,8 +107,8 @@ TEST_F(AcirFormatTests, TestLogicGateFromNoirCircuit) }; LogicConstraint logic_constraint{ - .a = 0, - .b = 1, + .a = WitnessOrConstant::from_index(0), + .b = WitnessOrConstant::from_index(1), .result = 2, .num_bits = 32, .is_xor_gate = 1, @@ -510,7 +510,33 @@ TEST_F(AcirFormatTests, TestKeccakPermutation) { Keccakf1600 keccak_permutation{ - .state = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 }, + .state = { + WitnessOrConstant::from_index(1), + WitnessOrConstant::from_index(2), + WitnessOrConstant::from_index(3), + WitnessOrConstant::from_index(4), + WitnessOrConstant::from_index(5), + WitnessOrConstant::from_index(6), + WitnessOrConstant::from_index(7), + WitnessOrConstant::from_index(8), + WitnessOrConstant::from_index(9), + WitnessOrConstant::from_index(10), + WitnessOrConstant::from_index(11), + WitnessOrConstant::from_index(12), + WitnessOrConstant::from_index(13), + WitnessOrConstant::from_index(14), + WitnessOrConstant::from_index(15), + WitnessOrConstant::from_index(16), + WitnessOrConstant::from_index(17), + WitnessOrConstant::from_index(18), + WitnessOrConstant::from_index(19), + WitnessOrConstant::from_index(20), + WitnessOrConstant::from_index(21), + WitnessOrConstant::from_index(22), + WitnessOrConstant::from_index(23), + WitnessOrConstant::from_index(24), + WitnessOrConstant::from_index(25), + }, .result = { 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50 }, }; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp index 2a7882a2080..d50bcd90495 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_to_constraint_buf.cpp @@ -195,19 +195,19 @@ uint32_t get_witness_from_function_input(Program::FunctionInput input) return input_witness.value.value; } -WitnessConstant parse_input(Program::FunctionInput input) +WitnessOrConstant parse_input(Program::FunctionInput input) { - WitnessConstant result = std::visit( + WitnessOrConstant result = std::visit( [&](auto&& e) { using T = std::decay_t; if constexpr (std::is_same_v) { - return WitnessConstant{ + return WitnessOrConstant{ .index = e.value.value, .value = bb::fr::zero(), .is_constant = false, }; } else if constexpr (std::is_same_v) { - return WitnessConstant{ + return WitnessOrConstant{ .index = 0, .value = uint256_t(e.value), .is_constant = true, @@ -215,7 +215,7 @@ WitnessConstant parse_input(Program::FunctionInput input) } else { ASSERT(false); } - return WitnessConstant{ + return WitnessOrConstant{ .index = 0, .value = bb::fr::zero(), .is_constant = true, @@ -223,33 +223,6 @@ WitnessConstant parse_input(Program::FunctionInput input) }, input.input.value); return result; - - // WitnessConstant result = std::visit( - // [&](auto&& e) { - // using T = std::decay_t; - // if constexpr (std::is_same_v) { - // return WitnessConstant{ - // .index = e.value.witness.value, - // .value = bb::fr::zero(), - // .is_constant = false, - // }; - // } else if constexpr (std::is_same_v) { - // return WitnessConstant{ - // .index = 0, - // .value = uint256_t(e.value.constant), - // .is_constant = true, - // }; - // } else { - // ASSERT(false); - // } - // return WitnessConstant{ - // .index = 0, - // .value = bb::fr::zero(), - // .is_constant = true, - // }; - // }, - // input.value); - // return result; } void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, @@ -261,8 +234,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, [&](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) { - auto lhs_input = get_witness_from_function_input(arg.lhs); - auto rhs_input = get_witness_from_function_input(arg.rhs); + auto lhs_input = parse_input(arg.lhs); + auto rhs_input = parse_input(arg.rhs); af.logic_constraints.push_back(LogicConstraint{ .a = lhs_input, .b = rhs_input, @@ -272,8 +245,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, }); af.original_opcode_indices.logic_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { - auto lhs_input = get_witness_from_function_input(arg.lhs); - auto rhs_input = get_witness_from_function_input(arg.rhs); + auto lhs_input = parse_input(arg.lhs); + auto rhs_input = parse_input(arg.rhs); af.logic_constraints.push_back(LogicConstraint{ .a = lhs_input, .b = rhs_input, @@ -292,29 +265,9 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, } else if constexpr (std::is_same_v) { af.aes128_constraints.push_back(AES128Constraint{ - .inputs = map(arg.inputs, - [](auto& e) { - return AES128Input{ - .witness = get_witness_from_function_input(e), - .num_bits = e.num_bits, - }; - }), - .iv = map(arg.iv, - [](auto& e) { - auto witness = get_witness_from_function_input(e); - return AES128Input{ - .witness = witness, - .num_bits = e.num_bits, - }; - }), - .key = map(arg.key, - [](auto& e) { - auto input_witness = get_witness_from_function_input(e); - return AES128Input{ - .witness = input_witness, - .num_bits = e.num_bits, - }; - }), + .inputs = map(arg.inputs, [](auto& e) { return parse_input(e); }), + .iv = map(arg.iv, [](auto& e) { return parse_input(e); }), + .key = map(arg.key, [](auto& e) { return parse_input(e); }), .outputs = map(arg.outputs, [](auto& e) { return e.value; }), }); af.original_opcode_indices.aes128_constraints.push_back(opcode_index); @@ -335,22 +288,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, } else if constexpr (std::is_same_v) { af.sha256_compression.push_back(Sha256Compression{ - .inputs = map(arg.inputs, - [](auto& e) { - auto input_witness = get_witness_from_function_input(e); - return Sha256Input{ - .witness = input_witness, - .num_bits = e.num_bits, - }; - }), - .hash_values = map(arg.hash_values, - [](auto& e) { - auto input_witness = get_witness_from_function_input(e); - return Sha256Input{ - .witness = input_witness, - .num_bits = e.num_bits, - }; - }), + .inputs = map(arg.inputs, [](auto& e) { return parse_input(e); }), + .hash_values = map(arg.hash_values, [](auto& e) { return parse_input(e); }), .result = map(arg.outputs, [](auto& e) { return e.value; }), }); af.original_opcode_indices.sha256_compression.push_back(opcode_index); @@ -358,9 +297,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.blake2s_constraints.push_back(Blake2sConstraint{ .inputs = map(arg.inputs, [](auto& e) { - auto input_witness = get_witness_from_function_input(e); return Blake2sInput{ - .witness = input_witness, + .blackbox_input = parse_input(e), .num_bits = e.num_bits, }; }), @@ -371,9 +309,8 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.blake3_constraints.push_back(Blake3Constraint{ .inputs = map(arg.inputs, [](auto& e) { - auto input_witness = get_witness_from_function_input(e); return Blake3Input{ - .witness = input_witness, + .blackbox_input = parse_input(e), .num_bits = e.num_bits, }; }), @@ -437,12 +374,12 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, }); af.original_opcode_indices.multi_scalar_mul_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { - auto input_1_x = get_witness_from_function_input(arg.input1[0]); - auto input_1_y = get_witness_from_function_input(arg.input1[1]); - auto input_1_infinite = get_witness_from_function_input(arg.input1[2]); - auto input_2_x = get_witness_from_function_input(arg.input2[0]); - auto input_2_y = get_witness_from_function_input(arg.input2[1]); - auto input_2_infinite = get_witness_from_function_input(arg.input2[2]); + auto input_1_x = parse_input(arg.input1[0]); + auto input_1_y = parse_input(arg.input1[1]); + auto input_1_infinite = parse_input(arg.input1[2]); + auto input_2_x = parse_input(arg.input2[0]); + auto input_2_y = parse_input(arg.input2[1]); + auto input_2_infinite = parse_input(arg.input2[2]); af.ec_add_constraints.push_back(EcAdd{ .input1_x = input_1_x, @@ -473,11 +410,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.original_opcode_indices.keccak_constraints.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.keccak_permutations.push_back(Keccakf1600{ - .state = map(arg.inputs, - [](auto& e) { - auto input_witness = get_witness_from_function_input(e); - return input_witness; - }), + .state = map(arg.inputs, [](auto& e) { return parse_input(e); }), .result = map(arg.outputs, [](auto& e) { return e.value; }), }); af.original_opcode_indices.keccak_permutations.push_back(opcode_index); @@ -551,11 +484,7 @@ void handle_blackbox_func_call(Program::Opcode::BlackBoxFuncCall const& arg, af.original_opcode_indices.bigint_operations.push_back(opcode_index); } else if constexpr (std::is_same_v) { af.poseidon2_constraints.push_back(Poseidon2Constraint{ - .state = map(arg.inputs, - [](auto& e) { - auto input_witness = get_witness_from_function_input(e); - return input_witness; - }), + .state = map(arg.inputs, [](auto& e) { return parse_input(e); }), .result = map(arg.outputs, [](auto& e) { return e.value; }), .len = arg.len, }); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.cpp index c835322b52c..1605cf8eaf7 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.cpp @@ -1,4 +1,5 @@ #include "aes128_constraint.hpp" +#include "barretenberg/dsl/acir_format/acir_format.hpp" #include "barretenberg/stdlib/encryption/aes128/aes128.hpp" #include #include @@ -14,20 +15,21 @@ template void create_aes128_constraints(Builder& builder, con using field_ct = bb::stdlib::field_t; // Packs 16 bytes from the inputs (plaintext, iv, key) into a field element - const auto convert_input = [&](std::span inputs, size_t padding) { - field_ct converted = 0; - for (size_t i = 0; i < 16 - padding; ++i) { - converted *= 256; - field_ct byte = field_ct::from_witness_index(&builder, inputs[i].witness); - converted += byte; - } - for (size_t i = 0; i < padding; ++i) { - converted *= 256; - field_ct byte = padding; - converted += byte; - } - return converted; - }; + const auto convert_input = + [&](std::span, std::dynamic_extent> inputs, size_t padding, Builder& builder) { + field_ct converted = 0; + for (size_t i = 0; i < 16 - padding; ++i) { + converted *= 256; + field_ct byte = to_field_ct(inputs[i], builder); + converted += byte; + } + for (size_t i = 0; i < padding; ++i) { + converted *= 256; + field_ct byte = padding; + converted += byte; + } + return converted; + }; // Packs 16 bytes from the outputs (witness indexes) into a field element for comparison const auto convert_output = [&](std::span outputs) { @@ -47,11 +49,14 @@ template void create_aes128_constraints(Builder& builder, con for (size_t i = 0; i < constraint.inputs.size(); i += 16) { field_ct to_add; if (i + 16 > constraint.inputs.size()) { - to_add = convert_input( - std::span{ &constraint.inputs[i], 16 - padding_size }, - padding_size); + to_add = + convert_input(std::span, std::dynamic_extent>{ &constraint.inputs[i], + 16 - padding_size }, + padding_size, + builder); } else { - to_add = convert_input(std::span{ &constraint.inputs[i], 16 }, 0); + to_add = + convert_input(std::span, 16>{ &constraint.inputs[i], 16 }, 0, builder); } converted_inputs.emplace_back(to_add); } @@ -63,7 +68,7 @@ template void create_aes128_constraints(Builder& builder, con } const std::vector output_bytes = bb::stdlib::aes128::encrypt_buffer_cbc( - converted_inputs, convert_input(constraint.iv, 0), convert_input(constraint.key, 0)); + converted_inputs, convert_input(constraint.iv, 0, builder), convert_input(constraint.key, 0, builder)); for (size_t i = 0; i < output_bytes.size(); ++i) { builder.assert_equal(output_bytes[i].normalize().witness_index, converted_outputs[i].normalize().witness_index); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.hpp index 3149210a8be..fe3a78c058f 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/aes128_constraint.hpp @@ -1,5 +1,7 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" #include #include #include @@ -16,9 +18,9 @@ struct AES128Input { }; struct AES128Constraint { - std::vector inputs; - std::array iv; - std::array key; + std::vector> inputs; + std::array, 16> iv; + std::array, 16> key; std::vector outputs; // For serialization, update with any new fields diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.cpp index f38b1e2b0dc..00978243476 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.cpp @@ -18,13 +18,13 @@ template void create_blake2s_constraints(Builder& builder, co // Get the witness assignment for each witness index // Write the witness assignment to the byte_array for (const auto& witness_index_num_bits : constraint.inputs) { - auto witness_index = witness_index_num_bits.witness; + auto witness_index = witness_index_num_bits.blackbox_input; auto num_bits = witness_index_num_bits.num_bits; // XXX: The implementation requires us to truncate the element to the nearest byte and not bit auto num_bytes = round_to_nearest_byte(num_bits); - field_ct element = field_ct::from_witness_index(&builder, witness_index); + field_ct element = to_field_ct(witness_index, builder); byte_array_ct element_bytes(element, num_bytes); arr.write(element_bytes); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.hpp index d0fd67eaa55..10f10f69af2 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake2s_constraint.hpp @@ -1,4 +1,5 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" #include #include @@ -7,11 +8,11 @@ namespace acir_format { struct Blake2sInput { - uint32_t witness; + WitnessOrConstant blackbox_input; uint32_t num_bits; // For serialization, update with any new fields - MSGPACK_FIELDS(witness, num_bits); + MSGPACK_FIELDS(blackbox_input, num_bits); friend bool operator==(Blake2sInput const& lhs, Blake2sInput const& rhs) = default; }; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.cpp index 71fc56301a1..94cb3845e06 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.cpp @@ -16,13 +16,13 @@ template void create_blake3_constraints(Builder& builder, con // Get the witness assignment for each witness index // Write the witness assignment to the byte_array for (const auto& witness_index_num_bits : constraint.inputs) { - auto witness_index = witness_index_num_bits.witness; + auto witness_index = witness_index_num_bits.blackbox_input; auto num_bits = witness_index_num_bits.num_bits; // XXX: The implementation requires us to truncate the element to the nearest byte and not bit auto num_bytes = round_to_nearest_byte(num_bits); - field_ct element = field_ct::from_witness_index(&builder, witness_index); + field_ct element = to_field_ct(witness_index, builder); byte_array_ct element_bytes(element, num_bytes); arr.write(element_bytes); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.hpp index f5a9c0e546b..2a0ebc3835c 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/blake3_constraint.hpp @@ -1,4 +1,5 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" #include #include @@ -7,11 +8,11 @@ namespace acir_format { struct Blake3Input { - uint32_t witness; + WitnessOrConstant blackbox_input; uint32_t num_bits; // For serialization, update with any new fields - MSGPACK_FIELDS(witness, num_bits); + MSGPACK_FIELDS(blackbox_input, num_bits); friend bool operator==(Blake3Input const& lhs, Blake3Input const& rhs) = default; }; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.cpp index 86bde0342fe..cd05eb150ac 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.cpp @@ -12,24 +12,39 @@ void create_ec_add_constraint(Builder& builder, const EcAdd& input, bool has_val { // Input to cycle_group points using cycle_group_ct = bb::stdlib::cycle_group; - using field_ct = bb::stdlib::field_t; using bool_ct = bb::stdlib::bool_t; - auto x1 = field_ct::from_witness_index(&builder, input.input1_x); - auto y1 = field_ct::from_witness_index(&builder, input.input1_y); - auto x2 = field_ct::from_witness_index(&builder, input.input2_x); - auto y2 = field_ct::from_witness_index(&builder, input.input2_y); - auto infinite1 = bool_ct(field_ct::from_witness_index(&builder, input.input1_infinite)); - auto infinite2 = bool_ct(field_ct::from_witness_index(&builder, input.input2_infinite)); + auto x1 = to_field_ct(input.input1_x, builder); + auto y1 = to_field_ct(input.input1_y, builder); + auto x2 = to_field_ct(input.input2_x, builder); + + auto y2 = to_field_ct(input.input2_y, builder); + + auto infinite1 = bool_ct(to_field_ct(input.input1_infinite, builder)); + + auto infinite2 = bool_ct(to_field_ct(input.input2_infinite, builder)); + if (!has_valid_witness_assignments) { auto g1 = bb::grumpkin::g1::affine_one; // We need to have correct values representing points on the curve - builder.variables[input.input1_x] = g1.x; - builder.variables[input.input1_y] = g1.y; - builder.variables[input.input1_infinite] = bb::fr(0); - builder.variables[input.input2_x] = g1.x; - builder.variables[input.input2_y] = g1.y; - builder.variables[input.input2_infinite] = bb::fr(0); + if (!x1.is_constant()) { + builder.variables[x1.witness_index] = g1.x; + } + if (!y1.is_constant()) { + builder.variables[y1.witness_index] = g1.y; + } + if (!infinite1.is_constant()) { + builder.variables[infinite1.witness_index] = bb::fr(0); + } + if (!x2.is_constant()) { + builder.variables[x2.witness_index] = g1.x; + } + if (!y2.is_constant()) { + builder.variables[y2.witness_index] = g1.y; + } + if (!infinite2.is_constant()) { + builder.variables[infinite2.witness_index] = bb::fr(0); + } } cycle_group_ct input1_point(x1, y1, infinite1); cycle_group_ct input2_point(x2, y2, infinite2); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.hpp index 6c33bee13ee..1a83d549ee4 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.hpp @@ -1,16 +1,17 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" #include namespace acir_format { struct EcAdd { - uint32_t input1_x; - uint32_t input1_y; - uint32_t input1_infinite; - uint32_t input2_x; - uint32_t input2_y; - uint32_t input2_infinite; + WitnessOrConstant input1_x; + WitnessOrConstant input1_y; + WitnessOrConstant input1_infinite; + WitnessOrConstant input2_x; + WitnessOrConstant input2_y; + WitnessOrConstant input2_infinite; uint32_t result_x; uint32_t result_y; uint32_t result_infinite; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp index 9ff5123ffd1..9f43d49c875 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.test.cpp @@ -40,12 +40,12 @@ size_t generate_ec_add_constraint(EcAdd& ec_add_constraint, WitnessVector& witne witness_values.push_back(fr(0)); witness_values.push_back(fr(0)); ec_add_constraint = EcAdd{ - .input1_x = 1, - .input1_y = 2, - .input1_infinite = 7, - .input2_x = 3, - .input2_y = 4, - .input2_infinite = 7, + .input1_x = WitnessOrConstant::from_index(1), + .input1_y = WitnessOrConstant::from_index(2), + .input1_infinite = WitnessOrConstant::from_index(7), + .input2_x = WitnessOrConstant::from_index(3), + .input2_y = WitnessOrConstant::from_index(4), + .input2_infinite = WitnessOrConstant::from_index(7), .result_x = 5, .result_y = 6, .result_infinite = 8, @@ -129,52 +129,52 @@ TEST_F(EcOperations, TestECMultiScalarMul) fr(0), }; msm_constrain = MultiScalarMul{ - .points = { WitnessConstant{ + .points = { WitnessOrConstant{ .index = 1, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 2, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 3, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 1, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 2, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 3, .value = fr(0), .is_constant = false, } }, - .scalars = { WitnessConstant{ + .scalars = { WitnessOrConstant{ .index = 4, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 5, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 4, .value = fr(0), .is_constant = false, }, - WitnessConstant{ + WitnessOrConstant{ .index = 5, .value = fr(0), .is_constant = false, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/honk_recursion_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/honk_recursion_constraint.test.cpp index af86ada3185..b87bd4c9e39 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/honk_recursion_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/honk_recursion_constraint.test.cpp @@ -39,8 +39,8 @@ class AcirHonkRecursionConstraint : public ::testing::Test { }; LogicConstraint logic_constraint{ - .a = 0, - .b = 1, + .a = WitnessOrConstant::from_index(0), + .b = WitnessOrConstant::from_index(1), .result = 2, .num_bits = 32, .is_xor_gate = 1, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.cpp index 13c72f0f231..54139dba07d 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.cpp @@ -51,7 +51,7 @@ template void create_keccak_permutations(Builder& builder, co // Get the witness assignment for each witness index // Write the witness assignment to the byte_array for (size_t i = 0; i < constraint.state.size(); ++i) { - state[i] = field_ct::from_witness_index(&builder, constraint.state[i]); + state[i] = to_field_ct(constraint.state[i], builder); } std::array output_state = bb::stdlib::keccak::permutation_opcode(state, &builder); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.hpp index 6ae65a6e8d2..2bf194dfd9f 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/keccak_constraint.hpp @@ -1,4 +1,5 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" #include #include @@ -16,7 +17,7 @@ struct HashInput { }; struct Keccakf1600 { - std::array state; + std::array, 25> state; std::array result; // For serialization, update with any new fields diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.cpp index de6c2842d1d..cea21fc25ba 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.cpp @@ -7,16 +7,16 @@ using namespace bb::plonk; template void create_logic_gate(Builder& builder, - const uint32_t a, - const uint32_t b, + const WitnessOrConstant a, + const WitnessOrConstant b, const uint32_t result, const size_t num_bits, const bool is_xor_gate) { using field_ct = bb::stdlib::field_t; - field_ct left = field_ct::from_witness_index(&builder, a); - field_ct right = field_ct::from_witness_index(&builder, b); + field_ct left = to_field_ct(a, builder); + field_ct right = to_field_ct(b, builder); field_ct res = bb::stdlib::logic::create_logic_constraint(left, right, num_bits, is_xor_gate); field_ct our_res = field_ct::from_witness_index(&builder, result); @@ -24,14 +24,14 @@ void create_logic_gate(Builder& builder, } template void create_logic_gate(bb::MegaCircuitBuilder& builder, - const uint32_t a, - const uint32_t b, + const WitnessOrConstant a, + const WitnessOrConstant b, const uint32_t result, const size_t num_bits, const bool is_xor_gate); template void create_logic_gate(bb::UltraCircuitBuilder& builder, - const uint32_t a, - const uint32_t b, + const WitnessOrConstant a, + const WitnessOrConstant b, const uint32_t result, const size_t num_bits, const bool is_xor_gate); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.hpp index 16b17f73830..a7af4832ee0 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/logic_constraint.hpp @@ -1,4 +1,5 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" #include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp" #include @@ -8,8 +9,8 @@ namespace acir_format { using Builder = bb::UltraCircuitBuilder; struct LogicConstraint { - uint32_t a; - uint32_t b; + WitnessOrConstant a; + WitnessOrConstant b; uint32_t result; uint32_t num_bits; uint32_t is_xor_gate; @@ -21,10 +22,14 @@ struct LogicConstraint { }; template -void create_logic_gate( - Builder& builder, uint32_t a, uint32_t b, uint32_t result, std::size_t num_bits, bool is_xor_gate); +void create_logic_gate(Builder& builder, + WitnessOrConstant a, + WitnessOrConstant b, + uint32_t result, + std::size_t num_bits, + bool is_xor_gate); -void xor_gate(Builder& builder, uint32_t a, uint32_t b, uint32_t result); +void xor_gate(Builder& builder, WitnessOrConstant a, WitnessOrConstant b, uint32_t result); -void and_gate(Builder& builder, uint32_t a, uint32_t b, uint32_t result); +void and_gate(Builder& builder, WitnessOrConstant a, WitnessOrConstant b, uint32_t result); } // namespace acir_format diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp index 398c794dea1..1baf25b97ec 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.cpp @@ -25,21 +25,9 @@ template void create_multi_scalar_mul_constraint(Builder& bui field_ct point_x; field_ct point_y; bool_ct infinite; - if (input.points[i].is_constant) { - point_x = field_ct(input.points[i].value); - } else { - point_x = field_ct::from_witness_index(&builder, input.points[i].index); - } - if (input.points[i + 1].is_constant) { - point_y = field_ct(input.points[i + 1].value); - } else { - point_y = field_ct::from_witness_index(&builder, input.points[i + 1].index); - } - if (input.points[i + 2].is_constant) { - infinite = bool_ct(field_ct(input.points[i + 2].value)); - } else { - infinite = bool_ct(field_ct::from_witness_index(&builder, input.points[i + 2].index)); - } + point_x = to_field_ct(input.points[i], builder); + point_y = to_field_ct(input.points[i + 1], builder); + infinite = bool_ct(to_field_ct(input.points[i + 2], builder)); cycle_group_ct input_point(point_x, point_y, infinite); // Reconstruct the scalar from the low and high limbs field_ct scalar_low_as_field; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp index ef980a99ab6..0d13aaca513 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp @@ -2,22 +2,15 @@ #include "barretenberg/serialize/msgpack.hpp" #include "barretenberg/stdlib/primitives/field/field.hpp" #include "serde/index.hpp" +#include "witness_constant.hpp" #include #include namespace acir_format { -template struct WitnessConstant { - uint32_t index; - FF value; - bool is_constant; - MSGPACK_FIELDS(index, value, is_constant); - friend bool operator==(WitnessConstant const& lhs, WitnessConstant const& rhs) = default; -}; - struct MultiScalarMul { - std::vector> points; - std::vector> scalars; + std::vector> points; + std::vector> scalars; uint32_t out_point_x; uint32_t out_point_y; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.cpp index 1a8e0821070..03ecef4efbe 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.cpp @@ -20,7 +20,7 @@ template void create_poseidon2_permutations(Builder& builder, // Write the witness assignment to the byte_array state State state; for (size_t i = 0; i < constraint.state.size(); ++i) { - state[i] = field_ct::from_witness_index(&builder, constraint.state[i]); + state[i] = to_field_ct(constraint.state[i], builder); } State output_state; output_state = stdlib::Poseidon2Permutation::permutation(&builder, state); diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.hpp index 342d709fd59..b2ee452bcd3 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.hpp @@ -1,4 +1,5 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" #include #include @@ -6,7 +7,7 @@ namespace acir_format { struct Poseidon2Constraint { - std::vector state; + std::vector> state; std::vector result; uint32_t len; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp index cc144fa6fd9..b140797cf9e 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/poseidon2_constraint.test.cpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace acir_format::tests { @@ -29,7 +30,12 @@ TEST_F(Poseidon2Tests, TestPoseidon2Permutation) { Poseidon2Constraint poseidon2_constraint{ - .state = { 1, 2, 3, 4, }, + .state = { + WitnessOrConstant::from_index(1), + WitnessOrConstant::from_index(2), + WitnessOrConstant::from_index(3), + WitnessOrConstant::from_index(4), + }, .result = { 5, 6, 7, 8, }, .len = 4, }; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp index 76e6d72bad1..abc6b33cb07 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/recursion_constraint.test.cpp @@ -37,8 +37,8 @@ Builder create_inner_circuit() }; LogicConstraint logic_constraint{ - .a = 0, - .b = 1, + .a = WitnessOrConstant::from_index(0), + .b = WitnessOrConstant::from_index(1), .result = 2, .num_bits = 32, .is_xor_gate = 1, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.cpp index eef6327ab12..99f88882937 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.cpp @@ -56,16 +56,12 @@ void create_sha256_compression_constraints(Builder& builder, const Sha256Compres // because of the lookup-tables. size_t i = 0; for (const auto& witness_index_num_bits : constraint.inputs) { - auto witness_index = witness_index_num_bits.witness; - field_ct element = field_ct::from_witness_index(&builder, witness_index); - inputs[i] = element; + inputs[i] = to_field_ct(witness_index_num_bits, builder); ++i; } i = 0; for (const auto& witness_index_num_bits : constraint.hash_values) { - auto witness_index = witness_index_num_bits.witness; - field_ct element = field_ct::from_witness_index(&builder, witness_index); - hash_inputs[i] = element; + hash_inputs[i] = to_field_ct(witness_index_num_bits, builder); ++i; } diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.hpp index 1843aa5198d..4b243b5b6ab 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.hpp @@ -1,4 +1,5 @@ #pragma once +#include "barretenberg/dsl/acir_format/witness_constant.hpp" #include "barretenberg/serialize/msgpack.hpp" #include #include @@ -25,8 +26,8 @@ struct Sha256Constraint { }; struct Sha256Compression { - std::array inputs; - std::array hash_values; + std::array, 16> inputs; + std::array, 8> hash_values; std::array result; friend bool operator==(Sha256Compression const& lhs, Sha256Compression const& rhs) = default; diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp index 72936adc776..39163fc7a6f 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/sha256_constraint.test.cpp @@ -21,13 +21,13 @@ class Sha256Tests : public ::testing::Test { TEST_F(Sha256Tests, TestSha256Compression) { - std::array inputs; + std::array, 16> inputs; for (size_t i = 0; i < 16; ++i) { - inputs[i] = { .witness = static_cast(i + 1), .num_bits = 32 }; + inputs[i] = WitnessOrConstant::from_index(static_cast(i + 1)); } - std::array hash_values; + std::array, 8> hash_values; for (size_t i = 0; i < 8; ++i) { - hash_values[i] = { .witness = static_cast(i + 17), .num_bits = 32 }; + hash_values[i] = WitnessOrConstant::from_index(static_cast(i + 17)); } Sha256Compression sha256_compression{ .inputs = inputs, diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/witness_constant.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/witness_constant.hpp new file mode 100644 index 00000000000..554c8ddf657 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/witness_constant.hpp @@ -0,0 +1,33 @@ +#pragma once +#include "barretenberg/serialize/msgpack.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" + +namespace acir_format { +template struct WitnessOrConstant { + + uint32_t index; + FF value; + bool is_constant; + MSGPACK_FIELDS(index, value, is_constant); + friend bool operator==(WitnessOrConstant const& lhs, WitnessOrConstant const& rhs) = default; + static WitnessOrConstant from_index(uint32_t index) + { + return WitnessOrConstant{ + .index = index, + .value = FF::zero(), + .is_constant = false, + }; + } +}; + +template +bb::stdlib::field_t to_field_ct(const WitnessOrConstant& input, Builder& builder) +{ + using field_ct = bb::stdlib::field_t; + if (input.is_constant) { + return field_ct(input.value); + } + return field_ct::from_witness_index(&builder, input.index); +} + +} // namespace acir_format \ No newline at end of file