Skip to content

Commit

Permalink
chore: Use helper functions for getting values of AcirVars (#2194)
Browse files Browse the repository at this point in the history
* chore: Use helper functions for getting values of `AcirVar`s

* chore: reorganise code
  • Loading branch information
TomAFrench authored Aug 7, 2023
1 parent 620517f commit aad82c3
Showing 1 changed file with 37 additions and 70 deletions.
107 changes: 37 additions & 70 deletions crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,29 @@ impl AcirContext {
self.add_data(var_data)
}

pub(crate) fn get_location(&mut self) -> Option<Location> {
pub(crate) fn get_location(&self) -> Option<Location> {
self.acir_ir.current_location
}

pub(crate) fn set_location(&mut self, location: Option<Location>) {
self.acir_ir.current_location = location;
}

/// Converts an [`AcirVar`] to a [`Witness`]
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
let expression = self.var_to_expression(var)?;
Ok(self.acir_ir.get_or_create_witness(&expression))
}

/// Converts an [`AcirVar`] to an [`Expression`]
fn var_to_expression(&self, var: AcirVar) -> Result<Expression, InternalError> {
let var_data = match self.vars.get(&var) {
Some(var_data) => var_data,
None => return Err(InternalError::UndeclaredAcirVar { location: self.get_location() }),
};
Ok(var_data.to_expression().into_owned())
}

/// True if the given AcirVar refers to a constant one value
pub(crate) fn is_constant_one(&self, var: &AcirVar) -> bool {
match self.vars[var] {
Expand Down Expand Up @@ -246,11 +261,8 @@ impl AcirContext {
/// Returns an `AcirVar` that is `1` if `lhs` equals `rhs` and
/// 0 otherwise.
pub(crate) fn eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, RuntimeError> {
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
let lhs_expr = self.var_to_expression(lhs)?;
let rhs_expr = self.var_to_expression(rhs)?;

let is_equal_witness = self.acir_ir.is_equal(&lhs_expr, &rhs_expr);
let result_var = self.add_data(AcirVarData::Witness(is_equal_witness));
Expand Down Expand Up @@ -479,13 +491,9 @@ impl AcirContext {
bit_size: u32,
predicate: AcirVar,
) -> Result<(AcirVar, AcirVar), RuntimeError> {
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
let predicate_data = &self.vars[&predicate];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
let predicate_expr = predicate_data.to_expression();
let lhs_expr = self.var_to_expression(lhs)?;
let rhs_expr = self.var_to_expression(rhs)?;
let predicate_expr = self.var_to_expression(predicate)?;

let (quotient, remainder) =
self.acir_ir.euclidean_division(&lhs_expr, &rhs_expr, bit_size, &predicate_expr)?;
Expand All @@ -500,24 +508,15 @@ impl AcirContext {
/// and |remainder| < |rhs|
/// and remainder has the same sign than lhs
/// Note that this is not the euclidian division, where we have instead remainder < |rhs|
///
///
///
///
fn signed_division_var(
&mut self,
lhs: AcirVar,
rhs: AcirVar,
bit_size: u32,
) -> Result<(AcirVar, AcirVar), RuntimeError> {
let lhs_data = &self.vars[&lhs].clone();
let rhs_data = &self.vars[&rhs].clone();
let l_witness = self.var_to_witness(lhs)?;
let r_witness = self.var_to_witness(rhs)?;

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
let l_witness = self.acir_ir.get_or_create_witness(&lhs_expr);
let r_witness = self.acir_ir.get_or_create_witness(&rhs_expr);
assert_ne!(bit_size, 0, "signed integer should have at least one bit");
let (q, r) =
self.acir_ir.signed_division(&l_witness.into(), &r_witness.into(), bit_size)?;
Expand Down Expand Up @@ -571,18 +570,7 @@ impl AcirContext {
/// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the
/// `GeneratedAcir`'s return witnesses.
pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> {
let acir_var_data = match self.vars.get(&acir_var) {
Some(acir_var_data) => acir_var_data,
None => return Err(InternalError::UndeclaredAcirVar { location: self.get_location() }),
};
// TODO: Add caching to prevent expressions from being needlessly duplicated
let witness = match acir_var_data {
AcirVarData::Const(constant) => {
self.acir_ir.get_or_create_witness(&Expression::from(*constant))
}
AcirVarData::Expr(expr) => self.acir_ir.get_or_create_witness(expr),
AcirVarData::Witness(witness) => *witness,
};
let witness = self.var_to_witness(acir_var)?;
self.acir_ir.push_return_witness(witness);
Ok(())
}
Expand All @@ -593,11 +581,9 @@ impl AcirContext {
variable: AcirVar,
numeric_type: &NumericType,
) -> Result<AcirVar, RuntimeError> {
let data = &self.vars[&variable];
match numeric_type {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let data_expr = data.to_expression();
let witness = self.acir_ir.get_or_create_witness(&data_expr);
let witness = self.var_to_witness(variable)?;
self.acir_ir.range_constraint(witness, *bit_size)?;
}
NumericType::NativeField => {
Expand All @@ -616,8 +602,7 @@ impl AcirContext {
rhs: u32,
max_bit_size: u32,
) -> Result<AcirVar, RuntimeError> {
let lhs_data = &self.vars[&lhs];
let lhs_expr = lhs_data.to_expression();
let lhs_expr = self.var_to_expression(lhs)?;

// 2^{rhs}
let divisor = FieldElement::from(2_i128).pow(&FieldElement::from(rhs as i128));
Expand All @@ -641,17 +626,12 @@ impl AcirContext {
bit_size: u32,
predicate: AcirVar,
) -> Result<AcirVar, RuntimeError> {
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();

let predicate_data = &self.vars[&predicate];
let predicate = predicate_data.to_expression().into_owned();
let lhs_expr = self.var_to_expression(lhs)?;
let rhs_expr = self.var_to_expression(rhs)?;
let predicate_expr = self.var_to_expression(predicate)?;

let is_greater_than_eq =
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size, predicate)?;
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size, predicate_expr)?;

Ok(self.add_data(AcirVarData::Witness(is_greater_than_eq)))
}
Expand Down Expand Up @@ -736,13 +716,10 @@ impl AcirContext {
for input in inputs {
let mut single_val_witnesses = Vec::new();
for (input, typ) in input.flatten() {
let var_data = &self.vars[&input];

// Intrinsics only accept Witnesses. This is not a limitation of the
// intrinsics, its just how we have defined things. Ideally, we allow
// constants too.
let expr = var_data.to_expression();
let witness = self.acir_ir.get_or_create_witness(&expr);
let witness = self.var_to_witness(input)?;
let num_bits = typ.bit_size();
single_val_witnesses.push(FunctionInput { witness, num_bits });
}
Expand Down Expand Up @@ -785,10 +762,10 @@ impl AcirContext {
}
};

let input_expr = &self.vars[&input_var].to_expression();
let input_expr = self.var_to_expression(input_var)?;

let bit_size = u32::BITS - (radix - 1).leading_zeros();
let limbs = self.acir_ir.radix_le_decompose(input_expr, radix, limb_count, bit_size)?;
let limbs = self.acir_ir.radix_le_decompose(&input_expr, radix, limb_count, bit_size)?;

let mut limb_vars = vecmap(limbs, |witness| {
let witness = self.add_data(AcirVarData::Witness(witness));
Expand Down Expand Up @@ -873,9 +850,7 @@ impl AcirContext {
outputs: Vec<AcirType>,
) -> Result<Vec<AcirValue>, InternalError> {
let b_inputs = try_vecmap(inputs, |i| match i {
AcirValue::Var(var, _) => {
Ok(BrilligInputs::Single(self.vars[&var].to_expression().into_owned()))
}
AcirValue::Var(var, _) => Ok(BrilligInputs::Single(self.var_to_expression(var)?)),
AcirValue::Array(vars) => {
let mut var_expressions: Vec<Expression> = Vec::new();
for var in vars {
Expand Down Expand Up @@ -904,7 +879,7 @@ impl AcirContext {
acir_value
}
});
let predicate = self.vars[&predicate].to_expression().into_owned();
let predicate = self.var_to_expression(predicate)?;
self.acir_ir.brillig(Some(predicate), code, b_inputs, b_outputs);

Ok(outputs_var)
Expand All @@ -917,7 +892,7 @@ impl AcirContext {
) -> Result<(), InternalError> {
match input {
AcirValue::Var(var, _) => {
var_expressions.push(self.vars[&var].to_expression().into_owned());
var_expressions.push(self.var_to_expression(var)?);
}
AcirValue::Array(vars) => {
for var in vars {
Expand Down Expand Up @@ -988,7 +963,7 @@ impl AcirContext {
) -> Result<Vec<AcirVar>, RuntimeError> {
let len = inputs.len();
// Convert the inputs into expressions
let inputs_expr = vecmap(inputs, |input| self.vars[&input].to_expression().into_owned());
let inputs_expr = try_vecmap(inputs, |input| self.var_to_expression(input))?;
// Generate output witnesses
let outputs_witness = vecmap(0..len, |_| self.acir_ir.next_witness_index());
let output_expr =
Expand All @@ -1007,14 +982,6 @@ impl AcirContext {

Ok(outputs_var)
}
/// Converts an AcirVar to a Witness
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
let var_data = match self.vars.get(&var) {
Some(var_data) => var_data,
None => return Err(InternalError::UndeclaredAcirVar { location: self.get_location() }),
};
Ok(self.acir_ir.get_or_create_witness(&var_data.to_expression()))
}

/// Constrain lhs to be less than rhs
fn less_than_constrain(
Expand Down

0 comments on commit aad82c3

Please sign in to comment.