diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index b1e559895b8..f0fc482cae0 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -20,8 +20,8 @@ use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TraitId, Type use crate::parser::{ParserError, SortedModule}; use crate::{ ExpressionKind, Ident, LetStatement, Literal, NoirFunction, NoirStruct, NoirTrait, - NoirTypeAlias, Path, PathKind, Type, UnresolvedGenerics, UnresolvedTraitConstraint, - UnresolvedType, + NoirTypeAlias, Path, PathKind, Type, TypeBindings, UnresolvedGenerics, + UnresolvedTraitConstraint, UnresolvedType, }; use fm::FileId; use iter_extended::vecmap; @@ -90,6 +90,7 @@ pub struct UnresolvedTraitImpl { pub file_id: FileId, pub module_id: LocalModuleId, pub trait_id: Option, + pub trait_generics: Vec, pub trait_path: Path, pub object_type: UnresolvedType, pub methods: UnresolvedFunctions, @@ -456,19 +457,44 @@ fn type_check_functions( } // TODO(vitkov): Move this out of here and into type_check +#[allow(clippy::too_many_arguments)] pub(crate) fn check_methods_signatures( resolver: &mut Resolver, impl_methods: &Vec<(FileId, FuncId)>, trait_id: TraitId, + trait_name_span: Span, + // These are the generics on the trait itself from the impl. + // E.g. in `impl Foo for Bar`, this is `vec![A, B]`. + trait_generics: Vec, trait_impl_generic_count: usize, + file_id: FileId, errors: &mut Vec<(CompilationError, FileId)>, ) { let self_type = resolver.get_self_type().expect("trait impl must have a Self type").clone(); + let trait_generics = vecmap(trait_generics, |typ| resolver.resolve_type(typ)); // Temporarily bind the trait's Self type to self_type so we can type check let the_trait = resolver.interner.get_trait_mut(trait_id); the_trait.self_type_typevar.bind(self_type); + if trait_generics.len() != the_trait.generics.len() { + let error = DefCollectorErrorKind::MismatchGenericCount { + actual_generic_count: trait_generics.len(), + expected_generic_count: the_trait.generics.len(), + // Preferring to use 'here' over a more precise term like 'this reference' + // to try to make the error easier to understand for newer users. + location: "here it", + origin: the_trait.name.to_string(), + span: trait_name_span, + }; + errors.push((error.into(), file_id)); + } + + // We also need to bind the traits generics to the trait's generics on the impl + for ((_, generic), binding) in the_trait.generics.iter().zip(trait_generics) { + generic.bind(binding); + } + // Temporarily take the trait's methods so we can use both them and a mutable reference // to the interner within the loop. let trait_methods = std::mem::take(&mut the_trait.methods); @@ -482,49 +508,44 @@ pub(crate) fn check_methods_signatures( if let Some(trait_method) = trait_methods.iter().find(|method| method.name.0.contents == func_name) { - let mut typecheck_errors = Vec::new(); let impl_method = resolver.interner.function_meta(func_id); - let (impl_function_type, _) = impl_method.typ.instantiate(resolver.interner); - let impl_method_generic_count = impl_method.typ.generic_count() - trait_impl_generic_count; // We subtract 1 here to account for the implicit generic `Self` type that is on all // traits (and thus trait methods) but is not required (or allowed) for users to specify. - let trait_method_generic_count = trait_method.generics().len() - 1; + let the_trait = resolver.interner.get_trait(trait_id); + let trait_method_generic_count = + trait_method.generics().len() - 1 - the_trait.generics.len(); if impl_method_generic_count != trait_method_generic_count { - let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics { - impl_method_generic_count, - trait_method_generic_count, - trait_name: resolver.interner.get_trait(trait_id).name.to_string(), - method_name: func_name.to_string(), + let trait_name = resolver.interner.get_trait(trait_id).name.clone(); + + let error = DefCollectorErrorKind::MismatchGenericCount { + actual_generic_count: impl_method_generic_count, + expected_generic_count: trait_method_generic_count, + origin: format!("{}::{}", trait_name, func_name), + location: "this method", span: impl_method.location.span, }; errors.push((error.into(), *file_id)); } - if let Type::Function(impl_params, _, _) = impl_function_type { - if trait_method.arguments().len() == impl_params.len() { - // Check the parameters of the impl method against the parameters of the trait method - let args = trait_method.arguments().iter(); - let args_and_params = args.zip(&impl_params).zip(&impl_method.parameters.0); - - for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in - args_and_params.enumerate() - { - expected.unify(actual, &mut typecheck_errors, || { - TypeCheckError::TraitMethodParameterTypeMismatch { - method_name: func_name.to_string(), - expected_typ: expected.to_string(), - actual_typ: actual.to_string(), - parameter_span: hir_pattern.span(), - parameter_index: parameter_index + 1, - } - }); - } - } else { + // This instantiation is technically not needed. We could bind each generic in the + // trait function to the impl's corresponding generic but to do so we'd have to rely + // on the trait function's generics being first in the generic list, since the same + // list also contains the generic `Self` variable, and any generics on the trait itself. + // + // Instantiating the impl method's generics here instead is a bit less precise but + // doesn't rely on any orderings that may be changed. + let impl_function_type = impl_method.typ.instantiate(resolver.interner).0; + + let mut bindings = TypeBindings::new(); + let mut typecheck_errors = Vec::new(); + + if let Type::Function(impl_params, impl_return, _) = impl_function_type.as_monotype() { + if trait_method.arguments().len() != impl_params.len() { let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters { actual_num_parameters: impl_method.parameters.0.len(), expected_num_parameters: trait_method.arguments().len(), @@ -534,28 +555,51 @@ pub(crate) fn check_methods_signatures( }; errors.push((error.into(), *file_id)); } - } - // Check that impl method return type matches trait return type: - let resolved_return_type = - resolver.resolve_type(impl_method.return_type.get_type().into_owned()); + // Check the parameters of the impl method against the parameters of the trait method + let args = trait_method.arguments().iter(); + let args_and_params = args.zip(impl_params).zip(&impl_method.parameters.0); - // TODO: This is not right since it may bind generic return types - trait_method.return_type().unify(&resolved_return_type, &mut typecheck_errors, || { - let impl_method = resolver.interner.function_meta(func_id); - let ret_type_span = impl_method.return_type.get_type().span; - let expr_span = ret_type_span.expect("return type must always have a span"); + for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in + args_and_params.enumerate() + { + if expected.try_unify(actual, &mut bindings).is_err() { + typecheck_errors.push(TypeCheckError::TraitMethodParameterTypeMismatch { + method_name: func_name.to_string(), + expected_typ: expected.to_string(), + actual_typ: actual.to_string(), + parameter_span: hir_pattern.span(), + parameter_index: parameter_index + 1, + }); + } + } - let expected_typ = trait_method.return_type().to_string(); - let expr_typ = impl_method.return_type().to_string(); - TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span } - }); + if trait_method.return_type().try_unify(impl_return, &mut bindings).is_err() { + let impl_method = resolver.interner.function_meta(func_id); + let ret_type_span = impl_method.return_type.get_type().span; + let expr_span = ret_type_span.expect("return type must always have a span"); + + let expected_typ = trait_method.return_type().to_string(); + let expr_typ = impl_method.return_type().to_string(); + let error = TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span }; + typecheck_errors.push(error); + } + } else { + unreachable!( + "impl_function_type is not a function type, it is: {impl_function_type}" + ); + } errors.extend(typecheck_errors.iter().cloned().map(|e| (e.into(), *file_id))); } } + // Now unbind `Self` and the trait's generics let the_trait = resolver.interner.get_trait_mut(trait_id); the_trait.set_methods(trait_methods); the_trait.self_type_typevar.unbind(the_trait.self_type_typevar_id); + + for (old_id, generic) in &the_trait.generics { + generic.unbind(*old_id); + } } diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 04791b11b2a..0fd5d415724 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -72,7 +72,7 @@ pub fn collect_defs( errors.extend(collector.collect_functions(context, ast.functions, crate_id)); - errors.extend(collector.collect_trait_impls(context, ast.trait_impls, crate_id)); + collector.collect_trait_impls(context, ast.trait_impls, crate_id); collector.collect_impls(context, ast.impls, crate_id); @@ -144,7 +144,7 @@ impl<'a> ModCollector<'a> { context: &mut Context, impls: Vec, krate: CrateId, - ) -> Vec<(CompilationError, fm::FileId)> { + ) { for trait_impl in impls { let trait_name = trait_impl.trait_name.clone(); @@ -168,11 +168,11 @@ impl<'a> ModCollector<'a> { generics: trait_impl.impl_generics, where_clause: trait_impl.where_clause, trait_id: None, // will be filled later + trait_generics: trait_impl.trait_generics, }; self.def_collector.collected_traits_impls.push(unresolved_trait_impl); } - vec![] } fn collect_trait_impl_function_overrides( diff --git a/compiler/noirc_frontend/src/hir/def_collector/errors.rs b/compiler/noirc_frontend/src/hir/def_collector/errors.rs index 2b91c4b36c5..de45be48c4e 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/errors.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/errors.rs @@ -49,12 +49,12 @@ pub enum DefCollectorErrorKind { method_name: String, span: Span, }, - #[error("Mismatched number of generics in impl method")] - MismatchTraitImplementationNumGenerics { - impl_method_generic_count: usize, - trait_method_generic_count: usize, - trait_name: String, - method_name: String, + #[error("Mismatched number of generics in {location}")] + MismatchGenericCount { + actual_generic_count: usize, + expected_generic_count: usize, + location: &'static str, + origin: String, span: Span, }, #[error("Method is not defined in trait")] @@ -188,16 +188,16 @@ impl From for Diagnostic { "`{trait_name}::{method_name}` expects {expected_num_parameters} parameter{plural}, but this method has {actual_num_parameters}"); Diagnostic::simple_error(primary_message, "".to_string(), span) } - DefCollectorErrorKind::MismatchTraitImplementationNumGenerics { - impl_method_generic_count, - trait_method_generic_count, - trait_name, - method_name, + DefCollectorErrorKind::MismatchGenericCount { + actual_generic_count, + expected_generic_count, + location, + origin, span, } => { - let plural = if trait_method_generic_count == 1 { "" } else { "s" }; + let plural = if expected_generic_count == 1 { "" } else { "s" }; let primary_message = format!( - "`{trait_name}::{method_name}` expects {trait_method_generic_count} generic{plural}, but this method has {impl_method_generic_count}"); + "`{origin}` expects {expected_generic_count} generic{plural}, but {location} has {actual_generic_count}"); Diagnostic::simple_error(primary_message, "".to_string(), span) } DefCollectorErrorKind::MethodNotInTrait { trait_name, impl_method } => { diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index c2f787313c6..e95f533f090 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -60,8 +60,8 @@ pub enum ResolverError { NonStructWithGenerics { span: Span }, #[error("Cannot apply generics on Self type")] GenericsOnSelfType { span: Span }, - #[error("Incorrect amount of arguments to generic type constructor")] - IncorrectGenericCount { span: Span, struct_type: String, actual: usize, expected: usize }, + #[error("Incorrect amount of arguments to {item_name}")] + IncorrectGenericCount { span: Span, item_name: String, actual: usize, expected: usize }, #[error("{0}")] ParserError(Box), #[error("Function is not defined in a contract yet sets its contract visibility")] @@ -259,12 +259,12 @@ impl From for Diagnostic { "Use an explicit type name or apply the generics at the start of the impl instead".into(), span, ), - ResolverError::IncorrectGenericCount { span, struct_type, actual, expected } => { + ResolverError::IncorrectGenericCount { span, item_name, actual, expected } => { let expected_plural = if expected == 1 { "" } else { "s" }; let actual_plural = if actual == 1 { "is" } else { "are" }; Diagnostic::simple_error( - format!("The struct type {struct_type} has {expected} generic{expected_plural} but {actual} {actual_plural} given here"), + format!("`{item_name}` has {expected} generic argument{expected_plural} but {actual} {actual_plural} given here"), "Incorrect number of generic arguments".into(), span, ) diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index b9770f34e1e..b78c6a9e86d 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -15,7 +15,7 @@ use crate::hir_def::expr::{ HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar, HirCastExpression, HirConstructorExpression, HirExpression, HirIdent, HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, - HirMethodCallExpression, HirPrefixExpression, + HirMethodCallExpression, HirPrefixExpression, ImplKind, }; use crate::hir_def::traits::{Trait, TraitConstraint}; @@ -29,7 +29,7 @@ use crate::hir::def_map::{LocalModuleId, ModuleDefId, TryFromModuleDefId, MAIN_F use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern}; use crate::node_interner::{ DefinitionId, DefinitionKind, ExprId, FuncId, NodeInterner, StmtId, StructId, TraitId, - TraitImplId, TraitImplKind, + TraitImplId, TraitMethodId, }; use crate::{ hir::{def_map::CrateDefMap, resolution::path_resolver::PathResolver}, @@ -40,8 +40,8 @@ use crate::{ ArrayLiteral, ContractFunctionType, Distinctness, ForRange, FunctionDefinition, FunctionReturnType, FunctionVisibility, Generics, LValue, NoirStruct, NoirTypeAlias, Param, Path, PathKind, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable, - UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, - UnresolvedTypeExpression, Visibility, ERROR_IDENT, + TypeVariableKind, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, + UnresolvedTypeData, UnresolvedTypeExpression, Visibility, ERROR_IDENT, }; use fm::FileId; use iter_extended::vecmap; @@ -267,7 +267,7 @@ impl<'a> Resolver<'a> { let has_underscore_prefix = variable_name.starts_with('_'); // XXX: This is used for development mode, and will be removed metadata.warn_if_unused && metadata.num_times_used == 0 && !has_underscore_prefix }); - unused_vars.extend(unused_variables.map(|(_, meta)| meta.ident)); + unused_vars.extend(unused_variables.map(|(_, meta)| meta.ident.clone())); } /// Run the given function in a new scope. @@ -304,8 +304,9 @@ impl<'a> Resolver<'a> { let location = Location::new(name.span(), self.file); let id = self.interner.push_definition(name.0.contents.clone(), mutable, definition, location); - let ident = HirIdent { location, id }; - let resolver_meta = ResolverMeta { num_times_used: 0, ident, warn_if_unused }; + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused }; let scope = self.scopes.get_mut_scope(); let old_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); @@ -325,8 +326,6 @@ impl<'a> Resolver<'a> { fn add_global_variable_decl(&mut self, name: Ident, definition: DefinitionKind) -> HirIdent { let scope = self.scopes.get_mut_scope(); - let ident; - let resolver_meta; // This check is necessary to maintain the same definition ids in the interner. Currently, each function uses a new resolver that has its own ScopeForest and thus global scope. // We must first check whether an existing definition ID has been inserted as otherwise there will be multiple definitions for the same global statement. @@ -341,17 +340,20 @@ impl<'a> Resolver<'a> { } } - if let Some(id) = stmt_id { + let (ident, resolver_meta) = if let Some(id) = stmt_id { let hir_let_stmt = self.interner.let_statement(&id); - ident = hir_let_stmt.ident(); - resolver_meta = ResolverMeta { num_times_used: 0, ident, warn_if_unused: true }; + let ident = hir_let_stmt.ident(); + let resolver_meta = ResolverMeta { num_times_used: 0, ident, warn_if_unused: true }; + (hir_let_stmt.ident(), resolver_meta) } else { let location = Location::new(name.span(), self.file); let id = self.interner.push_definition(name.0.contents.clone(), false, definition, location); - ident = HirIdent { location, id }; - resolver_meta = ResolverMeta { num_times_used: 0, ident, warn_if_unused: true }; - } + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused: true }; + (ident, resolver_meta) + }; let old_global_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); if let Some(old_global_value) = old_global_value { @@ -376,7 +378,7 @@ impl<'a> Resolver<'a> { self.push_err(error); let id = DefinitionId::dummy_id(); let location = Location::new(name.span(), self.file); - (HirIdent { location, id }, 0) + (HirIdent::non_trait_method(id, location), 0) }) } @@ -389,7 +391,7 @@ impl<'a> Resolver<'a> { if let Some((variable_found, scope)) = variable { variable_found.num_times_used += 1; let id = variable_found.ident.id; - Ok((HirIdent { location, id }, scope)) + Ok((HirIdent::non_trait_method(id, location), scope)) } else { Err(ResolverError::VariableNotDeclared { name: name.0.contents.clone(), @@ -419,8 +421,27 @@ impl<'a> Resolver<'a> { constraint: UnresolvedTraitConstraint, ) -> Option { let typ = self.resolve_type(constraint.typ); - let trait_id = self.lookup_trait_or_error(constraint.trait_bound.trait_path)?.id; - Some(TraitConstraint { typ, trait_id }) + let trait_generics = + vecmap(constraint.trait_bound.trait_generics, |typ| self.resolve_type(typ)); + + let span = constraint.trait_bound.trait_path.span(); + let the_trait = self.lookup_trait_or_error(constraint.trait_bound.trait_path)?; + let trait_id = the_trait.id; + + let expected_generics = the_trait.generics.len(); + let actual_generics = trait_generics.len(); + + if actual_generics != expected_generics { + let item_name = the_trait.name.to_string(); + self.push_err(ResolverError::IncorrectGenericCount { + span, + item_name, + actual: actual_generics, + expected: expected_generics, + }); + } + + Some(TraitConstraint { typ, trait_id, trait_generics }) } /// Translates an UnresolvedType into a Type and appends any @@ -564,11 +585,13 @@ impl<'a> Resolver<'a> { fn resolve_trait_as_type( &mut self, path: Path, - _args: Vec, - _new_variables: &mut Generics, + args: Vec, + new_variables: &mut Generics, ) -> Type { + let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + if let Some(t) = self.lookup_trait_or_error(path) { - Type::TraitAsType(t.id, Rc::new(t.name.to_string())) + Type::TraitAsType(t.id, Rc::new(t.name.to_string()), args) } else { Type::Error } @@ -584,7 +607,7 @@ impl<'a> Resolver<'a> { if args.len() != expected_count { self.errors.push(ResolverError::IncorrectGenericCount { span, - struct_type: type_name(), + item_name: type_name(), actual: args.len(), expected: expected_count, }); @@ -667,17 +690,17 @@ impl<'a> Resolver<'a> { Some(Ok(found)) => return found, // Try to look it up as a global, but still issue the first error if we fail Some(Err(error)) => match self.lookup_global(path) { - Ok(id) => return (HirIdent { location, id }, 0), + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), Err(_) => error, }, None => match self.lookup_global(path) { - Ok(id) => return (HirIdent { location, id }, 0), + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), Err(error) => error, }, }; self.push_err(error); let id = DefinitionId::dummy_id(); - (HirIdent { location, id }, 0) + (HirIdent::non_trait_method(id, location), 0) } /// Translates an UnresolvedType to a Type @@ -736,7 +759,6 @@ impl<'a> Resolver<'a> { let name = Rc::new(generic.0.contents.clone()); if let Some((_, _, first_span)) = self.find_generic(&name) { - let span = generic.0.span(); self.errors.push(ResolverError::DuplicateDefinition { name: generic.0.contents.clone(), first_span: *first_span, @@ -750,6 +772,32 @@ impl<'a> Resolver<'a> { }) } + /// Add the given existing generics to scope. + /// This is useful for adding the same generics to many items. E.g. apply impl generics + /// to each function in the impl or trait generics to each item in the trait. + pub fn add_existing_generics(&mut self, names: &UnresolvedGenerics, generics: &Generics) { + assert_eq!(names.len(), generics.len()); + + for (name, (_id, typevar)) in names.iter().zip(generics) { + self.add_existing_generic(&name.0.contents, name.0.span(), typevar.clone()); + } + } + + pub fn add_existing_generic(&mut self, name: &str, span: Span, typevar: TypeVariable) { + // Check for name collisions of this generic + let rc_name = Rc::new(name.to_owned()); + + if let Some((_, _, first_span)) = self.find_generic(&rc_name) { + self.errors.push(ResolverError::DuplicateDefinition { + name: name.to_owned(), + first_span: *first_span, + second_span: span, + }); + } else { + self.generics.push((rc_name, typevar, span)); + } + } + pub fn resolve_struct_fields( mut self, unresolved: NoirStruct, @@ -778,12 +826,13 @@ impl<'a> Resolver<'a> { /// there's a bunch of other places where trait constraints can pop up fn resolve_trait_constraints( &mut self, - where_clause: &Vec, + where_clause: &[UnresolvedTraitConstraint], ) -> Vec { - vecmap(where_clause, |constraint| TraitConstraint { - typ: self.resolve_type(constraint.typ.clone()), - trait_id: constraint.trait_bound.trait_id.unwrap_or_else(TraitId::dummy_id), - }) + where_clause + .iter() + .cloned() + .filter_map(|constraint| self.resolve_trait_constraint(constraint)) + .collect() } /// Extract metadata from a NoirFunction @@ -793,7 +842,7 @@ impl<'a> Resolver<'a> { fn extract_meta(&mut self, func: &NoirFunction, func_id: FuncId) -> FuncMeta { let location = Location::new(func.name_ident().span(), self.file); let id = self.interner.function_definition_id(func_id); - let name_ident = HirIdent { id, location }; + let name_ident = HirIdent::non_trait_method(id, location); let attributes = func.attributes().clone(); @@ -1115,7 +1164,7 @@ impl<'a> Resolver<'a> { match lvalue { LValue::Ident(ident) => { let ident = self.find_variable_or_default(&ident); - self.resolve_local_variable(ident.0, ident.1); + self.resolve_local_variable(ident.0.clone(), ident.1); HirLValue::Ident(ident.0, Type::Error) } @@ -1201,9 +1250,10 @@ impl<'a> Resolver<'a> { .position(|capture| capture.ident.id == hir_ident.id); if pos.is_none() { - self.lambda_stack[lambda_index] - .captures - .push(HirCapturedVar { ident: hir_ident, transitive_capture_index }); + self.lambda_stack[lambda_index].captures.push(HirCapturedVar { + ident: hir_ident.clone(), + transitive_capture_index, + }); } if lambda_index + 1 < self.lambda_stack.len() { @@ -1250,14 +1300,13 @@ impl<'a> Resolver<'a> { Literal::Unit => HirLiteral::Unit, }), ExpressionKind::Variable(path) => { - if let Some((hir_expr, object_type)) = self.resolve_trait_generic_path(&path) { - let expr_id = self.interner.push_expr(hir_expr); - self.interner.push_expr_location(expr_id, expr.span, self.file); - self.interner.select_impl_for_expression( - expr_id, - TraitImplKind::Assumed { object_type }, - ); - return expr_id; + if let Some((method, constraint, assumed)) = self.resolve_trait_generic_path(&path) + { + HirExpression::Ident(HirIdent { + location: Location::new(expr.span, self.file), + id: self.interner.trait_method_id(method), + impl_kind: ImplKind::TraitMethod(method, constraint, assumed), + }) } else { // If the Path is being used as an Expression, then it is referring to a global from a separate module // Otherwise, then it is referring to an Identifier @@ -1292,7 +1341,7 @@ impl<'a> Resolver<'a> { } DefinitionKind::Local(_) => { // only local variables can be captured by closures. - self.resolve_local_variable(hir_ident, var_scope_index); + self.resolve_local_variable(hir_ident.clone(), var_scope_index); } } } @@ -1632,7 +1681,7 @@ impl<'a> Resolver<'a> { fn resolve_trait_static_method_by_self( &mut self, path: &Path, - ) -> Option<(HirExpression, Type)> { + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { let trait_id = self.trait_id?; if path.kind == PathKind::Plain && path.segments.len() == 2 { @@ -1642,15 +1691,23 @@ impl<'a> Resolver<'a> { if name == SELF_TYPE_NAME { let the_trait = self.interner.get_trait(trait_id); let method = the_trait.find_method(method.0.contents.as_str())?; - let self_type = self.self_type.clone()?; - return Some((HirExpression::TraitMethodReference(method), self_type)); + + let constraint = TraitConstraint { + typ: self.self_type.clone()?, + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); } } None } // this resolves TraitName::some_static_method - fn resolve_trait_static_method(&mut self, path: &Path) -> Option<(HirExpression, Type)> { + fn resolve_trait_static_method( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { if path.kind == PathKind::Plain && path.segments.len() == 2 { let method = &path.segments[1]; @@ -1660,17 +1717,27 @@ impl<'a> Resolver<'a> { let the_trait = self.interner.get_trait(trait_id); let method = the_trait.find_method(method.0.contents.as_str())?; - let self_type = Type::type_variable(the_trait.self_type_typevar_id); - return Some((HirExpression::TraitMethodReference(method), self_type)); + let constraint = TraitConstraint { + typ: Type::TypeVariable( + the_trait.self_type_typevar.clone(), + TypeVariableKind::Normal, + ), + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); } None } - // this resolves a static trait method T::trait_method by iterating over the where clause + // This resolves a static trait method T::trait_method by iterating over the where clause + // + // Returns the trait method, object type, and the trait generics. + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` fn resolve_trait_method_by_named_generic( &mut self, path: &Path, - ) -> Option<(HirExpression, Type)> { + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { if path.segments.len() != 2 { return None; } @@ -1691,8 +1758,14 @@ impl<'a> Resolver<'a> { if let Some(method) = the_trait.find_method(path.segments.last().unwrap().0.contents.as_str()) { - let self_type = self.resolve_type(typ.clone()); - return Some((HirExpression::TraitMethodReference(method), self_type)); + let constraint = TraitConstraint { + trait_id, + typ: self.resolve_type(typ.clone()), + trait_generics: vecmap(trait_bound.trait_generics, |typ| { + self.resolve_type(typ) + }), + }; + return Some((method, constraint, true)); } } } @@ -1700,7 +1773,14 @@ impl<'a> Resolver<'a> { None } - fn resolve_trait_generic_path(&mut self, path: &Path) -> Option<(HirExpression, Type)> { + // Try to resolve the given trait method path. + // + // Returns the trait method, object type, and the trait generics. + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_generic_path( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { self.resolve_trait_static_method_by_self(path) .or_else(|| self.resolve_trait_static_method(path)) .or_else(|| self.resolve_trait_method_by_named_generic(path)) @@ -1765,7 +1845,8 @@ impl<'a> Resolver<'a> { let variable = scope_tree.find(ident_name); if let Some((old_value, _)) = variable { old_value.num_times_used += 1; - let expr_id = self.interner.push_expr(HirExpression::Ident(old_value.ident)); + let ident = HirExpression::Ident(old_value.ident.clone()); + let expr_id = self.interner.push_expr(ident); self.interner.push_expr_location(expr_id, call_expr_span, self.file); fmt_str_idents.push(expr_id); } else if ident_name.parse::().is_ok() { diff --git a/compiler/noirc_frontend/src/hir/resolution/traits.rs b/compiler/noirc_frontend/src/hir/resolution/traits.rs index 545a46fd8e4..f08d9c50c84 100644 --- a/compiler/noirc_frontend/src/hir/resolution/traits.rs +++ b/compiler/noirc_frontend/src/hir/resolution/traits.rs @@ -18,7 +18,7 @@ use crate::{ }, hir_def::traits::{TraitConstant, TraitFunction, TraitImpl, TraitType}, node_interner::{FuncId, NodeInterner, TraitId}, - Path, Shared, TraitItem, Type, TypeBinding, TypeVariableKind, + Generics, Path, Shared, TraitItem, Type, TypeBinding, TypeVariable, TypeVariableKind, }; use super::{ @@ -38,8 +38,14 @@ pub(crate) fn resolve_traits( for (trait_id, unresolved_trait) in &traits { context.def_interner.push_empty_trait(*trait_id, unresolved_trait); } - let mut res: Vec<(CompilationError, FileId)> = vec![]; + let mut all_errors = Vec::new(); + for (trait_id, unresolved_trait) in traits { + let generics = vecmap(&unresolved_trait.trait_def.generics, |_| { + let id = context.def_interner.next_type_variable_id(); + (id, TypeVariable::unbound(id)) + }); + // Resolve order // 1. Trait Types ( Trait constants can have a trait type, therefore types before constants) let _ = resolve_trait_types(context, crate_id, &unresolved_trait); @@ -47,10 +53,13 @@ pub(crate) fn resolve_traits( let _ = resolve_trait_constants(context, crate_id, &unresolved_trait); // 3. Trait Methods let (methods, errors) = - resolve_trait_methods(context, trait_id, crate_id, &unresolved_trait); - res.extend(errors); + resolve_trait_methods(context, trait_id, crate_id, &unresolved_trait, &generics); + + all_errors.extend(errors); + context.def_interner.update_trait(trait_id, |trait_def| { trait_def.set_methods(methods); + trait_def.generics = generics; }); // This check needs to be after the trait's methods are set since @@ -60,7 +69,7 @@ pub(crate) fn resolve_traits( context.def_interner.try_add_operator_trait(trait_id); } } - res + all_errors } fn resolve_trait_types( @@ -85,6 +94,7 @@ fn resolve_trait_methods( trait_id: TraitId, crate_id: CrateId, unresolved_trait: &UnresolvedTrait, + trait_generics: &Generics, ) -> (Vec, Vec<(CompilationError, FileId)>) { let interner = &mut context.def_interner; let def_maps = &mut context.def_maps; @@ -109,12 +119,15 @@ fn resolve_trait_methods( } = item { let the_trait = interner.get_trait(trait_id); - let self_type = - Type::TypeVariable(the_trait.self_type_typevar.clone(), TypeVariableKind::Normal); + let self_typevar = the_trait.self_type_typevar.clone(); + let self_type = Type::TypeVariable(self_typevar.clone(), TypeVariableKind::Normal); + let name_span = the_trait.name.span(); let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file); resolver.add_generics(generics); - resolver.set_self_type(Some(self_type)); + resolver.add_existing_generics(&unresolved_trait.trait_def.generics, trait_generics); + resolver.add_existing_generic("Self", name_span, self_typevar); + resolver.set_self_type(Some(self_type.clone())); let func_id = unresolved_trait.method_ids[&name.0.contents]; let (_, func_meta) = resolver.resolve_trait_function( @@ -129,16 +142,17 @@ fn resolve_trait_methods( let arguments = vecmap(parameters, |param| resolver.resolve_type(param.1.clone())); let return_type = resolver.resolve_type(return_type.get_type().into_owned()); - let mut generics = vecmap(resolver.get_generics(), |(_, type_var, _)| match &*type_var - .borrow() - { - TypeBinding::Unbound(id) => (*id, type_var.clone()), - TypeBinding::Bound(binding) => unreachable!("Trait generic was bound to {binding}"), - }); + let generics = + vecmap(resolver.get_generics(), |(_, type_var, _)| match &*type_var.borrow() { + TypeBinding::Unbound(id) => (*id, type_var.clone()), + TypeBinding::Bound(binding) => { + unreachable!("Trait generic was bound to {binding}") + } + }); // Ensure the trait is generic over the Self type as well - let the_trait = resolver.interner.get_trait(trait_id); - generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar.clone())); + // let the_trait = resolver.interner.get_trait(trait_id); + // generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar.clone())); let default_impl_list: Vec<_> = unresolved_trait .fns_with_default_impl @@ -382,9 +396,12 @@ pub(crate) fn resolve_trait_impls( let mut resolver = Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id); resolver.add_generics(&trait_impl.generics); - let self_type = resolver.resolve_type(unresolved_type.clone()); - let generics = resolver.get_generics().to_vec(); + let trait_generics = + vecmap(&trait_impl.trait_generics, |generic| resolver.resolve_type(generic.clone())); + + let self_type = resolver.resolve_type(unresolved_type.clone()); + let impl_generics = resolver.get_generics().to_vec(); let impl_id = interner.next_trait_impl_id(); let mut impl_methods = functions::resolve_function_set( @@ -394,7 +411,7 @@ pub(crate) fn resolve_trait_impls( trait_impl.methods.clone(), Some(self_type.clone()), Some(impl_id), - generics.clone(), + impl_generics.clone(), errors, ); @@ -414,7 +431,7 @@ pub(crate) fn resolve_trait_impls( let mut new_resolver = Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id); - new_resolver.set_generics(generics); + new_resolver.set_generics(impl_generics.clone()); new_resolver.set_self_type(Some(self_type.clone())); if let Some(trait_id) = maybe_trait_id { @@ -422,7 +439,10 @@ pub(crate) fn resolve_trait_impls( &mut new_resolver, &impl_methods, trait_id, + trait_impl.trait_path.span(), + trait_impl.trait_generics, trait_impl.generics.len(), + trait_impl.file_id, errors, ); @@ -432,19 +452,28 @@ pub(crate) fn resolve_trait_impls( .flat_map(|item| new_resolver.resolve_trait_constraint(item)) .collect(); + let resolver_errors = new_resolver.take_errors().into_iter(); + errors.extend(resolver_errors.map(|error| (error.into(), trait_impl.file_id))); + let resolved_trait_impl = Shared::new(TraitImpl { ident: trait_impl.trait_path.last_segment().clone(), typ: self_type.clone(), trait_id, + trait_generics: trait_generics.clone(), file: trait_impl.file_id, where_clause, methods: vecmap(&impl_methods, |(_, func_id)| *func_id), }); + let impl_generics = + vecmap(impl_generics, |(_, type_variable, _)| (type_variable.id(), type_variable)); + if let Err((prev_span, prev_file)) = interner.add_trait_implementation( self_type.clone(), trait_id, + trait_generics, impl_id, + impl_generics, resolved_trait_impl, ) { let error = DefCollectorErrorKind::OverlappingImpl { diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 50ed98a794a..b583959bfb1 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -6,7 +6,7 @@ use crate::{ hir_def::{ expr::{ self, HirArrayLiteral, HirBinaryOp, HirExpression, HirLiteral, HirMethodCallExpression, - HirMethodReference, HirPrefixExpression, + HirMethodReference, HirPrefixExpression, ImplKind, }, types::Type, }, @@ -18,7 +18,7 @@ use super::{errors::TypeCheckError, TypeChecker}; impl<'interner> TypeChecker<'interner> { fn check_if_deprecated(&mut self, expr: &ExprId) { - if let HirExpression::Ident(expr::HirIdent { location, id }) = + if let HirExpression::Ident(expr::HirIdent { location, id, impl_kind: _ }) = self.interner.expression(expr) { if let Some(DefinitionKind::Function(func_id)) = @@ -52,6 +52,10 @@ impl<'interner> TypeChecker<'interner> { // We must instantiate identifiers at every call site to replace this T with a new type // variable to handle generic functions. let t = self.interner.id_type_substitute_trait_as_type(ident.id); + + // This instantiate's a trait's generics as well which need to be set + // when the constraint below is later solved for when the function is + // finished. How to link the two? let (typ, bindings) = t.instantiate(self.interner); // Push any trait constraints required by this definition to the context @@ -59,13 +63,30 @@ impl<'interner> TypeChecker<'interner> { if let Some(definition) = self.interner.try_definition(ident.id) { if let DefinitionKind::Function(function) = definition.kind { let function = self.interner.function_meta(&function); + for mut constraint in function.trait_constraints.clone() { - constraint.typ = constraint.typ.substitute(&bindings); + constraint.apply_bindings(&bindings); self.trait_constraints.push((constraint, *expr_id)); } } } + if let ImplKind::TraitMethod(_, mut constraint, assumed) = ident.impl_kind { + constraint.apply_bindings(&bindings); + if assumed { + let trait_impl = TraitImplKind::Assumed { + object_type: constraint.typ, + trait_generics: constraint.trait_generics, + }; + self.interner.select_impl_for_expression(*expr_id, trait_impl); + } else { + // Currently only one impl can be selected per expr_id, so this + // constraint needs to be pushed after any other constraints so + // that monomorphization can resolve this trait method to the correct impl. + self.trait_constraints.push((constraint, *expr_id)); + } + } + self.interner.store_instantiation_bindings(*expr_id, bindings); typ } @@ -141,7 +162,14 @@ impl<'interner> TypeChecker<'interner> { Ok((typ, use_impl)) => { if use_impl { let id = infix_expr.trait_method_id; - self.verify_trait_constraint(&lhs_type, id.trait_id, *expr_id, span); + // Assume operators have no trait generics + self.verify_trait_constraint( + &lhs_type, + id.trait_id, + &[], + *expr_id, + span, + ); self.typecheck_operator_method(*expr_id, id, &lhs_type, span); } typ @@ -207,11 +235,12 @@ impl<'interner> TypeChecker<'interner> { .trait_id }) } - HirMethodReference::TraitMethodId(method) => Some(method.trait_id), + HirMethodReference::TraitMethodId(method, _) => Some(method.trait_id), }; let (function_id, function_call) = method_call.into_function_call( - method_ref.clone(), + &method_ref, + object_type.clone(), location, self.interner, ); @@ -220,7 +249,15 @@ impl<'interner> TypeChecker<'interner> { let ret = self.check_method_call(&function_id, method_ref, args, span); if let Some(trait_id) = trait_id { - self.verify_trait_constraint(&object_type, trait_id, function_id, span); + // Assume no trait generics were specified + // TODO: Fill in type variables + self.verify_trait_constraint( + &object_type, + trait_id, + &[], + function_id, + span, + ); } self.interner.replace_expr(expr_id, function_call); @@ -298,30 +335,6 @@ impl<'interner> TypeChecker<'interner> { Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } - HirExpression::TraitMethodReference(method) => { - let the_trait = self.interner.get_trait(method.trait_id); - let typ2 = &the_trait.methods[method.method_index].typ; - let (typ, mut bindings) = typ2.instantiate(self.interner); - - // We must also remember to apply these substitutions to the object_type - // referenced by the selected trait impl, if one has yet to be selected. - let impl_kind = self.interner.get_selected_impl_for_expression(*expr_id); - if let Some(TraitImplKind::Assumed { object_type }) = impl_kind { - let the_trait = self.interner.get_trait(method.trait_id); - let object_type = object_type.substitute(&bindings); - bindings.insert( - the_trait.self_type_typevar_id, - (the_trait.self_type_typevar.clone(), object_type.clone()), - ); - self.interner.select_impl_for_expression( - *expr_id, - TraitImplKind::Assumed { object_type }, - ); - } - - self.interner.store_instantiation_bindings(*expr_id, bindings); - typ - } }; self.interner.push_expr_type(expr_id, typ.clone()); @@ -332,11 +345,14 @@ impl<'interner> TypeChecker<'interner> { &mut self, object_type: &Type, trait_id: TraitId, + trait_generics: &[Type], function_ident_id: ExprId, span: Span, ) { - match self.interner.lookup_trait_implementation(object_type, trait_id) { - Ok(impl_kind) => self.interner.select_impl_for_expression(function_ident_id, impl_kind), + match self.interner.lookup_trait_implementation(object_type, trait_id, trait_generics) { + Ok(impl_kind) => { + self.interner.select_impl_for_expression(function_ident_id, impl_kind); + } Err(erroring_constraints) => { // Don't show any errors where try_get_trait returns None. // This can happen if a trait is used that was never declared. @@ -344,7 +360,12 @@ impl<'interner> TypeChecker<'interner> { .into_iter() .map(|constraint| { let r#trait = self.interner.try_get_trait(constraint.trait_id)?; - Some((constraint.typ, r#trait.name.to_string())) + let mut name = r#trait.name.to_string(); + if !constraint.trait_generics.is_empty() { + let generics = vecmap(&constraint.trait_generics, ToString::to_string); + name += &format!("<{}>", generics.join(", ")); + } + Some((constraint.typ, name)) }) .collect::>>(); @@ -554,7 +575,7 @@ impl<'interner> TypeChecker<'interner> { arguments: Vec<(Type, ExprId, Span)>, span: Span, ) -> Type { - let (fn_typ, param_len) = match method_ref { + let (fn_typ, param_len, generic_bindings) = match method_ref { HirMethodReference::FuncId(func_id) => { if func_id == FuncId::dummy_id() { return Type::Error; @@ -562,12 +583,22 @@ impl<'interner> TypeChecker<'interner> { let func_meta = self.interner.function_meta(&func_id); let param_len = func_meta.parameters.len(); - (func_meta.typ.clone(), param_len) + (func_meta.typ.clone(), param_len, TypeBindings::new()) } - HirMethodReference::TraitMethodId(method) => { + HirMethodReference::TraitMethodId(method, generics) => { let the_trait = self.interner.get_trait(method.trait_id); let method = &the_trait.methods[method.method_index]; - (method.typ.clone(), method.arguments().len()) + + // These are any bindings from the trait's generics itself, + // rather than an impl or method's generics. + let generic_bindings = the_trait + .generics + .iter() + .zip(generics) + .map(|((id, var), arg)| (*id, (var.clone(), arg))) + .collect(); + + (method.typ.clone(), method.arguments().len(), generic_bindings) } }; @@ -581,11 +612,12 @@ impl<'interner> TypeChecker<'interner> { }); } - let (function_type, instantiation_bindings) = fn_typ.instantiate(self.interner); + let (function_type, instantiation_bindings) = + fn_typ.instantiate_with_bindings(generic_bindings, self.interner); self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings); self.interner.push_expr_type(function_ident_id, function_type.clone()); - self.bind_function_type(function_type, arguments, span) + self.bind_function_type(function_type.clone(), arguments, span) } fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type { @@ -926,7 +958,10 @@ impl<'interner> TypeChecker<'interner> { trait_id: constraint.trait_id, method_index, }; - return Some(HirMethodReference::TraitMethodId(trait_method)); + return Some(HirMethodReference::TraitMethodId( + trait_method, + constraint.trait_generics.clone(), + )); } } } @@ -1233,15 +1268,17 @@ impl<'interner> TypeChecker<'interner> { // We must also remember to apply these substitutions to the object_type // referenced by the selected trait impl, if one has yet to be selected. let impl_kind = self.interner.get_selected_impl_for_expression(expr_id); - if let Some(TraitImplKind::Assumed { object_type }) = impl_kind { + if let Some(TraitImplKind::Assumed { object_type, trait_generics }) = impl_kind { let the_trait = self.interner.get_trait(trait_method_id.trait_id); let object_type = object_type.substitute(&bindings); bindings.insert( the_trait.self_type_typevar_id, (the_trait.self_type_typevar.clone(), object_type.clone()), ); - self.interner - .select_impl_for_expression(expr_id, TraitImplKind::Assumed { object_type }); + self.interner.select_impl_for_expression( + expr_id, + TraitImplKind::Assumed { object_type, trait_generics }, + ); } self.interner.store_instantiation_bindings(expr_id, bindings); diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index c05c233fe34..3c2a970ee84 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -61,8 +61,9 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec Vec Vec TypeChecker<'interner> { typ.follow_bindings() }; - (typ.clone(), HirLValue::Ident(*ident, typ), mutable) + (typ.clone(), HirLValue::Ident(ident.clone(), typ), mutable) } HirLValue::MemberAccess { object, field_name, .. } => { let (lhs_type, object, mut mutable) = self.check_lvalue(object, assign_span); @@ -216,8 +216,8 @@ impl<'interner> TypeChecker<'interner> { // we eventually reassign to it. let id = DefinitionId::dummy_id(); let location = Location::new(span, fm::FileId::dummy()); - let tmp_value = - HirLValue::Ident(HirIdent { location, id }, Type::Error); + let ident = HirIdent::non_trait_method(id, location); + let tmp_value = HirLValue::Ident(ident, Type::Error); let lvalue = std::mem::replace(object_ref, Box::new(tmp_value)); *object_ref = Box::new(HirLValue::Dereference { lvalue, element_type }); diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 7c04398ca88..fe1cd78b5ed 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -6,6 +6,7 @@ use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId, T use crate::{BinaryOp, BinaryOpKind, Ident, Shared, UnaryOp}; use super::stmt::HirPattern; +use super::traits::TraitConstraint; use super::types::{StructType, Type}; /// A HirExpression is the result of an Expression in the AST undergoing @@ -29,7 +30,6 @@ pub enum HirExpression { If(HirIfExpression), Tuple(Vec), Lambda(HirLambda), - TraitMethodReference(TraitMethodId), Error, } @@ -41,10 +41,45 @@ impl HirExpression { } /// Corresponds to a variable in the source code -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct HirIdent { pub location: Location, pub id: DefinitionId, + + /// If this HirIdent refers to a trait method, this field stores + /// whether the impl for this method is known or not. + pub impl_kind: ImplKind, +} + +impl HirIdent { + pub fn non_trait_method(id: DefinitionId, location: Location) -> Self { + Self { id, location, impl_kind: ImplKind::NotATraitMethod } + } +} + +#[derive(Debug, Clone)] +pub enum ImplKind { + /// This ident is not a trait method + NotATraitMethod, + + /// This ident refers to a trait method and its impl needs to be verified, + /// and eventually linked to this id. The boolean indicates whether the impl + /// is already assumed to exist - e.g. when resolving a path such as `T::default` + /// when there is a corresponding `T: Default` constraint in scope. + TraitMethod(TraitMethodId, TraitConstraint, bool), +} + +impl Eq for HirIdent {} +impl PartialEq for HirIdent { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl std::hash::Hash for HirIdent { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } } #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -162,28 +197,35 @@ pub enum HirMethodReference { /// Or a method can come from a Trait impl block, in which case /// the actual function called will depend on the instantiated type, /// which can be only known during monomorphization. - TraitMethodId(TraitMethodId), + TraitMethodId(TraitMethodId, /*trait generics:*/ Vec), } impl HirMethodCallExpression { pub fn into_function_call( mut self, - method: HirMethodReference, + method: &HirMethodReference, + object_type: Type, location: Location, interner: &mut NodeInterner, ) -> (ExprId, HirExpression) { let mut arguments = vec![self.object]; arguments.append(&mut self.arguments); - let expr = match method { + let (id, impl_kind) = match method { HirMethodReference::FuncId(func_id) => { - let id = interner.function_definition_id(func_id); - HirExpression::Ident(HirIdent { location, id }) + (interner.function_definition_id(*func_id), ImplKind::NotATraitMethod) } - HirMethodReference::TraitMethodId(method_id) => { - HirExpression::TraitMethodReference(method_id) + HirMethodReference::TraitMethodId(method_id, generics) => { + let id = interner.trait_method_id(*method_id); + let constraint = TraitConstraint { + typ: object_type, + trait_id: method_id.trait_id, + trait_generics: generics.clone(), + }; + (id, ImplKind::TraitMethod(*method_id, constraint, false)) } }; + let expr = HirExpression::Ident(HirIdent { location, id, impl_kind }); let func = interner.push_expr(expr); (func, HirExpression::Call(HirCallExpression { func, arguments, location })) } diff --git a/compiler/noirc_frontend/src/hir_def/stmt.rs b/compiler/noirc_frontend/src/hir_def/stmt.rs index 21f9b431b3a..34c9302c251 100644 --- a/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -28,8 +28,8 @@ pub struct HirLetStatement { impl HirLetStatement { pub fn ident(&self) -> HirIdent { - match self.pattern { - HirPattern::Identifier(ident) => ident, + match &self.pattern { + HirPattern::Identifier(ident) => ident.clone(), _ => panic!("can only fetch hir ident from HirPattern::Identifier"), } } diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 1d0449b6568..85c292ac5f3 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use crate::{ graph::CrateId, node_interner::{FuncId, TraitId, TraitMethodId}, - Generics, Ident, NoirFunction, Type, TypeVariable, TypeVariableId, + Generics, Ident, NoirFunction, Type, TypeBindings, TypeVariable, TypeVariableId, }; use fm::FileId; use noirc_errors::{Location, Span}; @@ -70,6 +70,7 @@ pub struct TraitImpl { pub ident: Ident, pub typ: Type, pub trait_id: TraitId, + pub trait_generics: Vec, pub file: FileId, pub methods: Vec, // methods[i] is the implementation of trait.methods[i] for Type typ @@ -84,12 +85,20 @@ pub struct TraitImpl { pub struct TraitConstraint { pub typ: Type, pub trait_id: TraitId, - // pub trait_generics: Generics, TODO + pub trait_generics: Vec, } impl TraitConstraint { - pub fn new(typ: Type, trait_id: TraitId) -> Self { - Self { typ, trait_id } + pub fn new(typ: Type, trait_id: TraitId, trait_generics: Vec) -> Self { + Self { typ, trait_id, trait_generics } + } + + pub fn apply_bindings(&mut self, type_bindings: &TypeBindings) { + self.typ = self.typ.substitute(type_bindings); + + for typ in &mut self.trait_generics { + *typ = typ.substitute(type_bindings); + } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 981d7e41b6e..69ae6e36d22 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -66,7 +66,7 @@ pub enum Type { /// `impl Trait` when used in a type position. /// These are only matched based on the TraitId. The trait name paramer is only /// used for displaying error messages using the name of the trait. - TraitAsType(TraitId, /*name:*/ Rc), + TraitAsType(TraitId, /*name:*/ Rc, /*generics:*/ Vec), /// NamedGenerics are the 'T' or 'U' in a user-defined generic function /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. @@ -423,6 +423,10 @@ impl TypeVariable { TypeVariable(id, Shared::new(TypeBinding::Unbound(id))) } + pub fn id(&self) -> TypeVariableId { + self.0 + } + /// Bind this type variable to a value. /// /// Panics if this TypeVariable is already Bound. @@ -734,8 +738,13 @@ impl std::fmt::Display for Type { write!(f, "{}<{}>", s.borrow(), args.join(", ")) } } - Type::TraitAsType(_id, name) => { - write!(f, "impl {}", name) + Type::TraitAsType(_id, name, generics) => { + write!(f, "impl {}", name)?; + if !generics.is_empty() { + let generics = vecmap(generics, ToString::to_string).join(", "); + write!(f, "<{generics}>")?; + } + Ok(()) } Type::Tuple(elements) => { let elements = vecmap(elements, ToString::to_string); @@ -1251,6 +1260,29 @@ impl Type { } } + /// Instantiate this type with the given type bindings. + /// If any type variables which would be instantiated are contained in the + /// given type bindings instead, the value from the type bindings is used. + pub fn instantiate_with_bindings( + &self, + mut bindings: TypeBindings, + interner: &NodeInterner, + ) -> (Type, TypeBindings) { + match self { + Type::Forall(typevars, typ) => { + for (id, var) in typevars { + bindings + .entry(*id) + .or_insert_with(|| (var.clone(), interner.next_type_variable())); + } + + let instantiated = typ.force_substitute(&bindings); + (instantiated, bindings) + } + other => (other.clone(), bindings), + } + } + /// Instantiate this type, replacing any type variables it is quantified /// over with fresh type variables. If this type is not a Type::Forall, /// it is unchanged. @@ -1272,82 +1304,6 @@ impl Type { } } - /// Replace each NamedGeneric (and TypeVariable) in this type with a fresh type variable - pub(crate) fn instantiate_type_variables( - &self, - interner: &NodeInterner, - ) -> (Type, TypeBindings) { - let mut type_variables = HashMap::new(); - self.find_all_unbound_type_variables(&mut type_variables); - - let substitutions = type_variables - .into_iter() - .map(|(id, type_var)| (id, (type_var, interner.next_type_variable()))) - .collect(); - - (self.substitute(&substitutions), substitutions) - } - - /// For each unbound type variable in the current type, add a type binding to the given list - /// to bind the unbound type variable to a fresh type variable. - fn find_all_unbound_type_variables( - &self, - type_variables: &mut HashMap, - ) { - match self { - Type::FieldElement - | Type::Integer(_, _) - | Type::Bool - | Type::Unit - | Type::TraitAsType(..) - | Type::Constant(_) - | Type::NotConstant - | Type::Error => (), - Type::Array(length, elem) => { - length.find_all_unbound_type_variables(type_variables); - elem.find_all_unbound_type_variables(type_variables); - } - Type::String(length) => length.find_all_unbound_type_variables(type_variables), - Type::FmtString(length, env) => { - length.find_all_unbound_type_variables(type_variables); - env.find_all_unbound_type_variables(type_variables); - } - Type::Struct(_, generics) => { - for generic in generics { - generic.find_all_unbound_type_variables(type_variables); - } - } - Type::Tuple(fields) => { - for field in fields { - field.find_all_unbound_type_variables(type_variables); - } - } - Type::Function(args, ret, env) => { - for arg in args { - arg.find_all_unbound_type_variables(type_variables); - } - ret.find_all_unbound_type_variables(type_variables); - env.find_all_unbound_type_variables(type_variables); - } - Type::MutableReference(elem) => { - elem.find_all_unbound_type_variables(type_variables); - } - Type::Forall(_, typ) => typ.find_all_unbound_type_variables(type_variables), - Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => { - match &*type_variable.borrow() { - TypeBinding::Bound(binding) => { - binding.find_all_unbound_type_variables(type_variables); - } - TypeBinding::Unbound(id) => { - if !type_variables.contains_key(id) { - type_variables.insert(*id, type_variable.clone()); - } - } - } - } - } - } - /// Substitute any type variables found within this type with the /// given bindings if found. If a type variable is not found within /// the given TypeBindings, it is unchanged. @@ -1554,6 +1510,10 @@ impl Type { | NotConstant => self.clone(), } } + + pub fn from_generics(generics: &Generics) -> Vec { + vecmap(generics, |(_, var)| Type::TypeVariable(var.clone(), TypeVariableKind::Normal)) + } } /// Wraps a given `expression` in `expression.as_slice()` @@ -1569,7 +1529,7 @@ fn convert_array_expression_to_slice( let as_slice_id = interner.function_definition_id(as_slice_method); let location = interner.expr_location(&expression); - let as_slice = HirExpression::Ident(HirIdent { location, id: as_slice_id }); + let as_slice = HirExpression::Ident(HirIdent::non_trait_method(as_slice_id, location)); let func = interner.push_expr(as_slice); let arguments = vec![expression]; diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 3510475e881..ac11e00ad20 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -431,11 +431,6 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Lambda(lambda) => self.lambda(lambda, expr), - HirExpression::TraitMethodReference(method) => { - let function_type = self.interner.id_type(expr); - self.resolve_trait_method_reference(expr, function_type, method) - } - HirExpression::MethodCall(hir_method_call) => { unreachable!("Encountered HirExpression::MethodCall during monomorphization {hir_method_call:?}") } @@ -682,6 +677,12 @@ impl<'interner> Monomorphizer<'interner> { } fn ident(&mut self, ident: HirIdent, expr_id: node_interner::ExprId) -> ast::Expression { + let typ = self.interner.id_type(expr_id); + + if let ImplKind::TraitMethod(method, _, _) = ident.impl_kind { + return self.resolve_trait_method_reference(expr_id, typ, method); + } + let definition = self.interner.definition(ident.id); match &definition.kind { DefinitionKind::Function(func_id) => { @@ -866,8 +867,12 @@ impl<'interner> Monomorphizer<'interner> { self.interner.get_trait_implementation(impl_id).borrow().methods [method.method_index] } - node_interner::TraitImplKind::Assumed { object_type } => { - match self.interner.lookup_trait_implementation(&object_type, method.trait_id) { + node_interner::TraitImplKind::Assumed { object_type, trait_generics } => { + match self.interner.lookup_trait_implementation( + &object_type, + method.trait_id, + &trait_generics, + ) { Ok(TraitImplKind::Normal(impl_id)) => { self.interner.get_trait_implementation(impl_id).borrow().methods [method.method_index] @@ -889,14 +894,12 @@ impl<'interner> Monomorphizer<'interner> { } }; - let func_def = self.lookup_function(func_id, expr_id, &function_type, Some(method)); - let func_id = match func_def { + let func_id = match self.lookup_function(func_id, expr_id, &function_type, Some(method)) { Definition::Function(func_id) => func_id, _ => unreachable!(), }; let the_trait = self.interner.get_trait(method.trait_id); - ast::Expression::Ident(ast::Ident { definition: Definition::Function(func_id), mutable: false, diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index f9d273d774f..1027dac813e 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -153,6 +153,19 @@ pub enum TraitImplKind { /// Assumed impls don't have an impl id since they don't link back to any concrete part of the source code. Assumed { object_type: Type, + + /// The trait generics to use - if specified. + /// This is allowed to be empty when they are inferred. E.g. for: + /// + /// ``` + /// trait Into { + /// fn into(self) -> T; + /// } + /// ``` + /// + /// The reference `Into::into(x)` would have inferred generics, but + /// `x.into()` with a `X: Into` in scope would not. + trait_generics: Vec, }, } @@ -941,6 +954,16 @@ impl NodeInterner { self.function_definition_ids[&function] } + /// Returns the DefinitionId of a trait's method, panics if the given trait method + /// is not a valid method of the trait or if the trait has not yet had + /// its methods ids set during name resolution. + pub fn trait_method_id(&self, trait_method: TraitMethodId) -> DefinitionId { + let the_trait = self.get_trait(trait_method.trait_id); + let method_name = &the_trait.methods[trait_method.method_index].name; + let function_id = the_trait.method_ids[&method_name.0.contents]; + self.function_definition_id(function_id) + } + /// Adds a non-trait method to a type. /// /// Returns `Some(duplicate)` if a matching method was already defined. @@ -996,8 +1019,11 @@ impl NodeInterner { &self, object_type: &Type, trait_id: TraitId, + trait_generics: &[Type], ) -> Result> { - let (impl_kind, bindings) = self.try_lookup_trait_implementation(object_type, trait_id)?; + let (impl_kind, bindings) = + self.try_lookup_trait_implementation(object_type, trait_id, trait_generics)?; + Type::apply_type_bindings(bindings); Ok(impl_kind) } @@ -1007,11 +1033,13 @@ impl NodeInterner { &self, object_type: &Type, trait_id: TraitId, + trait_generics: &[Type], ) -> Result<(TraitImplKind, TypeBindings), Vec> { let mut bindings = TypeBindings::new(); let impl_kind = self.lookup_trait_implementation_helper( object_type, trait_id, + trait_generics, &mut bindings, IMPL_SEARCH_RECURSION_LIMIT, )?; @@ -1022,10 +1050,12 @@ impl NodeInterner { &self, object_type: &Type, trait_id: TraitId, + trait_generics: &[Type], type_bindings: &mut TypeBindings, recursion_limit: u32, ) -> Result> { - let make_constraint = || TraitConstraint::new(object_type.clone(), trait_id); + let make_constraint = + || TraitConstraint::new(object_type.clone(), trait_id, trait_generics.to_vec()); // Prevent infinite recursion when looking for impls if recursion_limit == 0 { @@ -1037,12 +1067,35 @@ impl NodeInterner { let impls = self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; - for (existing_object_type, impl_kind) in impls { + for (existing_object_type2, impl_kind) in impls { + // Bug: We're instantiating only the object type's generics here, not all of the trait's generics like we need to let (existing_object_type, instantiation_bindings) = - existing_object_type.instantiate(self); + existing_object_type2.instantiate(self); let mut fresh_bindings = TypeBindings::new(); + let mut check_trait_generics = |impl_generics: &[Type]| { + trait_generics.iter().zip(impl_generics).all(|(trait_generic, impl_generic2)| { + let impl_generic = impl_generic2.substitute(&instantiation_bindings); + trait_generic.try_unify(&impl_generic, &mut fresh_bindings).is_ok() + }) + }; + + let generics_match = match impl_kind { + TraitImplKind::Normal(id) => { + let shared_impl = self.get_trait_implementation(*id); + let shared_impl = shared_impl.borrow(); + check_trait_generics(&shared_impl.trait_generics) + } + TraitImplKind::Assumed { trait_generics, .. } => { + check_trait_generics(trait_generics) + } + }; + + if !generics_match { + continue; + } + if object_type.try_unify(&existing_object_type, &mut fresh_bindings).is_ok() { // The unification was successful so we can append fresh_bindings to our bindings list type_bindings.extend(fresh_bindings); @@ -1085,9 +1138,15 @@ impl NodeInterner { let constraint_type = constraint.typ.force_substitute(instantiation_bindings); let constraint_type = constraint_type.substitute(type_bindings); + let trait_generics = vecmap(&constraint.trait_generics, |generic| { + let generic = generic.force_substitute(instantiation_bindings); + generic.substitute(type_bindings) + }); + self.lookup_trait_implementation_helper( &constraint_type, constraint.trait_id, + &trait_generics, // Use a fresh set of type bindings here since the constraint_type originates from // our impl list, which we don't want to bind to. &mut TypeBindings::new(), @@ -1109,14 +1168,15 @@ impl NodeInterner { &mut self, object_type: Type, trait_id: TraitId, + trait_generics: Vec, ) -> bool { // Make sure there are no overlapping impls - if self.try_lookup_trait_implementation(&object_type, trait_id).is_ok() { + if self.try_lookup_trait_implementation(&object_type, trait_id, &trait_generics).is_ok() { return false; } let entries = self.trait_implementation_map.entry(trait_id).or_default(); - entries.push((object_type.clone(), TraitImplKind::Assumed { object_type })); + entries.push((object_type.clone(), TraitImplKind::Assumed { object_type, trait_generics })); true } @@ -1125,7 +1185,9 @@ impl NodeInterner { &mut self, object_type: Type, trait_id: TraitId, + trait_generics: Vec, impl_id: TraitImplId, + impl_generics: Generics, trait_impl: Shared, ) -> Result<(), (Span, FileId)> { assert_eq!(impl_id.0, self.trait_implementations.len(), "trait impl defined out of order"); @@ -1136,12 +1198,20 @@ impl NodeInterner { // It should never happen since impls are defined at global scope, but even // if they were, we should never prevent defining a new impl because a where // clause already assumes it exists. - let (instantiated_object_type, substitutions) = - object_type.instantiate_type_variables(self); - if let Ok((TraitImplKind::Normal(existing), _)) = - self.try_lookup_trait_implementation(&instantiated_object_type, trait_id) - { + // Replace each generic with a fresh type variable + let substitutions = impl_generics + .into_iter() + .map(|(id, typevar)| (id, (typevar, self.next_type_variable()))) + .collect(); + + let instantiated_object_type = object_type.substitute(&substitutions); + + if let Ok((TraitImplKind::Normal(existing), _)) = self.try_lookup_trait_implementation( + &instantiated_object_type, + trait_id, + &trait_generics, + ) { let existing_impl = self.get_trait_implementation(existing); let existing_impl = existing_impl.borrow(); return Err((existing_impl.ident.span(), existing_impl.file)); @@ -1155,6 +1225,7 @@ impl NodeInterner { // The object type is generalized so that a generic impl will apply // to any type T, rather than just the generic type named T. let generalized_object_type = object_type.generalize_from_substitutions(substitutions); + let entries = self.trait_implementation_map.entry(trait_id).or_default(); entries.push((generalized_object_type, TraitImplKind::Normal(impl_id))); Ok(()) diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 954b531abff..cdfdc570949 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -413,13 +413,7 @@ fn trait_definition() -> impl NoirParser { .then_ignore(just(Token::LeftBrace)) .then(trait_body()) .then_ignore(just(Token::RightBrace)) - .validate(|(((name, generics), where_clause), items), span, emit| { - if !generics.is_empty() { - emit(ParserError::with_reason( - ParserErrorReason::ExperimentalFeature("Generic traits"), - span, - )); - } + .map_with_span(|(((name, generics), where_clause), items), span| { TopLevelStatement::Trait(NoirTrait { name, generics, where_clause, span, items }) }) } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 10684f76169..a56c3a7755f 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -118,6 +118,12 @@ mod test { fn eq(self, other: Foo) -> bool { self.a == other.a } } + impl Default for u64 { + fn default() -> Self { + 0 + } + } + impl Default for Foo { fn default() -> Self { Foo { a: Default::default() } diff --git a/docs/docs/noir/concepts/traits.md b/docs/docs/noir/concepts/traits.md index 7ba07e74f40..f8a50071a4b 100644 --- a/docs/docs/noir/concepts/traits.md +++ b/docs/docs/noir/concepts/traits.md @@ -168,6 +168,46 @@ impl Eq for [T; N] where T: Eq { } ``` +## Generic Traits + +Traits themselves can also be generic by placing the generic arguments after the trait name. These generics are in +scope of every item within the trait. + +```rust +trait Into { + // Convert `self` to type `T` + fn into(self) -> T; +} +``` + +When implementing generic traits the generic arguments of the trait must be specified. This is also true anytime +when referencing a generic trait (e.g. in a `where` clause). + +```rust +struct MyStruct { + array: [Field; 2], +} + +impl Into<[Field; 2]> for MyStruct { + fn into(self) -> [Field; 2] { + self.array + } +} + +fn as_array(x: T) -> [Field; 2] + where T: Into<[Field; 2]> +{ + x.into() +} + +fn main() { + let array = [1, 2]; + let my_struct = MyStruct { array }; + + assert_eq(as_array(my_struct), array); +} +``` + ## Trait Methods With No `self` A trait can contain any number of methods, each of which have access to the `Self` type which represents each type diff --git a/test_programs/compile_success_empty/trait_generics/Nargo.toml b/test_programs/compile_success_empty/trait_generics/Nargo.toml index c1b5d0aaa6c..7fdd5975541 100644 --- a/test_programs/compile_success_empty/trait_generics/Nargo.toml +++ b/test_programs/compile_success_empty/trait_generics/Nargo.toml @@ -2,5 +2,6 @@ name = "trait_generics" type = "bin" authors = [""] +compiler_version = ">=0.22.0" [dependencies] diff --git a/test_programs/compile_success_empty/trait_generics/src/main.nr b/test_programs/compile_success_empty/trait_generics/src/main.nr index bb6d6e74726..9a3c54c3fa1 100644 --- a/test_programs/compile_success_empty/trait_generics/src/main.nr +++ b/test_programs/compile_success_empty/trait_generics/src/main.nr @@ -1,59 +1,57 @@ -struct Empty {} -trait Foo { - fn foo(self) -> u32; -} +fn main() { + let xs: [Field; 1] = [3]; + let ys: [u32; 1] = [3]; + foo(xs, ys); -impl Foo for Empty { - fn foo(_self: Self) -> u32 { 32 } + assert_eq(15, sum(Data { a: 5, b: 10 })); + assert_eq(15, sum_static(Data { a: 5, b: 10 })); } -impl Foo for Empty { - fn foo(_self: Self) -> u32 { 64 } +fn foo(x: T, u: U) where T: Into, U: Eq { + assert(x.into() == u); } -fn main() { - let x: Empty = Empty {}; - let y: Empty = Empty {}; - let z = Empty {}; - - assert(x.foo() == 32); - assert(y.foo() == 64); - // Types matching multiple impls will currently choose - // the first matching one instead of erroring - assert(z.foo() == 32); - - call_impl_with_generic_struct(); - call_impl_with_generic_function(); +trait Into { + fn into(self) -> T; } -// Ensure we can call a generic impl -fn call_impl_with_generic_struct() { - let x: u8 = 7; - let y: i8 = 8; - let s2_u8 = S2 { x }; - let s2_i8 = S2 { x: y }; - assert(s2_u8.t2().x == 7); - assert(s2_i8.t2().x == 8); + +impl Into<[U; N]> for [T; N] where T: Into { + fn into(self) -> [U; N] { + self.map(|x: T| x.into()) + } } -trait T2 { - fn t2(self) -> Self; +impl Into for Field { + fn into(self) -> u32 { + self as u32 + } } -struct S2 { x: T } +/// Serialize example + +trait Serializable { + fn serialize(self) -> [Field; N]; +} -impl T2 for S2 { - fn t2(self) -> Self { self } +struct Data { + a: Field, + b: Field, } -fn call_impl_with_generic_function() { - assert(3.t3(7) == 7); +impl Serializable<2> for Data { + fn serialize(self) -> [Field; 2] { + [self.a, self.b] + } } -trait T3 { - fn t3(self, x: T) -> T; +fn sum(data: T) -> Field where T: Serializable { + let serialized = data.serialize(); + serialized.fold(0, |acc, elem| acc + elem) } -impl T3 for u32 { - fn t3(self, y: U) -> U { y } +// Test static trait method syntax +fn sum_static(data: T) -> Field where T: Serializable { + let serialized = Serializable::serialize(data); + serialized.fold(0, |acc, elem| acc + elem) } diff --git a/test_programs/compile_success_empty/trait_impl_generics/Nargo.toml b/test_programs/compile_success_empty/trait_impl_generics/Nargo.toml new file mode 100644 index 00000000000..b10b5dab6aa --- /dev/null +++ b/test_programs/compile_success_empty/trait_impl_generics/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "trait_impl_generics" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/compile_success_empty/trait_impl_generics/src/main.nr b/test_programs/compile_success_empty/trait_impl_generics/src/main.nr new file mode 100644 index 00000000000..c46c41cbdd7 --- /dev/null +++ b/test_programs/compile_success_empty/trait_impl_generics/src/main.nr @@ -0,0 +1,59 @@ +struct Empty {} + +trait Foo { + fn foo(self) -> u32; +} + +impl Foo for Empty { + fn foo(_self: Self) -> u32 { 32 } +} + +impl Foo for Empty { + fn foo(_self: Self) -> u32 { 64 } +} + +fn main() { + let x: Empty = Empty {}; + let y: Empty = Empty {}; + let z = Empty {}; + + assert(x.foo() == 32); + assert(y.foo() == 64); + // Types matching multiple impls will currently choose + // the first matching one instead of erroring + assert(z.foo() == 32); + + call_impl_with_generic_struct(); + call_impl_with_generic_function(); +} +// Ensure we can call a generic impl +fn call_impl_with_generic_struct() { + let x: u8 = 7; + let y: i8 = 8; + let s2_u8 = S2 { x }; + let s2_i8 = S2 { x: y }; + assert(s2_u8.t2().x == 7); + assert(s2_i8.t2().x == 8); +} + +trait T2 { + fn t2(self) -> Self; +} + +struct S2 { x: T } + +impl T2 for S2 { + fn t2(self) -> Self { self } +} + +fn call_impl_with_generic_function() { + assert(3.t3(7) == 7); +} + +trait T3 { + fn t3(self, x: T) -> T; +} + +impl T3 for u32 { + fn t3(_self: Self, y: U) -> U { y } +}