Skip to content

Commit

Permalink
chore: simplify MSM with constant folding (noir-lang#6650)
Browse files Browse the repository at this point in the history
  • Loading branch information
guipublic authored Dec 6, 2024
1 parent df71df7 commit 6d0f86b
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 42 deletions.
6 changes: 2 additions & 4 deletions compiler/noirc_evaluator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@ pub mod ssa;
pub use ssa::create_program;
pub use ssa::ir::instruction::ErrorType;

/// Trims leading whitespace from each line of the input string, according to
/// how much leading whitespace there is on the first non-empty line.
/// Trims leading whitespace from each line of the input string
#[cfg(test)]
pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String {
let mut lines = src.trim_end().lines();
let mut first_line = lines.next().unwrap();
while first_line.is_empty() {
first_line = lines.next().unwrap();
}
let indent = first_line.len() - first_line.trim_start().len();
let mut result = first_line.trim_start().to_string();
for line in lines {
result.push('\n');
result.push_str(&line[indent..]);
result.push_str(line.trim_start());
}
result
}
Expand Down
240 changes: 202 additions & 38 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::sync::Arc;

use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement};

use crate::ssa::ir::instruction::BlackBoxFunc;
use crate::ssa::ir::{
basic_block::BasicBlockId,
dfg::{CallStack, DataFlowGraph},
instruction::{Instruction, SimplifyResult},
instruction::{Instruction, Intrinsic, SimplifyResult},
types::Type,
value::ValueId,
};
Expand Down Expand Up @@ -70,52 +71,125 @@ pub(super) fn simplify_msm(
block: BasicBlockId,
call_stack: &CallStack,
) -> SimplifyResult {
// TODO: Handle MSMs where a subset of the terms are constant.
let mut is_constant;

match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) {
(Some((points, _)), Some((scalars, _))) => {
let Some(points) = points
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};

let Some(scalars) = scalars
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};
// We decompose points and scalars into constant and non-constant parts in order to simplify MSMs where a subset of the terms are constant.
let mut constant_points = vec![];
let mut constant_scalars_lo = vec![];
let mut constant_scalars_hi = vec![];
let mut var_points = vec![];
let mut var_scalars = vec![];
let len = scalars.len() / 2;
for i in 0..len {
match (
dfg.get_numeric_constant(scalars[2 * i]),
dfg.get_numeric_constant(scalars[2 * i + 1]),
dfg.get_numeric_constant(points[3 * i]),
dfg.get_numeric_constant(points[3 * i + 1]),
dfg.get_numeric_constant(points[3 * i + 2]),
) {
(Some(lo), Some(hi), _, _, _) if lo.is_zero() && hi.is_zero() => {
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(_, _, _, _, Some(infinity)) if infinity.is_one() => {
is_constant = true;
constant_scalars_lo.push(FieldElement::zero());
constant_scalars_hi.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(Some(lo), Some(hi), Some(x), Some(y), Some(infinity)) => {
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(x);
constant_points.push(y);
constant_points.push(infinity);
}
_ => {
is_constant = false;
}
}

let mut scalars_lo = Vec::new();
let mut scalars_hi = Vec::new();
for (i, scalar) in scalars.into_iter().enumerate() {
if i % 2 == 0 {
scalars_lo.push(scalar);
} else {
scalars_hi.push(scalar);
if !is_constant {
var_points.push(points[3 * i]);
var_points.push(points[3 * i + 1]);
var_points.push(points[3 * i + 2]);
var_scalars.push(scalars[2 * i]);
var_scalars.push(scalars[2 * i + 1]);
}
}

let Ok((result_x, result_y, result_is_infinity)) =
solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)
else {
// If there are no constant terms, we can't simplify
if constant_scalars_lo.is_empty() {
return SimplifyResult::None;
}
let Ok((result_x, result_y, result_is_infinity)) = solver.multi_scalar_mul(
&constant_points,
&constant_scalars_lo,
&constant_scalars_hi,
) else {
return SimplifyResult::None;
};

let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array =
dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone());

SimplifyResult::SimplifiedTo(result_array.first())
// If there are no variable term, we can directly return the constant result
if var_scalars.is_empty() {
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array = dfg.insert_instruction_and_results(
instruction,
block,
None,
call_stack.clone(),
);

return SimplifyResult::SimplifiedTo(result_array.first());
}
// If there is only one non-null constant term, we cannot simplify
if constant_scalars_lo.len() == 1 && result_is_infinity != FieldElement::one() {
return SimplifyResult::None;
}
// Add the constant part back to the non-constant part, if it is not null
let one = dfg.make_constant(FieldElement::one(), Type::field());
let zero = dfg.make_constant(FieldElement::zero(), Type::field());
if result_is_infinity.is_zero() {
var_scalars.push(one);
var_scalars.push(zero);
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());
var_points.push(result_x);
var_points.push(result_y);
var_points.push(result_is_infinity);
}
// Construct the simplified MSM expression
let typ = Type::Array(Arc::new(vec![Type::field()]), var_scalars.len() as u32);
let scalars = Instruction::MakeArray { elements: var_scalars.into(), typ };
let scalars = dfg
.insert_instruction_and_results(scalars, block, None, call_stack.clone())
.first();
let typ = Type::Array(Arc::new(vec![Type::field()]), var_points.len() as u32);
let points = Instruction::MakeArray { elements: var_points.into(), typ };
let points =
dfg.insert_instruction_and_results(points, block, None, call_stack.clone()).first();
let msm = dfg.import_intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul));
SimplifyResult::SimplifiedToInstruction(Instruction::Call {
func: msm,
arguments: vec![points, scalars],
})
}
_ => SimplifyResult::None,
}
Expand Down Expand Up @@ -261,3 +335,93 @@ pub(super) fn simplify_signature(
_ => SimplifyResult::None,
}
}

