From 9c7383d52fb9878c7b7283e4c2e3aceff3ec0341 Mon Sep 17 00:00:00 2001 From: Tom French Date: Sat, 8 Apr 2023 02:15:58 +0100 Subject: [PATCH] chore: organise operator implementations for `Expression` --- .../{arithmetic.rs => expression/mod.rs} | 238 ++---------------- acir/src/native_types/expression/operators.rs | 152 +++++++++++ acir/src/native_types/expression/ordering.rs | 99 ++++++++ acir/src/native_types/mod.rs | 4 +- acvm/src/compiler/transformers/fallback.rs | 10 +- acvm/src/pwg/block.rs | 8 +- stdlib/src/fallback.rs | 2 +- 7 files changed, 287 insertions(+), 226 deletions(-) rename acir/src/native_types/{arithmetic.rs => expression/mod.rs} (68%) create mode 100644 acir/src/native_types/expression/operators.rs create mode 100644 acir/src/native_types/expression/ordering.rs diff --git a/acir/src/native_types/arithmetic.rs b/acir/src/native_types/expression/mod.rs similarity index 68% rename from acir/src/native_types/arithmetic.rs rename to acir/src/native_types/expression/mod.rs index 1311cd7a4..40f08e560 100644 --- a/acir/src/native_types/arithmetic.rs +++ b/acir/src/native_types/expression/mod.rs @@ -2,9 +2,10 @@ use crate::native_types::Witness; use crate::serialization::{read_field_element, read_u32, write_bytes, write_u32}; use acir_field::FieldElement; use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; use std::io::{Read, Write}; -use std::ops::{Add, Mul, Neg, Sub}; + +mod operators; +mod ordering; // In the addition polynomial // We can have arbitrary fan-in/out, so we need more than wL,wR and wO @@ -47,36 +48,6 @@ impl std::fmt::Display for Expression { } } -// TODO: possibly remove, and move to noir repo. -impl Ord for Expression { - fn cmp(&self, other: &Self) -> Ordering { - let mut i1 = self.get_max_idx(); - let mut i2 = other.get_max_idx(); - let mut result = Ordering::Equal; - while result == Ordering::Equal { - let m1 = self.get_max_term(&mut i1); - let m2 = other.get_max_term(&mut i2); - if m1.is_none() && m2.is_none() { - return Ordering::Equal; - } - result = Expression::cmp_max(m1, m2); - } - result - } -} -// TODO: possibly remove, and move to noir repo. -impl PartialOrd for Expression { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -// TODO: possibly remove, and move to noir repo. -struct WitnessIdx { - linear: usize, - mul: usize, - second_term: bool, -} - impl Expression { // TODO: possibly remove, and move to noir repo. pub const fn can_defer_constraint(&self) -> bool { @@ -250,195 +221,13 @@ impl Expression { None } - fn get_max_idx(&self) -> WitnessIdx { - WitnessIdx { - linear: self.linear_combinations.len(), - mul: self.mul_terms.len(), - second_term: true, - } - } - // Returns the maximum witness at the provided position, and decrement the position - // This function assumes the gate is sorted - // TODO: possibly remove, and move to noir repo. - fn get_max_term(&self, idx: &mut WitnessIdx) -> Option { - if idx.linear > 0 { - if idx.mul > 0 { - let mul_term = if idx.second_term { - self.mul_terms[idx.mul - 1].2 - } else { - self.mul_terms[idx.mul - 1].1 - }; - if self.linear_combinations[idx.linear - 1].1 > mul_term { - idx.linear -= 1; - Some(self.linear_combinations[idx.linear].1) - } else { - if idx.second_term { - idx.second_term = false; - } else { - idx.mul -= 1; - } - Some(mul_term) - } - } else { - idx.linear -= 1; - Some(self.linear_combinations[idx.linear].1) - } - } else if idx.mul > 0 { - if idx.second_term { - idx.second_term = false; - Some(self.mul_terms[idx.mul - 1].2) - } else { - idx.mul -= 1; - Some(self.mul_terms[idx.mul].1) - } - } else { - None - } - } - - // TODO: possibly remove, and move to noir repo. - fn cmp_max(m1: Option, m2: Option) -> Ordering { - if let Some(m1) = m1 { - if let Some(m2) = m2 { - m1.cmp(&m2) - } else { - Ordering::Greater - } - } else if m2.is_some() { - Ordering::Less - } else { - Ordering::Equal - } - } - /// Sorts gate in a deterministic order /// XXX: We can probably make this more efficient by sorting on each phase. We only care if it is deterministic pub fn sort(&mut self) { self.mul_terms.sort_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2))); self.linear_combinations.sort_by(|a, b| a.1.cmp(&b.1)); } -} - -impl Mul<&FieldElement> for &Expression { - type Output = Expression; - fn mul(self, rhs: &FieldElement) -> Self::Output { - // Scale the mul terms - let mul_terms: Vec<_> = - self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * *rhs, *w_l, *w_r)).collect(); - - // Scale the linear combinations terms - let lin_combinations: Vec<_> = - self.linear_combinations.iter().map(|(q_l, w_l)| (*q_l * *rhs, *w_l)).collect(); - - // Scale the constant - let q_c = self.q_c * *rhs; - - Expression { mul_terms, q_c, linear_combinations: lin_combinations } - } -} -impl Add<&FieldElement> for Expression { - type Output = Expression; - fn add(self, rhs: &FieldElement) -> Self::Output { - // Increase the constant - let q_c = self.q_c + *rhs; - - Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations } - } -} -impl Sub<&FieldElement> for Expression { - type Output = Expression; - fn sub(self, rhs: &FieldElement) -> Self::Output { - // Increase the constant - let q_c = self.q_c - *rhs; - - Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations } - } -} - -impl Add<&Expression> for &Expression { - type Output = Expression; - fn add(self, rhs: &Expression) -> Expression { - // XXX(med) : Implement an efficient way to do this - - let mul_terms: Vec<_> = - self.mul_terms.iter().cloned().chain(rhs.mul_terms.iter().cloned()).collect(); - - let linear_combinations: Vec<_> = self - .linear_combinations - .iter() - .cloned() - .chain(rhs.linear_combinations.iter().cloned()) - .collect(); - let q_c = self.q_c + rhs.q_c; - - Expression { mul_terms, linear_combinations, q_c } - } -} - -impl Neg for &Expression { - type Output = Expression; - fn neg(self) -> Self::Output { - // XXX(med) : Implement an efficient way to do this - - let mul_terms: Vec<_> = - self.mul_terms.iter().map(|(q_m, w_l, w_r)| (-*q_m, *w_l, *w_r)).collect(); - - let linear_combinations: Vec<_> = - self.linear_combinations.iter().map(|(q_k, w_k)| (-*q_k, *w_k)).collect(); - let q_c = -self.q_c; - - Expression { mul_terms, linear_combinations, q_c } - } -} - -impl Sub<&Expression> for &Expression { - type Output = Expression; - fn sub(self, rhs: &Expression) -> Expression { - self + &-rhs - } -} - -impl From for Expression { - fn from(constant: FieldElement) -> Expression { - Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() } - } -} - -impl From<&FieldElement> for Expression { - fn from(constant: &FieldElement) -> Expression { - (*constant).into() - } -} - -impl From for Expression { - /// Creates an Expression from a Witness. - /// - /// This is infallible since an `Expression` is - /// a multi-variate polynomial and a `Witness` - /// can be seen as a univariate polynomial - fn from(wit: Witness) -> Expression { - Expression { - q_c: FieldElement::zero(), - linear_combinations: vec![(FieldElement::one(), wit)], - mul_terms: Vec::new(), - } - } -} - -impl From<&Witness> for Expression { - fn from(wit: &Witness) -> Expression { - (*wit).into() - } -} - -impl Sub<&Witness> for &Expression { - type Output = Expression; - fn sub(self, rhs: &Witness) -> Expression { - self - &Expression::from(rhs) - } -} -impl Expression { /// Checks if this polynomial can fit into one arithmetic identity pub fn fits_in_one_identity(&self, width: usize) -> bool { // A Polynomial with more than one mul term cannot fit into one gate @@ -495,6 +284,27 @@ impl Expression { } } +impl From for Expression { + fn from(constant: FieldElement) -> Expression { + Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() } + } +} + +impl From for Expression { + /// Creates an Expression from a Witness. + /// + /// This is infallible since an `Expression` is + /// a multi-variate polynomial and a `Witness` + /// can be seen as a univariate polynomial + fn from(wit: Witness) -> Expression { + Expression { + q_c: FieldElement::zero(), + linear_combinations: vec![(FieldElement::one(), wit)], + mul_terms: Vec::new(), + } + } +} + #[test] fn serialization_roundtrip() { // Empty expression diff --git a/acir/src/native_types/expression/operators.rs b/acir/src/native_types/expression/operators.rs new file mode 100644 index 000000000..a2e90766f --- /dev/null +++ b/acir/src/native_types/expression/operators.rs @@ -0,0 +1,152 @@ +use crate::native_types::Witness; +use acir_field::FieldElement; +use std::ops::{Add, Mul, Neg, Sub}; + +use super::Expression; + +// Negation + +impl Neg for &Expression { + type Output = Expression; + fn neg(self) -> Self::Output { + // XXX(med) : Implement an efficient way to do this + + let mul_terms: Vec<_> = + self.mul_terms.iter().map(|(q_m, w_l, w_r)| (-*q_m, *w_l, *w_r)).collect(); + + let linear_combinations: Vec<_> = + self.linear_combinations.iter().map(|(q_k, w_k)| (-*q_k, *w_k)).collect(); + let q_c = -self.q_c; + + Expression { mul_terms, linear_combinations, q_c } + } +} + +// FieldElement + +impl Add for Expression { + type Output = Expression; + fn add(self, rhs: FieldElement) -> Self::Output { + // Increase the constant + let q_c = self.q_c + rhs; + + Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations } + } +} + +impl Add for FieldElement { + type Output = Expression; + #[inline] + fn add(self, rhs: Expression) -> Self::Output { + rhs + self + } +} + +impl Sub for Expression { + type Output = Expression; + fn sub(self, rhs: FieldElement) -> Self::Output { + // Increase the constant + let q_c = self.q_c - rhs; + + Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations } + } +} + +impl Sub for FieldElement { + type Output = Expression; + #[inline] + fn sub(self, rhs: Expression) -> Self::Output { + rhs - self + } +} + +impl Mul for &Expression { + type Output = Expression; + fn mul(self, rhs: FieldElement) -> Self::Output { + // Scale the mul terms + let mul_terms: Vec<_> = + self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * rhs, *w_l, *w_r)).collect(); + + // Scale the linear combinations terms + let lin_combinations: Vec<_> = + self.linear_combinations.iter().map(|(q_l, w_l)| (*q_l * rhs, *w_l)).collect(); + + // Scale the constant + let q_c = self.q_c * rhs; + + Expression { mul_terms, q_c, linear_combinations: lin_combinations } + } +} + +impl Mul<&Expression> for FieldElement { + type Output = Expression; + #[inline] + fn mul(self, rhs: &Expression) -> Self::Output { + rhs * self + } +} + +// Witness + +impl Add for &Expression { + type Output = Expression; + fn add(self, rhs: Witness) -> Expression { + self + &Expression::from(rhs) + } +} + +impl Add<&Expression> for Witness { + type Output = Expression; + #[inline] + fn add(self, rhs: &Expression) -> Expression { + rhs + self + } +} + +impl Sub for &Expression { + type Output = Expression; + fn sub(self, rhs: Witness) -> Expression { + self - &Expression::from(rhs) + } +} + +impl Sub<&Expression> for Witness { + type Output = Expression; + #[inline] + fn sub(self, rhs: &Expression) -> Expression { + rhs - self + } +} + +// Mul is not implemented as this could result in degree 3 terms. + +// Expression + +impl Add<&Expression> for &Expression { + type Output = Expression; + fn add(self, rhs: &Expression) -> Expression { + // XXX(med) : Implement an efficient way to do this + + let mul_terms: Vec<_> = + self.mul_terms.iter().cloned().chain(rhs.mul_terms.iter().cloned()).collect(); + + let linear_combinations: Vec<_> = self + .linear_combinations + .iter() + .cloned() + .chain(rhs.linear_combinations.iter().cloned()) + .collect(); + let q_c = self.q_c + rhs.q_c; + + Expression { mul_terms, linear_combinations, q_c } + } +} + +impl Sub<&Expression> for &Expression { + type Output = Expression; + fn sub(self, rhs: &Expression) -> Expression { + self + &-rhs + } +} + +// Mul is not implemented as this could result in degree 3+ terms. diff --git a/acir/src/native_types/expression/ordering.rs b/acir/src/native_types/expression/ordering.rs new file mode 100644 index 000000000..e24a25ec3 --- /dev/null +++ b/acir/src/native_types/expression/ordering.rs @@ -0,0 +1,99 @@ +use crate::native_types::Witness; +use std::cmp::Ordering; + +use super::Expression; + +// TODO: It's undecided whether `Expression` should implement `Ord/PartialOrd`. +// This is currently used in ACVM in the compiler. + +impl Ord for Expression { + fn cmp(&self, other: &Self) -> Ordering { + let mut i1 = self.get_max_idx(); + let mut i2 = other.get_max_idx(); + let mut result = Ordering::Equal; + while result == Ordering::Equal { + let m1 = self.get_max_term(&mut i1); + let m2 = other.get_max_term(&mut i2); + if m1.is_none() && m2.is_none() { + return Ordering::Equal; + } + result = Expression::cmp_max(m1, m2); + } + result + } +} + +impl PartialOrd for Expression { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +struct WitnessIdx { + linear: usize, + mul: usize, + second_term: bool, +} + +impl Expression { + fn get_max_idx(&self) -> WitnessIdx { + WitnessIdx { + linear: self.linear_combinations.len(), + mul: self.mul_terms.len(), + second_term: true, + } + } + + /// Returns the maximum witness at the provided position, and decrement the position. + /// + /// This function assumes the gate is sorted + fn get_max_term(&self, idx: &mut WitnessIdx) -> Option { + if idx.linear > 0 { + if idx.mul > 0 { + let mul_term = if idx.second_term { + self.mul_terms[idx.mul - 1].2 + } else { + self.mul_terms[idx.mul - 1].1 + }; + if self.linear_combinations[idx.linear - 1].1 > mul_term { + idx.linear -= 1; + Some(self.linear_combinations[idx.linear].1) + } else { + if idx.second_term { + idx.second_term = false; + } else { + idx.mul -= 1; + } + Some(mul_term) + } + } else { + idx.linear -= 1; + Some(self.linear_combinations[idx.linear].1) + } + } else if idx.mul > 0 { + if idx.second_term { + idx.second_term = false; + Some(self.mul_terms[idx.mul - 1].2) + } else { + idx.mul -= 1; + Some(self.mul_terms[idx.mul].1) + } + } else { + None + } + } + + fn cmp_max(m1: Option, m2: Option) -> Ordering { + if let Some(m1) = m1 { + if let Some(m2) = m2 { + m1.cmp(&m2) + } else { + Ordering::Greater + } + } else if m2.is_some() { + Ordering::Less + } else { + Ordering::Equal + } + } +} diff --git a/acir/src/native_types/mod.rs b/acir/src/native_types/mod.rs index b0efba220..4b54d9388 100644 --- a/acir/src/native_types/mod.rs +++ b/acir/src/native_types/mod.rs @@ -1,5 +1,5 @@ -mod arithmetic; +mod expression; mod witness; -pub use arithmetic::Expression; +pub use expression::Expression; pub use witness::Witness; diff --git a/acvm/src/compiler/transformers/fallback.rs b/acvm/src/compiler/transformers/fallback.rs index 33c8c031e..6b669b18a 100644 --- a/acvm/src/compiler/transformers/fallback.rs +++ b/acvm/src/compiler/transformers/fallback.rs @@ -70,8 +70,8 @@ impl FallbackTransformer { BlackBoxFunc::AND => { let (lhs, rhs, result, num_bits) = crate::pwg::logic::extract_input_output(gc); stdlib::fallback::and( - Expression::from(&lhs), - Expression::from(&rhs), + Expression::from(lhs), + Expression::from(rhs), result, num_bits, current_witness_idx, @@ -80,8 +80,8 @@ impl FallbackTransformer { BlackBoxFunc::XOR => { let (lhs, rhs, result, num_bits) = crate::pwg::logic::extract_input_output(gc); stdlib::fallback::xor( - Expression::from(&lhs), - Expression::from(&rhs), + Expression::from(lhs), + Expression::from(rhs), result, num_bits, current_witness_idx, @@ -93,7 +93,7 @@ impl FallbackTransformer { let input = &gc.inputs[0]; // Note there are no outputs because range produces no outputs stdlib::fallback::range( - Expression::from(&input.witness), + Expression::from(input.witness), input.num_bits, current_witness_idx, ) diff --git a/acvm/src/pwg/block.rs b/acvm/src/pwg/block.rs index 1970d660e..5c33ad3ef 100644 --- a/acvm/src/pwg/block.rs +++ b/acvm/src/pwg/block.rs @@ -141,24 +141,24 @@ mod test { let mut trace = vec![MemOp { operation: Expression::one(), index: Expression::from_field(index), - value: Expression::from(&Witness(1)), + value: Expression::from(Witness(1)), }]; index += FieldElement::one(); trace.push(MemOp { operation: Expression::one(), index: Expression::from_field(index), - value: Expression::from(&Witness(2)), + value: Expression::from(Witness(2)), }); index += FieldElement::one(); trace.push(MemOp { operation: Expression::one(), index: Expression::from_field(index), - value: Expression::from(&Witness(3)), + value: Expression::from(Witness(3)), }); trace.push(MemOp { operation: Expression::zero(), index: Expression::one(), - value: Expression::from(&Witness(4)), + value: Expression::from(Witness(4)), }); let id = BlockId::default(); let mut initial_witness = BTreeMap::new(); diff --git a/stdlib/src/fallback.rs b/stdlib/src/fallback.rs index fe045cf37..3eaab2932 100644 --- a/stdlib/src/fallback.rs +++ b/stdlib/src/fallback.rs @@ -48,7 +48,7 @@ pub(crate) fn bit_decomposition( // of the input and each bit is actually a bit let mut binary_exprs = Vec::new(); let mut bit_decomp_constraint = gate; - let mut two_pow = FieldElement::one(); + let mut two_pow: FieldElement = FieldElement::one(); let two = FieldElement::from(2_i128); for &bit in &bit_vector { // Bit constraint to ensure each bit is a zero or one; bit^2 - bit = 0