Skip to content

Commit

Permalink
chore: Shared Permutation+Lookup relation arithmetic (AztecProtocol/b…
Browse files Browse the repository at this point in the history
…arretenberg#559)

Co-authored-by: ledwards2225 <[email protected]>
  • Loading branch information
zac-williamson and ledwards2225 authored Jul 21, 2023
1 parent 97c9132 commit 4638333
Show file tree
Hide file tree
Showing 23 changed files with 564 additions and 510 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ namespace proof_system::honk {
* @tparam Program settings needed to establish if w_4 is being used.
* */
template <StandardFlavor Flavor>
void StandardComposer_<Flavor>::compute_witness(const CircuitBuilder& circuit_constructor,
const size_t minimum_circuit_size)
void StandardComposer_<Flavor>::compute_witness(const CircuitBuilder& circuit_constructor, const size_t /*unused*/)
{
if (computed_witness) {
return;
Expand Down Expand Up @@ -72,7 +71,7 @@ std::shared_ptr<typename Flavor::ProvingKey> StandardComposer_<Flavor>::compute_
* */
template <StandardFlavor Flavor>
std::shared_ptr<typename Flavor::VerificationKey> StandardComposer_<Flavor>::compute_verification_key(
const CircuitBuilder& circuit_constructor)
const CircuitBuilder& /*unused*/)
{
if (verification_key) {
return verification_key;
Expand Down Expand Up @@ -124,7 +123,7 @@ StandardProver_<Flavor> StandardComposer_<Flavor>::create_prover(const CircuitBu
compute_proving_key(circuit_constructor);
compute_witness(circuit_constructor);

compute_commitment_key(proving_key->circuit_size, crs_factory_);
compute_commitment_key(proving_key->circuit_size);

StandardProver_<Flavor> output_state(proving_key, commitment_key);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ template <StandardFlavor Flavor> class StandardComposer_ {

void compute_witness(const CircuitBuilder& circuit_constructor, const size_t minimum_circuit_size = 0);

void compute_commitment_key(size_t circuit_size, std::shared_ptr<srs::factories::CrsFactory> crs_factory)
void compute_commitment_key(size_t circuit_size)
{
commitment_key = std::make_shared<typename PCSParams::CommitmentKey>(circuit_size, crs_factory_);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ void UltraComposer_<Flavor>::compute_circuit_size_parameters(CircuitBuilder& cir
lookups_size += table.lookup_gates.size();
}

const size_t num_gates = circuit_constructor.num_gates;
num_public_inputs = circuit_constructor.public_inputs.size();

// minimum circuit size due to the length of lookups plus tables
Expand Down Expand Up @@ -170,7 +169,7 @@ UltraProver_<Flavor> UltraComposer_<Flavor>::create_prover(CircuitBuilder& circu
compute_proving_key(circuit_constructor);
compute_witness(circuit_constructor);

compute_commitment_key(proving_key->circuit_size, crs_factory_);
compute_commitment_key(proving_key->circuit_size);

UltraProver_<Flavor> output_state(proving_key, commitment_key);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ template <UltraFlavor Flavor> class UltraComposer_ {

void add_table_column_selector_poly_to_proving_key(polynomial& small, const std::string& tag);

void compute_commitment_key(size_t circuit_size, std::shared_ptr<srs::factories::CrsFactory> crs_factory)
void compute_commitment_key(size_t circuit_size)
{
commitment_key = std::make_shared<typename PCSParams::CommitmentKey>(circuit_size, crs_factory_);
};
Expand Down
1 change: 1 addition & 0 deletions barretenberg/cpp/src/barretenberg/honk/flavor/standard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Standard {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 4;

using GrandProductRelations = std::tuple<sumcheck::PermutationRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::ArithmeticRelation<FF>, sumcheck::PermutationRelation<FF>>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class StandardGrumpkin {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 4;

// define the tuple of Relations that require grand products
using GrandProductRelations = std::tuple<sumcheck::PermutationRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::ArithmeticRelation<FF>, sumcheck::PermutationRelation<FF>>;

Expand Down
1 change: 1 addition & 0 deletions barretenberg/cpp/src/barretenberg/honk/flavor/ultra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Ultra {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 11;

using GrandProductRelations = std::tuple<sumcheck::UltraPermutationRelation<FF>, sumcheck::LookupRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::UltraArithmeticRelation<FF>,
sumcheck::UltraPermutationRelation<FF>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class UltraGrumpkin {
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 11;

using GrandProductRelations = std::tuple<sumcheck::UltraPermutationRelation<FF>, sumcheck::LookupRelation<FF>>;
// define the tuple of Relations that comprise the Sumcheck relation
using Relations = std::tuple<sumcheck::UltraArithmeticRelation<FF>,
sumcheck::UltraPermutationRelation<FF>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#pragma once
#include "barretenberg/honk/sumcheck/sumcheck.hpp"
#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/polynomials/polynomial.hpp"
#include <typeinfo>

namespace proof_system::honk::grand_product_library {

// TODO(luke): This contains utilities for grand product computation and is not specific to the permutation grand
// product. Update comments accordingly.
/**
* @brief Compute a permutation grand product polynomial Z_perm(X)
* *
* @details
* Z_perm may be defined in terms of its values on X_i = 0,1,...,n-1 as Z_perm[0] = 1 and for i = 1:n-1
* relation::numerator(j)
* Z_perm[i] = ∏ --------------------------------------------------------------------------------
* relation::denominator(j)
*
* where ∏ := ∏_{j=0:i-1}
*
* The specific algebraic relation used by Z_perm is defined by Flavor::GrandProductRelations
*
* For example, in Flavor::Standard the relation describes:
*
* (w_1(j) + β⋅id_1(j) + γ) ⋅ (w_2(j) + β⋅id_2(j) + γ) ⋅ (w_3(j) + β⋅id_3(j) + γ)
* Z_perm[i] = ∏ --------------------------------------------------------------------------------
* (w_1(j) + β⋅σ_1(j) + γ) ⋅ (w_2(j) + β⋅σ_2(j) + γ) ⋅ (w_3(j) + β⋅σ_3(j) + γ)
* where ∏ := ∏_{j=0:i-1} and id_i(X) = id(X) + n*(i-1)
*
* For Flavor::Ultra both the UltraPermutation and Lookup grand products are computed by this method.
*
* The grand product is constructed over the course of three steps.
*
* For expositional simplicity, write Z_perm[i] as
*
* A(j)
* Z_perm[i] = ∏ --------------------------
* B(h)
*
* Step 1) Compute 2 length-n polynomials A, B
* Step 2) Compute 2 length-n polynomials numerator = ∏ A(j), nenominator = ∏ B(j)
* Step 3) Compute Z_perm[i + 1] = numerator[i] / denominator[i] (recall: Z_perm[0] = 1)
*
* Note: Step (3) utilizes Montgomery batch inversion to replace n-many inversions with
*/
template <typename Flavor, typename GrandProdRelation>
void compute_grand_product(const size_t circuit_size,
auto& full_polynomials,
sumcheck::RelationParameters<typename Flavor::FF>& relation_parameters)
{
using FF = typename Flavor::FF;
using Polynomial = typename Flavor::Polynomial;
using ValueAccumTypes = typename GrandProdRelation::ValueAccumTypes;

// Allocate numerator/denominator polynomials that will serve as scratch space
// TODO(zac) we can re-use the permutation polynomial as the numerator polynomial. Reduces readability
Polynomial numerator = Polynomial{ circuit_size };
Polynomial denominator = Polynomial{ circuit_size };

// Step (1)
// Populate `numerator` and `denominator` with the algebra described by Relation
const size_t num_threads = circuit_size >= get_num_cpus_pow2() ? get_num_cpus_pow2() : 1;
const size_t block_size = circuit_size / num_threads;
parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx + 1) * block_size;
for (size_t i = start; i < end; ++i) {

typename Flavor::ClaimedEvaluations evaluations;
for (size_t k = 0; k < Flavor::NUM_ALL_ENTITIES; ++k) {
evaluations[k] = full_polynomials[k].size() > i ? full_polynomials[k][i] : 0;
}
numerator[i] = GrandProdRelation::template compute_grand_product_numerator<ValueAccumTypes>(
evaluations, relation_parameters, i);
denominator[i] = GrandProdRelation::template compute_grand_product_denominator<ValueAccumTypes>(
evaluations, relation_parameters, i);
}
});

// Step (2)
// Compute the accumulating product of the numerator and denominator terms.
// This step is split into three parts for efficient multithreading:
// (i) compute ∏ A(j), ∏ B(j) subproducts for each thread
// (ii) compute scaling factor required to convert each subproduct into a single running product
// (ii) combine subproducts into a single running product
//
// For example, consider 4 threads and a size-8 numerator { a0, a1, a2, a3, a4, a5, a6, a7 }
// (i) Each thread computes 1 element of N = {{ a0, a0a1 }, { a2, a2a3 }, { a4, a4a5 }, { a6, a6a7 }}
// (ii) Take partial products P = { 1, a0a1, a2a3, a4a5 }
// (iii) Each thread j computes N[i][j]*P[j]=
// {{a0,a0a1},{a0a1a2,a0a1a2a3},{a0a1a2a3a4,a0a1a2a3a4a5},{a0a1a2a3a4a5a6,a0a1a2a3a4a5a6a7}}
std::vector<FF> partial_numerators(num_threads);
std::vector<FF> partial_denominators(num_threads);

parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx + 1) * block_size;
for (size_t i = start; i < end - 1; ++i) {
numerator[i + 1] *= numerator[i];
denominator[i + 1] *= denominator[i];
}
partial_numerators[thread_idx] = numerator[end - 1];
partial_denominators[thread_idx] = denominator[end - 1];
});

parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx + 1) * block_size;
if (thread_idx > 0) {
FF numerator_scaling = 1;
FF denominator_scaling = 1;

for (size_t j = 0; j < thread_idx; ++j) {
numerator_scaling *= partial_numerators[j];
denominator_scaling *= partial_denominators[j];
}
for (size_t i = start; i < end; ++i) {
numerator[i] *= numerator_scaling;
denominator[i] *= denominator_scaling;
}
}

// Final step: invert denominator
FF::batch_invert(std::span{ &denominator[start], block_size });
});

// Step (3) Compute z_perm[i] = numerator[i] / denominator[i]
auto& grand_product_polynomial = GrandProdRelation::get_grand_product_polynomial(full_polynomials);
grand_product_polynomial[0] = 0;
parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx == num_threads - 1) ? circuit_size - 1 : (thread_idx + 1) * block_size;
for (size_t i = start; i < end; ++i) {
grand_product_polynomial[i + 1] = numerator[i] * denominator[i];
}
});
}

template <typename Flavor>
void compute_grand_products(std::shared_ptr<typename Flavor::ProvingKey>& key,
typename Flavor::ProverPolynomials& full_polynomials,
sumcheck::RelationParameters<typename Flavor::FF>& relation_parameters)
{
using GrandProductRelations = typename Flavor::GrandProductRelations;
using FF = typename Flavor::FF;

constexpr size_t NUM_RELATIONS = std::tuple_size<GrandProductRelations>{};
barretenberg::constexpr_for<0, NUM_RELATIONS, 1>([&]<size_t i>() {
using GrandProdRelation = typename std::tuple_element<i, GrandProductRelations>::type;

// Assign the grand product polynomial to the relevant std::span member of `full_polynomials` (and its shift)
// For example, for UltraPermutationRelation, this will be `full_polynomials.z_perm`
// For example, for LookupRelation, this will be `full_polynomials.z_lookup`
std::span<FF>& full_polynomial = GrandProdRelation::get_grand_product_polynomial(full_polynomials);
auto& key_polynomial = GrandProdRelation::get_grand_product_polynomial(*key);
full_polynomial = key_polynomial;

compute_grand_product<Flavor, GrandProdRelation>(key->circuit_size, full_polynomials, relation_parameters);
std::span<FF>& full_polynomial_shift =
GrandProdRelation::get_shifted_grand_product_polynomial(full_polynomials);
full_polynomial_shift = key_polynomial.shifted();
});
}

} // namespace proof_system::honk::grand_product_library
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "prover.hpp"
#include "barretenberg/honk/proof_system/grand_product_library.hpp"
#include "barretenberg/honk/proof_system/prover_library.hpp"
#include "barretenberg/honk/sumcheck/sumcheck.hpp"
#include "barretenberg/honk/transcript/transcript.hpp"
Expand Down Expand Up @@ -112,12 +113,9 @@ template <StandardFlavor Flavor> void StandardProver_<Flavor>::execute_grand_pro
.public_input_delta = public_input_delta,
};

key->z_perm = prover_library::compute_permutation_grand_product<Flavor>(key, beta, gamma);
grand_product_library::compute_grand_products<Flavor>(key, prover_polynomials, relation_parameters);

queue.add_commitment(key->z_perm, commitment_labels.z_perm);

prover_polynomials.z_perm = key->z_perm;
prover_polynomials.z_perm_shift = key->z_perm.shifted();
}

/**
Expand Down
Loading

0 comments on commit 4638333

Please sign in to comment.