#[cfg(feature = "bn254")]
#[cfg(test)]
mod test {
use crate::ssa::opt::assert_normalized_ssa_equals;
use crate::ssa::Ssa;

#[cfg(feature = "bn254")]
#[test]
fn full_constant_folding() {
let src = r#"
acir(inline) fn main f0 {
b0():
v0 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4]
v1 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6]
v2 = call multi_scalar_mul (v1, v0) -> [Field; 3]
return v2
}"#;
let ssa = Ssa::from_str(src).unwrap();

let expected_src = r#"
acir(inline) fn main f0 {
b0():
v3 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4]
v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6]
v10 = make_array [Field 1478523918288173385110236399861791147958001875200066088686689589556927843200, Field 700144278551281040379388961242974992655630750193306467120985766322057145630, Field 0] : [Field; 3]
return v10
}
"#;
assert_normalized_ssa_equals(ssa, expected_src);
}

#[cfg(feature = "bn254")]
#[test]
fn simplify_zero() {
let src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v2 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6]
v3 = make_array [
Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9]
v4 = call multi_scalar_mul (v3, v2) -> [Field; 3]
return v4
}"#;
let ssa = Ssa::from_str(src).unwrap();
//First point is zero, second scalar is zero, so we should be left with the scalar mul of the last point.
let expected_src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v3 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6]
v5 = make_array [Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9]
v6 = make_array [v0, Field 0] : [Field; 2]
v7 = make_array [Field 1, v0, Field 0] : [Field; 3]
v9 = call multi_scalar_mul(v7, v6) -> [Field; 3]
return v9
}
"#;
assert_normalized_ssa_equals(ssa, expected_src);
}

#[cfg(feature = "bn254")]
#[test]
fn partial_constant_folding() {
let src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v2 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6]
v3 = make_array [
Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9]
v4 = call multi_scalar_mul (v3, v2) -> [Field; 3]
return v4
}"#;
let ssa = Ssa::from_str(src).unwrap();
//First and last scalar/point are constant, so we should be left with the msm of the middle point and the folded constant point
let expected_src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v5 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6]
v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9]
v8 = make_array [v0, Field 0, Field 1, Field 0] : [Field; 4]
v12 = make_array [v0, v1, Field 0, Field -3227352362257037263902424173275354266044964400219754872043023745437788450996, Field 8902249110305491597038405103722863701255802573786510474664632793109847672620, u1 0] : [Field; 6]
v14 = call multi_scalar_mul(v12, v8) -> [Field; 3]
return v14
}
"#;
assert_normalized_ssa_equals(ssa, expected_src);
}
}
18 changes: 18 additions & 0 deletions test_programs/execution_success/embedded_curve_ops/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,22 @@ fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) {

// The results should be double the g1 point because the scalars are 1 and we pass in g1 twice
assert(double.x == res.x);

// Tests for #6549
let const_scalar1 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 23, hi: 0 };
let const_scalar2 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 23 };
let const_scalar3 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 13, hi: 4 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[g1, double, pub_point, g1, g1],
[scalar, const_scalar1, scalar, const_scalar2, const_scalar3],
);
assert(partial_mul.x == 0x2024c4eebfbc8a20018f8c95c7aab77c6f34f10cf785a6f04e97452d8708fda7);
// Check simplification by zero
let zero_point = std::embedded_curve_ops::EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true };
let const_zero = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 0 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[zero_point, double, g1],
[scalar, const_zero, scalar],
);
assert(partial_mul == g1);
}

0 comments on commit 6d0f86b

Please sign in to comment.