Skip to content

Commit

Permalink
fix(ssa refactor): Speedup acir-gen (#1793)
Browse files Browse the repository at this point in the history
* Speedup acir-gen

* Remove outdated test
  • Loading branch information
jfecher authored Jun 22, 2023
1 parent 401888c commit 1e75f0e
Showing 1 changed file with 30 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{errors::AcirGenError, generated_acir::GeneratedAcir};
use crate::brillig::brillig_gen::brillig_directive;
use crate::ssa_refactor::acir_gen::AcirValue;
use crate::ssa_refactor::ir::types::Type as SsaType;
use crate::ssa_refactor::ir::{instruction::Endian, map::TwoWayMap, types::NumericType};
use crate::ssa_refactor::ir::{instruction::Endian, types::NumericType};
use acvm::acir::{
brillig_vm::Opcode as BrilligOpcode,
circuit::brillig::{BrilligInputs, BrilligOutputs},
Expand All @@ -16,6 +16,7 @@ use acvm::{
FieldElement,
};
use iter_extended::vecmap;
use std::collections::HashMap;
use std::{borrow::Cow, hash::Hash};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -92,7 +93,7 @@ pub(crate) struct AcirContext {
/// Two-way map that links `AcirVar` to `AcirVarData`.
///
/// The vars object is an instance of the `TwoWayMap`, which provides a bidirectional mapping between `AcirVar` and `AcirVarData`.
vars: TwoWayMap<AcirVar, AcirVarData>,
vars: HashMap<AcirVar, AcirVarData>,

/// An in-memory representation of ACIR.
///
Expand Down Expand Up @@ -126,7 +127,7 @@ impl AcirContext {
///
/// Note: `Variables` are immutable.
pub(crate) fn neg_var(&mut self, var: AcirVar) -> AcirVar {
let var_data = &self.vars[var];
let var_data = &self.vars[&var];
let result_data = if let AcirVarData::Const(constant) = var_data {
AcirVarData::Const(-*constant)
} else {
Expand All @@ -138,7 +139,7 @@ impl AcirContext {
/// Adds a new Variable to context whose value will
/// be constrained to be the inverse of `var`.
pub(crate) fn inv_var(&mut self, var: AcirVar) -> Result<AcirVar, AcirGenError> {
let var_data = &self.vars[var];
let var_data = &self.vars[&var];
if let AcirVarData::Const(constant) = var_data {
// Note that this will return a 0 if the inverse is not available
let result_var = self.add_data(AcirVarData::Const(constant.inverse()));
Expand Down Expand Up @@ -179,8 +180,8 @@ impl AcirContext {
/// Returns an `AcirVar` that is `1` if `lhs` equals `rhs` and
/// 0 otherwise.
pub(crate) fn eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
Expand Down Expand Up @@ -245,8 +246,8 @@ impl AcirContext {
/// Constrains the `lhs` and `rhs` to be equal.
pub(crate) fn assert_eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<(), AcirGenError> {
// TODO: could use sub_var and then assert_eq_zero
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
if let (AcirVarData::Const(lhs_const), AcirVarData::Const(rhs_const)) = (lhs_data, rhs_data)
{
if lhs_const == rhs_const {
Expand Down Expand Up @@ -297,8 +298,8 @@ impl AcirContext {
/// Adds a new Variable to context whose value will
/// be constrained to be the multiplication of `lhs` and `rhs`
pub(crate) fn mul_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
let result = match (lhs_data, rhs_data) {
(AcirVarData::Witness(witness), AcirVarData::Expr(expr))
| (AcirVarData::Expr(expr), AcirVarData::Witness(witness)) => {
Expand Down Expand Up @@ -351,8 +352,8 @@ impl AcirContext {
/// Adds a new Variable to context whose value will
/// be constrained to be the addition of `lhs` and `rhs`
pub(crate) fn add_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
let result_data = if let (AcirVarData::Const(lhs_const), AcirVarData::Const(rhs_const)) =
(lhs_data, rhs_data)
{
Expand Down Expand Up @@ -385,7 +386,7 @@ impl AcirContext {
rhs: AcirVar,
_typ: AcirType,
) -> Result<AcirVar, AcirGenError> {
let rhs_data = &self.vars[rhs];
let rhs_data = &self.vars[&rhs];

// Compute 2^{rhs}
let two_pow_rhs = match rhs_data.as_constant() {
Expand All @@ -406,8 +407,8 @@ impl AcirContext {
) -> Result<(AcirVar, AcirVar), AcirGenError> {
let predicate = Expression::one();

let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
Expand Down Expand Up @@ -448,7 +449,7 @@ impl AcirContext {
rhs: AcirVar,
typ: AcirType,
) -> Result<AcirVar, AcirGenError> {
let rhs_data = &self.vars[rhs];
let rhs_data = &self.vars[&rhs];

// Compute 2^{rhs}
let two_pow_rhs = match rhs_data.as_constant() {
Expand Down Expand Up @@ -481,7 +482,7 @@ impl AcirContext {
variable: AcirVar,
numeric_type: &NumericType,
) -> Result<AcirVar, AcirGenError> {
let data = &self.vars[variable];
let data = &self.vars[&variable];
match numeric_type {
NumericType::Signed { .. } => todo!("signed integer constraining is unimplemented"),
NumericType::Unsigned { bit_size } => {
Expand All @@ -503,7 +504,7 @@ impl AcirContext {
rhs: u32,
max_bit_size: u32,
) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let lhs_data = &self.vars[&lhs];
let lhs_expr = lhs_data.to_expression();

let result_expr = self.acir_ir.truncate(&lhs_expr, rhs, max_bit_size)?;
Expand All @@ -520,8 +521,8 @@ impl AcirContext {
bit_size: u32,
predicate: Option<AcirVar>,
) -> Result<AcirVar, AcirGenError> {
let lhs_data = &self.vars[lhs];
let rhs_data = &self.vars[rhs];
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
Expand All @@ -530,7 +531,7 @@ impl AcirContext {
// TODO: The frontend should shout in this case

let predicate = predicate.map(|acir_var| {
let predicate_data = &self.vars[acir_var];
let predicate_data = &self.vars[&acir_var];
predicate_data.to_expression().into_owned()
});
let is_greater_than_eq =
Expand Down Expand Up @@ -570,7 +571,7 @@ impl AcirContext {
let domain_var =
inputs.pop().expect("ICE: Pedersen call requires domain separator").into_var();

let domain_constant = self.vars[domain_var]
let domain_constant = self.vars[&domain_var]
.as_constant()
.expect("ICE: Domain separator must be a constant");

Expand Down Expand Up @@ -603,7 +604,7 @@ impl AcirContext {
let mut witnesses = Vec::new();
for input in inputs {
for (input, typ) in input.flatten() {
let var_data = &self.vars[input];
let var_data = &self.vars[&input];

// Intrinsics only accept Witnesses. This is not a limitation of the
// intrinsics, its just how we have defined things. Ideally, we allow
Expand Down Expand Up @@ -635,12 +636,12 @@ impl AcirContext {
self.vars[&radix_var].as_constant().expect("ICE: radix should be a constant").to_u128()
as u32;

let limb_count = self.vars[limb_count_var]
let limb_count = self.vars[&limb_count_var]
.as_constant()
.expect("ICE: limb_size should be a constant")
.to_u128() as u32;

let input_expr = &self.vars[input_var].to_expression();
let input_expr = &self.vars[&input_var].to_expression();

let bit_size = u32::BITS - (radix - 1).leading_zeros();
let limbs = self.acir_ir.radix_le_decompose(input_expr, radix, limb_count, bit_size)?;
Expand Down Expand Up @@ -684,7 +685,7 @@ impl AcirContext {
let input = Self::flatten_values(input);

let witnesses = vecmap(input, |acir_var| {
let var_data = &self.vars[acir_var];
let var_data = &self.vars[&acir_var];
let expr = var_data.to_expression();
self.acir_ir.get_or_create_witness(&expr)
});
Expand Down Expand Up @@ -727,7 +728,8 @@ impl AcirContext {
/// either the key or the value.
fn add_data(&mut self, data: AcirVarData) -> AcirVar {
let id = AcirVar(self.vars.len());
self.vars.insert(id, data)
self.vars.insert(id, data);
id
}

pub(crate) fn brillig(
Expand Down Expand Up @@ -781,7 +783,7 @@ impl AcirContext {
fn brillig_array_input(&self, var_expressions: &mut Vec<Expression>, input: AcirValue) {
match input {
AcirValue::Var(var, _) => {
var_expressions.push(self.vars[var].to_expression().into_owned());
var_expressions.push(self.vars[&var].to_expression().into_owned());
}
AcirValue::Array(vars) => {
for var in vars {
Expand Down Expand Up @@ -841,18 +843,3 @@ impl AcirVarData {
/// A Reference to an `AcirVarData`
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(crate) struct AcirVar(usize);

#[test]
fn repeat_op() {
let mut ctx = AcirContext::default();

let var_a = ctx.add_variable();
let var_b = ctx.add_variable();

// Multiplying the same variables twice should yield
// the same output.
let var_c = ctx.mul_var(var_a, var_b);
let should_be_var_c = ctx.mul_var(var_a, var_b);

assert_eq!(var_c, should_be_var_c);
}

0 comments on commit 1e75f0e

Please sign in to comment.