Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: simplify MSM with constant folding #6650

Merged
merged 13 commits into from
Dec 6, 2024
6 changes: 2 additions & 4 deletions compiler/noirc_evaluator/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@ pub mod ssa;
pub use ssa::create_program;
pub use ssa::ir::instruction::ErrorType;

/// Trims leading whitespace from each line of the input string, according to
/// how much leading whitespace there is on the first non-empty line.
/// Trims leading whitespace from each line of the input string
#[cfg(test)]
pub(crate) fn trim_leading_whitespace_from_lines(src: &str) -> String {
let mut lines = src.trim_end().lines();
let mut first_line = lines.next().unwrap();
while first_line.is_empty() {
first_line = lines.next().unwrap();
}
let indent = first_line.len() - first_line.trim_start().len();
let mut result = first_line.trim_start().to_string();
for line in lines {
result.push('\n');
result.push_str(&line[indent..]);
result.push_str(line.trim_start());
}
result
}
Expand Down
240 changes: 202 additions & 38 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::sync::Arc;

use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement};

use crate::ssa::ir::instruction::BlackBoxFunc;
use crate::ssa::ir::{
basic_block::BasicBlockId,
dfg::{CallStack, DataFlowGraph},
instruction::{Instruction, SimplifyResult},
instruction::{Instruction, Intrinsic, SimplifyResult},
types::Type,
value::ValueId,
};
Expand Down Expand Up @@ -70,52 +71,125 @@ pub(super) fn simplify_msm(
block: BasicBlockId,
call_stack: &CallStack,
) -> SimplifyResult {
// TODO: Handle MSMs where a subset of the terms are constant.
let mut is_constant;

match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) {
(Some((points, _)), Some((scalars, _))) => {
let Some(points) = points
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};

let Some(scalars) = scalars
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};
// We decompose points and scalars into constant and non-constant parts in order to simplify MSMs where a subset of the terms are constant.
let mut constant_points = vec![];
let mut constant_scalars_lo = vec![];
let mut constant_scalars_hi = vec![];
let mut var_points = vec![];
let mut var_scalars = vec![];
let len = scalars.len() / 2;
for i in 0..len {
match (
dfg.get_numeric_constant(scalars[2 * i]),
dfg.get_numeric_constant(scalars[2 * i + 1]),
dfg.get_numeric_constant(points[3 * i]),
dfg.get_numeric_constant(points[3 * i + 1]),
dfg.get_numeric_constant(points[3 * i + 2]),
) {
(Some(lo), Some(hi), _, _, _) if lo.is_zero() && hi.is_zero() => {
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(_, _, _, _, Some(infinity)) if infinity.is_one() => {
is_constant = true;
constant_scalars_lo.push(FieldElement::zero());
constant_scalars_hi.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(Some(lo), Some(hi), Some(x), Some(y), Some(infinity)) => {
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(x);
constant_points.push(y);
constant_points.push(infinity);
}
_ => {
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
is_constant = false;
}
}

let mut scalars_lo = Vec::new();
let mut scalars_hi = Vec::new();
for (i, scalar) in scalars.into_iter().enumerate() {
if i % 2 == 0 {
scalars_lo.push(scalar);
} else {
scalars_hi.push(scalar);
if !is_constant {
var_points.push(points[3 * i]);
var_points.push(points[3 * i + 1]);
var_points.push(points[3 * i + 2]);
var_scalars.push(scalars[2 * i]);
var_scalars.push(scalars[2 * i + 1]);
}
}

let Ok((result_x, result_y, result_is_infinity)) =
solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)
else {
// If there are no constant terms, we can't simplify
if constant_scalars_lo.is_empty() {
return SimplifyResult::None;
}
let Ok((result_x, result_y, result_is_infinity)) = solver.multi_scalar_mul(
&constant_points,
&constant_scalars_lo,
&constant_scalars_hi,
) else {
return SimplifyResult::None;
};

let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array =
dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone());

SimplifyResult::SimplifiedTo(result_array.first())
// If there are no variable term, we can directly return the constant result
if var_scalars.is_empty() {
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array = dfg.insert_instruction_and_results(
instruction,
block,
None,
call_stack.clone(),
);

return SimplifyResult::SimplifiedTo(result_array.first());
}
// If there is only one non-null constant term, we cannot simplify
if constant_scalars_lo.len() == 1 && result_is_infinity != FieldElement::one() {
return SimplifyResult::None;
}
// Add the constant part back to the non-constant part, if it is not null
let one = dfg.make_constant(FieldElement::one(), Type::field());
let zero = dfg.make_constant(FieldElement::zero(), Type::field());
if result_is_infinity.is_zero() {
var_scalars.push(one);
var_scalars.push(zero);
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());
var_points.push(result_x);
var_points.push(result_y);
var_points.push(result_is_infinity);
}
// Construct the simplified MSM expression
let typ = Type::Array(Arc::new(vec![Type::field()]), var_scalars.len() as u32);
let scalars = Instruction::MakeArray { elements: var_scalars.into(), typ };
let scalars = dfg
.insert_instruction_and_results(scalars, block, None, call_stack.clone())
.first();
let typ = Type::Array(Arc::new(vec![Type::field()]), var_points.len() as u32);
let points = Instruction::MakeArray { elements: var_points.into(), typ };
let points =
dfg.insert_instruction_and_results(points, block, None, call_stack.clone()).first();
let msm = dfg.import_intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul));
SimplifyResult::SimplifiedToInstruction(Instruction::Call {
func: msm,
arguments: vec![points, scalars],
})
}
_ => SimplifyResult::None,
}
Expand Down Expand Up @@ -261,3 +335,93 @@ pub(super) fn simplify_signature(
_ => SimplifyResult::None,
}
}

