From 566656e7e858b7c9ea3b0425c9814137268e26cb Mon Sep 17 00:00:00 2001 From: Guillaume Endignoux Date: Sun, 1 Dec 2024 15:41:35 +0100 Subject: [PATCH] Add option to use Rayon's with_max_len() parameter. --- src/main.rs | 10 ++++++ src/meek.rs | 24 ++++++++++++- src/vote_count.rs | 71 +++++++++++++++++++++++++++++--------- tools/benchmark-max-len.sh | 48 ++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 17 deletions(-) create mode 100755 tools/benchmark-max-len.sh diff --git a/src/main.rs b/src/main.rs index fa71a69..28afe7c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -89,6 +89,11 @@ struct MeekParams { #[arg(long)] disable_work_stealing: bool, + /// Maximal length of a serial run of items. Ignored if `--parallel` isn't + /// set to "rayon". + #[arg(long)] + max_serial_len: Option, + /// Enable a bug-fix in the surplus calculation, preventing it from being /// negative. Results may differ from Droop.py, but this prevents /// crashes. @@ -202,6 +207,7 @@ impl Cli { meek_params.parallel, meek_params.num_threads, meek_params.disable_work_stealing, + meek_params.max_serial_len, meek_params.force_positive_surplus, meek_params.equalize, )?; @@ -273,6 +279,7 @@ mod test { disable_work_stealing: false, force_positive_surplus: false, equalize: false, + max_serial_len: None, }) } ); @@ -323,6 +330,7 @@ mod test { disable_work_stealing: true, force_positive_surplus: true, equalize: true, + max_serial_len: None, }), } ); @@ -360,6 +368,7 @@ mod test { disable_work_stealing: true, force_positive_surplus: true, equalize: true, + max_serial_len: None, }), } ); @@ -437,6 +446,7 @@ mod test { disable_work_stealing: false, force_positive_surplus: false, equalize, + max_serial_len: None, }), } } diff --git a/src/meek.rs b/src/meek.rs index 7537488..2fcbf00 100644 --- a/src/meek.rs +++ b/src/meek.rs @@ -38,6 +38,7 @@ pub fn stv_droop( parallel: Parallel, num_threads: Option, disable_work_stealing: bool, + max_serial_len: Option, force_positive_surplus: bool, equalize: bool, ) -> io::Result @@ -76,6 +77,7 @@ where force_positive_surplus, pascal, None, + None, ); state.run(stdout, package_name, omega_exponent) } @@ -95,6 +97,7 @@ where force_positive_surplus, pascal, None, + max_serial_len.map(|x| x.into()), ); state.run(stdout, package_name, omega_exponent) } @@ -130,6 +133,7 @@ where force_positive_surplus, pascal, Some(thread_pool), + None, ); state.run(stdout, package_name, omega_exponent) }), @@ -201,6 +205,7 @@ pub struct State<'scope, 'e, I, R> { /// enabled. pascal: Option<&'e [Vec]>, thread_pool: Option>, + max_serial_len: Option, _phantom: PhantomData, } @@ -225,6 +230,7 @@ where force_positive_surplus: bool, pascal: Option<&'a [Vec]>, thread_pool: Option>, + max_serial_len: Option, ) -> State<'scope, 'a, I, R> { State { election, @@ -250,6 +256,7 @@ where force_positive_surplus, pascal, thread_pool, + max_serial_len, _phantom: PhantomData, } } @@ -337,6 +344,7 @@ where self.parallel, self.thread_pool.as_ref(), self.pascal, + self.max_serial_len, ) } @@ -936,6 +944,7 @@ mod test { force_positive_surplus: self.force_positive_surplus.unwrap(), pascal: self.pascal.unwrap(), thread_pool: None, + max_serial_len: None, _phantom: PhantomData, } } @@ -1061,6 +1070,7 @@ mod test { parallel, None, false, + None, false, false, ) @@ -1288,6 +1298,7 @@ Action: Count Complete parallel, None, false, + None, false, true, ) @@ -1454,6 +1465,7 @@ Action: Count Complete Parallel::No, None, false, + None, false, false, ) @@ -1603,6 +1615,7 @@ Action: Count Complete Parallel::Custom, /* num_threads = */ Some(NonZeroUsize::new(2).unwrap()), false, + None, false, false, ); @@ -1621,6 +1634,7 @@ Action: Count Complete Parallel::Custom, /* num_threads = */ Some(NonZeroUsize::new(2).unwrap()), false, + None, false, false, ) @@ -1739,7 +1753,15 @@ Action: Count Complete ]) .build(); let omega_exponent = 6; - let state = State::new(&election, omega_exponent, Parallel::No, false, None, None); + let state = State::new( + &election, + omega_exponent, + Parallel::No, + false, + None, + None, + None, + ); let mut buf = Vec::new(); let count = state diff --git a/src/vote_count.rs b/src/vote_count.rs index 82ddd09..09d1240 100644 --- a/src/vote_count.rs +++ b/src/vote_count.rs @@ -110,10 +110,13 @@ where parallel: Parallel, thread_pool: Option<&ThreadPool<'_, I, R>>, pascal: Option<&[Vec]>, + max_serial_len: Option, ) -> Self { let vote_accumulator = match parallel { Parallel::No => Self::accumulate_votes_serial(election, keep_factors, pascal), - Parallel::Rayon => Self::accumulate_votes_rayon(election, keep_factors, pascal), + Parallel::Rayon => { + Self::accumulate_votes_rayon(election, keep_factors, pascal, max_serial_len) + } Parallel::Custom => thread_pool.unwrap().accumulate_votes(keep_factors), }; @@ -141,22 +144,53 @@ where election: &Election, keep_factors: &[R], pascal: Option<&[Vec]>, + max_serial_len: Option, ) -> VoteAccumulator { - election - .ballots - .par_iter() - .enumerate() - .fold_with( - VoteAccumulator::new(election.num_candidates), - |mut vote_accumulator, (i, ballot)| { - Self::process_ballot(&mut vote_accumulator, keep_factors, pascal, i, ballot); - vote_accumulator - }, - ) - .reduce( - || VoteAccumulator::new(election.num_candidates), - |a, b| a.reduce(b), - ) + match max_serial_len { + None => election + .ballots + .par_iter() + .enumerate() + .fold_with( + VoteAccumulator::new(election.num_candidates), + |mut vote_accumulator, (i, ballot)| { + Self::process_ballot( + &mut vote_accumulator, + keep_factors, + pascal, + i, + ballot, + ); + vote_accumulator + }, + ) + .reduce( + || VoteAccumulator::new(election.num_candidates), + |a, b| a.reduce(b), + ), + Some(max_len) => election + .ballots + .par_iter() + .with_max_len(max_len) + .enumerate() + .fold_with( + VoteAccumulator::new(election.num_candidates), + |mut vote_accumulator, (i, ballot)| { + Self::process_ballot( + &mut vote_accumulator, + keep_factors, + pascal, + i, + ballot, + ); + vote_accumulator + }, + ) + .reduce( + || VoteAccumulator::new(election.num_candidates), + |a, b| a.reduce(b), + ), + } } } @@ -1071,6 +1105,7 @@ mod test { Parallel::No, None, None, + None, ); let vote_count_parallel = VoteCount::::count_votes( &election, @@ -1078,6 +1113,7 @@ mod test { Parallel::Rayon, None, None, + None, ); assert_eq!( vote_count, vote_count_parallel, @@ -1103,6 +1139,7 @@ mod test { Parallel::No, None, None, + None, ); for num_threads in 1..=10 { @@ -1120,6 +1157,7 @@ mod test { Parallel::Custom, Some(&thread_pool), None, + None, ); assert_eq!( vote_count, vote_count_parallel, @@ -1666,6 +1704,7 @@ mod test { parallel, None, None, + None, ) }); } diff --git a/tools/benchmark-max-len.sh b/tools/benchmark-max-len.sh new file mode 100755 index 0000000..9c95b0e --- /dev/null +++ b/tools/benchmark-max-len.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -eux + +if [ -e "${HOME}/.cargo/bin/hyperfine" ]; then + HYPERFINE_PATH=${HOME}/.cargo/bin/hyperfine +elif [ -e "/usr/bin/hyperfine" ]; then + HYPERFINE_PATH=/usr/bin/hyperfine +else + echo "Hyperfine is not installed. Please install it with 'apt install hyperfine' or 'cargo install hyperfine'." + exit 42 +fi + +./tools/check-inputs.sh + +cargo build --release +sleep 15 + +BINARY=./target/release/stv-rs +if [ ! -e "${BINARY}" ]; then + echo "Binary not found at ${BINARY}." + exit 42 +fi + +SLEEP_SECONDS=15 + +function benchmark() { + local ARITHMETIC=$1 + local EQUALIZE=$2 + local INPUT=$3 + local NUM_THREADS=4 + + "${HYPERFINE_PATH}" \ + --style color \ + --sort command \ + --setup "sleep ${SLEEP_SECONDS}" \ + --warmup 1 \ + --parameter-list PARALLEL '--parallel=rayon --max-serial-len=1','--parallel=rayon --max-serial-len=2','--parallel=rayon --max-serial-len=4','--parallel=rayon --max-serial-len=8','--parallel=rayon --max-serial-len=16','--parallel=rayon --max-serial-len=32','--parallel=rayon --max-serial-len=64','--parallel=rayon --max-serial-len=128','--parallel=rayon','--parallel=custom' \ + "${BINARY} --arithmetic ${ARITHMETIC} --input ${INPUT} meek ${EQUALIZE} --num-threads=${NUM_THREADS} {PARALLEL} > /dev/null" +} + +benchmark bigfixed9 "" "testdata/ballots/random/rand_2x10.blt" +benchmark fixed9 "" "testdata/ballots/random/rand_mixed_5k.blt" +benchmark fixed9 --equalize "testdata/ballots/random/rand_mixed_5k.blt" +benchmark bigfixed9 "" "testdata/ballots/random/rand_mixed_5k.blt" +benchmark bigfixed9 --equalize "testdata/ballots/random/rand_mixed_5k.blt" +benchmark fixed9 "" "testdata/ballots/random/rand_hypergeometric_10k.blt" +benchmark fixed9 --equalize "testdata/ballots/random/rand_hypergeometric_10k.blt"