Skip to content

Commit

Permalink
fix: set denominator to 1 during verification of dsl/big-field divisi…
Browse files Browse the repository at this point in the history
…on (#5188)

This PR solve the verification issue with bigint division in Noir:
noir-lang/noir#4530

---------

Co-authored-by: Tom French <[email protected]>
Co-authored-by: Rumata888 <[email protected]>
  • Loading branch information
3 people authored Mar 19, 2024
1 parent ccc5016 commit 253d002
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ void build_constraints(Builder& builder, AcirFormat const& constraint_system, bo

// Add big_int constraints
DSLBigInts<Builder> dsl_bigints;
dsl_bigints.set_builder(&builder);
for (const auto& constraint : constraint_system.bigint_from_le_bytes_constraints) {
create_bigint_from_le_bytes_constraint(builder, constraint, dsl_bigints);
}
for (const auto& constraint : constraint_system.bigint_operations) {
create_bigint_operations_constraint<Builder>(constraint, dsl_bigints);
create_bigint_operations_constraint<Builder>(constraint, dsl_bigints, has_valid_witness_assignments);
}
for (const auto& constraint : constraint_system.bigint_to_le_bytes_constraints) {
create_bigint_to_le_bytes_constraint(builder, constraint, dsl_bigints);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "barretenberg/common/assert.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/numeric/uintx/uintx.hpp"
#include "barretenberg/stdlib/primitives/bigfield/bigfield.hpp"
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -34,14 +35,16 @@ ModulusId modulus_param_to_id(ModulusParam param)
secp256r1::FrParams::modulus_2 == param.modulus_2 && secp256r1::FrParams::modulus_3 == param.modulus_3) {
return ModulusId::SECP256R1_FR;
}

return ModulusId::UNKNOWN;
}

template void create_bigint_operations_constraint<UltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<UltraCircuitBuilder>& dsl_bigint);
DSLBigInts<UltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);
template void create_bigint_operations_constraint<GoblinUltraCircuitBuilder>(
const BigIntOperation& input, DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint);
const BigIntOperation& input,
DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);
template void create_bigint_addition_constraint<UltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<UltraCircuitBuilder>& dsl_bigint);
template void create_bigint_addition_constraint<GoblinUltraCircuitBuilder>(
Expand All @@ -55,9 +58,11 @@ template void create_bigint_mul_constraint<UltraCircuitBuilder>(const BigIntOper
template void create_bigint_mul_constraint<GoblinUltraCircuitBuilder>(
const BigIntOperation& input, DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint);
template void create_bigint_div_constraint<UltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<UltraCircuitBuilder>& dsl_bigint);
template void create_bigint_div_constraint<GoblinUltraCircuitBuilder>(
const BigIntOperation& input, DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint);
DSLBigInts<UltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);
template void create_bigint_div_constraint<GoblinUltraCircuitBuilder>(const BigIntOperation& input,
DSLBigInts<GoblinUltraCircuitBuilder>& dsl_bigint,
bool has_valid_witness_assignments);

template <typename Builder>
void create_bigint_addition_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigint)
Expand Down Expand Up @@ -198,8 +203,18 @@ void create_bigint_mul_constraint(const BigIntOperation& input, DSLBigInts<Build
}

