Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

chore!: organise operator implementations for Expression #190

Merged
merged 1 commit into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use crate::native_types::Witness;
use crate::serialization::{read_field_element, read_u32, write_bytes, write_u32};
use acir_field::FieldElement;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::io::{Read, Write};
use std::ops::{Add, Mul, Neg, Sub};

mod operators;
mod ordering;

// In the addition polynomial
// We can have arbitrary fan-in/out, so we need more than wL,wR and wO
Expand Down Expand Up @@ -47,36 +48,6 @@ impl std::fmt::Display for Expression {
}
}

// TODO: possibly remove, and move to noir repo.
impl Ord for Expression {
fn cmp(&self, other: &Self) -> Ordering {
let mut i1 = self.get_max_idx();
let mut i2 = other.get_max_idx();
let mut result = Ordering::Equal;
while result == Ordering::Equal {
let m1 = self.get_max_term(&mut i1);
let m2 = other.get_max_term(&mut i2);
if m1.is_none() && m2.is_none() {
return Ordering::Equal;
}
result = Expression::cmp_max(m1, m2);
}
result
}
}
// TODO: possibly remove, and move to noir repo.
impl PartialOrd for Expression {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
// TODO: possibly remove, and move to noir repo.
struct WitnessIdx {
linear: usize,
mul: usize,
second_term: bool,
}

impl Expression {
// TODO: possibly remove, and move to noir repo.
pub const fn can_defer_constraint(&self) -> bool {
Expand Down Expand Up @@ -250,195 +221,13 @@ impl Expression {
None
}

fn get_max_idx(&self) -> WitnessIdx {
WitnessIdx {
linear: self.linear_combinations.len(),
mul: self.mul_terms.len(),
second_term: true,
}
}
// Returns the maximum witness at the provided position, and decrement the position
// This function assumes the gate is sorted
// TODO: possibly remove, and move to noir repo.
fn get_max_term(&self, idx: &mut WitnessIdx) -> Option<Witness> {
if idx.linear > 0 {
if idx.mul > 0 {
let mul_term = if idx.second_term {
self.mul_terms[idx.mul - 1].2
} else {
self.mul_terms[idx.mul - 1].1
};
if self.linear_combinations[idx.linear - 1].1 > mul_term {
idx.linear -= 1;
Some(self.linear_combinations[idx.linear].1)
} else {
if idx.second_term {
idx.second_term = false;
} else {
idx.mul -= 1;
}
Some(mul_term)
}
} else {
idx.linear -= 1;
Some(self.linear_combinations[idx.linear].1)
}
} else if idx.mul > 0 {
if idx.second_term {
idx.second_term = false;
Some(self.mul_terms[idx.mul - 1].2)
} else {
idx.mul -= 1;
Some(self.mul_terms[idx.mul].1)
}
} else {
None
}
}

// TODO: possibly remove, and move to noir repo.
fn cmp_max(m1: Option<Witness>, m2: Option<Witness>) -> Ordering {
if let Some(m1) = m1 {
if let Some(m2) = m2 {
m1.cmp(&m2)
} else {
Ordering::Greater
}
} else if m2.is_some() {
Ordering::Less
} else {
Ordering::Equal
}
}

/// Sorts gate in a deterministic order
/// XXX: We can probably make this more efficient by sorting on each phase. We only care if it is deterministic
pub fn sort(&mut self) {
self.mul_terms.sort_by(|a, b| a.1.cmp(&b.1).then(a.2.cmp(&b.2)));
self.linear_combinations.sort_by(|a, b| a.1.cmp(&b.1));
}
}

impl Mul<&FieldElement> for &Expression {
type Output = Expression;
fn mul(self, rhs: &FieldElement) -> Self::Output {
// Scale the mul terms
let mul_terms: Vec<_> =
self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * *rhs, *w_l, *w_r)).collect();

// Scale the linear combinations terms
let lin_combinations: Vec<_> =
self.linear_combinations.iter().map(|(q_l, w_l)| (*q_l * *rhs, *w_l)).collect();

// Scale the constant
let q_c = self.q_c * *rhs;

Expression { mul_terms, q_c, linear_combinations: lin_combinations }
}
}
impl Add<&FieldElement> for Expression {
type Output = Expression;
fn add(self, rhs: &FieldElement) -> Self::Output {
// Increase the constant
let q_c = self.q_c + *rhs;

Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
}
}
impl Sub<&FieldElement> for Expression {
type Output = Expression;
fn sub(self, rhs: &FieldElement) -> Self::Output {
// Increase the constant
let q_c = self.q_c - *rhs;

Expression { mul_terms: self.mul_terms, q_c, linear_combinations: self.linear_combinations }
}
}

