diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index f8b7e7ee4e1..375df4e532a 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -215,7 +215,7 @@ impl<'context> Elaborator<'context> { UnresolvedTypeExpression::Constant(0, span) }); - let length = self.convert_expression_type(length, span); + let length = self.convert_expression_type(length, &Kind::u32(), span); let (repeated_element, elem_type) = self.elaborate_expression(*repeated_element); let length_clone = length.clone(); diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index bb1161650c3..816535ba564 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -77,22 +77,22 @@ impl<'context> Elaborator<'context> { FieldElement => Type::FieldElement, Array(size, elem) => { let elem = Box::new(self.resolve_type_inner(*elem, kind)); - let size = self.convert_expression_type(size, span); + let size = self.convert_expression_type(size, &Kind::u32(), span); Type::Array(Box::new(size), elem) } Slice(elem) => { let elem = Box::new(self.resolve_type_inner(*elem, kind)); Type::Slice(elem) } - Expression(expr) => self.convert_expression_type(expr, span), + Expression(expr) => self.convert_expression_type(expr, kind, span), Integer(sign, bits) => Type::Integer(sign, bits), Bool => Type::Bool, String(size) => { - let resolved_size = self.convert_expression_type(size, span); + let resolved_size = self.convert_expression_type(size, &Kind::u32(), span); Type::String(Box::new(resolved_size)) } FormatString(size, fields) => { - let resolved_size = self.convert_expression_type(size, span); + let resolved_size = self.convert_expression_type(size, &Kind::u32(), span); let fields = self.resolve_type_inner(*fields, kind); Type::FmtString(Box::new(resolved_size), Box::new(fields)) } @@ -426,37 +426,25 @@ impl<'context> Elaborator<'context> { pub(super) fn convert_expression_type( &mut self, length: UnresolvedTypeExpression, + expected_kind: &Kind, span: Span, ) -> Type { match length { UnresolvedTypeExpression::Variable(path) => { - let resolved_length = - self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { - self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); - Type::Constant(0, Kind::u32()) - }); - - if let Type::NamedGeneric(ref _type_var, ref _name, ref kind) = resolved_length { - if !kind.is_numeric() { - self.push_err(TypeCheckError::TypeKindMismatch { - expected_kind: Kind::u32().to_string(), - expr_kind: kind.to_string(), - expr_span: span, - }); - return Type::Error; - } - } - resolved_length + let typ = self.resolve_named_type(path, GenericTypeArgs::default()); + self.check_kind(typ, expected_kind, span) + } + UnresolvedTypeExpression::Constant(int, _span) => { + Type::Constant(int, expected_kind.clone()) } - UnresolvedTypeExpression::Constant(int, _span) => Type::Constant(int, Kind::u32()), UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); - let lhs = self.convert_expression_type(*lhs, lhs_span); - let rhs = self.convert_expression_type(*rhs, rhs_span); + let lhs = self.convert_expression_type(*lhs, expected_kind, lhs_span); + let rhs = self.convert_expression_type(*rhs, expected_kind, rhs_span); match (lhs, rhs) { (Type::Constant(lhs, lhs_kind), Type::Constant(rhs, rhs_kind)) => { - if lhs_kind != rhs_kind { + if !lhs_kind.unifies(&rhs_kind) { self.push_err(TypeCheckError::TypeKindMismatch { expected_kind: lhs_kind.to_string(), expr_kind: rhs_kind.to_string(), @@ -474,10 +462,27 @@ impl<'context> Elaborator<'context> { (lhs, rhs) => Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)).canonicalize(), } } - UnresolvedTypeExpression::AsTraitPath(path) => self.resolve_as_trait_path(*path), + UnresolvedTypeExpression::AsTraitPath(path) => { + let typ = self.resolve_as_trait_path(*path); + self.check_kind(typ, expected_kind, span) + } } } + fn check_kind(&mut self, typ: Type, expected_kind: &Kind, span: Span) -> Type { + if let Some(kind) = typ.kind() { + if !kind.unifies(expected_kind) { + self.push_err(TypeCheckError::TypeKindMismatch { + expected_kind: expected_kind.to_string(), + expr_kind: kind.to_string(), + expr_span: span, + }); + return Type::Error; + } + } + typ + } + fn resolve_as_trait_path(&mut self, path: AsTraitPath) -> Type { let span = path.trait_path.span; let Some(trait_id) = self.resolve_trait_by_path(path.trait_path.clone()) else { diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index a24ee2635be..7bb9fb83e70 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -151,12 +151,36 @@ impl Kind { } pub(crate) fn matches_opt(&self, other: Option) -> bool { - other.as_ref().map_or(true, |other_kind| self == other_kind) + other.as_ref().map_or(true, |other_kind| self.unifies(other_kind)) } pub(crate) fn u32() -> Self { Self::Numeric(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo))) } + + /// Unifies this kind with the other. Returns true on success + pub(crate) fn unifies(&self, other: &Kind) -> bool { + match (self, other) { + (Kind::Normal, Kind::Normal) => true, + (Kind::Numeric(lhs), Kind::Numeric(rhs)) => { + let mut bindings = TypeBindings::new(); + let unifies = lhs.try_unify(rhs, &mut bindings).is_ok(); + if unifies { + Type::apply_type_bindings(bindings); + } + unifies + } + _ => false, + } + } + + pub(crate) fn unify(&self, other: &Kind) -> Result<(), UnificationError> { + if self.unifies(other) { + Ok(()) + } else { + Err(UnificationError) + } + } } impl std::fmt::Display for Kind { @@ -1465,13 +1489,13 @@ impl Type { } } - (NamedGeneric(binding_a, name_a, _), NamedGeneric(binding_b, name_b, _)) => { + (NamedGeneric(binding_a, name_a, kind_a), NamedGeneric(binding_b, name_b, kind_b)) => { // Bound NamedGenerics are caught by the check above assert!(binding_a.borrow().is_unbound()); assert!(binding_b.borrow().is_unbound()); if name_a == name_b { - Ok(()) + kind_a.unify(kind_b) } else { Err(UnificationError) } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 44e1cce5bf8..0b8773eea77 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -1616,25 +1616,30 @@ fn numeric_generic_binary_operation_type_mismatch() { #[test] fn bool_generic_as_loop_bound() { let src = r#" - pub fn read() { - let mut fields = [0; N]; - for i in 0..N { + pub fn read() { // error here + let mut fields = [0; N]; // error here + for i in 0..N { // error here fields[i] = i + 1; } assert(fields[0] == 1); } "#; let errors = get_program_errors(src); - assert_eq!(errors.len(), 2); + assert_eq!(errors.len(), 3); assert!(matches!( errors[0].0, CompilationError::ResolverError(ResolverError::UnsupportedNumericGenericType { .. }), )); + assert!(matches!( + errors[1].0, + CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. }), + )); + let CompilationError::TypeError(TypeCheckError::TypeMismatch { expected_typ, expr_typ, .. - }) = &errors[1].0 + }) = &errors[2].0 else { panic!("Got an error other than a type mismatch"); }; @@ -1646,7 +1651,7 @@ fn bool_generic_as_loop_bound() { #[test] fn numeric_generic_in_function_signature() { let src = r#" - pub fn foo(arr: [Field; N]) -> [Field; N] { arr } + pub fn foo(arr: [Field; N]) -> [Field; N] { arr } "#; assert_no_errors(src); } @@ -3644,3 +3649,54 @@ fn does_not_crash_when_passing_mutable_undefined_variable() { assert_eq!(name, "undefined"); } + +#[test] +fn infer_globals_to_u32_from_type_use() { + let src = r#" + global ARRAY_LEN = 3; + global STR_LEN = 2; + global FMT_STR_LEN = 2; + + fn main() { + let _a: [u32; ARRAY_LEN] = [1, 2, 3]; + let _b: str = "hi"; + let _c: fmtstr = f"hi"; + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 0); +} + +#[test] +fn non_u32_in_array_length() { + let src = r#" + global ARRAY_LEN: u8 = 3; + + fn main() { + let _a: [u32; ARRAY_LEN] = [1, 2, 3]; + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + assert!(matches!( + errors[0].0, + CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. }) + )); +} + +#[test] +fn use_non_u32_generic_in_struct() { + let src = r#" + struct S {} + + fn main() { + let _: S<3> = S {}; + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 0); +} diff --git a/test_programs/compile_success_empty/numeric_generics_explicit/src/main.nr b/test_programs/compile_success_empty/numeric_generics_explicit/src/main.nr index 7c4f7761ff6..5c618e9db36 100644 --- a/test_programs/compile_success_empty/numeric_generics_explicit/src/main.nr +++ b/test_programs/compile_success_empty/numeric_generics_explicit/src/main.nr @@ -28,7 +28,7 @@ fn main() { } // Used in the signature of a function -fn id(x: [Field; I]) -> [Field; I] { +fn id(x: [Field; I]) -> [Field; I] { x }