Skip to content

Commit

Permalink
feat: add parallelize_in helper function (#46)
Browse files Browse the repository at this point in the history
Multi-threading of witness generation is tricky because one has to
ensure the circuit column assignment order stays deterministic. To
ensure good developer experience / avoiding pitfalls, we provide a new
helper function for this.

Co-authored-by: Jonathan Wang <[email protected]>
  • Loading branch information
jonathanpwang and jonathanpwang committed May 23, 2023
1 parent 805a21c commit 0fff063
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 73 deletions.
1 change: 1 addition & 0 deletions halo2-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ num-traits = "0.2"
rand_chacha = "0.3"
rustc-hash = "1.1"
ff = "0.12"
rayon = "1.6.1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
log = "0.4"
Expand Down
3 changes: 3 additions & 0 deletions halo2-base/src/gates/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ use std::{
env::{set_var, var},
};

mod parallelize;
pub use parallelize::*;

/// Vector of thread advice column break points
pub type ThreadBreakPoints = Vec<usize>;
/// Vector of vectors tracking the thread break points across different halo2 phases
Expand Down
38 changes: 38 additions & 0 deletions halo2-base/src/gates/builder/parallelize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use itertools::Itertools;
use rayon::prelude::*;

use crate::{utils::ScalarField, Context};

use super::GateThreadBuilder;

/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`.
pub fn parallelize_in<F, T, R, FR>(
phase: usize,
builder: &mut GateThreadBuilder<F>,
input: Vec<T>,
f: FR,
) -> Vec<R>
where
F: ScalarField,
T: Send,
R: Send,
FR: Fn(&mut Context<F>, T) -> R + Send + Sync,
{
let witness_gen_only = builder.witness_gen_only();
// to prevent concurrency issues with context id, we generate all the ids first
let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec();
let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input
.into_par_iter()
.zip(ctx_ids.into_par_iter())
.map(|(input, ctx_id)| {
// create new context
let mut ctx = Context::new(witness_gen_only, ctx_id);
let output = f(&mut ctx, input);
(output, ctx)
})
.unzip();
// we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused
builder.threads[phase].append(&mut ctxs);

outputs
}
31 changes: 15 additions & 16 deletions halo2-ecc/src/ecc/fixed_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip};
use crate::fields::{FieldChip, PrimeField, Selectable};
use group::Curve;
use halo2_base::gates::builder::GateThreadBuilder;
use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder};
use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context};
use itertools::Itertools;
use rayon::prelude::*;
Expand Down Expand Up @@ -107,6 +107,7 @@ where
curr_point.unwrap()
}

/* To reduce total amount of code, just always use msm_par below.
// basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation
// we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO)
pub fn msm<F, FC, C>(
Expand Down Expand Up @@ -212,6 +213,7 @@ where
.collect_vec();
chip.sum::<C>(ctx, scalar_mults)
}
*/

/// # Assumptions
/// * `points.len() = scalars.len()`
Expand Down Expand Up @@ -269,25 +271,23 @@ where
C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);

let field_chip = chip.field_chip();
let witness_gen_only = builder.witness_gen_only();

let zero = builder.main(phase).load_zero();
let thread_ids = (0..scalars.len()).map(|_| builder.get_new_thread_id()).collect::<Vec<_>>();
let (new_threads, scalar_mults): (Vec<_>, Vec<_>) = cached_points_affine
.par_chunks(cached_points_affine.len() / points.len())
.zip_eq(scalars.into_par_iter())
.zip(thread_ids.into_par_iter())
.map(|((cached_points, scalar), thread_id)| {
let mut thread = Context::new(witness_gen_only, thread_id);
let ctx = &mut thread;

let scalar_mults = parallelize_in(
phase,
builder,
cached_points_affine
.chunks(cached_points_affine.len() / points.len())
.zip_eq(scalars)
.collect(),
|ctx, (cached_points, scalar)| {
let cached_points = cached_points
.iter()
.map(|point| chip.assign_constant_point(ctx, *point))
.collect_vec();
let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();

debug_assert_eq!(scalar.len(), scalar_len);
assert_eq!(scalar.len(), scalar_len);
let bits = scalar
.into_iter()
.flat_map(|scalar_chunk| {
Expand Down Expand Up @@ -319,9 +319,8 @@ where
field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window)
};
}
(thread, curr_point.unwrap())
})
.unzip();
builder.threads[phase].extend(new_threads);
curr_point.unwrap()
},
);
chip.sum::<C>(builder.main(phase), scalar_mults)
}
25 changes: 10 additions & 15 deletions halo2-ecc/src/ecc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ impl<'chip, F: PrimeField, FC: FieldChip<F>> EccChip<'chip, F, FC> {
self.field_chip.assert_equal(ctx, P.y, Q.y);
}

/// None of elements in `points` can be point at infinity.
pub fn sum<C>(
&self,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -1153,21 +1154,15 @@ impl<'chip, F: PrimeField, FC: FieldChip<F>> EccChip<'chip, F, FC> {
#[cfg(feature = "display")]
println!("computing length {} fixed base msm", points.len());

// heuristic to decide when to use parallelism
if points.len() < 25 {
let ctx = builder.main(phase);
fixed_base::msm(self, ctx, points, scalars, max_scalar_bits_per_cell, clump_factor)
} else {
fixed_base::msm_par(
self,
builder,
points,
scalars,
max_scalar_bits_per_cell,
clump_factor,
phase,
)
}
fixed_base::msm_par(
self,
builder,
points,
scalars,
max_scalar_bits_per_cell,
clump_factor,
phase,
)

// Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator`
// Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4
Expand Down
69 changes: 27 additions & 42 deletions halo2-ecc/src/ecc/pippenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ use crate::{
fields::{FieldChip, PrimeField, Selectable},
};
use halo2_base::{
gates::{builder::GateThreadBuilder, GateInstructions},
gates::{
builder::{parallelize_in, GateThreadBuilder},
GateInstructions,
},
utils::CurveAffineExt,
AssignedValue, Context,
AssignedValue,
};
use rayon::prelude::*;

// Reference: https://jbootle.github.io/Misc/pippenger.pdf

Expand Down Expand Up @@ -238,7 +240,6 @@ where

// get a main thread
let ctx = builder.main(phase);
let witness_gen_only = ctx.witness_gen_only();
// single-threaded computation:
for scalar in scalars {
for (scalar_chunk, bool_chunk) in
Expand All @@ -250,32 +251,28 @@ where
}
}
}
// see multi-product comments for explanation of below

