Skip to content

Commit

Permalink
template recursive verifier on goblin flag
Browse files Browse the repository at this point in the history
  • Loading branch information
ledwards2225 committed Aug 28, 2023
1 parent ce56534 commit 90c3b19
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ template <typename Curve> class GeminiProver_ {
const Fr& r_challenge);
}; // namespace proof_system::honk::pcs::gemini

template <typename Curve> class GeminiVerifier_ {
template <typename Curve, bool goblin_flag = false> class GeminiVerifier_ {
using Fr = typename Curve::ScalarField;
using GroupElement = typename Curve::Element;
using Commitment = typename Curve::AffineElement;
Expand Down Expand Up @@ -242,8 +242,8 @@ template <typename Curve> class GeminiVerifier_ {
std::vector<GroupElement> commitments = {batched_f, batched_g};
auto one = Fr::from_witness(r.get_context(), 1);
// Note: these batch muls are not optimal since we are performing a mul by 1.
C0_r_pos = GroupElement::batch_mul(commitments, {one, r_inv});
C0_r_neg = GroupElement::batch_mul(commitments, {one, -r_inv});
C0_r_pos = GroupElement::template batch_mul<goblin_flag>(commitments, {one, r_inv});
C0_r_neg = GroupElement::template batch_mul<goblin_flag>(commitments, {one, -r_inv});
} else {
C0_r_pos = batched_f;
C0_r_neg = batched_f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ template <typename Curve> class ShplonkProver_ {
* @brief Shplonk Verifier
*
*/
template <typename Curve> class ShplonkVerifier_ {
template <typename Curve, bool goblin_flag = false> class ShplonkVerifier_ {
using Fr = typename Curve::ScalarField;
using GroupElement = typename Curve::Element;
using Commitment = typename Curve::AffineElement;
Expand Down Expand Up @@ -233,7 +233,7 @@ template <typename Curve> class ShplonkVerifier_ {
scalars.emplace_back(G_commitment_constant);

// [G] += G₀⋅[1] = [G] + (∑ⱼ ρʲ ⋅ vⱼ / ( r − xⱼ ))⋅[1]
G_commitment = GroupElement::batch_mul(commitments, scalars);
G_commitment = GroupElement::template batch_mul<goblin_flag>(commitments, scalars);

} else {
evaluation_zero = Fr(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ template <class Composer, class Fq, class Fr, class NativeGroup> class element {
// only works with Plookup!
template <size_t max_num_bits = 0>
static element wnaf_batch_mul(const std::vector<element>& points, const std::vector<Fr>& scalars);
template <bool use_goblin = false>
static element batch_mul(const std::vector<element>& points,
const std::vector<Fr>& scalars,
const size_t max_num_bits = 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,51 +597,55 @@ std::pair<element<C, Fq, Fr, G>, element<C, Fq, Fr, G>> element<C, Fq, Fr, G>::c
* scalars See `bn254_endo_batch_mul` for description of algorithm
**/
template <typename C, class Fq, class Fr, class G>
template <bool use_goblin>
element<C, Fq, Fr, G> element<C, Fq, Fr, G>::batch_mul(const std::vector<element>& points,
const std::vector<Fr>& scalars,
const size_t max_num_bits)
{

const size_t num_points = points.size();
ASSERT(scalars.size() == num_points);
batch_lookup_table point_table(points);
const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits;

std::vector<std::vector<bool_t<C>>> naf_entries;
for (size_t i = 0; i < num_points; ++i) {
naf_entries.emplace_back(compute_naf(scalars[i], max_num_bits));
}
const auto offset_generators = compute_offset_generators(num_rounds);
element accumulator =
element::chain_add_end(element::chain_add(offset_generators.first, point_table.get_chain_initial_entry()));

constexpr size_t num_rounds_per_iteration = 4;
size_t num_iterations = num_rounds / num_rounds_per_iteration;
num_iterations += ((num_iterations * num_rounds_per_iteration) == num_rounds) ? 0 : 1;
const size_t num_rounds_per_final_iteration = (num_rounds - 1) - ((num_iterations - 1) * num_rounds_per_iteration);
for (size_t i = 0; i < num_iterations; ++i) {

std::vector<bool_t<C>> nafs(num_points);
std::vector<element::chain_add_accumulator> to_add;
const size_t inner_num_rounds =
(i != num_iterations - 1) ? num_rounds_per_iteration : num_rounds_per_final_iteration;
for (size_t j = 0; j < inner_num_rounds; ++j) {
for (size_t k = 0; k < num_points; ++k) {
nafs[k] = (naf_entries[k][i * num_rounds_per_iteration + j + 1]);
if constexpr (use_goblin) {
return goblin_batch_mul(points, scalars);
} else {
const size_t num_points = points.size();
ASSERT(scalars.size() == num_points);
batch_lookup_table point_table(points);
const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits;

std::vector<std::vector<bool_t<C>>> naf_entries;
for (size_t i = 0; i < num_points; ++i) {
naf_entries.emplace_back(compute_naf(scalars[i], max_num_bits));
}
const auto offset_generators = compute_offset_generators(num_rounds);
element accumulator =
element::chain_add_end(element::chain_add(offset_generators.first, point_table.get_chain_initial_entry()));

constexpr size_t num_rounds_per_iteration = 4;
size_t num_iterations = num_rounds / num_rounds_per_iteration;
num_iterations += ((num_iterations * num_rounds_per_iteration) == num_rounds) ? 0 : 1;
const size_t num_rounds_per_final_iteration = (num_rounds - 1) - ((num_iterations - 1) * num_rounds_per_iteration);
for (size_t i = 0; i < num_iterations; ++i) {

std::vector<bool_t<C>> nafs(num_points);
std::vector<element::chain_add_accumulator> to_add;
const size_t inner_num_rounds =
(i != num_iterations - 1) ? num_rounds_per_iteration : num_rounds_per_final_iteration;
for (size_t j = 0; j < inner_num_rounds; ++j) {
for (size_t k = 0; k < num_points; ++k) {
nafs[k] = (naf_entries[k][i * num_rounds_per_iteration + j + 1]);
}
to_add.emplace_back(point_table.get_chain_add_accumulator(nafs));
}
to_add.emplace_back(point_table.get_chain_add_accumulator(nafs));
accumulator = accumulator.multiple_montgomery_ladder(to_add);
}
accumulator = accumulator.multiple_montgomery_ladder(to_add);
}
for (size_t i = 0; i < num_points; ++i) {
element skew = accumulator - points[i];
Fq out_x = accumulator.x.conditional_select(skew.x, naf_entries[i][num_rounds]);
Fq out_y = accumulator.y.conditional_select(skew.y, naf_entries[i][num_rounds]);
accumulator = element(out_x, out_y);
}
accumulator = accumulator - offset_generators.second;
for (size_t i = 0; i < num_points; ++i) {
element skew = accumulator - points[i];
Fq out_x = accumulator.x.conditional_select(skew.x, naf_entries[i][num_rounds]);
Fq out_y = accumulator.y.conditional_select(skew.y, naf_entries[i][num_rounds]);
accumulator = element(out_x, out_y);
}
accumulator = accumulator - offset_generators.second;

return accumulator;
return accumulator;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,26 @@
#include "barretenberg/numeric/bitop/get_msb.hpp"

namespace proof_system::plonk::stdlib::recursion::honk {
template <typename Flavor>
UltraRecursiveVerifier_<Flavor>::UltraRecursiveVerifier_(Builder* builder,
std::shared_ptr<typename Flavor::VerificationKey> verifier_key)

template <typename Flavor, bool goblin_flag>
UltraRecursiveVerifier_<Flavor, goblin_flag>::UltraRecursiveVerifier_(Builder* builder,
std::shared_ptr<VerificationKey> verifier_key)
: key(verifier_key)
, builder(builder)
{}

template <typename Flavor>
UltraRecursiveVerifier_<Flavor>::UltraRecursiveVerifier_(UltraRecursiveVerifier_&& other) noexcept
: key(std::move(other.key))
, pcs_verification_key(std::move(other.pcs_verification_key))
{}

template <typename Flavor>
UltraRecursiveVerifier_<Flavor>& UltraRecursiveVerifier_<Flavor>::operator=(UltraRecursiveVerifier_&& other) noexcept
{
key = other.key;
pcs_verification_key = (std::move(other.pcs_verification_key));
commitments.clear();
pcs_fr_elements.clear();
return *this;
}

/**
* @brief This function constructs a recursive verifier circuit for an Ultra Honk proof of a given flavor.
*
*/
template <typename Flavor>
std::array<typename Flavor::GroupElement, 2> UltraRecursiveVerifier_<Flavor>::verify_proof(const plonk::proof& proof, bool use_goblin)
template <typename Flavor, bool goblin_flag>
std::array<typename Flavor::GroupElement, 2> UltraRecursiveVerifier_<Flavor, goblin_flag>::verify_proof(
const plonk::proof& proof)
{
using FF = typename Flavor::FF;
using GroupElement = typename Flavor::GroupElement;
using Commitment = typename Flavor::Commitment;
using Sumcheck = ::proof_system::honk::sumcheck::SumcheckVerifier<Flavor>;
using Curve = typename Flavor::Curve;
using Gemini = ::proof_system::honk::pcs::gemini::GeminiVerifier_<Curve>;
using Shplonk = ::proof_system::honk::pcs::shplonk::ShplonkVerifier_<Curve>;
using Gemini = ::proof_system::honk::pcs::gemini::GeminiVerifier_<Curve, goblin_flag>;
using Shplonk = ::proof_system::honk::pcs::shplonk::ShplonkVerifier_<Curve, goblin_flag>;
using PCS = typename Flavor::PCS; // note: This can only be KZG
using VerifierCommitments = typename Flavor::VerifierCommitments;
using CommitmentLabels = typename Flavor::CommitmentLabels;
Expand Down Expand Up @@ -163,22 +146,14 @@ std::array<typename Flavor::GroupElement, 2> UltraRecursiveVerifier_<Flavor>::ve
scalars_unshifted[0] = FF::from_witness(builder, 1);

// Batch the commitments to the unshifted and to-be-shifted polynomials using powers of rho
GroupElement batched_commitment_unshifted;
if (use_goblin) {
batched_commitment_unshifted = GroupElement::goblin_batch_mul(commitments.get_unshifted(), scalars_unshifted);
} else {
batched_commitment_unshifted = GroupElement::batch_mul(commitments.get_unshifted(), scalars_unshifted);
}
auto batched_commitment_unshifted =
GroupElement::template batch_mul<goblin_flag>(commitments.get_unshifted(), scalars_unshifted);

info("Batch mul (unshifted): num gates = ", builder->get_num_gates() - prev_num_gates);
prev_num_gates = builder->get_num_gates();

GroupElement batched_commitment_to_be_shifted;
if (use_goblin) {
batched_commitment_to_be_shifted = GroupElement::goblin_batch_mul(commitments.get_to_be_shifted(), scalars_to_be_shifted);
} else {
batched_commitment_to_be_shifted = GroupElement::batch_mul(commitments.get_to_be_shifted(), scalars_to_be_shifted);
}
auto batched_commitment_to_be_shifted =
GroupElement::template batch_mul<goblin_flag>(commitments.get_to_be_shifted(), scalars_to_be_shifted);

info("Batch mul (to-be-shited): num gates = ", builder->get_num_gates() - prev_num_gates);
prev_num_gates = builder->get_num_gates();
Expand All @@ -201,14 +176,18 @@ std::array<typename Flavor::GroupElement, 2> UltraRecursiveVerifier_<Flavor>::ve
info("Shplonk: num gates = ", builder->get_num_gates() - prev_num_gates);
prev_num_gates = builder->get_num_gates();

info("Total: num gates = ", builder->get_num_gates());

// Constuct the inputs to the final KZG pairing check
auto pairing_points = PCS::compute_pairing_points(shplonk_claim, transcript);

info("KZG: num gates = ", builder->get_num_gates() - prev_num_gates);

info("Total: num gates = ", builder->get_num_gates());

return pairing_points;
}

template class UltraRecursiveVerifier_<proof_system::honk::flavor::UltraRecursive>;
using UltraRecursiveFlavor = proof_system::honk::flavor::UltraRecursive;
template class UltraRecursiveVerifier_<UltraRecursiveFlavor, /*goblin_flag*/ false>;
template class UltraRecursiveVerifier_<UltraRecursiveFlavor, /*goblin_flag*/ true>;

} // namespace proof_system::plonk::stdlib::recursion::honk
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@
#include "barretenberg/stdlib/recursion/honk/transcript/transcript.hpp"

namespace proof_system::plonk::stdlib::recursion::honk {
template <typename Flavor> class UltraRecursiveVerifier_ {
template <typename Flavor, bool goblin_flag = false> class UltraRecursiveVerifier_ {
using FF = typename Flavor::FF;
using Commitment = typename Flavor::Commitment;
using GroupElement = typename Flavor::GroupElement;
using VerificationKey = typename Flavor::VerificationKey;
using VerifierCommitmentKey = typename Flavor::VerifierCommitmentKey;
using Builder = typename Flavor::CircuitBuilder;
using PairingPoints = std::array<typename Flavor::GroupElement, 2>;
using PairingPoints = std::array<GroupElement, 2>;

public:
explicit UltraRecursiveVerifier_(Builder* builder, std::shared_ptr<VerificationKey> verifier_key = nullptr);
UltraRecursiveVerifier_(UltraRecursiveVerifier_&& other) noexcept;
UltraRecursiveVerifier_(UltraRecursiveVerifier_&& other) = delete;
UltraRecursiveVerifier_(const UltraRecursiveVerifier_& other) = delete;
UltraRecursiveVerifier_& operator=(const UltraRecursiveVerifier_& other) = delete;
UltraRecursiveVerifier_& operator=(UltraRecursiveVerifier_&& other) noexcept;
UltraRecursiveVerifier_& operator=(UltraRecursiveVerifier_&& other) = delete;
~UltraRecursiveVerifier_() = default;

// TODO(luke): Eventually this will return something like aggregation_state but I'm simplifying for now until we
// determine the exact interface. Simply returns the two pairing points.
PairingPoints verify_proof(const plonk::proof& proof, bool use_goblin = false);
PairingPoints verify_proof(const plonk::proof& proof);

std::shared_ptr<VerificationKey> key;
std::map<std::string, Commitment> commitments;
Expand All @@ -35,8 +37,10 @@ template <typename Flavor> class UltraRecursiveVerifier_ {
Transcript<Builder> transcript;
};

extern template class UltraRecursiveVerifier_<proof_system::honk::flavor::UltraRecursive>;
extern template class UltraRecursiveVerifier_<proof_system::honk::flavor::UltraRecursive, /*goblin_flag*/ false>;
extern template class UltraRecursiveVerifier_<proof_system::honk::flavor::UltraRecursive, /*goblin_flag*/ true>;

using UltraRecursiveVerifier = UltraRecursiveVerifier_<proof_system::honk::flavor::UltraRecursive>;
using UltraRecursiveVerifier =
UltraRecursiveVerifier_<proof_system::honk::flavor::UltraRecursive, /*goblin_flag*/ false>;

} // namespace proof_system::plonk::stdlib::recursion::honk
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ namespace proof_system::plonk::stdlib::recursion::honk {

template <typename UseGoblinFlag> class RecursiveVerifierTest : public testing::Test {

static constexpr bool use_goblin_flag = UseGoblinFlag::value;
static constexpr bool goblin_flag = UseGoblinFlag::value;

using InnerComposer = ::proof_system::honk::UltraComposer;
using InnerBuilder = typename InnerComposer::CircuitBuilder;

using OuterBuilder = ::proof_system::UltraCircuitBuilder;

using NativeVerifier = ::proof_system::honk::UltraVerifier_<::proof_system::honk::flavor::Ultra>;
using RecursiveVerifier = UltraRecursiveVerifier_<::proof_system::honk::flavor::UltraRecursive>;
using RecursiveVerifier = UltraRecursiveVerifier_<::proof_system::honk::flavor::UltraRecursive, goblin_flag>;
using VerificationKey = ::proof_system::honk::flavor::UltraRecursive::VerificationKey;

using inner_curve = bn254<InnerBuilder>;
Expand Down Expand Up @@ -117,7 +117,7 @@ template <typename UseGoblinFlag> class RecursiveVerifierTest : public testing::

// Instantiate the recursive verifier and construct the recusive verification circuit
RecursiveVerifier verifier(&outer_builder, verification_key);
auto pairing_points = verifier.verify_proof(proof_to_recursively_verify, use_goblin_flag);
auto pairing_points = verifier.verify_proof(proof_to_recursively_verify);

// For testing purposes only, perform native verification and compare the result
auto native_verifier = inner_composer.create_verifier(inner_circuit);
Expand Down

0 comments on commit 90c3b19

Please sign in to comment.