Skip to content

Commit

Permalink
fix: Do not emit range check for multiplication by bool (#6983)
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh authored Jan 8, 2025
1 parent bf474c0 commit c0a4010
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 122 deletions.
85 changes: 42 additions & 43 deletions compiler/noirc_evaluator/src/acir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1984,14 +1984,7 @@ impl<'a> Context<'a> {

if let NumericType::Unsigned { bit_size } = &num_type {
// Check for integer overflow
self.check_unsigned_overflow(
result,
*bit_size,
binary.lhs,
binary.rhs,
dfg,
binary.operator,
)?;
self.check_unsigned_overflow(result, *bit_size, binary, dfg)?;
}

Ok(result)
Expand All @@ -2002,47 +1995,18 @@ impl<'a> Context<'a> {
&mut self,
result: AcirVar,
bit_size: u32,
lhs: ValueId,
rhs: ValueId,
binary: &Binary,
dfg: &DataFlowGraph,
op: BinaryOp,
) -> Result<(), RuntimeError> {
// We try to optimize away operations that are guaranteed not to overflow
let max_lhs_bits = dfg.get_value_max_num_bits(lhs);
let max_rhs_bits = dfg.get_value_max_num_bits(rhs);

let msg = match op {
BinaryOp::Add => {
if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size {
// `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return Ok(());
}
"attempt to add with overflow".to_string()
}
BinaryOp::Sub => {
if dfg.is_constant(lhs) && max_lhs_bits > max_rhs_bits {
// `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0`
// Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible.
return Ok(());
}
"attempt to subtract with overflow".to_string()
}
BinaryOp::Mul => {
if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size {
// Either performing boolean multiplication (which cannot overflow),
// or `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return Ok(());
}
"attempt to multiply with overflow".to_string()
}
_ => return Ok(()),
let Some(msg) = binary.check_unsigned_overflow_msg(dfg, bit_size) else {
return Ok(());
};

let with_pred = self.acir_context.mul_var(result, self.current_side_effects_enabled_var)?;
self.acir_context.range_constrain_var(
with_pred,
&NumericType::Unsigned { bit_size },
Some(msg),
Some(msg.to_string()),
)?;
Ok(())
}
Expand Down Expand Up @@ -2888,8 +2852,9 @@ mod test {
use acvm::{
acir::{
circuit::{
brillig::BrilligFunctionId, opcodes::AcirFunctionId, ExpressionWidth, Opcode,
OpcodeLocation,
brillig::BrilligFunctionId,
opcodes::{AcirFunctionId, BlackBoxFuncCall},
ExpressionWidth, Opcode, OpcodeLocation,
},
native_types::Witness,
},
Expand All @@ -2913,6 +2878,8 @@ mod test {
},
};

use super::Ssa;

fn build_basic_foo_with_return(
builder: &mut FunctionBuilder,
foo_id: FunctionId,
Expand Down Expand Up @@ -3659,4 +3626,36 @@ mod test {
"Should have {expected_num_normal_calls} BrilligCall opcodes to normal Brillig functions but got {num_normal_brillig_calls}"
);
}

#[test]
fn multiply_with_bool_should_not_emit_range_check() {
let src = "
acir(inline) fn main f0 {
b0(v0: bool, v1: u32):
enable_side_effects v0
v2 = cast v0 as u32
v3 = mul v2, v1
return v3
}
";
let ssa = Ssa::from_str(src).unwrap();
let brillig = ssa.to_brillig(false);

let (mut acir_functions, _brillig_functions, _, _) = ssa
.into_acir(&brillig, ExpressionWidth::default())
.expect("Should compile manually written SSA into ACIR");

assert_eq!(acir_functions.len(), 1);

let opcodes = acir_functions[0].take_opcodes();

for opcode in opcodes {
if let Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) = opcode {
assert!(
input.to_witness().0 <= 1,
"only input witnesses should have range checks: {opcode:?}"
);
}
}
}
}
130 changes: 56 additions & 74 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1478,88 +1478,70 @@ impl<'block> BrilligBlock<'block> {
is_signed: bool,
) {
let bit_size = left.bit_size;
let max_lhs_bits = dfg.get_value_max_num_bits(binary.lhs);
let max_rhs_bits = dfg.get_value_max_num_bits(binary.rhs);

if bit_size == FieldElement::max_num_bits() {
if bit_size == FieldElement::max_num_bits() || is_signed {
return;
}

match (binary_operation, is_signed) {
(BrilligBinaryOp::Add, false) => {
if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size {
// `left` and `right` have both been casted up from smaller types and so cannot overflow.
return;
}

let condition =
SingleAddrVariable::new(self.brillig_context.allocate_register(), 1);
// Check that lhs <= result
self.brillig_context.binary_instruction(
left,
result,
condition,
BrilligBinaryOp::LessThanEquals,
);
self.brillig_context
.codegen_constrain(condition, Some("attempt to add with overflow".to_string()));
self.brillig_context.deallocate_single_addr(condition);
}
(BrilligBinaryOp::Sub, false) => {
if dfg.is_constant(binary.lhs) && max_lhs_bits > max_rhs_bits {
// `left` is a fixed constant and `right` is restricted such that `left - right > 0`
// Note strict inequality as `right > left` while `max_lhs_bits == max_rhs_bits` is possible.
return;
}

let condition =
SingleAddrVariable::new(self.brillig_context.allocate_register(), 1);
// Check that rhs <= lhs
self.brillig_context.binary_instruction(
right,
left,
condition,
BrilligBinaryOp::LessThanEquals,
);
self.brillig_context.codegen_constrain(
condition,
Some("attempt to subtract with overflow".to_string()),
);
self.brillig_context.deallocate_single_addr(condition);
}
(BrilligBinaryOp::Mul, false) => {
if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size {
// Either performing boolean multiplication (which cannot overflow),
// or `left` and `right` have both been casted up from smaller types and so cannot overflow.
return;
if let Some(msg) = binary.check_unsigned_overflow_msg(dfg, bit_size) {
match binary_operation {
BrilligBinaryOp::Add => {
let condition =
SingleAddrVariable::new(self.brillig_context.allocate_register(), 1);
// Check that lhs <= result
self.brillig_context.binary_instruction(
left,
result,
condition,
BrilligBinaryOp::LessThanEquals,
);
self.brillig_context.codegen_constrain(condition, Some(msg.to_string()));
self.brillig_context.deallocate_single_addr(condition);
}

let is_right_zero =
SingleAddrVariable::new(self.brillig_context.allocate_register(), 1);
let zero = self.brillig_context.make_constant_instruction(0_usize.into(), bit_size);
self.brillig_context.binary_instruction(
zero,
right,
is_right_zero,
BrilligBinaryOp::Equals,
);
self.brillig_context.codegen_if_not(is_right_zero.address, |ctx| {
let condition = SingleAddrVariable::new(ctx.allocate_register(), 1);
let division = SingleAddrVariable::new(ctx.allocate_register(), bit_size);
// Check that result / rhs == lhs
ctx.binary_instruction(result, right, division, BrilligBinaryOp::UnsignedDiv);
ctx.binary_instruction(division, left, condition, BrilligBinaryOp::Equals);
ctx.codegen_constrain(
BrilligBinaryOp::Sub => {
let condition =
SingleAddrVariable::new(self.brillig_context.allocate_register(), 1);
// Check that rhs <= lhs
self.brillig_context.binary_instruction(
right,
left,
condition,
Some("attempt to multiply with overflow".to_string()),
BrilligBinaryOp::LessThanEquals,
);
ctx.deallocate_single_addr(condition);
ctx.deallocate_single_addr(division);
});
self.brillig_context.deallocate_single_addr(is_right_zero);
self.brillig_context.deallocate_single_addr(zero);
self.brillig_context.codegen_constrain(condition, Some(msg.to_string()));
self.brillig_context.deallocate_single_addr(condition);
}
BrilligBinaryOp::Mul => {
let is_right_zero =
SingleAddrVariable::new(self.brillig_context.allocate_register(), 1);
let zero =
self.brillig_context.make_constant_instruction(0_usize.into(), bit_size);
self.brillig_context.binary_instruction(
zero,
right,
is_right_zero,
BrilligBinaryOp::Equals,
);
self.brillig_context.codegen_if_not(is_right_zero.address, |ctx| {
let condition = SingleAddrVariable::new(ctx.allocate_register(), 1);
let division = SingleAddrVariable::new(ctx.allocate_register(), bit_size);
// Check that result / rhs == lhs
ctx.binary_instruction(
result,
right,
division,
BrilligBinaryOp::UnsignedDiv,
);
ctx.binary_instruction(division, left, condition, BrilligBinaryOp::Equals);
ctx.codegen_constrain(condition, Some(msg.to_string()));
ctx.deallocate_single_addr(condition);
ctx.deallocate_single_addr(division);
});
self.brillig_context.deallocate_single_addr(is_right_zero);
self.brillig_context.deallocate_single_addr(zero);
}
_ => {}
}
_ => {}
}
}

Expand Down
43 changes: 43 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,49 @@ impl Binary {
};
SimplifyResult::None
}

/// Check if unsigned overflow is possible, and if so return some message to be used if it fails.
pub(crate) fn check_unsigned_overflow_msg(
&self,
dfg: &DataFlowGraph,
bit_size: u32,
) -> Option<&'static str> {
// We try to optimize away operations that are guaranteed not to overflow
let max_lhs_bits = dfg.get_value_max_num_bits(self.lhs);
let max_rhs_bits = dfg.get_value_max_num_bits(self.rhs);

let msg = match self.operator {
BinaryOp::Add => {
if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size {
// `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return None;
}
"attempt to add with overflow"
}
BinaryOp::Sub => {
if dfg.is_constant(self.lhs) && max_lhs_bits > max_rhs_bits {
// `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0`
// Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible.
return None;
}
"attempt to subtract with overflow"
}
BinaryOp::Mul => {
if bit_size == 1
|| max_lhs_bits + max_rhs_bits <= bit_size
|| max_lhs_bits == 1
|| max_rhs_bits == 1
{
// Either performing boolean multiplication (which cannot overflow),
// or `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return None;
}
"attempt to multiply with overflow"
}
_ => return None,
};
Some(msg)
}
}

/// Evaluate a binary operation with constant arguments.
Expand Down
12 changes: 8 additions & 4 deletions test_programs/gates_report_brillig_execution.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ set -e

# These tests are incompatible with gas reporting
excluded_dirs=(
"workspace"
"workspace_default_member"
"double_verify_nested_proof"
"overlapping_dep_and_mod"
"workspace"
"workspace_default_member"
"double_verify_nested_proof"
"overlapping_dep_and_mod"
"comptime_println"
# bit sizes for bigint operation doesn't match up.
"bigint"
Expand All @@ -33,6 +33,10 @@ for dir in $test_dirs; do
continue
fi

if [[ ! -f "${base_path}/${dir}/Nargo.toml" ]]; then
continue
fi

echo " \"execution_success/$dir\"," >> Nargo.toml
done

Expand Down
9 changes: 8 additions & 1 deletion tooling/nargo_cli/src/cli/info_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,14 @@ fn profile_brillig_execution(
initial_witness,
&Bn254BlackBoxSolver(pedantic_solving),
&mut DefaultForeignCallBuilder::default().build(),
)?;
)
.map_err(|e| {
CliError::Generic(format!(
"failed to execute '{}': {}",
package.root_dir.to_string_lossy(),
e
))
})?;

let expression_width = get_target_width(package.expression_width, expression_width);

Expand Down

0 comments on commit c0a4010

Please sign in to comment.