Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Avoid unnecessary range checks by inspecting instructions for casts #4039

Merged
merged 4 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@ impl FunctionBuilder {
if let Some(rhs_constant) = self.current_function.dfg.get_numeric_constant(rhs) {
// Happy case is that we know precisely by how many bits the the integer will
// increase: lhs_bit_size + rhs
let (rhs_bit_size_pow_2, overflows) =
2_u128.overflowing_pow(rhs_constant.to_u128() as u32);
let bit_shift_size = rhs_constant.to_u128() as u32;

let (rhs_bit_size_pow_2, overflows) = 2_u128.overflowing_pow(bit_shift_size);
if overflows {
assert!(bit_size < 128, "ICE - shift left with big integers are not supported");
if bit_size < 128 {
Expand All @@ -303,7 +304,18 @@ impl FunctionBuilder {
}
}
let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ);
(bit_size + (rhs_constant.to_u128() as u32), pow)

let mut max_lhs_bit = bit_size;

let dfg = &self.current_function.dfg;
if let Value::Instruction { instruction, .. } = dfg[lhs] {
if let Instruction::Cast(original_lhs, _) = dfg[instruction] {
let original_type = self.type_of_value(original_lhs);
max_lhs_bit = original_type.bit_size();
}
}

(max_lhs_bit + bit_shift_size, pow)
} else {
// we use a predicate to nullify the result in case of overflow
let bit_size_var =
Expand Down
107 changes: 86 additions & 21 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::ssa::ir::instruction::BinaryOp;
use crate::ssa::ir::instruction::Instruction;
use crate::ssa::ir::map::AtomicCounter;
use crate::ssa::ir::types::{NumericType, Type};
use crate::ssa::ir::value::ValueId;
use crate::ssa::ir::value::{Value as IrValue, ValueId};

use super::value::{Tree, Value, Values};
use fxhash::FxHashMap as HashMap;
Expand Down Expand Up @@ -335,29 +335,94 @@ impl<'a> FunctionContext<'a> {
}
}
Type::Numeric(NumericType::Unsigned { bit_size }) => {
let op_name = match operator {
BinaryOpKind::Add => "add",
BinaryOpKind::Subtract => "subtract",
BinaryOpKind::Multiply => "multiply",
BinaryOpKind::ShiftLeft => "left shift",
_ => unreachable!("operator {} should not overflow", operator),
let dfg = &self.builder.current_function.dfg;

let max_lhs_bits = match dfg[lhs] {
IrValue::Instruction { instruction, .. } => {
if let Instruction::Cast(original_lhs, _) = dfg[instruction] {
self.builder.type_of_value(original_lhs).bit_size()
} else {
self.builder.type_of_value(lhs).bit_size()
}
}

IrValue::NumericConstant { constant, .. } => constant.num_bits(),
_ => self.builder.type_of_value(lhs).bit_size(),
};
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved

if operator == BinaryOpKind::Multiply && bit_size == 1 {
result
} else if operator == BinaryOpKind::ShiftLeft
|| operator == BinaryOpKind::ShiftRight
{
self.check_shift_overflow(result, rhs, bit_size, location, false)
} else {
let message = format!("attempt to {} with overflow", op_name);
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
result
let max_rhs_bits = match dfg[rhs] {
IrValue::Instruction { instruction, .. } => {
if let Instruction::Cast(original_rhs, _) = dfg[instruction] {
self.builder.type_of_value(original_rhs).bit_size()
} else {
self.builder.type_of_value(rhs).bit_size()
}
}

IrValue::NumericConstant { constant, .. } => constant.num_bits(),
_ => self.builder.type_of_value(rhs).bit_size(),
};

match operator {
BinaryOpKind::Add => {
if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size {
// `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return result;
}

let message = "attempt to add with overflow".to_string();
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
}
BinaryOpKind::Subtract => {
if dfg.is_constant(lhs) && max_lhs_bits > max_rhs_bits {
// `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0`
// Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible.
return result;
}

let message = "attempt to subtract with overflow".to_string();
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
}
BinaryOpKind::Multiply => {
if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size {
// Either performing boolean multiplication (which cannot overflow),
// or `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return result;
}

let message = "attempt to multiply with overflow".to_string();
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
}
BinaryOpKind::ShiftLeft => {
if let Some(rhs_const) = dfg.get_numeric_constant(rhs) {
let bit_shift_size = rhs_const.to_u128() as u32;

if max_lhs_bits + bit_shift_size <= bit_size {
// `lhs` has been casted up from a smaller type such that shifting it by a constant
// `rhs` is known not to exceed the maximum bit size.
return result;
}
}

self.check_shift_overflow(result, rhs, bit_size, location, false);
}

_ => unreachable!("operator {} should not overflow", operator),
}

result
}
_ => result,
}
Expand Down
Loading