template <typename Builder>
void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigint)
void create_bigint_div_constraint(const BigIntOperation& input,
DSLBigInts<Builder>& dsl_bigint,
bool has_valid_witness_assignments)
{
if (!has_valid_witness_assignments) {
// Asserts catch the case where the divisor is zero, so we need to provide a different value (1) to avoid the
// assert
std::array<uint32_t, 5> limbs_idx;
dsl_bigint.get_witness_idx_of_limbs(input.rhs, limbs_idx);
dsl_bigint.set_value(1, limbs_idx);
}

switch (dsl_bigint.get_modulus_id(input.lhs)) {
case ModulusId::BN254_FR: {
auto lhs = dsl_bigint.bn254_fr(input.lhs);
Expand Down Expand Up @@ -244,7 +259,9 @@ void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Build
}

template <typename Builder>
void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigint)
void create_bigint_operations_constraint(const BigIntOperation& input,
DSLBigInts<Builder>& dsl_bigint,
bool has_valid_witness_assignments)
{
switch (input.opcode) {
case BigIntOperationType::Add: {
Expand All @@ -260,7 +277,7 @@ void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInt
break;
}
case BigIntOperationType::Div: {
create_bigint_div_constraint<Builder>(input, dsl_bigint);
create_bigint_div_constraint<Builder>(input, dsl_bigint, has_valid_witness_assignments);
break;
}
default: {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/serialize/msgpack.hpp"

#include <array>
#include <cstdint>
#include <vector>

Expand Down Expand Up @@ -77,9 +79,13 @@ template <typename Builder> class DSLBigInts {
std::map<uint32_t, big_secp256r1_fq> m_secp256r1_fq;
std::map<uint32_t, big_secp256r1_fr> m_secp256r1_fr;

Builder* builder;

public:
DSLBigInts() = default;

void set_builder(Builder* ctx) { builder = ctx; }

ModulusId get_modulus_id(uint32_t bigint_id)
{
if (this->m_bn254_fq.contains(bigint_id)) {
Expand All @@ -104,6 +110,62 @@ template <typename Builder> class DSLBigInts {
return ModulusId::UNKNOWN;
}

/// Set value of the witnesses representing the bigfield element
/// so that the bigfield value is the input value.
/// The input value is decomposed into the binary basis for the binary limbs
/// The input array must be:
/// the 4 witness index of the binary limbs, and the index of the prime limb
void set_value(uint256_t value, const std::array<uint32_t, 5> limbs_idx)
{
uint256_t limb_modulus = uint256_t(1) << big_bn254_fq::NUM_LIMB_BITS;
builder->variables[limbs_idx[4]] = value;
for (uint32_t i = 0; i < 4; i++) {
uint256_t limb = value % limb_modulus;
value = (value - limb) / limb_modulus;
builder->variables[limbs_idx[i]] = limb;
}
}

/// Utility function that retrieve the witness indexes of a bigfield element
/// for use in set_value()
void get_witness_idx_of_limbs(uint32_t bigint_id, std::array<uint32_t, 5>& limbs_idx)
{
if (m_bn254_fr.contains(bigint_id)) {
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = m_bn254_fr[bigint_id].binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = m_bn254_fr[bigint_id].prime_basis_limb.witness_index;
} else if (m_bn254_fq.contains(bigint_id)) {
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = m_bn254_fq[bigint_id].binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = m_bn254_fq[bigint_id].prime_basis_limb.witness_index;
} else if (m_secp256k1_fq.contains(bigint_id)) {
auto big_field = m_secp256k1_fq[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
} else if (m_secp256k1_fr.contains(bigint_id)) {
auto big_field = m_secp256k1_fr[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
} else if (m_secp256r1_fr.contains(bigint_id)) {
auto big_field = m_secp256r1_fr[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
} else if (m_secp256r1_fq.contains(bigint_id)) {
auto big_field = m_secp256r1_fq[bigint_id];
for (uint32_t i = 0; i < 4; i++) {
limbs_idx[i] = big_field.binary_basis_limbs[i].element.witness_index;
}
limbs_idx[4] = big_field.prime_basis_limb.witness_index;
}
}
big_bn254_fr bn254_fr(uint32_t bigint_id)
{
if (this->m_bn254_fr.contains(bigint_id)) {
Expand Down Expand Up @@ -192,14 +254,14 @@ void create_bigint_to_le_bytes_constraint(Builder& builder,
DSLBigInts<Builder>& dsl_bigints);

template <typename Builder>
void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
void create_bigint_operations_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints, bool);
template <typename Builder>
void create_bigint_addition_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
template <typename Builder>
void create_bigint_sub_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
template <typename Builder>
void create_bigint_mul_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
template <typename Builder>
void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints);
void create_bigint_div_constraint(const BigIntOperation& input, DSLBigInts<Builder>& dsl_bigints, bool);

} // namespace acir_format
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,78 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse2)
EXPECT_EQ(verifier.verify_proof(proof), true);
}

TEST_F(BigIntTests, TestBigIntDIV)
{
// 6 / 3 = 2
// 6 = bigint(1) = from_bytes(w(1))
// 3 = bigint(2) = from_bytes(w(2))
// 2 = bigint(3) = to_bytes(w(3))
BigIntOperation div_constraint{
.lhs = 1,
.rhs = 2,
.result = 3,
.opcode = BigIntOperationType::Div,
};

BigIntFromLeBytes from_le_bytes_constraint_bigint1{
.inputs = { 1 },
.modulus = { 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA,
0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF },
.result = 1,
};
BigIntFromLeBytes from_le_bytes_constraint_bigint2{
.inputs = { 2 },
.modulus = { 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA,
0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF },
.result = 2,
};

BigIntToLeBytes result3_to_le_bytes{
.input = 3, .result = { 3 }, //
};

AcirFormat constraint_system{
.varnum = 5,
.recursive = false,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
.sha256_constraints = {},
.sha256_compression = {},
.schnorr_constraints = {},
.ecdsa_k1_constraints = {},
.ecdsa_r1_constraints = {},
.blake2s_constraints = {},
.blake3_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.poseidon2_constraints = {},
.fixed_base_scalar_mul_constraints = {},
.ec_add_constraints = {},
.recursion_constraints = {},
.bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1, from_le_bytes_constraint_bigint2 },
.bigint_to_le_bytes_constraints = { result3_to_le_bytes },
.bigint_operations = { div_constraint },
.constraints = {},
.block_constraints = {},

};

WitnessVector witness{
0, 6, 3, 2, 0,
};
auto builder = create_circuit(constraint_system, /*size_hint*/ 0, witness);
auto composer = Composer();
auto prover = composer.create_ultra_with_keccak_prover(builder);
auto proof = prover.construct_proof();
EXPECT_TRUE(CircuitChecker::check(builder));

auto builder2 = create_circuit(constraint_system, /*size_hint*/ 0, WitnessVector{});
EXPECT_TRUE(CircuitChecker::check(builder));
auto verifier2 = composer.create_ultra_with_keccak_verifier(builder);
EXPECT_EQ(verifier2.verify_proof(proof), true);
}
} // namespace acir_format::tests

0 comments on commit 253d002

Please sign in to comment.