Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: remove 'single use' intermediate variables #6268

Merged
merged 9 commits into from
Oct 24, 2024
12 changes: 12 additions & 0 deletions acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashSet;

use crate::native_types::Witness;
use crate::{AcirField, BlackBoxFunc};

Expand Down Expand Up @@ -389,6 +391,16 @@ impl<F: Copy> BlackBoxFuncCall<F> {
BlackBoxFuncCall::BigIntToLeBytes { outputs, .. } => outputs.to_vec(),
}
}

pub fn get_input_witnesses(&self) -> HashSet<Witness> {
let mut result = HashSet::new();
for input in self.get_inputs_vec() {
if let ConstantOrWitnessEnum::Witness(w) = input.input() {
result.insert(w);
}
}
result
}
}

const ABBREVIATION_LIMIT: usize = 5;
Expand Down
202 changes: 202 additions & 0 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use std::collections::{HashMap, HashSet};

use acir::{
circuit::{brillig::BrilligInputs, directives::Directive, opcodes::BlockId, Circuit, Opcode},
native_types::{Expression, Witness},
AcirField,
};

pub(crate) struct MergeExpressionsOptimizer {
resolved_blocks: HashMap<BlockId, HashSet<Witness>>,
}

impl MergeExpressionsOptimizer {
pub(crate) fn new() -> Self {
MergeExpressionsOptimizer { resolved_blocks: HashMap::new() }
}
/// This pass analyzes the circuit and identifies intermediate variables that are
/// only used in two gates. It then merges the gate that produces the
/// intermediate variable into the second one that uses it
/// Note: This pass is only relevant for backends that can handle unlimited width
pub(crate) fn eliminate_intermediate_variable<F: AcirField>(
&mut self,
circuit: &Circuit<F>,
acir_opcode_positions: Vec<usize>,
) -> (Vec<Opcode<F>>, Vec<usize>) {
// Keep track, for each witness, of the gates that use it
let circuit_inputs = circuit.circuit_arguments();
self.resolved_blocks = HashMap::new();
let mut used_witness: HashMap<Witness, HashSet<usize>> = HashMap::new();
for (i, opcode) in circuit.opcodes.iter().enumerate() {
let witnesses = self.witness_inputs(opcode);
if let Opcode::MemoryInit { block_id, .. } = opcode {
self.resolved_blocks.insert(*block_id, witnesses.clone());
}
for w in witnesses {
// We do not simplify circuit inputs
if !circuit_inputs.contains(&w) {
used_witness.entry(w).or_default().insert(i);
}
}
}

let mut modified_gates: HashMap<usize, Opcode<F>> = HashMap::new();
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();
// For each opcode, try to get a target opcode to merge with
for (i, opcode) in circuit.opcodes.iter().enumerate() {
if !matches!(opcode, Opcode::AssertZero(_)) {
new_circuit.push(opcode.clone());
new_acir_opcode_positions.push(acir_opcode_positions[i]);
continue;
}
let opcode = modified_gates.get(&i).unwrap_or(opcode).clone();
let mut to_keep = true;
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses.clone() {
let empty_gates = HashSet::new();
let gates_using_w = used_witness.get(&w).unwrap_or(&empty_gates);
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let gates_using_w: Vec<_> = gates_using_w.iter().collect();
let mut b = *gates_using_w[1];
if b == i {
b = *gates_using_w[0];
} else {
// sanity check
assert!(i == *gates_using_w[0]);
}
let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]).clone();
if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) =
(opcode.clone(), second_gate)
{
if let Some(expr) = Self::merge(&expr_use, &expr_define, w) {
// sanity check
assert!(i < b);
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in Self::expr_wit(&expr_define) {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(b);
v.remove(&i);
used_witness.insert(w2, v);
}
}
// We need to stop here and continue with the next opcode
// because the merge invalidate the current opcode
break;
}
}
}
}

if to_keep {
if modified_gates.contains_key(&i) {
new_circuit.push(modified_gates[&i].clone());
} else {
new_circuit.push(opcode.clone());
}
new_acir_opcode_positions.push(acir_opcode_positions[i]);
}
}
(new_circuit, new_acir_opcode_positions)
}

