Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traits): Implement trait bounds typechecker + monomorphizer passes #2717

Merged
merged 7 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ impl<'a> Resolver<'a> {
) -> Vec<TraitConstraint> {
vecmap(where_clause, |constraint| TraitConstraint {
typ: self.resolve_type(constraint.typ.clone()),
trait_id: constraint.trait_bound.trait_id,
trait_id: constraint.trait_bound.trait_id.unwrap_or_else(TraitId::dummy_id),
})
}

Expand Down
117 changes: 81 additions & 36 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use crate::{
hir_def::{
expr::{
self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral, HirMethodCallExpression,
HirPrefixExpression,
HirMethodReference, HirPrefixExpression,
},
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId},
node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId},
Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp,
};

Expand Down Expand Up @@ -144,7 +144,7 @@ impl<'interner> TypeChecker<'interner> {
let object_type = self.check_expression(&method_call.object).follow_bindings();
let method_name = method_call.method.0.contents.as_str();
match self.lookup_method(&object_type, method_name, expr_id) {
Some(method_id) => {
Some(method_ref) => {
let mut args = vec![(
object_type,
method_call.object,
Expand All @@ -160,22 +160,27 @@ impl<'interner> TypeChecker<'interner> {
// so that the backend doesn't need to worry about methods
let location = method_call.location;

// Automatically add `&mut` if the method expects a mutable reference and
// the object is not already one.
if method_id != FuncId::dummy_id() {
let func_meta = self.interner.function_meta(&method_id);
self.try_add_mutable_reference_to_object(
&mut method_call,
&func_meta.typ,
&mut args,
);
if let HirMethodReference::FuncId(func_id) = method_ref {
// Automatically add `&mut` if the method expects a mutable reference and
// the object is not already one.
if func_id != FuncId::dummy_id() {
let func_meta = self.interner.function_meta(&func_id);
self.try_add_mutable_reference_to_object(
&mut method_call,
&func_meta.typ,
&mut args,
);
}
}

let (function_id, function_call) =
method_call.into_function_call(method_id, location, self.interner);
let (function_id, function_call) = method_call.into_function_call(
method_ref,
location,
self.interner,
);

let span = self.interner.expr_span(expr_id);
let ret = self.check_method_call(&function_id, &method_id, args, span);
let ret = self.check_method_call(&function_id, method_ref, args, span);

self.interner.replace_expr(expr_id, function_call);
ret
Expand Down Expand Up @@ -286,6 +291,7 @@ impl<'interner> TypeChecker<'interner> {

Type::Function(params, Box::new(lambda.return_type), Box::new(env_type))
}
HirExpression::TraitMethodReference(_) => unreachable!("unexpected TraitMethodReference - they should be added after initial type checking"),
};

self.interner.push_expr_type(expr_id, typ.clone());
Expand Down Expand Up @@ -477,34 +483,46 @@ impl<'interner> TypeChecker<'interner> {
fn check_method_call(
&mut self,
function_ident_id: &ExprId,
func_id: &FuncId,
method_ref: HirMethodReference,
arguments: Vec<(Type, ExprId, Span)>,
span: Span,
) -> Type {
if func_id == &FuncId::dummy_id() {
Type::Error
} else {
let func_meta = self.interner.function_meta(func_id);
let (fntyp, param_len) = match method_ref {
HirMethodReference::FuncId(func_id) => {
if func_id == FuncId::dummy_id() {
return Type::Error;
}

// Check function call arity is correct
let param_len = func_meta.parameters.len();
let arg_len = arguments.len();
let func_meta = self.interner.function_meta(&func_id);
let param_len = func_meta.parameters.len();

if param_len != arg_len {
self.errors.push(TypeCheckError::ArityMisMatch {
expected: param_len as u16,
found: arg_len as u16,
span,
});
(func_meta.typ, param_len)
}
HirMethodReference::TraitMethodId(method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let the_trait = the_trait.borrow();
let method = &the_trait.methods[method.method_index];

let (function_type, instantiation_bindings) = func_meta.typ.instantiate(self.interner);
(method.get_type(), method.arguments.len())
}
};

self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings);
self.interner.push_expr_type(function_ident_id, function_type.clone());
let arg_len = arguments.len();

self.bind_function_type(function_type, arguments, span)
if param_len != arg_len {
self.errors.push(TypeCheckError::ArityMisMatch {
expected: param_len as u16,
found: arg_len as u16,
span,
});
}

let (function_type, instantiation_bindings) = fntyp.instantiate(self.interner);

self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings);
self.interner.push_expr_type(function_ident_id, function_type.clone());

self.bind_function_type(function_type, arguments, span)
}

fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type {
Expand Down Expand Up @@ -818,11 +836,11 @@ impl<'interner> TypeChecker<'interner> {
object_type: &Type,
method_name: &str,
expr_id: &ExprId,
) -> Option<FuncId> {
) -> Option<HirMethodReference> {
match object_type {
Type::Struct(typ, _args) => {
match self.interner.lookup_method(typ.borrow().id, method_name) {
Some(method_id) => Some(method_id),
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
Expand All @@ -833,6 +851,33 @@ impl<'interner> TypeChecker<'interner> {
}
}
}
Type::NamedGeneric(_, _) => {
let func_meta = self.interner.function_meta(
&self.current_function.expect("unexpected method outside a function"),
);

for constraint in func_meta.trait_constraints {
if *object_type == constraint.typ {
let the_trait = self.interner.get_trait(constraint.trait_id);
let the_trait = the_trait.borrow();

for (method_index, method) in the_trait.methods.iter().enumerate() {
if method.name.0.contents == method_name {
let trait_method =
TraitMethodId { trait_id: constraint.trait_id, method_index };
return Some(HirMethodReference::TraitMethodId(trait_method));
}
}
}
}

self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
object_type: object_type.clone(),
span: self.interner.expr_span(expr_id),
});
None
}
// Mutable references to another type should resolve to methods of their element type.
// This may be a struct or a primitive type.
Type::MutableReference(element) => self.lookup_method(element, method_name, expr_id),
Expand All @@ -843,7 +888,7 @@ impl<'interner> TypeChecker<'interner> {
// In the future we could support methods for non-struct types if we have a context
// (in the interner?) essentially resembling HashMap<Type, Methods>
other => match self.interner.lookup_primitive_method(other, method_name) {
Some(method_id) => Some(method_id),
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
Expand Down
11 changes: 9 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct TypeChecker<'interner> {
delayed_type_checks: Vec<TypeCheckFn>,
interner: &'interner mut NodeInterner,
errors: Vec<TypeCheckError>,
current_function: Option<FuncId>,
}

/// Type checks a function and assigns the
Expand All @@ -40,6 +41,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
let function_body_id = function_body.as_expr();

let mut type_checker = TypeChecker::new(interner);
type_checker.current_function = Some(func_id);

// Bind each parameter to its annotated type.
// This is locally obvious, but it must be bound here so that the
Expand Down Expand Up @@ -111,7 +113,7 @@ fn function_info(interner: &NodeInterner, function_body_id: &ExprId) -> (noirc_e

impl<'interner> TypeChecker<'interner> {
fn new(interner: &'interner mut NodeInterner) -> Self {
Self { delayed_type_checks: Vec::new(), interner, errors: vec![] }
Self { delayed_type_checks: Vec::new(), interner, errors: vec![], current_function: None }
}

pub fn push_delayed_type_check(&mut self, f: TypeCheckFn) {
Expand All @@ -127,7 +129,12 @@ impl<'interner> TypeChecker<'interner> {
}

pub fn check_global(id: &StmtId, interner: &'interner mut NodeInterner) -> Vec<TypeCheckError> {
let mut this = Self { delayed_type_checks: Vec::new(), interner, errors: vec![] };
let mut this = Self {
delayed_type_checks: Vec::new(),
interner,
errors: vec![],
current_function: None,
};
this.check_statement(id);
this.errors
}
Expand Down
32 changes: 26 additions & 6 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use acvm::FieldElement;
use fm::FileId;
use noirc_errors::Location;

use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId};
use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId, TraitMethodId};
use crate::{BinaryOp, BinaryOpKind, Ident, Shared, UnaryOp};

use super::stmt::HirPattern;
Expand Down Expand Up @@ -30,6 +30,7 @@ pub enum HirExpression {
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
TraitMethodReference(TraitMethodId),
Error,
}

Expand Down Expand Up @@ -150,20 +151,39 @@ pub struct HirMethodCallExpression {
pub location: Location,
}

#[derive(Debug, Copy, Clone)]
pub enum HirMethodReference {
/// A method can be defined in a regular `impl` block, in which case
/// it's syntax sugar for a normal function call, and can be
/// translated to one during type checking
FuncId(FuncId),

/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphizaiton.
TraitMethodId(TraitMethodId),
}

impl HirMethodCallExpression {
pub fn into_function_call(
mut self,
func: FuncId,
method: HirMethodReference,
location: Location,
interner: &mut NodeInterner,
) -> (ExprId, HirExpression) {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);

let id = interner.function_definition_id(func);
let ident = HirExpression::Ident(HirIdent { location, id });
let func = interner.push_expr(ident);

let expr = match method {
HirMethodReference::FuncId(func_id) => {
let id = interner.function_definition_id(func_id);
HirExpression::Ident(HirIdent { location, id })
}
HirMethodReference::TraitMethodId(method_id) => {
HirExpression::TraitMethodReference(method_id)
}
};
let func = interner.push_expr(expr);
(func, HirExpression::Call(HirCallExpression { func, arguments, location }))
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub struct TraitImpl {
#[derive(Debug, Clone)]
pub struct TraitConstraint {
pub typ: Type,
pub trait_id: Option<TraitId>,
pub trait_id: TraitId,
// pub trait_generics: Generics, TODO
}

Expand Down
57 changes: 55 additions & 2 deletions compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
use acvm::FieldElement;
use iter_extended::{btree_map, vecmap};
use noirc_printable_type::PrintableType;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::{
collections::{BTreeMap, HashMap, VecDeque},
unreachable,
};

use crate::{
hir_def::{
Expand All @@ -20,7 +23,7 @@ use crate::{
stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement},
types,
},
node_interner::{self, DefinitionKind, NodeInterner, StmtId},
node_interner::{self, DefinitionKind, NodeInterner, StmtId, TraitImplKey, TraitMethodId},
token::FunctionAttribute,
ContractFunctionType, FunctionKind, Type, TypeBinding, TypeBindings, TypeVariableKind,
Visibility,
Expand Down Expand Up @@ -375,6 +378,17 @@ impl<'interner> Monomorphizer<'interner> {

HirExpression::Lambda(lambda) => self.lambda(lambda, expr),

HirExpression::TraitMethodReference(method) => {
if let Type::Function(args, _, _) = self.interner.id_type(expr) {
let self_type = args[0].clone();
self.resolve_trait_method_reference(self_type, expr, method)
} else {
unreachable!(
"Calling a non-function, this should've been caught in typechecking"
);
}
}

HirExpression::MethodCall(_) => {
unreachable!("Encountered HirExpression::MethodCall during monomorphization")
}
Expand Down Expand Up @@ -777,6 +791,45 @@ impl<'interner> Monomorphizer<'interner> {
}
}

fn resolve_trait_method_reference(
&mut self,
self_type: HirType,
expr_id: node_interner::ExprId,
method: TraitMethodId,
) -> ast::Expression {
let function_type = self.interner.id_type(expr_id);

// the substitute() here is to replace all internal occurences of the 'Self' typevar
// with whatever 'Self' is currently bound to, so we don't lose type information
// if we need to rebind the trait.
let trait_impl = self
.interner
.get_trait_implementation(&TraitImplKey {
typ: self_type.follow_bindings(),
trait_id: method.trait_id,
})
.expect("ICE: missing trait impl - should be caught during type checking");

let hir_func_id = trait_impl.borrow().methods[method.method_index];

let func_def = self.lookup_function(hir_func_id, expr_id, &function_type);
let func_id = match func_def {
Definition::Function(func_id) => func_id,
_ => unreachable!(),
};

let the_trait = self.interner.get_trait(method.trait_id);
let the_trait = the_trait.borrow();

ast::Expression::Ident(ast::Ident {
definition: Definition::Function(func_id),
mutable: false,
location: None,
name: the_trait.methods[method.method_index].name.0.contents.clone(),
typ: self.convert_type(&function_type),
})
}

fn function_call(
&mut self,
call: HirCallExpression,
Expand Down
Loading