Skip to content

Commit

Permalink
chore(ssa refactor): Add code to handle less than comparison (#1433)
Browse files Browse the repository at this point in the history
* add field mul and div

* add code to process field mul and div

* add assert example

* add `is_equal` constraint

* add `eq_var` method for AcirVar

* process `Constrain` instruction and BinaryOp::Eq

* add TODO for more than the maximum number of bits

* add numeric_cast_var method which constrains a variable to be equal to a NumericType

* implement casting for numeric types

* add simple range constraint example

* add constraints for `more_than_eq`

* - add more_than_eq method
- This method needs to know the bit size, so we cache this information whenever we do a range constraint.
We should ideally also cache it for constants too since we can figure out their bit-sizes easily

* add method to process less than binary operation

* add example

* assign result of cast operation

* add `y` as an input value

* return optimized circuit

* Addressed in Address GtEq extra opcodes #1444
  • Loading branch information
kevaundray authored May 30, 2023
1 parent 6faffad commit 92796aa
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "3"
y = "4"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Tests a very simple program.
//
// The features being tested is comparison
fn main(x : Field, y : Field) {
assert(x as u32 < y as u32);
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ pub(crate) struct AcirContext {
/// then the `acir_ir` will be populated to assert this
/// addition.
acir_ir: GeneratedAcir,

/// Maps an `AcirVar` to its known bit size.
variables_to_bit_sizes: HashMap<AcirVar, u32>,
}

impl AcirContext {
Expand Down Expand Up @@ -269,6 +272,8 @@ impl AcirContext {
let data_expr = data.to_expression();
let witness = self.acir_ir.get_or_create_witness(&data_expr);
self.acir_ir.range_constraint(witness, *bit_size)?;
// Log the bit size for this variable
self.variables_to_bit_sizes.insert(variable, *bit_size);
}
NumericType::NativeField => {
// If someone has made a cast to a `Field` type then this is a Noop.
Expand All @@ -280,6 +285,52 @@ impl AcirContext {
Ok(variable)
}

/// Returns an `AcirVar` which will be `1` if lhs >= rhs
/// and `0` otherwise.
fn more_than_eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.data[&lhs];
let rhs_data = &self.data[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();

let lhs_bit_size = self.variables_to_bit_sizes.get(&lhs).expect("comparisons cannot be made on variables with no known max bit size. This should have been caught by the frontend");
let rhs_bit_size = self.variables_to_bit_sizes.get(&rhs).expect("comparisons cannot be made on variables with no known max bit size. This should have been caught by the frontend");

// This is a conservative choice. Technically, we should just be able to take
// the bit size of the `lhs` (upper bound), but we need to check/document what happens
// if the bit_size is not enough to represent both witnesses.
// An example is the following: (a as u8) >= (b as u32)
// If the equality is true, then it means that `b` also fits inside
// of a u8.
// But its not clear what happens if the equality is false,
// and we 8 bits to `more_than_eq_comparison`. The conservative
// choice chosen is to use 32.
let bit_size = *std::cmp::max(lhs_bit_size, rhs_bit_size);

let is_greater_than_eq =
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size)?;

Ok(self.add_data(AcirVarData::Witness(is_greater_than_eq)))
}

/// Returns an `AcirVar` which will be `1` if lhs < rhs
/// and `0` otherwise.
pub(crate) fn less_than_var(
&mut self,
lhs: AcirVar,
rhs: AcirVar,
) -> Result<AcirVar, AcirGenError> {
// Flip the result of calling more than equal method to
// compute less than.
let comparison = self.more_than_eq_var(lhs, rhs)?;

let one = self.add_constant(FieldElement::one());
let comparison_negated = self.sub_var(one, comparison);

Ok(comparison_negated)
}

/// Terminates the context and takes the resulting `GeneratedAcir`
pub(crate) fn finish(self) -> GeneratedAcir {
self.acir_ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
//! program as it is being converted from SSA form.
use super::errors::AcirGenError;
use acvm::acir::{
circuit::opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode},
circuit::{
directives::QuotientDirective,
opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode},
},
native_types::Witness,
};
use acvm::{
Expand Down Expand Up @@ -240,4 +243,59 @@ impl GeneratedAcir {

Ok(())
}

/// Returns a `Witness` that is constrained to be:
/// - `1` if lhs >= rhs
/// - `0` otherwise
///
/// See [R1CS Workshop - Section 10](https://github.com/mir-protocol/r1cs-workshop/blob/master/workshop.pdf)
/// for an explanation.
pub(crate) fn more_than_eq_comparison(
&mut self,
a: &Expression,
b: &Expression,
max_bits: u32,
) -> Result<Witness, AcirGenError> {
// Ensure that 2^{max_bits + 1} is less than the field size
//
// TODO: perhaps this should be a user error, instead of an assert
assert!(max_bits + 1 < FieldElement::max_num_bits());

// Compute : 2^max_bits + a - b
let mut comparison_evaluation = a - b;
let two = FieldElement::from(2_i128);
let two_max_bits = two.pow(&FieldElement::from(max_bits as i128));
comparison_evaluation.q_c += two_max_bits;

let q_witness = self.next_witness_index();
let r_witness = self.next_witness_index();

// Add constraint : 2^{max_bits} + a - b = q * 2^{max_bits} + r
let mut expr = Expression::default();
expr.push_addition_term(two_max_bits, q_witness);
expr.push_addition_term(FieldElement::one(), r_witness);
self.push_opcode(AcirOpcode::Arithmetic(&comparison_evaluation - &expr));

self.push_opcode(AcirOpcode::Directive(Directive::Quotient(QuotientDirective {
a: comparison_evaluation,
b: Expression::from_field(two_max_bits),
q: q_witness,
r: r_witness,
predicate: None,
})));

// Add constraint to ensure `r` is correctly bounded
// between [0, 2^{max_bits}-1]
self.range_constraint(r_witness, max_bits)?;
// Add constraint to ensure that `q` is a boolean value
// in particular it should be the `n` bit of the comparison_evaluation
// which will indicate whether a >= b
//
// In the document linked above, they mention negating the value of `q`
// which would tell us whether a < b. Since we do not negate `q`
// what we get is a boolean indicating whether a >= b.
self.range_constraint(q_witness, 1)?;

Ok(q_witness)
}
}
9 changes: 6 additions & 3 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl Context {
Value::Intrinsic(..) => todo!(),
Value::Function(..) => unreachable!("ICE: All functions should have been inlined"),
Value::Instruction { .. } | Value::Param { .. } => {
unreachable!("ICE: Should have been in cache")
unreachable!("ICE: Should have been in cache {value:?}")
}
};
self.ssa_value_to_acir_var.insert(value_id, acir_var);
Expand All @@ -181,11 +181,14 @@ impl Context {
// Note: that this produces unnecessary constraints when
// this Eq instruction is being used for a constrain statement
BinaryOp::Eq => self.acir_context.eq_var(lhs, rhs),
BinaryOp::Lt => self
.acir_context
.less_than_var(lhs, rhs)
.expect("add Result types to all methods so errors bubble up"),
_ => todo!(),
}
}
/// Returns an `AcirVar` that is constrained to be `Type`.
/// Currently, we only allow casting to a NumericType.
/// Returns an `AcirVar` that is constrained to be
fn convert_ssa_cast(&mut self, value_id: &ValueId, typ: &Type, dfg: &DataFlowGraph) -> AcirVar {
let variable = self.convert_ssa_value(*value_id, dfg);

Expand Down

0 comments on commit 92796aa

Please sign in to comment.