From 0c376725a29ec18e25a7c9a89c0df8f5a1e06ff4 Mon Sep 17 00:00:00 2001 From: Sarkoxed <75146596+Sarkoxed@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:29:12 +0300 Subject: [PATCH] feat: Several Updates in SMT verification module (part 1) (#10437) This pr enhances symbolic circuit to produce valid witnesses. # Utils Added post processing functionality. So now, while optimizing something inside the circuit, you can postpone some witness calculations until here, in case the variable has been optimized and does not fit as an STerm. # Builders + Schema Added `circuit_finalized` flag to the export. Should be used in the future, during RAM/ROM processing. # StandardCircuit Pushed the post processing for standard logic operations. They were used to test the sha256 witness, which is coming in part 3, I guess. --- .../smt_verification/circuit/circuit_base.hpp | 6 + .../circuit/circuit_schema.hpp | 1 + .../circuit/standard_circuit.cpp | 211 +++++++++++++++--- .../circuit/standard_circuit.test.cpp | 141 ++++++++++++ .../smt_verification/terms/bvterm.test.cpp | 36 +-- .../smt_verification/terms/iterm.test.cpp | 16 +- .../smt_verification/util/smt_util.cpp | 58 +++-- .../smt_verification/util/smt_util.hpp | 18 +- .../circuit_builder_base.hpp | 1 + .../standard_circuit_builder.cpp | 1 + .../ultra_circuit_builder.cpp | 2 + 11 files changed, 414 insertions(+), 77 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp index e39d46a88e0..81584333656 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp @@ -41,6 +41,12 @@ class CircuitBase { std::unordered_map> cached_subcircuits; // caches subcircuits during optimization // No need to recompute them each time + std::unordered_map> + post_process; // Values idxs that should be post processed after the solver returns a witness. + // Basically it affects only optimized out variables. + // Because in BitVector case we can't collect negative values since they will not be + // the same in the field. That's why we store the expression and calculate it after the witness is + // obtained. Solver* solver; // pointer to the solver TermType type; // Type of the underlying Symbolic Terms diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp index cc3b37e8e41..099e60c1c73 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp @@ -33,6 +33,7 @@ struct CircuitSchema { std::vector>> lookup_tables; std::vector real_variable_tags; std::unordered_map range_tags; + bool circuit_finalized; MSGPACK_FIELDS(modulus, public_inps, vars_of_interest, diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp index 5b47ba23afb..bd49b18b82a 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp @@ -103,32 +103,32 @@ size_t StandardCircuit::prepare_gates(size_t cursor) // TODO(alex): Test the effect of this relaxation after the tests are merged. if (univariate_flag) { if ((q_m == 1) && (q_1 == 0) && (q_2 == 0) && (q_3 == -1) && (q_c == 0)) { - (Bool(symbolic_vars[w_l]) == + (Bool(this->symbolic_vars[w_l]) == Bool(STerm(0, this->solver, this->type)) | // STerm(0, this->solver, this->type)) | - Bool(symbolic_vars[w_l]) == + Bool(this->symbolic_vars[w_l]) == Bool(STerm(1, this->solver, this->type))) // STerm(1, this->solver, this->type))) .assert_term(); } else { this->handle_univariate_constraint(q_m, q_1, q_2, q_3, q_c, w_l); } } else { - STerm eq = symbolic_vars[0]; + STerm eq = this->symbolic_vars[this->variable_names_inverse["zero"]]; // mul selector if (q_m != 0) { - eq += symbolic_vars[w_l] * symbolic_vars[w_r] * q_m; + eq += this->symbolic_vars[w_l] * this->symbolic_vars[w_r] * q_m; } // left selector if (q_1 != 0) { - eq += symbolic_vars[w_l] * q_1; + eq += this->symbolic_vars[w_l] * q_1; } // right selector if (q_2 != 0) { - eq += symbolic_vars[w_r] * q_2; + eq += this->symbolic_vars[w_r] * q_2; } // out selector if (q_3 != 0) { - eq += symbolic_vars[w_o] * q_3; + eq += this->symbolic_vars[w_o] * q_3; } // constant selector if (q_c != 0) { @@ -157,7 +157,7 @@ void StandardCircuit::handle_univariate_constraint( bb::fr b = q_1 + q_2 + q_3; if (q_m == 0) { - symbolic_vars[w] == -q_c / b; + this->symbolic_vars[w] == -q_c / b; return; } @@ -169,10 +169,10 @@ void StandardCircuit::handle_univariate_constraint( bb::fr x2 = (-b - d.second) / (bb::fr(2) * q_m); if (d.second == 0) { - symbolic_vars[w] == STerm(x1, this->solver, type); + this->symbolic_vars[w] == STerm(x1, this->solver, type); } else { - ((Bool(symbolic_vars[w]) == Bool(STerm(x1, this->solver, this->type))) | - (Bool(symbolic_vars[w]) == Bool(STerm(x2, this->solver, this->type)))) + ((Bool(this->symbolic_vars[w]) == Bool(STerm(x1, this->solver, this->type))) | + (Bool(this->symbolic_vars[w]) == Bool(STerm(x2, this->solver, this->type)))) .assert_term(); } } @@ -285,8 +285,6 @@ size_t StandardCircuit::handle_logic_constraint(size_t cursor) } } - // TODO(alex): Figure out if I need to create range constraint here too or it'll be - // created anyway in any circuit if (res != static_cast(-1)) { CircuitProps xor_props = get_standard_logic_circuit(res, true); CircuitProps and_props = get_standard_logic_circuit(res, false); @@ -307,6 +305,47 @@ size_t StandardCircuit::handle_logic_constraint(size_t cursor) STerm right = this->symbolic_vars[right_idx]; STerm out = this->symbolic_vars[out_idx]; + // Initializing the parts of the witness that were optimized + // during the symbolic constraints initialization + // i.e. simulating the create_logic_constraint gate by gate using BitVectors/Integers + size_t num_bits = res; + size_t processed_gates = 0; + for (size_t i = num_bits - 1; i < num_bits; i -= 2) { + // 8 here is the number of gates we have to skip to get proper indices + processed_gates += 8; + uint32_t left_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + uint32_t left_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]]; + uint32_t left_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + processed_gates += 1; + uint32_t right_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + uint32_t right_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]]; + uint32_t right_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + processed_gates += 1; + uint32_t out_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + uint32_t out_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]]; + uint32_t out_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + processed_gates += 1; + uint32_t old_left_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + processed_gates += 1; + uint32_t old_right_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + processed_gates += 1; + uint32_t old_out_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + processed_gates += 1; + + this->symbolic_vars[old_left_acc_idx] == (left >> static_cast(i - 1)); + this->symbolic_vars[left_quad_idx] == (this->symbolic_vars[old_left_acc_idx] & 3); + this->symbolic_vars[left_lo_idx] == (this->symbolic_vars[left_quad_idx] & 1); + this->symbolic_vars[left_hi_idx] == (this->symbolic_vars[left_quad_idx] >> 1); + this->symbolic_vars[old_right_acc_idx] == (right >> static_cast(i - 1)); + this->symbolic_vars[right_quad_idx] == (this->symbolic_vars[old_right_acc_idx] & 3); + this->symbolic_vars[right_lo_idx] == (this->symbolic_vars[right_quad_idx] & 1); + this->symbolic_vars[right_hi_idx] == (this->symbolic_vars[right_quad_idx] >> 1); + this->symbolic_vars[old_out_acc_idx] == (out >> static_cast(i - 1)); + this->symbolic_vars[out_quad_idx] == (this->symbolic_vars[old_out_acc_idx] & 3); + this->symbolic_vars[out_lo_idx] == (this->symbolic_vars[out_quad_idx] & 1); + this->symbolic_vars[out_hi_idx] == (this->symbolic_vars[out_quad_idx] >> 1); + } + if (logic_flag) { (left ^ right) == out; } else { @@ -422,19 +461,44 @@ size_t StandardCircuit::handle_range_constraint(size_t cursor) // we need this because even right shifts do not create // any additional gates and therefore are undetectible - // TODO(alex): I think I should simulate the whole subcircuit at that point - // Otherwise optimized out variables are not correct in the final witness - // And I can't fix them by hand each time - size_t num_accs = range_props.gate_idxs.size() - 1; - for (size_t j = 1; j < num_accs + 1 && (this->type == TermType::BVTerm); j++) { - size_t acc_gate = range_props.gate_idxs[j]; - uint32_t acc_gate_idx = range_props.idxs[j]; + // Simulate the range constraint circuit using the bitwise operations + size_t num_bits = res; + size_t num_quads = num_bits >> 1; + num_quads += num_bits & 1; + uint32_t processed_gates = 0; + + // Initializing the parts of the witness that were optimized + // during the symbolic constraints initialization + // i.e. simulating the decompose_into_base4_accumulators gate by gate using BitVectors/Integers + for (size_t i = num_quads - 1; i < num_quads; i--) { + uint32_t lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + processed_gates += 1; + uint32_t quad_idx = 0; + uint32_t old_accumulator_idx = 0; + uint32_t hi_idx = 0; + + if (i == num_quads - 1 && ((num_bits & 1) == 1)) { + quad_idx = lo_idx; + } else { + hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + processed_gates += 1; + quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + processed_gates += 1; + } - uint32_t acc_idx = this->real_variable_index[this->wires_idxs[cursor + acc_gate][acc_gate_idx]]; + if (i == num_quads - 1) { + old_accumulator_idx = quad_idx; + } else { + old_accumulator_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + processed_gates += 1; + } - this->symbolic_vars[acc_idx] == (left >> static_cast(2 * j)); - // I think the following is worse. The name of the variable is lost after that - // this->symbolic_vars[acc_idx] = (left >> static_cast(2 * j)); + this->symbolic_vars[old_accumulator_idx] == (left >> static_cast(2 * i)); + this->symbolic_vars[quad_idx] == (this->symbolic_vars[old_accumulator_idx] & 3); + this->symbolic_vars[lo_idx] == (this->symbolic_vars[quad_idx] & 1); + if (i != (num_quads - 1) || ((num_bits)&1) != 1) { + this->symbolic_vars[hi_idx] == (this->symbolic_vars[quad_idx] >> 1); + } } left <= (bb::fr(2).pow(res) - 1); @@ -545,8 +609,37 @@ size_t StandardCircuit::handle_shr_constraint(size_t cursor) STerm left = this->symbolic_vars[left_idx]; STerm out = this->symbolic_vars[out_idx]; - STerm shled = left >> nr.second; - out == shled; + // Initializing the parts of the witness that were optimized + // during the symbolic constraints initialization + // i.e. simulating the uint's operator>> gate by gate using BitVectors/Integers + uint32_t shift = nr.second; + if ((shift & 1) == 1) { + size_t processed_gates = 0; + uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3); + STerm delta = this->symbolic_vars[delta_idx]; + processed_gates += 1; + uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + + // this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7); + this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } }); + + processed_gates += 1; + uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[r1_idx] == (delta >> 1) * 6; + processed_gates += 1; + uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[r2_idx] == (left >> shift) * 6; + processed_gates += 1; + uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + + // this->symbolic_vars[temp_idx] == -6 * out; + this->post_process.insert({ temp_idx, { out_idx, out_idx, 0, -6, 0, 0 } }); + } + + STerm shred = left >> nr.second; + out == shred; // You have to mark these arguments so they won't be optimized out optimized[left_idx] = false; @@ -652,7 +745,37 @@ size_t StandardCircuit::handle_shl_constraint(size_t cursor) STerm left = this->symbolic_vars[left_idx]; STerm out = this->symbolic_vars[out_idx]; - STerm shled = (left << nr.second) & (bb::fr(2).pow(nr.first) - 1); + // Initializing the parts of the witness that were optimized + // during the symbolic constraints initialization + // i.e. simulating the uint's operator<< gate by gate using BitVectors/Integers + uint32_t num_bits = nr.first; + uint32_t shift = nr.second; + if ((shift & 1) == 1) { + size_t processed_gates = 0; + uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3); + STerm delta = this->symbolic_vars[delta_idx]; + processed_gates += 1; + uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + + // this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7); + this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } }); + + processed_gates += 1; + uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[r1_idx] == (delta >> 1) * 6; + processed_gates += 1; + uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[r2_idx] == (left >> (num_bits - shift)) * 6; + processed_gates += 1; + uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + + // this->symbolic_vraiables[temp_idx] == -6 * r2 + this->post_process.insert({ temp_idx, { r2_idx, r2_idx, 0, -1, 0, 0 } }); + } + + STerm shled = (left << shift) & (bb::fr(2).pow(num_bits) - 1); out == shled; // You have to mark these arguments so they won't be optimized out @@ -760,7 +883,37 @@ size_t StandardCircuit::handle_ror_constraint(size_t cursor) STerm left = this->symbolic_vars[left_idx]; STerm out = this->symbolic_vars[out_idx]; - STerm rored = ((left >> nr.second) | (left << (nr.first - nr.second))) & (bb::fr(2).pow(nr.first) - 1); + // Initializing the parts of the witness that were optimized + // during the symbolic constraints initialization + // i.e. simulating the uint's rotate_right gate by gate using BitVectors/Integers + uint32_t num_bits = nr.first; + uint32_t rotation = nr.second; + if ((rotation & 1) == 1) { + size_t processed_gates = 0; + uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]]; + uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3); + STerm delta = this->symbolic_vars[delta_idx]; + processed_gates += 1; + uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + + // this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7); + this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } }); + + processed_gates += 1; + uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[r1_idx] == (delta >> 1) * 6; + processed_gates += 1; + uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + this->symbolic_vars[r2_idx] == (left >> rotation) * 6; + processed_gates += 1; + uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]]; + + // this->symbolic_vraiables[temp_idx] == -6 * r2 + this->post_process.insert({ temp_idx, { r2_idx, r2_idx, 0, -1, 0, 0 } }); + } + + STerm rored = ((left >> rotation) | (left << (num_bits - rotation))) & (bb::fr(2).pow(num_bits) - 1); out == rored; // You have to mark these arguments so they won't be optimized out @@ -909,4 +1062,4 @@ std::pair StandardCircuit::unique_witness(Circ } return { c1, c2 }; } -}; // namespace smt_circuit +}; // namespace smt_circuit \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp index ffded6864dd..3d9917d8997 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp @@ -376,3 +376,144 @@ TEST(standard_circuit, check_double_xor_bug) Solver s(circuit_info.modulus, default_solver_config, 16, 64); StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); } + +// Check that witness provided by the solver is the same as builder's witness +// Check that all the optimized out values are initialized and computed properly during post proccessing +TEST(standard_circuit, optimized_range_witness) +{ + uint32_t rbit = engine.get_random_uint8() & 1; + uint32_t num_bits = 32 + rbit; + info(num_bits); + + StandardCircuitBuilder builder; + field_t a = witness_t(&builder, engine.get_random_uint256() % (uint256_t(1) << num_bits)); + builder.create_range_constraint(a.get_witness_index(), num_bits); + builder.set_variable_name(a.get_witness_index(), "a"); + + CircuitSchema circuit_info = unpack_from_buffer(builder.export_circuit()); + Solver s(circuit_info.modulus, default_solver_config, 16, 64); + StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); + + circuit["a"] == a.get_value(); + + bool res = smt_timer(&s); + ASSERT_TRUE(res); + auto model_witness = default_model_single({ "a" }, circuit, "optimized_range_check.out"); + + ASSERT_EQ(model_witness.size(), builder.get_num_variables()); + for (size_t i = 0; i < model_witness.size(); i++) { + ASSERT_EQ(model_witness[i], builder.variables[i]); + } +} + +// Check that witness provided by the solver is the same as builder's witness +// Check that all the optimized out values are initialized and computed properly during post proccessing +TEST(standard_circuit, optimized_logic_witness) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_t(&builder, engine.get_random_uint32()); + builder.set_variable_name(a.get_witness_index(), "a"); + uint_ct b = witness_t(&builder, engine.get_random_uint32()); + builder.set_variable_name(b.get_witness_index(), "b"); + uint_ct c = a ^ b; + uint_ct d = a & b; + builder.set_variable_name(c.get_witness_index(), "c"); + builder.set_variable_name(d.get_witness_index(), "d"); + + CircuitSchema circuit_info = unpack_from_buffer(builder.export_circuit()); + Solver s(circuit_info.modulus, default_solver_config, 16, 64); + StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); + + circuit["a"] == a.get_value(); + circuit["b"] == b.get_value(); + + bool res = smt_timer(&s); + ASSERT_TRUE(res); + auto model_witness = default_model_single({ "a", "b", "c", "d" }, circuit, "optimized_xor_check.out"); + + ASSERT_EQ(model_witness.size(), builder.get_num_variables()); + for (size_t i = 0; i < model_witness.size(); i++) { + ASSERT_EQ(model_witness[i], builder.variables[i]); + } +} + +// Check that witness provided by the solver is the same as builder's witness +// Check that all the optimized out values are initialized and computed properly during post proccessing +TEST(standard_circuit, optimized_shr_witness) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_t(&builder, engine.get_random_uint32()); + builder.set_variable_name(a.get_witness_index(), "a"); + uint_ct b = a >> 0; + for (uint32_t i = 1; i < 32; i++) { + b = a >> i; + } + + CircuitSchema circuit_info = unpack_from_buffer(builder.export_circuit()); + Solver s(circuit_info.modulus, default_solver_config, 16, 64); + StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); + + circuit["a"] == a.get_value(); + bool res = smt_timer(&s); + ASSERT_TRUE(res); + auto model_witness = default_model_single({ "a" }, circuit, "optimized_xor_check.out"); + + ASSERT_EQ(model_witness.size(), builder.get_num_variables()); + for (size_t i = 0; i < model_witness.size(); i++) { + EXPECT_EQ(model_witness[i], builder.variables[i]); + } +} + +// Check that witness provided by the solver is the same as builder's witness +// Check that all the optimized out values are initialized and computed properly during post proccessing +TEST(standard_circuit, optimized_shl_witness) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_t(&builder, engine.get_random_uint32()); + builder.set_variable_name(a.get_witness_index(), "a"); + uint_ct b = a << 0; + for (uint32_t i = 1; i < 32; i++) { + b = a << i; + } + + CircuitSchema circuit_info = unpack_from_buffer(builder.export_circuit()); + Solver s(circuit_info.modulus, default_solver_config, 16, 64); + StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); + + circuit[a.get_witness_index()] == a.get_value(); + bool res = smt_timer(&s); + ASSERT_TRUE(res); + auto model_witness = default_model_single({ "a" }, circuit, "optimized_xor_check.out"); + + ASSERT_EQ(model_witness.size(), builder.get_num_variables()); + for (size_t i = 0; i < model_witness.size(); i++) { + EXPECT_EQ(model_witness[i], builder.variables[i]); + } +} + +// Check that witness provided by the solver is the same as builder's witness +// Check that all the optimized out values are initialized and computed properly during post proccessing +TEST(standard_circuit, optimized_ror_witness) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_t(&builder, engine.get_random_uint32()); + builder.set_variable_name(a.get_witness_index(), "a"); + uint_ct b = a.ror(0); + for (uint32_t i = 1; i < 32; i++) { + b = a.ror(i); + } + + CircuitSchema circuit_info = unpack_from_buffer(builder.export_circuit()); + Solver s(circuit_info.modulus, default_solver_config, 16, 64); + StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); + + circuit[a.get_witness_index()] == a.get_value(); + bool res = smt_timer(&s); + ASSERT_TRUE(res); + auto model_witness = default_model_single({ "a" }, circuit, "optimized_xor_check.out"); + + ASSERT_EQ(model_witness.size(), builder.get_num_variables()); + for (size_t i = 0; i < model_witness.size(); i++) { + EXPECT_EQ(model_witness[i], builder.variables[i]); + } +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp index 05d942cdcad..bb7de54a063 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp @@ -18,8 +18,8 @@ using namespace smt_terms; TEST(BVTerm, addition) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); - uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + uint_ct b = witness_ct(&builder, engine.get_random_uint32()); uint_ct c = a + b; uint32_t modulus_base = 16; @@ -47,8 +47,8 @@ TEST(BVTerm, addition) TEST(BVTerm, subtraction) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); - uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + uint_ct b = witness_ct(&builder, engine.get_random_uint32()); uint_ct c = a - b; uint32_t modulus_base = 16; @@ -76,8 +76,8 @@ TEST(BVTerm, subtraction) TEST(BVTerm, xor) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); - uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + uint_ct b = witness_ct(&builder, engine.get_random_uint32()); uint_ct c = a ^ b; uint32_t modulus_base = 16; @@ -105,7 +105,7 @@ TEST(BVTerm, xor) TEST(BVTerm, rotr) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); uint_ct b = a.ror(10); uint32_t modulus_base = 16; @@ -131,7 +131,7 @@ TEST(BVTerm, rotr) TEST(BVTerm, rotl) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); uint_ct b = a.rol(10); uint32_t modulus_base = 16; @@ -158,8 +158,8 @@ TEST(BVTerm, rotl) TEST(BVTerm, mul) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); - uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + uint_ct b = witness_ct(&builder, engine.get_random_uint32()); uint_ct c = a * b; uint32_t modulus_base = 16; @@ -187,8 +187,8 @@ TEST(BVTerm, mul) TEST(BVTerm, and) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); - uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + uint_ct b = witness_ct(&builder, engine.get_random_uint32()); uint_ct c = a & b; uint32_t modulus_base = 16; @@ -216,8 +216,8 @@ TEST(BVTerm, and) TEST(BVTerm, or) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); - uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + uint_ct b = witness_ct(&builder, engine.get_random_uint32()); uint_ct c = a | b; uint32_t modulus_base = 16; @@ -245,8 +245,8 @@ TEST(BVTerm, or) TEST(BVTerm, div) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); - uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); + uint_ct b = witness_ct(&builder, engine.get_random_uint32()); uint_ct c = a / b; uint32_t modulus_base = 16; @@ -274,7 +274,7 @@ TEST(BVTerm, div) TEST(BVTerm, shr) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); uint_ct b = a >> 5; uint32_t modulus_base = 16; @@ -300,7 +300,7 @@ TEST(BVTerm, shr) TEST(BVTerm, shl) { StandardCircuitBuilder builder; - uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct a = witness_ct(&builder, engine.get_random_uint32()); uint_ct b = a << 5; uint32_t modulus_base = 16; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/iterm.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/iterm.test.cpp index d86081a24bc..24710d7b938 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/iterm.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/iterm.test.cpp @@ -17,8 +17,8 @@ using namespace smt_terms; TEST(ITerm, addition) { StandardCircuitBuilder builder; - uint64_t a = static_cast(fr::random_element()) % (static_cast(1) << 31); - uint64_t b = static_cast(fr::random_element()) % (static_cast(1) << 31); + uint64_t a = engine.get_random_uint32() % (static_cast(1) << 31); + uint64_t b = engine.get_random_uint32() % (static_cast(1) << 31); uint64_t c = a + b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config); @@ -41,8 +41,8 @@ TEST(ITerm, addition) TEST(ITerm, subtraction) { StandardCircuitBuilder builder; - uint64_t c = static_cast(fr::random_element()) % (static_cast(1) << 31); - uint64_t b = static_cast(fr::random_element()) % (static_cast(1) << 31); + uint64_t c = engine.get_random_uint32() % (static_cast(1) << 31); + uint64_t b = engine.get_random_uint32() % (static_cast(1) << 31); uint64_t a = c + b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config); @@ -65,8 +65,8 @@ TEST(ITerm, subtraction) TEST(ITerm, mul) { StandardCircuitBuilder builder; - uint64_t a = static_cast(fr::random_element()) % (static_cast(1) << 31); - uint64_t b = static_cast(fr::random_element()) % (static_cast(1) << 31); + uint64_t a = engine.get_random_uint32() % (static_cast(1) << 31); + uint64_t b = engine.get_random_uint32() % (static_cast(1) << 31); uint64_t c = a * b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config); @@ -89,8 +89,8 @@ TEST(ITerm, mul) TEST(ITerm, div) { StandardCircuitBuilder builder; - uint64_t a = static_cast(fr::random_element()) % (static_cast(1) << 31); - uint64_t b = static_cast(fr::random_element()) % (static_cast(1) << 31) + 1; + uint64_t a = engine.get_random_uint32() % (static_cast(1) << 31); + uint64_t b = engine.get_random_uint32() % (static_cast(1) << 31) + 1; uint64_t c = a / b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config); diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp index 2a2ec75c54b..cb4bcb2b5e1 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp @@ -51,11 +51,11 @@ bb::fr string_to_fr(const std::string& number, int base, size_t step) * @param fname file to store the resulting witness if succeded * @param pack flags out to pack the resulting witness using msgpack */ -void default_model(const std::vector& special, - smt_circuit::CircuitBase& c1, - smt_circuit::CircuitBase& c2, - const std::string& fname, - bool pack) +std::vector> default_model(const std::vector& special, + smt_circuit::CircuitBase& c1, + smt_circuit::CircuitBase& c2, + const std::string& fname, + bool pack) { std::vector vterms1; std::vector vterms2; @@ -91,13 +91,28 @@ void default_model(const std::vector& special, info(RED, new_line, RESET); } myfile << new_line << std::endl; - ; - packed_witness.push_back({ string_to_fr(mmap1[vname1], base), string_to_fr(mmap2[vname2], base) }); } myfile << "};"; myfile.close(); + // Accessing post processing functionality of the symbolic circuit + // Once we obtained the witness, compatible with current configuration(e.g. BVTerm) + // We can further compute the remaining witness entries, which were optimized and hence + // are not provided by the solver + for (const auto& post : c1.post_process) { + uint32_t res_idx = post.first; + auto left_idx = static_cast(post.second[0]); + auto right_idx = static_cast(post.second[1]); + bb::fr q_m = post.second[2]; + bb::fr q_l = post.second[3]; + bb::fr q_r = post.second[4]; + bb::fr q_c = post.second[5]; + packed_witness[res_idx][0] = q_m * packed_witness[left_idx][0] * packed_witness[right_idx][0] + + q_l * packed_witness[left_idx][0] + q_r * packed_witness[right_idx][0] + q_c; + packed_witness[res_idx][1] = q_m * packed_witness[left_idx][1] * packed_witness[right_idx][1] + + q_l * packed_witness[left_idx][1] + q_r * packed_witness[right_idx][1] + q_c; + } if (pack) { msgpack::sbuffer buffer; msgpack::pack(buffer, packed_witness); @@ -119,6 +134,7 @@ void default_model(const std::vector& special, for (const auto& vname : special) { info(vname, "_1, ", vname, "_2 = ", mmap[vname + "_1"], ", ", mmap[vname + "_2"]); } + return packed_witness; } /** @@ -135,10 +151,10 @@ void default_model(const std::vector& special, * @param fname file to store the resulting witness if succeded * @param pack flags out to pack the resulting witness using msgpack */ -void default_model_single(const std::vector& special, - smt_circuit::CircuitBase& c, - const std::string& fname, - bool pack) +std::vector default_model_single(const std::vector& special, + smt_circuit::CircuitBase& c, + const std::string& fname, + bool pack) { std::vector vterms; vterms.reserve(c.get_num_vars()); @@ -157,7 +173,7 @@ void default_model_single(const std::vector& special, packed_witness.reserve(c.get_num_vars()); int base = c.type == smt_terms::TermType::BVTerm ? 2 : 10; - for (size_t i = 0; i < c.get_num_vars(); i++) { + for (uint32_t i = 0; i < c.get_num_vars(); i++) { std::string vname = vterms[i].toString(); std::string new_line = mmap[vname] + ", // " + vname; if (c.real_variable_index[i] != i) { @@ -169,6 +185,21 @@ void default_model_single(const std::vector& special, myfile << "};"; myfile.close(); + // Accessing post processing functionality of the symbolic circuit + // Once we obtained the witness, compatible with current configuration(e.g. BVTerm) + // We can further compute the remaining witness entries, which were optimized and hence + // are not provided by the solver + for (const auto& post : c.post_process) { + uint32_t res_idx = post.first; + auto left_idx = static_cast(post.second[0]); + auto right_idx = static_cast(post.second[1]); + bb::fr q_m = post.second[2]; + bb::fr q_l = post.second[3]; + bb::fr q_r = post.second[4]; + bb::fr q_c = post.second[5]; + packed_witness[res_idx] = q_m * packed_witness[left_idx] * packed_witness[right_idx] + + q_l * packed_witness[left_idx] + q_r * packed_witness[right_idx] + q_c; + } if (pack) { msgpack::sbuffer buffer; msgpack::pack(buffer, packed_witness); @@ -189,6 +220,7 @@ void default_model_single(const std::vector& special, for (const auto& vname : special) { info(vname, " = ", mmap1[vname]); } + return packed_witness; } /** @@ -311,4 +343,4 @@ void fix_range_lists(bb::UltraCircuitBuilder& builder) } builder.variables[list.second.variable_indices[num_multiples_of_three + 1]] = list.first; } -} \ No newline at end of file +} diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp index dcf08418028..5241960f5bf 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp @@ -6,15 +6,15 @@ #define RED "\033[31m" #define RESET "\033[0m" -void default_model(const std::vector& special, - smt_circuit::CircuitBase& c1, - smt_circuit::CircuitBase& c2, - const std::string& fname = "witness.out", - bool pack = true); -void default_model_single(const std::vector& special, - smt_circuit::CircuitBase& c, - const std::string& fname = "witness.out", - bool pack = true); +std::vector> default_model(const std::vector& special, + smt_circuit::CircuitBase& c1, + smt_circuit::CircuitBase& c2, + const std::string& fname = "witness.out", + bool pack = true); +std::vector default_model_single(const std::vector& special, + smt_circuit::CircuitBase& c, + const std::string& fname = "witness.out", + bool pack = true); bool smt_timer(smt_solver::Solver* s); std::pair, std::vector> base4(uint32_t el); diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/circuit_builder_base.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/circuit_builder_base.hpp index 28b2c9ad8f6..b2eedbd74e5 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/circuit_builder_base.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/circuit_builder_base.hpp @@ -251,6 +251,7 @@ template struct CircuitSchemaInternal { std::vector>> lookup_tables; std::vector real_variable_tags; std::unordered_map range_tags; + bool circuit_finalized; MSGPACK_FIELDS(modulus, public_inps, vars_of_interest, diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/standard_circuit_builder.cpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/standard_circuit_builder.cpp index 09e74567e40..03effb88841 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/standard_circuit_builder.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/standard_circuit_builder.cpp @@ -565,6 +565,7 @@ template msgpack::sbuffer StandardCircuitBuilder_::export_circ cir.wires.push_back(arith_wires); cir.real_variable_index = this->real_variable_index; + cir.circuit_finalized = true; msgpack::sbuffer buffer; msgpack::pack(buffer, cir); diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_circuit_builder.cpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_circuit_builder.cpp index 29bf01cc42c..b49b3a7ad7c 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_circuit_builder.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_circuit_builder.cpp @@ -2956,6 +2956,8 @@ template msgpack::sbuffer UltraCircuitBuilder_circuit_finalized; + msgpack::sbuffer buffer; msgpack::pack(buffer, cir); return buffer;