Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

Commit

Permalink
feat: implement add_mul on Expression (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Apr 20, 2023
1 parent a619df6 commit f156e18
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 17 deletions.
143 changes: 143 additions & 0 deletions acir/src/native_types/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::native_types::Witness;
use crate::serialization::{read_field_element, read_u32, write_bytes, write_u32};
use acir_field::FieldElement;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::io::{Read, Write};

mod operators;
Expand Down Expand Up @@ -282,6 +283,114 @@ impl Expression {

found_x & found_y
}

/// Returns `self + k*b`
pub fn add_mul(&self, k: FieldElement, b: &Expression) -> Expression {
if k.is_zero() {
return self.clone();
} else if self.is_const() {
return self.q_c + (k * b);
} else if b.is_const() {
return self.clone() + (k * b.q_c);
}

let mut mul_terms: Vec<(FieldElement, Witness, Witness)> =
Vec::with_capacity(self.mul_terms.len() + b.mul_terms.len());
let mut linear_combinations: Vec<(FieldElement, Witness)> =
Vec::with_capacity(self.linear_combinations.len() + b.linear_combinations.len());
let q_c = self.q_c + k * b.q_c;

//linear combinations
let mut i1 = 0; //a
let mut i2 = 0; //b
while i1 < self.linear_combinations.len() && i2 < b.linear_combinations.len() {
let (a_c, a_w) = self.linear_combinations[i1];
let (b_c, b_w) = b.linear_combinations[i2];

let (coeff, witness) = match a_w.cmp(&b_w) {
Ordering::Greater => {
i2 += 1;
(k * b_c, b_w)
}
Ordering::Less => {
i1 += 1;
(a_c, a_w)
}
Ordering::Equal => {
// Here we're taking both witnesses as the witness indices are equal.
// We then advance both `i1` and `i2`.
i1 += 1;
i2 += 1;
(a_c + k * b_c, a_w)
}
};

if !coeff.is_zero() {
linear_combinations.push((coeff, witness));
}
}

// Finally process all the remaining terms which we didn't handle in the above loop.
while i1 < self.linear_combinations.len() {
linear_combinations.push(self.linear_combinations[i1]);
i1 += 1;
}
while i2 < b.linear_combinations.len() {
let (b_c, b_w) = b.linear_combinations[i2];
let coeff = b_c * k;
if !coeff.is_zero() {
linear_combinations.push((coeff, b_w));
}
i2 += 1;
}

//mul terms

i1 = 0; //a
i2 = 0; //b
while i1 < self.mul_terms.len() && i2 < b.mul_terms.len() {
let (a_c, a_wl, a_wr) = self.mul_terms[i1];
let (b_c, b_wl, b_wr) = b.mul_terms[i2];

let (coeff, wl, wr) = match (a_wl, a_wr).cmp(&(b_wl, b_wr)) {
Ordering::Greater => {
i2 += 1;
(k * b_c, b_wl, b_wr)
}
Ordering::Less => {
i1 += 1;
(a_c, a_wl, a_wr)
}
Ordering::Equal => {
// Here we're taking both terms as the witness indices are equal.
// We then advance both `i1` and `i2`.
i2 += 1;
i1 += 1;
(a_c + k * b_c, a_wl, a_wr)
}
};

if !coeff.is_zero() {
mul_terms.push((coeff, wl, wr));
}
}

// Finally process all the remaining terms which we didn't handle in the above loop.
while i1 < self.mul_terms.len() {
mul_terms.push(self.mul_terms[i1]);
i1 += 1;
}
while i2 < b.mul_terms.len() {
let (b_c, b_wl, b_wr) = b.mul_terms[i2];
let coeff = b_c * k;
if coeff != FieldElement::zero() {
mul_terms.push((coeff, b_wl, b_wr));
}
i2 += 1;
}

Expression { mul_terms, linear_combinations, q_c }
}
}

