diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index 402ff31dafe..20d75a704d3 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -1,6 +1,30 @@ +use std::{collections::BTreeMap, fmt::Display}; + +use chumsky::Parser; +use fm::FileId; +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; + use crate::{ + hir::{ + comptime::{Interpreter, InterpreterError, Value}, + def_collector::{ + dc_crate::{ + CollectedItems, CompilationError, UnresolvedFunctions, UnresolvedStruct, + UnresolvedTrait, UnresolvedTraitImpl, + }, + dc_mod, + }, + resolution::errors::ResolverError, + }, hir_def::expr::HirIdent, - node_interner::{DependencyId, FuncId}, + lexer::Lexer, + macros_api::{ + Expression, ExpressionKind, HirExpression, NodeInterner, SecondaryAttribute, StructId, + }, + node_interner::{DefinitionKind, DependencyId, FuncId, TraitId}, + parser::{self, TopLevelStatement}, + Type, TypeBindings, }; use super::{Elaborator, FunctionContext, ResolverMeta}; @@ -35,13 +59,11 @@ impl<'context> Elaborator<'context> { elaborator.introduce_generics_into_scope(meta.all_generics.clone()); } - elaborator.comptime_scopes = std::mem::take(&mut self.comptime_scopes); elaborator.populate_scope_from_comptime_scopes(); let result = f(&mut elaborator); elaborator.check_and_pop_function_context(); - self.comptime_scopes = elaborator.comptime_scopes; self.errors.append(&mut elaborator.errors); result } @@ -50,7 +72,7 @@ impl<'context> Elaborator<'context> { // Take the comptime scope to be our runtime scope. // Iterate from global scope to the most local scope so that the // later definitions will naturally shadow the former. - for scope in &self.comptime_scopes { + for scope in &self.interner.comptime_scopes { for definition_id in scope.keys() { let definition = self.interner.definition(*definition_id); let name = definition.name.clone(); @@ -63,4 +85,326 @@ impl<'context> Elaborator<'context> { } } } + + pub(super) fn run_comptime_attributes_on_item( + &mut self, + attributes: &[SecondaryAttribute], + item: Value, + span: Span, + generated_items: &mut CollectedItems, + ) { + for attribute in attributes { + if let SecondaryAttribute::Custom(name) = attribute { + if let Err(error) = + self.run_comptime_attribute_on_item(name, item.clone(), span, generated_items) + { + self.errors.push(error); + } + } + } + } + + fn run_comptime_attribute_on_item( + &mut self, + attribute: &str, + item: Value, + span: Span, + generated_items: &mut CollectedItems, + ) -> Result<(), (CompilationError, FileId)> { + let location = Location::new(span, self.file); + let Some((function, arguments)) = Self::parse_attribute(attribute, self.file)? else { + // Do not issue an error if the attribute is unknown + return Ok(()); + }; + + // Elaborate the function, rolling back any errors generated in case it is unknown + let error_count = self.errors.len(); + let function = self.elaborate_expression(function).0; + self.errors.truncate(error_count); + + let definition_id = match self.interner.expression(&function) { + HirExpression::Ident(ident, _) => ident.id, + _ => return Ok(()), + }; + + let Some(definition) = self.interner.try_definition(definition_id) else { + // If there's no such function, don't return an error. + // This preserves backwards compatibility in allowing custom attributes that + // do not refer to comptime functions. + return Ok(()); + }; + + let DefinitionKind::Function(function) = definition.kind else { + return Err((ResolverError::NonFunctionInAnnotation { span }.into(), self.file)); + }; + + let mut interpreter = self.setup_interpreter(); + let mut arguments = + Self::handle_attribute_arguments(&mut interpreter, function, arguments, location) + .map_err(|error| { + let file = error.get_location().file; + (error.into(), file) + })?; + + arguments.insert(0, (item, location)); + + let value = interpreter + .call_function(function, arguments, TypeBindings::new(), location) + .map_err(|error| error.into_compilation_error_pair())?; + + if value != Value::Unit { + let items = value + .into_top_level_items(location, self.interner) + .map_err(|error| error.into_compilation_error_pair())?; + + self.add_items(items, generated_items, location); + } + + Ok(()) + } + + /// Parses an attribute in the form of a function call (e.g. `#[foo(a b, c d)]`) into + /// the function and quoted arguments called (e.g. `("foo", vec![(a b, location), (c d, location)])`) + #[allow(clippy::type_complexity)] + fn parse_attribute( + annotation: &str, + file: FileId, + ) -> Result)>, (CompilationError, FileId)> { + let (tokens, mut lexing_errors) = Lexer::lex(annotation); + if !lexing_errors.is_empty() { + return Err((lexing_errors.swap_remove(0).into(), file)); + } + + let expression = parser::expression() + .parse(tokens) + .map_err(|mut errors| (errors.swap_remove(0).into(), file))?; + + Ok(match expression.kind { + ExpressionKind::Call(call) => Some((*call.func, call.arguments)), + ExpressionKind::Variable(_) => Some((expression, Vec::new())), + _ => None, + }) + } + + fn handle_attribute_arguments( + interpreter: &mut Interpreter, + function: FuncId, + arguments: Vec, + location: Location, + ) -> Result, InterpreterError> { + let meta = interpreter.elaborator.interner.function_meta(&function); + let mut parameters = vecmap(&meta.parameters.0, |(_, typ, _)| typ.clone()); + + // Remove the initial parameter for the comptime item since that is not included + // in `arguments` at this point. + parameters.remove(0); + + // If the function is varargs, push the type of the last slice element N times + // to account for N extra arguments. + let modifiers = interpreter.elaborator.interner.function_modifiers(&function); + let is_varargs = modifiers.attributes.is_varargs(); + let varargs_type = if is_varargs { parameters.pop() } else { None }; + + let varargs_elem_type = varargs_type.as_ref().and_then(|t| t.slice_element_type()); + + let mut new_arguments = Vec::with_capacity(arguments.len()); + let mut varargs = im::Vector::new(); + + for (i, arg) in arguments.into_iter().enumerate() { + let param_type = parameters.get(i).or(varargs_elem_type).unwrap_or(&Type::Error); + + let mut push_arg = |arg| { + if i >= parameters.len() { + varargs.push_back(arg); + } else { + new_arguments.push((arg, location)); + } + }; + + if *param_type == Type::Quoted(crate::QuotedType::TraitDefinition) { + let trait_id = match arg.kind { + ExpressionKind::Variable(path) => interpreter + .elaborator + .resolve_trait_by_path(path) + .ok_or(InterpreterError::FailedToResolveTraitDefinition { location }), + _ => Err(InterpreterError::TraitDefinitionMustBeAPath { location }), + }?; + push_arg(Value::TraitDefinition(trait_id)); + } else { + let expr_id = interpreter.elaborator.elaborate_expression(arg).0; + push_arg(interpreter.evaluate(expr_id)?); + } + } + + if is_varargs { + let typ = varargs_type.unwrap_or(Type::Error); + new_arguments.push((Value::Slice(varargs, typ), location)); + } + + Ok(new_arguments) + } + + fn add_items( + &mut self, + items: Vec, + generated_items: &mut CollectedItems, + location: Location, + ) { + for item in items { + self.add_item(item, generated_items, location); + } + } + + fn add_item( + &mut self, + item: TopLevelStatement, + generated_items: &mut CollectedItems, + location: Location, + ) { + match item { + TopLevelStatement::Function(function) => { + let id = self.interner.push_empty_fn(); + let module = self.module_id(); + self.interner.push_function(id, &function.def, module, location); + let functions = vec![(self.local_module, id, function)]; + generated_items.functions.push(UnresolvedFunctions { + file_id: self.file, + functions, + trait_id: None, + self_type: None, + }); + } + TopLevelStatement::TraitImpl(mut trait_impl) => { + let methods = dc_mod::collect_trait_impl_functions( + self.interner, + &mut trait_impl, + self.crate_id, + self.file, + self.local_module, + ); + + generated_items.trait_impls.push(UnresolvedTraitImpl { + file_id: self.file, + module_id: self.local_module, + trait_generics: trait_impl.trait_generics, + trait_path: trait_impl.trait_name, + object_type: trait_impl.object_type, + methods, + generics: trait_impl.impl_generics, + where_clause: trait_impl.where_clause, + + // These last fields are filled in later + trait_id: None, + impl_id: None, + resolved_object_type: None, + resolved_generics: Vec::new(), + resolved_trait_generics: Vec::new(), + }); + } + TopLevelStatement::Global(global) => { + let (global, error) = dc_mod::collect_global( + self.interner, + self.def_maps.get_mut(&self.crate_id).unwrap(), + global, + self.file, + self.local_module, + self.crate_id, + ); + + generated_items.globals.push(global); + if let Some(error) = error { + self.errors.push(error); + } + } + // Assume that an error has already been issued + TopLevelStatement::Error => (), + + TopLevelStatement::Module(_) + | TopLevelStatement::Import(_) + | TopLevelStatement::Struct(_) + | TopLevelStatement::Trait(_) + | TopLevelStatement::Impl(_) + | TopLevelStatement::TypeAlias(_) + | TopLevelStatement::SubModule(_) => { + let item = item.to_string(); + let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location }; + self.errors.push(error.into_compilation_error_pair()); + } + } + } + + pub fn setup_interpreter<'local>(&'local mut self) -> Interpreter<'local, 'context> { + let current_function = match self.current_item { + Some(DependencyId::Function(function)) => Some(function), + _ => None, + }; + Interpreter::new(self, self.crate_id, current_function) + } + + pub(super) fn debug_comptime T>( + &mut self, + location: Location, + mut expr_f: F, + ) { + if Some(location.file) == self.debug_comptime_in_file { + let displayed_expr = expr_f(self.interner); + self.errors.push(( + InterpreterError::debug_evaluate_comptime(displayed_expr, location).into(), + location.file, + )); + } + } + + /// Run all the attributes on each item. The ordering is unspecified to users but currently + /// we run trait attributes first to (e.g.) register derive handlers before derive is + /// called on structs. + /// Returns any new items generated by attributes. + pub(super) fn run_attributes( + &mut self, + traits: &BTreeMap, + types: &BTreeMap, + functions: &[UnresolvedFunctions], + ) -> CollectedItems { + let mut generated_items = CollectedItems::default(); + + for (trait_id, trait_) in traits { + let attributes = &trait_.trait_def.attributes; + let item = Value::TraitDefinition(*trait_id); + let span = trait_.trait_def.span; + self.local_module = trait_.module_id; + self.file = trait_.file_id; + self.run_comptime_attributes_on_item(attributes, item, span, &mut generated_items); + } + + for (struct_id, struct_def) in types { + let attributes = &struct_def.struct_def.attributes; + let item = Value::StructDefinition(*struct_id); + let span = struct_def.struct_def.span; + self.local_module = struct_def.module_id; + self.file = struct_def.file_id; + self.run_comptime_attributes_on_item(attributes, item, span, &mut generated_items); + } + + self.run_attributes_on_functions(functions, &mut generated_items); + generated_items + } + + fn run_attributes_on_functions( + &mut self, + function_sets: &[UnresolvedFunctions], + generated_items: &mut CollectedItems, + ) { + for function_set in function_sets { + self.file = function_set.file_id; + self.self_type = function_set.self_type.clone(); + + for (local_module, function_id, function) in &function_set.functions { + self.local_module = *local_module; + let attributes = function.secondary_attributes(); + let item = Value::FunctionDefinition(*function_id); + let span = function.span(); + self.run_comptime_attributes_on_item(attributes, item, span, generated_items); + } + } + } } diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 0b6233b3445..e7f53ebb916 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -1,19 +1,14 @@ use std::{ collections::{BTreeMap, BTreeSet}, - fmt::Display, rc::Rc, }; use crate::{ ast::{FunctionKind, UnresolvedTraitConstraint}, hir::{ - comptime::{Interpreter, InterpreterError, Value}, - def_collector::{ - dc_crate::{ - filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal, - UnresolvedStruct, UnresolvedTrait, UnresolvedTypeAlias, - }, - dc_mod, + def_collector::dc_crate::{ + filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal, UnresolvedStruct, + UnresolvedTypeAlias, }, def_map::DefMaps, resolution::{errors::ResolverError, path_resolver::PathResolver}, @@ -26,18 +21,14 @@ use crate::{ traits::TraitConstraint, types::{Generics, Kind, ResolvedGeneric}, }, - lexer::Lexer, macros_api::{ BlockExpression, Ident, NodeInterner, NoirFunction, NoirStruct, Pattern, SecondaryAttribute, StructId, }, node_interner::{ - DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, ReferenceId, TraitId, - TypeAliasId, + DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, ReferenceId, TraitId, TypeAliasId, }, - parser::TopLevelStatement, - token::Tokens, - Shared, Type, TypeBindings, TypeVariable, + Shared, Type, TypeVariable, }; use crate::{ ast::{TraitBound, UnresolvedGeneric, UnresolvedGenerics}, @@ -75,7 +66,6 @@ mod unquote; use fm::FileId; use iter_extended::vecmap; use noirc_errors::{Location, Span}; -use rustc_hash::FxHashMap as HashMap; use self::traits::check_trait_impl_method_matches_declaration; @@ -164,11 +154,6 @@ pub struct Elaborator<'context> { crate_id: CrateId, - /// Each value currently in scope in the comptime interpreter. - /// Each element of the Vec represents a scope with every scope together making - /// up all currently visible definitions. The first scope is always the global scope. - pub(crate) comptime_scopes: Vec>, - /// The scope of --debug-comptime, or None if unset debug_comptime_in_file: Option, @@ -216,7 +201,6 @@ impl<'context> Elaborator<'context> { trait_bounds: Vec::new(), function_context: vec![FunctionContext::default()], current_trait_impl: None, - comptime_scopes: vec![HashMap::default()], debug_comptime_in_file, unresolved_globals: BTreeMap::new(), } @@ -1192,118 +1176,6 @@ impl<'context> Elaborator<'context> { } } - fn run_comptime_attributes_on_item( - &mut self, - attributes: &[SecondaryAttribute], - item: Value, - span: Span, - generated_items: &mut CollectedItems, - ) { - for attribute in attributes { - if let SecondaryAttribute::Custom(name) = attribute { - if let Err(error) = - self.run_comptime_attribute_on_item(name, item.clone(), span, generated_items) - { - self.errors.push(error); - } - } - } - } - - fn run_comptime_attribute_on_item( - &mut self, - attribute: &str, - item: Value, - span: Span, - generated_items: &mut CollectedItems, - ) -> Result<(), (CompilationError, FileId)> { - let location = Location::new(span, self.file); - let (function_name, mut arguments) = Self::parse_attribute(attribute, location) - .unwrap_or_else(|| (attribute.to_string(), Vec::new())); - - let Ok(id) = self.lookup_global(Path::from_single(function_name, span)) else { - // Do not issue an error if the attribute is unknown - return Ok(()); - }; - - let definition = self.interner.definition(id); - let DefinitionKind::Function(function) = definition.kind else { - return Err((ResolverError::NonFunctionInAnnotation { span }.into(), self.file)); - }; - - self.handle_varargs_attribute(function, &mut arguments, location); - arguments.insert(0, (item, location)); - - let mut interpreter = self.setup_interpreter(); - - let value = interpreter - .call_function(function, arguments, TypeBindings::new(), location) - .map_err(|error| error.into_compilation_error_pair())?; - - if value != Value::Unit { - let items = value - .into_top_level_items(location) - .map_err(|error| error.into_compilation_error_pair())?; - - self.add_items(items, generated_items, location); - } - - Ok(()) - } - - /// Parses an attribute in the form of a function call (e.g. `#[foo(a b, c d)]`) into - /// the function and quoted arguments called (e.g. `("foo", vec![(a b, location), (c d, location)])`) - fn parse_attribute( - annotation: &str, - location: Location, - ) -> Option<(String, Vec<(Value, Location)>)> { - let (tokens, errors) = Lexer::lex(annotation); - if !errors.is_empty() { - return None; - } - - let mut tokens = tokens.0; - if tokens.len() >= 4 { - // Remove the outer `ident ( )` wrapping the function arguments - let first = tokens.remove(0).into_token(); - let second = tokens.remove(0).into_token(); - - // Last token is always an EndOfInput - let _ = tokens.pop().unwrap().into_token(); - let last = tokens.pop().unwrap().into_token(); - - use crate::lexer::token::Token::*; - if let (Ident(name), LeftParen, RightParen) = (first, second, last) { - let args = tokens.split(|token| *token.token() == Comma); - let args = - vecmap(args, |arg| (Value::Code(Rc::new(Tokens(arg.to_vec()))), location)); - return Some((name, args)); - } - } - - None - } - - /// Checks if the given attribute function is a varargs function. - /// If so, we should pass its arguments in one slice rather than as separate arguments. - fn handle_varargs_attribute( - &mut self, - function: FuncId, - arguments: &mut Vec<(Value, Location)>, - location: Location, - ) { - let meta = self.interner.function_meta(&function); - let parameters = &meta.parameters.0; - - // If the last parameter is a slice, this is a varargs function. - if parameters.last().map_or(false, |(_, typ, _)| matches!(typ, Type::Slice(_))) { - let typ = Type::Slice(Box::new(Type::Quoted(crate::QuotedType::Quoted))); - let slice_elements = arguments.drain(..).map(|(value, _)| value); - let slice = Value::Slice(slice_elements.collect(), typ); - arguments.push((slice, location)); - } - } - pub fn resolve_struct_fields( &mut self, unresolved: &NoirStruct, @@ -1497,168 +1369,4 @@ impl<'context> Elaborator<'context> { _ => true, }) } - - fn add_items( - &mut self, - items: Vec, - generated_items: &mut CollectedItems, - location: Location, - ) { - for item in items { - self.add_item(item, generated_items, location); - } - } - - fn add_item( - &mut self, - item: TopLevelStatement, - generated_items: &mut CollectedItems, - location: Location, - ) { - match item { - TopLevelStatement::Function(function) => { - let id = self.interner.push_empty_fn(); - let module = self.module_id(); - self.interner.push_function(id, &function.def, module, location); - let functions = vec![(self.local_module, id, function)]; - generated_items.functions.push(UnresolvedFunctions { - file_id: self.file, - functions, - trait_id: None, - self_type: None, - }); - } - TopLevelStatement::TraitImpl(mut trait_impl) => { - let methods = dc_mod::collect_trait_impl_functions( - self.interner, - &mut trait_impl, - self.crate_id, - self.file, - self.local_module, - ); - - generated_items.trait_impls.push(UnresolvedTraitImpl { - file_id: self.file, - module_id: self.local_module, - trait_generics: trait_impl.trait_generics, - trait_path: trait_impl.trait_name, - object_type: trait_impl.object_type, - methods, - generics: trait_impl.impl_generics, - where_clause: trait_impl.where_clause, - - // These last fields are filled in later - trait_id: None, - impl_id: None, - resolved_object_type: None, - resolved_generics: Vec::new(), - resolved_trait_generics: Vec::new(), - }); - } - TopLevelStatement::Global(global) => { - let (global, error) = dc_mod::collect_global( - self.interner, - self.def_maps.get_mut(&self.crate_id).unwrap(), - global, - self.file, - self.local_module, - self.crate_id, - ); - - generated_items.globals.push(global); - if let Some(error) = error { - self.errors.push(error); - } - } - // Assume that an error has already been issued - TopLevelStatement::Error => (), - - TopLevelStatement::Module(_) - | TopLevelStatement::Import(_) - | TopLevelStatement::Struct(_) - | TopLevelStatement::Trait(_) - | TopLevelStatement::Impl(_) - | TopLevelStatement::TypeAlias(_) - | TopLevelStatement::SubModule(_) => { - let item = item.to_string(); - let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location }; - self.errors.push(error.into_compilation_error_pair()); - } - } - } - - pub fn setup_interpreter<'local>(&'local mut self) -> Interpreter<'local, 'context> { - let current_function = match self.current_item { - Some(DependencyId::Function(function)) => Some(function), - _ => None, - }; - Interpreter::new(self, self.crate_id, current_function) - } - - fn debug_comptime T>( - &mut self, - location: Location, - mut expr_f: F, - ) { - if Some(location.file) == self.debug_comptime_in_file { - let displayed_expr = expr_f(self.interner); - self.errors.push(( - InterpreterError::debug_evaluate_comptime(displayed_expr, location).into(), - location.file, - )); - } - } - - /// Run all the attributes on each item. The ordering is unspecified to users but currently - /// we run trait attributes first to (e.g.) register derive handlers before derive is - /// called on structs. - /// Returns any new items generated by attributes. - fn run_attributes( - &mut self, - traits: &BTreeMap, - types: &BTreeMap, - functions: &[UnresolvedFunctions], - ) -> CollectedItems { - let mut generated_items = CollectedItems::default(); - - for (trait_id, trait_) in traits { - let attributes = &trait_.trait_def.attributes; - let item = Value::TraitDefinition(*trait_id); - let span = trait_.trait_def.span; - self.local_module = trait_.module_id; - self.file = trait_.file_id; - self.run_comptime_attributes_on_item(attributes, item, span, &mut generated_items); - } - - for (struct_id, struct_def) in types { - let attributes = &struct_def.struct_def.attributes; - let item = Value::StructDefinition(*struct_id); - let span = struct_def.struct_def.span; - self.local_module = struct_def.module_id; - self.file = struct_def.file_id; - self.run_comptime_attributes_on_item(attributes, item, span, &mut generated_items); - } - - self.run_attributes_on_functions(functions, &mut generated_items); - generated_items - } - - fn run_attributes_on_functions( - &mut self, - function_sets: &[UnresolvedFunctions], - generated_items: &mut CollectedItems, - ) { - for function_set in function_sets { - self.file = function_set.file_id; - self.self_type = function_set.self_type.clone(); - - for (local_module, function_id, function) in &function_set.functions { - self.local_module = *local_module; - let attributes = function.secondary_attributes(); - let item = Value::FunctionDefinition(*function_id); - let span = function.span(); - self.run_comptime_attributes_on_item(attributes, item, span, generated_items); - } - } - } } diff --git a/compiler/noirc_frontend/src/elaborator/scope.rs b/compiler/noirc_frontend/src/elaborator/scope.rs index ae9a2c75ab6..b2367e0cf0e 100644 --- a/compiler/noirc_frontend/src/elaborator/scope.rs +++ b/compiler/noirc_frontend/src/elaborator/scope.rs @@ -171,12 +171,12 @@ impl<'context> Elaborator<'context> { pub fn push_scope(&mut self) { self.scopes.start_scope(); - self.comptime_scopes.push(Default::default()); + self.interner.comptime_scopes.push(Default::default()); } pub fn pop_scope(&mut self) { let scope = self.scopes.end_scope(); - self.comptime_scopes.pop(); + self.interner.comptime_scopes.pop(); self.check_for_unused_variables_in_scope_tree(scope.into()); } diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index 137433b48ef..7898f13945f 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -5,7 +5,7 @@ use crate::{ ast::TraitBound, hir::{def_collector::dc_crate::CompilationError, type_check::NoMatchingImplFoundError}, parser::ParserError, - token::Tokens, + token::Token, Type, }; use acvm::{acir::AcirField, BlackBoxResolutionError, FieldElement}; @@ -13,53 +13,183 @@ use fm::FileId; use iter_extended::vecmap; use noirc_errors::{CustomDiagnostic, Location}; -use super::value::Value; - /// The possible errors that can halt the interpreter. #[derive(Debug, Clone, PartialEq, Eq)] pub enum InterpreterError { - ArgumentCountMismatch { expected: usize, actual: usize, location: Location }, - TypeMismatch { expected: Type, value: Value, location: Location }, - NonComptimeVarReferenced { name: String, location: Location }, - VariableNotInScope { location: Location }, - IntegerOutOfRangeForType { value: FieldElement, typ: Type, location: Location }, - ErrorNodeEncountered { location: Location }, - NonFunctionCalled { value: Value, location: Location }, - NonBoolUsedInIf { value: Value, location: Location }, - NonBoolUsedInConstrain { value: Value, location: Location }, - FailingConstraint { message: Option, location: Location }, - NoMethodFound { name: String, typ: Type, location: Location }, - NonIntegerUsedInLoop { value: Value, location: Location }, - NonPointerDereferenced { value: Value, location: Location }, - NonTupleOrStructInMemberAccess { value: Value, location: Location }, - NonArrayIndexed { value: Value, location: Location }, - NonIntegerUsedAsIndex { value: Value, location: Location }, - NonIntegerIntegerLiteral { typ: Type, location: Location }, - NonIntegerArrayLength { typ: Type, location: Location }, - NonNumericCasted { value: Value, location: Location }, - IndexOutOfBounds { index: usize, length: usize, location: Location }, - ExpectedStructToHaveField { value: Value, field_name: String, location: Location }, - TypeUnsupported { typ: Type, location: Location }, - InvalidValueForUnary { value: Value, operator: &'static str, location: Location }, - InvalidValuesForBinary { lhs: Value, rhs: Value, operator: &'static str, location: Location }, - CastToNonNumericType { typ: Type, location: Location }, - QuoteInRuntimeCode { location: Location }, - NonStructInConstructor { typ: Type, location: Location }, - CannotInlineMacro { value: Value, location: Location }, - UnquoteFoundDuringEvaluation { location: Location }, - DebugEvaluateComptime { diagnostic: CustomDiagnostic, location: Location }, - FailedToParseMacro { error: ParserError, tokens: Rc, rule: &'static str, file: FileId }, - UnsupportedTopLevelItemUnquote { item: String, location: Location }, - ComptimeDependencyCycle { function: String, location: Location }, - NoImpl { location: Location }, - NoMatchingImplFound { error: NoMatchingImplFoundError, file: FileId }, - ImplMethodTypeMismatch { expected: Type, actual: Type, location: Location }, - BreakNotInLoop { location: Location }, - ContinueNotInLoop { location: Location }, + ArgumentCountMismatch { + expected: usize, + actual: usize, + location: Location, + }, + TypeMismatch { + expected: Type, + actual: Type, + location: Location, + }, + NonComptimeVarReferenced { + name: String, + location: Location, + }, + VariableNotInScope { + location: Location, + }, + IntegerOutOfRangeForType { + value: FieldElement, + typ: Type, + location: Location, + }, + ErrorNodeEncountered { + location: Location, + }, + NonFunctionCalled { + typ: Type, + location: Location, + }, + NonBoolUsedInIf { + typ: Type, + location: Location, + }, + NonBoolUsedInConstrain { + typ: Type, + location: Location, + }, + FailingConstraint { + message: Option, + location: Location, + }, + NoMethodFound { + name: String, + typ: Type, + location: Location, + }, + NonIntegerUsedInLoop { + typ: Type, + location: Location, + }, + NonPointerDereferenced { + typ: Type, + location: Location, + }, + NonTupleOrStructInMemberAccess { + typ: Type, + location: Location, + }, + NonArrayIndexed { + typ: Type, + location: Location, + }, + NonIntegerUsedAsIndex { + typ: Type, + location: Location, + }, + NonIntegerIntegerLiteral { + typ: Type, + location: Location, + }, + NonIntegerArrayLength { + typ: Type, + location: Location, + }, + NonNumericCasted { + typ: Type, + location: Location, + }, + IndexOutOfBounds { + index: usize, + length: usize, + location: Location, + }, + ExpectedStructToHaveField { + typ: Type, + field_name: String, + location: Location, + }, + TypeUnsupported { + typ: Type, + location: Location, + }, + InvalidValueForUnary { + typ: Type, + operator: &'static str, + location: Location, + }, + InvalidValuesForBinary { + lhs: Type, + rhs: Type, + operator: &'static str, + location: Location, + }, + CastToNonNumericType { + typ: Type, + location: Location, + }, + QuoteInRuntimeCode { + location: Location, + }, + NonStructInConstructor { + typ: Type, + location: Location, + }, + CannotInlineMacro { + value: String, + typ: Type, + location: Location, + }, + UnquoteFoundDuringEvaluation { + location: Location, + }, + DebugEvaluateComptime { + diagnostic: CustomDiagnostic, + location: Location, + }, + FailedToParseMacro { + error: ParserError, + tokens: Rc>, + rule: &'static str, + file: FileId, + }, + UnsupportedTopLevelItemUnquote { + item: String, + location: Location, + }, + ComptimeDependencyCycle { + function: String, + location: Location, + }, + NoImpl { + location: Location, + }, + NoMatchingImplFound { + error: NoMatchingImplFoundError, + file: FileId, + }, + ImplMethodTypeMismatch { + expected: Type, + actual: Type, + location: Location, + }, + BreakNotInLoop { + location: Location, + }, + ContinueNotInLoop { + location: Location, + }, BlackBoxError(BlackBoxResolutionError, Location), - FailedToResolveTraitBound { trait_bound: TraitBound, location: Location }, + FailedToResolveTraitBound { + trait_bound: TraitBound, + location: Location, + }, + TraitDefinitionMustBeAPath { + location: Location, + }, + FailedToResolveTraitDefinition { + location: Location, + }, - Unimplemented { item: String, location: Location }, + Unimplemented { + item: String, + location: Location, + }, // These cases are not errors, they are just used to prevent us from running more code // until the loop can be resumed properly. These cases will never be displayed to users. @@ -122,6 +252,8 @@ impl InterpreterError { | InterpreterError::BlackBoxError(_, location) | InterpreterError::BreakNotInLoop { location, .. } | InterpreterError::ContinueNotInLoop { location, .. } + | InterpreterError::TraitDefinitionMustBeAPath { location } + | InterpreterError::FailedToResolveTraitDefinition { location } | InterpreterError::FailedToResolveTraitBound { location, .. } => *location, InterpreterError::FailedToParseMacro { error, file, .. } => { @@ -166,9 +298,8 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let secondary = format!("Too {few_many} arguments"); CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::TypeMismatch { expected, value, location } => { - let typ = value.get_type(); - let msg = format!("Expected `{expected}` but a value of type `{typ}` was given"); + InterpreterError::TypeMismatch { expected, actual, location } => { + let msg = format!("Expected `{expected}` but a value of type `{actual}` was given"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } InterpreterError::NonComptimeVarReferenced { name, location } => { @@ -194,23 +325,23 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let secondary = "This is a bug, please report this if found!".to_string(); CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::NonFunctionCalled { value, location } => { + InterpreterError::NonFunctionCalled { typ, location } => { let msg = "Only functions may be called".to_string(); - let secondary = format!("Expression has type {}", value.get_type()); + let secondary = format!("Expression has type {typ}"); CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::NonBoolUsedInIf { value, location } => { - let msg = format!("Expected a `bool` but found `{}`", value.get_type()); + InterpreterError::NonBoolUsedInIf { typ, location } => { + let msg = format!("Expected a `bool` but found `{typ}`"); let secondary = "If conditions must be a boolean value".to_string(); CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::NonBoolUsedInConstrain { value, location } => { - let msg = format!("Expected a `bool` but found `{}`", value.get_type()); + InterpreterError::NonBoolUsedInConstrain { typ, location } => { + let msg = format!("Expected a `bool` but found `{typ}`"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } InterpreterError::FailingConstraint { message, location } => { let (primary, secondary) = match message { - Some(msg) => (format!("{msg}"), "Assertion failed".into()), + Some(msg) => (msg.clone(), "Assertion failed".into()), None => ("Assertion failed".into(), String::new()), }; CustomDiagnostic::simple_error(primary, secondary, location.span) @@ -219,32 +350,30 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let msg = format!("No method named `{name}` found for type `{typ}`"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::NonIntegerUsedInLoop { value, location } => { - let typ = value.get_type(); + InterpreterError::NonIntegerUsedInLoop { typ, location } => { let msg = format!("Non-integer type `{typ}` used in for loop"); - let secondary = if matches!(typ.as_ref(), &Type::FieldElement) { + let secondary = if matches!(typ, Type::FieldElement) { "`field` is not an integer type, try `u32` instead".to_string() } else { String::new() }; CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::NonPointerDereferenced { value, location } => { - let typ = value.get_type(); + InterpreterError::NonPointerDereferenced { typ, location } => { let msg = format!("Only references may be dereferenced, but found `{typ}`"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::NonTupleOrStructInMemberAccess { value, location } => { - let msg = format!("The type `{}` has no fields to access", value.get_type()); + InterpreterError::NonTupleOrStructInMemberAccess { typ, location } => { + let msg = format!("The type `{typ}` has no fields to access"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::NonArrayIndexed { value, location } => { - let msg = format!("Expected an array or slice but found a(n) {}", value.get_type()); + InterpreterError::NonArrayIndexed { typ, location } => { + let msg = format!("Expected an array or slice but found a(n) {typ}"); let secondary = "Only arrays or slices may be indexed".into(); CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::NonIntegerUsedAsIndex { value, location } => { - let msg = format!("Expected an integer but found a(n) {}", value.get_type()); + InterpreterError::NonIntegerUsedAsIndex { typ, location } => { + let msg = format!("Expected an integer but found a(n) {typ}"); let secondary = "Only integers may be indexed. Note that this excludes `field`s".into(); CustomDiagnostic::simple_error(msg, secondary, location.span) @@ -259,17 +388,16 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let secondary = "Array lengths must be integers".into(); CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::NonNumericCasted { value, location } => { + InterpreterError::NonNumericCasted { typ, location } => { let msg = "Only numeric types may be casted".into(); - let secondary = format!("`{}` is non-numeric", value.get_type()); + let secondary = format!("`{typ}` is non-numeric"); CustomDiagnostic::simple_error(msg, secondary, location.span) } InterpreterError::IndexOutOfBounds { index, length, location } => { let msg = format!("{index} is out of bounds for the array of length {length}"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::ExpectedStructToHaveField { value, field_name, location } => { - let typ = value.get_type(); + InterpreterError::ExpectedStructToHaveField { typ, field_name, location } => { let msg = format!("The type `{typ}` has no field named `{field_name}`"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } @@ -278,13 +406,11 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { format!("The type `{typ}` is currently unsupported in comptime expressions"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::InvalidValueForUnary { value, operator, location } => { - let msg = format!("`{}` cannot be used with unary {operator}", value.get_type()); + InterpreterError::InvalidValueForUnary { typ, operator, location } => { + let msg = format!("`{typ}` cannot be used with unary {operator}"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } InterpreterError::InvalidValuesForBinary { lhs, rhs, operator, location } => { - let lhs = lhs.get_type(); - let rhs = rhs.get_type(); let msg = format!("No implementation for `{lhs}` {operator} `{rhs}`",); CustomDiagnostic::simple_error(msg, String::new(), location.span) } @@ -300,10 +426,9 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let msg = format!("`{typ}` is not a struct type"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::CannotInlineMacro { value, location } => { - let typ = value.get_type(); + InterpreterError::CannotInlineMacro { value, typ, location } => { let msg = format!("Cannot inline values of type `{typ}` into this position"); - let secondary = format!("Cannot inline value {value:?}"); + let secondary = format!("Cannot inline value `{value}`"); CustomDiagnostic::simple_error(msg, secondary, location.span) } InterpreterError::UnquoteFoundDuringEvaluation { location } => { @@ -314,7 +439,7 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { InterpreterError::DebugEvaluateComptime { diagnostic, .. } => diagnostic.clone(), InterpreterError::FailedToParseMacro { error, tokens, rule, file: _ } => { let message = format!("Failed to parse macro's token stream into {rule}"); - let tokens = vecmap(&tokens.0, ToString::to_string).join(" "); + let tokens = vecmap(tokens.iter(), ToString::to_string).join(" "); // 10 is an aribtrary number of tokens here chosen to fit roughly onto one line let token_stream = if tokens.len() > 10 { @@ -383,6 +508,14 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { InterpreterError::NoMatchingImplFound { error, .. } => error.into(), InterpreterError::Break => unreachable!("Uncaught InterpreterError::Break"), InterpreterError::Continue => unreachable!("Uncaught InterpreterError::Continue"), + InterpreterError::TraitDefinitionMustBeAPath { location } => { + let msg = "Trait definition arguments must be a variable or path".to_string(); + CustomDiagnostic::simple_error(msg, String::new(), location.span) + } + InterpreterError::FailedToResolveTraitDefinition { location } => { + let msg = "Failed to resolve to a trait definition".to_string(); + CustomDiagnostic::simple_error(msg, String::new(), location.span) + } } } } diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 8f3f1295cac..888fa17d0af 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -258,8 +258,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { pub(super) fn enter_function(&mut self) -> (bool, Vec>) { // Drain every scope except the global scope let mut scope = Vec::new(); - if self.elaborator.comptime_scopes.len() > 1 { - scope = self.elaborator.comptime_scopes.drain(1..).collect(); + if self.elaborator.interner.comptime_scopes.len() > 1 { + scope = self.elaborator.interner.comptime_scopes.drain(1..).collect(); } self.push_scope(); (std::mem::take(&mut self.in_loop), scope) @@ -269,21 +269,21 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { self.in_loop = state.0; // Keep only the global scope - self.elaborator.comptime_scopes.truncate(1); - self.elaborator.comptime_scopes.append(&mut state.1); + self.elaborator.interner.comptime_scopes.truncate(1); + self.elaborator.interner.comptime_scopes.append(&mut state.1); } pub(super) fn push_scope(&mut self) { - self.elaborator.comptime_scopes.push(HashMap::default()); + self.elaborator.interner.comptime_scopes.push(HashMap::default()); } pub(super) fn pop_scope(&mut self) { - self.elaborator.comptime_scopes.pop(); + self.elaborator.interner.comptime_scopes.pop(); } fn current_scope_mut(&mut self) -> &mut HashMap { // the global scope is always at index zero, so this is always Some - self.elaborator.comptime_scopes.last_mut().unwrap() + self.elaborator.interner.comptime_scopes.last_mut().unwrap() } fn unbind_generics_from_previous_function(&mut self) { @@ -351,7 +351,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Ok(()) } (value, _) => { - Err(InterpreterError::TypeMismatch { expected: typ.clone(), value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected: typ.clone(), actual, location }) } }, HirPattern::Struct(struct_type, pattern_fields, _) => { @@ -362,7 +363,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { for (field_name, field_pattern) in pattern_fields { let field = fields.get(&field_name.0.contents).ok_or_else(|| { InterpreterError::ExpectedStructToHaveField { - value: Value::Struct(fields.clone(), struct_type.clone()), + typ: struct_type.clone(), field_name: field_name.0.contents.clone(), location, } @@ -380,7 +381,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } value => Err(InterpreterError::TypeMismatch { expected: typ.clone(), - value, + actual: value.get_type().into_owned(), location, }), }; @@ -402,7 +403,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { return Ok(()); } - for scope in self.elaborator.comptime_scopes.iter_mut().rev() { + for scope in self.elaborator.interner.comptime_scopes.iter_mut().rev() { if let Entry::Occupied(mut entry) = scope.entry(id) { entry.insert(argument); return Ok(()); @@ -416,7 +417,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } pub fn lookup_id(&self, id: DefinitionId, location: Location) -> IResult { - for scope in self.elaborator.comptime_scopes.iter().rev() { + for scope in self.elaborator.interner.comptime_scopes.iter().rev() { if let Some(value) = scope.get(&id) { return Ok(value.clone()); } @@ -569,7 +570,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { consuming = false; if let Some(value) = values.pop_front() { - result.push_str(&value.to_string()); + result.push_str(&value.display(self.elaborator.interner).to_string()); } } other if !consuming => { @@ -769,7 +770,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { value => { let location = self.elaborator.interner.expr_location(&id); let operator = "minus"; - Err(InterpreterError::InvalidValueForUnary { value, location, operator }) + let typ = value.get_type().into_owned(); + Err(InterpreterError::InvalidValueForUnary { typ, location, operator }) } }, UnaryOp::Not => match rhs { @@ -784,7 +786,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Value::U64(value) => Ok(Value::U64(!value)), value => { let location = self.elaborator.interner.expr_location(&id); - Err(InterpreterError::InvalidValueForUnary { value, location, operator: "not" }) + let typ = value.get_type().into_owned(); + Err(InterpreterError::InvalidValueForUnary { typ, location, operator: "not" }) } }, UnaryOp::MutableReference => { @@ -800,7 +803,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Value::Pointer(element, _) => Ok(element.borrow().clone()), value => { let location = self.elaborator.interner.expr_location(&id); - Err(InterpreterError::NonPointerDereferenced { value, location }) + let typ = value.get_type().into_owned(); + Err(InterpreterError::NonPointerDereferenced { typ, location }) } }, } @@ -814,6 +818,13 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { return self.evaluate_overloaded_infix(infix, lhs, rhs, id); } + let make_error = |this: &mut Self, lhs: Value, rhs: Value, operator| { + let location = this.elaborator.interner.expr_location(&id); + let lhs = lhs.get_type().into_owned(); + let rhs = rhs.get_type().into_owned(); + Err(InvalidValuesForBinary { lhs, rhs, location, operator }) + }; + use InterpreterError::InvalidValuesForBinary; match infix.operator.kind { BinaryOpKind::Add => match (lhs, rhs) { @@ -826,10 +837,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs + rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs + rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs + rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "+" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "+"), }, BinaryOpKind::Subtract => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs - rhs)), @@ -841,10 +849,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs - rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs - rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs - rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "-" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "-"), }, BinaryOpKind::Multiply => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs * rhs)), @@ -856,10 +861,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs * rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs * rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs * rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "*" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "*"), }, BinaryOpKind::Divide => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs / rhs)), @@ -871,10 +873,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs / rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs / rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs / rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "/" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "/"), }, BinaryOpKind::Equal => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs == rhs)), @@ -886,10 +885,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs == rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "==" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "=="), }, BinaryOpKind::NotEqual => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs != rhs)), @@ -901,10 +897,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs != rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "!=" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "!="), }, BinaryOpKind::Less => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs < rhs)), @@ -916,10 +909,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs < rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "<"), }, BinaryOpKind::LessEqual => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs <= rhs)), @@ -931,10 +921,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs <= rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<=" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "<="), }, BinaryOpKind::Greater => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs > rhs)), @@ -946,10 +933,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs > rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, ">"), }, BinaryOpKind::GreaterEqual => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs >= rhs)), @@ -961,10 +945,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs >= rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">=" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, ">="), }, BinaryOpKind::And => match (lhs, rhs) { (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs & rhs)), @@ -976,10 +957,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs & rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs & rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs & rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "&" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "&"), }, BinaryOpKind::Or => match (lhs, rhs) { (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs | rhs)), @@ -991,10 +969,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs | rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs | rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs | rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "|" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "|"), }, BinaryOpKind::Xor => match (lhs, rhs) { (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs ^ rhs)), @@ -1006,10 +981,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs ^ rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs ^ rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs ^ rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "^" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "^"), }, BinaryOpKind::ShiftRight => match (lhs, rhs) { (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs >> rhs)), @@ -1020,10 +992,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs >> rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs >> rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs >> rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">>" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, ">>"), }, BinaryOpKind::ShiftLeft => match (lhs, rhs) { (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs << rhs)), @@ -1034,10 +1003,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs << rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs << rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs << rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<<" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "<<"), }, BinaryOpKind::Modulo => match (lhs, rhs) { (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs % rhs)), @@ -1048,10 +1014,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs % rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs % rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs % rhs)), - (lhs, rhs) => { - let location = self.elaborator.interner.expr_location(&id); - Err(InvalidValuesForBinary { lhs, rhs, location, operator: "%" }) - } + (lhs, rhs) => make_error(self, lhs, rhs, "%"), }, } } @@ -1155,7 +1118,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Value::Array(array, _) => array, Value::Slice(array, _) => array, value => { - return Err(InterpreterError::NonArrayIndexed { value, location }); + let typ = value.get_type().into_owned(); + return Err(InterpreterError::NonArrayIndexed { typ, location }); } }; @@ -1175,7 +1139,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Value::U32(value) => value as usize, Value::U64(value) => value as usize, value => { - return Err(InterpreterError::NonIntegerUsedAsIndex { value, location }); + let typ = value.get_type().into_owned(); + return Err(InterpreterError::NonIntegerUsedAsIndex { typ, location }); } }; @@ -1222,7 +1187,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } value => { let location = self.elaborator.interner.expr_location(&id); - return Err(InterpreterError::NonTupleOrStructInMemberAccess { value, location }); + let typ = value.get_type().into_owned(); + return Err(InterpreterError::NonTupleOrStructInMemberAccess { typ, location }); } }; @@ -1230,7 +1196,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let location = self.elaborator.interner.expr_location(&id); let value = Value::Struct(fields, struct_type); let field_name = access.rhs.0.contents; - InterpreterError::ExpectedStructToHaveField { value, field_name, location } + let typ = value.get_type().into_owned(); + InterpreterError::ExpectedStructToHaveField { typ, field_name, location } }) } @@ -1255,7 +1222,10 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Ok(result) } Value::Closure(closure, env, _) => self.call_closure(closure, env, arguments, location), - value => Err(InterpreterError::NonFunctionCalled { value, location }), + value => { + let typ = value.get_type().into_owned(); + Err(InterpreterError::NonFunctionCalled { typ, location }) + } } } @@ -1331,7 +1301,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } value => { let location = interner.expr_location(&id); - return Err(InterpreterError::NonNumericCasted { value, location }); + let typ = value.get_type().into_owned(); + return Err(InterpreterError::NonNumericCasted { typ, location }); } }; @@ -1396,7 +1367,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Value::Bool(value) => value, value => { let location = self.elaborator.interner.expr_location(&id); - return Err(InterpreterError::NonBoolUsedInIf { value, location }); + let typ = value.get_type().into_owned(); + return Err(InterpreterError::NonBoolUsedInIf { typ, location }); } }; @@ -1436,8 +1408,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { fn evaluate_quote(&mut self, mut tokens: Tokens, expr_id: ExprId) -> IResult { let location = self.elaborator.interner.expr_location(&expr_id); - tokens = self.substitute_unquoted_values_into_tokens(tokens, location)?; - Ok(Value::Code(Rc::new(tokens))) + let tokens = self.substitute_unquoted_values_into_tokens(tokens, location)?; + Ok(Value::Quoted(Rc::new(tokens))) } pub fn evaluate_statement(&mut self, statement: StmtId) -> IResult { @@ -1474,11 +1446,14 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Value::Bool(false) => { let location = self.elaborator.interner.expr_location(&constrain.0); let message = constrain.2.and_then(|expr| self.evaluate(expr).ok()); + let message = + message.map(|value| value.display(self.elaborator.interner).to_string()); Err(InterpreterError::FailingConstraint { location, message }) } value => { let location = self.elaborator.interner.expr_location(&constrain.0); - Err(InterpreterError::NonBoolUsedInConstrain { value, location }) + let typ = value.get_type().into_owned(); + Err(InterpreterError::NonBoolUsedInConstrain { typ, location }) } } } @@ -1498,7 +1473,10 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { *value.borrow_mut() = rhs; Ok(()) } - value => Err(InterpreterError::NonPointerDereferenced { value, location }), + value => { + let typ = value.get_type().into_owned(); + Err(InterpreterError::NonPointerDereferenced { typ, location }) + } } } HirLValue::MemberAccess { object, field_name, field_index, typ: _, location } => { @@ -1507,7 +1485,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let index = field_index.ok_or_else(|| { let value = object_value.clone(); let field_name = field_name.to_string(); - InterpreterError::ExpectedStructToHaveField { value, field_name, location } + let typ = value.get_type().into_owned(); + InterpreterError::ExpectedStructToHaveField { typ, field_name, location } })?; match object_value { @@ -1520,7 +1499,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { self.store_lvalue(*object, Value::Struct(fields, typ.follow_bindings())) } value => { - Err(InterpreterError::NonTupleOrStructInMemberAccess { value, location }) + let typ = value.get_type().into_owned(); + Err(InterpreterError::NonTupleOrStructInMemberAccess { typ, location }) } } } @@ -1552,7 +1532,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { match self.evaluate_lvalue(lvalue)? { Value::Pointer(value, _) => Ok(value.borrow().clone()), value => { - Err(InterpreterError::NonPointerDereferenced { value, location: *location }) + let typ = value.get_type().into_owned(); + Err(InterpreterError::NonPointerDereferenced { typ, location: *location }) } } } @@ -1563,14 +1544,15 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let value = object_value.clone(); let field_name = field_name.to_string(); let location = *location; - InterpreterError::ExpectedStructToHaveField { value, field_name, location } + let typ = value.get_type().into_owned(); + InterpreterError::ExpectedStructToHaveField { typ, field_name, location } })?; match object_value { Value::Tuple(mut values) => Ok(values.swap_remove(index)), Value::Struct(fields, _) => Ok(fields[&field_name.0.contents].clone()), value => Err(InterpreterError::NonTupleOrStructInMemberAccess { - value, + typ: value.get_type().into_owned(), location: *location, }), } @@ -1598,7 +1580,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Value::U64(value) => Ok((value as i128, |i| Value::U64(i as u64))), value => { let location = this.elaborator.interner.expr_location(&expr); - Err(InterpreterError::NonIntegerUsedInLoop { value, location }) + let typ = value.get_type().into_owned(); + Err(InterpreterError::NonIntegerUsedInLoop { typ, location }) } } }; @@ -1652,9 +1635,9 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let print_newline = arguments[0].0 == Value::Bool(true); if print_newline { - println!("{}", arguments[1].0); + println!("{}", arguments[1].0.display(self.elaborator.interner)); } else { - print!("{}", arguments[1].0); + print!("{}", arguments[1].0.display(self.elaborator.interner)); } Ok(Value::Unit) diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 645e4d707c0..76f85740195 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -11,11 +11,11 @@ use rustc_hash::FxHashMap as HashMap; use crate::{ ast::IntegerBitSize, - hir::comptime::{errors::IResult, InterpreterError, Value}, + hir::comptime::{errors::IResult, value::add_token_spans, InterpreterError, Value}, macros_api::{NodeInterner, Signedness}, node_interner::TraitId, parser, - token::{SpannedToken, Token, Tokens}, + token::Token, QuotedType, Shared, Type, }; @@ -53,6 +53,8 @@ impl<'local, 'context> Interpreter<'local, 'context> { "trait_def_as_trait_constraint" => { trait_def_as_trait_constraint(interner, arguments, location) } + "trait_def_eq" => trait_def_eq(interner, arguments, location), + "trait_def_hash" => trait_def_hash(interner, arguments, location), "quoted_as_trait_constraint" => quoted_as_trait_constraint(self, arguments, location), "quoted_as_type" => quoted_as_type(self, arguments, location), "type_eq" => type_eq(arguments, location), @@ -80,8 +82,7 @@ pub(super) fn check_argument_count( } fn failing_constraint(message: impl Into, location: Location) -> IResult { - let message = Some(Value::String(Rc::new(message.into()))); - Err(InterpreterError::FailingConstraint { message, location }) + Err(InterpreterError::FailingConstraint { message: Some(message.into()), location }) } pub(super) fn get_array( @@ -94,7 +95,8 @@ pub(super) fn get_array( value => { let type_var = Box::new(interner.next_type_variable()); let expected = Type::Array(type_var.clone(), type_var); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } @@ -109,7 +111,8 @@ fn get_slice( value => { let type_var = Box::new(interner.next_type_variable()); let expected = Type::Slice(type_var); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } @@ -118,7 +121,8 @@ pub(super) fn get_field(value: Value, location: Location) -> IResult Ok(value), value => { - Err(InterpreterError::TypeMismatch { expected: Type::FieldElement, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected: Type::FieldElement, actual, location }) } } } @@ -128,7 +132,8 @@ pub(super) fn get_u32(value: Value, location: Location) -> IResult { Value::U32(value) => Ok(value), value => { let expected = Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } @@ -138,7 +143,8 @@ fn get_trait_constraint(value: Value, location: Location) -> IResult<(TraitId, V Value::TraitConstraint(trait_id, generics) => Ok((trait_id, generics)), value => { let expected = Type::Quoted(QuotedType::TraitConstraint); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } @@ -148,17 +154,19 @@ fn get_trait_def(value: Value, location: Location) -> IResult { Value::TraitDefinition(id) => Ok(id), value => { let expected = Type::Quoted(QuotedType::TraitDefinition); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } -fn get_quoted(value: Value, location: Location) -> IResult> { +fn get_quoted(value: Value, location: Location) -> IResult>> { match value { - Value::Code(tokens) => Ok(tokens), + Value::Quoted(tokens) => Ok(tokens), value => { let expected = Type::Quoted(QuotedType::Quoted); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } @@ -175,7 +183,8 @@ fn array_len( value => { let type_var = Box::new(interner.next_type_variable()); let expected = Type::Array(type_var.clone(), type_var); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } @@ -193,7 +202,8 @@ fn as_slice( value => { let type_var = Box::new(interner.next_type_variable()); let expected = Type::Array(type_var.clone(), type_var); - Err(InterpreterError::TypeMismatch { expected, value, location }) + let actual = value.get_type().into_owned(); + Err(InterpreterError::TypeMismatch { expected, actual, location }) } } } @@ -223,7 +233,8 @@ fn struct_def_as_type( Value::StructDefinition(id) => id, value => { let expected = Type::Quoted(QuotedType::StructDefinition); - return Err(InterpreterError::TypeMismatch { expected, location, value }); + let actual = value.get_type().into_owned(); + return Err(InterpreterError::TypeMismatch { expected, location, actual }); } }; @@ -246,11 +257,12 @@ fn struct_def_generics( ) -> IResult { check_argument_count(1, &arguments, location)?; - let (struct_def, span) = match arguments.pop().unwrap() { - (Value::StructDefinition(id), location) => (id, location.span), + let struct_def = match arguments.pop().unwrap().0 { + Value::StructDefinition(id) => id, value => { let expected = Type::Quoted(QuotedType::StructDefinition); - return Err(InterpreterError::TypeMismatch { expected, location, value: value.0 }); + let actual = value.get_type().into_owned(); + return Err(InterpreterError::TypeMismatch { expected, location, actual }); } }; @@ -258,8 +270,8 @@ fn struct_def_generics( let struct_def = struct_def.borrow(); let generics = struct_def.generics.iter().map(|generic| { - let name = SpannedToken::new(Token::Ident(generic.type_var.borrow().to_string()), span); - Value::Code(Rc::new(Tokens(vec![name]))) + let name = Token::Ident(generic.type_var.borrow().to_string()); + Value::Quoted(Rc::new(vec![name])) }); let typ = Type::Slice(Box::new(Type::Quoted(QuotedType::Quoted))); @@ -275,24 +287,22 @@ fn struct_def_fields( ) -> IResult { check_argument_count(1, &arguments, location)?; - let (struct_def, span) = match arguments.pop().unwrap() { - (Value::StructDefinition(id), location) => (id, location.span), + let struct_def = match arguments.pop().unwrap().0 { + Value::StructDefinition(id) => id, value => { let expected = Type::Quoted(QuotedType::StructDefinition); - return Err(InterpreterError::TypeMismatch { expected, location, value: value.0 }); + let actual = value.get_type().into_owned(); + return Err(InterpreterError::TypeMismatch { expected, location, actual }); } }; let struct_def = interner.get_struct(struct_def); let struct_def = struct_def.borrow(); - let make_token = |name| SpannedToken::new(Token::Ident(name), span); - let make_quoted = |tokens| Value::Code(Rc::new(Tokens(tokens))); - let mut fields = im::Vector::new(); for (name, typ) in struct_def.get_fields_as_written() { - let name = make_quoted(vec![make_token(name)]); + let name = Value::Quoted(Rc::new(vec![Token::Ident(name)])); let typ = Value::Type(typ); fields.push_back(Value::Tuple(vec![name, typ])); } @@ -394,7 +404,7 @@ fn quoted_as_trait_constraint( check_argument_count(1, &arguments, location)?; let tokens = get_quoted(arguments.pop().unwrap().0, location)?; - let quoted = tokens.as_ref().clone(); + let quoted = add_token_spans(tokens.clone(), location.span); let trait_bound = parser::trait_bound().parse(quoted).map_err(|mut errors| { let error = errors.swap_remove(0); @@ -420,7 +430,7 @@ fn quoted_as_type( check_argument_count(1, &arguments, location)?; let tokens = get_quoted(arguments.pop().unwrap().0, location)?; - let quoted = tokens.as_ref().clone(); + let quoted = add_token_spans(tokens.clone(), location.span); let typ = parser::parse_type().parse(quoted).map_err(|mut errors| { let error = errors.swap_remove(0); @@ -481,6 +491,37 @@ fn trait_constraint_eq( Ok(Value::Bool(constraint_a == constraint_b)) } +// fn trait_def_hash(def: TraitDefinition) -> Field +fn trait_def_hash( + _interner: &mut NodeInterner, + mut arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + check_argument_count(1, &arguments, location)?; + + let id = get_trait_def(arguments.pop().unwrap().0, location)?; + + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + id.hash(&mut hasher); + let hash = hasher.finish(); + + Ok(Value::Field((hash as u128).into())) +} + +// fn trait_def_eq(def_a: TraitDefinition, def_b: TraitDefinition) -> bool +fn trait_def_eq( + _interner: &mut NodeInterner, + mut arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + check_argument_count(2, &arguments, location)?; + + let id_b = get_trait_def(arguments.pop().unwrap().0, location)?; + let id_a = get_trait_def(arguments.pop().unwrap().0, location)?; + + Ok(Value::Bool(id_a == id_b)) +} + // fn zeroed() -> T fn zeroed(return_type: Type) -> IResult { match return_type { diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs index c80d39f8df8..c7b1532c9b7 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/unquote.rs @@ -15,20 +15,20 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { &mut self, tokens: Tokens, location: Location, - ) -> IResult { + ) -> IResult> { let mut new_tokens = Vec::with_capacity(tokens.0.len()); for token in tokens.0 { - match token.token() { + match token.into_token() { Token::UnquoteMarker(id) => { - let value = self.evaluate(*id)?; + let value = self.evaluate(id)?; let tokens = value.into_tokens(self.elaborator.interner, location)?; - new_tokens.extend(tokens.0); + new_tokens.extend(tokens); } - _ => new_tokens.push(token), + token => new_tokens.push(token), } } - Ok(Tokens(new_tokens)) + Ok(new_tokens) } } diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 0959e4c17ac..61475fa60e1 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -4,7 +4,7 @@ use acvm::{AcirField, FieldElement}; use chumsky::Parser; use im::Vector; use iter_extended::{try_vecmap, vecmap}; -use noirc_errors::Location; +use noirc_errors::{Location, Span}; use crate::{ ast::{ArrayLiteral, ConstructorExpression, Ident, IntegerBitSize, Signedness}, @@ -46,7 +46,10 @@ pub enum Value { Pointer(Shared, /* auto_deref */ bool), Array(Vector, Type), Slice(Vector, Type), - Code(Rc), + /// Quoted tokens don't have spans because otherwise inserting them in the middle of other + /// tokens can cause larger spans to be before lesser spans, causing an assert. They may also + /// be inserted into separate files entirely. + Quoted(Rc>), StructDefinition(StructId), TraitConstraint(TraitId, /* trait generics */ Vec), TraitDefinition(TraitId), @@ -84,7 +87,7 @@ impl Value { Value::Struct(_, typ) => return Cow::Borrowed(typ), Value::Array(_, typ) => return Cow::Borrowed(typ), Value::Slice(_, typ) => return Cow::Borrowed(typ), - Value::Code(_) => Type::Quoted(QuotedType::Quoted), + Value::Quoted(_) => Type::Quoted(QuotedType::Quoted), Value::StructDefinition(_) => Type::Quoted(QuotedType::StructDefinition), Value::Pointer(element, auto_deref) => { if *auto_deref { @@ -204,9 +207,9 @@ impl Value { try_vecmap(elements, |element| element.into_expression(interner, location))?; ExpressionKind::Literal(Literal::Slice(ArrayLiteral::Standard(elements))) } - Value::Code(tokens) => { + Value::Quoted(tokens) => { // Wrap the tokens in '{' and '}' so that we can parse statements as well. - let mut tokens_to_parse = tokens.as_ref().clone(); + let mut tokens_to_parse = add_token_spans(tokens.clone(), location.span); tokens_to_parse.0.insert(0, SpannedToken::new(Token::LeftBrace, location.span)); tokens_to_parse.0.push(SpannedToken::new(Token::RightBrace, location.span)); @@ -228,7 +231,9 @@ impl Value { | Value::Zeroed(_) | Value::Type(_) | Value::ModuleDefinition(_) => { - return Err(InterpreterError::CannotInlineMacro { value: self, location }) + let typ = self.get_type().into_owned(); + let value = self.display(interner).to_string(); + return Err(InterpreterError::CannotInlineMacro { typ, value, location }); } }; @@ -339,7 +344,7 @@ impl Value { })?; HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements))) } - Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)), + Value::Quoted(tokens) => HirExpression::Unquote(add_token_spans(tokens, location.span)), Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(..) @@ -348,7 +353,9 @@ impl Value { | Value::Zeroed(_) | Value::Type(_) | Value::ModuleDefinition(_) => { - return Err(InterpreterError::CannotInlineMacro { value: self, location }) + let typ = self.get_type().into_owned(); + let value = self.display(interner).to_string(); + return Err(InterpreterError::CannotInlineMacro { value, typ, location }); } }; @@ -362,13 +369,13 @@ impl Value { self, interner: &mut NodeInterner, location: Location, - ) -> IResult { + ) -> IResult> { let token = match self { - Value::Code(tokens) => return Ok(unwrap_rc(tokens)), + Value::Quoted(tokens) => return Ok(unwrap_rc(tokens)), Value::Type(typ) => Token::QuotedType(interner.push_quoted_type(typ)), other => Token::UnquoteMarker(other.into_hir_expression(interner, location)?), }; - Ok(Tokens(vec![SpannedToken::new(token, location.span)])) + Ok(vec![token]) } /// Converts any unsigned `Value` into a `u128`. @@ -391,12 +398,24 @@ impl Value { pub(crate) fn into_top_level_items( self, location: Location, + interner: &NodeInterner, ) -> IResult> { match self { - Value::Code(tokens) => parse_tokens(tokens, parser::top_level_items(), location.file), - value => Err(InterpreterError::CannotInlineMacro { value, location }), + Value::Quoted(tokens) => parse_tokens(tokens, parser::top_level_items(), location), + _ => { + let typ = self.get_type().into_owned(); + let value = self.display(interner).to_string(); + Err(InterpreterError::CannotInlineMacro { value, typ, location }) + } } } + + pub fn display<'value, 'interner>( + &'value self, + interner: &'interner NodeInterner, + ) -> ValuePrinter<'value, 'interner> { + ValuePrinter { value: self, interner } + } } /// Unwraps an Rc value without cloning the inner value if the reference count is 1. Clones otherwise. @@ -404,20 +423,35 @@ pub(crate) fn unwrap_rc(rc: Rc) -> T { Rc::try_unwrap(rc).unwrap_or_else(|rc| (*rc).clone()) } -fn parse_tokens(tokens: Rc, parser: impl NoirParser, file: fm::FileId) -> IResult { - match parser.parse(tokens.as_ref().clone()) { +fn parse_tokens( + tokens: Rc>, + parser: impl NoirParser, + location: Location, +) -> IResult { + match parser.parse(add_token_spans(tokens.clone(), location.span)) { Ok(expr) => Ok(expr), Err(mut errors) => { let error = errors.swap_remove(0); let rule = "an expression"; + let file = location.file; Err(InterpreterError::FailedToParseMacro { error, file, tokens, rule }) } } } -impl Display for Value { +pub(crate) fn add_token_spans(tokens: Rc>, span: Span) -> Tokens { + let tokens = unwrap_rc(tokens); + Tokens(vecmap(tokens, |token| SpannedToken::new(token, span))) +} + +pub struct ValuePrinter<'value, 'interner> { + value: &'value Value, + interner: &'interner NodeInterner, +} + +impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { + match self.value { Value::Unit => write!(f, "()"), Value::Bool(value) => { let msg = if *value { "true" } else { "false" }; @@ -438,7 +472,7 @@ impl Display for Value { Value::Function(..) => write!(f, "(function)"), Value::Closure(_, _, _) => write!(f, "(closure)"), Value::Tuple(fields) => { - let fields = vecmap(fields, ToString::to_string); + let fields = vecmap(fields, |field| field.display(self.interner).to_string()); write!(f, "({})", fields.join(", ")) } Value::Struct(fields, typ) => { @@ -446,29 +480,48 @@ impl Display for Value { Type::Struct(def, _) => def.borrow().name.to_string(), other => other.to_string(), }; - let fields = vecmap(fields, |(name, value)| format!("{}: {}", name, value)); + let fields = vecmap(fields, |(name, value)| { + format!("{}: {}", name, value.display(self.interner)) + }); write!(f, "{typename} {{ {} }}", fields.join(", ")) } - Value::Pointer(value, _) => write!(f, "&mut {}", value.borrow()), + Value::Pointer(value, _) => write!(f, "&mut {}", value.borrow().display(self.interner)), Value::Array(values, _) => { - let values = vecmap(values, ToString::to_string); + let values = vecmap(values, |value| value.display(self.interner).to_string()); write!(f, "[{}]", values.join(", ")) } Value::Slice(values, _) => { - let values = vecmap(values, ToString::to_string); + let values = vecmap(values, |value| value.display(self.interner).to_string()); write!(f, "&[{}]", values.join(", ")) } - Value::Code(tokens) => { + Value::Quoted(tokens) => { write!(f, "quote {{")?; - for token in tokens.0.iter() { + for token in tokens.iter() { write!(f, " {token}")?; } write!(f, " }}") } - Value::StructDefinition(_) => write!(f, "(struct definition)"), - Value::TraitConstraint { .. } => write!(f, "(trait constraint)"), - Value::TraitDefinition(_) => write!(f, "(trait definition)"), - Value::FunctionDefinition(_) => write!(f, "(function definition)"), + Value::StructDefinition(id) => { + let def = self.interner.get_struct(*id); + let def = def.borrow(); + write!(f, "{}", def.name) + } + Value::TraitConstraint(trait_id, generics) => { + let trait_ = self.interner.get_trait(*trait_id); + let generic_string = vecmap(generics, ToString::to_string).join(", "); + if generics.is_empty() { + write!(f, "{}", trait_.name) + } else { + write!(f, "{}<{generic_string}>", trait_.name) + } + } + Value::TraitDefinition(trait_id) => { + let trait_ = self.interner.get_trait(*trait_id); + write!(f, "{}", trait_.name) + } + Value::FunctionDefinition(function_id) => { + write!(f, "{}", self.interner.function_name(function_id)) + } Value::ModuleDefinition(_) => write!(f, "(module)"), Value::Zeroed(typ) => write!(f, "(zeroed {typ})"), Value::Type(typ) => write!(f, "{}", typ), diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index bf6de746791..1cc1abfa495 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -114,8 +114,6 @@ pub enum ResolverError { MacroIsNotComptime { span: Span }, #[error("Annotation name must refer to a comptime function")] NonFunctionInAnnotation { span: Span }, - #[error("Unknown annotation")] - UnknownAnnotation { span: Span }, } impl ResolverError { @@ -460,13 +458,6 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) }, - ResolverError::UnknownAnnotation { span } => { - Diagnostic::simple_warning( - "Unknown annotation".into(), - "No matching comptime function found in scope".into(), - *span, - ) - }, } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 0ec975a04db..fc1af63540a 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -2092,6 +2092,13 @@ impl Type { Type::Forall(_, typ) => typ.replace_named_generics_with_type_variables(), } } + + pub fn slice_element_type(&self) -> Option<&Type> { + match self { + Type::Slice(element) => Some(element), + _ => None, + } + } } /// Wraps a given `expression` in `expression.as_slice()` diff --git a/compiler/noirc_frontend/src/lexer/errors.rs b/compiler/noirc_frontend/src/lexer/errors.rs index 387ced05258..be5180a777b 100644 --- a/compiler/noirc_frontend/src/lexer/errors.rs +++ b/compiler/noirc_frontend/src/lexer/errors.rs @@ -1,3 +1,4 @@ +use crate::hir::def_collector::dc_crate::CompilationError; use crate::parser::ParserError; use crate::parser::ParserErrorReason; use crate::token::SpannedToken; @@ -42,6 +43,12 @@ impl From for ParserError { } } +impl From for CompilationError { + fn from(error: LexerErrorKind) -> Self { + ParserError::from(error).into() + } +} + impl LexerErrorKind { pub fn span(&self) -> Span { match self { diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index c6a1d44f26b..2284991bbc0 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -635,6 +635,10 @@ impl Attributes { pub fn is_no_predicates(&self) -> bool { self.function.as_ref().map_or(false, |func_attribute| func_attribute.is_no_predicates()) } + + pub fn is_varargs(&self) -> bool { + self.secondary.iter().any(|attr| matches!(attr, SecondaryAttribute::Varargs)) + } } /// An Attribute can be either a Primary Attribute or a Secondary Attribute @@ -728,6 +732,7 @@ impl Attribute { name.trim_matches('"').to_string().into(), )) } + ["varargs"] => Attribute::Secondary(SecondaryAttribute::Varargs), tokens => { tokens.iter().try_for_each(|token| validate(token))?; Attribute::Secondary(SecondaryAttribute::Custom(word.to_owned())) @@ -825,6 +830,9 @@ pub enum SecondaryAttribute { Field(String), Custom(String), Abi(String), + + /// A variable-argument comptime function. + Varargs, } impl fmt::Display for SecondaryAttribute { @@ -839,6 +847,7 @@ impl fmt::Display for SecondaryAttribute { SecondaryAttribute::Export => write!(f, "#[export]"), SecondaryAttribute::Field(ref k) => write!(f, "#[field({k})]"), SecondaryAttribute::Abi(ref k) => write!(f, "#[abi({k})]"), + SecondaryAttribute::Varargs => write!(f, "#[varargs]"), } } } @@ -867,6 +876,7 @@ impl AsRef for SecondaryAttribute { | SecondaryAttribute::Abi(string) => string, SecondaryAttribute::ContractLibraryMethod => "", SecondaryAttribute::Export => "", + SecondaryAttribute::Varargs => "", } } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 06428fe2b01..ce26b38b639 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::marker::Copy; @@ -12,6 +11,7 @@ use noirc_errors::{Location, Span, Spanned}; use petgraph::algo::tarjan_scc; use petgraph::prelude::DiGraph; use petgraph::prelude::NodeIndex as PetGraphIndex; +use rustc_hash::FxHashMap as HashMap; use crate::ast::Ident; use crate::graph::CrateId; @@ -230,6 +230,14 @@ pub struct NodeInterner { // The module where each reference is // (ReferenceId::Reference and ReferenceId::Local aren't included here) pub(crate) reference_modules: HashMap, + + /// Each value currently in scope in the comptime interpreter. + /// Each element of the Vec represents a scope with every scope together making + /// up all currently visible definitions. The first scope is always the global scope. + /// + /// This is stored in the NodeInterner so that the Elaborator from each crate can + /// share the same global values. + pub(crate) comptime_scopes: Vec>, } /// A dependency in the dependency graph may be a type or a definition. @@ -556,44 +564,45 @@ impl Default for NodeInterner { fn default() -> Self { NodeInterner { nodes: Arena::default(), - func_meta: HashMap::new(), - function_definition_ids: HashMap::new(), - function_modifiers: HashMap::new(), - function_modules: HashMap::new(), - module_attributes: HashMap::new(), - func_id_to_trait: HashMap::new(), + func_meta: HashMap::default(), + function_definition_ids: HashMap::default(), + function_modifiers: HashMap::default(), + function_modules: HashMap::default(), + module_attributes: HashMap::default(), + func_id_to_trait: HashMap::default(), dependency_graph: petgraph::graph::DiGraph::new(), - dependency_graph_indices: HashMap::new(), - id_to_location: HashMap::new(), + dependency_graph_indices: HashMap::default(), + id_to_location: HashMap::default(), definitions: vec![], - id_to_type: HashMap::new(), - definition_to_type: HashMap::new(), - structs: HashMap::new(), - struct_attributes: HashMap::new(), + id_to_type: HashMap::default(), + definition_to_type: HashMap::default(), + structs: HashMap::default(), + struct_attributes: HashMap::default(), type_aliases: Vec::new(), - traits: HashMap::new(), - trait_implementations: HashMap::new(), + traits: HashMap::default(), + trait_implementations: HashMap::default(), next_trait_implementation_id: 0, - trait_implementation_map: HashMap::new(), - selected_trait_implementations: HashMap::new(), - infix_operator_traits: HashMap::new(), - prefix_operator_traits: HashMap::new(), + trait_implementation_map: HashMap::default(), + selected_trait_implementations: HashMap::default(), + infix_operator_traits: HashMap::default(), + prefix_operator_traits: HashMap::default(), ordering_type: None, - instantiation_bindings: HashMap::new(), - field_indices: HashMap::new(), + instantiation_bindings: HashMap::default(), + field_indices: HashMap::default(), next_type_variable_id: std::cell::Cell::new(0), globals: Vec::new(), - global_attributes: HashMap::new(), - struct_methods: HashMap::new(), - primitive_methods: HashMap::new(), + global_attributes: HashMap::default(), + struct_methods: HashMap::default(), + primitive_methods: HashMap::default(), type_alias_ref: Vec::new(), type_ref_locations: Vec::new(), quoted_types: Default::default(), lsp_mode: false, location_indices: LocationIndices::default(), reference_graph: petgraph::graph::DiGraph::new(), - reference_graph_indices: HashMap::new(), - reference_modules: HashMap::new(), + reference_graph_indices: HashMap::default(), + reference_modules: HashMap::default(), + comptime_scopes: vec![HashMap::default()], } } } diff --git a/noir_stdlib/src/cmp.nr b/noir_stdlib/src/cmp.nr index bdd5e2bc5ec..d2cf6b3836a 100644 --- a/noir_stdlib/src/cmp.nr +++ b/noir_stdlib/src/cmp.nr @@ -1,9 +1,37 @@ +use crate::meta::derive_via; + +#[derive_via(derive_eq)] // docs:start:eq-trait trait Eq { fn eq(self, other: Self) -> bool; } // docs:end:eq-trait +comptime fn derive_eq(s: StructDefinition) -> Quoted { + let typ = s.as_type(); + + let impl_generics = s.generics().join(quote {,}); + + let where_clause = s.generics().map(|name| quote { $name: Default }).join(quote {,}); + + // `(self.a == other.a) & (self.b == other.b) & ...` + let equalities = s.fields().map( + |f: (Quoted, Type)| { + let name = f.0; + quote { (self.$name == other.$name) } + } + ); + let body = equalities.join(quote { & }); + + quote { + impl<$impl_generics> Eq for $typ where $where_clause { + fn eq(self, other: Self) -> bool { + $body + } + } + } +} + impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } } impl Eq for u64 { fn eq(self, other: u64) -> bool { self == other } } diff --git a/noir_stdlib/src/default.nr b/noir_stdlib/src/default.nr index 0acb3966034..f0d98205a90 100644 --- a/noir_stdlib/src/default.nr +++ b/noir_stdlib/src/default.nr @@ -1,9 +1,37 @@ +use crate::meta::derive_via; + +#[derive_via(derive_default)] // docs:start:default-trait trait Default { fn default() -> Self; } // docs:end:default-trait +comptime fn derive_default(s: StructDefinition) -> Quoted { + let typ = s.as_type(); + + let impl_generics = s.generics().join(quote {,}); + + let where_clause = s.generics().map(|name| quote { $name: Default }).join(quote {,}); + + // `foo: Default::default(), bar: Default::default(), ...` + let fields = s.fields().map( + |f: (Quoted, Type)| { + let name = f.0; + quote { $name: Default::default() } + } + ); + let fields = fields.join(quote {,}); + + quote { + impl<$impl_generics> Default for $typ where $where_clause { + fn default() -> Self { + Self { $fields } + } + } + } +} + impl Default for Field { fn default() -> Field { 0 } } impl Default for u8 { fn default() -> u8 { 0 } } diff --git a/noir_stdlib/src/hash/poseidon2.nr b/noir_stdlib/src/hash/poseidon2.nr index 08cf68d1f82..9626da0cf97 100644 --- a/noir_stdlib/src/hash/poseidon2.nr +++ b/noir_stdlib/src/hash/poseidon2.nr @@ -1,7 +1,7 @@ use crate::hash::Hasher; use crate::default::Default; -global RATE: u32 = 3; +comptime global RATE: u32 = 3; struct Poseidon2 { cache: [Field;3], diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index 351e128fa9a..7ed5e3ff44f 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -1,7 +1,11 @@ +use crate::collections::umap::UHashMap; +use crate::hash::BuildHasherDefault; +use crate::hash::poseidon2::Poseidon2Hasher; + +mod struct_def; mod trait_constraint; mod trait_def; mod typ; -mod type_def; mod quoted; /// Calling unquote as a macro (via `unquote!(arg)`) will unquote @@ -11,6 +15,29 @@ pub comptime fn unquote(code: Quoted) -> Quoted { code } +/// Returns the type of any value #[builtin(type_of)] pub comptime fn type_of(x: T) -> Type {} +type DeriveFunction = fn(StructDefinition) -> Quoted; + +comptime mut global HANDLERS: UHashMap> = UHashMap::default(); + +#[varargs] +pub comptime fn derive(s: StructDefinition, traits: [TraitDefinition]) -> Quoted { + let mut result = quote {}; + + for trait_to_derive in traits { + let handler = HANDLERS.get(trait_to_derive); + assert(handler.is_some(), f"No derive function registered for `{trait_to_derive}`"); + + let trait_impl = handler.unwrap()(s); + result = quote { $result $trait_impl }; + } + + result +} + +unconstrained pub comptime fn derive_via(t: TraitDefinition, f: DeriveFunction) { + HANDLERS.insert(t, f); +} diff --git a/noir_stdlib/src/meta/type_def.nr b/noir_stdlib/src/meta/struct_def.nr similarity index 100% rename from noir_stdlib/src/meta/type_def.nr rename to noir_stdlib/src/meta/struct_def.nr diff --git a/noir_stdlib/src/meta/trait_def.nr b/noir_stdlib/src/meta/trait_def.nr index 5de7631e34d..ca381cb8e16 100644 --- a/noir_stdlib/src/meta/trait_def.nr +++ b/noir_stdlib/src/meta/trait_def.nr @@ -1,4 +1,25 @@ +use crate::hash::{Hash, Hasher}; +use crate::cmp::Eq; + impl TraitDefinition { #[builtin(trait_def_as_trait_constraint)] fn as_trait_constraint(_self: Self) -> TraitConstraint {} } + +impl Eq for TraitDefinition { + fn eq(self, other: Self) -> bool { + trait_def_eq(self, other) + } +} + +impl Hash for TraitDefinition { + fn hash(self, state: &mut H) where H: Hasher { + state.write(trait_def_hash(self)); + } +} + +#[builtin(trait_def_eq)] +fn trait_def_eq(_first: TraitDefinition, _second: TraitDefinition) -> bool {} + +#[builtin(trait_def_hash)] +fn trait_def_hash(_def: TraitDefinition) -> Field {} diff --git a/noir_stdlib/src/prelude.nr b/noir_stdlib/src/prelude.nr index 3244329aa4b..0d423e3556d 100644 --- a/noir_stdlib/src/prelude.nr +++ b/noir_stdlib/src/prelude.nr @@ -6,3 +6,4 @@ use crate::uint128::U128; use crate::cmp::{Eq, Ord}; use crate::default::Default; use crate::convert::{From, Into}; +use crate::meta::{derive, derive_via}; diff --git a/test_programs/compile_success_empty/attribute_args/src/main.nr b/test_programs/compile_success_empty/attribute_args/src/main.nr index 44b9c20460f..6178df5e749 100644 --- a/test_programs/compile_success_empty/attribute_args/src/main.nr +++ b/test_programs/compile_success_empty/attribute_args/src/main.nr @@ -1,9 +1,9 @@ -#[attr_with_args(a b, c d)] -#[varargs(one, two)] -#[varargs(one, two, three, four)] +#[attr_with_args(1, 2)] +#[varargs(1, 2)] +#[varargs(1, 2, 3, 4)] struct Foo {} -comptime fn attr_with_args(s: StructDefinition, a: Quoted, b: Quoted) { +comptime fn attr_with_args(s: StructDefinition, a: Field, b: Field) { // Ensure all variables are in scope. // We can't print them since that breaks the test runner. let _ = s; @@ -11,7 +11,8 @@ comptime fn attr_with_args(s: StructDefinition, a: Quoted, b: Quoted) { let _ = b; } -comptime fn varargs(s: StructDefinition, t: [Quoted]) { +#[varargs] +comptime fn varargs(s: StructDefinition, t: [Field]) { let _ = s; for _ in t {} assert(t.len() < 5); diff --git a/test_programs/execution_success/derive/src/main.nr b/test_programs/execution_success/derive/src/main.nr index f226817fbaf..e4148f2c944 100644 --- a/test_programs/execution_success/derive/src/main.nr +++ b/test_programs/execution_success/derive/src/main.nr @@ -1,37 +1,11 @@ -use std::collections::umap::UHashMap; -use std::hash::BuildHasherDefault; -use std::hash::poseidon2::Poseidon2Hasher; - -#[my_derive(DoNothing)] -struct MyStruct { my_field: u32 } - -type DeriveFunction = fn(StructDefinition) -> Quoted; - -comptime mut global HANDLERS: UHashMap> = UHashMap::default(); - -comptime fn my_derive(s: StructDefinition, traits: [Quoted]) -> Quoted { - let mut result = quote {}; - - for trait_to_derive in traits { - let handler = HANDLERS.get(trait_to_derive.as_trait_constraint()); - assert(handler.is_some(), f"No derive function registered for `{trait_to_derive}`"); - - let trait_impl = handler.unwrap()(s); - result = quote { $result $trait_impl }; - } - - result -} - -unconstrained comptime fn my_derive_via(t: TraitDefinition, f: Quoted) { - HANDLERS.insert(t.as_trait_constraint(), std::meta::unquote!(f)); -} - -#[my_derive_via(derive_do_nothing)] +#[derive_via(derive_do_nothing)] trait DoNothing { fn do_nothing(self); } +#[derive(DoNothing)] +struct MyStruct { my_field: u32 } + comptime fn derive_do_nothing(s: StructDefinition) -> Quoted { let typ = s.as_type(); let generics = s.generics().join(quote {,}); @@ -45,7 +19,17 @@ comptime fn derive_do_nothing(s: StructDefinition) -> Quoted { } } +// Test stdlib derive fns & multiple traits +#[derive(Eq, Default)] +struct MyOtherStruct { + field1: u32, + field2: u64, +} + fn main() { let s = MyStruct { my_field: 1 }; s.do_nothing(); + + let o = MyOtherStruct::default(); + assert_eq(o, o); }