Skip to content

Commit

Permalink
feat!: type-check trait default methods (noir-lang#6645)
Browse files Browse the repository at this point in the history
  • Loading branch information
asterite authored Jan 7, 2025
1 parent 5b9a113 commit 8bb3908
Show file tree
Hide file tree
Showing 17 changed files with 306 additions and 82 deletions.
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,8 @@ impl FunctionDefinition {
is_unconstrained: bool,
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
body: &BlockExpression,
where_clause: &[UnresolvedTraitConstraint],
body: BlockExpression,
where_clause: Vec<UnresolvedTraitConstraint>,
return_type: &FunctionReturnType,
) -> FunctionDefinition {
let p = parameters
Expand All @@ -843,9 +843,9 @@ impl FunctionDefinition {
visibility: ItemVisibility::Private,
generics: generics.clone(),
parameters: p,
body: body.clone(),
body,
span: name.span(),
where_clause: where_clause.to_vec(),
where_clause,
return_type: return_type.clone(),
return_visibility: Visibility::Private,
}
Expand Down
9 changes: 7 additions & 2 deletions compiler/noirc_frontend/src/ast/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,27 @@ pub struct NoirFunction {
pub def: FunctionDefinition,
}

/// Currently, we support three types of functions:
/// Currently, we support four types of functions:
/// - Normal functions
/// - LowLevel/Foreign which link to an OPCODE in ACIR
/// - BuiltIn which are provided by the runtime
/// - TraitFunctionWithoutBody for which we don't type-check their body
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum FunctionKind {
LowLevel,
Builtin,
Normal,
Oracle,
TraitFunctionWithoutBody,
}

impl FunctionKind {
pub fn can_ignore_return_type(self) -> bool {
match self {
FunctionKind::LowLevel | FunctionKind::Builtin | FunctionKind::Oracle => true,
FunctionKind::LowLevel
| FunctionKind::Builtin
| FunctionKind::Oracle
| FunctionKind::TraitFunctionWithoutBody => true,
FunctionKind::Normal => false,
}
}
Expand Down
45 changes: 37 additions & 8 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ impl<'context> Elaborator<'context> {
self.elaborate_functions(functions);
}

for (trait_id, unresolved_trait) in items.traits {
self.current_trait = Some(trait_id);
self.elaborate_functions(unresolved_trait.fns_with_default_impl);
}
self.current_trait = None;

for impls in items.impls.into_values() {
self.elaborate_impls(impls);
}
Expand Down Expand Up @@ -472,9 +478,10 @@ impl<'context> Elaborator<'context> {
self.add_trait_constraints_to_scope(&func_meta);

let (hir_func, body_type) = match kind {
FunctionKind::Builtin | FunctionKind::LowLevel | FunctionKind::Oracle => {
(HirFunction::empty(), Type::Error)
}
FunctionKind::Builtin
| FunctionKind::LowLevel
| FunctionKind::Oracle
| FunctionKind::TraitFunctionWithoutBody => (HirFunction::empty(), Type::Error),
FunctionKind::Normal => {
let (block, body_type) = self.elaborate_block(body);
let expr_id = self.intern_expr(block, body_span);
Expand All @@ -494,11 +501,7 @@ impl<'context> Elaborator<'context> {
// when multiple impls are available. Instead we default first to choose the Field or u64 impl.
self.check_and_pop_function_context();

// Now remove all the `where` clause constraints we added
for constraint in &func_meta.trait_constraints {
self.interner
.remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id);
}
self.remove_trait_constraints_from_scope(&func_meta);

let func_scope_tree = self.scopes.end_function();

Expand Down Expand Up @@ -1019,6 +1022,32 @@ impl<'context> Elaborator<'context> {
constraint.trait_bound.trait_id,
);
}

// Also assume `self` implements the current trait if we are inside a trait definition
if let Some(trait_id) = self.current_trait {
let the_trait = self.interner.get_trait(trait_id);
let constraint = the_trait.as_constraint(the_trait.name.span());
let self_type =
self.self_type.clone().expect("Expected a self type if there's a current trait");
self.add_trait_bound_to_scope(
func_meta,
&self_type,
&constraint.trait_bound,
constraint.trait_bound.trait_id,
);
}
}

fn remove_trait_constraints_from_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
self.interner
.remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id);
}

// Also remove the assumed trait implementation for `self` if this is a trait definition
if let Some(trait_id) = self.current_trait {
self.interner.remove_assumed_trait_implementations_for_trait(trait_id);
}
}

fn add_trait_bound_to_scope(
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ impl<'context> Elaborator<'context> {

let impl_kind = match method {
HirMethodReference::FuncId(_) => ImplKind::NotATraitMethod,
HirMethodReference::TraitMethodId(method_id, generics) => {
HirMethodReference::TraitMethodId(method_id, generics, _) => {
let mut constraint =
self.interner.get_trait(method_id.trait_id).as_constraint(span);
constraint.trait_bound.trait_generics = generics;
Expand Down
43 changes: 34 additions & 9 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ impl<'context> Elaborator<'context> {
self.recover_generics(|this| {
this.current_trait = Some(*trait_id);

let the_trait = this.interner.get_trait(*trait_id);
let self_typevar = the_trait.self_type_typevar.clone();
let self_type = Type::TypeVariable(self_typevar.clone());
this.self_type = Some(self_type.clone());

let resolved_generics = this.interner.get_trait(*trait_id).generics.clone();
this.add_existing_generics(
&unresolved_trait.trait_def.generics,
Expand All @@ -48,12 +53,15 @@ impl<'context> Elaborator<'context> {
.add_trait_dependency(DependencyId::Trait(bound.trait_id), *trait_id);
}

this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});

let methods = this.resolve_trait_methods(*trait_id, unresolved_trait);

this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_methods(methods);
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});
});

Expand Down Expand Up @@ -94,7 +102,7 @@ impl<'context> Elaborator<'context> {
parameters,
return_type,
where_clause,
body: _,
body,
is_unconstrained,
visibility: _,
is_comptime: _,
Expand All @@ -103,7 +111,6 @@ impl<'context> Elaborator<'context> {
self.recover_generics(|this| {
let the_trait = this.interner.get_trait(trait_id);
let self_typevar = the_trait.self_type_typevar.clone();
let self_type = Type::TypeVariable(self_typevar.clone());
let name_span = the_trait.name.span();

this.add_existing_generic(
Expand All @@ -115,9 +122,12 @@ impl<'context> Elaborator<'context> {
span: name_span,
},
);
this.self_type = Some(self_type.clone());

let func_id = unresolved_trait.method_ids[&name.0.contents];
let mut where_clause = where_clause.to_vec();

// Attach any trait constraints on the trait to the function
where_clause.extend(unresolved_trait.trait_def.where_clause.clone());

this.resolve_trait_function(
trait_id,
Expand All @@ -127,6 +137,7 @@ impl<'context> Elaborator<'context> {
parameters,
return_type,
where_clause,
body,
unresolved_trait.trait_def.visibility,
func_id,
);
Expand Down Expand Up @@ -189,21 +200,29 @@ impl<'context> Elaborator<'context> {
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
return_type: &FunctionReturnType,
where_clause: &[UnresolvedTraitConstraint],
where_clause: Vec<UnresolvedTraitConstraint>,
body: &Option<BlockExpression>,
trait_visibility: ItemVisibility,
func_id: FuncId,
) {
let old_generic_count = self.generics.len();

self.scopes.start_function();

let kind = FunctionKind::Normal;
let has_body = body.is_some();

let body = match body {
Some(body) => body.clone(),
None => BlockExpression { statements: Vec::new() },
};
let kind =
if has_body { FunctionKind::Normal } else { FunctionKind::TraitFunctionWithoutBody };
let mut def = FunctionDefinition::normal(
name,
is_unconstrained,
generics,
parameters,
&BlockExpression { statements: Vec::new() },
body,
where_clause,
return_type,
);
Expand All @@ -213,7 +232,13 @@ impl<'context> Elaborator<'context> {

let mut function = NoirFunction { kind, def };
self.define_function_meta(&mut function, func_id, Some(trait_id));
self.elaborate_function(func_id);

// Here we elaborate functions without a body, mainly to check the arguments and return types.
// Later on we'll elaborate functions with a body by fully type-checking them.
if !has_body {
self.elaborate_function(func_id);
}

let _ = self.scopes.end_function();
// Don't check the scope tree for unused variables, they can't be used in a declaration anyway.
self.generics.truncate(old_generic_count);
Expand Down
29 changes: 27 additions & 2 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,17 @@ impl<'context> Elaborator<'context> {
}

// this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type)
// or inside a trait default method.
//
// Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not
// E.g. `t.method()` with `where T: Foo<Bar>` in scope will return `(Foo::method, T, vec![Bar])`
fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option<TraitPathResolution> {
let trait_impl = self.current_trait_impl?;
let trait_id = self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id;
let trait_id = if let Some(current_trait) = self.current_trait {
current_trait
} else {
let trait_impl = self.current_trait_impl?;
self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id
};

if path.kind == PathKind::Plain && path.segments.len() == 2 {
let name = &path.segments[0].ident.0.contents;
Expand Down Expand Up @@ -1395,6 +1400,25 @@ impl<'context> Elaborator<'context> {
};
let func_meta = self.interner.function_meta(&func_id);

// If inside a trait method, check if it's a method on `self`
if let Some(trait_id) = func_meta.trait_id {
if Some(object_type) == self.self_type.as_ref() {
let the_trait = self.interner.get_trait(trait_id);
let constraint = the_trait.as_constraint(the_trait.name.span());
if let Some(HirMethodReference::TraitMethodId(method_id, generics, _)) = self
.lookup_method_in_trait(
the_trait,
method_name,
&constraint.trait_bound,
the_trait.id,
)
{
// If it is, it's an assumed trait
return Some(HirMethodReference::TraitMethodId(method_id, generics, true));
}
}
}

for constraint in &func_meta.trait_constraints {
if *object_type == constraint.typ {
if let Some(the_trait) =
Expand Down Expand Up @@ -1432,6 +1456,7 @@ impl<'context> Elaborator<'context> {
return Some(HirMethodReference::TraitMethodId(
trait_method,
trait_bound.trait_generics.clone(),
false,
));
}

Expand Down
8 changes: 5 additions & 3 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ impl<'a> ModCollector<'a> {
is_comptime,
} => {
let func_id = context.def_interner.push_empty_fn();
method_ids.insert(name.to_string(), func_id);
if !method_ids.contains_key(&name.0.contents) {
method_ids.insert(name.to_string(), func_id);
}

let location = Location::new(name.span(), self.file_id);
let modifiers = FunctionModifiers {
Expand Down Expand Up @@ -521,8 +523,8 @@ impl<'a> ModCollector<'a> {
*is_unconstrained,
generics,
parameters,
body,
where_clause,
body.clone(),
where_clause.clone(),
return_type,
));
unresolved_functions.push_fn(
Expand Down
9 changes: 5 additions & 4 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ pub enum HirMethodReference {
/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphization.
TraitMethodId(TraitMethodId, TraitGenerics),
TraitMethodId(TraitMethodId, TraitGenerics, bool /* assumed */),
}

impl HirMethodReference {
pub fn func_id(&self, interner: &NodeInterner) -> Option<FuncId> {
match self {
HirMethodReference::FuncId(func_id) => Some(*func_id),
HirMethodReference::TraitMethodId(method_id, _) => {
HirMethodReference::TraitMethodId(method_id, _, _) => {
let id = interner.trait_method_id(*method_id);
match &interner.try_definition(id)?.kind {
DefinitionKind::Function(func_id) => Some(*func_id),
Expand Down Expand Up @@ -246,7 +246,7 @@ impl HirMethodCallExpression {
HirMethodReference::FuncId(func_id) => {
(interner.function_definition_id(func_id), ImplKind::NotATraitMethod)
}
HirMethodReference::TraitMethodId(method_id, trait_generics) => {
HirMethodReference::TraitMethodId(method_id, trait_generics, assumed) => {
let id = interner.trait_method_id(method_id);
let constraint = TraitConstraint {
typ: object_type,
Expand All @@ -256,7 +256,8 @@ impl HirMethodCallExpression {
span: location.span,
},
};
(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed: false }))

(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed }))
}
};
let func_var = HirIdent { location, id, impl_kind };
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_frontend/src/hir_def/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,13 @@ pub enum FunctionBody {

impl FuncMeta {
/// A stub function does not have a body. This includes Builtin, LowLevel,
/// and Oracle functions in addition to method declarations within a trait.
/// and Oracle functions in addition to method declarations within a trait
/// without a body.
///
/// We don't check the return type of these functions since it will always have
/// an empty body, and we don't check for unused parameters.
pub fn is_stub(&self) -> bool {
self.kind.can_ignore_return_type() || self.trait_id.is_some()
self.kind.can_ignore_return_type()
}

pub fn function_signature(&self) -> FunctionSignature {
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl<'interner> Monomorphizer<'interner> {
);
Definition::Builtin(opcode.to_string())
}
FunctionKind::Normal => {
FunctionKind::Normal | FunctionKind::TraitFunctionWithoutBody => {
let id =
self.queue_function(id, expr_id, typ, turbofish_generics, trait_method);
Definition::Function(id)
Expand Down
Loading

0 comments on commit 8bb3908

Please sign in to comment.