Skip to content

Commit

Permalink
Add option to use Rayon's with_max_len() parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
gendx committed Dec 1, 2024
1 parent 473bcd6 commit 566656e
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 17 deletions.
10 changes: 10 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ struct MeekParams {
#[arg(long)]
disable_work_stealing: bool,

/// Maximal length of a serial run of items. Ignored if `--parallel` isn't
/// set to "rayon".
#[arg(long)]
max_serial_len: Option<NonZeroUsize>,

/// Enable a bug-fix in the surplus calculation, preventing it from being
/// negative. Results may differ from Droop.py, but this prevents
/// crashes.
Expand Down Expand Up @@ -202,6 +207,7 @@ impl Cli {
meek_params.parallel,
meek_params.num_threads,
meek_params.disable_work_stealing,
meek_params.max_serial_len,
meek_params.force_positive_surplus,
meek_params.equalize,
)?;
Expand Down Expand Up @@ -273,6 +279,7 @@ mod test {
disable_work_stealing: false,
force_positive_surplus: false,
equalize: false,
max_serial_len: None,
})
}
);
Expand Down Expand Up @@ -323,6 +330,7 @@ mod test {
disable_work_stealing: true,
force_positive_surplus: true,
equalize: true,
max_serial_len: None,
}),
}
);
Expand Down Expand Up @@ -360,6 +368,7 @@ mod test {
disable_work_stealing: true,
force_positive_surplus: true,
equalize: true,
max_serial_len: None,
}),
}
);
Expand Down Expand Up @@ -437,6 +446,7 @@ mod test {
disable_work_stealing: false,
force_positive_surplus: false,
equalize,
max_serial_len: None,
}),
}
}
Expand Down
24 changes: 23 additions & 1 deletion src/meek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub fn stv_droop<I, R>(
parallel: Parallel,
num_threads: Option<NonZeroUsize>,
disable_work_stealing: bool,
max_serial_len: Option<NonZeroUsize>,
force_positive_surplus: bool,
equalize: bool,
) -> io::Result<ElectionResult>
Expand Down Expand Up @@ -76,6 +77,7 @@ where
force_positive_surplus,
pascal,
None,
None,
);
state.run(stdout, package_name, omega_exponent)
}
Expand All @@ -95,6 +97,7 @@ where
force_positive_surplus,
pascal,
None,
max_serial_len.map(|x| x.into()),
);
state.run(stdout, package_name, omega_exponent)
}
Expand Down Expand Up @@ -130,6 +133,7 @@ where
force_positive_surplus,
pascal,
Some(thread_pool),
None,
);
state.run(stdout, package_name, omega_exponent)
}),
Expand Down Expand Up @@ -201,6 +205,7 @@ pub struct State<'scope, 'e, I, R> {
/// enabled.
pascal: Option<&'e [Vec<I>]>,
thread_pool: Option<ThreadPool<'scope, I, R>>,
max_serial_len: Option<usize>,
_phantom: PhantomData<I>,
}

Expand All @@ -225,6 +230,7 @@ where
force_positive_surplus: bool,
pascal: Option<&'a [Vec<I>]>,
thread_pool: Option<ThreadPool<'scope, I, R>>,
max_serial_len: Option<usize>,
) -> State<'scope, 'a, I, R> {
State {
election,
Expand All @@ -250,6 +256,7 @@ where
force_positive_surplus,
pascal,
thread_pool,
max_serial_len,
_phantom: PhantomData,
}
}
Expand Down Expand Up @@ -337,6 +344,7 @@ where
self.parallel,
self.thread_pool.as_ref(),
self.pascal,
self.max_serial_len,
)
}

Expand Down Expand Up @@ -936,6 +944,7 @@ mod test {
force_positive_surplus: self.force_positive_surplus.unwrap(),
pascal: self.pascal.unwrap(),
thread_pool: None,
max_serial_len: None,
_phantom: PhantomData,
}
}
Expand Down Expand Up @@ -1061,6 +1070,7 @@ mod test {
parallel,
None,
false,
None,
false,
false,
)
Expand Down Expand Up @@ -1288,6 +1298,7 @@ Action: Count Complete
parallel,
None,
false,
None,
false,
true,
)
Expand Down Expand Up @@ -1454,6 +1465,7 @@ Action: Count Complete
Parallel::No,
None,
false,
None,
false,
false,
)
Expand Down Expand Up @@ -1603,6 +1615,7 @@ Action: Count Complete
Parallel::Custom,
/* num_threads = */ Some(NonZeroUsize::new(2).unwrap()),
false,
None,
false,
false,
);
Expand All @@ -1621,6 +1634,7 @@ Action: Count Complete
Parallel::Custom,
/* num_threads = */ Some(NonZeroUsize::new(2).unwrap()),
false,
None,
false,
false,
)
Expand Down Expand Up @@ -1739,7 +1753,15 @@ Action: Count Complete
])
.build();
let omega_exponent = 6;
let state = State::new(&election, omega_exponent, Parallel::No, false, None, None);
let state = State::new(
&election,
omega_exponent,
Parallel::No,
false,
None,
None,
None,
);

