Skip to content

Commit

Permalink
chore: array.sort() intrinsic function for the new ssa (noir-lang#1782)
Browse files Browse the repository at this point in the history
* array.sort() intrinsic function for the new ssa

* Code review

* Add array_sort test
  • Loading branch information
guipublic authored Jun 26, 2023
1 parent fa9be1d commit 6fa751b
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.6.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
xs = [2, 1, 3]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(xs : [u8; 3]) {
let sorted = xs.sort();
assert(sorted[0] == 1);
assert(sorted[1] == 2);
assert(sorted[2] == 3);
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub(crate) mod acir_variable;
pub(crate) mod errors;
pub(crate) mod generated_acir;
pub(crate) mod sort;
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use acvm::acir::{
brillig_vm::Opcode as BrilligOpcode,
circuit::brillig::{BrilligInputs, BrilligOutputs},
};

use acvm::{
acir::{
circuit::opcodes::FunctionInput,
Expand Down Expand Up @@ -800,6 +801,46 @@ impl AcirContext {
}
}
}

/// Generate output variables that are constrained to be the sorted inputs
/// The outputs are the sorted inputs iff
/// outputs are sorted and
/// outputs are a permutation of the inputs
pub(crate) fn sort(
&mut self,
inputs: Vec<AcirVar>,
bit_size: u32,
) -> Result<Vec<AcirVar>, AcirGenError> {
let len = inputs.len();
// Convert the inputs into expressions
let inputs_expr = vecmap(inputs, |input| self.vars[&input].to_expression().into_owned());
// Generate output witnesses
let outputs_witness = vecmap(0..len, |_| self.acir_ir.next_witness_index());
let output_expr =
vecmap(&outputs_witness, |witness_index| Expression::from(*witness_index));
let outputs_var = vecmap(&outputs_witness, |witness_index| {
self.add_data(AcirVarData::Witness(*witness_index))
});
// Enforce the outputs to be sorted
for i in 0..(outputs_var.len() - 1) {
self.less_than_constrain(outputs_var[i], outputs_var[i + 1], bit_size)?;
}
// Enforce the outputs to be a permutation of the inputs
self.acir_ir.permutation(&inputs_expr, &output_expr);

Ok(outputs_var)
}

/// Constrain lhs to be less than rhs
fn less_than_constrain(
&mut self,
lhs: AcirVar,
rhs: AcirVar,
bit_size: u32,
) -> Result<(), AcirGenError> {
let lhs_less_than_rhs = self.more_than_eq_var(rhs, lhs, bit_size, None)?;
self.assert_eq_one(lhs_less_than_rhs)
}
}

/// Enum representing the possible values that a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,29 @@ impl GeneratedAcir {
});
self.push_opcode(opcode);
}

/// Generate gates and control bits witnesses which ensure that out_expr is a permutation of in_expr
/// Add the control bits of the sorting network used to generate the constrains
/// into the PermutationSort directive for solving in ACVM.
/// The directive is solving the control bits so that the outputs are sorted in increasing order.
///
/// n.b. A sorting network is a predetermined set of switches,
/// the control bits indicate the configuration of each switch: false for pass-through and true for cross-over
pub(crate) fn permutation(&mut self, in_expr: &[Expression], out_expr: &[Expression]) {
let bits = Vec::new();
let (w, b) = self.permutation_layer(in_expr, &bits, true);
// Constrain the network output to out_expr
for (b, o) in b.iter().zip(out_expr) {
self.push_opcode(AcirOpcode::Arithmetic(b - o));
}
let inputs = in_expr.iter().map(|a| vec![a.clone()]).collect();
self.push_opcode(AcirOpcode::Directive(Directive::PermutationSort {
inputs,
tuple: 1,
bits: w,
sort_by: vec![0],
}));
}
}

/// This function will return the number of inputs that a blackbox function
Expand Down
101 changes: 101 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/acir_gen/acir_ir/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use acvm::acir::native_types::{Expression, Witness};

use super::generated_acir::GeneratedAcir;