let c = clump_factor;
let num_rounds = (points.len() + c - 1) / c;
// to avoid adding two points that are equal or negative of each other,
// we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness
// note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints
// we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do
let any_base = load_random_point::<F, FC, C>(chip, ctx);
let mut any_points = Vec::with_capacity(num_rounds);
any_points.push(any_base);
for _ in 1..num_rounds {
any_points.push(ec_double(chip, ctx, any_points.last().unwrap()));
}
// we will use a different thread per round
// to prevent concurrency issues with context id, we generate all the ids first
let thread_ids = (0..num_rounds).map(|_| builder.get_new_thread_id()).collect::<Vec<_>>();
// now begins multi-threading

// now begins multi-threading
// multi_prods is 2d vector of size `num_rounds` by `scalar_bits`
let (new_threads, multi_prods): (Vec<_>, Vec<_>) = points
.par_chunks(c)
.zip(any_points.par_iter())
.zip(thread_ids.into_par_iter())
.enumerate()
.map(|(round, ((points_clump, any_point), thread_id))| {
let multi_prods = parallelize_in(
phase,
builder,
points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(),
|ctx, (round, (points_clump, any_point))| {
// compute all possible multi-products of elements in points[round * c .. round * (c+1)]
// create new thread
let mut thread = Context::new(witness_gen_only, thread_id);
let ctx = &mut thread;
// stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... }
let mut bucket = Vec::with_capacity(1 << c);
let any_point = into_strict_point(chip, ctx, any_point.clone());
Expand All @@ -294,7 +291,7 @@ where
bucket.push(new_point);
}
}
let multi_prods = bool_scalars
bool_scalars
.iter()
.map(|bits| {
strict_ec_select_from_bits(
Expand All @@ -304,31 +301,19 @@ where
&bits[round * c..round * c + points_clump.len()],
)
})
.collect::<Vec<_>>();

(thread, multi_prods)
})
.unzip();
// we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused
builder.threads[phase].extend(new_threads);
.collect::<Vec<_>>()
},
);

// agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits
let thread_ids = (0..scalar_bits).map(|_| builder.get_new_thread_id()).collect::<Vec<_>>();
let (new_threads, mut agg): (Vec<_>, Vec<_>) = thread_ids
.into_par_iter()
.enumerate()
.map(|(i, thread_id)| {
let mut thread = Context::new(witness_gen_only, thread_id);
let ctx = &mut thread;
let mut acc = multi_prods[0][i].clone();
for multi_prod in multi_prods.iter().skip(1) {
let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true);
acc = into_strict_point(chip, ctx, _acc);
}
(thread, acc)
})
.unzip();
builder.threads[phase].extend(new_threads);
let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| {
let mut acc = multi_prods[0][i].clone();
for multi_prod in multi_prods.iter().skip(1) {
let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true);
acc = into_strict_point(chip, ctx, _acc);
}
acc
});

// gets the LAST thread for single threaded work
let ctx = builder.main(phase);
Expand Down

0 comments on commit 0fff063

Please sign in to comment.