Skip to content

Commit

Permalink
Merge pull request #63 from Yoii-Inc/fix/role_assignment_r1cs
Browse files Browse the repository at this point in the history
Fix/role assignment r1cs
  • Loading branch information
taskooh authored Oct 7, 2024
2 parents 73e6019 + 9cf11eb commit 603208f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/rust_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,7 @@ jobs:
run: ./run_werewolf.zsh night

- name: Run werewolf vote
run: ./run_werewolf.zsh vote
run: ./run_werewolf.zsh vote

- name: Run werewolf role assignment
run: ./run_werewolf.zsh role_assignment
10 changes: 10 additions & 0 deletions examples/bin_werewolf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ fn role_assignment(opt: &Opt) -> Result<(), std::io::Error> {
// prove
let local_role_circuit = RoleAssignmentCircuit {
num_players: n,
max_group_size: grouping_parameter.get_max_group_size(),
pedersen_param: pedersen_param.clone(),
tau_matrix: na::DMatrix::<Fr>::zeros(n + m, n + m),
shuffle_matrices: vec![na::DMatrix::<Fr>::zeros(n + m, n + m); 2],
Expand Down Expand Up @@ -420,6 +421,7 @@ fn role_assignment(opt: &Opt) -> Result<(), std::io::Error> {

let mpc_role_circuit = RoleAssignmentCircuit {
num_players: n,
max_group_size: grouping_parameter.get_max_group_size(),
pedersen_param: mpc_pedersen_param,
tau_matrix: grouping_parameter.generate_tau_matrix(),
shuffle_matrices: shuffle_matrix,
Expand Down Expand Up @@ -622,6 +624,14 @@ impl GroupingParameter {
self.0.values().map(|x| x.0).sum()
}

fn get_max_group_size(&self) -> usize {
self.0
.values()
.map(|(count, is_not_alone)| if *is_not_alone { *count } else { 1 })
.max()
.expect("Error: No max value found")
}

fn get_corresponding_role(&self, role_id: usize) -> Roles {
let mut count = self.get_num_players();
for (role, (role_count, is_not_alone)) in self.0.iter() {
Expand Down
30 changes: 14 additions & 16 deletions src/circuits/werewolf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ impl ConstraintSynthesizer<mm::MpcField<Fr>> for WinningJudgeCircuit<mm::MpcFiel
pub struct RoleAssignmentCircuit<F: PrimeField + LocalOrMPC<F>> {
// parameter
pub num_players: usize,
pub max_group_size: usize,
pub pedersen_param: F::PedersenParam,

// instance
Expand Down Expand Up @@ -961,7 +962,7 @@ impl ConstraintSynthesizer<Fr> for RoleAssignmentCircuit<Fr> {

let calced_role = calced_vec
.iter()
.map(|val| test_max(val, false).unwrap())
.map(|val| test_max(val, self.max_group_size + 1, true).unwrap())
.collect::<Vec<_>>();

// commitment
Expand Down Expand Up @@ -1077,7 +1078,7 @@ impl ConstraintSynthesizer<mm::MpcField<Fr>> for RoleAssignmentCircuit<mm::MpcFi

let calced_role = calced_vec
.iter()
.map(|val| test_max_mpc(val, false).unwrap())
.map(|val| test_max_mpc(val, self.max_group_size + 1, true).unwrap())
.collect::<Vec<_>>();

// commitment
Expand Down Expand Up @@ -1328,8 +1329,10 @@ impl ElGamalLocalOrMPC<mm::MpcField<Fr>> for mm::MpcField<Fr> {
}
}

// return maximum value in the vector a, index runs from 0 to use_index_len
fn test_max<F: PrimeField>(
a: &[FpVar<F>],
use_index_len: usize,
should_enforce: bool,
) -> Result<FpVar<F>, SynthesisError> {
let cs = a[0].cs().clone();
Expand All @@ -1340,18 +1343,17 @@ fn test_max<F: PrimeField>(

if should_enforce {
// each element must be less than half of the modulus
a.iter().for_each(|x| {
max_var
.enforce_cmp(x, core::cmp::Ordering::Greater, true)
.unwrap()
});
for i in 0..use_index_len {
a[i].enforce_cmp(&max_var, core::cmp::Ordering::Less, true)?;
}
}

Ok(max_var)
}

fn test_max_mpc<F: PrimeField + SquareRootField + BitDecomposition + EqualityZero>(
a: &[MpcFpVar<F>],
use_index_len: usize,
should_enforce: bool,
) -> Result<MpcFpVar<F>, SynthesisError> {
let cs = a[0].cs().clone();
Expand All @@ -1361,12 +1363,9 @@ fn test_max_mpc<F: PrimeField + SquareRootField + BitDecomposition + EqualityZer
})?;

if should_enforce {
// [ ]: implement correctly
a.iter().for_each(|x| {
max_var
.enforce_cmp(x, core::cmp::Ordering::Greater, true)
.unwrap()
});
for i in 0..use_index_len {
a[i].enforce_cmp(&max_var, core::cmp::Ordering::Less, true)?;
}
}

Ok(max_var)
Expand Down Expand Up @@ -1533,10 +1532,9 @@ where
// all check 0 or 1 -> row sum and column sum is 1
let val = &matrix[(i, j)];

(<MpcFpVar<F> as Zero>::zero() - val)
.is_zero()
val.is_zero()
.unwrap()
.or(&(<MpcFpVar<F> as One>::one() - val).is_zero().unwrap())
.or(&(val - <MpcFpVar<F> as One>::one()).is_zero().unwrap())
.unwrap()
.enforce_equal(&MpcBoolean::TRUE)?;

Expand Down

0 comments on commit 603208f

Please sign in to comment.