Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add de-sugaring for impl Trait in function parameters #4919

Merged
merged 9 commits into from
Apr 29, 2024
52 changes: 48 additions & 4 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ArrayLiteral, BinaryOpKind, BlockExpression, Distinctness, Expression, ExpressionKind,
ForRange, FunctionDefinition, FunctionKind, FunctionReturnType, Ident, ItemVisibility, LValue,
LetStatement, Literal, NoirFunction, NoirStruct, NoirTypeAlias, Param, Path, PathKind, Pattern,
Statement, StatementKind, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint,
Statement, StatementKind, TraitBound, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint,
UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
use crate::graph::CrateId;
Expand Down Expand Up @@ -193,16 +193,55 @@
/// Since lowering would require scope data, unless we add an extra resolution field to the AST
pub fn resolve_function(
mut self,
func: NoirFunction,
mut func: NoirFunction,
func_id: FuncId,
) -> (HirFunction, FuncMeta, Vec<ResolverError>) {
self.scopes.start_function();
self.current_item = Some(DependencyId::Function(func_id));

// Check whether the function has globals in the local module and add them to the scope
self.resolve_local_globals();

self.add_generics(&func.def.generics);

// TODO: better way to get a unique generic/global ident?
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
// e.g. split ident into (current | internal_counter)
let mut impl_trait_generics = HashSet::new();
let mut counter: usize = 0;
for parameter in func.def.parameters.iter_mut() {
if let UnresolvedTypeData::TraitAsType(path, args) = &parameter.typ.typ {
let mut new_generic_ident: Ident = "T_impl_trait".into();
let mut new_generic_path = Path::from_ident(new_generic_ident.clone());
while impl_trait_generics.contains(&new_generic_ident)
|| self.lookup_generic_or_global_type(&new_generic_path).is_some()
{
new_generic_ident = format!("T_impl_trait_{}", counter).into();
new_generic_path = Path::from_ident(new_generic_ident.clone());
counter += 1;
}
impl_trait_generics.insert(new_generic_ident.clone());

let is_synthesized = true;
let new_generic_type_data =
UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized);
let new_generic_type =
UnresolvedType { typ: new_generic_type_data.clone(), span: None };
let new_trait_bound = TraitBound {
trait_path: path.clone(),
trait_id: None,
trait_generics: args.to_vec(),
};
let new_trait_constraint = UnresolvedTraitConstraint {
typ: new_generic_type,
trait_bound: new_trait_bound,
};

parameter.typ.typ = new_generic_type_data;
func.def.generics.push(new_generic_ident);
func.def.where_clause.push(new_trait_constraint);
}
}
self.add_generics(&impl_trait_generics.into_iter().collect());

self.trait_bounds = func.def.where_clause.clone();

let is_low_level_or_oracle = func
Expand Down Expand Up @@ -624,7 +663,7 @@
.iter()
.any(|attr| matches!(attr, SecondaryAttribute::Abi(_)))
{
self.push_err(ResolverError::AbiAttributeOusideContract {

Check warning on line 666 in compiler/noirc_frontend/src/hir/resolution/resolver.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Ouside)
span: struct_type.borrow().name.span(),
});
}
Expand Down Expand Up @@ -1123,10 +1162,15 @@
| Type::TypeVariable(_, _)
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::TraitAsType(..)
| Type::Code
| Type::Forall(_, _) => (),

Type::TraitAsType(_, _, args) => {
for arg in args {
Self::find_numeric_generics_in_type(arg, found);
}
}

Type::Array(length, element_type) => {
if let Type::NamedGeneric(type_variable, name) = length.as_ref() {
found.insert(name.to_string(), type_variable.clone());
Expand Down Expand Up @@ -1201,7 +1245,7 @@
if !self.in_contract
&& let_stmt.attributes.iter().any(|attr| matches!(attr, SecondaryAttribute::Abi(_)))
{
self.push_err(ResolverError::AbiAttributeOusideContract {

Check warning on line 1248 in compiler/noirc_frontend/src/hir/resolution/resolver.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Ouside)
span: let_stmt.pattern.span(),
});
}
Expand Down
40 changes: 25 additions & 15 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,11 @@
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::Forall(_, _)
| Type::Code
| Type::TraitAsType(..) => false,
| Type::Code => false,

Type::TraitAsType(_, _, args) => {
args.iter().any(|generic| generic.contains_numeric_typevar(target_id))
}
Type::Array(length, elem) => {
elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length)
}
Expand Down Expand Up @@ -756,9 +758,12 @@
| Type::MutableReference(_)
| Type::Forall(_, _)
// TODO: probably can allow code as it is all compile time
| Type::Code
| Type::TraitAsType(..) => false,
| Type::Code => false,

| Type::TraitAsType(s, name, generics) => {
panic!("is_valid_non_inlined_function_input: {:?}, {:?}, {:?}", *s, name.clone(), generics.clone())

},
Type::Alias(alias, generics) => {
let alias = alias.borrow();
alias.get_type(generics).is_valid_non_inlined_function_input()
Expand Down Expand Up @@ -1571,7 +1576,7 @@
Type::Tuple(fields)
}
Type::Forall(typevars, typ) => {
// Trying to substitute_helper a variable de, substitute_bound_typevarsfined within a nested Forall

Check warning on line 1579 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (typevarsfined)
// is usually impossible and indicative of an error in the type checker somewhere.
for var in typevars {
assert!(!type_bindings.contains_key(&var.id()));
Expand All @@ -1591,11 +1596,17 @@
element.substitute_helper(type_bindings, substitute_bound_typevars),
)),