impl From<FieldElement> for Expression {
Expand Down Expand Up @@ -330,3 +439,37 @@ fn serialization_roundtrip() {
let (expr, got_expr) = read_write(expr);
assert_eq!(expr, got_expr);
}

#[test]
fn add_mul_smoketest() {
let a = Expression {
mul_terms: vec![(FieldElement::from(2u128), Witness(1), Witness(2))],
..Default::default()
};

let k = FieldElement::from(10u128);

let b = Expression {
mul_terms: vec![
(FieldElement::from(3u128), Witness(0), Witness(2)),
(FieldElement::from(3u128), Witness(1), Witness(2)),
(FieldElement::from(4u128), Witness(4), Witness(5)),
],
linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
q_c: FieldElement::one(),
};

let result = a.add_mul(k, &b);
assert_eq!(
result,
Expression {
mul_terms: vec![
(FieldElement::from(30u128), Witness(0), Witness(2)),
(FieldElement::from(32u128), Witness(1), Witness(2)),
(FieldElement::from(40u128), Witness(4), Witness(5)),
],
linear_combinations: vec![(FieldElement::from(40u128), Witness(4))],
q_c: FieldElement::from(10u128)
}
)
}
106 changes: 89 additions & 17 deletions acir/src/native_types/expression/operators.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::native_types::Witness;
use acir_field::FieldElement;
use std::ops::{Add, Mul, Neg, Sub};
use std::{
cmp::Ordering,
ops::{Add, Mul, Neg, Sub},
};

use super::Expression;

Expand Down Expand Up @@ -125,28 +128,97 @@ impl Sub<&Expression> for Witness {
impl Add<&Expression> for &Expression {
type Output = Expression;
fn add(self, rhs: &Expression) -> Expression {
// XXX(med) : Implement an efficient way to do this

let mul_terms: Vec<_> =
self.mul_terms.iter().cloned().chain(rhs.mul_terms.iter().cloned()).collect();

let linear_combinations: Vec<_> = self
.linear_combinations
.iter()
.cloned()
.chain(rhs.linear_combinations.iter().cloned())
.collect();
let q_c = self.q_c + rhs.q_c;

Expression { mul_terms, linear_combinations, q_c }
self.add_mul(FieldElement::one(), rhs)
}
}

impl Sub<&Expression> for &Expression {
type Output = Expression;
fn sub(self, rhs: &Expression) -> Expression {
self + &-rhs
self.add_mul(-FieldElement::one(), rhs)
}
}

// Mul<Expression> is not implemented as this could result in degree 3+ terms.
impl Mul<&Expression> for &Expression {
type Output = Expression;
fn mul(self, rhs: &Expression) -> Expression {
if self.is_const() {
return self.q_c * rhs;
} else if rhs.is_const() {
return self * rhs.q_c;
} else if !(self.is_linear() && rhs.is_linear()) {
// `Expression`s can only represent terms which are up to degree 2.
// We then disallow multiplication of `Expression`s which have degree 2 terms.
unreachable!("Can only multiply linear terms");
}

let mut output = Expression::from_field(self.q_c * rhs.q_c);

//TODO to optimize...
for lc in &self.linear_combinations {
let single = single_mul(lc.1, rhs);
output = output.add_mul(lc.0, &single);
}

//linear terms
let mut i1 = 0; //a
let mut i2 = 0; //b
while i1 < self.linear_combinations.len() && i2 < rhs.linear_combinations.len() {
let (a_c, a_w) = self.linear_combinations[i1];
let (b_c, b_w) = rhs.linear_combinations[i2];

// Apply scaling from multiplication
let a_c = rhs.q_c * a_c;
let b_c = self.q_c * b_c;

let (coeff, witness) = match a_w.cmp(&b_w) {
Ordering::Greater => {
i2 += 1;
(b_c, b_w)
}
Ordering::Less => {
i1 += 1;
(a_c, a_w)
}
Ordering::Equal => {
// Here we're taking both terms as the witness indices are equal.
// We then advance both `i1` and `i2`.
i1 += 1;
i2 += 1;
(a_c + b_c, a_w)
}
};

if !coeff.is_zero() {
output.linear_combinations.push((coeff, witness));
}
}
while i1 < self.linear_combinations.len() {
let (a_c, a_w) = self.linear_combinations[i1];
output.linear_combinations.push((rhs.q_c * a_c, a_w));
i1 += 1;
}
while i2 < rhs.linear_combinations.len() {
let (b_c, b_w) = rhs.linear_combinations[i1];
output.linear_combinations.push((self.q_c * b_c, b_w));
i2 += 1;
}

output
}
}

/// Returns `w*b.linear_combinations`
fn single_mul(w: Witness, b: &Expression) -> Expression {
Expression {
mul_terms: b
.linear_combinations
.iter()
.map(|(a, wit)| {
let (wl, wr) = if w < *wit { (w, *wit) } else { (*wit, w) };
(*a, wl, wr)
})
.collect(),
..Default::default()
}
}

0 comments on commit f156e18

Please sign in to comment.