Skip to content

Commit

Permalink
feat: Initial work on rewriting closures to regular functions with hi… (
Browse files Browse the repository at this point in the history
#1959)

* feat: Initial work on rewriting closures to regular functions with hidden env

This commit implements the following mechanism:

On a line where a lambda expression is encountered, we initialize a tuple for the captured lambda environment and we rewrite the lambda to a regular function taking this environment as an additional parameter. All calls to the closure are then modified to insert this hidden parameter.

In other words, the following code:

```
let x = some_value;
let closure = |a| x + a;

println(closure(10));
println(closure(20));
```

is rewritten to:

```
fn closure(env: (Field,), a: Field) -> Field {
  env.0 + a
}

let x = some_value;
let closure_env = (x,);

println(closure(closure_env, 10));
println(closure(closure_env, 20));
```

In the presence of nested closures, we propagate the captured variables implicitly through all intermediate closures:

```
let x = some_value;

let closure = |a, c|
  # here, `x` is initialized from the hidden env of the outer closure
  let inner_closure = |b|
    a + b + x

  inner_closure(c)
```

To make these transforms possible, the following changes were made to the logic of the HIR resolver and the monomorphization pass:

* In the HIR resolver pass, the code determines the precise list of variables captured by each lambda. Along with the list, we compute the index of each captured var within the parent closure's environment (when the capture is propagated).

* Introduction of a new `Closure` type in order to be able to recognize the call-sites that need the automatic environment variable treatment. It's a bit unfortunate that the Closure type is defined within the `AST` modules that are used to describe the output of the monomorphization pass, because we aim to eliminate all closures during the pass. A better solution would have been possible if the type check pass after HIR resolution was outputting types specific to the HIR pass (then the closures would exist only within this separate non-simplified type system).

* The majority of the work is in the Lambda processing step in the monomorphizer which performs the necessary transformations based on the above information.

Remaining things to do:

* There are a number of pending TODO items for various minor unresolved loose ends in the code.

* There are a lot of possible additional tests to be written.

* Update docs

* refactor: use panic, instead of println+assert

Co-authored-by: jfecher <[email protected]>

* test: add an initial monomorphization rewrite test

a lot of the machinery is copied from similar existing tests
the original authors also note some of those can be refactored
in something reusable

* fix: address some PR comments: comment/refactor/small fixes

* fix: use an unified Function object, fix some problems, comments

* fix: fix code, addressing `cargo clippy` warnings

* fix: replace type_of usage and remove it, as hinted in review

* test: move closure-related tests to test_data

* test: update closure rewrite test output

* chore: apply cargo fmt changes

* test: capture some variables in some tests, fix warnings, add a TODO

add a TODO about returning closures

* test: add simplification of #1088 as a resolve test, enable another test

* fix: fix unify for closures, fix display for fn/closure types

* test: update closure tests after resolving mutable bug

* fix: address some review comments for closure PR: fixes/cleanup

* refactor: cleanup, remove a line

Co-authored-by: jfecher <[email protected]>

* refactor: cleanup

Co-authored-by: jfecher <[email protected]>

* fix: fix bind_function_type env_type handling type variable binding

* test: improve higher_order_fn_selector test

* fix: remove skip_params/additional param logic from typechecking/display

* fix: don't use closure capture logic for lambdas without captures

* fix: apply cargo fmt & clippy

* chore: apply cargo fmt

* test: fix closure rewrite test: actually capture

* chore: remove type annotation for `params`

* chore: run cargo fmt

---------

Co-authored-by: jfecher <[email protected]>
Co-authored-by: Alex Vitkov <[email protected]>
  • Loading branch information
3 people authored Aug 2, 2023
1 parent f3f6fbe commit 35404ba
Show file tree
Hide file tree
Showing 27 changed files with 1,078 additions and 125 deletions.
6 changes: 6 additions & 0 deletions crates/nargo_cli/tests/test_data/closures_mut_ref/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "closures_mut_ref"
authors = [""]
compiler_version = "0.8.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "0"
20 changes: 20 additions & 0 deletions crates/nargo_cli/tests/test_data/closures_mut_ref/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use dep::std;

fn main(mut x: Field) {
let one = 1;
let add1 = |z| {
*z = *z + one;
};

let two = 2;
let add2 = |z| {
*z = *z + two;
};

add1(&mut x);
assert(x == 1);

add2(&mut x);
assert(x == 3);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "higher_order_fn_selector"
authors = [""]
compiler_version = "0.8.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use dep::std;

fn g(x: &mut Field) -> () {
*x *= 2;
}

fn h(x: &mut Field) -> () {
*x *= 3;
}

fn selector(flag: &mut bool) -> fn(&mut Field) -> () {
let my_func = if *flag {
g
} else {
h
};

// Flip the flag for the next function call
*flag = !(*flag);
my_func
}

fn main() {

let mut flag: bool = true;

let mut x: Field = 100;
let returned_func = selector(&mut flag);
returned_func(&mut x);

assert(x == 200);

let mut y: Field = 100;
let returned_func2 = selector(&mut flag);
returned_func2(&mut y);

assert(y == 300);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "higher_order_functions"
authors = [""]
compiler_version = "0.1"

[dependencies]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use dep::std;

fn main() -> pub Field {
let f = if 3 * 7 > 200 as u32 { foo } else { bar };
assert(f()[1] == 2);
// Lambdas:
assert(twice(|x| x * 2, 5) == 20);
assert((|x, y| x + y + 1)(2, 3) == 6);

// nested lambdas
assert((|a, b| {
a + (|c| c + 2)(b)
})(0, 1) == 3);


// Closures:
let a = 42;
let g = || a;
assert(g() == 42);

// When you copy mutable variables,
// the capture of the copies shouldn't change:
let mut x = 2;
x = x + 1;
let z = x;

// Add extra mutations to ensure we can mutate x without the
// captured z changing.
x = x + 1;
assert((|y| y + z)(1) == 4);

// When you capture mutable variables,
// again, the captured variable doesn't change:
let closure_capturing_mutable = (|y| y + x);
assert(closure_capturing_mutable(1) == 5);
x += 1;
assert(closure_capturing_mutable(1) == 5);

let ret = twice(add1, 3);

test_array_functions();
ret
}

/// Test the array functions in std::array
fn test_array_functions() {
let myarray: [i32; 3] = [1, 2, 3];
assert(myarray.any(|n| n > 2));

let evens: [i32; 3] = [2, 4, 6];
assert(evens.all(|n| n > 1));

assert(evens.fold(0, |a, b| a + b) == 12);
assert(evens.reduce(|a, b| a + b) == 12);

// TODO: is this a sort_via issue with the new backend,
// or something more general?
//
// currently it fails only with `--experimental-ssa` with
// "not yet implemented: Cast into signed"
// but it worked with the original ssa backend
// (before dropping it)
//
// opened #2121 for it
// https://github.com/noir-lang/noir/issues/2121

// let descending = myarray.sort_via(|a, b| a > b);
// assert(descending == [3, 2, 1]);

assert(evens.map(|n| n / 2) == myarray);
}

fn foo() -> [u32; 2] {
[1, 3]
}

fn bar() -> [u32; 2] {
[3, 2]
}

fn add1(x: Field) -> Field {
x + 1
}

fn twice(f: fn(Field) -> Field, x: Field) -> Field {
f(f(x))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"backend":"acvm-backend-barretenberg","abi":{"parameters":[],"param_witnesses":{},"return_type":null,"return_witnesses":[]},"bytecode":[155,194,56,97,194,4,0],"proving_key":null,"verification_key":null}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"backend":"acvm-backend-barretenberg","abi":{"parameters":[{"name":"x","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"y","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"z","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"}],"param_witnesses":{"x":[1],"y":[2],"z":[3]},"return_type":null,"return_witnesses":[]},"bytecode":"H4sIAAAAAAAA/9WUTW6DMBSEJ/yFhoY26bYLjoAxBLPrVYpK7n+EgmoHamWXeShYQsYSvJ+Z9/kDwCf+1m58ArsXi3PgnUN7dt/u7P9fdi8fW8rlATduCW89GFe5l2iMES90YBd+EyTyjIjtGYIm+HF1eanroa0GpdV3WXW9acq66S9GGdWY5qcyWg+mNm3Xd23ZqVoP6tp0+moDJ5AxNOTUWdk6VUTsOSb6wtRPCuDYziaZAzGA92OMFCsAPCUqMAOcQg5gZwIb4BdsA+A9seeU6AtTPymAUzubZA7EAD6MMTKsAPCUqMAMcAY5gJ0JbIBfsQ2AD8SeM6IvTP2kAM7sbJI5EAP4OMbIsQLAU6ICM8A55AB2JrABfsM2AD4Se86Jvjy5freeQ2LPObGud6J+Ce5ADz6LzJqX9Z4W75HdgzszkQj0BC+Pr6PohSpl0kkg7hm84Zfq+8z36N/l9OyaLtcv2EfpKJUUAAA=","proving_key":null,"verification_key":null}
Binary file not shown.
6 changes: 6 additions & 0 deletions crates/nargo_cli/tests/test_data/inner_outer_cl/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "inner_outer_cl"
authors = [""]
compiler_version = "0.7.1"

[dependencies]
12 changes: 12 additions & 0 deletions crates/nargo_cli/tests/test_data/inner_outer_cl/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fn main() {
let z1 = 0;
let z2 = 1;
let cl_outer = |x| {
let cl_inner = |y| {
x + y + z2
};
cl_inner(1) + z1
};
let result = cl_outer(1);
assert(result == 3);
}
6 changes: 6 additions & 0 deletions crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "ret_fn_ret_cl"
authors = [""]
compiler_version = "0.7.1"

[dependencies]
1 change: 1 addition & 0 deletions crates/nargo_cli/tests/test_data/ret_fn_ret_cl/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "10"
39 changes: 39 additions & 0 deletions crates/nargo_cli/tests/test_data/ret_fn_ret_cl/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use dep::std;

fn f(x: Field) -> Field {
x + 1
}

fn ret_fn() -> fn(Field) -> Field {
f
}

// TODO: in the advanced implicitly generic function with closures branch
// which would support higher-order functions in a better way
// support returning closures:
//
// fn ret_closure() -> fn(Field) -> Field {
// let y = 1;
// let inner_closure = |z| -> Field{
// z + y
// };
// inner_closure
// }

fn ret_lambda() -> fn(Field) -> Field {
let cl = |z: Field| -> Field {
z + 1
};
cl
}

fn main(x : Field) {
let result_fn = ret_fn();
assert(result_fn(x) == x + 1);

// let result_closure = ret_closure();
// assert(result_closure(x) == x + 1);

let result_lambda = ret_lambda();
assert(result_lambda(x) == x + 1);
}
2 changes: 1 addition & 1 deletion crates/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ impl<'a> FunctionContext<'a> {
}
ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"),
ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"),
ast::Type::Function(_, _) => Type::Function,
ast::Type::Function(_, _, _) => Type::Function,
ast::Type::Slice(element) => {
let element_types = Self::convert_type(element).flatten();
Type::Slice(Rc::new(element_types))
Expand Down
4 changes: 2 additions & 2 deletions crates/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::hir::type_check::{type_check_func, TypeChecker};
use crate::hir::Context;
use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TypeAliasId};
use crate::{
ExpressionKind, Generics, Ident, LetStatement, NoirFunction, NoirStruct, NoirTypeAlias,
ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType, Literal,
ExpressionKind, Generics, Ident, LetStatement, Literal, NoirFunction, NoirStruct,
NoirTypeAlias, ParsedModule, Shared, Type, TypeBinding, UnresolvedGenerics, UnresolvedType,
};
use fm::FileId;
use iter_extended::vecmap;
Expand Down
Loading

0 comments on commit 35404ba

Please sign in to comment.