Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

feat(acvm): Directive for sorting networks #77

Merged
merged 4 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions acir/src/circuit/directives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ pub enum Directive {
radix: u32,
},

// Sort directive, using a sorting network
// This directive is used to generate the values of the control bits for the sorting network such that its outputs are properly sorted according to sort_by
PermutationSort {
inputs: Vec<Vec<Expression>>, // Array of tuples to sort
tuple: u32, // tuple size; if 1 then inputs is a single array [a0,a1,..], if 2 then inputs=[(a0,b0),..] is [a0,b0,a1,b1,..], etc..
bits: Vec<Witness>, // control bits of the network which permutes the inputs into its sorted version
sort_by: Vec<u32>, // specify primary index to sort by, then the secondary,... For instance, if tuple is 2 and sort_by is [1,0], then a=[(a0,b0),..] is sorted by bi and then ai.
},
Log(LogInfo),
}

Expand All @@ -60,6 +68,7 @@ impl Directive {
Directive::Truncate { .. } => "truncate",
Directive::OddRange { .. } => "odd_range",
Directive::ToRadix { .. } => "to_radix",
Directive::PermutationSort { .. } => "permutation_sort",
Directive::Log { .. } => "log",
}
}
Expand All @@ -71,6 +80,7 @@ impl Directive {
Directive::OddRange { .. } => 3,
Directive::ToRadix { .. } => 4,
Directive::Log { .. } => 5,
Directive::PermutationSort { .. } => 6,
}
}

