Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(bb): use std::span in pippenger for scalars #8269

Merged
merged 8 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ int pippenger()
scalar_multiplication::pippenger_runtime_state<curve::BN254> state(NUM_POINTS);
std::chrono::steady_clock::time_point time_start = std::chrono::steady_clock::now();
g1::element result = scalar_multiplication::pippenger_unsafe<curve::BN254>(
&scalars[0], reference_string->get_monomial_points(), NUM_POINTS, state);
{ &scalars[0], NUM_POINTS }, reference_string->get_monomial_points(), NUM_POINTS, state);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this syntax construct a span from the localation of scalars[0] up to NUM_POINTS? I assume so. It'd be good if you could document this in whatever place you see best

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented

std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
std::chrono::microseconds diff = std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_start);
std::cout << "run time: " << diff.count() << "us" << std::endl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ template <class Curve> class CommitmentKey {
ASSERT(false);
}
return scalar_multiplication::pippenger_unsafe<Curve>(
const_cast<Fr*>(polynomial.data()), srs->get_monomial_points(), degree, pippenger_runtime_state);
polynomial, srs->get_monomial_points(), degree, pippenger_runtime_state);
};

/**
Expand Down Expand Up @@ -146,7 +146,7 @@ template <class Curve> class CommitmentKey {

// Call the version of pippenger which assumes all points are distinct
return scalar_multiplication::pippenger_unsafe<Curve>(
scalars.data(), points.data(), scalars.size(), pippenger_runtime_state);
scalars, points.data(), scalars.size(), pippenger_runtime_state);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,13 @@ template <typename Curve_> class IPA {
// Step 6.a (using letters, because doxygen automaticall converts the sublist counters to letters :( )
// L_i = < a_vec_lo, G_vec_hi > + inner_prod_L * aux_generator
L_i = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&a_vec[0], &G_vec_local[round_size], round_size, ck->pippenger_runtime_state);
{&a_vec[0], round_size}, &G_vec_local[round_size], round_size, ck->pippenger_runtime_state);
L_i += aux_generator * inner_prod_L;

// Step 6.b
// R_i = < a_vec_hi, G_vec_lo > + inner_prod_R * aux_generator
R_i = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&a_vec[round_size], &G_vec_local[0], round_size, ck->pippenger_runtime_state);
{&a_vec[round_size], round_size}, &G_vec_local[0], round_size, ck->pippenger_runtime_state);
R_i += aux_generator * inner_prod_R;

// Step 6.c
Expand Down Expand Up @@ -345,7 +345,7 @@ template <typename Curve_> class IPA {
// Step 5.
// Compute C₀ = C' + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j
GroupElement LR_sums = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&msm_scalars[0], &msm_elements[0], pippenger_size, vk->pippenger_runtime_state);
{&msm_scalars[0], pippenger_size}, &msm_elements[0], pippenger_size, vk->pippenger_runtime_state);
GroupElement C_zero = C_prime + LR_sums;

// Step 6.
Expand Down Expand Up @@ -394,7 +394,7 @@ template <typename Curve_> class IPA {
// Step 8.
// Compute G₀
Commitment G_zero = bb::scalar_multiplication::pippenger_without_endomorphism_basis_points<Curve>(
&s_vec[0], &G_vec_local[0], poly_length, vk->pippenger_runtime_state);
{&s_vec[0], poly_length}, &G_vec_local[0], poly_length, vk->pippenger_runtime_state);

// Step 9.
// Receive a₀ from the prover
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ template <typename Curve>
void compute_wnaf_states(uint64_t* point_schedule,
bool* input_skew_table,
uint64_t* round_counts,
const typename Curve::ScalarField* scalars,
const std::span<const typename Curve::ScalarField> scalars,
const size_t num_initial_points)
{
using Fr = typename Curve::ScalarField;
Expand Down Expand Up @@ -857,21 +857,26 @@ typename Curve::Element evaluate_pippenger_rounds(pippenger_runtime_state<Curve>

template <typename Curve>
typename Curve::Element pippenger_internal(typename Curve::AffineElement* points,
typename Curve::ScalarField* scalars,
std::span<const typename Curve::ScalarField> scalars,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
bool handle_edge_cases)
{
size_t num_initial_points_power_2 = 1 << numeric::get_msb(num_initial_points);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a sente commenting these three lines

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shouldn't be here actually right now, will remove

if (num_initial_points_power_2 != num_initial_points) {
num_initial_points_power_2 *= 2; // Round up
}
// multiplication_runtime_state state;
compute_wnaf_states<Curve>(state.point_schedule, state.skew_table, state.round_counts, scalars, num_initial_points);
organize_buckets(state.point_schedule, num_initial_points * 2);
compute_wnaf_states<Curve>(
state.point_schedule, state.skew_table, state.round_counts, scalars, num_initial_points_power_2);
organize_buckets(state.point_schedule, num_initial_points_power_2 * 2);
typename Curve::Element result =
evaluate_pippenger_rounds<Curve>(state, points, num_initial_points * 2, handle_edge_cases);
evaluate_pippenger_rounds<Curve>(state, points, num_initial_points_power_2 * 2, handle_edge_cases);
return result;
}

template <typename Curve>
typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
Expand Down Expand Up @@ -910,10 +915,9 @@ typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
const auto num_slice_points = static_cast<size_t>(1ULL << slice_bits);

Element result = pippenger_internal(points, scalars, num_slice_points, state, handle_edge_cases);

if (num_slice_points != num_initial_points) {
const uint64_t leftover_points = num_initial_points - num_slice_points;
return result + pippenger(scalars + num_slice_points,
return result + pippenger(scalars.subspan(num_slice_points),
points + static_cast<size_t>(num_slice_points * 2),
static_cast<size_t>(leftover_points),
state,
Expand All @@ -938,7 +942,7 @@ typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
*
**/
template <typename Curve>
typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger_unsafe(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state)
Expand All @@ -947,10 +951,11 @@ typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
}

