Skip to content

Commit

Permalink
feat: add syntax for specifying function type environments (#2357)
Browse files Browse the repository at this point in the history
alexvitkov authored Aug 18, 2023
1 parent 36fe1ee commit 495a479
Showing 9 changed files with 158 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "closure_explicit_types"
type = "bin"
authors = [""]
compiler_version = "0.10.3"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

fn ret_normal_lambda1() -> fn() -> Field {
|| 10
}

// explicitly specified empty capture group
fn ret_normal_lambda2() -> fn[]() -> Field {
|| 20
}

// return lamda that captures a thing
fn ret_closure1() -> fn[Field]() -> Field {
let x = 20;
|| x + 10
}

// return lamda that captures two things
fn ret_closure2() -> fn[Field,Field]() -> Field {
let x = 20;
let y = 10;
|| x + y + 10
}

// return lamda that captures two things with different types
fn ret_closure3() -> fn[u32,u64]() -> u64 {
let x: u32 = 20;
let y: u64 = 10;
|| x as u64 + y + 10
}

// accepts closure that has 1 thing in its env, calls it and returns the result
fn accepts_closure1(f: fn[Field]() -> Field) -> Field {
f()
}

// accepts closure that has 1 thing in its env and returns it
fn accepts_closure2(f: fn[Field]() -> Field) -> fn[Field]() -> Field {
f
}

// accepts closure with different types in the capture group
fn accepts_closure3(f: fn[u32, u64]() -> u64) -> u64 {
f()
}

fn main() {
assert(ret_normal_lambda1()() == 10);
assert(ret_normal_lambda2()() == 20);
assert(ret_closure1()() == 30);
assert(ret_closure2()() == 40);
assert(ret_closure3()() == 40);

let x = 50;
assert(accepts_closure1(|| x) == 50);
assert(accepts_closure2(|| x + 10)() == 60);

let y: u32 = 30;
let z: u64 = 40;
assert(accepts_closure3(|| y as u64 + z) == 70);
}
20 changes: 17 additions & 3 deletions crates/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
@@ -50,7 +50,11 @@ pub enum UnresolvedType {
// Note: Tuples have no visibility, instead each of their elements may have one.
Tuple(Vec<UnresolvedType>),

Function(/*args:*/ Vec<UnresolvedType>, /*ret:*/ Box<UnresolvedType>),
Function(
/*args:*/ Vec<UnresolvedType>,
/*ret:*/ Box<UnresolvedType>,
/*env:*/ Box<UnresolvedType>,
),

Unspecified, // This is for when the user declares a variable without specifying it's type
Error,
@@ -109,9 +113,19 @@ impl std::fmt::Display for UnresolvedType {
Some(len) => write!(f, "str<{len}>"),
},
FormatString(len, elements) => write!(f, "fmt<{len}, {elements}"),
Function(args, ret) => {
Function(args, ret, env) => {
let args = vecmap(args, ToString::to_string);
write!(f, "fn({}) -> {ret}", args.join(", "))

match &**env {
UnresolvedType::Unit => {
write!(f, "fn({}) -> {ret}", args.join(", "))
}
UnresolvedType::Tuple(env_types) => {
let env_types = vecmap(env_types, ToString::to_string);
write!(f, "fn[{}]({}) -> {ret}", env_types.join(", "), args.join(", "))
}
_ => unreachable!(),
}
}
MutableReference(element) => write!(f, "&mut {element}"),
Unit => write!(f, "()"),
4 changes: 2 additions & 2 deletions crates/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
@@ -361,10 +361,10 @@ impl<'a> Resolver<'a> {
UnresolvedType::Tuple(fields) => {
Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables)))
}
UnresolvedType::Function(args, ret) => {
UnresolvedType::Function(args, ret, env) => {
let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables));
let ret = Box::new(self.resolve_type_inner(*ret, new_variables));
let env = Box::new(Type::Unit);
let env = Box::new(self.resolve_type_inner(*env, new_variables));
Type::Function(args, ret, env)
}
UnresolvedType::MutableReference(element) => {
12 changes: 5 additions & 7 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
@@ -837,13 +837,11 @@ impl<'interner> TypeChecker<'interner> {
}

for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) {
if arg.try_unify_allow_incompat_lambdas(param).is_err() {
self.errors.push(TypeCheckError::TypeMismatch {
expected_typ: param.to_string(),
expr_typ: arg.to_string(),
expr_span: *arg_span,
});
}
self.unify(arg, param, || TypeCheckError::TypeMismatch {
expected_typ: param.to_string(),
expr_typ: arg.to_string(),
expr_span: *arg_span,
});
}

fn_ret.clone()
45 changes: 20 additions & 25 deletions crates/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
@@ -63,33 +63,28 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
let (expr_span, empty_function) = function_info(interner, function_body_id);

let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type.clone(),
span: func_span,
source: Source::Return(meta.return_type, expr_span),
};

