Skip to content

Commit

Permalink
feat(stdlib): Add higher order array functions (#833)
Browse files Browse the repository at this point in the history
* Add higher order array functions

* Fix new functions

* Replace for-each loops with regular for loops:

* Fix type errors

* Remove redundant clone
  • Loading branch information
jfecher authored Feb 14, 2023
1 parent ad5d889 commit 9c62fef
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
22 changes: 21 additions & 1 deletion crates/nargo/tests/test_data/higher-order-functions/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use dep::std;

fn main() -> pub Field {
let f = if 3 * 7 > 200 { foo } else { bar };
constrain f()[1] == 2;
Expand All @@ -22,8 +24,26 @@ fn main() -> pub Field {
x = x + 1;
constrain (|y| y + z)(1) == 4;
x = x + 1;
let ret = twice(add1, 3);

test_array_functions();

ret
}

/// Test the array functions in std::array
fn test_array_functions() {
let myarray: [i32; 3] = [1, 2, 3];
constrain std::array::any(myarray, |n| n > 2);

let evens: [i32; 3] = [2, 4, 6];
constrain std::array::all(evens, |n| n > 1);

twice(add1, 3)
constrain std::array::fold(evens, 0, |a, b| a + b) == 12;
constrain std::array::reduce(evens, |a, b| a + b) == 12;

let descending = std::array::sort_via(myarray, |a, b| a > b);
constrain descending == [3, 2, 1];
}

fn foo() -> [u32; 2] {
Expand Down
2 changes: 1 addition & 1 deletion crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ pub fn comparator_operand_type_rules(

let comptime = CompTime::No(None);
if other.try_bind_to_polymorphic_int(var, &comptime, true, op.location.span).is_ok() || other == &Type::Error {
Ok(other.clone())
Ok(Bool(comptime))
} else {
Err(format!("Types in a binary operation should match, but found {lhs_type} and {rhs_type}"))
}
Expand Down
53 changes: 53 additions & 0 deletions noir_stdlib/src/array.nr
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,56 @@ fn len<T>(_input : [T]) -> comptime Field {}

#[builtin(arraysort)]
fn sort<T, N>(_a: [T; N]) -> [T; N] {}

// Sort with a custom sorting function.
fn sort_via<T, N>(mut a: [T; N], ordering: fn(T, T) -> bool) -> [T; N] {
for i in 1..len(a) {
for j in 0..i {
if ordering(a[i], a[j]) {
let old_a_j = a[j];
a[j] = a[i];
a[i] = old_a_j;
}
}
}
a
}

// Apply a function to each element of the array and an accumulator value,
// returning the final accumulated value. This function is also sometimes
// called `foldl`, `fold_left`, `reduce`, or `inject`.
fn fold<T, U, N>(array: [T; N], mut accumulator: U, f: fn(U, T) -> U) -> U {
for i in 0 .. len(array) {
accumulator = f(accumulator, array[i]);
}
accumulator
}

// Apply a function to each element of the array and an accumulator value,
// returning the final accumulated value. Unlike fold, reduce uses the first
// element of the given array as its starting accumulator value.
fn reduce<T, N>(array: [T; N], f: fn(T, T) -> T) -> T {
let mut accumulator = array[0];
for i in 1 .. len(array) {
accumulator = f(accumulator, array[i]);
}
accumulator
}

// Returns true if all elements in the array satisfy the predicate
fn all<T, N>(array: [T; N], predicate: fn(T) -> bool) -> bool {
let mut ret = true;
for i in 0 .. len(array) {
ret &= predicate(array[i]);
}
ret
}

// Returns true if any element in the array satisfies the predicate
fn any<T, N>(array: [T; N], predicate: fn(T) -> bool) -> bool {
let mut ret = false;
for i in 0 .. len(array) {
ret |= predicate(array[i]);
}
ret
}

0 comments on commit 9c62fef

Please sign in to comment.