template <typename Curve>
typename Curve::Element pippenger_without_endomorphism_basis_points(typename Curve::ScalarField* scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state)
typename Curve::Element pippenger_without_endomorphism_basis_points(
std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<Curve>& state)
{
std::vector<typename Curve::AffineElement> G_mod(num_initial_points * 2);
bb::scalar_multiplication::generate_pippenger_point_table<Curve>(points, &G_mod[0], num_initial_points);
Expand Down Expand Up @@ -978,7 +983,7 @@ template void evaluate_addition_chains<curve::BN254>(affine_product_runtime_stat
const size_t max_bucket_bits,
bool handle_edge_cases);
template curve::BN254::Element pippenger_internal<curve::BN254>(curve::BN254::AffineElement* points,
curve::BN254::ScalarField* scalars,
std::span<const curve::BN254::ScalarField> scalars,
const size_t num_initial_points,
pippenger_runtime_state<curve::BN254>& state,
bool handle_edge_cases);
Expand All @@ -992,19 +997,19 @@ template curve::BN254::AffineElement* reduce_buckets<curve::BN254>(affine_produc
bool first_round = true,
bool handle_edge_cases = false);

template curve::BN254::Element pippenger<curve::BN254>(curve::BN254::ScalarField* scalars,
template curve::BN254::Element pippenger<curve::BN254>(std::span<const curve::BN254::ScalarField> scalars,
curve::BN254::AffineElement* points,
const size_t num_points,
pippenger_runtime_state<curve::BN254>& state,
bool handle_edge_cases = true);

template curve::BN254::Element pippenger_unsafe<curve::BN254>(curve::BN254::ScalarField* scalars,
template curve::BN254::Element pippenger_unsafe<curve::BN254>(std::span<const curve::BN254::ScalarField> scalars,
curve::BN254::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::BN254>& state);

template curve::BN254::Element pippenger_without_endomorphism_basis_points<curve::BN254>(
curve::BN254::ScalarField* scalars,
std::span<const curve::BN254::ScalarField> scalars,
curve::BN254::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::BN254>& state);
Expand All @@ -1028,11 +1033,12 @@ template void add_affine_points_with_edge_cases<curve::Grumpkin>(curve::Grumpkin
template void evaluate_addition_chains<curve::Grumpkin>(affine_product_runtime_state<curve::Grumpkin>& state,
const size_t max_bucket_bits,
bool handle_edge_cases);
template curve::Grumpkin::Element pippenger_internal<curve::Grumpkin>(curve::Grumpkin::AffineElement* points,
curve::Grumpkin::ScalarField* scalars,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state,
bool handle_edge_cases);
template curve::Grumpkin::Element pippenger_internal<curve::Grumpkin>(
curve::Grumpkin::AffineElement* points,
std::span<const curve::Grumpkin::ScalarField> scalars,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state,
bool handle_edge_cases);

template curve::Grumpkin::Element evaluate_pippenger_rounds<curve::Grumpkin>(
pippenger_runtime_state<curve::Grumpkin>& state,
Expand All @@ -1043,19 +1049,20 @@ template curve::Grumpkin::Element evaluate_pippenger_rounds<curve::Grumpkin>(
template curve::Grumpkin::AffineElement* reduce_buckets<curve::Grumpkin>(
affine_product_runtime_state<curve::Grumpkin>& state, bool first_round = true, bool handle_edge_cases = false);

template curve::Grumpkin::Element pippenger<curve::Grumpkin>(curve::Grumpkin::ScalarField* scalars,
template curve::Grumpkin::Element pippenger<curve::Grumpkin>(std::span<const curve::Grumpkin::ScalarField> scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_points,
pippenger_runtime_state<curve::Grumpkin>& state,
bool handle_edge_cases = true);

template curve::Grumpkin::Element pippenger_unsafe<curve::Grumpkin>(curve::Grumpkin::ScalarField* scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state);
template curve::Grumpkin::Element pippenger_unsafe<curve::Grumpkin>(
std::span<const curve::Grumpkin::ScalarField> scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state);

template curve::Grumpkin::Element pippenger_without_endomorphism_basis_points<curve::Grumpkin>(
curve::Grumpkin::ScalarField* scalars,
std::span<const curve::Grumpkin::ScalarField> scalars,
curve::Grumpkin::AffineElement* points,
const size_t num_initial_points,
pippenger_runtime_state<curve::Grumpkin>& state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ template <typename Curve>
void compute_wnaf_states(uint64_t* point_schedule,
bool* input_skew_table,
uint64_t* round_counts,
const typename Curve::ScalarField* scalars,
std::span<const typename Curve::ScalarField> scalars,
size_t num_initial_points);

template <typename Curve>
Expand Down Expand Up @@ -135,7 +135,7 @@ void evaluate_addition_chains(affine_product_runtime_state<Curve>& state,
bool handle_edge_cases);
template <typename Curve>
typename Curve::Element pippenger_internal(typename Curve::AffineElement* points,
typename Curve::ScalarField* scalars,
std::span<const typename Curve::ScalarField> scalars,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
bool handle_edge_cases);
Expand All @@ -152,23 +152,24 @@ typename Curve::AffineElement* reduce_buckets(affine_product_runtime_state<Curve
bool handle_edge_cases = false);

template <typename Curve>
typename Curve::Element pippenger(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state,
bool handle_edge_cases = true);

template <typename Curve>
typename Curve::Element pippenger_unsafe(typename Curve::ScalarField* scalars,
typename Curve::Element pippenger_unsafe(std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state);

template <typename Curve>
typename Curve::Element pippenger_without_endomorphism_basis_points(typename Curve::ScalarField* scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state);
typename Curve::Element pippenger_without_endomorphism_basis_points(
std::span<const typename Curve::ScalarField> scalars,
typename Curve::AffineElement* points,
size_t num_initial_points,
pippenger_runtime_state<Curve>& state);

// Explicit instantiation
// BN254
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ template <typename program_settings> bool VerifierBase<program_settings>::verify

g1::element P[2];

P[0] = bb::scalar_multiplication::pippenger<curve::BN254>(&scalars[0], &elements[0], num_elements, state);
P[0] = bb::scalar_multiplication::pippenger<curve::BN254>(
{ &scalars[0], num_elements }, &elements[0], num_elements, state);
P[1] = -(g1::element(PI_Z_OMEGA) * separator_challenge + PI_Z);

if (key->contains_recursive_proof) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ plonk::Verifier generate_verifier(std::shared_ptr<proving_key> circuit_proving_k
commitments.resize(8);

for (size_t i = 0; i < 8; ++i) {
commitments[i] = g1::affine_element(
scalar_multiplication::pippenger<curve::BN254>(poly_coefficients[i].get(),
circuit_proving_key->reference_string->get_monomial_points(),
circuit_proving_key->circuit_size,
state));
commitments[i] = g1::affine_element(scalar_multiplication::pippenger<curve::BN254>(
{ poly_coefficients[i].get(), circuit_proving_key->circuit_size },
circuit_proving_key->reference_string->get_monomial_points(),
circuit_proving_key->circuit_size,
state));
}

auto crs = std::make_shared<bb::srs::factories::FileVerifierCrs<curve::BN254>>("../srs_db/ignition");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ void work_queue::process_queue()
// Run pippenger multi-scalar multiplication.
auto runtime_state = bb::scalar_multiplication::pippenger_runtime_state<curve::BN254>(msm_size);
bb::g1::affine_element result(bb::scalar_multiplication::pippenger_unsafe<curve::BN254>(
item.mul_scalars.get(), srs_points, msm_size, runtime_state));
{ item.mul_scalars.get(), msm_size }, srs_points, msm_size, runtime_state));

transcript->add_element(item.tag, result.to_buffer());

Expand Down
Loading
Loading