diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/Nargo.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/Nargo.toml new file mode 100644 index 00000000000..670888e37cd --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/Nargo.toml @@ -0,0 +1,5 @@ +[package] +authors = [""] +compiler_version = "0.6.0" + +[dependencies] \ No newline at end of file diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/Prover.toml b/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/Prover.toml new file mode 100644 index 00000000000..e0d79da4da6 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/Prover.toml @@ -0,0 +1 @@ +xs = [2, 1, 3] diff --git a/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/src/main.nr b/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/src/main.nr new file mode 100644 index 00000000000..17df7b23551 --- /dev/null +++ b/crates/nargo_cli/tests/test_data_ssa_refactor/array_sort/src/main.nr @@ -0,0 +1,6 @@ +fn main(xs : [u8; 3]) { + let sorted = xs.sort(); + assert(sorted[0] == 1); + assert(sorted[1] == 2); + assert(sorted[2] == 3); +} diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir.rs index ae183aa962f..6e715002161 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir.rs @@ -1,3 +1,4 @@ pub(crate) mod acir_variable; pub(crate) mod errors; pub(crate) mod generated_acir; +pub(crate) mod sort; diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs index df8afc29d99..b4aa7258726 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/acir_variable.rs @@ -7,6 +7,7 @@ use acvm::acir::{ brillig_vm::Opcode as BrilligOpcode, circuit::brillig::{BrilligInputs, BrilligOutputs}, }; + use acvm::{ acir::{ circuit::opcodes::FunctionInput, @@ -800,6 +801,46 @@ impl AcirContext { } } } + + /// Generate output variables that are constrained to be the sorted inputs + /// The outputs are the sorted inputs iff + /// outputs are sorted and + /// outputs are a permutation of the inputs + pub(crate) fn sort( + &mut self, + inputs: Vec, + bit_size: u32, + ) -> Result, AcirGenError> { + let len = inputs.len(); + // Convert the inputs into expressions + let inputs_expr = vecmap(inputs, |input| self.vars[&input].to_expression().into_owned()); + // Generate output witnesses + let outputs_witness = vecmap(0..len, |_| self.acir_ir.next_witness_index()); + let output_expr = + vecmap(&outputs_witness, |witness_index| Expression::from(*witness_index)); + let outputs_var = vecmap(&outputs_witness, |witness_index| { + self.add_data(AcirVarData::Witness(*witness_index)) + }); + // Enforce the outputs to be sorted + for i in 0..(outputs_var.len() - 1) { + self.less_than_constrain(outputs_var[i], outputs_var[i + 1], bit_size)?; + } + // Enforce the outputs to be a permutation of the inputs + self.acir_ir.permutation(&inputs_expr, &output_expr); + + Ok(outputs_var) + } + + /// Constrain lhs to be less than rhs + fn less_than_constrain( + &mut self, + lhs: AcirVar, + rhs: AcirVar, + bit_size: u32, + ) -> Result<(), AcirGenError> { + let lhs_less_than_rhs = self.more_than_eq_var(rhs, lhs, bit_size, None)?; + self.assert_eq_one(lhs_less_than_rhs) + } } /// Enum representing the possible values that a diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs index 67b2ce984d9..11b1b6a6d92 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/generated_acir.rs @@ -684,6 +684,29 @@ impl GeneratedAcir { }); self.push_opcode(opcode); } + + /// Generate gates and control bits witnesses which ensure that out_expr is a permutation of in_expr + /// Add the control bits of the sorting network used to generate the constrains + /// into the PermutationSort directive for solving in ACVM. + /// The directive is solving the control bits so that the outputs are sorted in increasing order. + /// + /// n.b. A sorting network is a predetermined set of switches, + /// the control bits indicate the configuration of each switch: false for pass-through and true for cross-over + pub(crate) fn permutation(&mut self, in_expr: &[Expression], out_expr: &[Expression]) { + let bits = Vec::new(); + let (w, b) = self.permutation_layer(in_expr, &bits, true); + // Constrain the network output to out_expr + for (b, o) in b.iter().zip(out_expr) { + self.push_opcode(AcirOpcode::Arithmetic(b - o)); + } + let inputs = in_expr.iter().map(|a| vec![a.clone()]).collect(); + self.push_opcode(AcirOpcode::Directive(Directive::PermutationSort { + inputs, + tuple: 1, + bits: w, + sort_by: vec![0], + })); + } } /// This function will return the number of inputs that a blackbox function diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/sort.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/sort.rs new file mode 100644 index 00000000000..02898bacde4 --- /dev/null +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/sort.rs @@ -0,0 +1,101 @@ +use acvm::acir::native_types::{Expression, Witness}; + +use super::generated_acir::GeneratedAcir; + +impl GeneratedAcir { + // Generates gates for a sorting network + // returns witness corresponding to the network configuration and the expressions corresponding to the network output + // in_expr: inputs of the sorting network + // if generate_witness is false, it uses the witness provided in bits instead of generating them + // in both cases it returns the witness of the network configuration + // if generate_witness is true, bits is ignored + pub(crate) fn permutation_layer( + &mut self, + in_expr: &[Expression], + bits: &[Witness], + generate_witness: bool, + ) -> (Vec, Vec) { + let n = in_expr.len(); + if n == 1 { + return (Vec::new(), in_expr.to_vec()); + } + let n1 = n / 2; + + // witness for the input switches + let mut conf = iter_extended::vecmap(0..n1, |i| { + if generate_witness { + self.next_witness_index() + } else { + bits[i] + } + }); + + // compute expressions after the input switches + // If inputs are a1,a2, and the switch value is c, then we compute expressions b1,b2 where + // b1 = a1+q, b2 = a2-q, q = c(a2-a1) + let mut in_sub1 = Vec::new(); + let mut in_sub2 = Vec::new(); + for i in 0..n1 { + //q = c*(a2-a1); + let intermediate = self.mul_with_witness( + &Expression::from(conf[i]), + &(&in_expr[2 * i + 1] - &in_expr[2 * i]), + ); + //b1=a1+q + in_sub1.push(&intermediate + &in_expr[2 * i]); + //b2=a2-q + in_sub2.push(&in_expr[2 * i + 1] - &intermediate); + } + if n % 2 == 1 { + in_sub2.push(in_expr.last().unwrap().clone()); + } + let mut out_expr = Vec::new(); + // compute results for the sub networks + let bits1 = if generate_witness { bits } else { &bits[n1 + (n - 1) / 2..] }; + let (w1, b1) = self.permutation_layer(&in_sub1, bits1, generate_witness); + let bits2 = if generate_witness { bits } else { &bits[n1 + (n - 1) / 2 + w1.len()..] }; + let (w2, b2) = self.permutation_layer(&in_sub2, bits2, generate_witness); + // apply the output switches + for i in 0..(n - 1) / 2 { + let c = if generate_witness { self.next_witness_index() } else { bits[n1 + i] }; + conf.push(c); + let intermediate = self.mul_with_witness(&Expression::from(c), &(&b2[i] - &b1[i])); + out_expr.push(&intermediate + &b1[i]); + out_expr.push(&b2[i] - &intermediate); + } + if n % 2 == 0 { + out_expr.push(b1.last().unwrap().clone()); + } + out_expr.push(b2.last().unwrap().clone()); + conf.extend(w1); + conf.extend(w2); + (conf, out_expr) + } + + /// Returns an expression which represents a*b + /// If one has multiplicative term and the other is of degree one or more, + /// the function creates intermediate variables accordindly + fn mul_with_witness(&mut self, a: &Expression, b: &Expression) -> Expression { + let a_arith; + let a_arith = if !a.mul_terms.is_empty() && !b.is_const() { + let a_witness = self.get_or_create_witness(a); + a_arith = Expression::from(a_witness); + &a_arith + } else { + a + }; + let b_arith; + let b_arith = if !b.mul_terms.is_empty() && !a.is_const() { + if a == b { + a_arith + } else { + let b_witness = self.get_or_create_witness(a); + b_arith = Expression::from(b_witness); + &b_arith + } + } else { + b + }; + a_arith * b_arith + } +} diff --git a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs index ed07a6b773c..27053716859 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs @@ -651,7 +651,31 @@ impl Context { } Vec::new() } - _ => todo!("expected a black box function"), + Intrinsic::Sort => { + let inputs = vecmap(arguments, |arg| self.convert_value(*arg, dfg)); + // We flatten the inputs and retrieve the bit_size of the elements + let mut input_vars = Vec::new(); + let mut bit_size = 0; + for input in inputs { + for (var, typ) in input.flatten() { + input_vars.push(var); + if bit_size == 0 { + bit_size = typ.bit_size(); + } else { + assert_eq!( + bit_size, + typ.bit_size(), + "cannot sort element of different bit size" + ); + } + } + } + // Generate the sorted output variables + let out_vars = + self.acir_context.sort(input_vars, bit_size).expect("Could not sort"); + + Self::convert_vars_to_values(out_vars, dfg, result_ids) + } } }