Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: simplify MSM with constant folding #6650

Merged
merged 13 commits into from
Dec 6, 2024
152 changes: 114 additions & 38 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::sync::Arc;

use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement};

use crate::ssa::ir::instruction::BlackBoxFunc;
use crate::ssa::ir::{
basic_block::BasicBlockId,
dfg::{CallStack, DataFlowGraph},
instruction::{Instruction, SimplifyResult},
instruction::{Instruction, Intrinsic, SimplifyResult},
types::Type,
value::ValueId,
};
Expand Down Expand Up @@ -70,52 +71,127 @@ pub(super) fn simplify_msm(
block: BasicBlockId,
call_stack: &CallStack,
) -> SimplifyResult {
// TODO: Handle MSMs where a subset of the terms are constant.
let mut is_constant;

match (dfg.get_array_constant(arguments[0]), dfg.get_array_constant(arguments[1])) {
(Some((points, _)), Some((scalars, _))) => {
let Some(points) = points
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};

let Some(scalars) = scalars
.into_iter()
.map(|id| dfg.get_numeric_constant(id))
.collect::<Option<Vec<_>>>()
else {
return SimplifyResult::None;
};
// We decompose points and scalars into constant and non-constant parts in order to simplify MSMs where a subset of the terms are constant.
let mut constant_points = vec![];
let mut constant_scalars_lo = vec![];
let mut constant_scalars_hi = vec![];
let mut var_points = vec![];
let mut var_scalars = vec![];
let len = scalars.len() / 2;
for i in 0..len {
match (
dfg.get_numeric_constant(scalars[2 * i]),
dfg.get_numeric_constant(scalars[2 * i + 1]),
dfg.get_numeric_constant(points[3 * i]),
dfg.get_numeric_constant(points[3 * i + 1]),
dfg.get_numeric_constant(points[3 * i + 2]),
) {
(Some(lo), Some(hi), _, _, _)
if lo == FieldElement::zero() && hi == FieldElement::zero() =>
guipublic marked this conversation as resolved.
Show resolved Hide resolved
{
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(_, _, _, _, Some(infinity)) if infinity == FieldElement::one() => {
guipublic marked this conversation as resolved.
Show resolved Hide resolved
is_constant = true;
constant_scalars_lo.push(FieldElement::zero());
constant_scalars_hi.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::zero());
constant_points.push(FieldElement::one());
}
(Some(lo), Some(hi), Some(x), Some(y), Some(infinity)) => {
is_constant = true;
constant_scalars_lo.push(lo);
constant_scalars_hi.push(hi);
constant_points.push(x);
constant_points.push(y);
constant_points.push(infinity);
}
_ => {
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
is_constant = false;
}
}

let mut scalars_lo = Vec::new();
let mut scalars_hi = Vec::new();
for (i, scalar) in scalars.into_iter().enumerate() {
if i % 2 == 0 {
scalars_lo.push(scalar);
} else {
scalars_hi.push(scalar);
if !is_constant {
var_points.push(points[3 * i]);
var_points.push(points[3 * i + 1]);
var_points.push(points[3 * i + 2]);
var_scalars.push(scalars[2 * i]);
var_scalars.push(scalars[2 * i + 1]);
}
}

let Ok((result_x, result_y, result_is_infinity)) =
solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)
else {
// If there are no constant terms, we can't simplify
if constant_scalars_lo.is_empty() {
return SimplifyResult::None;
}
let Ok((result_x, result_y, result_is_infinity)) = solver.multi_scalar_mul(
&constant_points,
&constant_scalars_lo,
&constant_scalars_hi,
) else {
return SimplifyResult::None;
};

let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array =
dfg.insert_instruction_and_results(instruction, block, None, call_stack.clone());

SimplifyResult::SimplifiedTo(result_array.first())
// If there are no variable term, we can directly return the constant result
if var_scalars.is_empty() {
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::field());

let elements = im::vector![result_x, result_y, result_is_infinity];
let typ = Type::Array(Arc::new(vec![Type::field()]), 3);
let instruction = Instruction::MakeArray { elements, typ };
let result_array = dfg.insert_instruction_and_results(
instruction,
block,
None,
call_stack.clone(),
);

return SimplifyResult::SimplifiedTo(result_array.first());
}
// If there is only one non-null constant term, we cannot simplify
if constant_scalars_lo.len() == 1 && result_is_infinity != FieldElement::one() {
return SimplifyResult::None;
}
// Add the constant part back to the non-constant part, if it is not null
if result_is_infinity != FieldElement::one() {
let one = dfg.make_constant(FieldElement::one(), Type::field());
let zero = dfg.make_constant(FieldElement::zero(), Type::field());
guipublic marked this conversation as resolved.
Show resolved Hide resolved
var_scalars.push(one);
var_scalars.push(zero);
let result_x = dfg.make_constant(result_x, Type::field());
let result_y = dfg.make_constant(result_y, Type::field());
let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool());
var_points.push(result_x);
var_points.push(result_y);
var_points.push(result_is_infinity);
}
// Construct the simplified MSM expression
let typ = Type::Array(Arc::new(vec![Type::field()]), var_scalars.len());
let scalars = Instruction::MakeArray { elements: var_scalars.into(), typ };
let scalars = dfg
.insert_instruction_and_results(scalars, block, None, call_stack.clone())
.first();
let typ = Type::Array(Arc::new(vec![Type::field()]), var_points.len());
let points = Instruction::MakeArray { elements: var_points.into(), typ };
let points =
dfg.insert_instruction_and_results(points, block, None, call_stack.clone()).first();
let msm = dfg.import_intrinsic(Intrinsic::BlackBox(BlackBoxFunc::MultiScalarMul));
SimplifyResult::SimplifiedToInstruction(Instruction::Call {
func: msm,
arguments: vec![points, scalars],
})
}
_ => SimplifyResult::None,
}
Expand Down
18 changes: 18 additions & 0 deletions test_programs/execution_success/embedded_curve_ops/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,22 @@ fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) {

// The results should be double the g1 point because the scalars are 1 and we pass in g1 twice
assert(double.x == res.x);

// Tests for #6549
let const_scalar1 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 23, hi: 0 };
let const_scalar2 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 23 };
let const_scalar3 = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 13, hi: 4 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[g1, double, pub_point, g1, g1],
[scalar, const_scalar1, scalar, const_scalar2, const_scalar3],
);
assert(partial_mul.x == 0x2024c4eebfbc8a20018f8c95c7aab77c6f34f10cf785a6f04e97452d8708fda7);
// Check simplification by zero
let zero_point = std::embedded_curve_ops::EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true };
let const_zero = std::embedded_curve_ops::EmbeddedCurveScalar { lo: 0, hi: 0 };
let partial_mul = std::embedded_curve_ops::multi_scalar_mul(
[zero_point, double, g1],
[scalar, const_zero, scalar],
);
assert(partial_mul == g1);
}
Loading