diff --git a/crates/noirc_evaluator/src/ssa/acir_gen/acir_mem.rs b/crates/noirc_evaluator/src/ssa/acir_gen/acir_mem.rs index 12c85fe1d44..0c55f61ca20 100644 --- a/crates/noirc_evaluator/src/ssa/acir_gen/acir_mem.rs +++ b/crates/noirc_evaluator/src/ssa/acir_gen/acir_mem.rs @@ -15,7 +15,7 @@ use acvm::{ }; use iter_extended::vecmap; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashSet}; use super::{ constraints::{self, mul_with_witness, subtract}, @@ -25,7 +25,7 @@ use super::{ /// Represent a memory operation on the ArrayHeap, at the specified index /// Operation is one for a store and 0 for a load #[derive(Clone, Debug)] -struct MemOp { +pub(crate) struct MemOp { operation: Expression, value: Expression, index: Expression, @@ -33,6 +33,26 @@ struct MemOp { type MemAddress = u32; +enum ArrayType { + /// Initialization phase: initializing the array with writes on the 0..array.len range + /// It contains the HashSet of the initialized indexes and the maximum of these indexes + Init(HashSet, MemAddress), + /// Array is only written on, never read + WriteOnly, + /// Initialization phase and then only read, and optionally a bunch of writes at the end + /// The optional usize indicates the position of the ending writes if any: after this position, there are only writes + ReadOnly(Option), + /// Reads and writes outside the initialization phase + /// The optional usize indicates the position of the ending writes if any: after this position, there are only writes + ReadWrite(Option), +} + +impl Default for ArrayType { + fn default() -> Self { + ArrayType::Init(HashSet::default(), 0) + } +} + #[derive(Default)] struct ArrayHeap { // maps memory address to InternalVar @@ -40,6 +60,7 @@ struct ArrayHeap { trace: Vec, // maps memory address to (values,operation) that must be committed to the trace staged: BTreeMap, + typ: ArrayType, } impl ArrayHeap { @@ -55,12 +76,50 @@ impl ArrayHeap { self.staged.clear(); } - fn push(&mut self, index: Expression, value: Expression, op: Expression) { - let item = MemOp { operation: op, value, index }; + fn push(&mut self, item: MemOp) { + let is_load = item.operation == Expression::zero(); + let index_const = item.index.to_const(); + self.typ = match &self.typ { + ArrayType::Init(init_idx, len) => match (is_load, index_const) { + (false, Some(idx)) => { + let idx: MemAddress = idx.to_u128().try_into().unwrap(); + let mut init_idx2 = init_idx.clone(); + init_idx2.insert(idx); + let len2 = std::cmp::max(idx + 1, *len); + ArrayType::Init(init_idx2, len2) + } + (false, None) => ArrayType::WriteOnly, + (true, _) => { + if *len as usize == init_idx.len() { + ArrayType::ReadOnly(None) + } else { + ArrayType::ReadWrite(None) + } + } + }, + ArrayType::WriteOnly => { + if is_load { + ArrayType::ReadWrite(None) + } else { + ArrayType::WriteOnly + } + } + ArrayType::ReadOnly(last) => match (is_load, last) { + (true, Some(_)) => ArrayType::ReadWrite(None), + (true, None) => ArrayType::ReadOnly(None), + (false, None) => ArrayType::ReadOnly(Some(self.trace.len())), + (false, Some(_)) => ArrayType::ReadOnly(*last), + }, + ArrayType::ReadWrite(last) => match (is_load, last) { + (true, _) => ArrayType::ReadWrite(None), + (false, None) => ArrayType::ReadWrite(Some(self.trace.len())), + (false, Some(_)) => ArrayType::ReadWrite(*last), + }, + }; self.trace.push(item); } - fn stage(&mut self, index: u32, value: Expression, op: Expression) { + fn stage(&mut self, index: MemAddress, value: Expression, op: Expression) { self.staged.insert(index, (value, op)); } @@ -77,8 +136,14 @@ impl ArrayHeap { } outputs } + pub(crate) fn acir_gen(&self, evaluator: &mut Evaluator) { - let len = self.trace.len(); + let (len, read_write) = match self.typ { + ArrayType::Init(_, _) | ArrayType::WriteOnly => (0, true), + ArrayType::ReadOnly(last) => (last.unwrap_or(self.trace.len()), false), + ArrayType::ReadWrite(last) => (last.unwrap_or(self.trace.len()), true), + }; + if len == 0 { return; } @@ -90,20 +155,25 @@ impl ArrayHeap { let mut in_op = Vec::new(); let mut tuple_expressions = Vec::new(); - for (counter, item) in self.trace.iter().enumerate() { + for (counter, item) in self.trace.iter().take(len).enumerate() { let counter_expr = Expression::from_field(FieldElement::from(counter as i128)); in_counter.push(counter_expr.clone()); in_index.push(item.index.clone()); in_value.push(item.value.clone()); - in_op.push(item.operation.clone()); + if read_write { + in_op.push(item.operation.clone()); + } tuple_expressions.push(vec![item.index.clone(), counter_expr.clone()]); } let mut bit_counter = Vec::new(); let out_counter = Self::generate_outputs(in_counter, &mut bit_counter, evaluator); let out_index = Self::generate_outputs(in_index, &mut bit_counter, evaluator); let out_value = Self::generate_outputs(in_value, &mut bit_counter, evaluator); - let out_op = Self::generate_outputs(in_op, &mut bit_counter, evaluator); - + let out_op = if read_write { + Self::generate_outputs(in_op, &mut bit_counter, evaluator) + } else { + Vec::new() + }; // sort directive evaluator.opcodes.push(AcirOpcode::Directive(Directive::PermutationSort { inputs: tuple_expressions, @@ -111,8 +181,10 @@ impl ArrayHeap { bits: bit_counter, sort_by: vec![0, 1], })); - let init = subtract(&out_op[0], FieldElement::one(), &Expression::one()); - evaluator.opcodes.push(AcirOpcode::Arithmetic(init)); + if read_write { + let init = subtract(&out_op[0], FieldElement::one(), &Expression::one()); + evaluator.opcodes.push(AcirOpcode::Arithmetic(init)); + } for i in 0..len - 1 { // index sort let index_sub = subtract(&out_index[i + 1], FieldElement::one(), &out_index[i]); @@ -134,11 +206,19 @@ impl ArrayHeap { ); evaluator.opcodes.push(AcirOpcode::Arithmetic(secondary_order)); // consistency checks - let sub1 = subtract(&Expression::one(), FieldElement::one(), &out_op[i + 1]); let sub2 = subtract(&out_value[i + 1], FieldElement::one(), &out_value[i]); - let load_on_same_adr = mul_with_witness(evaluator, &sub1, &sub2); - let store_on_new_adr = mul_with_witness(evaluator, &index_sub, &sub1); - evaluator.opcodes.push(AcirOpcode::Arithmetic(store_on_new_adr)); + let load_on_same_adr = if read_write { + let sub1 = subtract(&Expression::one(), FieldElement::one(), &out_op[i + 1]); + let store_on_new_adr = mul_with_witness(evaluator, &index_sub, &sub1); + evaluator.opcodes.push(AcirOpcode::Arithmetic(store_on_new_adr)); + mul_with_witness(evaluator, &sub1, &sub2) + } else { + subtract( + &mul_with_witness(evaluator, &index_sub, &sub2), + FieldElement::one(), + &sub2, + ) + }; evaluator.opcodes.push(AcirOpcode::Arithmetic(load_on_same_adr)); } } @@ -152,7 +232,7 @@ pub(crate) struct AcirMem { impl AcirMem { // Returns the memory_map for the array - fn array_map_mut(&mut self, array_id: ArrayId) -> &mut BTreeMap { + fn array_map_mut(&mut self, array_id: ArrayId) -> &mut BTreeMap { &mut self.virtual_memory.entry(array_id).or_default().memory_map } @@ -235,7 +315,8 @@ impl AcirMem { op: Expression, ) { self.commit(array_id, op != Expression::zero()); - self.array_heap_mut(*array_id).push(index, value, op); + let item = MemOp { operation: op, value, index }; + self.array_heap_mut(*array_id).push(item); } pub(crate) fn acir_gen(&self, evaluator: &mut Evaluator) { for mem in &self.virtual_memory {