#[cfg(feature = "bn254")]
#[cfg(test)]
mod test {
use crate::ssa::opt::assert_normalized_ssa_equals;
use crate::ssa::Ssa;

#[cfg(feature = "bn254")]
#[test]
fn full_constant_folding() {
let src = r#"
acir(inline) fn main f0 {
b0():
v0 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4]
v1 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6]
v2 = call multi_scalar_mul (v1, v0) -> [Field; 3]
return v2
}"#;
let ssa = Ssa::from_str(src).unwrap();

let expected_src = r#"
acir(inline) fn main f0 {
b0():
v3 = make_array [Field 2, Field 3, Field 5, Field 5] : [Field; 4]
v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 6]
v10 = make_array [Field 1478523918288173385110236399861791147958001875200066088686689589556927843200, Field 700144278551281040379388961242974992655630750193306467120985766322057145630, Field 0] : [Field; 3]
return v10
}
"#;
assert_normalized_ssa_equals(ssa, expected_src);
}

#[cfg(feature = "bn254")]
#[test]
fn simplify_zero() {
let src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v2 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6]
v3 = make_array [
Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9]
v4 = call multi_scalar_mul (v3, v2) -> [Field; 3]

return v4

}"#;
let ssa = Ssa::from_str(src).unwrap();
//First point is zero, second scalar is zero, so we should be left with the scalar mul of the last point.
let expected_src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v3 = make_array [v0, Field 0, Field 0, Field 0, v0, Field 0] : [Field; 6]
v5 = make_array [Field 0, Field 0, Field 1, v0, v1, Field 0, Field 1, v0, Field 0] : [Field; 9]
v6 = make_array [v0, Field 0] : [Field; 2]
v7 = make_array [Field 1, v0, Field 0] : [Field; 3]
v9 = call multi_scalar_mul(v7, v6) -> [Field; 3]
return v9
}
"#;
assert_normalized_ssa_equals(ssa, expected_src);
}

#[cfg(feature = "bn254")]
#[test]
fn partial_constant_folding() {
let src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v2 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6]
v3 = make_array [
Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9]
v4 = call multi_scalar_mul (v3, v2) -> [Field; 3]
return v4
}"#;
let ssa = Ssa::from_str(src).unwrap();
//First and last scalar/point are constant, so we should be left with the msm of the middle point and the folded constant point
let expected_src = r#"
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v5 = make_array [Field 1, Field 0, v0, Field 0, Field 2, Field 0] : [Field; 6]
v7 = make_array [Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0, v0, v1, Field 0, Field 1, Field 17631683881184975370165255887551781615748388533673675138860, Field 0] : [Field; 9]
v8 = make_array [v0, Field 0, Field 1, Field 0] : [Field; 4]
v12 = make_array [v0, v1, Field 0, Field -3227352362257037263902424173275354266044964400219754872043023745437788450996, Field 8902249110305491597038405103722863701255802573786510474664632793109847672620, u1 0] : [Field; 6]
v14 = call multi_scalar_mul(v12, v8) -> [Field; 3]
return v14
}
"#;
assert_normalized_ssa_equals(ssa, expected_src);
}
}
18 changes: 18 additions & 0 deletions test_programs/execution_success/embedded_curve_ops/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,22 @@ fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) {

// The results should be double the g1 point because the scalars are 1 and we pass in g1 twice
assert(double.x == res.x);

// Tests for #6549
let const_scalar1 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 23, hi: 0 };
let const_scalar2 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 23 };
let const_scalar3 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 13, hi: 4 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[g1, double, pub_point, g1, g1],
[scalar, const_scalar1, scalar, const_scalar2, const_scalar3],
);
assert(partial_mul.x == 0x2024c4eebfbc8a20018f8c95c7aab77c6f34f10cf785a6f04e97452d8708fda7);
// Check simplification by zero
let zero_point = std::embedded_curve_ops::EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true };
let const_zero = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 0 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[zero_point, double, g1],
[scalar, const_zero, scalar],
);
assert(partial_mul == g1);
}
Loading