Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: type-check trait default methods #6645

Merged
merged 15 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,8 @@ impl FunctionDefinition {
is_unconstrained: bool,
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
body: &BlockExpression,
where_clause: &[UnresolvedTraitConstraint],
body: BlockExpression,
where_clause: Vec<UnresolvedTraitConstraint>,
return_type: &FunctionReturnType,
) -> FunctionDefinition {
let p = parameters
Expand All @@ -843,9 +843,9 @@ impl FunctionDefinition {
visibility: ItemVisibility::Private,
generics: generics.clone(),
parameters: p,
body: body.clone(),
body,
span: name.span(),
where_clause: where_clause.to_vec(),
where_clause,
return_type: return_type.clone(),
return_visibility: Visibility::Private,
}
Expand Down
9 changes: 7 additions & 2 deletions compiler/noirc_frontend/src/ast/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,27 @@ pub struct NoirFunction {
pub def: FunctionDefinition,
}

/// Currently, we support three types of functions:
/// Currently, we support four types of functions:
/// - Normal functions
/// - LowLevel/Foreign which link to an OPCODE in ACIR
/// - BuiltIn which are provided by the runtime
/// - TraitFunctionWithoutBody for which we don't type-check their body
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum FunctionKind {
LowLevel,
Builtin,
Normal,
Oracle,
TraitFunctionWithoutBody,
}

impl FunctionKind {
pub fn can_ignore_return_type(self) -> bool {
match self {
FunctionKind::LowLevel | FunctionKind::Builtin | FunctionKind::Oracle => true,
FunctionKind::LowLevel
| FunctionKind::Builtin
| FunctionKind::Oracle
| FunctionKind::TraitFunctionWithoutBody => true,
FunctionKind::Normal => false,
}
}
Expand Down
45 changes: 37 additions & 8 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ impl<'context> Elaborator<'context> {
self.elaborate_functions(functions);
}

for (trait_id, unresolved_trait) in items.traits {
self.current_trait = Some(trait_id);
self.elaborate_functions(unresolved_trait.fns_with_default_impl);
}
self.current_trait = None;

for impls in items.impls.into_values() {
self.elaborate_impls(impls);
}
Expand Down Expand Up @@ -454,9 +460,10 @@ impl<'context> Elaborator<'context> {
self.add_trait_constraints_to_scope(&func_meta);

let (hir_func, body_type) = match kind {
FunctionKind::Builtin | FunctionKind::LowLevel | FunctionKind::Oracle => {
(HirFunction::empty(), Type::Error)
}
FunctionKind::Builtin
| FunctionKind::LowLevel
| FunctionKind::Oracle
| FunctionKind::TraitFunctionWithoutBody => (HirFunction::empty(), Type::Error),
FunctionKind::Normal => {
let (block, body_type) = self.elaborate_block(body);
let expr_id = self.intern_expr(block, body_span);
Expand All @@ -476,11 +483,7 @@ impl<'context> Elaborator<'context> {
// when multiple impls are available. Instead we default first to choose the Field or u64 impl.
self.check_and_pop_function_context();

// Now remove all the `where` clause constraints we added
for constraint in &func_meta.trait_constraints {
self.interner
.remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id);
}
self.remove_trait_constraints_from_scope(&func_meta);

let func_scope_tree = self.scopes.end_function();

Expand Down Expand Up @@ -1001,6 +1004,32 @@ impl<'context> Elaborator<'context> {
constraint.trait_bound.trait_id,
);
}

// Also assume `self` implements the current trait if we are inside a trait definition
if let Some(trait_id) = self.current_trait {
let the_trait = self.interner.get_trait(trait_id);
let constraint = the_trait.as_constraint(the_trait.name.span());
let self_type =
self.self_type.clone().expect("Expected a self type if there's a current trait");
self.add_trait_bound_to_scope(
func_meta,
&self_type,
&constraint.trait_bound,
constraint.trait_bound.trait_id,
);
}
}

fn remove_trait_constraints_from_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
self.interner
.remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id);
}

// Also remove the assumed trait implementation for `self` if this is a trait definition
if let Some(trait_id) = self.current_trait {
self.interner.remove_assumed_trait_implementations_for_trait(trait_id);
}
}