impl Add<&Expression> for &Expression {
type Output = Expression;
fn add(self, rhs: &Expression) -> Expression {
// XXX(med) : Implement an efficient way to do this

let mul_terms: Vec<_> =
self.mul_terms.iter().cloned().chain(rhs.mul_terms.iter().cloned()).collect();

let linear_combinations: Vec<_> = self
.linear_combinations
.iter()
.cloned()
.chain(rhs.linear_combinations.iter().cloned())
.collect();
let q_c = self.q_c + rhs.q_c;

Expression { mul_terms, linear_combinations, q_c }
}
}

impl Neg for &Expression {
type Output = Expression;
fn neg(self) -> Self::Output {
// XXX(med) : Implement an efficient way to do this

let mul_terms: Vec<_> =
self.mul_terms.iter().map(|(q_m, w_l, w_r)| (-*q_m, *w_l, *w_r)).collect();

let linear_combinations: Vec<_> =
self.linear_combinations.iter().map(|(q_k, w_k)| (-*q_k, *w_k)).collect();
let q_c = -self.q_c;

Expression { mul_terms, linear_combinations, q_c }
}
}

impl Sub<&Expression> for &Expression {
type Output = Expression;
fn sub(self, rhs: &Expression) -> Expression {
self + &-rhs
}
}

impl From<FieldElement> for Expression {
fn from(constant: FieldElement) -> Expression {
Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() }
}
}

impl From<&FieldElement> for Expression {
fn from(constant: &FieldElement) -> Expression {
(*constant).into()
}
}

impl From<Witness> for Expression {
/// Creates an Expression from a Witness.
///
/// This is infallible since an `Expression` is
/// a multi-variate polynomial and a `Witness`
/// can be seen as a univariate polynomial
fn from(wit: Witness) -> Expression {
Expression {
q_c: FieldElement::zero(),
linear_combinations: vec![(FieldElement::one(), wit)],
mul_terms: Vec::new(),
}
}
}

impl From<&Witness> for Expression {
fn from(wit: &Witness) -> Expression {
(*wit).into()
}
}

impl Sub<&Witness> for &Expression {
type Output = Expression;
fn sub(self, rhs: &Witness) -> Expression {
self - &Expression::from(rhs)
}
}

impl Expression {
/// Checks if this polynomial can fit into one arithmetic identity
pub fn fits_in_one_identity(&self, width: usize) -> bool {
// A Polynomial with more than one mul term cannot fit into one gate
Expand Down Expand Up @@ -495,6 +284,27 @@ impl Expression {
}
}

impl From<FieldElement> for Expression {
fn from(constant: FieldElement) -> Expression {
Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() }
}
}

impl From<Witness> for Expression {
/// Creates an Expression from a Witness.
///
/// This is infallible since an `Expression` is
/// a multi-variate polynomial and a `Witness`
/// can be seen as a univariate polynomial
fn from(wit: Witness) -> Expression {
Expression {
q_c: FieldElement::zero(),
linear_combinations: vec![(FieldElement::one(), wit)],
mul_terms: Vec::new(),
}
}
}

#[test]
fn serialization_roundtrip() {
// Empty expression
Expand Down
Loading