Skip to content

Commit

Permalink
[chore] Parallelize Poseidon trace generation (#1045)
Browse files Browse the repository at this point in the history
* Parallelize Poseidon trace generation

* chore: par_extend

---------

Co-authored-by: Jonathan Wang <[email protected]>
  • Loading branch information
nyunyunyunyu and jonathanpwang authored Dec 15, 2024
1 parent a26c3de commit 6c0c0e2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ derivative.workspace = true
static_assertions.workspace = true
async-trait.workspace = true
getset.workspace = true
rayon = { workspace = true, optional = true }

[dev-dependencies]
p3-dft = { workspace = true }
Expand Down Expand Up @@ -70,7 +71,7 @@ hex.workspace = true

[features]
default = ["parallel", "mimalloc"]
parallel = ["openvm-stark-backend/parallel"]
parallel = ["openvm-stark-backend/parallel", "dep:rayon"]
test-utils = ["openvm-ecc-guest/halo2curves", "dep:openvm-stark-sdk"]
bench-metrics = [
"dep:metrics",
Expand Down
20 changes: 16 additions & 4 deletions crates/vm/src/system/poseidon2/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ use openvm_stark_backend::{
p3_air::BaseAir,
p3_field::PrimeField32,
p3_matrix::dense::RowMajorMatrix,
p3_maybe_rayon::prelude::*,
prover::types::AirProofInput,
rap::{get_air_name, AnyRap},
Chip, ChipUsageGetter,
};
#[cfg(feature = "parallel")]
use rayon::iter::ParallelExtend;

use super::{columns::*, Poseidon2Chip};

Expand All @@ -35,12 +38,21 @@ where

let aux_cols_factory = memory_controller.borrow().aux_cols_factory();
let mut flat_rows: Vec<_> = records
.into_iter()
.into_par_iter()
.flat_map(|record| Self::record_to_cols(&aux_cols_factory, record).flatten())
.collect();
for _ in 0..diff {
flat_rows.extend(Poseidon2VmCols::<Val<SC>>::blank_row(&air).flatten());
}
#[cfg(feature = "parallel")]
flat_rows.par_extend(
vec![Poseidon2VmCols::<Val<SC>>::blank_row(&air).flatten(); diff]
.into_par_iter()
.flatten(),
);
#[cfg(not(feature = "parallel"))]
flat_rows.extend(
vec![Poseidon2VmCols::<Val<SC>>::blank_row(&air).flatten(); diff]
.into_iter()
.flatten(),
);

AirProofInput::simple_no_pis(
Arc::new(air.clone()),
Expand Down

0 comments on commit 6c0c0e2

Please sign in to comment.