diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 295297cc738..6e2756f0301 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -553,7 +553,7 @@ impl<'context> Elaborator<'context> { fn elaborate_cast(&mut self, cast: CastExpression, span: Span) -> (HirExpression, Type) { let (lhs, lhs_type) = self.elaborate_expression(cast.lhs); let r#type = self.resolve_type(cast.r#type); - let result = self.check_cast(lhs_type, &r#type, span); + let result = self.check_cast(&lhs_type, &r#type, span); let expr = HirExpression::Cast(HirCastExpression { lhs, r#type }); (expr, result) } diff --git a/compiler/noirc_frontend/src/elaborator/scope.rs b/compiler/noirc_frontend/src/elaborator/scope.rs index 73aed9bf06c..258c32d4427 100644 --- a/compiler/noirc_frontend/src/elaborator/scope.rs +++ b/compiler/noirc_frontend/src/elaborator/scope.rs @@ -1,6 +1,6 @@ use noirc_errors::{Location, Spanned}; -use crate::ast::ERROR_IDENT; +use crate::ast::{PathKind, ERROR_IDENT}; use crate::hir::def_map::{LocalModuleId, ModuleId}; use crate::hir::resolution::path_resolver::{PathResolver, StandardPathResolver}; use crate::hir::scope::{Scope as GenericScope, ScopeTree as GenericScopeTree}; @@ -43,7 +43,34 @@ impl<'context> Elaborator<'context> { } pub(super) fn resolve_path(&mut self, path: Path) -> Result { - let resolver = StandardPathResolver::new(self.module_id()); + let mut module_id = self.module_id(); + let mut path = path; + + if path.kind == PathKind::Plain && path.first_name() == SELF_TYPE_NAME { + if let Some(Type::Struct(struct_type, _)) = &self.self_type { + let struct_type = struct_type.borrow(); + if path.segments.len() == 1 { + return Ok(ModuleDefId::TypeId(struct_type.id)); + } + + module_id = struct_type.id.module_id(); + path = Path { + segments: path.segments[1..].to_vec(), + kind: PathKind::Plain, + span: path.span(), + }; + } + } + + self.resolve_path_in_module(path, module_id) + } + + fn resolve_path_in_module( + &mut self, + path: Path, + module_id: ModuleId, + ) -> Result { + let resolver = StandardPathResolver::new(module_id); let path_resolution; if self.interner.track_references { diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index c134820811e..0973e592c1e 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -779,7 +779,7 @@ impl<'context> Elaborator<'context> { } } - pub(super) fn check_cast(&mut self, from: Type, to: &Type, span: Span) -> Type { + pub(super) fn check_cast(&mut self, from: &Type, to: &Type, span: Span) -> Type { match from.follow_bindings() { Type::Integer(..) | Type::FieldElement @@ -788,8 +788,13 @@ impl<'context> Elaborator<'context> { | Type::Bool => (), Type::TypeVariable(_, _) => { - self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); - return Type::Error; + // NOTE: in reality the expected type can also include bool, but for the compiler's simplicity + // we only allow integer types. If a bool is in `from` it will need an explicit type annotation. + let expected = Type::polymorphic_integer_or_field(self.interner); + self.unify(from, &expected, || TypeCheckError::InvalidCast { + from: from.clone(), + span, + }); } Type::Error => return Type::Error, from => { diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index c63a6961da5..e6506a5fde6 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -383,8 +383,8 @@ impl<'interner> Monomorphizer<'interner> { self.parameter(field, &typ, new_params)?; } } - HirPattern::Struct(_, fields, _) => { - let struct_field_types = unwrap_struct_type(typ); + HirPattern::Struct(_, fields, location) => { + let struct_field_types = unwrap_struct_type(typ, *location)?; assert_eq!(struct_field_types.len(), fields.len()); let mut fields = @@ -663,8 +663,10 @@ impl<'interner> Monomorphizer<'interner> { constructor: HirConstructorExpression, id: node_interner::ExprId, ) -> Result { + let location = self.interner.expr_location(&id); + let typ = self.interner.id_type(id); - let field_types = unwrap_struct_type(&typ); + let field_types = unwrap_struct_type(&typ, location)?; let field_type_map = btree_map(&field_types, |x| x.clone()); @@ -740,8 +742,8 @@ impl<'interner> Monomorphizer<'interner> { let fields = unwrap_tuple_type(typ); self.unpack_tuple_pattern(value, patterns.into_iter().zip(fields)) } - HirPattern::Struct(_, patterns, _) => { - let fields = unwrap_struct_type(typ); + HirPattern::Struct(_, patterns, location) => { + let fields = unwrap_struct_type(typ, location)?; assert_eq!(patterns.len(), fields.len()); let mut patterns = @@ -975,12 +977,24 @@ impl<'interner> Monomorphizer<'interner> { } HirType::Struct(def, args) => { + // Not all generic arguments may be used in a struct's fields so we have to check + // the arguments as well as the fields in case any need to be defaulted or are unbound. + for arg in args { + Self::check_type(arg, location)?; + } + let fields = def.borrow().get_fields(args); let fields = try_vecmap(fields, |(_, field)| Self::convert_type(&field, location))?; ast::Type::Tuple(fields) } HirType::Alias(def, args) => { + // Similar to the struct case above: generics of an alias might not end up being + // used in the type that is aliased. + for arg in args { + Self::check_type(arg, location)?; + } + Self::convert_type(&def.borrow().get_type(args), location)? } @@ -1019,6 +1033,83 @@ impl<'interner> Monomorphizer<'interner> { }) } + // Similar to `convert_type` but returns an error if any type variable can't be defaulted. + fn check_type(typ: &HirType, location: Location) -> Result<(), MonomorphizationError> { + match typ { + HirType::FieldElement + | HirType::Integer(..) + | HirType::Bool + | HirType::String(..) + | HirType::Unit + | HirType::TraitAsType(..) + | HirType::Forall(_, _) + | HirType::Constant(_) + | HirType::Error + | HirType::Quoted(_) => Ok(()), + HirType::FmtString(_size, fields) => Self::check_type(fields.as_ref(), location), + HirType::Array(_length, element) => Self::check_type(element.as_ref(), location), + HirType::Slice(element) => Self::check_type(element.as_ref(), location), + HirType::NamedGeneric(binding, _, _) => { + if let TypeBinding::Bound(binding) = &*binding.borrow() { + return Self::check_type(binding, location); + } + + Ok(()) + } + + HirType::TypeVariable(binding, kind) => { + if let TypeBinding::Bound(binding) = &*binding.borrow() { + return Self::check_type(binding, location); + } + + // Default any remaining unbound type variables. + // This should only happen if the variable in question is unused + // and within a larger generic type. + let default = match kind.default_type() { + Some(typ) => typ, + None => return Err(MonomorphizationError::TypeAnnotationsNeeded { location }), + }; + + Self::check_type(&default, location) + } + + HirType::Struct(_def, args) => { + for arg in args { + Self::check_type(arg, location)?; + } + + Ok(()) + } + + HirType::Alias(_def, args) => { + for arg in args { + Self::check_type(arg, location)?; + } + + Ok(()) + } + + HirType::Tuple(fields) => { + for field in fields { + Self::check_type(field, location)?; + } + + Ok(()) + } + + HirType::Function(args, ret, env) => { + for arg in args { + Self::check_type(arg, location)?; + } + + Self::check_type(ret, location)?; + Self::check_type(env, location) + } + + HirType::MutableReference(element) => Self::check_type(element, location), + } + } + fn is_function_closure(&self, t: ast::Type) -> bool { if self.is_function_closure_type(&t) { true @@ -1753,9 +1844,19 @@ fn unwrap_tuple_type(typ: &HirType) -> Vec { } } -fn unwrap_struct_type(typ: &HirType) -> Vec<(String, HirType)> { +fn unwrap_struct_type( + typ: &HirType, + location: Location, +) -> Result, MonomorphizationError> { match typ.follow_bindings() { - HirType::Struct(def, args) => def.borrow().get_fields(&args), + HirType::Struct(def, args) => { + // Some of args might not be mentioned in fields, so we need to check that they aren't unbound. + for arg in &args { + Monomorphizer::check_type(arg, location)?; + } + + Ok(def.borrow().get_fields(&args)) + } other => unreachable!("unwrap_struct_type: expected struct, found {:?}", other), } } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index a21259a4f0d..8ce430b6e48 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2729,3 +2729,120 @@ fn incorrect_generic_count_on_type_alias() { assert_eq!(actual, 1); assert_eq!(expected, 0); } + +#[test] +fn uses_self_type_for_struct_function_call() { + let src = r#" + struct S { } + + impl S { + fn one() -> Field { + 1 + } + + fn two() -> Field { + Self::one() + Self::one() + } + } + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn uses_self_type_inside_trait() { + let src = r#" + trait Foo { + fn foo() -> Self { + Self::bar() + } + + fn bar() -> Self; + } + + impl Foo for Field { + fn bar() -> Self { + 1 + } + } + + fn main() { + let _: Field = Foo::foo(); + } + "#; + assert_no_errors(src); +} + +#[test] +fn uses_self_type_in_trait_where_clause() { + let src = r#" + trait Trait { + fn trait_func() -> bool; + } + + trait Foo where Self: Trait { + fn foo(self) -> bool { + self.trait_func() + } + } + + struct Bar { + + } + + impl Foo for Bar { + + } + + fn main() {} + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { method_name, .. }) = + &errors[0].0 + else { + panic!("Expected an unresolved method call error, got {:?}", errors[0].0); + }; + + assert_eq!(method_name, "trait_func"); +} + +#[test] +fn do_not_eagerly_error_on_cast_on_type_variable() { + let src = r#" + pub fn foo(x: T, f: fn(T) -> U) -> U { + f(x) + } + + fn main() { + let x: u8 = 1; + let _: Field = foo(x, |x| x as Field); + } + "#; + assert_no_errors(src); +} + +#[test] +fn error_on_cast_over_type_variable() { + let src = r#" + pub fn foo(x: T, f: fn(T) -> U) -> U { + f(x) + } + + fn main() { + let x = "a"; + let _: Field = foo(x, |x| x as Field); + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + assert!(matches!( + errors[0].0, + CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) + )); +} diff --git a/test_programs/compile_failure/type_annotation_needed_on_struct_constructor/Nargo.toml b/test_programs/compile_failure/type_annotation_needed_on_struct_constructor/Nargo.toml new file mode 100644 index 00000000000..ac7933fa250 --- /dev/null +++ b/test_programs/compile_failure/type_annotation_needed_on_struct_constructor/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "type_annotation_needed_on_struct_constructor" +type = "bin" +authors = [""] +compiler_version = ">=0.31.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_failure/type_annotation_needed_on_struct_constructor/src/main.nr b/test_programs/compile_failure/type_annotation_needed_on_struct_constructor/src/main.nr new file mode 100644 index 00000000000..5207210dfbf --- /dev/null +++ b/test_programs/compile_failure/type_annotation_needed_on_struct_constructor/src/main.nr @@ -0,0 +1,6 @@ +struct Foo { +} + +fn main() { + let foo = Foo {}; +} diff --git a/test_programs/compile_failure/type_annotation_needed_on_struct_new/Nargo.toml b/test_programs/compile_failure/type_annotation_needed_on_struct_new/Nargo.toml new file mode 100644 index 00000000000..cb53d2924f4 --- /dev/null +++ b/test_programs/compile_failure/type_annotation_needed_on_struct_new/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "type_annotation_needed_on_struct_new" +type = "bin" +authors = [""] +compiler_version = ">=0.31.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_failure/type_annotation_needed_on_struct_new/src/main.nr b/test_programs/compile_failure/type_annotation_needed_on_struct_new/src/main.nr new file mode 100644 index 00000000000..f740dfa6d37 --- /dev/null +++ b/test_programs/compile_failure/type_annotation_needed_on_struct_new/src/main.nr @@ -0,0 +1,12 @@ +struct Foo { +} + +impl Foo { + fn new() -> Foo { + Foo {} + } +} + +fn main() { + let foo = Foo::new(); +}