Type::TraitAsType(s, name, args) => {
let args = vecmap(args, |arg| {
arg.substitute_helper(type_bindings, substitute_bound_typevars)
});
Type::TraitAsType(*s, name.clone(), args)
}

Type::FieldElement
| Type::Integer(_, _)
| Type::Bool
| Type::Constant(_)
| Type::TraitAsType(..)
| Type::Error
| Type::Code
| Type::Unit => self.clone(),
Expand All @@ -1613,7 +1624,9 @@
let field_occurs = fields.occurs(target_id);
len_occurs || field_occurs
}
Type::Struct(_, generic_args) | Type::Alias(_, generic_args) => {
Type::Struct(_, generic_args)
| Type::Alias(_, generic_args)
| Type::TraitAsType(_, _, generic_args) => {
generic_args.iter().any(|arg| arg.occurs(target_id))
}
Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)),
Expand All @@ -1637,7 +1650,6 @@
| Type::Integer(_, _)
| Type::Bool
| Type::Constant(_)
| Type::TraitAsType(..)
| Type::Error
| Type::Code
| Type::Unit => false,
Expand Down Expand Up @@ -1689,16 +1701,14 @@

MutableReference(element) => MutableReference(Box::new(element.follow_bindings())),

TraitAsType(s, name, args) => {
let args = vecmap(args, |arg| arg.follow_bindings());
TraitAsType(*s, name.clone(), args)
}

// Expect that this function should only be called on instantiated types
Forall(..) => unreachable!(),
TraitAsType(..)
| FieldElement
| Integer(_, _)
| Bool
| Constant(_)
| Unit
| Code
| Error => self.clone(),
FieldElement | Integer(_, _) | Bool | Constant(_) | Unit | Code | Error => self.clone(),
}
}

Expand Down
11 changes: 11 additions & 0 deletions compiler/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1840,4 +1840,15 @@ mod test {

check_cases_with_errors(&cases[..], block(fresh_statement()));
}

#[test]
fn parse_function_impl_parameter() {
parse_all(
program(),
vec![
"fn func_name(x: impl Eq) {}",
"fn func_name<T>(x: impl Eq, y : T) where T: SomeTrait + Eq {}",
],
);
}
}
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/parser/parser/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ mod test {
"fn func_name<T>(f: Field, y : T) where T: SomeTrait + {}",
// The following should produce compile error on later stage. From the parser's perspective it's fine
"fn func_name<A>(f: Field, y : Field, z : Field) where T: SomeTrait {}",
// TODO: this fails with known EOF != EOF error
// fn func_name(x: impl Eq) {} with error Expected an end of input but found end of input
// "fn func_name(x: impl Eq) {}",
"fn func_name<T>(x: impl Eq, y : T) where T: SomeTrait + Eq {}",
],
);

Expand Down
64 changes: 64 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,70 @@ mod test {
}
}

#[test]
fn check_trait_as_type_as_fn_parameter() {
let src = "
trait Eq {
fn eq(self, other: Self) -> bool;
}

struct Foo {
a: u64,
}

impl Eq for Foo {
fn eq(self, other: Foo) -> bool { self.a == other.a }
}

fn test_eq(x: impl Eq) -> bool {
x.eq(x)
}

fn main(a: Foo) -> pub bool {
test_eq(a)
}";

let errors = get_program_errors(src);
errors.iter().for_each(|err| println!("{:?}", err));
assert!(errors.is_empty());
}

#[test]
fn check_trait_as_type_as_two_fn_parameters() {
let src = "
trait Eq {
fn eq(self, other: Self) -> bool;
}

trait Test {
fn test(self) -> bool;
}

struct Foo {
a: u64,
}

impl Eq for Foo {
fn eq(self, other: Foo) -> bool { self.a == other.a }
}

impl Test for u64 {
fn test(self) -> bool { self == self }
}

fn test_eq(x: impl Eq, y: impl Test) -> bool {
x.eq(x) == y.test()
}

fn main(a: Foo, b: u64) -> pub bool {
test_eq(a, b)
}";

let errors = get_program_errors(src);
errors.iter().for_each(|err| println!("{:?}", err));
assert!(errors.is_empty());
}

fn get_program_captures(src: &str) -> Vec<Vec<String>> {
let (program, context, _errors) = get_program(src);
let interner = context.def_interner;
Expand Down
3 changes: 3 additions & 0 deletions tooling/nargo_fmt/tests/expected/impl_trait_fn_parameter.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn func_name(x: impl Eq) {}

fn func_name<T>(x: impl Eq, y: T) where T: SomeTrait + Eq {}
3 changes: 3 additions & 0 deletions tooling/nargo_fmt/tests/input/impl_trait_fn_parameter.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn func_name(x: impl Eq) {}

fn func_name<T>(x: impl Eq, y: T) where T: SomeTrait + Eq {}
Loading