fn expr_wit<F>(expr: &Expression<F>) -> HashSet<Witness> {
let mut result = HashSet::new();
result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2]));
result.extend(expr.linear_combinations.iter().map(|i| i.1));
result
}

fn brillig_input_wit<F>(&self, input: &BrilligInputs<F>) -> HashSet<Witness> {
let mut result = HashSet::new();
match input {
BrilligInputs::Single(expr) => {
result.extend(Self::expr_wit(expr));
}
BrilligInputs::Array(exprs) => {
for expr in exprs {
result.extend(Self::expr_wit(expr));
}
}
BrilligInputs::MemoryArray(block_id) => {
let witnesses = self.resolved_blocks.get(block_id).expect("Unknown block id");
result.extend(witnesses);
}
}
result
}

// Returns the input witnesses used by the opcode
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> HashSet<Witness> {
let mut witnesses = HashSet::new();
match opcode {
Opcode::AssertZero(expr) => Self::expr_wit(expr),
Opcode::BlackBoxFuncCall(bb_func) => bb_func.get_input_witnesses(),
Opcode::Directive(Directive::ToLeRadix { a, .. }) => Self::expr_wit(a),
Opcode::MemoryOp { block_id: _, op, predicate } => {
//index et value, et predicate
let mut witnesses = HashSet::new();
witnesses.extend(Self::expr_wit(&op.index));
witnesses.extend(Self::expr_wit(&op.value));
if let Some(p) = predicate {
witnesses.extend(Self::expr_wit(p));
}
witnesses
}

Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
init.iter().cloned().collect()
}
Opcode::BrilligCall { inputs, .. } => {
for i in inputs {
witnesses.extend(self.brillig_input_wit(i));
}
witnesses
}
Opcode::Call { id: _, inputs, outputs: _, predicate } => {
for i in inputs {
witnesses.insert(*i);
}
if let Some(p) = predicate {
witnesses.extend(Self::expr_wit(p));
}
witnesses
}
}
}

// Merge 'expr' into 'target' via Gaussian elimination on 'w'
// Returns None if the expressions cannot be merged
fn merge<F: AcirField>(
target: &Expression<F>,
expr: &Expression<F>,
w: Witness,
) -> Option<Expression<F>> {
// Check that the witness is not part of multiplication terms
for m in &target.mul_terms {
if m.1 == w || m.2 == w {
return None;
}
}
for m in &expr.mul_terms {
if m.1 == w || m.2 == w {
return None;
}
}

for k in &target.linear_combinations {
if k.1 == w {
for i in &expr.linear_combinations {
if i.1 == w {
return Some(target.add_mul(-(k.0 / i.0), expr));
}
}
}
}
None
}
}
2 changes: 2 additions & 0 deletions acvm-repo/acvm/src/compiler/optimizers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ use acir::{

// mod constant_backpropagation;
mod general;
mod merge_expressions;
mod redundant_range;
mod unused_memory;

pub(crate) use general::GeneralOptimizer;
pub(crate) use merge_expressions::MergeExpressionsOptimizer;
pub(crate) use redundant_range::RangeOptimizer;
use tracing::info;

Expand Down
16 changes: 14 additions & 2 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ mod csat;
pub(crate) use csat::CSatTransformer;
pub use csat::MIN_EXPRESSION_WIDTH;

use super::{transform_assert_messages, AcirTransformationMap};
use super::{
optimizers::MergeExpressionsOptimizer, transform_assert_messages, AcirTransformationMap,
};

/// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`].
pub fn transform<F: AcirField>(
Expand Down Expand Up @@ -166,6 +168,16 @@ pub(super) fn transform_internal<F: AcirField>(
// The transformer does not add new public inputs
..acir
};

let mut merge_optimizer = MergeExpressionsOptimizer::new();
let (opcodes, new_acir_opcode_positions) =
merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);
// n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
let acir = Circuit {
current_witness_index,
expression_width,
opcodes,
// The optimizer does not add new public inputs
..acir
};
(acir, new_acir_opcode_positions)
}
Loading