Skip to content

Commit

Permalink
Implement type constraints late passes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexvitkov committed Sep 25, 2023
1 parent f7529b8 commit 05e0b21
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 51 deletions.
122 changes: 85 additions & 37 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use crate::{
hir_def::{
expr::{
self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral, HirMethodCallExpression,
HirPrefixExpression,
HirMethodReference, HirPrefixExpression,
},
traits::TraitFunction,
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId},
node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId},
Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp,
};

Expand Down Expand Up @@ -144,7 +145,7 @@ impl<'interner> TypeChecker<'interner> {
let object_type = self.check_expression(&method_call.object).follow_bindings();
let method_name = method_call.method.0.contents.as_str();
match self.lookup_method(&object_type, method_name, expr_id) {
Some(method_id) => {
Some(method_ref) => {
let mut args = vec![(
object_type,
method_call.object,
Expand All @@ -160,22 +161,27 @@ impl<'interner> TypeChecker<'interner> {
// so that the backend doesn't need to worry about methods
let location = method_call.location;

// Automatically add `&mut` if the method expects a mutable reference and
// the object is not already one.
if method_id != FuncId::dummy_id() {
let func_meta = self.interner.function_meta(&method_id);
self.try_add_mutable_reference_to_object(
&mut method_call,
&func_meta.typ,
&mut args,
);
}
if let HirMethodReference::FuncId(func_id) = method_ref {
// Automatically add `&mut` if the method expects a mutable reference and
// the object is not already one.
if func_id != FuncId::dummy_id() {
let func_meta = self.interner.function_meta(&func_id);
self.try_add_mutable_reference_to_object(
&mut method_call,
&func_meta.typ,
&mut args,
);
}
};

let (function_id, function_call) =
method_call.into_function_call(method_id, location, self.interner);
let (function_id, function_call) = method_call.into_function_call(
method_ref.clone(),
location,
self.interner,
);

let span = self.interner.expr_span(expr_id);
let ret = self.check_method_call(&function_id, &method_id, args, span);
let ret = self.check_method_call(&function_id, &method_ref, args, span);

self.interner.replace_expr(expr_id, function_call);
ret
Expand Down Expand Up @@ -286,6 +292,7 @@ impl<'interner> TypeChecker<'interner> {

Type::Function(params, Box::new(lambda.return_type), Box::new(env_type))
}
HirExpression::TraitMethodReference(_) => unreachable!(),
};

self.interner.push_expr_type(expr_id, typ.clone());
Expand Down Expand Up @@ -477,34 +484,46 @@ impl<'interner> TypeChecker<'interner> {
fn check_method_call(
&mut self,
function_ident_id: &ExprId,
func_id: &FuncId,
method_ref: &HirMethodReference,
arguments: Vec<(Type, ExprId, Span)>,
span: Span,
) -> Type {
if func_id == &FuncId::dummy_id() {
Type::Error
} else {
let func_meta = self.interner.function_meta(func_id);
let (fntyp, param_len) = match method_ref {
HirMethodReference::FuncId(func_id) => {
if func_id == &FuncId::dummy_id() {
return Type::Error;
}

// Check function call arity is correct
let param_len = func_meta.parameters.len();
let arg_len = arguments.len();
let func_meta = self.interner.function_meta(func_id);
let param_len = func_meta.parameters.len();

if param_len != arg_len {
self.errors.push(TypeCheckError::ArityMisMatch {
expected: param_len as u16,
found: arg_len as u16,
span,
});
(func_meta.typ, param_len)
}
HirMethodReference::TraitMethodId(method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let the_trait = the_trait.borrow();
let method: &TraitFunction = &the_trait.methods[method.method_index];

let (function_type, instantiation_bindings) = func_meta.typ.instantiate(self.interner);
(method.get_type(), method.arguments.len())
}
};

self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings);
self.interner.push_expr_type(function_ident_id, function_type.clone());
let arg_len = arguments.len();

self.bind_function_type(function_type, arguments, span)
if param_len != arg_len {
self.errors.push(TypeCheckError::ArityMisMatch {
expected: param_len as u16,
found: arg_len as u16,
span,
});
}

let (function_type, instantiation_bindings) = fntyp.instantiate(self.interner);

self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings);
self.interner.push_expr_type(function_ident_id, function_type.clone());

self.bind_function_type(function_type, arguments, span)
}

fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type {
Expand Down Expand Up @@ -818,11 +837,11 @@ impl<'interner> TypeChecker<'interner> {
object_type: &Type,
method_name: &str,
expr_id: &ExprId,
) -> Option<FuncId> {
) -> Option<HirMethodReference> {
match object_type {
Type::Struct(typ, _args) => {
match self.interner.lookup_method(typ.borrow().id, method_name) {
Some(method_id) => Some(method_id),
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
Expand All @@ -833,6 +852,35 @@ impl<'interner> TypeChecker<'interner> {
}
}
}
Type::NamedGeneric(_, _) => {
let func_meta = self.interner.function_meta(&self.current_function.unwrap());

for constraint in func_meta.trait_constraints {
if let Some(trait_id) = constraint.trait_id {
// TODO(#2568): == on types is sketchy, since Field != TypeVar::Bound(Field)
// unify() is sketchier here though, since it may accidentally commit typebindings.
// this works for now, but likely needs to be revisited when we implement generic traits
if *object_type == constraint.typ {
let the_trait = self.interner.get_trait(trait_id);
let the_trait = the_trait.borrow();

for (method_index, method) in the_trait.methods.iter().enumerate() {
if method.name.0.contents == method_name {
let trait_method = TraitMethodId { trait_id, method_index };
return Some(HirMethodReference::TraitMethodId(trait_method));
}
}
}
}
}

self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
object_type: object_type.clone(),
span: self.interner.expr_span(expr_id),
});
None
}
// Mutable references to another type should resolve to methods of their element type.
// This may be a struct or a primitive type.
Type::MutableReference(element) => self.lookup_method(element, method_name, expr_id),
Expand All @@ -843,7 +891,7 @@ impl<'interner> TypeChecker<'interner> {
// In the future we could support methods for non-struct types if we have a context
// (in the interner?) essentially resembling HashMap<Type, Methods>
other => match self.interner.lookup_primitive_method(other, method_name) {
Some(method_id) => Some(method_id),
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
Expand Down
11 changes: 9 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct TypeChecker<'interner> {
delayed_type_checks: Vec<TypeCheckFn>,
interner: &'interner mut NodeInterner,
errors: Vec<TypeCheckError>,
current_function: Option<FuncId>,
}

/// Type checks a function and assigns the
Expand All @@ -40,6 +41,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
let function_body_id = function_body.as_expr();

let mut type_checker = TypeChecker::new(interner);
type_checker.current_function = Some(func_id);

// Bind each parameter to its annotated type.
// This is locally obvious, but it must be bound here so that the
Expand Down Expand Up @@ -111,7 +113,7 @@ fn function_info(interner: &NodeInterner, function_body_id: &ExprId) -> (noirc_e

impl<'interner> TypeChecker<'interner> {
fn new(interner: &'interner mut NodeInterner) -> Self {
Self { delayed_type_checks: Vec::new(), interner, errors: vec![] }
Self { delayed_type_checks: Vec::new(), interner, errors: vec![], current_function: None }
}

pub fn push_delayed_type_check(&mut self, f: TypeCheckFn) {
Expand All @@ -127,7 +129,12 @@ impl<'interner> TypeChecker<'interner> {
}

pub fn check_global(id: &StmtId, interner: &'interner mut NodeInterner) -> Vec<TypeCheckError> {
let mut this = Self { delayed_type_checks: Vec::new(), interner, errors: vec![] };
let mut this = Self {
delayed_type_checks: Vec::new(),
interner,
errors: vec![],
current_function: None,
};
this.check_statement(id);
this.errors
}
Expand Down
32 changes: 26 additions & 6 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use acvm::FieldElement;
use fm::FileId;
use noirc_errors::Location;

use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId};
use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId, TraitMethodId};
use crate::{BinaryOp, BinaryOpKind, Ident, Shared, UnaryOp};

use super::stmt::HirPattern;
Expand Down Expand Up @@ -30,6 +30,7 @@ pub enum HirExpression {
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
TraitMethodReference(TraitMethodId),
Error,
}

Expand Down Expand Up @@ -150,20 +151,39 @@ pub struct HirMethodCallExpression {
pub location: Location,
}

#[derive(Debug, Clone)]
pub enum HirMethodReference {
/// A method can be defined in a regular `impl` block, in which case
/// it's syntax sugar for a normal function call, and can be
/// translated to one during type checking
FuncId(FuncId),

/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphizaiton.
TraitMethodId(TraitMethodId),
}

impl HirMethodCallExpression {
pub fn into_function_call(
mut self,
func: FuncId,
method: HirMethodReference,
location: Location,
interner: &mut NodeInterner,
) -> (ExprId, HirExpression) {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);

let id = interner.function_definition_id(func);
let ident = HirExpression::Ident(HirIdent { location, id });
let func = interner.push_expr(ident);

let expr = match method {
HirMethodReference::FuncId(func_id) => {
let id = interner.function_definition_id(func_id);
HirExpression::Ident(HirIdent { location, id })
}
HirMethodReference::TraitMethodId(method_id) => {
HirExpression::TraitMethodReference(method_id)
}
};
let func = interner.push_expr(expr);
(func, HirExpression::Call(HirCallExpression { func, arguments, location }))
}
}
Expand Down
4 changes: 0 additions & 4 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1017,10 +1017,6 @@ impl Type {
/// given bindings if found. If a type variable is not found within
/// the given TypeBindings, it is unchanged.
pub fn substitute(&self, type_bindings: &TypeBindings) -> Type {
if type_bindings.is_empty() {
return self.clone();
}

let substitute_binding = |binding: &TypeVariable| match &*binding.borrow() {
TypeBinding::Bound(binding) => binding.substitute(type_bindings),
TypeBinding::Unbound(id) => match type_bindings.get(id) {
Expand Down
57 changes: 55 additions & 2 deletions compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
use acvm::FieldElement;
use iter_extended::{btree_map, vecmap};
use noirc_printable_type::PrintableType;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::{
collections::{BTreeMap, HashMap, VecDeque},
unreachable,
};

use crate::{
hir_def::{
Expand All @@ -20,7 +23,7 @@ use crate::{
stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement},
types,
},
node_interner::{self, DefinitionKind, NodeInterner, StmtId},
node_interner::{self, DefinitionKind, NodeInterner, StmtId, TraitImplKey, TraitMethodId},
token::FunctionAttribute,
ContractFunctionType, FunctionKind, Type, TypeBinding, TypeBindings, TypeVariableKind,
Visibility,
Expand Down Expand Up @@ -375,6 +378,17 @@ impl<'interner> Monomorphizer<'interner> {

HirExpression::Lambda(lambda) => self.lambda(lambda, expr),

HirExpression::TraitMethodReference(method) => {
if let Type::Function(args, _, _) = self.interner.id_type(expr) {
let self_type = args[0].clone();
self.resolve_trait_method_reference(self_type, expr, method)
} else {
unreachable!(
"Calling a non-function, this should've been caught in typechecking"
);
}
}

HirExpression::MethodCall(_) => {
unreachable!("Encountered HirExpression::MethodCall during monomorphization")
}
Expand Down Expand Up @@ -777,6 +791,45 @@ impl<'interner> Monomorphizer<'interner> {
}
}

fn resolve_trait_method_reference(
&mut self,
self_type: HirType,
expr_id: node_interner::ExprId,
method: TraitMethodId,
) -> ast::Expression {
let function_type = self.interner.id_type(expr_id);

// the substitute() here is to replace all internal occurences of the 'Self' typevar
// with whatever 'Self' is currently bound to, so we don't lose type information
// if we need to rebind the trait.
let trait_impl = self
.interner
.get_trait_implementation(&TraitImplKey {
typ: self_type.substitute(&HashMap::new()),
trait_id: method.trait_id,
})
.expect("ICE: missing trait impl - should be caught during type checking");

let hir_func_id = trait_impl.borrow().methods[method.method_index];

let func_def = self.lookup_function(hir_func_id, expr_id, &function_type);
let func_id = match func_def {
Definition::Function(func_id) => func_id,
_ => unreachable!(),
};

let the_trait = self.interner.get_trait(method.trait_id);
let the_trait = the_trait.borrow();

ast::Expression::Ident(ast::Ident {
definition: Definition::Function(func_id),
mutable: false,
location: None,
name: the_trait.methods[method.method_index].name.0.contents.clone(),
typ: Self::convert_type(&function_type),
})
}

fn function_call(
&mut self,
call: HirCallExpression,
Expand Down
Loading

0 comments on commit 05e0b21

Please sign in to comment.