fn add_trait_bound_to_scope(
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ impl<'context> Elaborator<'context> {

let impl_kind = match method {
HirMethodReference::FuncId(_) => ImplKind::NotATraitMethod,
HirMethodReference::TraitMethodId(method_id, generics) => {
HirMethodReference::TraitMethodId(method_id, generics, _) => {
let mut constraint =
self.interner.get_trait(method_id.trait_id).as_constraint(span);
constraint.trait_bound.trait_generics = generics;
Expand Down
43 changes: 34 additions & 9 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ impl<'context> Elaborator<'context> {
self.recover_generics(|this| {
this.current_trait = Some(*trait_id);

let the_trait = this.interner.get_trait(*trait_id);
let self_typevar = the_trait.self_type_typevar.clone();
let self_type = Type::TypeVariable(self_typevar.clone());
this.self_type = Some(self_type.clone());

let resolved_generics = this.interner.get_trait(*trait_id).generics.clone();
this.add_existing_generics(
&unresolved_trait.trait_def.generics,
Expand All @@ -48,12 +53,15 @@ impl<'context> Elaborator<'context> {
.add_trait_dependency(DependencyId::Trait(bound.trait_id), *trait_id);
}

this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});

let methods = this.resolve_trait_methods(*trait_id, unresolved_trait);

this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_methods(methods);
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});
});

Expand Down Expand Up @@ -94,7 +102,7 @@ impl<'context> Elaborator<'context> {
parameters,
return_type,
where_clause,
body: _,
body,
is_unconstrained,
visibility: _,
is_comptime: _,
Expand All @@ -103,7 +111,6 @@ impl<'context> Elaborator<'context> {
self.recover_generics(|this| {
let the_trait = this.interner.get_trait(trait_id);
let self_typevar = the_trait.self_type_typevar.clone();
let self_type = Type::TypeVariable(self_typevar.clone());
let name_span = the_trait.name.span();

this.add_existing_generic(
Expand All @@ -115,9 +122,12 @@ impl<'context> Elaborator<'context> {
span: name_span,
},
);
this.self_type = Some(self_type.clone());

let func_id = unresolved_trait.method_ids[&name.0.contents];
let mut where_clause = where_clause.to_vec();

// Attach any trait constraints on the trait to the function
where_clause.extend(unresolved_trait.trait_def.where_clause.clone());

this.resolve_trait_function(
trait_id,
Expand All @@ -127,6 +137,7 @@ impl<'context> Elaborator<'context> {
parameters,
return_type,
where_clause,
body,
unresolved_trait.trait_def.visibility,
func_id,
);
Expand Down Expand Up @@ -189,21 +200,29 @@ impl<'context> Elaborator<'context> {
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
return_type: &FunctionReturnType,
where_clause: &[UnresolvedTraitConstraint],
where_clause: Vec<UnresolvedTraitConstraint>,
body: &Option<BlockExpression>,
trait_visibility: ItemVisibility,
func_id: FuncId,
) {
let old_generic_count = self.generics.len();

self.scopes.start_function();

let kind = FunctionKind::Normal;
let has_body = body.is_some();

let body = match body {
Some(body) => body.clone(),
None => BlockExpression { statements: Vec::new() },
};
let kind =
if has_body { FunctionKind::Normal } else { FunctionKind::TraitFunctionWithoutBody };
let mut def = FunctionDefinition::normal(
name,
is_unconstrained,
generics,
parameters,
&BlockExpression { statements: Vec::new() },
body,
where_clause,
return_type,
);
Expand All @@ -213,7 +232,13 @@ impl<'context> Elaborator<'context> {

let mut function = NoirFunction { kind, def };
self.define_function_meta(&mut function, func_id, Some(trait_id));
self.elaborate_function(func_id);

// Here we elaborate functions without a body, mainly to check the arguments and return types.
// Later on we'll elaborate functions with a body by fully type-checking them.
if !has_body {
self.elaborate_function(func_id);
}

let _ = self.scopes.end_function();
// Don't check the scope tree for unused variables, they can't be used in a declaration anyway.
self.generics.truncate(old_generic_count);
Expand Down
29 changes: 27 additions & 2 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,17 @@ impl<'context> Elaborator<'context> {
}

// this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type)
// or inside a trait default method.
//
// Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not
// E.g. `t.method()` with `where T: Foo<Bar>` in scope will return `(Foo::method, T, vec![Bar])`
fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option<TraitPathResolution> {
let trait_impl = self.current_trait_impl?;
let trait_id = self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id;
let trait_id = if let Some(current_trait) = self.current_trait {
current_trait
} else {
let trait_impl = self.current_trait_impl?;
self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id
};

if path.kind == PathKind::Plain && path.segments.len() == 2 {
let name = &path.segments[0].ident.0.contents;
Expand Down Expand Up @@ -1395,6 +1400,25 @@ impl<'context> Elaborator<'context> {
};
let func_meta = self.interner.function_meta(&func_id);

// If inside a trait method, check if it's a method on `self`
if let Some(trait_id) = func_meta.trait_id {
if Some(object_type) == self.self_type.as_ref() {
let the_trait = self.interner.get_trait(trait_id);
let constraint = the_trait.as_constraint(the_trait.name.span());
if let Some(HirMethodReference::TraitMethodId(method_id, generics, _)) = self
.lookup_method_in_trait(
the_trait,
method_name,
&constraint.trait_bound,
the_trait.id,
)
{
// If it is, it's an assumed trait
return Some(HirMethodReference::TraitMethodId(method_id, generics, true));
}
}
}

for constraint in &func_meta.trait_constraints {
if *object_type == constraint.typ {
if let Some(the_trait) =
Expand Down Expand Up @@ -1432,6 +1456,7 @@ impl<'context> Elaborator<'context> {
return Some(HirMethodReference::TraitMethodId(
trait_method,
trait_bound.trait_generics.clone(),
false,
));
}

Expand Down
8 changes: 5 additions & 3 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ impl<'a> ModCollector<'a> {
is_comptime,
} => {
let func_id = context.def_interner.push_empty_fn();
method_ids.insert(name.to_string(), func_id);
if !method_ids.contains_key(&name.0.contents) {
method_ids.insert(name.to_string(), func_id);
}

let location = Location::new(name.span(), self.file_id);
let modifiers = FunctionModifiers {
Expand Down Expand Up @@ -521,8 +523,8 @@ impl<'a> ModCollector<'a> {
*is_unconstrained,
generics,
parameters,
body,
where_clause,
body.clone(),
where_clause.clone(),
return_type,
));
unresolved_functions.push_fn(
Expand Down
9 changes: 5 additions & 4 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ pub enum HirMethodReference {
/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphization.
TraitMethodId(TraitMethodId, TraitGenerics),
TraitMethodId(TraitMethodId, TraitGenerics, bool /* assumed */),
}

impl HirMethodReference {
pub fn func_id(&self, interner: &NodeInterner) -> Option<FuncId> {
match self {
HirMethodReference::FuncId(func_id) => Some(*func_id),
HirMethodReference::TraitMethodId(method_id, _) => {
HirMethodReference::TraitMethodId(method_id, _, _) => {
let id = interner.trait_method_id(*method_id);
match &interner.try_definition(id)?.kind {
DefinitionKind::Function(func_id) => Some(*func_id),
Expand Down Expand Up @@ -246,7 +246,7 @@ impl HirMethodCallExpression {
HirMethodReference::FuncId(func_id) => {
(interner.function_definition_id(func_id), ImplKind::NotATraitMethod)
}
HirMethodReference::TraitMethodId(method_id, trait_generics) => {
HirMethodReference::TraitMethodId(method_id, trait_generics, assumed) => {
let id = interner.trait_method_id(method_id);
let constraint = TraitConstraint {
typ: object_type,
Expand All @@ -256,7 +256,8 @@ impl HirMethodCallExpression {
span: location.span,
},
};
(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed: false }))

(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed }))
}
};
let func_var = HirIdent { location, id, impl_kind };
Expand Down
5 changes: 3 additions & 2 deletions compiler/noirc_frontend/src/hir_def/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,13 @@ pub enum FunctionBody {

impl FuncMeta {
/// A stub function does not have a body. This includes Builtin, LowLevel,
/// and Oracle functions in addition to method declarations within a trait.
/// and Oracle functions in addition to method declarations within a trait
/// without a body.
///
/// We don't check the return type of these functions since it will always have
/// an empty body, and we don't check for unused parameters.
pub fn is_stub(&self) -> bool {
self.kind.can_ignore_return_type() || self.trait_id.is_some()
self.kind.can_ignore_return_type()
}

pub fn function_signature(&self) -> FunctionSignature {
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl<'interner> Monomorphizer<'interner> {
);
Definition::Builtin(opcode.to_string())
}
FunctionKind::Normal => {
FunctionKind::Normal | FunctionKind::TraitFunctionWithoutBody => {
let id =
self.queue_function(id, expr_id, typ, turbofish_generics, trait_method);
Definition::Function(id)
Expand Down
Loading
Loading