let mut buf = Vec::new();
let count = state
Expand Down
71 changes: 55 additions & 16 deletions src/vote_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,13 @@ where
parallel: Parallel,
thread_pool: Option<&ThreadPool<'_, I, R>>,
pascal: Option<&[Vec<I>]>,
max_serial_len: Option<usize>,
) -> Self {
let vote_accumulator = match parallel {
Parallel::No => Self::accumulate_votes_serial(election, keep_factors, pascal),
Parallel::Rayon => Self::accumulate_votes_rayon(election, keep_factors, pascal),
Parallel::Rayon => {
Self::accumulate_votes_rayon(election, keep_factors, pascal, max_serial_len)
}
Parallel::Custom => thread_pool.unwrap().accumulate_votes(keep_factors),
};

Expand Down Expand Up @@ -141,22 +144,53 @@ where
election: &Election,
keep_factors: &[R],
pascal: Option<&[Vec<I>]>,
max_serial_len: Option<usize>,
) -> VoteAccumulator<I, R> {
election
.ballots
.par_iter()
.enumerate()
.fold_with(
VoteAccumulator::new(election.num_candidates),
|mut vote_accumulator, (i, ballot)| {
Self::process_ballot(&mut vote_accumulator, keep_factors, pascal, i, ballot);
vote_accumulator
},
)
.reduce(
|| VoteAccumulator::new(election.num_candidates),
|a, b| a.reduce(b),
)
match max_serial_len {
None => election
.ballots
.par_iter()
.enumerate()
.fold_with(
VoteAccumulator::new(election.num_candidates),
|mut vote_accumulator, (i, ballot)| {
Self::process_ballot(
&mut vote_accumulator,
keep_factors,
pascal,
i,
ballot,
);
vote_accumulator
},
)
.reduce(
|| VoteAccumulator::new(election.num_candidates),
|a, b| a.reduce(b),
),
Some(max_len) => election
.ballots
.par_iter()
.with_max_len(max_len)
.enumerate()
.fold_with(
VoteAccumulator::new(election.num_candidates),
|mut vote_accumulator, (i, ballot)| {
Self::process_ballot(
&mut vote_accumulator,
keep_factors,
pascal,
i,
ballot,
);
vote_accumulator
},
)
.reduce(
|| VoteAccumulator::new(election.num_candidates),
|a, b| a.reduce(b),
),
}
}
}

Expand Down Expand Up @@ -1071,13 +1105,15 @@ mod test {
Parallel::No,
None,
None,
None,
);
let vote_count_parallel = VoteCount::<I, R>::count_votes(
&election,
&keep_factors,
Parallel::Rayon,
None,
None,
None,
);
assert_eq!(
vote_count, vote_count_parallel,
Expand All @@ -1103,6 +1139,7 @@ mod test {
Parallel::No,
None,
None,
None,
);

for num_threads in 1..=10 {
Expand All @@ -1120,6 +1157,7 @@ mod test {
Parallel::Custom,
Some(&thread_pool),
None,
None,
);
assert_eq!(
vote_count, vote_count_parallel,
Expand Down Expand Up @@ -1666,6 +1704,7 @@ mod test {
parallel,
None,
None,
None,
)
});
}
Expand Down
48 changes: 48 additions & 0 deletions tools/benchmark-max-len.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash

set -eux

if [ -e "${HOME}/.cargo/bin/hyperfine" ]; then
HYPERFINE_PATH=${HOME}/.cargo/bin/hyperfine
elif [ -e "/usr/bin/hyperfine" ]; then
HYPERFINE_PATH=/usr/bin/hyperfine
else
echo "Hyperfine is not installed. Please install it with 'apt install hyperfine' or 'cargo install hyperfine'."
exit 42
fi

./tools/check-inputs.sh

cargo build --release
sleep 15

BINARY=./target/release/stv-rs
if [ ! -e "${BINARY}" ]; then
echo "Binary not found at ${BINARY}."
exit 42
fi

SLEEP_SECONDS=15

function benchmark() {
local ARITHMETIC=$1
local EQUALIZE=$2
local INPUT=$3
local NUM_THREADS=4

"${HYPERFINE_PATH}" \
--style color \
--sort command \
--setup "sleep ${SLEEP_SECONDS}" \
--warmup 1 \
--parameter-list PARALLEL '--parallel=rayon --max-serial-len=1','--parallel=rayon --max-serial-len=2','--parallel=rayon --max-serial-len=4','--parallel=rayon --max-serial-len=8','--parallel=rayon --max-serial-len=16','--parallel=rayon --max-serial-len=32','--parallel=rayon --max-serial-len=64','--parallel=rayon --max-serial-len=128','--parallel=rayon','--parallel=custom' \
"${BINARY} --arithmetic ${ARITHMETIC} --input ${INPUT} meek ${EQUALIZE} --num-threads=${NUM_THREADS} {PARALLEL} > /dev/null"
}

benchmark bigfixed9 "" "testdata/ballots/random/rand_2x10.blt"
benchmark fixed9 "" "testdata/ballots/random/rand_mixed_5k.blt"
benchmark fixed9 --equalize "testdata/ballots/random/rand_mixed_5k.blt"
benchmark bigfixed9 "" "testdata/ballots/random/rand_mixed_5k.blt"
benchmark bigfixed9 --equalize "testdata/ballots/random/rand_mixed_5k.blt"
benchmark fixed9 "" "testdata/ballots/random/rand_hypergeometric_10k.blt"
benchmark fixed9 --equalize "testdata/ballots/random/rand_hypergeometric_10k.blt"

0 comments on commit 566656e

Please sign in to comment.