impl GeneratedAcir {
// Generates gates for a sorting network
// returns witness corresponding to the network configuration and the expressions corresponding to the network output
// in_expr: inputs of the sorting network
// if generate_witness is false, it uses the witness provided in bits instead of generating them
// in both cases it returns the witness of the network configuration
// if generate_witness is true, bits is ignored
pub(crate) fn permutation_layer(
&mut self,
in_expr: &[Expression],
bits: &[Witness],
generate_witness: bool,
) -> (Vec<Witness>, Vec<Expression>) {
let n = in_expr.len();
if n == 1 {
return (Vec::new(), in_expr.to_vec());
}
let n1 = n / 2;

// witness for the input switches
let mut conf = iter_extended::vecmap(0..n1, |i| {
if generate_witness {
self.next_witness_index()
} else {
bits[i]
}
});

// compute expressions after the input switches
// If inputs are a1,a2, and the switch value is c, then we compute expressions b1,b2 where
// b1 = a1+q, b2 = a2-q, q = c(a2-a1)
let mut in_sub1 = Vec::new();
let mut in_sub2 = Vec::new();
for i in 0..n1 {
//q = c*(a2-a1);
let intermediate = self.mul_with_witness(
&Expression::from(conf[i]),
&(&in_expr[2 * i + 1] - &in_expr[2 * i]),
);
//b1=a1+q
in_sub1.push(&intermediate + &in_expr[2 * i]);
//b2=a2-q
in_sub2.push(&in_expr[2 * i + 1] - &intermediate);
}
if n % 2 == 1 {
in_sub2.push(in_expr.last().unwrap().clone());
}
let mut out_expr = Vec::new();
// compute results for the sub networks
let bits1 = if generate_witness { bits } else { &bits[n1 + (n - 1) / 2..] };
let (w1, b1) = self.permutation_layer(&in_sub1, bits1, generate_witness);
let bits2 = if generate_witness { bits } else { &bits[n1 + (n - 1) / 2 + w1.len()..] };
let (w2, b2) = self.permutation_layer(&in_sub2, bits2, generate_witness);
// apply the output switches
for i in 0..(n - 1) / 2 {
let c = if generate_witness { self.next_witness_index() } else { bits[n1 + i] };
conf.push(c);
let intermediate = self.mul_with_witness(&Expression::from(c), &(&b2[i] - &b1[i]));
out_expr.push(&intermediate + &b1[i]);
out_expr.push(&b2[i] - &intermediate);
}
if n % 2 == 0 {
out_expr.push(b1.last().unwrap().clone());
}
out_expr.push(b2.last().unwrap().clone());
conf.extend(w1);
conf.extend(w2);
(conf, out_expr)
}

/// Returns an expression which represents a*b
/// If one has multiplicative term and the other is of degree one or more,
/// the function creates intermediate variables accordindly
fn mul_with_witness(&mut self, a: &Expression, b: &Expression) -> Expression {
let a_arith;
let a_arith = if !a.mul_terms.is_empty() && !b.is_const() {
let a_witness = self.get_or_create_witness(a);
a_arith = Expression::from(a_witness);
&a_arith
} else {
a
};
let b_arith;
let b_arith = if !b.mul_terms.is_empty() && !a.is_const() {
if a == b {
a_arith
} else {
let b_witness = self.get_or_create_witness(a);
b_arith = Expression::from(b_witness);
&b_arith
}
} else {
b
};
a_arith * b_arith
}
}
26 changes: 25 additions & 1 deletion crates/noirc_evaluator/src/ssa_refactor/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,31 @@ impl Context {
}
Vec::new()
}
_ => todo!("expected a black box function"),
Intrinsic::Sort => {
let inputs = vecmap(arguments, |arg| self.convert_value(*arg, dfg));
// We flatten the inputs and retrieve the bit_size of the elements
let mut input_vars = Vec::new();
let mut bit_size = 0;
for input in inputs {
for (var, typ) in input.flatten() {
input_vars.push(var);
if bit_size == 0 {
bit_size = typ.bit_size();
} else {
assert_eq!(
bit_size,
typ.bit_size(),
"cannot sort element of different bit size"
);
}
}
}
// Generate the sorted output variables
let out_vars =
self.acir_context.sort(input_vars, bit_size).expect("Could not sort");

Self::convert_vars_to_values(out_vars, dfg, result_ids)
}
}
}

Expand Down

0 comments on commit 6fa751b

Please sign in to comment.