diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index 0a55e4ca17c..f49cd61e813 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -12,26 +12,36 @@ use acir::{ use crate::compiler::CircuitSimulator; -pub(crate) struct MergeExpressionsOptimizer { +pub(crate) struct MergeExpressionsOptimizer { resolved_blocks: HashMap>, + modified_gates: HashMap>, + deleted_gates: BTreeSet, } -impl MergeExpressionsOptimizer { +impl MergeExpressionsOptimizer { pub(crate) fn new() -> Self { - MergeExpressionsOptimizer { resolved_blocks: HashMap::new() } + MergeExpressionsOptimizer { + resolved_blocks: HashMap::new(), + modified_gates: HashMap::new(), + deleted_gates: BTreeSet::new(), + } } /// This pass analyzes the circuit and identifies intermediate variables that are /// only used in two gates. It then merges the gate that produces the /// intermediate variable into the second one that uses it /// Note: This pass is only relevant for backends that can handle unlimited width - pub(crate) fn eliminate_intermediate_variable( + pub(crate) fn eliminate_intermediate_variable( &mut self, circuit: &Circuit, acir_opcode_positions: Vec, ) -> (Vec>, Vec) { + // Initialization + self.modified_gates.clear(); + self.deleted_gates.clear(); + self.resolved_blocks.clear(); + // Keep track, for each witness, of the gates that use it let circuit_inputs = circuit.circuit_arguments(); - self.resolved_blocks = HashMap::new(); let mut used_witness: BTreeMap> = BTreeMap::new(); for (i, opcode) in circuit.opcodes.iter().enumerate() { let witnesses = self.witness_inputs(opcode); @@ -46,80 +56,89 @@ impl MergeExpressionsOptimizer { } } - let mut modified_gates: HashMap> = HashMap::new(); - let mut new_circuit = Vec::new(); - let mut new_acir_opcode_positions = Vec::new(); // For each opcode, try to get a target opcode to merge with - for (i, (opcode, opcode_position)) in - circuit.opcodes.iter().zip(acir_opcode_positions).enumerate() - { + for (i, opcode) in circuit.opcodes.iter().enumerate() { if !matches!(opcode, Opcode::AssertZero(_)) { - new_circuit.push(opcode.clone()); - new_acir_opcode_positions.push(opcode_position); continue; } - let opcode = modified_gates.get(&i).unwrap_or(opcode).clone(); - let mut to_keep = true; - let input_witnesses = self.witness_inputs(&opcode); - for w in input_witnesses { - let Some(gates_using_w) = used_witness.get(&w) else { - continue; - }; - // We only consider witness which are used in exactly two arithmetic gates - if gates_using_w.len() == 2 { - let first = *gates_using_w.first().expect("gates_using_w.len == 2"); - let second = *gates_using_w.last().expect("gates_using_w.len == 2"); - let b = if second == i { - first - } else { - // sanity check - assert!(i == first); - second + if let Some(opcode) = self.get_opcode(i, circuit) { + let input_witnesses = self.witness_inputs(&opcode); + for w in input_witnesses { + let Some(gates_using_w) = used_witness.get(&w) else { + continue; }; - - let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]); - if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) = - (&opcode, second_gate) - { - // We cannot merge an expression into an earlier opcode, because this - // would break the 'execution ordering' of the opcodes - // This case can happen because a previous merge would change an opcode - // and eliminate a witness from it, giving new opportunities for this - // witness to be used in only two expressions - // TODO: the missed optimization for the i>b case can be handled by - // - doing this pass again until there is no change, or - // - merging 'b' into 'i' instead - if i < b { - if let Some(expr) = Self::merge(expr_use, expr_define, w) { - modified_gates.insert(b, Opcode::AssertZero(expr)); - to_keep = false; - // Update the 'used_witness' map to account for the merge. - for w2 in CircuitSimulator::expr_wit(expr_define) { - if !circuit_inputs.contains(&w2) { - let v = used_witness.entry(w2).or_default(); - v.insert(b); - v.remove(&i); + // We only consider witness which are used in exactly two arithmetic gates + if gates_using_w.len() == 2 { + let first = *gates_using_w.first().expect("gates_using_w.len == 2"); + let second = *gates_using_w.last().expect("gates_using_w.len == 2"); + let b = if second == i { + first + } else { + // sanity check + assert!(i == first); + second + }; + // Merge the opcode with smaller index into the other one + // by updating modified_gates/deleted_gates/used_witness + // returns false if it could not merge them + let mut merge_opcodes = |op1, op2| -> bool { + if op1 == op2 { + return false; + } + let (source, target) = if op1 < op2 { (op1, op2) } else { (op2, op1) }; + let source_opcode = self.get_opcode(source, circuit); + let target_opcode = self.get_opcode(target, circuit); + if let ( + Some(Opcode::AssertZero(expr_use)), + Some(Opcode::AssertZero(expr_define)), + ) = (target_opcode, source_opcode) + { + if let Some(expr) = + Self::merge_expression(&expr_use, &expr_define, w) + { + self.modified_gates.insert(target, Opcode::AssertZero(expr)); + self.deleted_gates.insert(source); + // Update the 'used_witness' map to account for the merge. + let mut witness_list = CircuitSimulator::expr_wit(&expr_use); + witness_list.extend(CircuitSimulator::expr_wit(&expr_define)); + for w2 in witness_list { + if !circuit_inputs.contains(&w2) { + used_witness.entry(w2).and_modify(|v| { + v.insert(target); + v.remove(&source); + }); + } } + return true; } - // We need to stop here and continue with the next opcode - // because the merge invalidates the current opcode. - break; } + false + }; + + if merge_opcodes(b, i) { + // We need to stop here and continue with the next opcode + // because the merge invalidates the current opcode. + break; } } } } + } + + // Construct the new circuit from modified/deleted gates + let mut new_circuit = Vec::new(); + let mut new_acir_opcode_positions = Vec::new(); - if to_keep { - let opcode = modified_gates.get(&i).cloned().unwrap_or(opcode); - new_circuit.push(opcode); - new_acir_opcode_positions.push(opcode_position); + for (i, opcode_position) in acir_opcode_positions.iter().enumerate() { + if let Some(op) = self.get_opcode(i, circuit) { + new_circuit.push(op); + new_acir_opcode_positions.push(*opcode_position); } } (new_circuit, new_acir_opcode_positions) } - fn brillig_input_wit(&self, input: &BrilligInputs) -> BTreeSet { + fn brillig_input_wit(&self, input: &BrilligInputs) -> BTreeSet { let mut result = BTreeSet::new(); match input { BrilligInputs::Single(expr) => { @@ -152,7 +171,7 @@ impl MergeExpressionsOptimizer { } // Returns the input witnesses used by the opcode - fn witness_inputs(&self, opcode: &Opcode) -> BTreeSet { + fn witness_inputs(&self, opcode: &Opcode) -> BTreeSet { match opcode { Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr), Opcode::BlackBoxFuncCall(bb_func) => { @@ -198,7 +217,7 @@ impl MergeExpressionsOptimizer { // Merge 'expr' into 'target' via Gaussian elimination on 'w' // Returns None if the expressions cannot be merged - fn merge( + fn merge_expression( target: &Expression, expr: &Expression, w: Witness, @@ -226,6 +245,13 @@ impl MergeExpressionsOptimizer { } None } + + fn get_opcode(&self, g: usize, circuit: &Circuit) -> Option> { + if self.deleted_gates.contains(&g) { + return None; + } + self.modified_gates.get(&g).or(circuit.opcodes.get(g)).cloned() + } } #[cfg(test)] diff --git a/test_programs/execution_success/regression_6451/src/main.nr b/test_programs/execution_success/regression_6451/src/main.nr index fbee6956dfa..b13b6c25a7e 100644 --- a/test_programs/execution_success/regression_6451/src/main.nr +++ b/test_programs/execution_success/regression_6451/src/main.nr @@ -9,7 +9,7 @@ fn main(x: Field) { value += term2; value.assert_max_bit_size::<1>(); - // Regression test for Aztec Packages issue #6451 + // Regression test for #6447 (Aztec Packages issue #9488) let y = unsafe { empty(x + 1) }; let z = y + x + 1; let z1 = z + y;