let result = function_last_type.try_unify_allow_incompat_lambdas(&declared_return_type);

if result.is_err() {
function_last_type.unify_with_coercions(
&declared_return_type,
*function_body_id,
interner,
&mut errors,
|| {
let mut error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type.clone(),
span: func_span,
source: Source::Return(meta.return_type, expr_span),
};

if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
}
if empty_function {
error = error.add_context(
"implicitly returns `()` as its body has no tail or `return` expression",
);
}

error
},
);
}
error
},
);
}

errors
29 changes: 0 additions & 29 deletions crates/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
@@ -947,35 +947,6 @@ impl Type {
}
}

/// Similar to try_unify() but allows non-matching capture groups for function types
pub fn try_unify_allow_incompat_lambdas(&self, other: &Type) -> Result<(), UnificationError> {
use Type::*;
use TypeVariableKind::*;

match (self, other) {
(TypeVariable(binding, Normal), other) | (other, TypeVariable(binding, Normal)) => {
if let TypeBinding::Bound(link) = &*binding.borrow() {
return link.try_unify_allow_incompat_lambdas(other);
}

other.try_bind_to(binding)
}
(Function(params_a, ret_a, _), Function(params_b, ret_b, _)) => {
if params_a.len() == params_b.len() {
for (a, b) in params_a.iter().zip(params_b.iter()) {
a.try_unify_allow_incompat_lambdas(b)?;
}

// no check for environments here!
ret_b.try_unify_allow_incompat_lambdas(ret_a)
} else {
Err(UnificationError)
}
}
_ => self.try_unify(other),
}
}

/// Similar to `unify` but if the check fails this will attempt to coerce the
/// argument to the target type. When this happens, the given expression is wrapped in
/// a new expression to convert its type. E.g. `array` -> `array.as_slice()`
37 changes: 26 additions & 11 deletions crates/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
@@ -784,15 +784,27 @@ impl<'interner> Monomorphizer<'interner> {

let is_closure = self.is_function_closure(call.func);
if is_closure {
let extracted_func: ast::Expression;
let hir_call_func = self.interner.expression(&call.func);
if let HirExpression::Lambda(l) = hir_call_func {
let (setup, closure_variable) = self.lambda_with_setup(l, call.func);
block_expressions.push(setup);
extracted_func = closure_variable;
} else {
extracted_func = *original_func;
}
let local_id = self.next_local_id();

// store the function in a temporary variable before calling it
// this is needed for example if call.func is of the form `foo()()`
// without this, we would translate it to `foo().1(foo().0)`
let let_stmt = ast::Expression::Let(ast::Let {
id: local_id,
mutable: false,
name: "tmp".to_string(),
expression: Box::new(*original_func),
});
block_expressions.push(let_stmt);

let extracted_func = ast::Expression::Ident(ast::Ident {
location: None,
definition: Definition::Local(local_id),
mutable: false,
name: "tmp".to_string(),
typ: Self::convert_type(&self.interner.id_type(call.func)),
});

func = Box::new(ast::Expression::ExtractTupleField(
Box::new(extracted_func.clone()),
1usize,
@@ -1435,7 +1447,7 @@ mod tests {
#[test]
fn simple_closure_with_no_captured_variables() {
let src = r#"
fn main() -> Field {
fn main() -> pub Field {
let x = 1;
let closure = || x;
closure()
@@ -1451,7 +1463,10 @@ mod tests {
};
closure_variable$l2
};
closure$l3.1(closure$l3.0)
{
let tmp$4 = closure$l3;
tmp$l4.1(tmp$l4.0)
}
}
fn lambda$f1(mut env$l1: (Field)) -> Field {
env$l1.0
24 changes: 21 additions & 3 deletions crates/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
@@ -971,12 +971,30 @@ fn function_type<T>(type_parser: T) -> impl NoirParser<UnresolvedType>
where
T: NoirParser<UnresolvedType>,
{
let args = parenthesized(type_parser.clone().separated_by(just(Token::Comma)).allow_trailing());
let types = type_parser.clone().separated_by(just(Token::Comma)).allow_trailing();
let args = parenthesized(types.clone());

let env = just(Token::LeftBracket)
.ignore_then(types)
.then_ignore(just(Token::RightBracket))
.or_not()
.map(|args| match args {
Some(args) => {
if args.is_empty() {
UnresolvedType::Unit
} else {
UnresolvedType::Tuple(args)
}
}
None => UnresolvedType::Unit,
});

keyword(Keyword::Fn)
.ignore_then(args)
.ignore_then(env)
.then(args)
.then_ignore(just(Token::Arrow))
.then(type_parser)
.map(|(args, ret)| UnresolvedType::Function(args, Box::new(ret)))
.map(|((env, args), ret)| UnresolvedType::Function(args, Box::new(ret), Box::new(env)))
}

fn mutable_reference_type<T>(type_parser: T) -> impl NoirParser<UnresolvedType>

0 comments on commit 495a479

Please sign in to comment.