Expand Down Expand Up @@ -120,6 +130,28 @@ impl Directive {
}
write_u32(&mut writer, *radix)?;
}
Directive::PermutationSort {
inputs: a,
tuple,
bits,
sort_by,
} => {
write_u32(&mut writer, *tuple)?;
write_u32(&mut writer, a.len() as u32)?;
for e in a {
for i in 0..*tuple {
e[i as usize].write(&mut writer)?;
}
}
write_u32(&mut writer, bits.len() as u32)?;
for b in bits {
write_u32(&mut writer, b.witness_index())?;
}
write_u32(&mut writer, sort_by.len() as u32)?;
for i in sort_by {
write_u32(&mut writer, *i)?;
}
}
Directive::Log(info) => match info {
LogInfo::FinalizedOutput(output_string) => {
write_bytes(&mut writer, output_string.as_bytes())?;
Expand Down Expand Up @@ -193,6 +225,35 @@ impl Directive {

Ok(Directive::ToRadix { a, b, radix })
}
6 => {
let tuple = read_u32(&mut reader)?;
let a_len = read_u32(&mut reader)?;
let mut a = Vec::with_capacity(a_len as usize);
for _ in 0..a_len {
let mut element = Vec::new();
for _ in 0..tuple {
element.push(Expression::read(&mut reader)?);
}
a.push(element);
}

let bits_len = read_u32(&mut reader)?;
let mut bits = Vec::with_capacity(bits_len as usize);
for _ in 0..bits_len {
bits.push(Witness(read_u32(&mut reader)?));
}
let sort_by_len = read_u32(&mut reader)?;
let mut sort_by = Vec::with_capacity(sort_by_len as usize);
for _ in 0..sort_by_len {
sort_by.push(read_u32(&mut reader)?);
}
Ok(Directive::PermutationSort {
inputs: a,
tuple,
bits,
sort_by,
})
}

_ => Err(std::io::ErrorKind::InvalidData.into()),
}
Expand Down
18 changes: 18 additions & 0 deletions acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ impl std::fmt::Display for Opcode {
b.last().unwrap().witness_index(),
)
}
Opcode::Directive(Directive::PermutationSort {
inputs: a,
tuple,
bits,
sort_by,
}) => {
write!(f, "DIR::PERMUTATIONSORT ")?;
write!(
f,
"(permutation size: {} {}-tuples, sort_by: {:#?}, bits: [_{}..._{}]))",
a.len(),
tuple,
sort_by,
// (Note): the bits do not have contiguous index but there are too many for display
bits.first().unwrap().witness_index(),
bits.last().unwrap().witness_index(),
)
}
Opcode::Directive(Directive::Log(info)) => match info {
LogInfo::FinalizedOutput(output_string) => write!(f, "Log: {output_string}"),
LogInfo::WitnessOutput(witnesses) => write!(
Expand Down
1 change: 1 addition & 0 deletions acvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ bls12_381 = ["acir_field/bls12_381"]

[dev-dependencies]
tempfile = "3.2.0"
rand="0.8.5"
1 change: 1 addition & 0 deletions acvm/src/pwg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod hash;
pub mod logic;
pub mod range;
pub mod signature;
pub mod sorting;

// Returns the concrete value for a particular witness
// If the witness has no assignment, then
Expand Down
94 changes: 69 additions & 25 deletions acvm/src/pwg/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl ArithmeticSolver {
initial_witness: &mut BTreeMap<Witness, FieldElement>,
gate: &Expression,
) -> Result<(), OpcodeResolutionError> {
let gate = &ArithmeticSolver::evaluate(gate, initial_witness);
// Evaluate multiplication term
let mul_result = ArithmeticSolver::solve_mul_term(gate, initial_witness);
// Evaluate the fan-in terms
Expand Down Expand Up @@ -127,26 +128,41 @@ impl ArithmeticSolver {
// We are assuming it has been optimized.
match arith_gate.mul_terms.len() {
0 => MulTerm::Solved(FieldElement::zero()),
1 => {
let q_m = &arith_gate.mul_terms[0].0;
let w_l = &arith_gate.mul_terms[0].1;
let w_r = &arith_gate.mul_terms[0].2;

// Check if these values are in the witness assignments
let w_l_value = witness_assignments.get(w_l);
let w_r_value = witness_assignments.get(w_r);

match (w_l_value, w_r_value) {
(None, None) => MulTerm::TooManyUnknowns,
(Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
(None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
(Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
}
}
1 => ArithmeticSolver::solve_mul_term_helper(
&arith_gate.mul_terms[0],
witness_assignments,
),
_ => panic!("Mul term in the arithmetic gate must contain either zero or one term"),
guipublic marked this conversation as resolved.
Show resolved Hide resolved
}
}

fn solve_mul_term_helper(
term: &(FieldElement, Witness, Witness),
witness_assignments: &BTreeMap<Witness, FieldElement>,
) -> MulTerm {
let (q_m, w_l, w_r) = term;
// Check if these values are in the witness assignments
let w_l_value = witness_assignments.get(w_l);
let w_r_value = witness_assignments.get(w_r);

match (w_l_value, w_r_value) {
(None, None) => MulTerm::TooManyUnknowns,
(Some(w_l), Some(w_r)) => MulTerm::Solved(*q_m * *w_l * *w_r),
(None, Some(w_r)) => MulTerm::OneUnknown(*q_m * *w_r, *w_l),
(Some(w_l), None) => MulTerm::OneUnknown(*q_m * *w_l, *w_r),
}
}

fn solve_fan_in_term_helper(
term: &(FieldElement, Witness),
witness_assignments: &BTreeMap<Witness, FieldElement>,
) -> Option<FieldElement> {
let (q_l, w_l) = term;
// Check if we have w_l
let w_l_value = witness_assignments.get(w_l);
w_l_value.map(|a| *q_l * *a)
}

/// Returns the summation of all of the variables, plus the unknown variable
/// Returns None, if there is more than one unknown variable
/// We cannot assign
Expand All @@ -163,19 +179,14 @@ impl ArithmeticSolver {
let mut result = FieldElement::zero();

for term in arith_gate.linear_combinations.iter() {
let q_l = term.0;
let w_l = &term.1;

// Check if we have w_l
let w_l_value = witness_assignments.get(w_l);

match w_l_value {
Some(a) => result += q_l * *a,
let value = ArithmeticSolver::solve_fan_in_term_helper(term, witness_assignments);
match value {
Some(a) => result += a,
None => {
unknown_variable = *term;
num_unknowns += 1;
}
};
}

// If we have more than 1 unknown, then we cannot solve this equation
if num_unknowns > 1 {
Expand All @@ -189,6 +200,39 @@ impl ArithmeticSolver {

GateStatus::GateSolvable(result, unknown_variable)
}

// Partially evaluate the gate using the known witnesses
pub fn evaluate(
expr: &Expression,
initial_witness: &BTreeMap<Witness, FieldElement>,
) -> Expression {
let mut result = Expression::default();
for &(c, w1, w2) in &expr.mul_terms {
let mul_result = ArithmeticSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness);
match mul_result {
MulTerm::OneUnknown(v, w) => {
if !v.is_zero() {
result.linear_combinations.push((v, w));
}
}
MulTerm::TooManyUnknowns => {
if !c.is_zero() {
result.mul_terms.push((c, w1, w2));
}
}
MulTerm::Solved(f) => result.q_c += f,
}
}
for &(c, w) in &expr.linear_combinations {
if let Some(f) = ArithmeticSolver::solve_fan_in_term_helper(&(c, w), initial_witness) {
result.q_c += f;
} else if !c.is_zero() {
result.linear_combinations.push((c, w));
}
}
result.q_c += expr.q_c;
result
}
}

#[test]
Expand Down
Loading