diff --git a/Cargo.lock b/Cargo.lock index f78fbfede27..cd936e4bca2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,7 @@ version = "0.33.0" dependencies = [ "acvm", "convert_case 0.6.0", + "im", "iter-extended", "noirc_errors", "noirc_frontend", @@ -3004,6 +3005,7 @@ version = "0.33.0" dependencies = [ "acvm", "bn254_blackbox_solver", + "cfg-if 1.0.0", "chrono", "fxhash", "im", @@ -3012,6 +3014,7 @@ dependencies = [ "noirc_frontend", "num-bigint", "proptest", + "rayon", "serde", "serde_json", "serde_with", diff --git a/Cargo.toml b/Cargo.toml index 52cb1012b71..a903ef6fec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,7 +130,7 @@ criterion = "0.5.0" # https://github.com/tikv/pprof-rs/pull/172 pprof = { version = "0.13", features = ["flamegraph", "criterion"] } - +cfg-if = "1.0.0" dirs = "4" serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0" @@ -154,6 +154,7 @@ color-eyre = "0.6.2" rand = "0.8.5" proptest = "1.2.0" proptest-derive = "0.4.0" +rayon = "1.8.0" im = { version = "15.1", features = ["serde"] } tracing = "0.1.40" diff --git a/acvm-repo/acir_field/Cargo.toml b/acvm-repo/acir_field/Cargo.toml index c1cffc1334e..acc34457bc9 100644 --- a/acvm-repo/acir_field/Cargo.toml +++ b/acvm-repo/acir_field/Cargo.toml @@ -24,7 +24,7 @@ ark-bn254.workspace = true ark-bls12-381 = { workspace = true, optional = true } ark-ff.workspace = true -cfg-if = "1.0.0" +cfg-if.workspace = true [dev-dependencies] proptest.workspace = true diff --git a/aztec_macros/Cargo.toml b/aztec_macros/Cargo.toml index c9d88e36e28..258379cd7b8 100644 --- a/aztec_macros/Cargo.toml +++ b/aztec_macros/Cargo.toml @@ -18,5 +18,6 @@ noirc_frontend.workspace = true noirc_errors.workspace = true iter-extended.workspace = true convert_case = "0.6.0" +im.workspace = true regex = "1.10" tiny-keccak = { version = "2.0.0", features = ["keccak"] } diff --git a/aztec_macros/src/utils/parse_utils.rs b/aztec_macros/src/utils/parse_utils.rs index 4c6cbb10d9f..f2998fbaafc 100644 --- a/aztec_macros/src/utils/parse_utils.rs +++ b/aztec_macros/src/utils/parse_utils.rs @@ -218,7 +218,10 @@ fn empty_statement(statement: &mut Statement) { StatementKind::For(for_loop_statement) => empty_for_loop_statement(for_loop_statement), StatementKind::Comptime(statement) => empty_statement(statement), StatementKind::Semi(expression) => empty_expression(expression), - StatementKind::Break | StatementKind::Continue | StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -271,12 +274,15 @@ fn empty_expression(expression: &mut Expression) { ExpressionKind::Unsafe(block_expression, _span) => { empty_block_expression(block_expression); } - ExpressionKind::Quote(..) | ExpressionKind::Resolved(_) | ExpressionKind::Error => (), ExpressionKind::AsTraitPath(path) => { empty_unresolved_type(&mut path.typ); empty_path(&mut path.trait_path); empty_ident(&mut path.impl_item); } + ExpressionKind::Quote(..) + | ExpressionKind::Resolved(_) + | ExpressionKind::Interned(_) + | ExpressionKind::Error => (), } } @@ -353,6 +359,7 @@ fn empty_unresolved_type(unresolved_type: &mut UnresolvedType) { | UnresolvedTypeData::Unit | UnresolvedTypeData::Quoted(_) | UnresolvedTypeData::Resolved(_) + | UnresolvedTypeData::Interned(_) | UnresolvedTypeData::Unspecified | UnresolvedTypeData::Error => (), } @@ -531,6 +538,7 @@ fn empty_lvalue(lvalue: &mut LValue) { empty_expression(index); } LValue::Dereference(lvalue, _) => empty_lvalue(lvalue), + LValue::Interned(..) => (), } } diff --git a/compiler/noirc_driver/Cargo.toml b/compiler/noirc_driver/Cargo.toml index b244018cc71..6b200e79b89 100644 --- a/compiler/noirc_driver/Cargo.toml +++ b/compiler/noirc_driver/Cargo.toml @@ -29,3 +29,7 @@ rust-embed.workspace = true tracing.workspace = true aztec_macros = { path = "../../aztec_macros" } + +[features] +bn254 = ["noirc_frontend/bn254", "noirc_evaluator/bn254"] +bls12_381 = ["noirc_frontend/bls12_381", "noirc_evaluator/bls12_381"] diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index cb3a4d25c9d..f95c9de7c2c 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -131,6 +131,18 @@ pub struct CompileOptions { pub skip_underconstrained_check: bool, } +#[derive(Clone, Debug, Default)] +pub struct CheckOptions { + pub compile_options: CompileOptions, + pub error_on_unused_imports: bool, +} + +impl CheckOptions { + pub fn new(compile_options: &CompileOptions, error_on_unused_imports: bool) -> Self { + Self { compile_options: compile_options.clone(), error_on_unused_imports } + } +} + pub fn parse_expression_width(input: &str) -> Result { use std::io::{Error, ErrorKind}; let width = input @@ -278,8 +290,10 @@ pub fn add_dep( pub fn check_crate( context: &mut Context, crate_id: CrateId, - options: &CompileOptions, + check_options: &CheckOptions, ) -> CompilationResult<()> { + let options = &check_options.compile_options; + let macros: &[&dyn MacroProcessor] = if options.disable_macros { &[] } else { &[&aztec_macros::AztecMacro] }; @@ -289,6 +303,7 @@ pub fn check_crate( context, options.debug_comptime_in_file.as_deref(), options.arithmetic_generics, + check_options.error_on_unused_imports, macros, ); errors.extend(diagnostics.into_iter().map(|(error, file_id)| { @@ -322,7 +337,10 @@ pub fn compile_main( options: &CompileOptions, cached_program: Option, ) -> CompilationResult { - let (_, mut warnings) = check_crate(context, crate_id, options)?; + let error_on_unused_imports = true; + let check_options = CheckOptions::new(options, error_on_unused_imports); + + let (_, mut warnings) = check_crate(context, crate_id, &check_options)?; let main = context.get_main_function(&crate_id).ok_or_else(|| { // TODO(#2155): This error might be a better to exist in Nargo @@ -357,7 +375,9 @@ pub fn compile_contract( crate_id: CrateId, options: &CompileOptions, ) -> CompilationResult { - let (_, warnings) = check_crate(context, crate_id, options)?; + let error_on_unused_imports = true; + let check_options = CheckOptions::new(options, error_on_unused_imports); + let (_, warnings) = check_crate(context, crate_id, &check_options)?; // TODO: We probably want to error if contracts is empty let contracts = context.get_all_contracts(&crate_id); diff --git a/compiler/noirc_errors/src/reporter.rs b/compiler/noirc_errors/src/reporter.rs index 3ce0f268715..b21dc759f14 100644 --- a/compiler/noirc_errors/src/reporter.rs +++ b/compiler/noirc_errors/src/reporter.rs @@ -46,7 +46,7 @@ impl CustomDiagnostic { ) -> CustomDiagnostic { CustomDiagnostic { message: primary_message, - secondaries: vec![CustomLabel::new(secondary_message, secondary_span)], + secondaries: vec![CustomLabel::new(secondary_message, secondary_span, None)], notes: Vec::new(), kind, } @@ -98,7 +98,7 @@ impl CustomDiagnostic { ) -> CustomDiagnostic { CustomDiagnostic { message: primary_message, - secondaries: vec![CustomLabel::new(secondary_message, secondary_span)], + secondaries: vec![CustomLabel::new(secondary_message, secondary_span, None)], notes: Vec::new(), kind: DiagnosticKind::Bug, } @@ -113,7 +113,11 @@ impl CustomDiagnostic { } pub fn add_secondary(&mut self, message: String, span: Span) { - self.secondaries.push(CustomLabel::new(message, span)); + self.secondaries.push(CustomLabel::new(message, span, None)); + } + + pub fn add_secondary_with_file(&mut self, message: String, span: Span, file: fm::FileId) { + self.secondaries.push(CustomLabel::new(message, span, Some(file))); } pub fn is_error(&self) -> bool { @@ -153,11 +157,12 @@ impl std::fmt::Display for CustomDiagnostic { pub struct CustomLabel { pub message: String, pub span: Span, + pub file: Option, } impl CustomLabel { - fn new(message: String, span: Span) -> CustomLabel { - CustomLabel { message, span } + fn new(message: String, span: Span, file: Option) -> CustomLabel { + CustomLabel { message, span, file } } } @@ -234,7 +239,8 @@ fn convert_diagnostic( .map(|sl| { let start_span = sl.span.start() as usize; let end_span = sl.span.end() as usize; - Label::secondary(file_id, start_span..end_span).with_message(&sl.message) + let file = sl.file.unwrap_or(file_id); + Label::secondary(file, start_span..end_span).with_message(&sl.message) }) .collect() } else { diff --git a/compiler/noirc_evaluator/Cargo.toml b/compiler/noirc_evaluator/Cargo.toml index 81feb0b7154..1db6af2ae85 100644 --- a/compiler/noirc_evaluator/Cargo.toml +++ b/compiler/noirc_evaluator/Cargo.toml @@ -26,6 +26,12 @@ serde_json.workspace = true serde_with = "3.2.0" tracing.workspace = true chrono = "0.4.37" +rayon.workspace = true +cfg-if.workspace = true [dev-dependencies] -proptest.workspace = true \ No newline at end of file +proptest.workspace = true + +[features] +bn254 = ["noirc_frontend/bn254"] +bls12_381= [] diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index 54dbd4716a9..eeaa60b4323 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -24,7 +24,7 @@ use acvm::{acir::AcirField, FieldElement}; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use iter_extended::vecmap; use num_bigint::BigUint; -use std::rc::Rc; +use std::sync::Arc; use super::brillig_black_box::convert_black_box_call; use super::brillig_block_variables::BlockVariables; @@ -1643,7 +1643,7 @@ impl<'block> BrilligBlock<'block> { fn initialize_constant_array_runtime( &mut self, - item_types: Rc>, + item_types: Arc>, item_to_repeat: Vec, item_count: usize, pointer: MemoryAddress, diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs index 88f7e35865e..112f9c908f9 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs @@ -276,20 +276,14 @@ impl BrilligContext< let index_at_end_of_array = self.allocate_register(); let end_value_register = self.allocate_register(); - self.codegen_loop(iteration_count, |ctx, iterator_register| { - // Load both values - ctx.codegen_array_get(pointer, iterator_register, start_value_register); + self.mov_instruction(index_at_end_of_array, vector.size); + self.codegen_loop(iteration_count, |ctx, iterator_register| { // The index at the end of array is size - 1 - iterator - ctx.mov_instruction(index_at_end_of_array, size); ctx.codegen_usize_op_in_place(index_at_end_of_array, BrilligBinaryOp::Sub, 1); - ctx.memory_op_instruction( - index_at_end_of_array, - iterator_register.address, - index_at_end_of_array, - BrilligBinaryOp::Sub, - ); + // Load both values + ctx.codegen_array_get(vector.pointer, iterator_register, start_value_register); ctx.codegen_array_get( pointer, SingleAddrVariable::new_usize(index_at_end_of_array), diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index a27354d7cb6..d3d936385b0 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -10,7 +10,6 @@ use crate::ssa::ir::{instruction::Endian, types::NumericType}; use acvm::acir::circuit::brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs}; use acvm::acir::circuit::opcodes::{AcirFunctionId, BlockId, BlockType, MemOp}; use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, ExpressionWidth, Opcode}; -use acvm::blackbox_solver; use acvm::brillig_vm::{MemoryValue, VMStatus, VM}; use acvm::{ acir::AcirField, @@ -1459,7 +1458,6 @@ impl AcirContext { name, BlackBoxFunc::MultiScalarMul | BlackBoxFunc::Keccakf1600 - | BlackBoxFunc::Sha256Compression | BlackBoxFunc::Blake2s | BlackBoxFunc::Blake3 | BlackBoxFunc::AND @@ -2152,7 +2150,11 @@ fn execute_brillig( } // Instantiate a Brillig VM given the solved input registers and memory, along with the Brillig bytecode. - let mut vm = VM::new(calldata, code, Vec::new(), &blackbox_solver::StubbedBlackBoxSolver); + // + // We pass a stubbed solver here as a concrete solver implies a field choice which conflicts with this function + // being generic. + let solver = acvm::blackbox_solver::StubbedBlackBoxSolver; + let mut vm = VM::new(calldata, code, Vec::new(), &solver); // Run the Brillig VM on these inputs, bytecode, etc! let vm_status = vm.process_opcodes(); diff --git a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs index 79db4e645ee..26eab290d4b 100644 --- a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs +++ b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs @@ -1,8 +1,6 @@ //! This module defines an SSA pass that detects if the final function has any subgraphs independent from inputs and outputs. //! If this is the case, then part of the final circuit can be completely replaced by any other passing circuit, since there are no constraints ensuring connections. //! So the compiler informs the developer of this as a bug -use im::HashMap; - use crate::errors::{InternalBug, SsaReport}; use crate::ssa::ir::basic_block::BasicBlockId; use crate::ssa::ir::function::RuntimeType; @@ -10,25 +8,29 @@ use crate::ssa::ir::function::{Function, FunctionId}; use crate::ssa::ir::instruction::{Instruction, InstructionId, Intrinsic}; use crate::ssa::ir::value::{Value, ValueId}; use crate::ssa::ssa_gen::Ssa; +use im::HashMap; +use rayon::prelude::*; use std::collections::{BTreeMap, HashSet}; impl Ssa { /// Go through each top-level non-brillig function and detect if it has independent subgraphs #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn check_for_underconstrained_values(&mut self) -> Vec { - let mut warnings: Vec = Vec::new(); - for function in self.functions.values() { - match function.runtime() { - RuntimeType::Acir { .. } => { - warnings.extend(check_for_underconstrained_values_within_function( - function, + let functions_id = self.functions.values().map(|f| f.id().to_usize()).collect::>(); + functions_id + .iter() + .par_bridge() + .flat_map(|fid| { + let function_to_process = &self.functions[&FunctionId::new(*fid)]; + match function_to_process.runtime() { + RuntimeType::Acir { .. } => check_for_underconstrained_values_within_function( + function_to_process, &self.functions, - )); + ), + RuntimeType::Brillig => Vec::new(), } - RuntimeType::Brillig => (), - } - } - warnings + }) + .collect() } } @@ -88,9 +90,8 @@ impl Context { self.visited_blocks.insert(block); self.connect_value_ids_in_block(function, block, all_functions); } - // Merge ValueIds into sets, where each original small set of ValueIds is merged with another set if they intersect - self.merge_sets(); + self.value_sets = Self::merge_sets_par(&self.value_sets); } /// Find sets that contain input or output value of the function @@ -267,14 +268,13 @@ impl Context { /// Merge all small sets into larger ones based on whether the sets intersect or not /// /// If two small sets have a common ValueId, we merge them into one - fn merge_sets(&mut self) { + fn merge_sets(current: &[HashSet]) -> Vec> { let mut new_set_id: usize = 0; let mut updated_sets: HashMap> = HashMap::new(); let mut value_dictionary: HashMap = HashMap::new(); let mut parsed_value_set: HashSet = HashSet::new(); - // Go through each set - for set in self.value_sets.iter() { + for set in current.iter() { // Check if the set has any of the ValueIds we've encountered at previous iterations let intersection: HashSet = set.intersection(&parsed_value_set).copied().collect(); @@ -327,7 +327,26 @@ impl Context { } updated_sets.insert(largest_set_index, largest_set); } - self.value_sets = updated_sets.values().cloned().collect(); + updated_sets.values().cloned().collect() + } + + /// Parallel version of merge_sets + /// The sets are merged by chunks, and then the chunks are merged together + fn merge_sets_par(sets: &[HashSet]) -> Vec> { + let mut sets = sets.to_owned(); + let mut len = sets.len(); + let mut prev_len = len + 1; + + while len > 1000 && len < prev_len { + sets = sets.par_chunks(1000).flat_map(Self::merge_sets).collect(); + + prev_len = len; + len = sets.len(); + } + // TODO: if prev_len >= len, this means we cannot effectively merge the sets anymore + // We should instead partition the sets into disjoint chunks and work on those chunks, + // but for now we fallback to the non-parallel implementation + Self::merge_sets(&sets) } } #[cfg(test)] diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs index 9f964cf048d..38895bb977e 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs @@ -1,5 +1,4 @@ -use std::collections::BTreeMap; -use std::rc::Rc; +use std::{collections::BTreeMap, sync::Arc}; use crate::ssa::ir::{types::Type, value::ValueId}; use acvm::FieldElement; @@ -155,8 +154,8 @@ impl FunctionBuilder { let len = databus.values.len(); let array = if len > 0 { - let array = - self.array_constant(databus.values, Type::Array(Rc::new(vec![Type::field()]), len)); + let array = self + .array_constant(databus.values, Type::Array(Arc::new(vec![Type::field()]), len)); Some(array) } else { None diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 8cc42241d92..bf6430c36d7 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod data_bus; -use std::{borrow::Cow, collections::BTreeMap, rc::Rc}; +use std::{borrow::Cow, collections::BTreeMap, sync::Arc}; use acvm::{acir::circuit::ErrorSelector, FieldElement}; use noirc_errors::Location; @@ -189,7 +189,7 @@ impl FunctionBuilder { /// given amount of field elements. Returns the result of the allocate instruction, /// which is always a Reference to the allocated data. pub(crate) fn insert_allocate(&mut self, element_type: Type) -> ValueId { - let reference_type = Type::Reference(Rc::new(element_type)); + let reference_type = Type::Reference(Arc::new(element_type)); self.insert_instruction(Instruction::Allocate, Some(vec![reference_type])).first() } @@ -516,7 +516,7 @@ impl std::ops::Index for FunctionBuilder { #[cfg(test)] mod tests { - use std::rc::Rc; + use std::sync::Arc; use acvm::{acir::AcirField, FieldElement}; @@ -542,7 +542,7 @@ mod tests { let to_bits_id = builder.import_intrinsic_id(Intrinsic::ToBits(Endian::Little)); let input = builder.numeric_constant(FieldElement::from(7_u128), Type::field()); let length = builder.numeric_constant(FieldElement::from(8_u128), Type::field()); - let result_types = vec![Type::Array(Rc::new(vec![Type::bool()]), 8)]; + let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), 8)]; let call_results = builder.insert_call(to_bits_id, vec![input, length], result_types).into_owned(); diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index c3cd27bf179..36069f17933 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -733,11 +733,15 @@ impl Instruction { } } + let then_value = dfg.resolve(*then_value); + let else_value = dfg.resolve(*else_value); + if then_value == else_value { + return SimplifiedTo(then_value); + } + if matches!(&typ, Type::Numeric(_)) { let then_condition = *then_condition; - let then_value = *then_value; let else_condition = *else_condition; - let else_value = *else_value; let result = ValueMerger::merge_numeric_values( dfg, diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index ea2523e873e..2c6aedeca35 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -1,7 +1,10 @@ use fxhash::FxHashMap as HashMap; -use std::{collections::VecDeque, rc::Rc}; +use std::{collections::VecDeque, sync::Arc}; -use acvm::{acir::AcirField, acir::BlackBoxFunc, BlackBoxResolutionError, FieldElement}; +use acvm::{ + acir::{AcirField, BlackBoxFunc}, + BlackBoxResolutionError, FieldElement, +}; use bn254_blackbox_solver::derive_generators; use iter_extended::vecmap; use num_bigint::BigUint; @@ -20,6 +23,8 @@ use crate::ssa::{ use super::{Binary, BinaryOp, Endian, Instruction, SimplifyResult}; +mod blackbox; + /// Try to simplify this call instruction. If the instruction can be simplified to a known value, /// that value is returned. Otherwise None is returned. /// @@ -468,11 +473,17 @@ fn simplify_black_box_func( arguments: &[ValueId], dfg: &mut DataFlowGraph, ) -> SimplifyResult { + cfg_if::cfg_if! { + if #[cfg(feature = "bn254")] { + let solver = bn254_blackbox_solver::Bn254BlackBoxSolver; + } else { + let solver = acvm::blackbox_solver::StubbedBlackBoxSolver; + } + }; match bb_func { BlackBoxFunc::SHA256 => simplify_hash(dfg, arguments, acvm::blackbox_solver::sha256), BlackBoxFunc::Blake2s => simplify_hash(dfg, arguments, acvm::blackbox_solver::blake2s), BlackBoxFunc::Blake3 => simplify_hash(dfg, arguments, acvm::blackbox_solver::blake3), - BlackBoxFunc::PedersenCommitment | BlackBoxFunc::PedersenHash => SimplifyResult::None, BlackBoxFunc::Keccakf1600 => { if let Some((array_input, _)) = dfg.get_array_constant(arguments[0]) { if array_is_constant(dfg, &array_input) { @@ -503,20 +514,26 @@ fn simplify_black_box_func( BlackBoxFunc::Keccak256 => { unreachable!("Keccak256 should have been replaced by calls to Keccakf1600") } - BlackBoxFunc::Poseidon2Permutation => SimplifyResult::None, //TODO(Guillaume) - BlackBoxFunc::EcdsaSecp256k1 => { - simplify_signature(dfg, arguments, acvm::blackbox_solver::ecdsa_secp256k1_verify) - } - BlackBoxFunc::EcdsaSecp256r1 => { - simplify_signature(dfg, arguments, acvm::blackbox_solver::ecdsa_secp256r1_verify) + BlackBoxFunc::Poseidon2Permutation => { + blackbox::simplify_poseidon2_permutation(dfg, solver, arguments) } + BlackBoxFunc::EcdsaSecp256k1 => blackbox::simplify_signature( + dfg, + arguments, + acvm::blackbox_solver::ecdsa_secp256k1_verify, + ), + BlackBoxFunc::EcdsaSecp256r1 => blackbox::simplify_signature( + dfg, + arguments, + acvm::blackbox_solver::ecdsa_secp256r1_verify, + ), + + BlackBoxFunc::PedersenCommitment + | BlackBoxFunc::PedersenHash + | BlackBoxFunc::MultiScalarMul => SimplifyResult::None, + BlackBoxFunc::EmbeddedCurveAdd => blackbox::simplify_ec_add(dfg, solver, arguments), + BlackBoxFunc::SchnorrVerify => blackbox::simplify_schnorr_verify(dfg, solver, arguments), - BlackBoxFunc::MultiScalarMul - | BlackBoxFunc::SchnorrVerify - | BlackBoxFunc::EmbeddedCurveAdd => { - // Currently unsolvable here as we rely on an implementation in the backend. - SimplifyResult::None - } BlackBoxFunc::BigIntAdd | BlackBoxFunc::BigIntSub | BlackBoxFunc::BigIntMul @@ -544,7 +561,7 @@ fn simplify_black_box_func( fn make_constant_array(dfg: &mut DataFlowGraph, results: Vec, typ: Type) -> ValueId { let result_constants = vecmap(results, |element| dfg.make_constant(element, typ.clone())); - let typ = Type::Array(Rc::new(vec![typ]), result_constants.len()); + let typ = Type::Array(Arc::new(vec![typ]), result_constants.len()); dfg.make_array(result_constants.into(), typ) } @@ -555,7 +572,7 @@ fn make_constant_slice( ) -> (ValueId, ValueId) { let result_constants = vecmap(results, |element| dfg.make_constant(element, typ.clone())); - let typ = Type::Slice(Rc::new(vec![typ])); + let typ = Type::Slice(Arc::new(vec![typ])); let length = FieldElement::from(result_constants.len() as u128); (dfg.make_constant(length, Type::length_type()), dfg.make_array(result_constants.into(), typ)) } diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs new file mode 100644 index 00000000000..7789b212e58 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -0,0 +1,190 @@ +use std::sync::Arc; + +use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement}; +use iter_extended::vecmap; + +use crate::ssa::ir::{ + dfg::DataFlowGraph, instruction::SimplifyResult, types::Type, value::ValueId, +}; + +use super::{array_is_constant, make_constant_array, to_u8_vec}; + +pub(super) fn simplify_ec_add( + dfg: &mut DataFlowGraph, + solver: impl BlackBoxFunctionSolver, + arguments: &[ValueId], +) -> SimplifyResult { + match ( + dfg.get_numeric_constant(arguments[0]), + dfg.get_numeric_constant(arguments[1]), + dfg.get_numeric_constant(arguments[2]), + dfg.get_numeric_constant(arguments[3]), + dfg.get_numeric_constant(arguments[4]), + dfg.get_numeric_constant(arguments[5]), + ) { + ( + Some(point1_x), + Some(point1_y), + Some(point1_is_infinity), + Some(point2_x), + Some(point2_y), + Some(point2_is_infinity), + ) => { + let Ok((result_x, result_y, result_is_infinity)) = solver.ec_add( + &point1_x, + &point1_y, + &point1_is_infinity, + &point2_x, + &point2_y, + &point2_is_infinity, + ) else { + return SimplifyResult::None; + }; + + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + + let typ = Type::Array(Arc::new(vec![Type::field()]), 3); + let result_array = + dfg.make_array(im::vector![result_x, result_y, result_is_infinity], typ); + + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } +} + +pub(super) fn simplify_poseidon2_permutation( + dfg: &mut DataFlowGraph, + solver: impl BlackBoxFunctionSolver, + arguments: &[ValueId], +) -> SimplifyResult { + match (dfg.get_array_constant(arguments[0]), dfg.get_numeric_constant(arguments[1])) { + (Some((state, _)), Some(state_length)) if array_is_constant(dfg, &state) => { + let state: Vec = state + .iter() + .map(|id| { + dfg.get_numeric_constant(*id) + .expect("value id from array should point at constant") + }) + .collect(); + + let Some(state_length) = state_length.try_to_u32() else { + return SimplifyResult::None; + }; + + let Ok(new_state) = solver.poseidon2_permutation(&state, state_length) else { + return SimplifyResult::None; + }; + + let result_array = make_constant_array(dfg, new_state, Type::field()); + + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } +} + +pub(super) fn simplify_schnorr_verify( + dfg: &mut DataFlowGraph, + solver: impl BlackBoxFunctionSolver, + arguments: &[ValueId], +) -> SimplifyResult { + match ( + dfg.get_numeric_constant(arguments[0]), + dfg.get_numeric_constant(arguments[1]), + dfg.get_array_constant(arguments[2]), + dfg.get_array_constant(arguments[3]), + ) { + (Some(public_key_x), Some(public_key_y), Some((signature, _)), Some((message, _))) + if array_is_constant(dfg, &signature) && array_is_constant(dfg, &message) => + { + let signature = to_u8_vec(dfg, signature); + let signature: [u8; 64] = + signature.try_into().expect("Compiler should produce correctly sized signature"); + + let message = to_u8_vec(dfg, message); + + let Ok(valid_signature) = + solver.schnorr_verify(&public_key_x, &public_key_y, &signature, &message) + else { + return SimplifyResult::None; + }; + + let valid_signature = dfg.make_constant(valid_signature.into(), Type::bool()); + SimplifyResult::SimplifiedTo(valid_signature) + } + _ => SimplifyResult::None, + } +} + +pub(super) fn simplify_hash( + dfg: &mut DataFlowGraph, + arguments: &[ValueId], + hash_function: fn(&[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, +) -> SimplifyResult { + match dfg.get_array_constant(arguments[0]) { + Some((input, _)) if array_is_constant(dfg, &input) => { + let input_bytes: Vec = to_u8_vec(dfg, input); + + let hash = hash_function(&input_bytes) + .expect("Rust solvable black box function should not fail"); + + let hash_values = vecmap(hash, |byte| FieldElement::from_be_bytes_reduce(&[byte])); + + let result_array = make_constant_array(dfg, hash_values, Type::unsigned(8)); + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } +} + +type ECDSASignatureVerifier = fn( + hashed_msg: &[u8], + public_key_x: &[u8; 32], + public_key_y: &[u8; 32], + signature: &[u8; 64], +) -> Result; + +pub(super) fn simplify_signature( + dfg: &mut DataFlowGraph, + arguments: &[ValueId], + signature_verifier: ECDSASignatureVerifier, +) -> SimplifyResult { + match ( + dfg.get_array_constant(arguments[0]), + dfg.get_array_constant(arguments[1]), + dfg.get_array_constant(arguments[2]), + dfg.get_array_constant(arguments[3]), + ) { + ( + Some((public_key_x, _)), + Some((public_key_y, _)), + Some((signature, _)), + Some((hashed_message, _)), + ) if array_is_constant(dfg, &public_key_x) + && array_is_constant(dfg, &public_key_y) + && array_is_constant(dfg, &signature) + && array_is_constant(dfg, &hashed_message) => + { + let public_key_x: [u8; 32] = to_u8_vec(dfg, public_key_x) + .try_into() + .expect("ECDSA public key fields are 32 bytes"); + let public_key_y: [u8; 32] = to_u8_vec(dfg, public_key_y) + .try_into() + .expect("ECDSA public key fields are 32 bytes"); + let signature: [u8; 64] = + to_u8_vec(dfg, signature).try_into().expect("ECDSA signatures are 64 bytes"); + let hashed_message: Vec = to_u8_vec(dfg, hashed_message); + + let valid_signature = + signature_verifier(&hashed_message, &public_key_x, &public_key_y, &signature) + .expect("Rust solvable black box function should not fail"); + + let valid_signature = dfg.make_constant(valid_signature.into(), Type::bool()); + SimplifyResult::SimplifiedTo(valid_signature) + } + _ => SimplifyResult::None, + } +} diff --git a/compiler/noirc_evaluator/src/ssa/ir/map.rs b/compiler/noirc_evaluator/src/ssa/ir/map.rs index f1265553b83..769d52e6e65 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/map.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/map.rs @@ -27,7 +27,7 @@ impl Id { /// Constructs a new Id for the given index. /// This constructor is deliberately private to prevent /// constructing invalid IDs. - fn new(index: usize) -> Self { + pub(crate) fn new(index: usize) -> Self { Self { index, _marker: std::marker::PhantomData } } diff --git a/compiler/noirc_evaluator/src/ssa/ir/types.rs b/compiler/noirc_evaluator/src/ssa/ir/types.rs index e467fa5400d..b7ee37ba17a 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/types.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/types.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use std::rc::Rc; +use std::sync::Arc; use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; @@ -72,13 +72,13 @@ pub(crate) enum Type { Numeric(NumericType), /// A reference to some value, such as an array - Reference(Rc), + Reference(Arc), /// An immutable array value with the given element type and length - Array(Rc, usize), + Array(Arc, usize), /// An immutable slice value with a given element type - Slice(Rc), + Slice(Arc), /// A function that may be called directly Function, @@ -112,7 +112,7 @@ impl Type { /// Creates the str type, of the given length N pub(crate) fn str(length: usize) -> Type { - Type::Array(Rc::new(vec![Type::char()]), length) + Type::Array(Arc::new(vec![Type::char()]), length) } /// Creates the native field type. @@ -190,7 +190,7 @@ impl Type { } } - pub(crate) fn element_types(self) -> Rc> { + pub(crate) fn element_types(self) -> Arc> { match self { Type::Array(element_types, _) | Type::Slice(element_types) => element_types, other => panic!("element_types: Expected array or slice, found {other}"), diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index c8f6d201d86..ff9a63c8d79 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -311,7 +311,7 @@ impl Context { #[cfg(test)] mod test { - use std::rc::Rc; + use std::sync::Arc; use crate::ssa::{ function_builder::FunctionBuilder, @@ -509,7 +509,7 @@ mod test { let one = builder.field_constant(1u128); let v1 = builder.insert_binary(v0, BinaryOp::Add, one); - let array_type = Type::Array(Rc::new(vec![Type::field()]), 1); + let array_type = Type::Array(Arc::new(vec![Type::field()]), 1); let arr = builder.current_function.dfg.make_array(vec![v1].into(), array_type); builder.terminate_with_return(vec![arr]); @@ -601,7 +601,7 @@ mod test { // Compiling main let mut builder = FunctionBuilder::new("main".into(), main_id); - let v0 = builder.add_parameter(Type::Array(Rc::new(vec![Type::field()]), 4)); + let v0 = builder.add_parameter(Type::Array(Arc::new(vec![Type::field()]), 4)); let v1 = builder.add_parameter(Type::unsigned(32)); let v2 = builder.add_parameter(Type::unsigned(1)); let v3 = builder.add_parameter(Type::unsigned(1)); @@ -737,7 +737,7 @@ mod test { let zero = builder.field_constant(0u128); let one = builder.field_constant(1u128); - let typ = Type::Array(Rc::new(vec![Type::field()]), 2); + let typ = Type::Array(Arc::new(vec![Type::field()]), 2); let array = builder.array_constant(vec![zero, one].into(), typ); let _v2 = builder.insert_array_get(array, v1, Type::field()); @@ -787,7 +787,7 @@ mod test { let v0 = builder.add_parameter(Type::bool()); let v1 = builder.add_parameter(Type::bool()); - let v2 = builder.add_parameter(Type::Array(Rc::new(vec![Type::field()]), 2)); + let v2 = builder.add_parameter(Type::Array(Arc::new(vec![Type::field()]), 2)); let zero = builder.numeric_constant(0u128, Type::length_type()); let one = builder.numeric_constant(1u128, Type::length_type()); diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 72ed02b00a8..d5fb98c7adc 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -878,7 +878,7 @@ impl<'f> Context<'f> { #[cfg(test)] mod test { - use std::rc::Rc; + use std::sync::Arc; use acvm::acir::AcirField; @@ -1016,7 +1016,7 @@ mod test { let b2 = builder.insert_block(); let v0 = builder.add_parameter(Type::bool()); - let v1 = builder.add_parameter(Type::Reference(Rc::new(Type::field()))); + let v1 = builder.add_parameter(Type::Reference(Arc::new(Type::field()))); builder.terminate_with_jmpif(v0, b1, b2); @@ -1078,7 +1078,7 @@ mod test { let b3 = builder.insert_block(); let v0 = builder.add_parameter(Type::bool()); - let v1 = builder.add_parameter(Type::Reference(Rc::new(Type::field()))); + let v1 = builder.add_parameter(Type::Reference(Arc::new(Type::field()))); builder.terminate_with_jmpif(v0, b1, b2); @@ -1477,7 +1477,7 @@ mod test { let b2 = builder.insert_block(); let b3 = builder.insert_block(); - let element_type = Rc::new(vec![Type::unsigned(8)]); + let element_type = Arc::new(vec![Type::unsigned(8)]); let array_type = Type::Array(element_type.clone(), 2); let array = builder.add_parameter(array_type); diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index 7c2db62b0ea..75ee57dd4fa 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -58,6 +58,13 @@ impl<'a> ValueMerger<'a> { then_value: ValueId, else_value: ValueId, ) -> ValueId { + let then_value = self.dfg.resolve(then_value); + let else_value = self.dfg.resolve(else_value); + + if then_value == else_value { + return then_value; + } + match self.dfg.type_of_value(then_value) { Type::Numeric(_) => Self::merge_numeric_values( self.dfg, diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index e5a25dcfef1..9d6582c0db7 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -425,7 +425,7 @@ impl<'f> PerFunctionContext<'f> { #[cfg(test)] mod tests { - use std::rc::Rc; + use std::sync::Arc; use acvm::{acir::AcirField, FieldElement}; use im::vector; @@ -454,11 +454,11 @@ mod tests { let func_id = Id::test_new(0); let mut builder = FunctionBuilder::new("func".into(), func_id); - let v0 = builder.insert_allocate(Type::Array(Rc::new(vec![Type::field()]), 2)); + let v0 = builder.insert_allocate(Type::Array(Arc::new(vec![Type::field()]), 2)); let one = builder.field_constant(FieldElement::one()); let two = builder.field_constant(FieldElement::one()); - let element_type = Rc::new(vec![Type::field()]); + let element_type = Arc::new(vec![Type::field()]); let array_type = Type::Array(element_type, 2); let array = builder.array_constant(vector![one, two], array_type.clone()); @@ -672,7 +672,7 @@ mod tests { let zero = builder.field_constant(0u128); builder.insert_store(v0, zero); - let v2 = builder.insert_allocate(Type::Reference(Rc::new(Type::field()))); + let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field()))); builder.insert_store(v2, v0); let v3 = builder.insert_load(v2, Type::field()); diff --git a/compiler/noirc_evaluator/src/ssa/opt/rc.rs b/compiler/noirc_evaluator/src/ssa/opt/rc.rs index 1561547e32e..4f109a27874 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/rc.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/rc.rs @@ -161,7 +161,7 @@ fn remove_instructions(to_remove: HashSet, function: &mut Functio #[cfg(test)] mod test { - use std::rc::Rc; + use std::sync::Arc; use crate::ssa::{ function_builder::FunctionBuilder, @@ -209,14 +209,14 @@ mod test { let mut builder = FunctionBuilder::new("foo".into(), main_id); builder.set_runtime(RuntimeType::Brillig); - let inner_array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let inner_array_type = Type::Array(Arc::new(vec![Type::field()]), 2); let v0 = builder.add_parameter(inner_array_type.clone()); builder.insert_inc_rc(v0); builder.insert_inc_rc(v0); builder.insert_dec_rc(v0); - let outer_array_type = Type::Array(Rc::new(vec![inner_array_type]), 1); + let outer_array_type = Type::Array(Arc::new(vec![inner_array_type]), 1); let array = builder.array_constant(vec![v0].into(), outer_array_type); builder.terminate_with_return(vec![array]); @@ -248,7 +248,7 @@ mod test { let main_id = Id::test_new(0); let mut builder = FunctionBuilder::new("mutator".into(), main_id); - let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let array_type = Type::Array(Arc::new(vec![Type::field()]), 2); let v0 = builder.add_parameter(array_type.clone()); let v1 = builder.insert_allocate(array_type.clone()); @@ -297,8 +297,8 @@ mod test { let main_id = Id::test_new(0); let mut builder = FunctionBuilder::new("mutator2".into(), main_id); - let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); - let reference_type = Type::Reference(Rc::new(array_type.clone())); + let array_type = Type::Array(Arc::new(vec![Type::field()]), 2); + let reference_type = Type::Reference(Arc::new(array_type.clone())); let v0 = builder.add_parameter(reference_type); diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index 628e1bd7410..6ca7a76d740 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, rc::Rc}; +use std::{borrow::Cow, sync::Arc}; use acvm::{acir::AcirField, FieldElement}; @@ -174,7 +174,7 @@ impl Context<'_> { let to_bits = self.function.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little)); let length = self.field_constant(FieldElement::from(bit_size as i128)); let result_types = - vec![Type::field(), Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)]; + vec![Type::field(), Type::Array(Arc::new(vec![Type::bool()]), bit_size as usize)]; let rhs_bits = self.insert_call(to_bits, vec![rhs, length], result_types); let rhs_bits = rhs_bits[1]; diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index 13e5c2445ad..fb7091a8854 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -1,5 +1,4 @@ -use std::rc::Rc; -use std::sync::{Mutex, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; @@ -198,7 +197,7 @@ impl<'a> FunctionContext<'a> { // A mutable reference wraps each element into a reference. // This can be multiple values if the element type is a tuple. ast::Type::MutableReference(element) => { - Self::map_type_helper(element, &mut |typ| f(Type::Reference(Rc::new(typ)))) + Self::map_type_helper(element, &mut |typ| f(Type::Reference(Arc::new(typ)))) } ast::Type::FmtString(len, fields) => { // A format string is represented by multiple values @@ -213,7 +212,7 @@ impl<'a> FunctionContext<'a> { let element_types = Self::convert_type(elements).flatten(); Tree::Branch(vec![ Tree::Leaf(f(Type::length_type())), - Tree::Leaf(f(Type::Slice(Rc::new(element_types)))), + Tree::Leaf(f(Type::Slice(Arc::new(element_types)))), ]) } other => Tree::Leaf(f(Self::convert_non_tuple_type(other))), @@ -237,7 +236,7 @@ impl<'a> FunctionContext<'a> { ast::Type::Field => Type::field(), ast::Type::Array(len, element) => { let element_types = Self::convert_type(element).flatten(); - Type::Array(Rc::new(element_types), *len as usize) + Type::Array(Arc::new(element_types), *len as usize) } ast::Type::Integer(Signedness::Signed, bits) => Type::signed((*bits).into()), ast::Type::Integer(Signedness::Unsigned, bits) => Type::unsigned((*bits).into()), @@ -253,7 +252,7 @@ impl<'a> FunctionContext<'a> { ast::Type::MutableReference(element) => { // Recursive call to panic if element is a tuple let element = Self::convert_non_tuple_type(element); - Type::Reference(Rc::new(element)) + Type::Reference(Arc::new(element)) } } } diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index 7ef8870eaa8..c0f6c8965fb 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -27,7 +27,7 @@ num-traits.workspace = true rustc-hash = "1.1.0" small-ord-set = "0.1.3" regex = "1.9.1" -cfg-if = "1.0.0" +cfg-if.workspace = true tracing.workspace = true petgraph = "0.6" rangemap = "1.4.0" diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index dc07f55ee33..f242180134d 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -7,7 +7,7 @@ use crate::ast::{ }; use crate::hir::def_collector::errors::DefCollectorErrorKind; use crate::macros_api::StructId; -use crate::node_interner::{ExprId, QuotedTypeId}; +use crate::node_interner::{ExprId, InternedExpressionKind, QuotedTypeId}; use crate::token::{Attributes, FunctionAttribute, Token, Tokens}; use crate::{Kind, Type}; use acvm::{acir::AcirField, FieldElement}; @@ -43,6 +43,11 @@ pub enum ExpressionKind { // code. It is used to translate function values back into the AST while // guaranteeing they have the same instantiated type and definition id without resolving again. Resolved(ExprId), + + // This is an interned ExpressionKind during comptime code. + // The actual ExpressionKind can be retrieved with a NodeInterner. + Interned(InternedExpressionKind), + Error, } @@ -603,6 +608,7 @@ impl Display for ExpressionKind { Unsafe(block, _) => write!(f, "unsafe {block}"), Error => write!(f, "Error"), Resolved(_) => write!(f, "?Resolved"), + Interned(_) => write!(f, "?Interned"), Unquote(expr) => write!(f, "$({expr})"), Quote(tokens) => { let tokens = vecmap(&tokens.0, ToString::to_string); diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index 6f6d5cbccdf..b10e58aac0c 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -22,7 +22,7 @@ pub use traits::*; pub use type_alias::*; use crate::{ - node_interner::QuotedTypeId, + node_interner::{InternedUnresolvedTypeData, QuotedTypeId}, parser::{ParserError, ParserErrorReason}, token::IntType, BinaryTypeOperator, @@ -141,6 +141,10 @@ pub enum UnresolvedTypeData { /// as a result of being spliced into a macro's token stream input. Resolved(QuotedTypeId), + // This is an interned UnresolvedTypeData during comptime code. + // The actual UnresolvedTypeData can be retrieved with a NodeInterner. + Interned(InternedUnresolvedTypeData), + Unspecified, // This is for when the user declares a variable without specifying it's type Error, } @@ -297,6 +301,7 @@ impl std::fmt::Display for UnresolvedTypeData { Unspecified => write!(f, "unspecified"), Parenthesized(typ) => write!(f, "({typ})"), Resolved(_) => write!(f, "(resolved type)"), + Interned(_) => write!(f, "?Interned"), AsTraitPath(path) => write!(f, "{path}"), } } diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index edccf545a02..c88fcba749b 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -13,6 +13,7 @@ use super::{ use crate::elaborator::types::SELF_TYPE_NAME; use crate::lexer::token::SpannedToken; use crate::macros_api::{SecondaryAttribute, UnresolvedTypeData}; +use crate::node_interner::{InternedExpressionKind, InternedStatementKind}; use crate::parser::{ParserError, ParserErrorReason}; use crate::token::Token; @@ -45,6 +46,9 @@ pub enum StatementKind { Comptime(Box), // This is an expression with a trailing semi-colon Semi(Expression), + // This is an interned StatementKind during comptime code. + // The actual StatementKind can be retrieved with a NodeInterner. + Interned(InternedStatementKind), // This statement is the result of a recovered parse error. // To avoid issuing multiple errors in later steps, it should // be skipped in any future analysis if possible. @@ -97,6 +101,9 @@ impl StatementKind { // A semicolon on a for loop is optional and does nothing StatementKind::For(_) => self, + // No semicolon needed for a resolved statement + StatementKind::Interned(_) => self, + StatementKind::Expression(expr) => { match (&expr.kind, semi, last_statement_in_block) { // Semicolons are optional for these expressions @@ -534,6 +541,7 @@ pub enum LValue { MemberAccess { object: Box, field_name: Ident, span: Span }, Index { array: Box, index: Expression, span: Span }, Dereference(Box, Span), + Interned(InternedExpressionKind, Span), } #[derive(Debug, PartialEq, Eq, Clone)] @@ -591,7 +599,7 @@ impl Recoverable for Pattern { } impl LValue { - fn as_expression(&self) -> Expression { + pub fn as_expression(&self) -> Expression { let kind = match self { LValue::Ident(ident) => ExpressionKind::Variable(Path::from_ident(ident.clone())), LValue::MemberAccess { object, field_name, span: _ } => { @@ -612,17 +620,53 @@ impl LValue { rhs: lvalue.as_expression(), })) } + LValue::Interned(id, _) => ExpressionKind::Interned(*id), }; let span = self.span(); Expression::new(kind, span) } + pub fn from_expression(expr: Expression) -> LValue { + LValue::from_expression_kind(expr.kind, expr.span) + } + + pub fn from_expression_kind(expr: ExpressionKind, span: Span) -> LValue { + match expr { + ExpressionKind::Variable(path) => LValue::Ident(path.as_ident().unwrap().clone()), + ExpressionKind::MemberAccess(member_access) => LValue::MemberAccess { + object: Box::new(LValue::from_expression(member_access.lhs)), + field_name: member_access.rhs, + span, + }, + ExpressionKind::Index(index) => LValue::Index { + array: Box::new(LValue::from_expression(index.collection)), + index: index.index, + span, + }, + ExpressionKind::Prefix(prefix) => { + if matches!( + prefix.operator, + crate::ast::UnaryOp::Dereference { implicitly_added: false } + ) { + LValue::Dereference(Box::new(LValue::from_expression(prefix.rhs)), span) + } else { + panic!("Called LValue::from_expression with an invalid prefix operator") + } + } + ExpressionKind::Interned(id) => LValue::Interned(id, span), + _ => { + panic!("Called LValue::from_expression with an invalid expression") + } + } + } + pub fn span(&self) -> Span { match self { LValue::Ident(ident) => ident.span(), LValue::MemberAccess { span, .. } | LValue::Index { span, .. } | LValue::Dereference(_, span) => *span, + LValue::Interned(_, span) => *span, } } } @@ -777,6 +821,7 @@ impl Display for StatementKind { StatementKind::Continue => write!(f, "continue"), StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind), StatementKind::Semi(semi) => write!(f, "{semi};"), + StatementKind::Interned(_) => write!(f, "(resolved);"), StatementKind::Error => write!(f, "Error"), } } @@ -809,6 +854,7 @@ impl Display for LValue { } LValue::Index { array, index, span: _ } => write!(f, "{array}[{index}]"), LValue::Dereference(lvalue, _span) => write!(f, "*{lvalue}"), + LValue::Interned(_, _) => write!(f, "?Interned"), } } } diff --git a/compiler/noirc_frontend/src/debug/mod.rs b/compiler/noirc_frontend/src/debug/mod.rs index 935acc4e6d0..fe027969473 100644 --- a/compiler/noirc_frontend/src/debug/mod.rs +++ b/compiler/noirc_frontend/src/debug/mod.rs @@ -322,6 +322,9 @@ impl DebugInstrumenter { ast::LValue::Dereference(_ref, _span) => { unimplemented![] } + ast::LValue::Interned(..) => { + unimplemented![] + } } } build_assign_member_stmt( diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index 01b4585640f..12099b556b7 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -46,6 +46,7 @@ impl<'context> Elaborator<'context> { self.crate_id, self.debug_comptime_in_file, self.enable_arithmetic_generics, + self.interpreter_call_stack.clone(), ); elaborator.function_context.push(FunctionContext::default()); diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index cf0b4f4071a..beede7a3a30 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -62,6 +62,11 @@ impl<'context> Elaborator<'context> { self.elaborate_unsafe_block(block_expression) } ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)), + ExpressionKind::Interned(id) => { + let expr_kind = self.interner.get_expression_kind(id); + let expr = Expression::new(expr_kind.clone(), expr.span); + return self.elaborate_expression(expr); + } ExpressionKind::Error => (HirExpression::Error, Type::Error), ExpressionKind::Unquote(_) => { self.push_err(ResolverError::UnquoteUsedOutsideQuote { span: expr.span }); diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 53b46536078..5bbd5f00ca8 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -11,7 +11,7 @@ use crate::{ UnresolvedTypeAlias, }, def_map::DefMaps, - resolution::{errors::ResolverError, path_resolver::PathResolver}, + resolution::errors::ResolverError, scope::ScopeForest as GenericScopeForest, type_check::{generics::TraitGenerics, TypeCheckError}, }, @@ -36,7 +36,7 @@ use crate::{ hir::{ def_collector::{dc_crate::CollectedItems, errors::DefCollectorErrorKind}, def_map::{LocalModuleId, ModuleDefId, ModuleId, MAIN_FUNCTION}, - resolution::{import::PathResolution, path_resolver::StandardPathResolver}, + resolution::import::PathResolution, Context, }, hir_def::function::{FuncMeta, HirFunction}, @@ -168,6 +168,8 @@ pub struct Elaborator<'context> { /// Temporary flag to enable the experimental arithmetic generics feature enable_arithmetic_generics: bool, + + pub(crate) interpreter_call_stack: im::Vector, } #[derive(Default)] @@ -191,6 +193,7 @@ impl<'context> Elaborator<'context> { crate_id: CrateId, debug_comptime_in_file: Option, enable_arithmetic_generics: bool, + interpreter_call_stack: im::Vector, ) -> Self { Self { scopes: ScopeForest::default(), @@ -214,6 +217,7 @@ impl<'context> Elaborator<'context> { unresolved_globals: BTreeMap::new(), enable_arithmetic_generics, current_trait: None, + interpreter_call_stack, } } @@ -229,6 +233,7 @@ impl<'context> Elaborator<'context> { crate_id, debug_comptime_in_file, enable_arithmetic_generics, + im::Vector::new(), ) } @@ -630,10 +635,8 @@ impl<'context> Elaborator<'context> { } } - pub fn resolve_module_by_path(&self, path: Path) -> Option { - let path_resolver = StandardPathResolver::new(self.module_id()); - - match path_resolver.resolve(self.def_maps, path.clone(), &mut None) { + pub fn resolve_module_by_path(&mut self, path: Path) -> Option { + match self.resolve_path(path.clone()) { Ok(PathResolution { module_def_id: ModuleDefId::ModuleId(module_id), error }) => { if error.is_some() { None @@ -646,9 +649,7 @@ impl<'context> Elaborator<'context> { } fn resolve_trait_by_path(&mut self, path: Path) -> Option { - let path_resolver = StandardPathResolver::new(self.module_id()); - - let error = match path_resolver.resolve(self.def_maps, path.clone(), &mut None) { + let error = match self.resolve_path(path.clone()) { Ok(PathResolution { module_def_id: ModuleDefId::TraitId(trait_id), error }) => { if let Some(error) = error { self.push_err(error); diff --git a/compiler/noirc_frontend/src/elaborator/scope.rs b/compiler/noirc_frontend/src/elaborator/scope.rs index 3288d10b62e..a51fd737f74 100644 --- a/compiler/noirc_frontend/src/elaborator/scope.rs +++ b/compiler/noirc_frontend/src/elaborator/scope.rs @@ -2,6 +2,7 @@ use noirc_errors::{Location, Spanned}; use crate::ast::{PathKind, ERROR_IDENT}; use crate::hir::def_map::{LocalModuleId, ModuleId}; +use crate::hir::resolution::import::{PathResolution, PathResolutionResult}; use crate::hir::resolution::path_resolver::{PathResolver, StandardPathResolver}; use crate::hir::scope::{Scope as GenericScope, ScopeTree as GenericScopeTree}; use crate::macros_api::Ident; @@ -29,7 +30,7 @@ type ScopeTree = GenericScopeTree; impl<'context> Elaborator<'context> { pub(super) fn lookup(&mut self, path: Path) -> Result { let span = path.span(); - let id = self.resolve_path(path)?; + let id = self.resolve_path_or_error(path)?; T::try_from(id).ok_or_else(|| ResolverError::Expected { expected: T::description(), got: id.as_str().to_owned(), @@ -42,15 +43,37 @@ impl<'context> Elaborator<'context> { ModuleId { krate: self.crate_id, local_id: self.local_module } } - pub(super) fn resolve_path(&mut self, path: Path) -> Result { + pub(super) fn resolve_path_or_error( + &mut self, + path: Path, + ) -> Result { + let path_resolution = self.resolve_path(path)?; + + if let Some(error) = path_resolution.error { + self.push_err(error); + } + + Ok(path_resolution.module_def_id) + } + + pub(super) fn resolve_path(&mut self, path: Path) -> PathResolutionResult { let mut module_id = self.module_id(); let mut path = path; + if path.kind == PathKind::Plain { + let def_map = self.def_maps.get_mut(&self.crate_id).unwrap(); + let module_data = &mut def_map.modules[module_id.local_id.0]; + module_data.use_import(&path.segments[0].ident); + } + if path.kind == PathKind::Plain && path.first_name() == SELF_TYPE_NAME { if let Some(Type::Struct(struct_type, _)) = &self.self_type { let struct_type = struct_type.borrow(); if path.segments.len() == 1 { - return Ok(ModuleDefId::TypeId(struct_type.id)); + return Ok(PathResolution { + module_def_id: ModuleDefId::TypeId(struct_type.id), + error: None, + }); } module_id = struct_type.id.module_id(); @@ -65,11 +88,7 @@ impl<'context> Elaborator<'context> { self.resolve_path_in_module(path, module_id) } - fn resolve_path_in_module( - &mut self, - path: Path, - module_id: ModuleId, - ) -> Result { + fn resolve_path_in_module(&mut self, path: Path, module_id: ModuleId) -> PathResolutionResult { let resolver = StandardPathResolver::new(module_id); let path_resolution; @@ -99,11 +118,7 @@ impl<'context> Elaborator<'context> { path_resolution = resolver.resolve(self.def_maps, path, &mut None)?; } - if let Some(error) = path_resolution.error { - self.push_err(error); - } - - Ok(path_resolution.module_def_id) + Ok(path_resolution) } pub(super) fn get_struct(&self, type_id: StructId) -> Shared { @@ -150,7 +165,7 @@ impl<'context> Elaborator<'context> { pub(super) fn lookup_global(&mut self, path: Path) -> Result { let span = path.span(); - let id = self.resolve_path(path)?; + let id = self.resolve_path_or_error(path)?; if let Some(function) = TryFromModuleDefId::try_from(id) { return Ok(self.interner.function_definition_id(function)); diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index 0bb8641b6b3..d7d330f891a 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -39,11 +39,16 @@ impl<'context> Elaborator<'context> { let (expr, _typ) = self.elaborate_expression(expr); (HirStatement::Semi(expr), Type::Unit) } + StatementKind::Interned(id) => { + let kind = self.interner.get_statement_kind(id); + let statement = Statement { kind: kind.clone(), span: statement.span }; + self.elaborate_statement_value(statement) + } StatementKind::Error => (HirStatement::Error, Type::Error), } } - pub(super) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { + pub(crate) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { let span = statement.span; let (hir_statement, typ) = self.elaborate_statement_value(statement); let id = self.interner.push_stmt(hir_statement); @@ -357,6 +362,10 @@ impl<'context> Elaborator<'context> { let lvalue = HirLValue::Dereference { lvalue, element_type, location }; (lvalue, typ, true) } + LValue::Interned(id, span) => { + let lvalue = self.interner.get_lvalue(id, span).clone(); + self.elaborate_lvalue(lvalue, assign_span) + } } } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 44bded6b92f..e41234a5be5 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -158,6 +158,10 @@ impl<'context> Elaborator<'context> { Parenthesized(typ) => self.resolve_type_inner(*typ, kind), Resolved(id) => self.interner.get_quoted_type(id).clone(), AsTraitPath(path) => self.resolve_as_trait_path(*path), + Interned(id) => { + let typ = self.interner.get_unresolved_type_data(id).clone(); + return self.resolve_type_inner(UnresolvedType { typ, span }, kind); + } }; let location = Location::new(named_path_span.unwrap_or(typ.span), self.file); @@ -421,7 +425,7 @@ impl<'context> Elaborator<'context> { } // If we cannot find a local generic of the same name, try to look up a global - match self.resolve_path(path.clone()) { + match self.resolve_path_or_error(path.clone()) { Ok(ModuleDefId::GlobalId(id)) => { if let Some(current_item) = self.current_item { self.interner.add_global_dependency(current_item, id); @@ -445,14 +449,19 @@ impl<'context> Elaborator<'context> { }) } UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), - UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); let lhs = self.convert_expression_type(*lhs); let rhs = self.convert_expression_type(*rhs); match (lhs, rhs) { (Type::Constant(lhs), Type::Constant(rhs)) => { - Type::Constant(op.function(lhs, rhs)) + if let Some(result) = op.function(lhs, rhs) { + Type::Constant(result) + } else { + self.push_err(ResolverError::OverflowInType { lhs, op, rhs, span }); + Type::Error + } } (lhs, rhs) => { if !self.enable_arithmetic_generics { diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index fd916485eaf..cfee6bcedac 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -56,6 +56,7 @@ pub enum InterpreterError { FailingConstraint { message: Option, location: Location, + call_stack: im::Vector, }, NoMethodFound { name: String, @@ -353,12 +354,20 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let msg = format!("Expected a `bool` but found `{typ}`"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::FailingConstraint { message, location } => { + InterpreterError::FailingConstraint { message, location, call_stack } => { let (primary, secondary) = match message { Some(msg) => (msg.clone(), "Assertion failed".into()), None => ("Assertion failed".into(), String::new()), }; - CustomDiagnostic::simple_error(primary, secondary, location.span) + let mut diagnostic = + CustomDiagnostic::simple_error(primary, secondary, location.span); + + // Only take at most 3 frames starting from the top of the stack to avoid producing too much output + for frame in call_stack.iter().rev().take(3) { + diagnostic.add_secondary_with_file("".to_string(), frame.span, frame.file); + } + + diagnostic } InterpreterError::NoMethodFound { name, typ, location } => { let msg = format!("No method named `{name}` found for type `{typ}`"); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 33f8c9d8332..4980045c68d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -70,7 +70,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { current_function: Option, ) -> Self { let bound_generics = Vec::new(); - Self { elaborator, crate_id, current_function, bound_generics, in_loop: false } + let in_loop = false; + Self { elaborator, crate_id, current_function, bound_generics, in_loop } } pub(crate) fn call_function( @@ -99,8 +100,11 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } self.remember_bindings(&instantiation_bindings, &impl_bindings); + self.elaborator.interpreter_call_stack.push_back(location); + let result = self.call_function_inner(function, arguments, location); + self.elaborator.interpreter_call_stack.pop_back(); undo_instantiation_bindings(impl_bindings); undo_instantiation_bindings(instantiation_bindings); self.rebind_generics_from_previous_function(); @@ -1462,7 +1466,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let message = constrain.2.and_then(|expr| self.evaluate(expr).ok()); let message = message.map(|value| value.display(self.elaborator.interner).to_string()); - Err(InterpreterError::FailingConstraint { location, message }) + let call_stack = self.elaborator.interpreter_call_stack.clone(); + Err(InterpreterError::FailingConstraint { location, message, call_stack }) } value => { let location = self.elaborator.interner.expr_location(&constrain.0); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index aeac4a87de0..0dafe408c7f 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -6,32 +6,35 @@ use std::{ use acvm::{AcirField, FieldElement}; use builtin_helpers::{ block_expression_to_value, check_argument_count, check_function_not_yet_resolved, - check_one_argument, check_three_arguments, check_two_arguments, get_expr, get_function_def, - get_module, get_quoted, get_slice, get_struct, get_trait_constraint, get_trait_def, - get_trait_impl, get_tuple, get_type, get_u32, get_unresolved_type, hir_pattern_to_tokens, - mutate_func_meta_type, parse, parse_tokens, replace_func_meta_parameters, - replace_func_meta_return_type, + check_one_argument, check_three_arguments, check_two_arguments, get_expr, get_field, + get_function_def, get_module, get_quoted, get_slice, get_struct, get_trait_constraint, + get_trait_def, get_trait_impl, get_tuple, get_type, get_typed_expr, get_u32, + get_unresolved_type, hir_pattern_to_tokens, mutate_func_meta_type, parse, + replace_func_meta_parameters, replace_func_meta_return_type, }; +use chumsky::{prelude::choice, Parser}; use im::Vector; use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; +use num_bigint::BigUint; use rustc_hash::FxHashMap as HashMap; use crate::{ ast::{ - ArrayLiteral, Expression, ExpressionKind, FunctionKind, FunctionReturnType, IntegerBitSize, - Literal, UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, + ArrayLiteral, BlockExpression, ConstrainKind, Expression, ExpressionKind, FunctionKind, + FunctionReturnType, IntegerBitSize, LValue, Literal, Statement, StatementKind, UnaryOp, + UnresolvedType, UnresolvedTypeData, Visibility, }, hir::comptime::{ errors::IResult, - value::{add_token_spans, ExprValue}, + value::{ExprValue, TypedExpr}, InterpreterError, Value, }, hir_def::function::FunctionBody, - macros_api::{ModuleDefId, NodeInterner, Signedness, StatementKind}, + macros_api::{HirExpression, HirLiteral, ModuleDefId, NodeInterner, Signedness}, node_interner::{DefinitionKind, TraitImplKind}, parser::{self}, - token::{SpannedToken, Token}, + token::Token, QuotedType, Shared, Type, }; @@ -49,37 +52,48 @@ impl<'local, 'context> Interpreter<'local, 'context> { location: Location, ) -> IResult { let interner = &mut self.elaborator.interner; + let call_stack = &self.elaborator.interpreter_call_stack; match name { "array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location), "array_len" => array_len(interner, arguments, location), + "assert_constant" => Ok(Value::Bool(true)), "as_slice" => as_slice(interner, arguments, location), - "expr_as_array" => expr_as_array(arguments, return_type, location), - "expr_as_assign" => expr_as_assign(arguments, return_type, location), - "expr_as_binary_op" => expr_as_binary_op(arguments, return_type, location), - "expr_as_block" => expr_as_block(arguments, return_type, location), - "expr_as_bool" => expr_as_bool(arguments, return_type, location), - "expr_as_cast" => expr_as_cast(arguments, return_type, location), - "expr_as_comptime" => expr_as_comptime(arguments, return_type, location), - "expr_as_function_call" => expr_as_function_call(arguments, return_type, location), - "expr_as_if" => expr_as_if(arguments, return_type, location), - "expr_as_index" => expr_as_index(arguments, return_type, location), - "expr_as_integer" => expr_as_integer(arguments, return_type, location), - "expr_as_member_access" => expr_as_member_access(arguments, return_type, location), - "expr_as_method_call" => expr_as_method_call(arguments, return_type, location), + "expr_as_array" => expr_as_array(interner, arguments, return_type, location), + "expr_as_assert" => expr_as_assert(interner, arguments, return_type, location), + "expr_as_assign" => expr_as_assign(interner, arguments, return_type, location), + "expr_as_binary_op" => expr_as_binary_op(interner, arguments, return_type, location), + "expr_as_block" => expr_as_block(interner, arguments, return_type, location), + "expr_as_bool" => expr_as_bool(interner, arguments, return_type, location), + "expr_as_cast" => expr_as_cast(interner, arguments, return_type, location), + "expr_as_comptime" => expr_as_comptime(interner, arguments, return_type, location), + "expr_as_function_call" => { + expr_as_function_call(interner, arguments, return_type, location) + } + "expr_as_if" => expr_as_if(interner, arguments, return_type, location), + "expr_as_index" => expr_as_index(interner, arguments, return_type, location), + "expr_as_integer" => expr_as_integer(interner, arguments, return_type, location), + "expr_as_member_access" => { + expr_as_member_access(interner, arguments, return_type, location) + } + "expr_as_method_call" => { + expr_as_method_call(interner, arguments, return_type, location) + } "expr_as_repeated_element_array" => { - expr_as_repeated_element_array(arguments, return_type, location) + expr_as_repeated_element_array(interner, arguments, return_type, location) } "expr_as_repeated_element_slice" => { - expr_as_repeated_element_slice(arguments, return_type, location) + expr_as_repeated_element_slice(interner, arguments, return_type, location) } - "expr_as_slice" => expr_as_slice(arguments, return_type, location), - "expr_as_tuple" => expr_as_tuple(arguments, return_type, location), - "expr_as_unary_op" => expr_as_unary_op(arguments, return_type, location), - "expr_as_unsafe" => expr_as_unsafe(arguments, return_type, location), - "expr_has_semicolon" => expr_has_semicolon(arguments, location), - "expr_is_break" => expr_is_break(arguments, location), - "expr_is_continue" => expr_is_continue(arguments, location), + "expr_as_slice" => expr_as_slice(interner, arguments, return_type, location), + "expr_as_tuple" => expr_as_tuple(interner, arguments, return_type, location), + "expr_as_unary_op" => expr_as_unary_op(interner, arguments, return_type, location), + "expr_as_unsafe" => expr_as_unsafe(interner, arguments, return_type, location), + "expr_has_semicolon" => expr_has_semicolon(interner, arguments, location), + "expr_is_break" => expr_is_break(interner, arguments, location), + "expr_is_continue" => expr_is_continue(interner, arguments, location), + "expr_resolve" => expr_resolve(self, arguments, location), "is_unconstrained" => Ok(Value::Bool(true)), + "function_def_body" => function_def_body(interner, arguments, location), "function_def_name" => function_def_name(interner, arguments, location), "function_def_parameters" => function_def_parameters(interner, arguments, location), "function_def_return_type" => function_def_return_type(interner, arguments, location), @@ -102,14 +116,15 @@ impl<'local, 'context> Interpreter<'local, 'context> { "quoted_as_type" => quoted_as_type(self, arguments, location), "quoted_eq" => quoted_eq(arguments, location), "slice_insert" => slice_insert(interner, arguments, location), - "slice_pop_back" => slice_pop_back(interner, arguments, location), - "slice_pop_front" => slice_pop_front(interner, arguments, location), + "slice_pop_back" => slice_pop_back(interner, arguments, location, call_stack), + "slice_pop_front" => slice_pop_front(interner, arguments, location, call_stack), "slice_push_back" => slice_push_back(interner, arguments, location), "slice_push_front" => slice_push_front(interner, arguments, location), - "slice_remove" => slice_remove(interner, arguments, location), + "slice_remove" => slice_remove(interner, arguments, location, call_stack), "struct_def_as_type" => struct_def_as_type(interner, arguments, location), "struct_def_fields" => struct_def_fields(interner, arguments, location), "struct_def_generics" => struct_def_generics(interner, arguments, location), + "to_le_radix" => to_le_radix(arguments, location), "trait_constraint_eq" => trait_constraint_eq(interner, arguments, location), "trait_constraint_hash" => trait_constraint_hash(interner, arguments, location), "trait_def_as_trait_constraint" => { @@ -135,7 +150,10 @@ impl<'local, 'context> Interpreter<'local, 'context> { "type_is_bool" => type_is_bool(arguments, location), "type_is_field" => type_is_field(arguments, location), "type_of" => type_of(arguments, location), - "unresolved_type_is_field" => unresolved_type_is_field(arguments, location), + "typed_expr_as_function_definition" => { + typed_expr_as_function_definition(interner, arguments, return_type, location) + } + "unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location), "zeroed" => zeroed(return_type), _ => { let item = format!("Comptime evaluation for builtin function {name}"); @@ -145,8 +163,16 @@ impl<'local, 'context> Interpreter<'local, 'context> { } } -fn failing_constraint(message: impl Into, location: Location) -> IResult { - Err(InterpreterError::FailingConstraint { message: Some(message.into()), location }) +fn failing_constraint( + message: impl Into, + location: Location, + call_stack: &im::Vector, +) -> IResult { + Err(InterpreterError::FailingConstraint { + message: Some(message.into()), + location, + call_stack: call_stack.clone(), + }) } fn array_len( @@ -278,6 +304,7 @@ fn slice_remove( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, + call_stack: &im::Vector, ) -> IResult { let (slice, index) = check_two_arguments(arguments, location)?; @@ -285,7 +312,7 @@ fn slice_remove( let index = get_u32(index)? as usize; if values.is_empty() { - return failing_constraint("slice_remove called on empty slice", location); + return failing_constraint("slice_remove called on empty slice", location, call_stack); } if index >= values.len() { @@ -293,7 +320,7 @@ fn slice_remove( "slice_remove: index {index} is out of bounds for a slice of length {}", values.len() ); - return failing_constraint(message, location); + return failing_constraint(message, location, call_stack); } let element = values.remove(index); @@ -316,13 +343,14 @@ fn slice_pop_front( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, + call_stack: &im::Vector, ) -> IResult { let argument = check_one_argument(arguments, location)?; let (mut values, typ) = get_slice(interner, argument)?; match values.pop_front() { Some(element) => Ok(Value::Tuple(vec![element, Value::Slice(values, typ)])), - None => failing_constraint("slice_pop_front called on empty slice", location), + None => failing_constraint("slice_pop_front called on empty slice", location, call_stack), } } @@ -330,13 +358,14 @@ fn slice_pop_back( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, + call_stack: &im::Vector, ) -> IResult { let argument = check_one_argument(arguments, location)?; let (mut values, typ) = get_slice(interner, argument)?; match values.pop_back() { Some(element) => Ok(Value::Tuple(vec![Value::Slice(values, typ), element])), - None => failing_constraint("slice_pop_back called on empty slice", location), + None => failing_constraint("slice_pop_back called on empty slice", location, call_stack), } } @@ -361,10 +390,14 @@ fn quoted_as_expr( ) -> IResult { let argument = check_one_argument(arguments, location)?; - let expr = parse(argument, parser::expression(), "an expression").ok(); - let value = expr.map(|expr| Value::expression(expr.kind)); + let expr_parser = parser::expression().map(|expr| Value::expression(expr.kind)); + let statement_parser = parser::fresh_statement().map(Value::statement); + let lvalue_parser = parser::lvalue(parser::expression()).map(Value::lvalue); + let parser = choice((expr_parser, statement_parser, lvalue_parser)); - option(return_type, value) + let expr = parse(argument, parser, "an expression").ok(); + + option(return_type, expr) } // fn as_module(quoted: Quoted) -> Option @@ -417,6 +450,35 @@ fn quoted_as_type( Ok(Value::Type(typ)) } +fn to_le_radix(arguments: Vec<(Value, Location)>, location: Location) -> IResult { + let (value, radix, limb_count) = check_three_arguments(arguments, location)?; + + let value = get_field(value)?; + let radix = get_u32(radix)?; + let limb_count = get_u32(limb_count)?; + + // Decompose the integer into its radix digits in little endian form. + let decomposed_integer = compute_to_radix(value, radix); + let decomposed_integer = vecmap(0..limb_count as usize, |i| match decomposed_integer.get(i) { + Some(digit) => Value::U8(*digit), + None => Value::U8(0), + }); + Ok(Value::Array( + decomposed_integer.into(), + Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + )) +} + +fn compute_to_radix(field: FieldElement, radix: u32) -> Vec { + let bit_size = u32::BITS - (radix - 1).leading_zeros(); + let radix_big = BigUint::from(radix); + assert_eq!(BigUint::from(2u128).pow(bit_size), radix_big, "ICE: Radix must be a power of 2"); + let big_integer = BigUint::from_bytes_be(&field.to_be_bytes()); + + // Decompose the integer into its radix digits in little endian form. + big_integer.to_radix_le(radix) +} + // fn as_array(self) -> Option<(Type, Type)> fn type_as_array( arguments: Vec<(Value, Location)>, @@ -709,13 +771,31 @@ fn trait_impl_trait_generic_args( Ok(Value::Slice(trait_generics, slice_type)) } +fn typed_expr_as_function_definition( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let typed_expr = get_typed_expr(self_argument)?; + let option_value = if let TypedExpr::ExprId(expr_id) = typed_expr { + let func_id = interner.lookup_function_from_expr(&expr_id); + func_id.map(Value::FunctionDefinition) + } else { + None + }; + option(return_type, option_value) +} + // fn is_field(self) -> bool fn unresolved_type_is_field( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let typ = get_unresolved_type(self_argument)?; + let typ = get_unresolved_type(interner, self_argument)?; Ok(Value::Bool(matches!(typ, UnresolvedTypeData::FieldElement))) } @@ -802,11 +882,12 @@ fn zeroed(return_type: Type) -> IResult { // fn as_array(self) -> Option<[Expr]> fn expr_as_array( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Array( ArrayLiteral::Standard(exprs), ))) = expr @@ -820,13 +901,46 @@ fn expr_as_array( }) } +// fn as_assert(self) -> Option<(Expr, Option)> +fn expr_as_assert( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(interner, arguments, return_type.clone(), location, |expr| { + if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr { + if constrain.2 == ConstrainKind::Assert { + let predicate = Value::expression(constrain.0.kind); + + let option_type = extract_option_generic_type(return_type); + let Type::Tuple(mut tuple_types) = option_type else { + panic!("Expected the return type option generic arg to be a tuple"); + }; + assert_eq!(tuple_types.len(), 2); + + let option_type = tuple_types.pop().unwrap(); + let message = constrain.1.map(|message| Value::expression(message.kind)); + let message = option(option_type, message).ok()?; + + Some(Value::Tuple(vec![predicate, message])) + } else { + None + } + } else { + None + } + }) +} + // fn as_assign(self) -> Option<(Expr, Expr)> fn expr_as_assign( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Statement(StatementKind::Assign(assign)) = expr { let lhs = Value::lvalue(assign.lvalue); let rhs = Value::expression(assign.expression.kind); @@ -839,11 +953,12 @@ fn expr_as_assign( // fn as_binary_op(self) -> Option<(Expr, BinaryOp, Expr)> fn expr_as_binary_op( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { + expr_as(interner, arguments, return_type.clone(), location, |expr| { if let ExprValue::Expression(ExpressionKind::Infix(infix_expr)) = expr { let option_type = extract_option_generic_type(return_type); let Type::Tuple(mut tuple_types) = option_type else { @@ -872,11 +987,12 @@ fn expr_as_binary_op( // fn as_block(self) -> Option<[Expr]> fn expr_as_block( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Block(block_expr)) = expr { Some(block_expression_to_value(block_expr)) } else { @@ -887,11 +1003,12 @@ fn expr_as_block( // fn as_bool(self) -> Option fn expr_as_bool( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Bool(bool))) = expr { Some(Value::Bool(bool)) } else { @@ -902,11 +1019,12 @@ fn expr_as_bool( // fn as_cast(self) -> Option<(Expr, UnresolvedType)> fn expr_as_cast( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Cast(cast)) = expr { let lhs = Value::expression(cast.lhs.kind); let typ = Value::UnresolvedType(cast.r#type.typ); @@ -919,13 +1037,14 @@ fn expr_as_cast( // fn as_comptime(self) -> Option<[Expr]> fn expr_as_comptime( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { use ExpressionKind::Block; - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Comptime(block_expr, _)) = expr { Some(block_expression_to_value(block_expr)) } else if let ExprValue::Statement(StatementKind::Comptime(statement)) = expr { @@ -951,11 +1070,12 @@ fn expr_as_comptime( // fn as_function_call(self) -> Option<(Expr, [Expr])> fn expr_as_function_call( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Call(call_expression)) = expr { let function = Value::expression(call_expression.func.kind); let arguments = call_expression.arguments.into_iter(); @@ -971,11 +1091,12 @@ fn expr_as_function_call( // fn as_if(self) -> Option<(Expr, Expr, Option)> fn expr_as_if( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { + expr_as(interner, arguments, return_type.clone(), location, |expr| { if let ExprValue::Expression(ExpressionKind::If(if_expr)) = expr { // Get the type of `Option` let option_type = extract_option_generic_type(return_type.clone()); @@ -1003,11 +1124,12 @@ fn expr_as_if( // fn as_index(self) -> Option fn expr_as_index( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Index(index_expr)) = expr { Some(Value::Tuple(vec![ Value::expression(index_expr.collection.kind), @@ -1021,27 +1143,36 @@ fn expr_as_index( // fn as_integer(self) -> Option<(Field, bool)> fn expr_as_integer( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { - if let ExprValue::Expression(ExpressionKind::Literal(Literal::Integer(field, sign))) = expr - { + expr_as(interner, arguments, return_type.clone(), location, |expr| match expr { + ExprValue::Expression(ExpressionKind::Literal(Literal::Integer(field, sign))) => { Some(Value::Tuple(vec![Value::Field(field), Value::Bool(sign)])) - } else { - None } + ExprValue::Expression(ExpressionKind::Resolved(id)) => { + if let HirExpression::Literal(HirLiteral::Integer(field, sign)) = + interner.expression(&id) + { + Some(Value::Tuple(vec![Value::Field(field), Value::Bool(sign)])) + } else { + None + } + } + _ => None, }) } // fn as_member_access(self) -> Option<(Expr, Quoted)> fn expr_as_member_access( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| match expr { + expr_as(interner, arguments, return_type, location, |expr| match expr { ExprValue::Expression(ExpressionKind::MemberAccess(member_access)) => { let tokens = Rc::new(vec![Token::Ident(member_access.rhs.0.contents.clone())]); Some(Value::Tuple(vec![ @@ -1059,11 +1190,12 @@ fn expr_as_member_access( // fn as_method_call(self) -> Option<(Expr, Quoted, [UnresolvedType], [Expr])> fn expr_as_method_call( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::MethodCall(method_call)) = expr { let object = Value::expression(method_call.object.kind); @@ -1092,11 +1224,12 @@ fn expr_as_method_call( // fn as_repeated_element_array(self) -> Option<(Expr, Expr)> fn expr_as_repeated_element_array( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Array( ArrayLiteral::Repeated { repeated_element, length }, ))) = expr @@ -1113,11 +1246,12 @@ fn expr_as_repeated_element_array( // fn as_repeated_element_slice(self) -> Option<(Expr, Expr)> fn expr_as_repeated_element_slice( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Slice( ArrayLiteral::Repeated { repeated_element, length }, ))) = expr @@ -1134,11 +1268,12 @@ fn expr_as_repeated_element_slice( // fn as_slice(self) -> Option<[Expr]> fn expr_as_slice( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Slice( ArrayLiteral::Standard(exprs), ))) = expr @@ -1154,11 +1289,12 @@ fn expr_as_slice( // fn as_tuple(self) -> Option<[Expr]> fn expr_as_tuple( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Tuple(expressions)) = expr { let expressions = expressions.into_iter().map(|expr| Value::expression(expr.kind)).collect(); @@ -1172,11 +1308,12 @@ fn expr_as_tuple( // fn as_unary_op(self) -> Option<(UnaryOp, Expr)> fn expr_as_unary_op( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { + expr_as(interner, arguments, return_type.clone(), location, |expr| { if let ExprValue::Expression(ExpressionKind::Prefix(prefix_expr)) = expr { let option_type = extract_option_generic_type(return_type); let Type::Tuple(mut tuple_types) = option_type else { @@ -1209,11 +1346,12 @@ fn expr_as_unary_op( // fn as_unsafe(self) -> Option<[Expr]> fn expr_as_unsafe( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Unsafe(block_expr, _)) = expr { Some(block_expression_to_value(block_expr)) } else { @@ -1223,28 +1361,41 @@ fn expr_as_unsafe( } // fn as_has_semicolon(self) -> bool -fn expr_has_semicolon(arguments: Vec<(Value, Location)>, location: Location) -> IResult { +fn expr_has_semicolon( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let expr_value = get_expr(self_argument)?; + let expr_value = get_expr(interner, self_argument)?; Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Semi(..))))) } // fn is_break(self) -> bool -fn expr_is_break(arguments: Vec<(Value, Location)>, location: Location) -> IResult { +fn expr_is_break( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let expr_value = get_expr(self_argument)?; + let expr_value = get_expr(interner, self_argument)?; Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Break)))) } // fn is_continue(self) -> bool -fn expr_is_continue(arguments: Vec<(Value, Location)>, location: Location) -> IResult { +fn expr_is_continue( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let expr_value = get_expr(self_argument)?; + let expr_value = get_expr(interner, self_argument)?; Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Continue)))) } // Helper function for implementing the `expr_as_...` functions. fn expr_as( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, @@ -1254,15 +1405,88 @@ where F: FnOnce(ExprValue) -> Option, { let self_argument = check_one_argument(arguments, location)?; - let mut expression_kind = get_expr(self_argument)?; - while let ExprValue::Expression(ExpressionKind::Parenthesized(expression)) = expression_kind { - expression_kind = ExprValue::Expression(expression.kind); - } + let expr_value = get_expr(interner, self_argument)?; + let expr_value = unwrap_expr_value(interner, expr_value); - let option_value = f(expression_kind); + let option_value = f(expr_value); option(return_type, option_value) } +// fn resolve(self) -> TypedExpr +fn expr_resolve( + interpreter: &mut Interpreter, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let self_argument_location = self_argument.1; + let expr_value = get_expr(interpreter.elaborator.interner, self_argument)?; + let expr_value = unwrap_expr_value(interpreter.elaborator.interner, expr_value); + + let value = + interpreter.elaborate_item(interpreter.current_function, |elaborator| match expr_value { + ExprValue::Expression(expression_kind) => { + let expr = Expression { kind: expression_kind, span: self_argument_location.span }; + let (expr_id, _) = elaborator.elaborate_expression(expr); + Value::TypedExpr(TypedExpr::ExprId(expr_id)) + } + ExprValue::Statement(statement_kind) => { + let statement = + Statement { kind: statement_kind, span: self_argument_location.span }; + let (stmt_id, _) = elaborator.elaborate_statement(statement); + Value::TypedExpr(TypedExpr::StmtId(stmt_id)) + } + ExprValue::LValue(lvalue) => { + let expr = lvalue.as_expression(); + let (expr_id, _) = elaborator.elaborate_expression(expr); + Value::TypedExpr(TypedExpr::ExprId(expr_id)) + } + }); + + Ok(value) +} + +fn unwrap_expr_value(interner: &NodeInterner, mut expr_value: ExprValue) -> ExprValue { + loop { + match expr_value { + ExprValue::Expression(ExpressionKind::Parenthesized(expression)) => { + expr_value = ExprValue::Expression(expression.kind); + } + ExprValue::Statement(StatementKind::Expression(expression)) + | ExprValue::Statement(StatementKind::Semi(expression)) => { + expr_value = ExprValue::Expression(expression.kind); + } + ExprValue::Expression(ExpressionKind::Interned(id)) => { + expr_value = ExprValue::Expression(interner.get_expression_kind(id).clone()); + } + ExprValue::Statement(StatementKind::Interned(id)) => { + expr_value = ExprValue::Statement(interner.get_statement_kind(id).clone()); + } + ExprValue::LValue(LValue::Interned(id, span)) => { + expr_value = ExprValue::LValue(interner.get_lvalue(id, span).clone()); + } + _ => break, + } + } + expr_value +} + +// fn body(self) -> Expr +fn function_def_body( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let func_id = get_function_def(self_argument)?; + let func_meta = interner.function_meta(&func_id); + if let FunctionBody::Unresolved(_, block_expr, _) = &func_meta.function_body { + Ok(Value::expression(ExpressionKind::Block(block_expr.clone()))) + } else { + Err(InterpreterError::FunctionAlreadyResolved { location }) + } +} + // fn name(self) -> Quoted fn function_def_name( interner: &NodeInterner, @@ -1317,32 +1541,30 @@ fn function_def_return_type( Ok(Value::Type(func_meta.return_type().follow_bindings())) } -// fn set_body(self, body: Quoted) +// fn set_body(self, body: Expr) fn function_def_set_body( interpreter: &mut Interpreter, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let (self_argument, body_argument) = check_two_arguments(arguments, location)?; - let body_argument_location = body_argument.1; + let body_location = body_argument.1; let func_id = get_function_def(self_argument)?; check_function_not_yet_resolved(interpreter, func_id, location)?; - let body_tokens = get_quoted(body_argument)?; - let mut body_quoted = add_token_spans(body_tokens.clone(), body_argument_location.span); - - // Surround the body in `{ ... }` so we can parse it as a block - body_quoted.0.insert(0, SpannedToken::new(Token::LeftBrace, location.span)); - body_quoted.0.push(SpannedToken::new(Token::RightBrace, location.span)); + let body_argument = get_expr(interpreter.elaborator.interner, body_argument)?; + let statement_kind = match body_argument { + ExprValue::Expression(expression_kind) => StatementKind::Expression(Expression { + kind: expression_kind, + span: body_location.span, + }), + ExprValue::Statement(statement_kind) => statement_kind, + ExprValue::LValue(lvalue) => StatementKind::Expression(lvalue.as_expression()), + }; - let body = parse_tokens( - body_tokens, - body_quoted, - body_argument_location, - parser::block(parser::fresh_statement()), - "a block", - )?; + let statement = Statement { kind: statement_kind, span: body_location.span }; + let body = BlockExpression { statements: vec![statement] }; let func_meta = interpreter.elaborator.interner.function_meta_mut(&func_id); func_meta.has_body = true; diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index a409731a5e4..dd9ea51961e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -4,11 +4,14 @@ use acvm::FieldElement; use noirc_errors::Location; use crate::{ - ast::{BlockExpression, IntegerBitSize, Signedness, UnresolvedTypeData}, + ast::{ + BlockExpression, ExpressionKind, IntegerBitSize, LValue, Signedness, StatementKind, + UnresolvedTypeData, + }, hir::{ comptime::{ errors::IResult, - value::{add_token_spans, ExprValue}, + value::{add_token_spans, ExprValue, TypedExpr}, Interpreter, InterpreterError, Value, }, def_map::ModuleId, @@ -142,9 +145,33 @@ pub(crate) fn get_u32((value, location): (Value, Location)) -> IResult { } } -pub(crate) fn get_expr((value, location): (Value, Location)) -> IResult { +pub(crate) fn get_u64((value, location): (Value, Location)) -> IResult { match value { - Value::Expr(expr) => Ok(expr), + Value::U64(value) => Ok(value), + value => { + let expected = Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour); + type_mismatch(value, expected, location) + } + } +} + +pub(crate) fn get_expr( + interner: &NodeInterner, + (value, location): (Value, Location), +) -> IResult { + match value { + Value::Expr(expr) => match expr { + ExprValue::Expression(ExpressionKind::Interned(id)) => { + Ok(ExprValue::Expression(interner.get_expression_kind(id).clone())) + } + ExprValue::Statement(StatementKind::Interned(id)) => { + Ok(ExprValue::Statement(interner.get_statement_kind(id).clone())) + } + ExprValue::LValue(LValue::Interned(id, _)) => { + Ok(ExprValue::LValue(interner.get_lvalue(id, location.span).clone())) + } + _ => Ok(expr), + }, value => type_mismatch(value, Type::Quoted(QuotedType::Expr), location), } } @@ -200,6 +227,13 @@ pub(crate) fn get_type((value, location): (Value, Location)) -> IResult { } } +pub(crate) fn get_typed_expr((value, location): (Value, Location)) -> IResult { + match value { + Value::TypedExpr(typed_expr) => Ok(typed_expr), + value => type_mismatch(value, Type::Quoted(QuotedType::TypedExpr), location), + } +} + pub(crate) fn get_quoted((value, location): (Value, Location)) -> IResult>> { match value { Value::Quoted(tokens) => Ok(tokens), @@ -208,10 +242,18 @@ pub(crate) fn get_quoted((value, location): (Value, Location)) -> IResult IResult { match value { - Value::UnresolvedType(typ) => Ok(typ), + Value::UnresolvedType(typ) => { + if let UnresolvedTypeData::Interned(id) = typ { + let typ = interner.get_unresolved_type_data(id).clone(); + Ok(typ) + } else { + Ok(typ) + } + } value => type_mismatch(value, Type::Quoted(QuotedType::UnresolvedType), location), } } diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs index f7caf84ec42..5ae60bb4d00 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs @@ -1,5 +1,6 @@ -use acvm::BlackBoxFunctionSolver; +use acvm::blackbox_solver::BlackBoxFunctionSolver; use bn254_blackbox_solver::Bn254BlackBoxSolver; +use im::Vector; use iter_extended::try_vecmap; use noirc_errors::Location; @@ -8,7 +9,9 @@ use crate::{ macros_api::NodeInterner, }; -use super::builtin::builtin_helpers::{check_two_arguments, get_array, get_field, get_u32}; +use super::builtin::builtin_helpers::{ + check_one_argument, check_two_arguments, get_array, get_field, get_u32, get_u64, +}; pub(super) fn call_foreign( interner: &mut NodeInterner, @@ -18,6 +21,7 @@ pub(super) fn call_foreign( ) -> IResult { match name { "poseidon2_permutation" => poseidon2_permutation(interner, arguments, location), + "keccakf1600" => keccakf1600(interner, arguments, location), _ => { let item = format!("Comptime evaluation for builtin function {name}"); Err(InterpreterError::Unimplemented { item, location }) @@ -47,3 +51,26 @@ fn poseidon2_permutation( let array = fields.into_iter().map(Value::Field).collect(); Ok(Value::Array(array, typ)) } + +fn keccakf1600( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let input = check_one_argument(arguments, location)?; + let input_location = input.1; + + let (input, typ) = get_array(interner, input)?; + + let input = try_vecmap(input, |integer| get_u64((integer, input_location)))?; + + let mut state = [0u64; 25]; + for (it, input_value) in state.iter_mut().zip(input.iter()) { + *it = *input_value; + } + let result_lanes = acvm::blackbox_solver::keccakf1600(state) + .map_err(|error| InterpreterError::BlackBoxError(error, location))?; + + let array: Vector = result_lanes.into_iter().map(Value::U64).collect(); + Ok(Value::Array(array, typ)) +} diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 18f482585ea..b96c4852931 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -9,8 +9,11 @@ use strum_macros::Display; use crate::{ ast::{ - ArrayLiteral, BlockExpression, ConstructorExpression, Ident, IntegerBitSize, LValue, - Signedness, Statement, StatementKind, UnresolvedTypeData, + ArrayLiteral, AssignStatement, BlockExpression, CallExpression, CastExpression, + ConstrainStatement, ConstructorExpression, ForLoopStatement, ForRange, Ident, IfExpression, + IndexExpression, InfixExpression, IntegerBitSize, LValue, Lambda, LetStatement, + MemberAccessExpression, MethodCallExpression, PrefixExpression, Signedness, Statement, + StatementKind, UnresolvedTypeData, }, hir::{def_map::ModuleId, type_check::generics::TraitGenerics}, hir_def::{ @@ -21,7 +24,7 @@ use crate::{ Expression, ExpressionKind, HirExpression, HirLiteral, Literal, NodeInterner, Path, StructId, }, - node_interner::{ExprId, FuncId, TraitId, TraitImplId}, + node_interner::{ExprId, FuncId, StmtId, TraitId, TraitImplId}, parser::{self, NoirParser, TopLevelStatement}, token::{SpannedToken, Token, Tokens}, QuotedType, Shared, Type, TypeBindings, @@ -66,6 +69,7 @@ pub enum Value { Type(Type), Zeroed(Type), Expr(ExprValue), + TypedExpr(TypedExpr), UnresolvedType(UnresolvedTypeData), } @@ -76,6 +80,12 @@ pub enum ExprValue { LValue(LValue), } +#[derive(Debug, Clone, PartialEq, Eq, Display)] +pub enum TypedExpr { + ExprId(ExprId), + StmtId(StmtId), +} + impl Value { pub(crate) fn expression(expr: ExpressionKind) -> Self { Value::Expr(ExprValue::Expression(expr)) @@ -134,6 +144,7 @@ impl Value { Value::Type(_) => Type::Quoted(QuotedType::Type), Value::Zeroed(typ) => return Cow::Borrowed(typ), Value::Expr(_) => Type::Quoted(QuotedType::Expr), + Value::TypedExpr(_) => Type::Quoted(QuotedType::TypedExpr), Value::UnresolvedType(_) => Type::Quoted(QuotedType::UnresolvedType), }) } @@ -261,7 +272,8 @@ impl Value { statements: vec![Statement { kind: statement, span: location.span }], }) } - Value::Expr(ExprValue::LValue(_)) + Value::Expr(ExprValue::LValue(lvalue)) => lvalue.as_expression().kind, + Value::TypedExpr(..) | Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(..) @@ -386,7 +398,9 @@ impl Value { HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements))) } Value::Quoted(tokens) => HirExpression::Unquote(add_token_spans(tokens, location.span)), - Value::Expr(..) + Value::TypedExpr(TypedExpr::ExprId(expr_id)) => interner.expression(&expr_id), + Value::TypedExpr(TypedExpr::StmtId(..)) + | Value::Expr(..) | Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(..) @@ -417,6 +431,18 @@ impl Value { let token = match self { Value::Quoted(tokens) => return Ok(unwrap_rc(tokens)), Value::Type(typ) => Token::QuotedType(interner.push_quoted_type(typ)), + Value::Expr(ExprValue::Expression(expr)) => { + Token::InternedExpr(interner.push_expression_kind(expr)) + } + Value::Expr(ExprValue::Statement(statement)) => { + Token::InternedStatement(interner.push_statement_kind(statement)) + } + Value::Expr(ExprValue::LValue(lvalue)) => { + Token::InternedLValue(interner.push_lvalue(lvalue)) + } + Value::UnresolvedType(typ) => { + Token::InternedUnresolvedTypeData(interner.push_unresolved_type_data(typ)) + } other => Token::UnquoteMarker(other.into_hir_expression(interner, location)?), }; Ok(vec![token]) @@ -597,10 +623,24 @@ impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { Value::ModuleDefinition(_) => write!(f, "(module)"), Value::Zeroed(typ) => write!(f, "(zeroed {typ})"), Value::Type(typ) => write!(f, "{}", typ), - Value::Expr(ExprValue::Expression(expr)) => write!(f, "{}", expr), - Value::Expr(ExprValue::Statement(statement)) => write!(f, "{}", statement), - Value::Expr(ExprValue::LValue(lvalue)) => write!(f, "{}", lvalue), - Value::UnresolvedType(typ) => write!(f, "{}", typ), + Value::Expr(ExprValue::Expression(expr)) => { + write!(f, "{}", remove_interned_in_expression_kind(self.interner, expr.clone())) + } + Value::Expr(ExprValue::Statement(statement)) => { + write!(f, "{}", remove_interned_in_statement_kind(self.interner, statement.clone())) + } + Value::Expr(ExprValue::LValue(lvalue)) => { + write!(f, "{}", remove_interned_in_lvalue(self.interner, lvalue.clone())) + } + Value::TypedExpr(_) => write!(f, "(typed expr)"), + Value::UnresolvedType(typ) => { + if let UnresolvedTypeData::Interned(id) = typ { + let typ = self.interner.get_unresolved_type_data(*id); + write!(f, "{}", typ) + } else { + write!(f, "{}", typ) + } + } } } } @@ -609,3 +649,227 @@ fn display_trait_constraint(interner: &NodeInterner, trait_constraint: &TraitCon let trait_ = interner.get_trait(trait_constraint.trait_id); format!("{}: {}{}", trait_constraint.typ, trait_.name, trait_constraint.trait_generics) } + +// Returns a new Expression where all Interned and Resolved expressions have been turned into non-interned ExpressionKind. +fn remove_interned_in_expression(interner: &NodeInterner, expr: Expression) -> Expression { + Expression { kind: remove_interned_in_expression_kind(interner, expr.kind), span: expr.span } +} + +// Returns a new ExpressionKind where all Interned and Resolved expressions have been turned into non-interned ExpressionKind. +fn remove_interned_in_expression_kind( + interner: &NodeInterner, + expr: ExpressionKind, +) -> ExpressionKind { + match expr { + ExpressionKind::Literal(literal) => { + ExpressionKind::Literal(remove_interned_in_literal(interner, literal)) + } + ExpressionKind::Block(block) => { + let statements = + vecmap(block.statements, |stmt| remove_interned_in_statement(interner, stmt)); + ExpressionKind::Block(BlockExpression { statements }) + } + ExpressionKind::Prefix(prefix) => ExpressionKind::Prefix(Box::new(PrefixExpression { + rhs: remove_interned_in_expression(interner, prefix.rhs), + ..*prefix + })), + ExpressionKind::Index(index) => ExpressionKind::Index(Box::new(IndexExpression { + collection: remove_interned_in_expression(interner, index.collection), + index: remove_interned_in_expression(interner, index.index), + })), + ExpressionKind::Call(call) => ExpressionKind::Call(Box::new(CallExpression { + func: Box::new(remove_interned_in_expression(interner, *call.func)), + arguments: vecmap(call.arguments, |arg| remove_interned_in_expression(interner, arg)), + ..*call + })), + ExpressionKind::MethodCall(call) => { + ExpressionKind::MethodCall(Box::new(MethodCallExpression { + object: remove_interned_in_expression(interner, call.object), + arguments: vecmap(call.arguments, |arg| { + remove_interned_in_expression(interner, arg) + }), + ..*call + })) + } + ExpressionKind::Constructor(constructor) => { + ExpressionKind::Constructor(Box::new(ConstructorExpression { + fields: vecmap(constructor.fields, |(name, expr)| { + (name, remove_interned_in_expression(interner, expr)) + }), + ..*constructor + })) + } + ExpressionKind::MemberAccess(member_access) => { + ExpressionKind::MemberAccess(Box::new(MemberAccessExpression { + lhs: remove_interned_in_expression(interner, member_access.lhs), + ..*member_access + })) + } + ExpressionKind::Cast(cast) => ExpressionKind::Cast(Box::new(CastExpression { + lhs: remove_interned_in_expression(interner, cast.lhs), + ..*cast + })), + ExpressionKind::Infix(infix) => ExpressionKind::Infix(Box::new(InfixExpression { + lhs: remove_interned_in_expression(interner, infix.lhs), + rhs: remove_interned_in_expression(interner, infix.rhs), + ..*infix + })), + ExpressionKind::If(if_expr) => ExpressionKind::If(Box::new(IfExpression { + condition: remove_interned_in_expression(interner, if_expr.condition), + consequence: remove_interned_in_expression(interner, if_expr.consequence), + alternative: if_expr + .alternative + .map(|alternative| remove_interned_in_expression(interner, alternative)), + })), + ExpressionKind::Variable(_) => expr, + ExpressionKind::Tuple(expressions) => ExpressionKind::Tuple(vecmap(expressions, |expr| { + remove_interned_in_expression(interner, expr) + })), + ExpressionKind::Lambda(lambda) => ExpressionKind::Lambda(Box::new(Lambda { + body: remove_interned_in_expression(interner, lambda.body), + ..*lambda + })), + ExpressionKind::Parenthesized(expr) => { + ExpressionKind::Parenthesized(Box::new(remove_interned_in_expression(interner, *expr))) + } + ExpressionKind::Quote(_) => expr, + ExpressionKind::Unquote(expr) => { + ExpressionKind::Unquote(Box::new(remove_interned_in_expression(interner, *expr))) + } + ExpressionKind::Comptime(block, span) => { + let statements = + vecmap(block.statements, |stmt| remove_interned_in_statement(interner, stmt)); + ExpressionKind::Comptime(BlockExpression { statements }, span) + } + ExpressionKind::Unsafe(block, span) => { + let statements = + vecmap(block.statements, |stmt| remove_interned_in_statement(interner, stmt)); + ExpressionKind::Unsafe(BlockExpression { statements }, span) + } + ExpressionKind::AsTraitPath(_) => expr, + ExpressionKind::Resolved(id) => { + let expr = interner.expression(&id); + expr.to_display_ast(interner, Span::default()).kind + } + ExpressionKind::Interned(id) => { + let expr = interner.get_expression_kind(id).clone(); + remove_interned_in_expression_kind(interner, expr) + } + ExpressionKind::Error => expr, + } +} + +fn remove_interned_in_literal(interner: &NodeInterner, literal: Literal) -> Literal { + match literal { + Literal::Array(array_literal) => { + Literal::Array(remove_interned_in_array_literal(interner, array_literal)) + } + Literal::Slice(array_literal) => { + Literal::Array(remove_interned_in_array_literal(interner, array_literal)) + } + Literal::Bool(_) + | Literal::Integer(_, _) + | Literal::Str(_) + | Literal::RawStr(_, _) + | Literal::FmtStr(_) + | Literal::Unit => literal, + } +} + +fn remove_interned_in_array_literal( + interner: &NodeInterner, + literal: ArrayLiteral, +) -> ArrayLiteral { + match literal { + ArrayLiteral::Standard(expressions) => { + ArrayLiteral::Standard(vecmap(expressions, |expr| { + remove_interned_in_expression(interner, expr) + })) + } + ArrayLiteral::Repeated { repeated_element, length } => ArrayLiteral::Repeated { + repeated_element: Box::new(remove_interned_in_expression(interner, *repeated_element)), + length: Box::new(remove_interned_in_expression(interner, *length)), + }, + } +} + +// Returns a new Statement where all Interned statements have been turned into non-interned StatementKind. +fn remove_interned_in_statement(interner: &NodeInterner, statement: Statement) -> Statement { + Statement { + kind: remove_interned_in_statement_kind(interner, statement.kind), + span: statement.span, + } +} + +// Returns a new StatementKind where all Interned statements have been turned into non-interned StatementKind. +fn remove_interned_in_statement_kind( + interner: &NodeInterner, + statement: StatementKind, +) -> StatementKind { + match statement { + StatementKind::Let(let_statement) => StatementKind::Let(LetStatement { + expression: remove_interned_in_expression(interner, let_statement.expression), + ..let_statement + }), + StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement( + remove_interned_in_expression(interner, constrain.0), + constrain.1.map(|expr| remove_interned_in_expression(interner, expr)), + constrain.2, + )), + StatementKind::Expression(expr) => { + StatementKind::Expression(remove_interned_in_expression(interner, expr)) + } + StatementKind::Assign(assign) => StatementKind::Assign(AssignStatement { + lvalue: assign.lvalue, + expression: remove_interned_in_expression(interner, assign.expression), + }), + StatementKind::For(for_loop) => StatementKind::For(ForLoopStatement { + range: match for_loop.range { + ForRange::Range(from, to) => ForRange::Range( + remove_interned_in_expression(interner, from), + remove_interned_in_expression(interner, to), + ), + ForRange::Array(expr) => { + ForRange::Array(remove_interned_in_expression(interner, expr)) + } + }, + block: remove_interned_in_expression(interner, for_loop.block), + ..for_loop + }), + StatementKind::Comptime(statement) => { + StatementKind::Comptime(Box::new(remove_interned_in_statement(interner, *statement))) + } + StatementKind::Semi(expr) => { + StatementKind::Semi(remove_interned_in_expression(interner, expr)) + } + StatementKind::Interned(id) => { + let statement = interner.get_statement_kind(id).clone(); + remove_interned_in_statement_kind(interner, statement) + } + StatementKind::Break | StatementKind::Continue | StatementKind::Error => statement, + } +} + +// Returns a new LValue where all Interned LValues have been turned into LValue. +fn remove_interned_in_lvalue(interner: &NodeInterner, lvalue: LValue) -> LValue { + match lvalue { + LValue::Ident(_) => lvalue, + LValue::MemberAccess { object, field_name, span } => LValue::MemberAccess { + object: Box::new(remove_interned_in_lvalue(interner, *object)), + field_name, + span, + }, + LValue::Index { array, index, span } => LValue::Index { + array: Box::new(remove_interned_in_lvalue(interner, *array)), + index: remove_interned_in_expression(interner, index), + span, + }, + LValue::Dereference(lvalue, span) => { + LValue::Dereference(Box::new(remove_interned_in_lvalue(interner, *lvalue)), span) + } + LValue::Interned(id, span) => { + let lvalue = interner.get_lvalue(id, span); + remove_interned_in_lvalue(interner, lvalue) + } + } +} diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index a961de628a8..30c91b42b2e 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -245,6 +245,7 @@ impl DefCollector { /// Collect all of the definitions in a given crate into a CrateDefMap /// Modules which are not a part of the module hierarchy starting with /// the root module, will be ignored. + #[allow(clippy::too_many_arguments)] pub fn collect_crate_and_dependencies( mut def_map: CrateDefMap, context: &mut Context, @@ -252,6 +253,7 @@ impl DefCollector { root_file_id: FileId, debug_comptime_in_file: Option<&str>, enable_arithmetic_generics: bool, + error_on_unused_imports: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { let mut errors: Vec<(CompilationError, FileId)> = vec![]; @@ -265,11 +267,13 @@ impl DefCollector { let crate_graph = &context.crate_graph[crate_id]; for dep in crate_graph.dependencies.clone() { + let error_on_unused_imports = false; errors.extend(CrateDefMap::collect_defs( dep.crate_id, context, debug_comptime_in_file, enable_arithmetic_generics, + error_on_unused_imports, macro_processors, )); @@ -413,8 +417,26 @@ impl DefCollector { ); } + if error_on_unused_imports { + Self::check_unused_imports(context, crate_id, &mut errors); + } + errors } + + fn check_unused_imports( + context: &Context, + crate_id: CrateId, + errors: &mut Vec<(CompilationError, FileId)>, + ) { + errors.extend(context.def_maps[&crate_id].modules().iter().flat_map(|(_, module)| { + module.unused_imports().iter().map(|ident| { + let ident = ident.clone(); + let error = CompilationError::ResolverError(ResolverError::UnusedImport { ident }); + (error, module.location.file) + }) + })); + } } fn add_import_reference( diff --git a/compiler/noirc_frontend/src/hir/def_map/mod.rs b/compiler/noirc_frontend/src/hir/def_map/mod.rs index e607de52ff1..758b4cf6e5c 100644 --- a/compiler/noirc_frontend/src/hir/def_map/mod.rs +++ b/compiler/noirc_frontend/src/hir/def_map/mod.rs @@ -77,6 +77,7 @@ impl CrateDefMap { context: &mut Context, debug_comptime_in_file: Option<&str>, enable_arithmetic_generics: bool, + error_on_unused_imports: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { // Check if this Crate has already been compiled @@ -127,12 +128,14 @@ impl CrateDefMap { root_file_id, debug_comptime_in_file, enable_arithmetic_generics, + error_on_unused_imports, macro_processors, )); errors.extend( parsing_errors.iter().map(|e| (e.clone().into(), root_file_id)).collect::>(), ); + errors } diff --git a/compiler/noirc_frontend/src/hir/def_map/module_data.rs b/compiler/noirc_frontend/src/hir/def_map/module_data.rs index 8a0125cfe95..7b14db8be77 100644 --- a/compiler/noirc_frontend/src/hir/def_map/module_data.rs +++ b/compiler/noirc_frontend/src/hir/def_map/module_data.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use noirc_errors::Location; @@ -24,6 +24,10 @@ pub struct ModuleData { /// True if this module is a `contract Foo { ... }` module containing contract functions pub is_contract: bool, + + /// List of all unused imports. Each time something is imported into this module it's added + /// to this set. When it's used, it's removed. At the end of the program only unused imports remain. + unused_imports: HashSet, } impl ModuleData { @@ -35,6 +39,7 @@ impl ModuleData { definitions: ItemScope::default(), location, is_contract, + unused_imports: HashSet::new(), } } @@ -121,6 +126,11 @@ impl ModuleData { id: ModuleDefId, is_prelude: bool, ) -> Result<(), (Ident, Ident)> { + // Empty spans could come from implicitly injected imports, and we don't want to track those + if name.span().start() < name.span().end() { + self.unused_imports.insert(name.clone()); + } + self.scope.add_item_to_namespace(name, ItemVisibility::Public, id, None, is_prelude) } @@ -137,4 +147,14 @@ impl ModuleData { pub fn value_definitions(&self) -> impl Iterator + '_ { self.definitions.values().values().flat_map(|a| a.values().map(|(id, _, _)| *id)) } + + /// Marks an ident as being used by an import. + pub fn use_import(&mut self, ident: &Ident) { + self.unused_imports.remove(ident); + } + + /// Returns the list of all unused imports at this moment. + pub fn unused_imports(&self) -> &HashSet { + &self.unused_imports + } } diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 0aad50d13b2..0b0d8d735eb 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -20,6 +20,8 @@ pub enum ResolverError { DuplicateDefinition { name: String, first_span: Span, second_span: Span }, #[error("Unused variable")] UnusedVariable { ident: Ident }, + #[error("Unused import")] + UnusedImport { ident: Ident }, #[error("Could not find variable in this scope")] VariableNotDeclared { name: String, span: Span }, #[error("path is not an identifier")] @@ -120,6 +122,8 @@ pub enum ResolverError { NamedTypeArgs { span: Span, item_kind: &'static str }, #[error("Associated constants may only be a field or integer type")] AssociatedConstantsMustBeNumeric { span: Span }, + #[error("Overflow in `{lhs} {op} {rhs}`")] + OverflowInType { lhs: u32, op: crate::BinaryTypeOperator, rhs: u32, span: Span }, } impl ResolverError { @@ -152,6 +156,15 @@ impl<'a> From<&'a ResolverError> for Diagnostic { ident.span(), ) } + ResolverError::UnusedImport { ident } => { + let name = &ident.0.contents; + + Diagnostic::simple_warning( + format!("unused import {name}"), + "unused import ".to_string(), + ident.span(), + ) + } ResolverError::VariableNotDeclared { name, span } => Diagnostic::simple_error( format!("cannot find `{name}` in this scope "), "not found in this scope".to_string(), @@ -480,6 +493,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) } + ResolverError::OverflowInType { lhs, op, rhs, span } => { + Diagnostic::simple_error( + format!("Overflow in `{lhs} {op} {rhs}`"), + "Overflow here".to_string(), + *span, + ) + } } } } diff --git a/compiler/noirc_frontend/src/hir/type_check/generics.rs b/compiler/noirc_frontend/src/hir/type_check/generics.rs index 379c53944e5..697c78745f9 100644 --- a/compiler/noirc_frontend/src/hir/type_check/generics.rs +++ b/compiler/noirc_frontend/src/hir/type_check/generics.rs @@ -160,6 +160,7 @@ fn fmt_trait_generics( write!(f, "{} = {}", named.name, named.typ)?; } } + write!(f, ">")?; } Ok(()) } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 807666f9af9..638003d3fcd 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -24,7 +24,9 @@ use super::{ traits::NamedType, }; -#[derive(PartialEq, Eq, Clone, Hash, Ord, PartialOrd)] +mod arithmetic; + +#[derive(Eq, Clone, Ord, PartialOrd)] pub enum Type { /// A primitive Field type FieldElement, @@ -151,6 +153,7 @@ pub enum QuotedType { Quoted, TopLevelItem, Type, + TypedExpr, StructDefinition, TraitConstraint, TraitDefinition, @@ -739,6 +742,7 @@ impl std::fmt::Display for QuotedType { QuotedType::Quoted => write!(f, "Quoted"), QuotedType::TopLevelItem => write!(f, "TopLevelItem"), QuotedType::Type => write!(f, "Type"), + QuotedType::TypedExpr => write!(f, "TypedExpr"), QuotedType::StructDefinition => write!(f, "StructDefinition"), QuotedType::TraitDefinition => write!(f, "TraitDefinition"), QuotedType::TraitConstraint => write!(f, "TraitConstraint"), @@ -1657,132 +1661,6 @@ impl Type { } } - /// Try to canonicalize the representation of this type. - /// Currently the only type with a canonical representation is - /// `Type::Infix` where for each consecutive commutative operator - /// we sort the non-constant operands by `Type: Ord` and place all constant - /// operands at the end, constant folded. - /// - /// For example: - /// - `canonicalize[((1 + N) + M) + 2] = (M + N) + 3` - /// - `canonicalize[A + 2 * B + 3 - 2] = A + (B * 2) + 3 - 2` - pub fn canonicalize(&self) -> Type { - match self.follow_bindings() { - Type::InfixExpr(lhs, op, rhs) => { - // evaluate_to_u32 also calls canonicalize so if we just called - // `self.evaluate_to_u32()` we'd get infinite recursion. - if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { - return Type::Constant(op.function(lhs, rhs)); - } - - let lhs = lhs.canonicalize(); - let rhs = rhs.canonicalize(); - if let Some(result) = Self::try_simplify_addition(&lhs, op, &rhs) { - return result; - } - - if let Some(result) = Self::try_simplify_subtraction(&lhs, op, &rhs) { - return result; - } - - if op.is_commutative() { - return Self::sort_commutative(&lhs, op, &rhs); - } - - Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) - } - other => other, - } - } - - fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type { - let mut queue = vec![lhs.clone(), rhs.clone()]; - - let mut sorted = BTreeSet::new(); - - let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 }; - let mut constant = zero_value; - - // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. - while let Some(item) = queue.pop() { - match item.canonicalize() { - Type::InfixExpr(lhs, new_op, rhs) if new_op == op => { - queue.push(*lhs); - queue.push(*rhs); - } - Type::Constant(new_constant) => { - constant = op.function(constant, new_constant); - } - other => { - sorted.insert(other); - } - } - } - - if let Some(first) = sorted.pop_first() { - let mut typ = first.clone(); - - for rhs in sorted { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); - } - - if constant != zero_value { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(Type::Constant(constant))); - } - - typ - } else { - // Every type must have been a constant - Type::Constant(constant) - } - } - - /// Try to simplify an addition expression of `lhs + rhs`. - /// - /// - Simplifies `(a - b) + b` to `a`. - fn try_simplify_addition(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { - use BinaryTypeOperator::*; - match lhs { - Type::InfixExpr(l_lhs, l_op, l_rhs) => { - if op == Addition && *l_op == Subtraction { - // TODO: Propagate type bindings. Can do in another PR, this one is large enough. - let unifies = l_rhs.try_unify(rhs, &mut TypeBindings::new()); - if unifies.is_ok() { - return Some(l_lhs.as_ref().clone()); - } - } - None - } - _ => None, - } - } - - /// Try to simplify a subtraction expression of `lhs - rhs`. - /// - /// - Simplifies `(a + C1) - C2` to `a + (C1 - C2)` if C1 and C2 are constants. - fn try_simplify_subtraction(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { - use BinaryTypeOperator::*; - match lhs { - Type::InfixExpr(l_lhs, l_op, l_rhs) => { - // Simplify `(N + 2) - 1` - if op == Subtraction && *l_op == Addition { - if let (Some(lhs_const), Some(rhs_const)) = - (l_rhs.evaluate_to_u32(), rhs.evaluate_to_u32()) - { - if lhs_const > rhs_const { - let constant = Box::new(Type::Constant(lhs_const - rhs_const)); - return Some( - Type::InfixExpr(l_lhs.clone(), *l_op, constant).canonicalize(), - ); - } - } - } - None - } - _ => None, - } - } - /// Try to unify a type variable to `self`. /// This is a helper function factored out from try_unify. fn try_unify_to_type_variable( @@ -1926,7 +1804,7 @@ impl Type { Type::InfixExpr(lhs, op, rhs) => { let lhs = lhs.evaluate_to_u32()?; let rhs = rhs.evaluate_to_u32()?; - Some(op.function(lhs, rhs)) + op.function(lhs, rhs) } _ => None, } @@ -2030,17 +1908,13 @@ impl Type { Type::Forall(typevars, typ) => { assert_eq!(types.len() + implicit_generic_count, typevars.len(), "Turbofish operator used with incorrect generic count which was not caught by name resolution"); + let bindings = + (0..implicit_generic_count).map(|_| interner.next_type_variable()).chain(types); + let replacements = typevars .iter() - .enumerate() - .map(|(i, var)| { - let binding = if i < implicit_generic_count { - interner.next_type_variable() - } else { - types[i - implicit_generic_count].clone() - }; - (var.id(), (var.clone(), binding)) - }) + .zip(bindings) + .map(|(var, binding)| (var.id(), (var.clone(), binding))) .collect(); let instantiated = typ.substitute(&replacements); @@ -2457,13 +2331,13 @@ fn convert_array_expression_to_slice( impl BinaryTypeOperator { /// Perform the actual rust numeric operation associated with this operator - pub fn function(self, a: u32, b: u32) -> u32 { + pub fn function(self, a: u32, b: u32) -> Option { match self { - BinaryTypeOperator::Addition => a.wrapping_add(b), - BinaryTypeOperator::Subtraction => a.wrapping_sub(b), - BinaryTypeOperator::Multiplication => a.wrapping_mul(b), - BinaryTypeOperator::Division => a.wrapping_div(b), - BinaryTypeOperator::Modulo => a.wrapping_rem(b), + BinaryTypeOperator::Addition => a.checked_add(b), + BinaryTypeOperator::Subtraction => a.checked_sub(b), + BinaryTypeOperator::Multiplication => a.checked_mul(b), + BinaryTypeOperator::Division => a.checked_div(b), + BinaryTypeOperator::Modulo => a.checked_rem(b), } } @@ -2681,3 +2555,136 @@ impl std::fmt::Debug for StructType { write!(f, "{}", self.name) } } + +impl std::hash::Hash for Type { + fn hash(&self, state: &mut H) { + if let Some(variable) = self.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + typ.hash(state); + return; + } + } + + if !matches!(self, Type::TypeVariable(..) | Type::NamedGeneric(..)) { + std::mem::discriminant(self).hash(state); + } + + match self { + Type::FieldElement | Type::Bool | Type::Unit | Type::Error => (), + Type::Array(len, elem) => { + len.hash(state); + elem.hash(state); + } + Type::Slice(elem) => elem.hash(state), + Type::Integer(sign, bits) => { + sign.hash(state); + bits.hash(state); + } + Type::String(len) => len.hash(state), + Type::FmtString(len, env) => { + len.hash(state); + env.hash(state); + } + Type::Tuple(elems) => elems.hash(state), + Type::Struct(def, args) => { + def.hash(state); + args.hash(state); + } + Type::Alias(alias, args) => { + alias.hash(state); + args.hash(state); + } + Type::TypeVariable(var, _) | Type::NamedGeneric(var, ..) => var.hash(state), + Type::TraitAsType(trait_id, _, args) => { + trait_id.hash(state); + args.hash(state); + } + Type::Function(args, ret, env, is_unconstrained) => { + args.hash(state); + ret.hash(state); + env.hash(state); + is_unconstrained.hash(state); + } + Type::MutableReference(elem) => elem.hash(state), + Type::Forall(vars, typ) => { + vars.hash(state); + typ.hash(state); + } + Type::Constant(value) => value.hash(state), + Type::Quoted(typ) => typ.hash(state), + Type::InfixExpr(lhs, op, rhs) => { + lhs.hash(state); + op.hash(state); + rhs.hash(state); + } + } + } +} + +impl PartialEq for Type { + fn eq(&self, other: &Self) -> bool { + if let Some(variable) = self.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + return typ == other; + } + } + + if let Some(variable) = other.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + return self == typ; + } + } + + use Type::*; + match (self, other) { + (FieldElement, FieldElement) | (Bool, Bool) | (Unit, Unit) | (Error, Error) => true, + (Array(lhs_len, lhs_elem), Array(rhs_len, rhs_elem)) => { + lhs_len == rhs_len && lhs_elem == rhs_elem + } + (Slice(lhs_elem), Slice(rhs_elem)) => lhs_elem == rhs_elem, + (Integer(lhs_sign, lhs_bits), Integer(rhs_sign, rhs_bits)) => { + lhs_sign == rhs_sign && lhs_bits == rhs_bits + } + (String(lhs_len), String(rhs_len)) => lhs_len == rhs_len, + (FmtString(lhs_len, lhs_env), FmtString(rhs_len, rhs_env)) => { + lhs_len == rhs_len && lhs_env == rhs_env + } + (Tuple(lhs_types), Tuple(rhs_types)) => lhs_types == rhs_types, + (Struct(lhs_struct, lhs_generics), Struct(rhs_struct, rhs_generics)) => { + lhs_struct == rhs_struct && lhs_generics == rhs_generics + } + (Alias(lhs_alias, lhs_generics), Alias(rhs_alias, rhs_generics)) => { + lhs_alias == rhs_alias && lhs_generics == rhs_generics + } + (TraitAsType(lhs_trait, _, lhs_generics), TraitAsType(rhs_trait, _, rhs_generics)) => { + lhs_trait == rhs_trait && lhs_generics == rhs_generics + } + ( + Function(lhs_args, lhs_ret, lhs_env, lhs_unconstrained), + Function(rhs_args, rhs_ret, rhs_env, rhs_unconstrained), + ) => { + let args_and_ret_eq = lhs_args == rhs_args && lhs_ret == rhs_ret; + args_and_ret_eq && lhs_env == rhs_env && lhs_unconstrained == rhs_unconstrained + } + (MutableReference(lhs_elem), MutableReference(rhs_elem)) => lhs_elem == rhs_elem, + (Forall(lhs_vars, lhs_type), Forall(rhs_vars, rhs_type)) => { + lhs_vars == rhs_vars && lhs_type == rhs_type + } + (Constant(lhs), Constant(rhs)) => lhs == rhs, + (Quoted(lhs), Quoted(rhs)) => lhs == rhs, + (InfixExpr(l_lhs, l_op, l_rhs), InfixExpr(r_lhs, r_op, r_rhs)) => { + l_lhs == r_lhs && l_op == r_op && l_rhs == r_rhs + } + // Special case: we consider unbound named generics and type variables to be equal to each + // other if their type variable ids match. This is important for some corner cases in + // monomorphization where we call `replace_named_generics_with_type_variables` but + // still want them to be equal for canonicalization checks in arithmetic generics. + // Without this we'd fail the `serialize` test. + ( + NamedGeneric(lhs_var, _, _) | TypeVariable(lhs_var, _), + NamedGeneric(rhs_var, _, _) | TypeVariable(rhs_var, _), + ) => lhs_var.id() == rhs_var.id(), + _ => false, + } + } +} diff --git a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs new file mode 100644 index 00000000000..ad07185dff1 --- /dev/null +++ b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -0,0 +1,215 @@ +use std::collections::BTreeSet; + +use crate::{BinaryTypeOperator, Type}; + +impl Type { + /// Try to canonicalize the representation of this type. + /// Currently the only type with a canonical representation is + /// `Type::Infix` where for each consecutive commutative operator + /// we sort the non-constant operands by `Type: Ord` and place all constant + /// operands at the end, constant folded. + /// + /// For example: + /// - `canonicalize[((1 + N) + M) + 2] = (M + N) + 3` + /// - `canonicalize[A + 2 * B + 3 - 2] = A + (B * 2) + 3 - 2` + pub fn canonicalize(&self) -> Type { + match self.follow_bindings() { + Type::InfixExpr(lhs, op, rhs) => { + // evaluate_to_u32 also calls canonicalize so if we just called + // `self.evaluate_to_u32()` we'd get infinite recursion. + if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { + if let Some(result) = op.function(lhs, rhs) { + return Type::Constant(result); + } + } + + let lhs = lhs.canonicalize(); + let rhs = rhs.canonicalize(); + if let Some(result) = Self::try_simplify_non_constants_in_lhs(&lhs, op, &rhs) { + return result.canonicalize(); + } + + if let Some(result) = Self::try_simplify_non_constants_in_rhs(&lhs, op, &rhs) { + return result.canonicalize(); + } + + // Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + // where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) + if let Some(result) = Self::try_simplify_partial_constants(&lhs, op, &rhs) { + return result.canonicalize(); + } + + if op.is_commutative() { + return Self::sort_commutative(&lhs, op, &rhs); + } + + Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) + } + other => other, + } + } + + fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type { + let mut queue = vec![lhs.clone(), rhs.clone()]; + + let mut sorted = BTreeSet::new(); + + let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 }; + let mut constant = zero_value; + + // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. + while let Some(item) = queue.pop() { + match item.canonicalize() { + Type::InfixExpr(lhs, new_op, rhs) if new_op == op => { + queue.push(*lhs); + queue.push(*rhs); + } + Type::Constant(new_constant) => { + if let Some(result) = op.function(constant, new_constant) { + constant = result; + } else { + sorted.insert(Type::Constant(new_constant)); + } + } + other => { + sorted.insert(other); + } + } + } + + if let Some(first) = sorted.pop_first() { + let mut typ = first.clone(); + + for rhs in sorted { + typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); + } + + if constant != zero_value { + typ = Type::InfixExpr(Box::new(typ), op, Box::new(Type::Constant(constant))); + } + + typ + } else { + // Every type must have been a constant + Type::Constant(constant) + } + } + + /// Try to simplify non-constant expressions in the form `(N op1 M) op2 M` + /// where the two `M` terms are expected to cancel out. + /// Precondition: `lhs & rhs are in canonical form` + /// + /// - Simplifies `(N +/- M) -/+ M` to `N` + /// - Simplifies `(N */÷ M) ÷/* M` to `N` + fn try_simplify_non_constants_in_lhs( + lhs: &Type, + op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + let Type::InfixExpr(l_lhs, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + // Note that this is exact, syntactic equality, not unification. + // `rhs` is expected to already be in canonical form. + if l_op.inverse() != Some(op) || l_rhs.canonicalize() != *rhs { + return None; + } + + Some(*l_lhs) + } + + /// Try to simplify non-constant expressions in the form `N op1 (M op1 N)` + /// where the two `M` terms are expected to cancel out. + /// Precondition: `lhs & rhs are in canonical form` + /// + /// Unlike `try_simplify_non_constants_in_lhs` we can't simplify `N / (M * N)` + /// Since that should simplify to `1 / M` instead of `M`. + /// + /// - Simplifies `N +/- (M -/+ N)` to `M` + /// - Simplifies `N * (M ÷ N)` to `M` + fn try_simplify_non_constants_in_rhs( + lhs: &Type, + op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + let Type::InfixExpr(r_lhs, r_op, r_rhs) = rhs.follow_bindings() else { + return None; + }; + + // `N / (M * N)` should be simplified to `1 / M`, but we only handle + // simplifying to `M` in this function. + if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication { + return None; + } + + // Note that this is exact, syntactic equality, not unification. + // `lhs` is expected to already be in canonical form. + if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize() { + return None; + } + + Some(*r_lhs) + } + + /// Given: + /// lhs = `N op C1` + /// rhs = C2 + /// Returns: `(N, op, C1, C2)` if C1 and C2 are constants. + /// Note that the operator here is within the `lhs` term, the operator + /// separating lhs and rhs is not needed. + /// Precondition: `lhs & rhs are in canonical form` + fn parse_partial_constant_expr( + lhs: &Type, + rhs: &Type, + ) -> Option<(Box, BinaryTypeOperator, u32, u32)> { + let rhs = rhs.evaluate_to_u32()?; + + let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + let l_rhs = l_rhs.evaluate_to_u32()?; + Some((l_type, l_op, l_rhs, rhs)) + } + + /// Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + /// where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) + /// Precondition: `lhs & rhs are in canonical form` + /// + /// - Simplifies `(N +/- C1) +/- C2` to `N +/- (C1 +/- C2)` if C1 and C2 are constants. + /// - Simplifies `(N */÷ C1) */÷ C2` to `N */÷ (C1 */÷ C2)` if C1 and C2 are constants. + fn try_simplify_partial_constants( + lhs: &Type, + mut op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + use BinaryTypeOperator::*; + let (l_type, l_op, l_const, r_const) = Type::parse_partial_constant_expr(lhs, rhs)?; + + match (l_op, op) { + (Addition | Subtraction, Addition | Subtraction) => { + // If l_op is a subtraction we want to inverse the rhs operator. + if l_op == Subtraction { + op = op.inverse()?; + } + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) + } + (Multiplication | Division, Multiplication | Division) => { + // If l_op is a division we want to inverse the rhs operator. + if l_op == Division { + op = op.inverse()?; + } + // If op is a division we need to ensure it divides evenly + if op == Division && (r_const == 0 || l_const % r_const != 0) { + None + } else { + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) + } + } + _ => None, + } + } +} diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index 8ee0fca2957..c9bd465b6a6 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -4,7 +4,10 @@ use std::{fmt, iter::Map, vec::IntoIter}; use crate::{ lexer::errors::LexerErrorKind, - node_interner::{ExprId, QuotedTypeId}, + node_interner::{ + ExprId, InternedExpressionKind, InternedStatementKind, InternedUnresolvedTypeData, + QuotedTypeId, + }, }; /// Represents a token in noir's grammar - a word, number, @@ -28,6 +31,10 @@ pub enum BorrowedToken<'input> { BlockComment(&'input str, Option), Quote(&'input Tokens), QuotedType(QuotedTypeId), + InternedExpression(InternedExpressionKind), + InternedStatement(InternedStatementKind), + InternedLValue(InternedExpressionKind), + InternedUnresolvedTypeData(InternedUnresolvedTypeData), /// < Less, /// <= @@ -134,6 +141,14 @@ pub enum Token { /// to avoid having to tokenize it, re-parse it, and re-resolve it which /// may change the underlying type. QuotedType(QuotedTypeId), + /// A reference to an interned `ExpressionKind`. + InternedExpr(InternedExpressionKind), + /// A reference to an interned `StatementKind`. + InternedStatement(InternedStatementKind), + /// A reference to an interned `LValue`. + InternedLValue(InternedExpressionKind), + /// A reference to an interned `UnresolvedTypeData`. + InternedUnresolvedTypeData(InternedUnresolvedTypeData), /// < Less, /// <= @@ -233,6 +248,10 @@ pub fn token_to_borrowed_token(token: &Token) -> BorrowedToken<'_> { Token::BlockComment(ref s, _style) => BorrowedToken::BlockComment(s, *_style), Token::Quote(stream) => BorrowedToken::Quote(stream), Token::QuotedType(id) => BorrowedToken::QuotedType(*id), + Token::InternedExpr(id) => BorrowedToken::InternedExpression(*id), + Token::InternedStatement(id) => BorrowedToken::InternedStatement(*id), + Token::InternedLValue(id) => BorrowedToken::InternedLValue(*id), + Token::InternedUnresolvedTypeData(id) => BorrowedToken::InternedUnresolvedTypeData(*id), Token::IntType(ref i) => BorrowedToken::IntType(i.clone()), Token::Less => BorrowedToken::Less, Token::LessEqual => BorrowedToken::LessEqual, @@ -353,8 +372,12 @@ impl fmt::Display for Token { } write!(f, "}}") } - // Quoted types only have an ID so there is nothing to display + // Quoted types and exprs only have an ID so there is nothing to display Token::QuotedType(_) => write!(f, "(type)"), + Token::InternedExpr(_) | Token::InternedStatement(_) | Token::InternedLValue(_) => { + write!(f, "(expr)") + } + Token::InternedUnresolvedTypeData(_) => write!(f, "(type)"), Token::IntType(ref i) => write!(f, "{i}"), Token::Less => write!(f, "<"), Token::LessEqual => write!(f, "<="), @@ -407,6 +430,10 @@ pub enum TokenKind { Attribute, Quote, QuotedType, + InternedExpr, + InternedStatement, + InternedLValue, + InternedUnresolvedTypeData, UnquoteMarker, } @@ -420,6 +447,10 @@ impl fmt::Display for TokenKind { TokenKind::Attribute => write!(f, "attribute"), TokenKind::Quote => write!(f, "quote"), TokenKind::QuotedType => write!(f, "quoted type"), + TokenKind::InternedExpr => write!(f, "interned expr"), + TokenKind::InternedStatement => write!(f, "interned statement"), + TokenKind::InternedLValue => write!(f, "interned lvalue"), + TokenKind::InternedUnresolvedTypeData => write!(f, "interned unresolved type"), TokenKind::UnquoteMarker => write!(f, "macro result"), } } @@ -439,6 +470,10 @@ impl Token { Token::UnquoteMarker(_) => TokenKind::UnquoteMarker, Token::Quote(_) => TokenKind::Quote, Token::QuotedType(_) => TokenKind::QuotedType, + Token::InternedExpr(_) => TokenKind::InternedExpr, + Token::InternedStatement(_) => TokenKind::InternedStatement, + Token::InternedLValue(_) => TokenKind::InternedLValue, + Token::InternedUnresolvedTypeData(_) => TokenKind::InternedUnresolvedTypeData, tok => TokenKind::Token(tok.clone()), } } @@ -927,6 +962,7 @@ pub enum Keyword { TraitDefinition, TraitImpl, Type, + TypedExpr, TypeType, Unchecked, Unconstrained, @@ -982,6 +1018,7 @@ impl fmt::Display for Keyword { Keyword::TraitDefinition => write!(f, "TraitDefinition"), Keyword::TraitImpl => write!(f, "TraitImpl"), Keyword::Type => write!(f, "type"), + Keyword::TypedExpr => write!(f, "TypedExpr"), Keyword::TypeType => write!(f, "Type"), Keyword::Unchecked => write!(f, "unchecked"), Keyword::Unconstrained => write!(f, "unconstrained"), @@ -1040,6 +1077,7 @@ impl Keyword { "TraitImpl" => Keyword::TraitImpl, "type" => Keyword::Type, "Type" => Keyword::TypeType, + "TypedExpr" => Keyword::TypedExpr, "StructDefinition" => Keyword::StructDefinition, "unchecked" => Keyword::Unchecked, "unconstrained" => Keyword::Unconstrained, diff --git a/compiler/noirc_frontend/src/monomorphization/errors.rs b/compiler/noirc_frontend/src/monomorphization/errors.rs index 665bf26f7b9..ce8ef3572e6 100644 --- a/compiler/noirc_frontend/src/monomorphization/errors.rs +++ b/compiler/noirc_frontend/src/monomorphization/errors.rs @@ -34,7 +34,7 @@ impl MonomorphizationError { fn into_diagnostic(self) -> CustomDiagnostic { let message = match &self { MonomorphizationError::UnknownArrayLength { length, .. } => { - format!("ICE: Could not determine array length `{length}`") + format!("Could not determine array length `{length}`") } MonomorphizationError::NoDefaultType { location } => { let message = "Type annotation needed".into(); diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index edb831b2158..87b55540bbd 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -301,6 +301,7 @@ impl<'interner> Monomorphizer<'interner> { } let meta = self.interner.function_meta(&f).clone(); + let mut func_sig = meta.function_signature(); // Follow the bindings of the function signature for entry points // which are not `main` such as foldable functions. @@ -845,6 +846,14 @@ impl<'interner> Monomorphizer<'interner> { return self.resolve_trait_method_expr(expr_id, typ, method); } + // Ensure all instantiation bindings are bound. + // This ensures even unused type variables like `fn foo() {}` have concrete types + if let Some(bindings) = self.interner.try_get_instantiation_bindings(expr_id) { + for (_, binding) in bindings.values() { + Self::check_type(binding, ident.location)?; + } + } + let definition = self.interner.definition(ident.id); let ident = match &definition.kind { DefinitionKind::Function(func_id) => { @@ -1950,6 +1959,7 @@ pub fn resolve_trait_method( TraitImplKind::Normal(impl_id) => impl_id, TraitImplKind::Assumed { object_type, trait_generics } => { let location = interner.expr_location(&expr_id); + match interner.lookup_trait_implementation( &object_type, method.trait_id, diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 2c0426f6938..32f25790e12 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -13,7 +13,11 @@ use petgraph::prelude::DiGraph; use petgraph::prelude::NodeIndex as PetGraphIndex; use rustc_hash::FxHashMap as HashMap; +use crate::ast::ExpressionKind; use crate::ast::Ident; +use crate::ast::LValue; +use crate::ast::StatementKind; +use crate::ast::UnresolvedTypeData; use crate::graph::CrateId; use crate::hir::comptime; use crate::hir::def_collector::dc_crate::CompilationError; @@ -208,6 +212,15 @@ pub struct NodeInterner { /// the actual type since types do not implement Send or Sync. quoted_types: noirc_arena::Arena, + // Interned `ExpressionKind`s during comptime code. + interned_expression_kinds: noirc_arena::Arena, + + // Interned `StatementKind`s during comptime code. + interned_statement_kinds: noirc_arena::Arena, + + // Interned `UnresolvedTypeData`s during comptime code. + interned_unresolved_type_datas: noirc_arena::Arena, + /// Determins whether to run in LSP mode. In LSP mode references are tracked. pub(crate) lsp_mode: bool, @@ -580,6 +593,15 @@ pub struct GlobalInfo { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct QuotedTypeId(noirc_arena::Index); +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct InternedExpressionKind(noirc_arena::Index); + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct InternedStatementKind(noirc_arena::Index); + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct InternedUnresolvedTypeData(noirc_arena::Index); + impl Default for NodeInterner { fn default() -> Self { NodeInterner { @@ -617,6 +639,9 @@ impl Default for NodeInterner { type_alias_ref: Vec::new(), type_ref_locations: Vec::new(), quoted_types: Default::default(), + interned_expression_kinds: Default::default(), + interned_statement_kinds: Default::default(), + interned_unresolved_type_datas: Default::default(), lsp_mode: false, location_indices: LocationIndices::default(), reference_graph: petgraph::graph::DiGraph::new(), @@ -1289,6 +1314,10 @@ impl NodeInterner { &self.instantiation_bindings[&expr_id] } + pub fn try_get_instantiation_bindings(&self, expr_id: ExprId) -> Option<&TypeBindings> { + self.instantiation_bindings.get(&expr_id) + } + pub fn get_field_index(&self, expr_id: ExprId) -> usize { self.field_indices[&expr_id] } @@ -2038,6 +2067,41 @@ impl NodeInterner { &self.quoted_types[id.0] } + pub fn push_expression_kind(&mut self, expr: ExpressionKind) -> InternedExpressionKind { + InternedExpressionKind(self.interned_expression_kinds.insert(expr)) + } + + pub fn get_expression_kind(&self, id: InternedExpressionKind) -> &ExpressionKind { + &self.interned_expression_kinds[id.0] + } + + pub fn push_statement_kind(&mut self, statement: StatementKind) -> InternedStatementKind { + InternedStatementKind(self.interned_statement_kinds.insert(statement)) + } + + pub fn get_statement_kind(&self, id: InternedStatementKind) -> &StatementKind { + &self.interned_statement_kinds[id.0] + } + + pub fn push_lvalue(&mut self, lvalue: LValue) -> InternedExpressionKind { + self.push_expression_kind(lvalue.as_expression().kind) + } + + pub fn get_lvalue(&self, id: InternedExpressionKind, span: Span) -> LValue { + LValue::from_expression_kind(self.get_expression_kind(id).clone(), span) + } + + pub fn push_unresolved_type_data( + &mut self, + typ: UnresolvedTypeData, + ) -> InternedUnresolvedTypeData { + InternedUnresolvedTypeData(self.interned_unresolved_type_datas.insert(typ)) + } + + pub fn get_unresolved_type_data(&self, id: InternedUnresolvedTypeData) -> &UnresolvedTypeData { + &self.interned_unresolved_type_datas[id.0] + } + /// Returns the type of an operator (which is always a function), along with its return type. pub fn get_infix_operator_type( &self, diff --git a/compiler/noirc_frontend/src/parser/mod.rs b/compiler/noirc_frontend/src/parser/mod.rs index f1972bcb9b5..11944cd3304 100644 --- a/compiler/noirc_frontend/src/parser/mod.rs +++ b/compiler/noirc_frontend/src/parser/mod.rs @@ -25,7 +25,7 @@ use noirc_errors::Span; pub use parser::path::path_no_turbofish; pub use parser::traits::trait_bound; pub use parser::{ - block, expression, fresh_statement, parse_program, parse_type, pattern, top_level_items, + block, expression, fresh_statement, lvalue, parse_program, parse_type, pattern, top_level_items, }; #[derive(Debug, Clone)] diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 56c80ee1ce0..8a894ec2b83 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -73,7 +73,9 @@ mod test_helpers; use literals::literal; use path::{maybe_empty_path, path}; -use primitives::{dereference, ident, negation, not, nothing, right_shift_operator, token_kind}; +use primitives::{ + dereference, ident, interned_expr, negation, not, nothing, right_shift_operator, token_kind, +}; use traits::where_clause; /// Entry function for the parser - also handles lexing internally. @@ -487,6 +489,7 @@ where continue_statement(), return_statement(expr_parser.clone()), comptime_statement(expr_parser.clone(), expr_no_constructors, statement), + interned_statement(), expr_parser.map(StatementKind::Expression), )) }) @@ -526,6 +529,15 @@ where keyword(Keyword::Comptime).ignore_then(comptime_statement).map(StatementKind::Comptime) } +pub(super) fn interned_statement() -> impl NoirParser { + token_kind(TokenKind::InternedStatement).map(|token| match token { + Token::InternedStatement(id) => StatementKind::Interned(id), + _ => { + unreachable!("token_kind(InternedStatement) guarantees we parse an interned statement") + } + }) +} + /// Comptime in an expression position only accepts entire blocks fn comptime_expr<'a, S>(statement: S) -> impl NoirParser + 'a where @@ -642,7 +654,7 @@ enum LValueRhs { Index(Expression, Span), } -fn lvalue<'a, P>(expr_parser: P) -> impl NoirParser + 'a +pub fn lvalue<'a, P>(expr_parser: P) -> impl NoirParser + 'a where P: ExprParser + 'a, { @@ -655,7 +667,15 @@ where let parenthesized = lvalue.delimited_by(just(Token::LeftParen), just(Token::RightParen)); - let term = choice((parenthesized, dereferences, l_ident)); + let interned = + token_kind(TokenKind::InternedLValue).map_with_span(|token, span| match token { + Token::InternedLValue(id) => LValue::Interned(id, span), + _ => unreachable!( + "token_kind(InternedLValue) guarantees we parse an interned lvalue" + ), + }); + + let term = choice((parenthesized, dereferences, l_ident, interned)); let l_member_rhs = just(Token::Dot).ignore_then(field_name()).map_with_span(LValueRhs::MemberAccess); @@ -1154,6 +1174,7 @@ where literal(), as_trait_path(parse_type()).map(ExpressionKind::AsTraitPath), macro_quote_marker(), + interned_expr(), )) .map_with_span(Expression::new) .or(parenthesized(expr_parser.clone()).map_with_span(|sub_expr, span| { diff --git a/compiler/noirc_frontend/src/parser/parser/primitives.rs b/compiler/noirc_frontend/src/parser/parser/primitives.rs index 9145fb945c9..c1516e2c927 100644 --- a/compiler/noirc_frontend/src/parser/parser/primitives.rs +++ b/compiler/noirc_frontend/src/parser/parser/primitives.rs @@ -119,6 +119,13 @@ pub(super) fn macro_quote_marker() -> impl NoirParser { }) } +pub(super) fn interned_expr() -> impl NoirParser { + token_kind(TokenKind::InternedExpr).map(|token| match token { + Token::InternedExpr(id) => ExpressionKind::Interned(id), + _ => unreachable!("token_kind(InternedExpr) guarantees we parse an interned expr"), + }) +} + #[cfg(test)] mod test { use crate::parser::parser::{ diff --git a/compiler/noirc_frontend/src/parser/parser/types.rs b/compiler/noirc_frontend/src/parser/parser/types.rs index c655ab8c5a4..9dabb8b83b6 100644 --- a/compiler/noirc_frontend/src/parser/parser/types.rs +++ b/compiler/noirc_frontend/src/parser/parser/types.rs @@ -40,6 +40,7 @@ pub(super) fn parse_type_inner<'a>( function_type(recursive_type_parser.clone()), mutable_reference_type(recursive_type_parser.clone()), as_trait_path_type(recursive_type_parser), + interned_unresolved_type(), )) } @@ -87,6 +88,7 @@ pub(super) fn comptime_type() -> impl NoirParser { type_of_quoted_types(), top_level_item_type(), quoted_type(), + typed_expr_type(), )) } @@ -158,6 +160,12 @@ fn quoted_type() -> impl NoirParser { .map_with_span(|_, span| UnresolvedTypeData::Quoted(QuotedType::Quoted).with_span(span)) } +/// This is the type of a typed/resolved expression. +fn typed_expr_type() -> impl NoirParser { + keyword(Keyword::TypedExpr) + .map_with_span(|_, span| UnresolvedTypeData::Quoted(QuotedType::TypedExpr).with_span(span)) +} + /// This is the type of an already resolved type. /// The only way this can appear in the token input is if an already resolved `Type` object /// was spliced into a macro's token stream via the `$` operator. @@ -168,6 +176,15 @@ pub(super) fn resolved_type() -> impl NoirParser { }) } +pub(super) fn interned_unresolved_type() -> impl NoirParser { + token_kind(TokenKind::InternedUnresolvedTypeData).map_with_span(|token, span| match token { + Token::InternedUnresolvedTypeData(id) => UnresolvedTypeData::Interned(id).with_span(span), + _ => unreachable!( + "token_kind(InternedUnresolvedTypeData) guarantees we parse an interned unresolved type" + ), + }) +} + pub(super) fn string_type() -> impl NoirParser { keyword(Keyword::String) .ignore_then(type_expression().delimited_by(just(Token::Less), just(Token::Greater))) diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index cc4aae7f447..870c781b89d 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -76,15 +76,21 @@ pub(crate) fn get_program(src: &str) -> (ParsedModule, Context, Vec<(Compilation extern_prelude: BTreeMap::new(), }; + let debug_comptime_in_file = None; + let enable_arithmetic_generics = false; + let error_on_unused_imports = true; + let macro_processors = &[]; + // Now we want to populate the CrateDefMap using the DefCollector errors.extend(DefCollector::collect_crate_and_dependencies( def_map, &mut context, program.clone().into_sorted(), root_file_id, - None, // No debug_comptime_in_file - false, // Disallow arithmetic generics - &[], // No macro processors + debug_comptime_in_file, + enable_arithmetic_generics, + error_on_unused_imports, + macro_processors, )); } (program, context, errors) @@ -2424,6 +2430,10 @@ fn use_super() { mod foo { use super::some_func; + + fn bar() { + some_func(); + } } "#; assert_no_errors(src); @@ -3187,3 +3197,37 @@ fn as_trait_path_syntax_no_impl() { use CompilationError::TypeError; assert!(matches!(&errors[0].0, TypeError(TypeCheckError::NoMatchingImplFound { .. }))); } + +#[test] +fn errors_on_unused_import() { + let src = r#" + mod foo { + pub fn bar() {} + pub fn baz() {} + + trait Foo { + } + } + + use foo::bar; + use foo::baz; + use foo::Foo; + + impl Foo for Field { + } + + fn main() { + baz(); + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::ResolverError(ResolverError::UnusedImport { ident }) = &errors[0].0 + else { + panic!("Expected an unused import error"); + }; + + assert_eq!(ident.to_string(), "bar"); +} diff --git a/docs/docs/noir/concepts/comptime.md b/docs/docs/noir/concepts/comptime.md index 2b5c29538b9..ed55a541fbd 100644 --- a/docs/docs/noir/concepts/comptime.md +++ b/docs/docs/noir/concepts/comptime.md @@ -232,6 +232,7 @@ The following is an incomplete list of some `comptime` types along with some use - `fn fields(self) -> [(Quoted, Type)]` - Return the name and type of each field - `TraitConstraint`: A trait constraint such as `From` +- `TypedExpr`: A type-checked expression. - `UnresolvedType`: A syntactic notation that refers to a Noir type that hasn't been resolved yet There are many more functions available by exploring the `std::meta` module and its submodules. diff --git a/docs/docs/noir/concepts/data_types/slices.mdx b/docs/docs/noir/concepts/data_types/slices.mdx index 95da2030843..a0c87c29259 100644 --- a/docs/docs/noir/concepts/data_types/slices.mdx +++ b/docs/docs/noir/concepts/data_types/slices.mdx @@ -20,7 +20,7 @@ fn main() -> pub u32 { } ``` -To write a slice literal, use a preceeding ampersand as in: `&[0; 2]` or +To write a slice literal, use a preceding ampersand as in: `&[0; 2]` or `&[1, 2, 3]`. It is important to note that slices are not references to arrays. In Noir, diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md index 0a32b2b04fc..8f708b2359e 100644 --- a/docs/docs/noir/standard_library/meta/expr.md +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -12,6 +12,12 @@ title: Expr If this expression is an array, this returns a slice of each element in the array. +### as_assert + +#include_code as_assert noir_stdlib/src/meta/expr.nr rust + +If this expression is an assert, this returns the assert expression and the optional message. + ### as_assign #include_code as_assign noir_stdlib/src/meta/expr.nr rust @@ -161,3 +167,27 @@ comptime { #include_code is_continue noir_stdlib/src/meta/expr.nr rust `true` if this expression is `continue`. + +### modify + +#include_code modify noir_stdlib/src/meta/expr.nr rust + +Applies a mapping function to this expression and to all of its sub-expressions. +`f` will be applied to each sub-expression first, then applied to the expression itself. + +This happens recursively for every expression within `self`. + +For example, calling `modify` on `(&[1], &[2, 3])` with an `f` that returns `Option::some` +for expressions that are integers, doubling them, would return `(&[2], &[4, 6])`. + +### quoted + +#include_code quoted noir_stdlib/src/meta/expr.nr rust + +Returns this expression as a `Quoted` value. It's the same as `quote { $self }`. + +### resolve + +#include_code resolve noir_stdlib/src/meta/expr.nr rust + +Resolves and type-checks this expression and returns the result as a `TypedExpr`. If any names used by this expression are not in scope or if there are any type errors, this will give compiler errors as if the expression was written directly into the current `comptime` function. \ No newline at end of file diff --git a/docs/docs/noir/standard_library/meta/function_def.md b/docs/docs/noir/standard_library/meta/function_def.md index 4b359a9d343..8a4e8c84958 100644 --- a/docs/docs/noir/standard_library/meta/function_def.md +++ b/docs/docs/noir/standard_library/meta/function_def.md @@ -7,6 +7,14 @@ a function definition in the source program. ## Methods +### body + +#include_code body noir_stdlib/src/meta/function_def.nr rust + +Returns the body of the function as an expression. This is only valid +on functions in the current crate which have not yet been resolved. +This means any functions called at compile-time are invalid targets for this method. + ### name #include_code name noir_stdlib/src/meta/function_def.nr rust @@ -33,8 +41,6 @@ Mutate the function body to a new expression. This is only valid on functions in the current crate which have not yet been resolved. This means any functions called at compile-time are invalid targets for this method. -Requires the new body to be a valid expression. - ### set_parameters #include_code set_parameters noir_stdlib/src/meta/function_def.nr rust diff --git a/docs/docs/noir/standard_library/meta/op.md b/docs/docs/noir/standard_library/meta/op.md index 37d4cb746ac..d8b154edc02 100644 --- a/docs/docs/noir/standard_library/meta/op.md +++ b/docs/docs/noir/standard_library/meta/op.md @@ -37,6 +37,12 @@ Returns `true` if this operator is `-`. `true` if this operator is `*` +#### quoted + +#include_code unary_quoted noir_stdlib/src/meta/op.nr rust + +Returns this operator as a `Quoted` value. + ### BinaryOp Represents a binary operator. One of `+`, `-`, `*`, `/`, `%`, `==`, `!=`, `<`, `<=`, `>`, `>=`, `&`, `|`, `^`, `>>`, or `<<`. @@ -132,3 +138,9 @@ Represents a binary operator. One of `+`, `-`, `*`, `/`, `%`, `==`, `!=`, `<`, ` #include_code is_shift_right noir_stdlib/src/meta/op.nr rust `true` if this operator is `<<` + +#### quoted + +#include_code binary_quoted noir_stdlib/src/meta/op.nr rust + +Returns this operator as a `Quoted` value. \ No newline at end of file diff --git a/docs/docs/noir/standard_library/meta/typed_expr.md b/docs/docs/noir/standard_library/meta/typed_expr.md new file mode 100644 index 00000000000..eacfd9c1230 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/typed_expr.md @@ -0,0 +1,13 @@ +--- +title: TypedExpr +--- + +`std::meta::typed_expr` contains methods on the built-in `TypedExpr` type for resolved and type-checked expressions. + +## Methods + +### as_function_definition + +#include_code as_function_definition noir_stdlib/src/meta/typed_expr.nr rust + +If this expression refers to a function definitions, returns it. Otherwise returns `Option::none()`. \ No newline at end of file diff --git a/noir_stdlib/src/array.nr b/noir_stdlib/src/array.nr index cef79e7c7f6..23683a54e45 100644 --- a/noir_stdlib/src/array.nr +++ b/noir_stdlib/src/array.nr @@ -3,6 +3,7 @@ use crate::option::Option; use crate::convert::From; impl [T; N] { + /// Returns the length of the slice. #[builtin(array_len)] pub fn len(self) -> u32 {} diff --git a/noir_stdlib/src/collections/map.nr b/noir_stdlib/src/collections/map.nr index bd50f345356..4607b06d667 100644 --- a/noir_stdlib/src/collections/map.nr +++ b/noir_stdlib/src/collections/map.nr @@ -77,10 +77,10 @@ impl Slot { // While conducting lookup, we iterate attempt from 0 to N - 1 due to heuristic, // that if we have went that far without finding desired, // it is very unlikely to be after - performance will be heavily degraded. -impl HashMap { +impl HashMap { // Creates a new instance of HashMap with specified BuildHasher. // docs:start:with_hasher - pub fn with_hasher(_build_hasher: B) -> Self + pub fn with_hasher(_build_hasher: B) -> Self where B: BuildHasher { // docs:end:with_hasher @@ -99,7 +99,7 @@ impl HashMap { // Returns true if the map contains a value for the specified key. // docs:start:contains_key - pub fn contains_key( + pub fn contains_key( self, key: K ) -> bool @@ -183,7 +183,7 @@ impl HashMap { // For each key-value entry applies mutator function. // docs:start:iter_mut - pub fn iter_mut( + pub fn iter_mut( &mut self, f: fn(K, V) -> (K, V) ) @@ -208,7 +208,7 @@ impl HashMap { // For each key applies mutator function. // docs:start:iter_keys_mut - pub fn iter_keys_mut( + pub fn iter_keys_mut( &mut self, f: fn(K) -> K ) @@ -278,7 +278,7 @@ impl HashMap { // Get the value by key. If it does not exist, returns none(). // docs:start:get - pub fn get( + pub fn get( self, key: K ) -> Option @@ -313,7 +313,7 @@ impl HashMap { // Insert key-value entry. In case key was already present, value is overridden. // docs:start:insert - pub fn insert( + pub fn insert( &mut self, key: K, value: V @@ -356,7 +356,7 @@ impl HashMap { // Removes a key-value entry. If key is not present, HashMap remains unchanged. // docs:start:remove - pub fn remove( + pub fn remove( &mut self, key: K ) @@ -388,7 +388,7 @@ impl HashMap { } // Apply HashMap's hasher onto key to obtain pre-hash for probing. - fn hash( + fn hash( self, key: K ) -> u32 diff --git a/noir_stdlib/src/collections/umap.nr b/noir_stdlib/src/collections/umap.nr index 86ae79ea644..c552c053a92 100644 --- a/noir_stdlib/src/collections/umap.nr +++ b/noir_stdlib/src/collections/umap.nr @@ -76,10 +76,10 @@ impl Slot { // While conducting lookup, we iterate attempt from 0 to N - 1 due to heuristic, // that if we have went that far without finding desired, // it is very unlikely to be after - performance will be heavily degraded. -impl UHashMap { +impl UHashMap { // Creates a new instance of UHashMap with specified BuildHasher. // docs:start:with_hasher - pub fn with_hasher(_build_hasher: B) -> Self + pub fn with_hasher(_build_hasher: B) -> Self where B: BuildHasher { // docs:end:with_hasher @@ -88,7 +88,7 @@ impl UHashMap { Self { _table, _len, _build_hasher } } - pub fn with_hasher_and_capacity(_build_hasher: B, capacity: u32) -> Self + pub fn with_hasher_and_capacity(_build_hasher: B, capacity: u32) -> Self where B: BuildHasher { // docs:end:with_hasher @@ -110,7 +110,7 @@ impl UHashMap { // Returns true if the map contains a value for the specified key. // docs:start:contains_key - pub fn contains_key( + pub fn contains_key( self, key: K ) -> bool @@ -194,7 +194,7 @@ impl UHashMap { // For each key-value entry applies mutator function. // docs:start:iter_mut - unconstrained pub fn iter_mut( + unconstrained pub fn iter_mut( &mut self, f: fn(K, V) -> (K, V) ) @@ -216,7 +216,7 @@ impl UHashMap { // For each key applies mutator function. // docs:start:iter_keys_mut - unconstrained pub fn iter_keys_mut( + unconstrained pub fn iter_keys_mut( &mut self, f: fn(K) -> K ) @@ -283,7 +283,7 @@ impl UHashMap { // Get the value by key. If it does not exist, returns none(). // docs:start:get - unconstrained pub fn get( + unconstrained pub fn get( self, key: K ) -> Option @@ -315,7 +315,7 @@ impl UHashMap { // Insert key-value entry. In case key was already present, value is overridden. // docs:start:insert - unconstrained pub fn insert( + unconstrained pub fn insert( &mut self, key: K, value: V @@ -353,7 +353,7 @@ impl UHashMap { } } - unconstrained fn try_resize(&mut self) + unconstrained fn try_resize(&mut self) where B: BuildHasher, K: Eq + Hash, H: Hasher { if self.len() + 1 >= self.capacity() / 2 { let capacity = self.capacity() * 2; @@ -368,7 +368,7 @@ impl UHashMap { // Removes a key-value entry. If key is not present, UHashMap remains unchanged. // docs:start:remove - unconstrained pub fn remove( + unconstrained pub fn remove( &mut self, key: K ) @@ -397,7 +397,7 @@ impl UHashMap { } // Apply UHashMap's hasher onto key to obtain pre-hash for probing. - fn hash( + fn hash( self, key: K ) -> u32 diff --git a/noir_stdlib/src/field/mod.nr b/noir_stdlib/src/field/mod.nr index 4b6deaa1106..534ac012beb 100644 --- a/noir_stdlib/src/field/mod.nr +++ b/noir_stdlib/src/field/mod.nr @@ -2,35 +2,85 @@ mod bn254; use bn254::lt as bn254_lt; impl Field { + /// Asserts that `self` can be represented in `bit_size` bits. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{bit_size}`. + pub fn assert_max_bit_size(self, bit_size: u32) { + crate::assert_constant(bit_size); + assert(bit_size < modulus_num_bits() as u32); + self.__assert_max_bit_size(bit_size); + } + + #[builtin(apply_range_constraint)] + fn __assert_max_bit_size(self, bit_size: u32) {} + + /// Decomposes `self` into its little endian bit decomposition as a `[u1]` slice of length `bit_size`. + /// This slice will be zero padded should not all bits be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{bit_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `bit_size` equal to or greater than the number of bits necessary to represent the `Field` modulus + /// (e.g. 254 for the BN254 field) allow for multiple bit decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_le_bits(self: Self, bit_size: u32) -> [u1] { crate::assert_constant(bit_size); self.__to_le_bits(bit_size) } + /// Decomposes `self` into its big endian bit decomposition as a `[u1]` slice of length `bit_size`. + /// This slice will be zero padded should not all bits be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{bit_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `bit_size` equal to or greater than the number of bits necessary to represent the `Field` modulus + /// (e.g. 254 for the BN254 field) allow for multiple bit decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_be_bits(self: Self, bit_size: u32) -> [u1] { crate::assert_constant(bit_size); self.__to_be_bits(bit_size) } + /// See `Field.to_be_bits` #[builtin(to_le_bits)] fn __to_le_bits(self, _bit_size: u32) -> [u1] {} + /// See `Field.to_le_bits` #[builtin(to_be_bits)] fn __to_be_bits(self, bit_size: u32) -> [u1] {} - #[builtin(apply_range_constraint)] - fn __assert_max_bit_size(self, bit_size: u32) {} - - pub fn assert_max_bit_size(self: Self, bit_size: u32) { - crate::assert_constant(bit_size); - assert(bit_size < modulus_num_bits() as u32); - self.__assert_max_bit_size(bit_size); - } - + /// Decomposes `self` into its little endian byte decomposition as a `[u8]` slice of length `byte_size`. + /// This slice will be zero padded should not all bytes be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{8*byte_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `byte_size` equal to or greater than the number of bytes necessary to represent the `Field` modulus + /// (e.g. 32 for the BN254 field) allow for multiple byte decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_le_bytes(self: Self, byte_size: u32) -> [u8] { self.to_le_radix(256, byte_size) } + /// Decomposes `self` into its big endian byte decomposition as a `[u8]` slice of length `byte_size`. + /// This slice will be zero padded should not all bytes be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{8*byte_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `byte_size` equal to or greater than the number of bytes necessary to represent the `Field` modulus + /// (e.g. 32 for the BN254 field) allow for multiple byte decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_be_bytes(self: Self, byte_size: u32) -> [u8] { self.to_be_radix(256, byte_size) } @@ -47,7 +97,6 @@ impl Field { self.__to_be_radix(radix, result_len) } - // decompose `_self` into a `_result_len` vector over the `_radix` basis // `_radix` must be less than 256 #[builtin(to_le_radix)] fn __to_le_radix(self, radix: u32, result_len: u32) -> [u8] {} diff --git a/noir_stdlib/src/hash/keccak.nr b/noir_stdlib/src/hash/keccak.nr index bb8a9cc2ce2..0c31d238f66 100644 --- a/noir_stdlib/src/hash/keccak.nr +++ b/noir_stdlib/src/hash/keccak.nr @@ -1,19 +1,27 @@ +use crate::collections::vec::Vec; +use crate::runtime::is_unconstrained; + global LIMBS_PER_BLOCK = 17; //BLOCK_SIZE / 8; global NUM_KECCAK_LANES = 25; global BLOCK_SIZE = 136; //(1600 - BITS * 2) / WORD_SIZE; global WORD_SIZE = 8; -use crate::collections::vec::Vec; - #[foreign(keccakf1600)] fn keccakf1600(input: [u64; 25]) -> [u64; 25] {} #[no_predicates] -pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u8; 32] { +pub(crate) fn keccak256(input: [u8; N], message_size: u32) -> [u8; 32] { assert(N >= message_size); - for i in 0..N { - if i >= message_size { - input[i] = 0; + let mut block_bytes = [0; BLOCK_SIZE]; + if is_unconstrained() { + for i in 0..message_size { + block_bytes[i] = input[i]; + } + } else { + for i in 0..N { + if i < message_size { + block_bytes[i] = input[i]; + } } } @@ -24,11 +32,6 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u let real_max_blocks = (message_size + BLOCK_SIZE) / BLOCK_SIZE; let real_blocks_bytes = real_max_blocks * BLOCK_SIZE; - let mut block_bytes = [0; BLOCK_SIZE]; - for i in 0..N { - block_bytes[i] = input[i]; - } - block_bytes[message_size] = 1; block_bytes[real_blocks_bytes - 1] = 0x80; @@ -36,28 +39,28 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u // means we need to swap our byte ordering let num_limbs = max_blocks * LIMBS_PER_BLOCK; //max_blocks_length / WORD_SIZE; for i in 0..num_limbs { - let mut temp = [0; 8]; - for j in 0..8 { - temp[j] = block_bytes[8*i+j]; + let mut temp = [0; WORD_SIZE]; + let word_size_times_i = WORD_SIZE * i; + for j in 0..WORD_SIZE { + temp[j] = block_bytes[word_size_times_i+j]; } - for j in 0..8 { - block_bytes[8 * i + j] = temp[7 - j]; + for j in 0..WORD_SIZE { + block_bytes[word_size_times_i + j] = temp[7 - j]; } } - let byte_size = max_blocks_length; + let mut sliced_buffer = Vec::new(); - for _i in 0..num_limbs { - sliced_buffer.push(0); - } // populate a vector of 64-bit limbs from our byte array for i in 0..num_limbs { + let word_size_times_i = i * WORD_SIZE; + let ws_times_i_plus_7 = word_size_times_i + 7; let mut sliced = 0; - if (i * WORD_SIZE + WORD_SIZE > byte_size) { - let slice_size = byte_size - (i * WORD_SIZE); + if (word_size_times_i + WORD_SIZE > max_blocks_length) { + let slice_size = max_blocks_length - word_size_times_i; let byte_shift = (WORD_SIZE - slice_size) * 8; let mut v = 1; for k in 0..slice_size { - sliced += v * (block_bytes[i * WORD_SIZE+7-k] as Field); + sliced += v * (block_bytes[ws_times_i_plus_7-k] as Field); v *= 256; } let w = 1 << (byte_shift as u8); @@ -65,22 +68,20 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u } else { let mut v = 1; for k in 0..WORD_SIZE { - sliced += v * (block_bytes[i * WORD_SIZE+7-k] as Field); + sliced += v * (block_bytes[ws_times_i_plus_7-k] as Field); v *= 256; } } - sliced_buffer.set(i, sliced as u64); + + sliced_buffer.push(sliced as u64); } //2. sponge_absorb - let num_blocks = max_blocks; let mut state : [u64;NUM_KECCAK_LANES]= [0; NUM_KECCAK_LANES]; - let mut under_block = true; - for i in 0..num_blocks { - if i == real_max_blocks { - under_block = false; - } - if under_block { + // When in an unconstrained runtime we can take advantage of runtime loop bounds, + // thus allowing us to simplify the loop body. + if is_unconstrained() { + for i in 0..real_max_blocks { if (i == 0) { for j in 0..LIMBS_PER_BLOCK { state[j] = sliced_buffer.get(j); @@ -92,6 +93,22 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u } state = keccakf1600(state); } + } else { + // `real_max_blocks` is guaranteed to at least be `1` + // We peel out the first block as to avoid a conditional inside of the loop. + // Otherwise, a dynamic predicate can cause a blowup in a constrained runtime. + for j in 0..LIMBS_PER_BLOCK { + state[j] = sliced_buffer.get(j); + } + state = keccakf1600(state); + for i in 1..max_blocks { + if i < real_max_blocks { + for j in 0..LIMBS_PER_BLOCK { + state[j] = state[j] ^ sliced_buffer.get(i * LIMBS_PER_BLOCK + j); + } + state = keccakf1600(state); + } + } } //3. sponge_squeeze diff --git a/noir_stdlib/src/hash/poseidon2.nr b/noir_stdlib/src/hash/poseidon2.nr index 9626da0cf97..cf820f86370 100644 --- a/noir_stdlib/src/hash/poseidon2.nr +++ b/noir_stdlib/src/hash/poseidon2.nr @@ -26,7 +26,7 @@ impl Poseidon2 { result } - fn perform_duplex(&mut self) -> [Field; RATE] { + fn perform_duplex(&mut self) { // zero-pad the cache for i in 0..RATE { if i >= self.cache_size { @@ -38,61 +38,30 @@ impl Poseidon2 { self.state[i] += self.cache[i]; } self.state = crate::hash::poseidon2_permutation(self.state, 4); - // return `RATE` number of field elements from the sponge state. - let mut result = [0; RATE]; - for i in 0..RATE { - result[i] = self.state[i]; - } - result } fn absorb(&mut self, input: Field) { - if (!self.squeeze_mode) & (self.cache_size == RATE) { + assert(!self.squeeze_mode); + if self.cache_size == RATE { // If we're absorbing, and the cache is full, apply the sponge permutation to compress the cache - let _ = self.perform_duplex(); + self.perform_duplex(); self.cache[0] = input; self.cache_size = 1; - } else if (!self.squeeze_mode) & (self.cache_size != RATE) { + } else { // If we're absorbing, and the cache is not full, add the input into the cache self.cache[self.cache_size] = input; self.cache_size += 1; - } else if self.squeeze_mode { - // If we're in squeeze mode, switch to absorb mode and add the input into the cache. - // N.B. I don't think this code path can be reached?! - self.cache[0] = input; - self.cache_size = 1; - self.squeeze_mode = false; } } fn squeeze(&mut self) -> Field { - if self.squeeze_mode & (self.cache_size == 0) { - // If we're in squeze mode and the cache is empty, there is nothing left to squeeze out of the sponge! - // Switch to absorb mode. - self.squeeze_mode = false; - self.cache_size = 0; - } - if !self.squeeze_mode { - // If we're in absorb mode, apply sponge permutation to compress the cache, populate cache with compressed - // state and switch to squeeze mode. Note: this code block will execute if the previous `if` condition was - // matched - let new_output_elements = self.perform_duplex(); - self.squeeze_mode = true; - for i in 0..RATE { - self.cache[i] = new_output_elements[i]; - } - self.cache_size = RATE; - } - // By this point, we should have a non-empty cache. Pop one item off the top of the cache and return it. - let result = self.cache[0]; - for i in 1..RATE { - if i < self.cache_size { - self.cache[i - 1] = self.cache[i]; - } - } - self.cache_size -= 1; - self.cache[self.cache_size] = 0; - result + assert(!self.squeeze_mode); + // If we're in absorb mode, apply sponge permutation to compress the cache. + self.perform_duplex(); + self.squeeze_mode = true; + + // Pop one item off the top of the permutation and return it. + self.state[0] } fn hash_internal(input: [Field; N], in_len: u32, is_variable_length: bool) -> Field { diff --git a/noir_stdlib/src/hash/sha256.nr b/noir_stdlib/src/hash/sha256.nr index 5035be4b73e..50cf2f965f9 100644 --- a/noir_stdlib/src/hash/sha256.nr +++ b/noir_stdlib/src/hash/sha256.nr @@ -1,3 +1,5 @@ +use crate::runtime::is_unconstrained; + // Implementation of SHA-256 mapping a byte array of variable length to // 32 bytes. @@ -17,82 +19,230 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { sha256_var(msg, N as u64) } +// Convert 64-byte array to array of 16 u32s +fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { + let mut msg32: [u32; 16] = [0; 16]; + + for i in 0..16 { + let mut msg_field: Field = 0; + for j in 0..4 { + msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; + } + msg32[15 - i] = msg_field as u32; + } + + msg32 +} + +unconstrained fn build_msg_block_iter(msg: [u8; N], message_size: u64, msg_start: u32) -> ([u8; 64], u64) { + let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE]; + let mut msg_byte_ptr: u64 = 0; // Message byte pointer + let mut msg_end = msg_start + BLOCK_SIZE; + if msg_end > N { + msg_end = N; + } + for k in msg_start..msg_end { + if k as u64 < message_size { + msg_block[msg_byte_ptr] = msg[k]; + msg_byte_ptr = msg_byte_ptr + 1; + } + } + (msg_block, msg_byte_ptr) +} + +// Verify the block we are compressing was appropriately constructed +fn verify_msg_block( + msg: [u8; N], + message_size: u64, + msg_block: [u8; 64], + msg_start: u32 +) -> u64 { + let mut msg_byte_ptr: u64 = 0; // Message byte pointer + let mut msg_end = msg_start + BLOCK_SIZE; + let mut extra_bytes = 0; + if msg_end > N { + msg_end = N; + extra_bytes = msg_end - N; + } + + for k in msg_start..msg_end { + if k as u64 < message_size { + assert_eq(msg_block[msg_byte_ptr], msg[k]); + msg_byte_ptr = msg_byte_ptr + 1; + } + } + + msg_byte_ptr +} + +global BLOCK_SIZE = 64; +global ZERO = 0; + // Variable size SHA-256 hash pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { - let mut msg_block: [u8; 64] = [0; 64]; + let num_blocks = N / BLOCK_SIZE; + let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE]; let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value - let mut i: u64 = 0; // Message byte pointer - for k in 0..N { - if k as u64 < message_size { - // Populate msg_block - msg_block[i] = msg[k]; - i = i + 1; - if i == 64 { - // Enough to hash block - h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); - - i = 0; + let mut msg_byte_ptr = 0; // Pointer into msg_block + + for i in 0..num_blocks { + let msg_start = BLOCK_SIZE * i; + let (new_msg_block, new_msg_byte_ptr) = unsafe { + build_msg_block_iter(msg, message_size, msg_start) + }; + if msg_start as u64 < message_size { + msg_block = new_msg_block; + } + + if !is_unconstrained() { + // Verify the block we are compressing was appropriately constructed + let new_msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, msg_start); + if msg_start as u64 < message_size { + msg_byte_ptr = new_msg_byte_ptr; } + } else if msg_start as u64 < message_size { + msg_byte_ptr = new_msg_byte_ptr; + } + + // If the block is filled, compress it. + // An un-filled block is handled after this loop. + if msg_byte_ptr == 64 { + h = sha256_compression(msg_u8_to_u32(msg_block), h); } } + + let modulo = N % BLOCK_SIZE; + // Handle setup of the final msg block. + // This case is only hit if the msg is less than the block size, + // or our message cannot be evenly split into blocks. + if modulo != 0 { + let msg_start = BLOCK_SIZE * num_blocks; + let (new_msg_block, new_msg_byte_ptr) = unsafe { + build_msg_block_iter(msg, message_size, msg_start) + }; + + if msg_start as u64 < message_size { + msg_block = new_msg_block; + } + + if !is_unconstrained() { + let new_msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, msg_start); + if msg_start as u64 < message_size { + msg_byte_ptr = new_msg_byte_ptr; + } + } else if msg_start as u64 < message_size { + msg_byte_ptr = new_msg_byte_ptr; + } + } + + if msg_byte_ptr == BLOCK_SIZE as u64 { + msg_byte_ptr = 0; + } + + // This variable is used to get around the compiler under-constrained check giving a warning. + // We want to check against a constant zero, but if it does not come from the circuit inputs + // or return values the compiler check will issue a warning. + let zero = msg_block[0] - msg_block[0]; + // Pad the rest such that we have a [u32; 2] block at the end representing the length - // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i] = 1 << 7; - i = i + 1; - // If i >= 57, there aren't enough bits in the current message block to accomplish this, so - // the 1 and 0s fill up the current block, which we then compress accordingly. - if i >= 57 { + // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). + msg_block[msg_byte_ptr] = 1 << 7; + let last_block = msg_block; + msg_byte_ptr = msg_byte_ptr + 1; + + unsafe { + let (new_msg_block, new_msg_byte_ptr) = pad_msg_block(msg_block, msg_byte_ptr); + msg_block = new_msg_block; + if crate::runtime::is_unconstrained() { + msg_byte_ptr = new_msg_byte_ptr; + } + } + + if !crate::runtime::is_unconstrained() { + for i in 0..64 { + assert_eq(msg_block[i], last_block[i]); + } + + // If i >= 57, there aren't enough bits in the current message block to accomplish this, so + // the 1 and 0s fill up the current block, which we then compress accordingly. // Not enough bits (64) to store length. Fill up with zeros. - if i < 64 { - for _i in 57..64 { - if i <= 63 { - msg_block[i] = 0; - i += 1; - } + for _i in 57..64 { + if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 { + assert_eq(msg_block[msg_byte_ptr], zero); + msg_byte_ptr += 1; } } - h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); + } + + if msg_byte_ptr >= 57 { + h = sha256_compression(msg_u8_to_u32(msg_block), h); - i = 0; + msg_byte_ptr = 0; } - let len = 8 * message_size; - let len_bytes = (len as Field).to_le_bytes(8); - for _i in 0..64 { - // In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56). - if i < 56 { - msg_block[i] = 0; - i = i + 1; - } else if i < 64 { - for j in 0..8 { - msg_block[63 - j] = len_bytes[j]; + msg_block = unsafe { + attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size) + }; + + if !crate::runtime::is_unconstrained() { + for i in 0..56 { + if i < msg_byte_ptr { + assert_eq(msg_block[i], last_block[i]); + } else { + assert_eq(msg_block[i], zero); } - i += 8; + } + + let len = 8 * message_size; + let len_bytes = (len as Field).to_be_bytes(8); + for i in 56..64 { + assert_eq(msg_block[i], len_bytes[i - 56]); } } + hash_final_block(msg_block, h) } -// Convert 64-byte array to array of 16 u32s -fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { - let mut msg32: [u32; 16] = [0; 16]; - - for i in 0..16 { - let mut msg_field: Field = 0; - for j in 0..4 { - msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; +unconstrained fn pad_msg_block(mut msg_block: [u8; 64], mut msg_byte_ptr: u64) -> ([u8; 64], u64) { + // If i >= 57, there aren't enough bits in the current message block to accomplish this, so + // the 1 and 0s fill up the current block, which we then compress accordingly. + if msg_byte_ptr >= 57 { + // Not enough bits (64) to store length. Fill up with zeros. + if msg_byte_ptr < 64 { + for _ in 57..64 { + if msg_byte_ptr <= 63 { + msg_block[msg_byte_ptr] = 0; + msg_byte_ptr += 1; + } + } } - msg32[15 - i] = msg_field as u32; } + (msg_block, msg_byte_ptr) +} - msg32 +unconstrained fn attach_len_to_msg_block(mut msg_block: [u8; 64], mut msg_byte_ptr: u64, message_size: u64) -> [u8; 64] { + let len = 8 * message_size; + let len_bytes = (len as Field).to_be_bytes(8); + for _i in 0..64 { + // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). + if msg_byte_ptr < 56 { + msg_block[msg_byte_ptr] = 0; + msg_byte_ptr = msg_byte_ptr + 1; + } else if msg_byte_ptr < 64 { + for j in 0..8 { + msg_block[msg_byte_ptr + j] = len_bytes[j]; + } + msg_byte_ptr += 8; + } + } + msg_block } fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] { let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes // Hash final padded block - state = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), state); + state = sha256_compression(msg_u8_to_u32(msg_block), state); // Return final hash as byte array for j in 0..8 { @@ -104,3 +254,4 @@ fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] { out_h } + diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index f04b06c4642..7d9b6b6da34 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -8,6 +8,11 @@ impl Expr { fn as_array(self) -> Option<[Expr]> {} // docs:end:as_array + #[builtin(expr_as_assert)] + // docs:start:as_assert + fn as_assert(self) -> Option<(Expr, Option)> {} + // docs:end:as_assert + #[builtin(expr_as_assign)] // docs:start:as_assign fn as_assign(self) -> Option<(Expr, Expr)> {} @@ -110,6 +115,338 @@ impl Expr { // docs:start:is_continue fn is_continue(self) -> bool {} // docs:end:is_continue + + // docs:start:modify + fn modify(self, f: fn[Env](Expr) -> Option) -> Expr { + // docs:end:modify + let result = modify_array(self, f); + let result = result.or_else(|| modify_assert(self, f)); + let result = result.or_else(|| modify_assign(self, f)); + let result = result.or_else(|| modify_binary_op(self, f)); + let result = result.or_else(|| modify_block(self, f)); + let result = result.or_else(|| modify_cast(self, f)); + let result = result.or_else(|| modify_comptime(self, f)); + let result = result.or_else(|| modify_if(self, f)); + let result = result.or_else(|| modify_index(self, f)); + let result = result.or_else(|| modify_function_call(self, f)); + let result = result.or_else(|| modify_member_access(self, f)); + let result = result.or_else(|| modify_method_call(self, f)); + let result = result.or_else(|| modify_repeated_element_array(self, f)); + let result = result.or_else(|| modify_repeated_element_slice(self, f)); + let result = result.or_else(|| modify_slice(self, f)); + let result = result.or_else(|| modify_tuple(self, f)); + let result = result.or_else(|| modify_unary_op(self, f)); + let result = result.or_else(|| modify_unsafe(self, f)); + if result.is_some() { + let result = result.unwrap_unchecked(); + let modified = f(result); + modified.unwrap_or(result) + } else { + f(self).unwrap_or(self) + } + } + + // docs:start:quoted + fn quoted(self) -> Quoted { + // docs:end:quoted + quote { $self } + } + + #[builtin(expr_resolve)] + // docs:start:resolve + fn resolve(self) -> TypedExpr {} + // docs:end:resolve +} + +fn modify_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_array().map( + |exprs: [Expr]| { + let exprs = modify_expressions(exprs, f); + new_array(exprs) + } + ) +} + +fn modify_assert(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_assert().map( + |expr: (Expr, Option)| { + let (predicate, msg) = expr; + let predicate = predicate.modify(f); + let msg = msg.map(|msg: Expr| msg.modify(f)); + new_assert(predicate, msg) + } + ) +} + +fn modify_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_assign().map( + |expr: (Expr, Expr)| { + let (lhs, rhs) = expr; + let lhs = lhs.modify(f); + let rhs = rhs.modify(f); + new_assign(lhs, rhs) + } + ) +} + +fn modify_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_binary_op().map( + |expr: (Expr, BinaryOp, Expr)| { + let (lhs, op, rhs) = expr; + let lhs = lhs.modify(f); + let rhs = rhs.modify(f); + new_binary_op(lhs, op, rhs) + } + ) +} + +fn modify_block(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_block().map( + |exprs: [Expr]| { + let exprs = modify_expressions(exprs, f); + new_block(exprs) + } + ) +} + +fn modify_cast(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_cast().map( + |expr: (Expr, UnresolvedType)| { + let (expr, typ) = expr; + let expr = expr.modify(f); + new_cast(expr, typ) + } + ) +} + +fn modify_comptime(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_comptime().map( + |exprs: [Expr]| { + let exprs = exprs.map(|expr: Expr| expr.modify(f)); + new_comptime(exprs) + } + ) +} + +fn modify_function_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_function_call().map( + |expr: (Expr, [Expr])| { + let (function, arguments) = expr; + let function = function.modify(f); + let arguments = arguments.map(|arg: Expr| arg.modify(f)); + new_function_call(function, arguments) + } + ) +} + +fn modify_if(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_if().map( + |expr: (Expr, Expr, Option)| { + let (condition, consequence, alternative) = expr; + let condition = condition.modify(f); + let consequence = consequence.modify(f); + let alternative = alternative.map(|alternative: Expr| alternative.modify(f)); + new_if(condition, consequence, alternative) + } + ) +} + +fn modify_index(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_index().map( + |expr: (Expr, Expr)| { + let (object, index) = expr; + let object = object.modify(f); + let index = index.modify(f); + new_index(object, index) + } + ) +} + +fn modify_member_access(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_member_access().map( + |expr: (Expr, Quoted)| { + let (object, name) = expr; + let object = object.modify(f); + new_member_access(object, name) + } + ) +} + +fn modify_method_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_method_call().map( + |expr: (Expr, Quoted, [UnresolvedType], [Expr])| { + let (object, name, generics, arguments) = expr; + let object = object.modify(f); + let arguments = arguments.map(|arg: Expr| arg.modify(f)); + new_method_call(object, name, generics, arguments) + } + ) +} + +fn modify_repeated_element_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_repeated_element_array().map( + |expr: (Expr, Expr)| { + let (expr, length) = expr; + let expr = expr.modify(f); + let length = length.modify(f); + new_repeated_element_array(expr, length) + } + ) +} + +fn modify_repeated_element_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_repeated_element_slice().map( + |expr: (Expr, Expr)| { + let (expr, length) = expr; + let expr = expr.modify(f); + let length = length.modify(f); + new_repeated_element_slice(expr, length) + } + ) +} + +fn modify_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_slice().map( + |exprs: [Expr]| { + let exprs = modify_expressions(exprs, f); + new_slice(exprs) + } + ) +} + +fn modify_tuple(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_tuple().map( + |exprs: [Expr]| { + let exprs = modify_expressions(exprs, f); + new_tuple(exprs) + } + ) +} + +fn modify_unary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_unary_op().map( + |expr: (UnaryOp, Expr)| { + let (op, rhs) = expr; + let rhs = rhs.modify(f); + new_unary_op(op, rhs) + } + ) +} + +fn modify_unsafe(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_unsafe().map( + |exprs: [Expr]| { + let exprs = exprs.map(|expr: Expr| expr.modify(f)); + new_unsafe(exprs) + } + ) +} + +fn modify_expressions(exprs: [Expr], f: fn[Env](Expr) -> Option) -> [Expr] { + exprs.map(|expr: Expr| expr.modify(f)) +} + +fn new_array(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { , }); + quote { [$exprs]}.as_expr().unwrap() +} + +fn new_assert(predicate: Expr, msg: Option) -> Expr { + if msg.is_some() { + let msg = msg.unwrap(); + quote { assert($predicate, $msg) }.as_expr().unwrap() + } else { + quote { assert($predicate) }.as_expr().unwrap() + } +} + +fn new_assign(lhs: Expr, rhs: Expr) -> Expr { + quote { $lhs = $rhs }.as_expr().unwrap() +} + +fn new_binary_op(lhs: Expr, op: BinaryOp, rhs: Expr) -> Expr { + let op = op.quoted(); + quote { ($lhs) $op ($rhs) }.as_expr().unwrap() +} + +fn new_block(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { ; }); + quote { { $exprs }}.as_expr().unwrap() +} + +fn new_cast(expr: Expr, typ: UnresolvedType) -> Expr { + quote { ($expr) as $typ }.as_expr().unwrap() +} + +fn new_comptime(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { ; }); + quote { comptime { $exprs }}.as_expr().unwrap() +} + +fn new_if(condition: Expr, consequence: Expr, alternative: Option) -> Expr { + if alternative.is_some() { + let alternative = alternative.unwrap(); + quote { if $condition { $consequence } else { $alternative }}.as_expr().unwrap() + } else { + quote { if $condition { $consequence } }.as_expr().unwrap() + } +} + +fn new_index(object: Expr, index: Expr) -> Expr { + quote { $object[$index] }.as_expr().unwrap() +} + +fn new_member_access(object: Expr, name: Quoted) -> Expr { + quote { $object.$name }.as_expr().unwrap() +} + +fn new_function_call(function: Expr, arguments: [Expr]) -> Expr { + let arguments = join_expressions(arguments, quote { , }); + + quote { $function($arguments) }.as_expr().unwrap() +} + +fn new_method_call(object: Expr, name: Quoted, generics: [UnresolvedType], arguments: [Expr]) -> Expr { + let arguments = join_expressions(arguments, quote { , }); + + if generics.len() == 0 { + quote { $object.$name($arguments) }.as_expr().unwrap() + } else { + let generics = generics.map(|generic| quote { $generic }).join(quote { , }); + quote { $object.$name::<$generics>($arguments) }.as_expr().unwrap() + } +} + +fn new_repeated_element_array(expr: Expr, length: Expr) -> Expr { + quote { [$expr; $length] }.as_expr().unwrap() +} + +fn new_repeated_element_slice(expr: Expr, length: Expr) -> Expr { + quote { &[$expr; $length] }.as_expr().unwrap() +} + +fn new_slice(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { , }); + quote { &[$exprs]}.as_expr().unwrap() +} + +fn new_tuple(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { , }); + quote { ($exprs) }.as_expr().unwrap() +} + +fn new_unary_op(op: UnaryOp, rhs: Expr) -> Expr { + let op = op.quoted(); + quote { $op($rhs) }.as_expr().unwrap() +} + +fn new_unsafe(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { ; }); + quote { unsafe { $exprs }}.as_expr().unwrap() +} + +fn join_expressions(exprs: [Expr], separator: Quoted) -> Quoted { + exprs.map(|expr: Expr| expr.quoted()).join(separator) } mod tests { diff --git a/noir_stdlib/src/meta/function_def.nr b/noir_stdlib/src/meta/function_def.nr index 7ac8803e7e4..84f9c60b304 100644 --- a/noir_stdlib/src/meta/function_def.nr +++ b/noir_stdlib/src/meta/function_def.nr @@ -1,4 +1,9 @@ impl FunctionDefinition { + #[builtin(function_def_body)] + // docs:start:body + fn body(self) -> Expr {} + // docs:end:body + #[builtin(function_def_name)] // docs:start:name fn name(self) -> Quoted {} @@ -16,7 +21,7 @@ impl FunctionDefinition { #[builtin(function_def_set_body)] // docs:start:set_body - fn set_body(self, body: Quoted) {} + fn set_body(self, body: Expr) {} // docs:end:set_body #[builtin(function_def_set_parameters)] diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index be1b12540c9..24398054467 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -7,6 +7,7 @@ mod trait_constraint; mod trait_def; mod trait_impl; mod typ; +mod typed_expr; mod quoted; mod unresolved_type; diff --git a/noir_stdlib/src/meta/op.nr b/noir_stdlib/src/meta/op.nr index 9c892c4d80b..f3060a1648b 100644 --- a/noir_stdlib/src/meta/op.nr +++ b/noir_stdlib/src/meta/op.nr @@ -26,6 +26,22 @@ impl UnaryOp { // docs:end:is_dereference self.op == 3 } + + // docs:start:unary_quoted + pub fn quoted(self) -> Quoted { + // docs:end:unary_quoted + if self.is_minus() { + quote { - } + } else if self.is_not() { + quote { ! } + } else if self.is_mutable_reference() { + quote { &mut } + } else if self.is_dereference() { + quote { * } + } else { + crate::mem::zeroed() + } + } } struct BinaryOp { @@ -128,5 +144,45 @@ impl BinaryOp { // docs:end:is_modulo self.op == 15 } + + // docs:start:binary_quoted + pub fn quoted(self) -> Quoted { + // docs:end:binary_quoted + if self.is_add() { + quote { + } + } else if self.is_subtract() { + quote { - } + } else if self.is_multiply() { + quote { * } + } else if self.is_divide() { + quote { / } + } else if self.is_equal() { + quote { == } + } else if self.is_not_equal() { + quote { != } + } else if self.is_less_than() { + quote { < } + } else if self.is_less_than_or_equal() { + quote { <= } + } else if self.is_greater_than() { + quote { > } + } else if self.is_greater_than_or_equal() { + quote { >= } + } else if self.is_and() { + quote { & } + } else if self.is_or() { + quote { | } + } else if self.is_xor() { + quote { ^ } + } else if self.is_shift_right() { + quote { >> } + } else if self.is_shift_left() { + quote { << } + } else if self.is_modulo() { + quote { % } + } else { + crate::mem::zeroed() + } + } } diff --git a/noir_stdlib/src/meta/typed_expr.nr b/noir_stdlib/src/meta/typed_expr.nr new file mode 100644 index 00000000000..8daede97438 --- /dev/null +++ b/noir_stdlib/src/meta/typed_expr.nr @@ -0,0 +1,8 @@ +use crate::option::Option; + +impl TypedExpr { + #[builtin(typed_expr_as_function_definition)] + // docs:start:as_function_definition + fn as_function_definition(self) -> Option {} + // docs:end:as_function_definition +} diff --git a/noir_stdlib/src/option.nr b/noir_stdlib/src/option.nr index 8d6d9ef970d..5b6b36679f8 100644 --- a/noir_stdlib/src/option.nr +++ b/noir_stdlib/src/option.nr @@ -116,7 +116,7 @@ impl Option { } /// If self is Some, return self. Otherwise, return `default()`. - pub fn or_else(self, default: fn[Env]() -> Self) -> Self { + pub fn or_else(self, default: fn[Env]() -> Self) -> Self { if self._is_some { self } else { default() } } diff --git a/noir_stdlib/src/slice.nr b/noir_stdlib/src/slice.nr index f9aa98a9ecd..66c69db65f0 100644 --- a/noir_stdlib/src/slice.nr +++ b/noir_stdlib/src/slice.nr @@ -1,6 +1,7 @@ use crate::append::Append; impl [T] { + /// Returns the length of the slice. #[builtin(array_len)] pub fn len(self) -> u32 {} @@ -37,8 +38,8 @@ impl [T] { #[builtin(slice_remove)] pub fn remove(self, index: u32) -> (Self, T) {} - // Append each element of the `other` slice to the end of `self`. - // This returns a new slice and leaves both input slices unchanged. + /// Append each element of the `other` slice to the end of `self`. + /// This returns a new slice and leaves both input slices unchanged. pub fn append(mut self, other: Self) -> Self { for elem in other { self = self.push_back(elem); diff --git a/test_programs/compile_failure/unspecified_generic/Nargo.toml b/test_programs/compile_failure/unspecified_generic/Nargo.toml new file mode 100644 index 00000000000..15b97018f2d --- /dev/null +++ b/test_programs/compile_failure/unspecified_generic/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "unspecified_generic" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] diff --git a/test_programs/compile_failure/unspecified_generic/src/main.nr b/test_programs/compile_failure/unspecified_generic/src/main.nr new file mode 100644 index 00000000000..f26d794d567 --- /dev/null +++ b/test_programs/compile_failure/unspecified_generic/src/main.nr @@ -0,0 +1,5 @@ +fn foo() {} + +fn main() { + foo(); +} diff --git a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr index 6cd13ab0e2f..ad8dff6c7b9 100644 --- a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr +++ b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr @@ -7,6 +7,9 @@ fn main() { let _ = split_first([1, 2, 3]); let _ = push_multiple([1, 2, 3]); + + test_constant_folding::<10>(); + test_non_constant_folding::<10, 20>(); } fn split_first(array: [T; N]) -> (T, [T; N - 1]) { @@ -101,3 +104,31 @@ fn demo_proof() -> Equiv, (Equiv, (), W, () let p3: Equiv, (), W, ()> = add_equiv_r::(p3_sub); equiv_trans(equiv_trans(p1, p2), p3) } + +fn test_constant_folding() { + // N + C1 - C2 = N + (C1 - C2) + let _: W = W:: {}; + + // N - C1 + C2 = N - (C1 - C2) + let _: W = W:: {}; + + // N * C1 / C2 = N * (C1 / C2) + let _: W = W:: {}; + + // N / C1 * C2 = N / (C1 / C2) + let _: W = W:: {}; +} + +fn test_non_constant_folding() { + // N + M - M = N + let _: W = W:: {}; + + // N - M + M = N + let _: W = W:: {}; + + // N * M / M = N + let _: W = W:: {}; + + // N / M * M = N + let _: W = W:: {}; +} diff --git a/test_programs/compile_success_empty/comptime_function_definition/src/main.nr b/test_programs/compile_success_empty/comptime_function_definition/src/main.nr index ce09ba86e49..06da5a1dde5 100644 --- a/test_programs/compile_success_empty/comptime_function_definition/src/main.nr +++ b/test_programs/compile_success_empty/comptime_function_definition/src/main.nr @@ -49,7 +49,7 @@ comptime fn mutate_add_one(f: FunctionDefinition) { assert_eq(f.return_type(), type_of(0)); // fn add_one(x: Field) -> Field { x + 1 } - f.set_body(quote { x + 1 }); + f.set_body(quote { x + 1 }.as_expr().unwrap()); } fn main() { diff --git a/test_programs/compile_success_empty/comptime_keccak/Nargo.toml b/test_programs/compile_success_empty/comptime_keccak/Nargo.toml new file mode 100644 index 00000000000..47c8654804d --- /dev/null +++ b/test_programs/compile_success_empty/comptime_keccak/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "comptime_keccak" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/comptime_keccak/src/main.nr b/test_programs/compile_success_empty/comptime_keccak/src/main.nr new file mode 100644 index 00000000000..3cde32b6ba9 --- /dev/null +++ b/test_programs/compile_success_empty/comptime_keccak/src/main.nr @@ -0,0 +1,31 @@ +// Tests a very simple program. +// +// The features being tested is keccak256 in brillig +fn main() { + comptime + { + let x = 0xbd; + let result = [ + 0x5a, 0x50, 0x2f, 0x9f, 0xca, 0x46, 0x7b, 0x26, 0x6d, 0x5b, 0x78, 0x33, 0x65, 0x19, 0x37, 0xe8, 0x05, 0x27, 0x0c, 0xa3, 0xf3, 0xaf, 0x1c, 0x0d, 0xd2, 0x46, 0x2d, 0xca, 0x4b, 0x3b, 0x1a, 0xbf + ]; + // We use the `as` keyword here to denote the fact that we want to take just the first byte from the x Field + // The padding is taken care of by the program + let digest = keccak256([x as u8], 1); + assert(digest == result); + //#1399: variable message size + let message_size = 4; + let hash_a = keccak256([1, 2, 3, 4], message_size); + let hash_b = keccak256([1, 2, 3, 4, 0, 0, 0, 0], message_size); + + assert(hash_a == hash_b); + + let message_size_big = 8; + let hash_c = keccak256([1, 2, 3, 4, 0, 0, 0, 0], message_size_big); + + assert(hash_a != hash_c); + } +} + +comptime fn keccak256(data: [u8; N], msg_len: u32) -> [u8; 32] { + std::hash::keccak256(data, msg_len) +} diff --git a/test_programs/compile_success_empty/embedded_curve_add_simplification/Nargo.toml b/test_programs/compile_success_empty/embedded_curve_add_simplification/Nargo.toml new file mode 100644 index 00000000000..02586f5e926 --- /dev/null +++ b/test_programs/compile_success_empty/embedded_curve_add_simplification/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "embedded_curve_add_simplification" +type = "bin" +authors = [""] +compiler_version = ">=0.23.0" + +[dependencies] diff --git a/test_programs/compile_success_empty/embedded_curve_add_simplification/src/main.nr b/test_programs/compile_success_empty/embedded_curve_add_simplification/src/main.nr new file mode 100644 index 00000000000..39992a6454b --- /dev/null +++ b/test_programs/compile_success_empty/embedded_curve_add_simplification/src/main.nr @@ -0,0 +1,11 @@ +use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}; + +fn main() { + let zero = EmbeddedCurvePoint::point_at_infinity(); + let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + + assert(g1 + zero == g1); + assert(g1 - g1 == zero); + assert(g1 - zero == g1); + assert(zero + zero == zero); +} diff --git a/test_programs/compile_success_empty/inject_context_attribute/Nargo.toml b/test_programs/compile_success_empty/inject_context_attribute/Nargo.toml new file mode 100644 index 00000000000..10f9cb1f9e2 --- /dev/null +++ b/test_programs/compile_success_empty/inject_context_attribute/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "inject_context_attribute" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/compile_success_empty/inject_context_attribute/src/main.nr b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr new file mode 100644 index 00000000000..594b37ce072 --- /dev/null +++ b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr @@ -0,0 +1,53 @@ +struct Context { + value: Field, +} + +#[inject_context] +fn foo(x: Field) { + if true { + // 20 + 1 => 21 + bar(qux(x + 1)); + } else { + assert(false); + } +} + +#[inject_context] +fn bar(x: Field) { + let expected = _context.value; + assert_eq(x, expected); +} + +#[inject_context] +fn qux(x: Field) -> Field { + // 21 * 2 => 42 + x * 2 +} + +fn inject_context(f: FunctionDefinition) { + // Add a `_context: Context` parameter to the function + let parameters = f.parameters(); + let parameters = parameters.push_front((quote { _context }, quote { Context }.as_type())); + f.set_parameters(parameters); + + // Create a new body where every function call has `_context` added to the list of arguments. + let body = f.body().modify(mapping_function); + f.set_body(body); +} + +fn mapping_function(expr: Expr) -> Option { + expr.as_function_call().map( + |func_call: (Expr, [Expr])| { + let (name, arguments) = func_call; + let arguments = arguments.push_front(quote { _context }.as_expr().unwrap()); + let arguments = arguments.map(|arg: Expr| arg.quoted()).join(quote { , }); + quote { $name($arguments) }.as_expr().unwrap() + } + ) +} + +fn main() { + let context = Context { value: 42 }; + foo(context, 20); +} + diff --git a/test_programs/compile_success_empty/method_call_regression/src/main.nr b/test_programs/compile_success_empty/method_call_regression/src/main.nr index 88b8dc57196..de58271cae6 100644 --- a/test_programs/compile_success_empty/method_call_regression/src/main.nr +++ b/test_programs/compile_success_empty/method_call_regression/src/main.nr @@ -1,14 +1,14 @@ fn main() { - // s: Struct - let s = Struct { b: () }; + // s: Struct + let s = Struct { a: 0, b: () }; // Regression for #3089 s.foo(); } -struct Struct { b: B } +struct Struct { a: A, b: B } // Before the fix, this candidate is searched first, binding ? to `u8` permanently. -impl Struct { +impl Struct { fn foo(self) {} } @@ -18,6 +18,6 @@ impl Struct { // With the fix, the type of `s` correctly no longer changes until a // method is actually selected. So this candidate is now valid since // `Struct` unifies with `Struct` with `? = u32`. -impl Struct { +impl Struct { fn foo(self) {} } diff --git a/test_programs/compile_success_empty/option/src/main.nr b/test_programs/compile_success_empty/option/src/main.nr index c5f321256b1..d135b2d88b8 100644 --- a/test_programs/compile_success_empty/option/src/main.nr +++ b/test_programs/compile_success_empty/option/src/main.nr @@ -1,6 +1,6 @@ fn main() { let ten = 10; // giving this a name, to ensure that the Option functions work with closures - let none = Option::none(); + let none: Option = Option::none(); let some = Option::some(3); assert(none.is_none()); diff --git a/test_programs/compile_success_empty/poseidon2_simplification/Nargo.toml b/test_programs/compile_success_empty/poseidon2_simplification/Nargo.toml new file mode 100644 index 00000000000..fbf2c11b220 --- /dev/null +++ b/test_programs/compile_success_empty/poseidon2_simplification/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "poseidon2_simplification" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/poseidon2_simplification/src/main.nr b/test_programs/compile_success_empty/poseidon2_simplification/src/main.nr new file mode 100644 index 00000000000..423dfcc4d3b --- /dev/null +++ b/test_programs/compile_success_empty/poseidon2_simplification/src/main.nr @@ -0,0 +1,7 @@ +use std::hash::poseidon2; + +fn main() { + let digest = poseidon2::Poseidon2::hash([0], 1); + let expected_digest = 0x2710144414c3a5f2354f4c08d52ed655b9fe253b4bf12cb9ad3de693d9b1db11; + assert_eq(digest, expected_digest); +} diff --git a/test_programs/compile_success_empty/schnorr_simplification/Nargo.toml b/test_programs/compile_success_empty/schnorr_simplification/Nargo.toml new file mode 100644 index 00000000000..599f06ac3d2 --- /dev/null +++ b/test_programs/compile_success_empty/schnorr_simplification/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "schnorr_simplification" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/compile_success_empty/schnorr_simplification/src/main.nr b/test_programs/compile_success_empty/schnorr_simplification/src/main.nr new file mode 100644 index 00000000000..e1095cd7fe2 --- /dev/null +++ b/test_programs/compile_success_empty/schnorr_simplification/src/main.nr @@ -0,0 +1,78 @@ +use std::embedded_curve_ops; + +// Note: If main has any unsized types, then the verifier will never be able +// to figure out the circuit instance +fn main() { + let message = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let pub_key_x = 0x04b260954662e97f00cab9adb773a259097f7a274b83b113532bce27fa3fb96a; + let pub_key_y = 0x2fd51571db6c08666b0edfbfbc57d432068bccd0110a39b166ab243da0037197; + let signature = [ + 1, + 13, + 119, + 112, + 212, + 39, + 233, + 41, + 84, + 235, + 255, + 93, + 245, + 172, + 186, + 83, + 157, + 253, + 76, + 77, + 33, + 128, + 178, + 15, + 214, + 67, + 105, + 107, + 177, + 234, + 77, + 48, + 27, + 237, + 155, + 84, + 39, + 84, + 247, + 27, + 22, + 8, + 176, + 230, + 24, + 115, + 145, + 220, + 254, + 122, + 135, + 179, + 171, + 4, + 214, + 202, + 64, + 199, + 19, + 84, + 239, + 138, + 124, + 12 + ]; + + let valid_signature = std::schnorr::verify_signature(pub_key_x, pub_key_y, signature, message); + assert(valid_signature); +} diff --git a/test_programs/execution_success/sha256_regression/Nargo.toml b/test_programs/execution_success/sha256_regression/Nargo.toml new file mode 100644 index 00000000000..ce98d000bcb --- /dev/null +++ b/test_programs/execution_success/sha256_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sha256_regression" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_regression/Prover.toml b/test_programs/execution_success/sha256_regression/Prover.toml new file mode 100644 index 00000000000..ea0a0f2e8a7 --- /dev/null +++ b/test_programs/execution_success/sha256_regression/Prover.toml @@ -0,0 +1,14 @@ +msg_just_over_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116] +msg_multiple_of_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99] +msg_just_under_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59] +msg_big_not_block_multiple = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99, 116, 58, 77, 101, 115, 115, 97, 103, 101, 45, 73, 100, 58, 68, 97, 116, 101, 58, 116, 111, 59, 32, 98, 61] +msg_big_with_padding = [48, 130, 1, 37, 2, 1, 0, 48, 11, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 48, 130, 1, 17, 48, 37, 2, 1, 1, 4, 32, 176, 223, 31, 133, 108, 84, 158, 102, 70, 11, 165, 175, 196, 12, 201, 130, 25, 131, 46, 125, 156, 194, 28, 23, 55, 133, 157, 164, 135, 136, 220, 78, 48, 37, 2, 1, 2, 4, 32, 190, 82, 180, 235, 222, 33, 79, 50, 152, 136, 142, 35, 116, 224, 6, 242, 156, 141, 128, 248, 10, 61, 98, 86, 248, 45, 207, 210, 90, 232, 175, 38, 48, 37, 2, 1, 3, 4, 32, 0, 194, 104, 108, 237, 246, 97, 230, 116, 198, 69, 110, 26, 87, 17, 89, 110, 199, 108, 250, 36, 21, 39, 87, 110, 102, 250, 213, 174, 131, 171, 174, 48, 37, 2, 1, 11, 4, 32, 136, 155, 87, 144, 111, 15, 152, 127, 85, 25, 154, 81, 20, 58, 51, 75, 193, 116, 234, 0, 60, 30, 29, 30, 183, 141, 72, 247, 255, 203, 100, 124, 48, 37, 2, 1, 12, 4, 32, 41, 234, 106, 78, 31, 11, 114, 137, 237, 17, 92, 71, 134, 47, 62, 78, 189, 233, 201, 214, 53, 4, 47, 189, 201, 133, 6, 121, 34, 131, 64, 142, 48, 37, 2, 1, 13, 4, 32, 91, 222, 210, 193, 62, 222, 104, 82, 36, 41, 138, 253, 70, 15, 148, 208, 156, 45, 105, 171, 241, 195, 185, 43, 217, 162, 146, 201, 222, 89, 238, 38, 48, 37, 2, 1, 14, 4, 32, 76, 123, 216, 13, 51, 227, 72, 245, 59, 193, 238, 166, 103, 49, 23, 164, 171, 188, 194, 197, 156, 187, 249, 28, 198, 95, 69, 15, 182, 56, 54, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +msg_big_no_padding = [48, 130, 1, 37, 2, 1, 0, 48, 11, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 48, 130, 1, 17, 48, 37, 2, 1, 1, 4, 32, 176, 223, 31, 133, 108, 84, 158, 102, 70, 11, 165, 175, 196, 12, 201, 130, 25, 131, 46, 125, 156, 194, 28, 23, 55, 133, 157, 164, 135, 136, 220, 78, 48, 37, 2, 1, 2, 4, 32, 190, 82, 180, 235, 222, 33, 79, 50, 152, 136, 142, 35, 116, 224, 6, 242, 156, 141, 128, 248, 10, 61, 98, 86, 248, 45, 207, 210, 90, 232, 175, 38, 48, 37, 2, 1, 3, 4, 32, 0, 194, 104, 108, 237, 246, 97, 230, 116, 198, 69, 110, 26, 87, 17, 89, 110, 199, 108, 250, 36, 21, 39, 87, 110, 102, 250, 213, 174, 131, 171, 174, 48, 37, 2, 1, 11, 4, 32, 136, 155, 87, 144, 111, 15, 152, 127, 85, 25, 154, 81, 20, 58, 51, 75, 193, 116, 234, 0, 60, 30, 29, 30, 183, 141, 72, 247, 255, 203, 100, 124, 48, 37, 2, 1, 12, 4, 32, 41, 234, 106, 78, 31, 11, 114, 137, 237, 17, 92, 71, 134, 47, 62, 78, 189, 233, 201, 214, 53, 4, 47, 189, 201, 133, 6, 121, 34, 131, 64, 142, 48, 37, 2, 1, 13, 4, 32, 91, 222, 210, 193, 62, 222, 104, 82, 36, 41, 138, 253, 70, 15, 148, 208, 156, 45, 105, 171, 241, 195, 185, 43, 217, 162, 146, 201, 222, 89, 238, 38, 48, 37, 2, 1, 14, 4, 32, 76, 123, 216, 13, 51, 227, 72, 245, 59, 193, 238, 166, 103, 49, 23, 164, 171, 188, 194, 197, 156, 187, 249, 28, 198, 95, 69, 15, 182, 56, 54, 38] +message_size = 297 + +# Results matched against ethers library +result_just_over_block = [91, 122, 146, 93, 52, 109, 133, 148, 171, 61, 156, 70, 189, 238, 153, 7, 222, 184, 94, 24, 65, 114, 192, 244, 207, 199, 87, 232, 192, 224, 171, 207] +result_multiple_of_block = [116, 90, 151, 31, 78, 22, 138, 180, 211, 189, 69, 76, 227, 200, 155, 29, 59, 123, 154, 60, 47, 153, 203, 129, 157, 251, 48, 2, 79, 11, 65, 47] +result_just_under_block = [143, 140, 76, 173, 222, 123, 102, 68, 70, 149, 207, 43, 39, 61, 34, 79, 216, 252, 213, 165, 74, 16, 110, 74, 29, 64, 138, 167, 30, 1, 9, 119] +result_big = [112, 144, 73, 182, 208, 98, 9, 238, 54, 229, 61, 145, 222, 17, 72, 62, 148, 222, 186, 55, 192, 82, 220, 35, 66, 47, 193, 200, 22, 38, 26, 186] +result_big_with_padding = [32, 85, 108, 174, 127, 112, 178, 182, 8, 43, 134, 123, 192, 211, 131, 66, 184, 240, 212, 181, 240, 180, 106, 195, 24, 117, 54, 129, 19, 10, 250, 53] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_regression/src/main.nr b/test_programs/execution_success/sha256_regression/src/main.nr new file mode 100644 index 00000000000..83049640ac4 --- /dev/null +++ b/test_programs/execution_success/sha256_regression/src/main.nr @@ -0,0 +1,39 @@ +// A bunch of different test cases for sha256_var in the stdlib +fn main( + msg_just_over_block: [u8; 68], + result_just_over_block: pub [u8; 32], + msg_multiple_of_block: [u8; 448], + result_multiple_of_block: pub [u8; 32], + // We want to make sure we are testing a message with a size >= 57 but < 64 + msg_just_under_block: [u8; 60], + result_just_under_block: pub [u8; 32], + msg_big_not_block_multiple: [u8; 472], + result_big: pub [u8; 32], + // This message is only 297 elements and we want to hash only a variable amount + msg_big_with_padding: [u8; 700], + // This is the same as `msg_big_with_padding` but with no padding + msg_big_no_padding: [u8; 297], + message_size: u64, + result_big_with_padding: pub [u8; 32] +) { + let hash = std::hash::sha256_var(msg_just_over_block, msg_just_over_block.len() as u64); + assert_eq(hash, result_just_over_block); + + let hash = std::hash::sha256_var(msg_multiple_of_block, msg_multiple_of_block.len() as u64); + assert_eq(hash, result_multiple_of_block); + + let hash = std::hash::sha256_var(msg_just_under_block, msg_just_under_block.len() as u64); + assert_eq(hash, result_just_under_block); + + let hash = std::hash::sha256_var( + msg_big_not_block_multiple, + msg_big_not_block_multiple.len() as u64 + ); + assert_eq(hash, result_big); + + let hash_padding = std::hash::sha256_var(msg_big_with_padding, message_size); + assert_eq(hash_padding, result_big_with_padding); + + let hash_no_padding = std::hash::sha256_var(msg_big_no_padding, message_size); + assert_eq(hash_no_padding, result_big_with_padding); +} diff --git a/test_programs/execution_success/sha256_var_size_regression/Nargo.toml b/test_programs/execution_success/sha256_var_size_regression/Nargo.toml new file mode 100644 index 00000000000..3e141ee5d5f --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sha256_var_size_regression" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_size_regression/Prover.toml b/test_programs/execution_success/sha256_var_size_regression/Prover.toml new file mode 100644 index 00000000000..df632a42858 --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/Prover.toml @@ -0,0 +1,3 @@ +enable = [true, false] +foo = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] +toggle = false diff --git a/test_programs/execution_success/sha256_var_size_regression/src/main.nr b/test_programs/execution_success/sha256_var_size_regression/src/main.nr new file mode 100644 index 00000000000..de1c2b23c5f --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/src/main.nr @@ -0,0 +1,17 @@ +global NUM_HASHES = 2; + +fn main(foo: [u8; 95], toggle: bool, enable: [bool; NUM_HASHES]) { + let mut result = [[0; 32]; NUM_HASHES]; + let mut const_result = [[0; 32]; NUM_HASHES]; + let size: Field = 93 + toggle as Field * 2; + for i in 0..NUM_HASHES { + if enable[i] { + result[i] = std::sha256::sha256_var(foo, size as u64); + const_result[i] = std::sha256::sha256_var(foo, 93); + } + } + + for i in 0..NUM_HASHES { + assert_eq(result[i], const_result[i]); + } +} diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index 329e97dc9d9..7248d51ca9a 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -15,6 +15,54 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_array() { + comptime + { + let expr = quote { [1, 2, 4] }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let elems = expr.as_array().unwrap(); + assert_eq(elems.len(), 3); + assert_eq(elems[0].as_integer().unwrap(), (2, false)); + assert_eq(elems[1].as_integer().unwrap(), (4, false)); + assert_eq(elems[2].as_integer().unwrap(), (8, false)); + } + } + + #[test] + fn test_expr_as_assert() { + comptime + { + let expr = quote { assert(true) }.as_expr().unwrap(); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_bool().unwrap(), true); + assert(msg.is_none()); + + let expr = quote { assert(false, "oops") }.as_expr().unwrap(); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_bool().unwrap(), false); + assert(msg.is_some()); + } + } + + #[test] + fn test_expr_mutate_for_assert() { + comptime + { + let expr = quote { assert(1) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_integer().unwrap(), (2, false)); + assert(msg.is_none()); + + let expr = quote { assert(1, 2) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_integer().unwrap(), (2, false)); + assert_eq(msg.unwrap().as_integer().unwrap(), (4, false)); + } + } + #[test] fn test_expr_as_assign() { comptime @@ -26,6 +74,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_assign() { + comptime + { + let expr = quote { { a = 1; } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let exprs = expr.as_block().unwrap(); + let (_lhs, rhs) = exprs[0].as_assign().unwrap(); + assert_eq(rhs.as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_block() { comptime @@ -43,6 +103,24 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_block() { + comptime + { + let expr = quote { { 1; 4; 23 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let exprs = expr.as_block().unwrap(); + assert_eq(exprs.len(), 3); + assert_eq(exprs[0].as_integer().unwrap(), (2, false)); + assert_eq(exprs[1].as_integer().unwrap(), (8, false)); + assert_eq(exprs[2].as_integer().unwrap(), (46, false)); + + assert(exprs[0].has_semicolon()); + assert(exprs[1].has_semicolon()); + assert(!exprs[2].has_semicolon()); + } + } + #[test] fn test_expr_as_method_call() { comptime @@ -61,6 +139,25 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_method_call() { + comptime + { + let expr = quote { foo.bar(3, 4) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + + let (_object, name, generics, arguments) = expr.as_method_call().unwrap(); + + assert_eq(name, quote { bar }); + + assert_eq(generics.len(), 0); + + assert_eq(arguments.len(), 2); + assert_eq(arguments[0].as_integer().unwrap(), (6, false)); + assert_eq(arguments[1].as_integer().unwrap(), (8, false)); + } + } + #[test] fn test_expr_as_integer() { comptime @@ -73,6 +170,17 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_integer() { + comptime + { + let expr = quote { 1 }.as_expr().unwrap(); + let expr = expr.modify(times_two); + + assert_eq((2, false), expr.as_integer().unwrap()); + } + } + #[test] fn test_expr_as_binary_op() { comptime @@ -96,6 +204,20 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_binary_op() { + comptime + { + let expr = quote { 3 + 4 }.as_expr().unwrap(); + let expr = expr.modify(times_two); + + let (lhs, op, rhs) = expr.as_binary_op().unwrap(); + assert_eq(lhs.as_integer().unwrap(), (6, false)); + assert(op.is_add()); + assert_eq(rhs.as_integer().unwrap(), (8, false)); + } + } + #[test] fn test_expr_as_bool() { comptime @@ -119,6 +241,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_cast() { + comptime + { + let expr = quote { 1 as Field }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (expr, typ) = expr.as_cast().unwrap(); + assert_eq(expr.as_integer().unwrap(), (2, false)); + assert(typ.is_field()); + } + } + #[test] fn test_expr_as_comptime() { comptime @@ -129,6 +263,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_comptime() { + comptime + { + let expr = quote { comptime { 1; 4; 23 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let exprs = expr.as_comptime().unwrap(); + assert_eq(exprs.len(), 3); + assert_eq(exprs[0].as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_comptime_as_statement() { comptime @@ -157,6 +303,18 @@ mod tests { } // docs:end:as_expr_example + #[test] + fn test_expr_mutate_for_function_call() { + comptime + { + let expr = quote { foo(42) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (_function, args) = expr.as_function_call().unwrap(); + assert_eq(args.len(), 1); + assert_eq(args[0].as_integer().unwrap(), (84, false)); + } + } + #[test] fn test_expr_as_if() { comptime @@ -171,6 +329,29 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_if() { + comptime + { + let expr = quote { if 1 { 2 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (condition, consequence, alternative) = expr.as_if().unwrap(); + assert_eq(condition.as_integer().unwrap(), (2, false)); + let consequence = consequence.as_block().unwrap()[0].as_block().unwrap()[0]; + assert_eq(consequence.as_integer().unwrap(), (4, false)); + assert(alternative.is_none()); + + let expr = quote { if 1 { 2 } else { 3 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (condition, consequence, alternative) = expr.as_if().unwrap(); + assert_eq(condition.as_integer().unwrap(), (2, false)); + let consequence = consequence.as_block().unwrap()[0].as_block().unwrap()[0]; + assert_eq(consequence.as_integer().unwrap(), (4, false)); + let alternative = alternative.unwrap().as_block().unwrap()[0].as_block().unwrap()[0]; + assert_eq(alternative.as_integer().unwrap(), (6, false)); + } + } + #[test] fn test_expr_as_index() { comptime @@ -180,6 +361,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_index() { + comptime + { + let expr = quote { 1[2] }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (object, index) = expr.as_index().unwrap(); + assert_eq(object.as_integer().unwrap(), (2, false)); + assert_eq(index.as_integer().unwrap(), (4, false)); + } + } + #[test] fn test_expr_as_member_access() { comptime @@ -190,6 +383,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_member_access() { + comptime + { + let expr = quote { 1.bar }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (expr, name) = expr.as_member_access().unwrap(); + assert_eq(name, quote { bar }); + assert_eq(expr.as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_member_access_with_an_lvalue() { comptime @@ -213,6 +418,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_repeated_element_array() { + comptime + { + let expr = quote { [1; 3] }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (expr, length) = expr.as_repeated_element_array().unwrap(); + assert_eq(expr.as_integer().unwrap(), (2, false)); + assert_eq(length.as_integer().unwrap(), (6, false)); + } + } + #[test] fn test_expr_as_repeated_element_slice() { comptime @@ -224,6 +441,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_repeated_element_slice() { + comptime + { + let expr = quote { &[1; 3] }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (expr, length) = expr.as_repeated_element_slice().unwrap(); + assert_eq(expr.as_integer().unwrap(), (2, false)); + assert_eq(length.as_integer().unwrap(), (6, false)); + } + } + #[test] fn test_expr_as_slice() { comptime @@ -237,6 +466,20 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_slice() { + comptime + { + let expr = quote { &[1, 3, 5] }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let elems = expr.as_slice().unwrap(); + assert_eq(elems.len(), 3); + assert_eq(elems[0].as_integer().unwrap(), (2, false)); + assert_eq(elems[1].as_integer().unwrap(), (6, false)); + assert_eq(elems[2].as_integer().unwrap(), (10, false)); + } + } + #[test] fn test_expr_as_tuple() { comptime @@ -247,6 +490,19 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_tuple() { + comptime + { + let expr = quote { (1, 2) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let tuple_exprs = expr.as_tuple().unwrap(); + assert_eq(tuple_exprs.len(), 2); + assert_eq(tuple_exprs[0].as_integer().unwrap(), (2, false)); + assert_eq(tuple_exprs[1].as_integer().unwrap(), (4, false)); + } + } + #[test] fn test_expr_as_unary_op() { comptime @@ -258,6 +514,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_unary_op() { + comptime + { + let expr = quote { -(1) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (op, expr) = expr.as_unary_op().unwrap(); + assert(op.is_minus()); + assert_eq(expr.as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_unsafe() { comptime @@ -268,6 +536,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_unsafe() { + comptime + { + let expr = quote { unsafe { 1; 4; 23 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let exprs = expr.as_unsafe().unwrap(); + assert_eq(exprs.len(), 3); + assert_eq(exprs[0].as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_is_break() { comptime @@ -297,6 +577,17 @@ mod tests { } } + #[test] + fn test_resolve_to_function_definition() { + comptime + { + let expr = quote { times_two }.as_expr().unwrap(); + let func = expr.resolve().as_function_definition().unwrap(); + assert_eq(func.name(), quote { times_two }); + assert_eq(func.parameters().len(), 1); + } + } + comptime fn get_unary_op(quoted: Quoted) -> UnaryOp { let expr = quoted.as_expr().unwrap(); let (op, _) = expr.as_unary_op().unwrap(); @@ -308,6 +599,16 @@ mod tests { let (_, op, _) = expr.as_binary_op().unwrap(); op } + + comptime fn times_two(expr: Expr) -> Option { + expr.as_integer().and_then( + |integer: (Field, bool)| { + let (value, _) = integer; + let value = value * 2; + quote { $value }.as_expr() + } + ) + } } fn main() {} diff --git a/test_programs/noir_test_success/embedded_curve_ops/src/main.nr b/test_programs/noir_test_success/embedded_curve_ops/src/main.nr index 0c2c333fa62..760df58c34a 100644 --- a/test_programs/noir_test_success/embedded_curve_ops/src/main.nr +++ b/test_programs/noir_test_success/embedded_curve_ops/src/main.nr @@ -4,7 +4,6 @@ use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_sca fn test_infinite_point() { let zero = EmbeddedCurvePoint::point_at_infinity(); - let zero = EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true }; let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; let g2 = g1 + g1; diff --git a/tooling/lsp/Cargo.toml b/tooling/lsp/Cargo.toml index 353a6ade904..c15895d801f 100644 --- a/tooling/lsp/Cargo.toml +++ b/tooling/lsp/Cargo.toml @@ -31,7 +31,7 @@ async-lsp = { workspace = true, features = ["omni-trait"] } serde_with = "3.2.0" thiserror.workspace = true fm.workspace = true -rayon = "1.8.0" +rayon.workspace = true fxhash.workspace = true convert_case = "0.6.0" diff --git a/tooling/lsp/src/notifications/mod.rs b/tooling/lsp/src/notifications/mod.rs index 4d2186badc3..8b030c9e0aa 100644 --- a/tooling/lsp/src/notifications/mod.rs +++ b/tooling/lsp/src/notifications/mod.rs @@ -2,7 +2,7 @@ use std::ops::ControlFlow; use crate::insert_all_files_for_workspace_into_file_manager; use async_lsp::{ErrorCode, LanguageClient, ResponseError}; -use noirc_driver::{check_crate, file_manager_with_stdlib}; +use noirc_driver::{check_crate, file_manager_with_stdlib, CheckOptions}; use noirc_errors::{DiagnosticKind, FileDiagnostic}; use crate::types::{ @@ -132,7 +132,11 @@ pub(crate) fn process_workspace_for_noir_document( let (mut context, crate_id) = crate::prepare_package(&workspace_file_manager, &parsed_files, package); - let file_diagnostics = match check_crate(&mut context, crate_id, &Default::default()) { + let options = CheckOptions { + error_on_unused_imports: package.error_on_unused_imports(), + ..Default::default() + }; + let file_diagnostics = match check_crate(&mut context, crate_id, &options) { Ok(((), warnings)) => warnings, Err(errors_and_warnings) => errors_and_warnings, }; diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index c61f92795ad..f339ed19622 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -449,7 +449,10 @@ impl<'a> NodeFinder<'a> { StatementKind::Semi(expression) => { self.find_in_expression(expression); } - StatementKind::Break | StatementKind::Continue | StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -501,6 +504,7 @@ impl<'a> NodeFinder<'a> { self.find_in_expression(index); } LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), + LValue::Interned(..) => (), } } @@ -565,7 +569,10 @@ impl<'a> NodeFinder<'a> { ExpressionKind::AsTraitPath(as_trait_path) => { self.find_in_as_trait_path(as_trait_path); } - ExpressionKind::Quote(_) | ExpressionKind::Resolved(_) | ExpressionKind::Error => (), + ExpressionKind::Quote(_) + | ExpressionKind::Resolved(_) + | ExpressionKind::Interned(_) + | ExpressionKind::Error => (), } // "foo." (no identifier afterwards) is parsed as the expression on the left hand-side of the dot. @@ -739,6 +746,7 @@ impl<'a> NodeFinder<'a> { | UnresolvedTypeData::Bool | UnresolvedTypeData::Unit | UnresolvedTypeData::Resolved(_) + | UnresolvedTypeData::Interned(_) | UnresolvedTypeData::Error => (), } } diff --git a/tooling/lsp/src/requests/completion/builtins.rs b/tooling/lsp/src/requests/completion/builtins.rs index b9c4ce2358a..54340075b15 100644 --- a/tooling/lsp/src/requests/completion/builtins.rs +++ b/tooling/lsp/src/requests/completion/builtins.rs @@ -3,7 +3,10 @@ use noirc_frontend::token::Keyword; use strum::IntoEnumIterator; use super::{ - completion_items::{simple_completion_item, snippet_completion_item}, + completion_items::{ + completion_item_with_trigger_parameter_hints_command, simple_completion_item, + snippet_completion_item, + }, kinds::FunctionCompletionKind, name_matches, NodeFinder, }; @@ -31,12 +34,16 @@ impl<'a> NodeFinder<'a> { } } - self.completion_items.push(snippet_completion_item( - label, - CompletionItemKind::FUNCTION, - insert_text, - description, - )); + self.completion_items.push( + completion_item_with_trigger_parameter_hints_command( + snippet_completion_item( + label, + CompletionItemKind::FUNCTION, + insert_text, + description, + ), + ), + ); } } } @@ -90,10 +97,12 @@ pub(super) fn keyword_builtin_type(keyword: &Keyword) -> Option<&'static str> { Keyword::Expr => Some("Expr"), Keyword::Field => Some("Field"), Keyword::FunctionDefinition => Some("FunctionDefinition"), + Keyword::Quoted => Some("Quoted"), Keyword::StructDefinition => Some("StructDefinition"), Keyword::TraitConstraint => Some("TraitConstraint"), Keyword::TraitDefinition => Some("TraitDefinition"), Keyword::TraitImpl => Some("TraitImpl"), + Keyword::TypedExpr => Some("TypedExpr"), Keyword::TypeType => Some("Type"), Keyword::UnresolvedType => Some("UnresolvedType"), @@ -122,7 +131,6 @@ pub(super) fn keyword_builtin_type(keyword: &Keyword) -> Option<&'static str> { | Keyword::Module | Keyword::Mut | Keyword::Pub - | Keyword::Quoted | Keyword::Return | Keyword::ReturnData | Keyword::String @@ -200,6 +208,7 @@ pub(super) fn keyword_builtin_function(keyword: &Keyword) -> Option InlayHintCollector<'a> { } StatementKind::Comptime(statement) => self.collect_in_statement(statement), StatementKind::Semi(expression) => self.collect_in_expression(expression), - StatementKind::Break => (), - StatementKind::Continue => (), - StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -303,6 +304,7 @@ impl<'a> InlayHintCollector<'a> { | ExpressionKind::Variable(..) | ExpressionKind::Quote(..) | ExpressionKind::Resolved(..) + | ExpressionKind::Interned(..) | ExpressionKind::Error => (), } } @@ -692,6 +694,7 @@ fn get_expression_name(expression: &Expression) -> Option { | ExpressionKind::Unquote(..) | ExpressionKind::Comptime(..) | ExpressionKind::Resolved(..) + | ExpressionKind::Interned(..) | ExpressionKind::Literal(..) | ExpressionKind::Unsafe(..) | ExpressionKind::Error => None, diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index 8aa74fe9900..25676b57381 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -7,7 +7,10 @@ use lsp_types::{ }; use noirc_errors::{Location, Span}; use noirc_frontend::{ - ast::{CallExpression, Expression, FunctionReturnType, MethodCallExpression}, + ast::{ + CallExpression, ConstrainKind, ConstrainStatement, Expression, ExpressionKind, + FunctionReturnType, MethodCallExpression, + }, hir_def::{function::FuncMeta, stmt::HirPattern}, macros_api::NodeInterner, node_interner::ReferenceId, @@ -104,6 +107,55 @@ impl<'a> SignatureFinder<'a> { ); } + pub(super) fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { + self.find_in_expression(&constrain_statement.0); + + if let Some(exp) = &constrain_statement.1 { + self.find_in_expression(exp); + } + + if self.signature_help.is_some() { + return; + } + + let arguments_span = if let Some(expr) = &constrain_statement.1 { + Span::from(constrain_statement.0.span.start()..expr.span.end()) + } else { + constrain_statement.0.span + }; + + if !self.includes_span(arguments_span) { + return; + } + + match constrain_statement.2 { + ConstrainKind::Assert => { + let mut arguments = vec![constrain_statement.0.clone()]; + if let Some(expr) = &constrain_statement.1 { + arguments.push(expr.clone()); + } + + let active_parameter = self.compute_active_parameter(&arguments); + let signature_information = self.assert_signature_information(active_parameter); + self.set_signature_help(signature_information); + } + ConstrainKind::AssertEq => { + if let ExpressionKind::Infix(infix) = &constrain_statement.0.kind { + let mut arguments = vec![infix.lhs.clone(), infix.rhs.clone()]; + if let Some(expr) = &constrain_statement.1 { + arguments.push(expr.clone()); + } + + let active_parameter = self.compute_active_parameter(&arguments); + let signature_information = + self.assert_eq_signature_information(active_parameter); + self.set_signature_help(signature_information); + } + } + ConstrainKind::Constrain => (), + } + } + fn try_compute_signature_help( &mut self, arguments: &[Expression], @@ -119,18 +171,7 @@ impl<'a> SignatureFinder<'a> { return; } - let mut active_parameter = None; - for (index, arg) in arguments.iter().enumerate() { - if self.includes_span(arg.span) || arg.span.start() as usize >= self.byte_index { - active_parameter = Some(index as u32); - break; - } - } - - if active_parameter.is_none() { - active_parameter = Some(arguments.len() as u32); - } - + let active_parameter = self.compute_active_parameter(arguments); let location = Location::new(name_span, self.file); // Check if the call references a named function @@ -267,6 +308,60 @@ impl<'a> SignatureFinder<'a> { } } + fn assert_signature_information(&self, active_parameter: Option) -> SignatureInformation { + self.hardcoded_signature_information( + active_parameter, + "assert", + &["predicate: bool", "[failure_message: str]"], + ) + } + + fn assert_eq_signature_information( + &self, + active_parameter: Option, + ) -> SignatureInformation { + self.hardcoded_signature_information( + active_parameter, + "assert_eq", + &["lhs: T", "rhs: T", "[failure_message: str]"], + ) + } + + fn hardcoded_signature_information( + &self, + active_parameter: Option, + name: &str, + arguments: &[&str], + ) -> SignatureInformation { + let mut label = String::new(); + let mut parameters = Vec::new(); + + label.push_str(name); + label.push('('); + for (index, typ) in arguments.iter().enumerate() { + if index > 0 { + label.push_str(", "); + } + + let parameter_start = label.chars().count(); + label.push_str(typ); + let parameter_end = label.chars().count(); + + parameters.push(ParameterInformation { + label: ParameterLabel::LabelOffsets([parameter_start as u32, parameter_end as u32]), + documentation: None, + }); + } + label.push(')'); + + SignatureInformation { + label, + documentation: None, + parameters: Some(parameters), + active_parameter, + } + } + fn hir_pattern_to_argument(&self, pattern: &HirPattern, text: &mut String) { match pattern { HirPattern::Identifier(hir_ident) => { @@ -286,6 +381,22 @@ impl<'a> SignatureFinder<'a> { self.signature_help = Some(signature_help); } + fn compute_active_parameter(&self, arguments: &[Expression]) -> Option { + let mut active_parameter = None; + for (index, arg) in arguments.iter().enumerate() { + if self.includes_span(arg.span) || arg.span.start() as usize >= self.byte_index { + active_parameter = Some(index as u32); + break; + } + } + + if active_parameter.is_none() { + active_parameter = Some(arguments.len() as u32); + } + + active_parameter + } + fn includes_span(&self, span: Span) -> bool { span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize } diff --git a/tooling/lsp/src/requests/signature_help/tests.rs b/tooling/lsp/src/requests/signature_help/tests.rs index c48ee159084..4b3f3c38156 100644 --- a/tooling/lsp/src/requests/signature_help/tests.rs +++ b/tooling/lsp/src/requests/signature_help/tests.rs @@ -193,4 +193,51 @@ mod signature_help_tests { assert_eq!(signature.active_parameter, Some(0)); } + + #[test] + async fn test_signature_help_for_assert() { + let src = r#" + fn bar() { + assert(>|<1, "hello"); + } + "#; + + let signature_help = get_signature_help(src).await; + assert_eq!(signature_help.signatures.len(), 1); + + let signature = &signature_help.signatures[0]; + assert_eq!(signature.label, "assert(predicate: bool, [failure_message: str])"); + + let params = signature.parameters.as_ref().unwrap(); + assert_eq!(params.len(), 2); + + check_label(&signature.label, ¶ms[0].label, "predicate: bool"); + check_label(&signature.label, ¶ms[1].label, "[failure_message: str]"); + + assert_eq!(signature.active_parameter, Some(0)); + } + + #[test] + async fn test_signature_help_for_assert_eq() { + let src = r#" + fn bar() { + assert_eq(>|])"); + + let params = signature.parameters.as_ref().unwrap(); + assert_eq!(params.len(), 3); + + check_label(&signature.label, ¶ms[0].label, "lhs: T"); + check_label(&signature.label, ¶ms[1].label, "rhs: T"); + check_label(&signature.label, ¶ms[2].label, "[failure_message: str]"); + + assert_eq!(signature.active_parameter, Some(0)); + } } diff --git a/tooling/lsp/src/requests/signature_help/traversal.rs b/tooling/lsp/src/requests/signature_help/traversal.rs index 22f92a86124..e9b050fc965 100644 --- a/tooling/lsp/src/requests/signature_help/traversal.rs +++ b/tooling/lsp/src/requests/signature_help/traversal.rs @@ -4,11 +4,11 @@ use super::SignatureFinder; use noirc_frontend::{ ast::{ - ArrayLiteral, AssignStatement, BlockExpression, CastExpression, ConstrainStatement, - ConstructorExpression, Expression, ExpressionKind, ForLoopStatement, ForRange, - IfExpression, IndexExpression, InfixExpression, LValue, Lambda, LetStatement, Literal, - MemberAccessExpression, NoirFunction, NoirTrait, NoirTraitImpl, Statement, StatementKind, - TraitImplItem, TraitItem, TypeImpl, + ArrayLiteral, AssignStatement, BlockExpression, CastExpression, ConstructorExpression, + Expression, ExpressionKind, ForLoopStatement, ForRange, IfExpression, IndexExpression, + InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, + NoirFunction, NoirTrait, NoirTraitImpl, Statement, StatementKind, TraitImplItem, TraitItem, + TypeImpl, }, parser::{Item, ItemKind}, ParsedModule, @@ -125,7 +125,10 @@ impl<'a> SignatureFinder<'a> { StatementKind::Semi(expression) => { self.find_in_expression(expression); } - StatementKind::Break | StatementKind::Continue | StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -133,14 +136,6 @@ impl<'a> SignatureFinder<'a> { self.find_in_expression(&let_statement.expression); } - pub(super) fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { - self.find_in_expression(&constrain_statement.0); - - if let Some(exp) = &constrain_statement.1 { - self.find_in_expression(exp); - } - } - pub(super) fn find_in_assign_statement(&mut self, assign_statement: &AssignStatement) { self.find_in_lvalue(&assign_statement.lvalue); self.find_in_expression(&assign_statement.expression); @@ -160,6 +155,7 @@ impl<'a> SignatureFinder<'a> { self.find_in_expression(index); } LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), + LValue::Interned(..) => (), } } @@ -232,6 +228,7 @@ impl<'a> SignatureFinder<'a> { | ExpressionKind::AsTraitPath(_) | ExpressionKind::Quote(_) | ExpressionKind::Resolved(_) + | ExpressionKind::Interned(_) | ExpressionKind::Error => (), } } diff --git a/tooling/nargo/Cargo.toml b/tooling/nargo/Cargo.toml index 046eca88099..c5d4bbc9788 100644 --- a/tooling/nargo/Cargo.toml +++ b/tooling/nargo/Cargo.toml @@ -23,7 +23,7 @@ noirc_printable_type.workspace = true iter-extended.workspace = true thiserror.workspace = true tracing.workspace = true -rayon = "1.8.0" +rayon.workspace = true jsonrpc.workspace = true rand.workspace = true serde.workspace = true diff --git a/tooling/nargo/src/package.rs b/tooling/nargo/src/package.rs index f55ca5550a3..cde616a9e32 100644 --- a/tooling/nargo/src/package.rs +++ b/tooling/nargo/src/package.rs @@ -73,4 +73,11 @@ impl Package { pub fn is_library(&self) -> bool { self.package_type == PackageType::Library } + + pub fn error_on_unused_imports(&self) -> bool { + match self.package_type { + PackageType::Library => false, + PackageType::Binary | PackageType::Contract => true, + } + } } diff --git a/tooling/nargo_cli/Cargo.toml b/tooling/nargo_cli/Cargo.toml index f3d9f92caaa..284be56d247 100644 --- a/tooling/nargo_cli/Cargo.toml +++ b/tooling/nargo_cli/Cargo.toml @@ -31,7 +31,7 @@ nargo_fmt.workspace = true nargo_toml.workspace = true noir_lsp.workspace = true noir_debugger.workspace = true -noirc_driver.workspace = true +noirc_driver = { workspace = true, features = ["bn254"] } noirc_frontend = { workspace = true, features = ["bn254"] } noirc_abi.workspace = true noirc_errors.workspace = true @@ -42,7 +42,7 @@ toml.workspace = true serde.workspace = true serde_json.workspace = true prettytable-rs = "0.10" -rayon = "1.8.0" +rayon.workspace = true thiserror.workspace = true tower.workspace = true async-lsp = { workspace = true, features = ["client-monitor", "stdio", "tracing", "tokio"] } diff --git a/tooling/nargo_cli/src/cli/check_cmd.rs b/tooling/nargo_cli/src/cli/check_cmd.rs index 5239070b4d2..1130a82fdfc 100644 --- a/tooling/nargo_cli/src/cli/check_cmd.rs +++ b/tooling/nargo_cli/src/cli/check_cmd.rs @@ -10,7 +10,7 @@ use nargo::{ use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; use noirc_abi::{AbiParameter, AbiType, MAIN_RETURN_NAME}; use noirc_driver::{ - check_crate, compute_function_abi, file_manager_with_stdlib, CompileOptions, + check_crate, compute_function_abi, file_manager_with_stdlib, CheckOptions, CompileOptions, NOIR_ARTIFACT_VERSION_STRING, }; use noirc_frontend::{ @@ -81,7 +81,9 @@ fn check_package( allow_overwrite: bool, ) -> Result { let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate_and_report_errors(&mut context, crate_id, compile_options)?; + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(compile_options, error_on_unused_imports); + check_crate_and_report_errors(&mut context, crate_id, &check_options)?; if package.is_library() || package.is_contract() { // Libraries do not have ABIs while contracts have many, so we cannot generate a `Prover.toml` file. @@ -150,9 +152,10 @@ fn create_input_toml_template( pub(crate) fn check_crate_and_report_errors( context: &mut Context, crate_id: CrateId, - options: &CompileOptions, + check_options: &CheckOptions, ) -> Result<(), CompileError> { - let result = check_crate(context, crate_id, options); + let options = &check_options.compile_options; + let result = check_crate(context, crate_id, check_options); report_errors(result, &context.file_manager, options.deny_warnings, options.silence_warnings) } diff --git a/tooling/nargo_cli/src/cli/export_cmd.rs b/tooling/nargo_cli/src/cli/export_cmd.rs index 19add7f30dc..5721dd33e27 100644 --- a/tooling/nargo_cli/src/cli/export_cmd.rs +++ b/tooling/nargo_cli/src/cli/export_cmd.rs @@ -12,7 +12,7 @@ use nargo::workspace::Workspace; use nargo::{insert_all_files_for_workspace_into_file_manager, parse_all}; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; use noirc_driver::{ - compile_no_check, file_manager_with_stdlib, CompileOptions, CompiledProgram, + compile_no_check, file_manager_with_stdlib, CheckOptions, CompileOptions, CompiledProgram, NOIR_ARTIFACT_VERSION_STRING, }; @@ -83,7 +83,9 @@ fn compile_exported_functions( compile_options: &CompileOptions, ) -> Result<(), CliError> { let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate_and_report_errors(&mut context, crate_id, compile_options)?; + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(compile_options, error_on_unused_imports); + check_crate_and_report_errors(&mut context, crate_id, &check_options)?; let exported_functions = context.get_all_exported_functions_in_crate(&crate_id); diff --git a/tooling/nargo_cli/src/cli/fs/inputs.rs b/tooling/nargo_cli/src/cli/fs/inputs.rs index dee9a00507c..4a7a81431bb 100644 --- a/tooling/nargo_cli/src/cli/fs/inputs.rs +++ b/tooling/nargo_cli/src/cli/fs/inputs.rs @@ -25,7 +25,13 @@ pub(crate) fn read_inputs_from_file>( let file_path = path.as_ref().join(file_name).with_extension(format.ext()); if !file_path.exists() { - return Err(FilesystemError::MissingTomlFile(file_name.to_owned(), file_path)); + if abi.parameters.is_empty() { + // Reading a return value from the `Prover.toml` is optional, + // so if the ABI has no parameters we can skip reading the file if it doesn't exist. + return Ok((BTreeMap::new(), None)); + } else { + return Err(FilesystemError::MissingTomlFile(file_name.to_owned(), file_path)); + } } let input_string = std::fs::read_to_string(file_path).unwrap(); diff --git a/tooling/nargo_cli/src/cli/test_cmd.rs b/tooling/nargo_cli/src/cli/test_cmd.rs index 0d7c8fc8bf7..2b0c0fd58db 100644 --- a/tooling/nargo_cli/src/cli/test_cmd.rs +++ b/tooling/nargo_cli/src/cli/test_cmd.rs @@ -10,7 +10,8 @@ use nargo::{ }; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; use noirc_driver::{ - check_crate, file_manager_with_stdlib, CompileOptions, NOIR_ARTIFACT_VERSION_STRING, + check_crate, file_manager_with_stdlib, CheckOptions, CompileOptions, + NOIR_ARTIFACT_VERSION_STRING, }; use noirc_frontend::{ graph::CrateName, @@ -180,7 +181,9 @@ fn run_test + Default>( // We then need to construct a separate copy for each test. let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate(&mut context, crate_id, compile_options) + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(compile_options, error_on_unused_imports); + check_crate(&mut context, crate_id, &check_options) .expect("Any errors should have occurred when collecting test functions"); let test_functions = context @@ -206,10 +209,12 @@ fn get_tests_in_package( parsed_files: &ParsedFiles, package: &Package, fn_name: FunctionNameMatch, - compile_options: &CompileOptions, + options: &CompileOptions, ) -> Result, CliError> { let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate_and_report_errors(&mut context, crate_id, compile_options)?; + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(options, error_on_unused_imports); + check_crate_and_report_errors(&mut context, crate_id, &check_options)?; Ok(context .get_all_test_functions_in_crate_matching(&crate_id, fn_name) diff --git a/tooling/nargo_fmt/src/rewrite/expr.rs b/tooling/nargo_fmt/src/rewrite/expr.rs index 4fee7d3e197..caa60b17cc2 100644 --- a/tooling/nargo_fmt/src/rewrite/expr.rs +++ b/tooling/nargo_fmt/src/rewrite/expr.rs @@ -175,6 +175,9 @@ pub(crate) fn rewrite( ExpressionKind::Resolved(_) => { unreachable!("ExpressionKind::Resolved should only emitted by the comptime interpreter") } + ExpressionKind::Interned(_) => { + unreachable!("ExpressionKind::Interned should only emitted by the comptime interpreter") + } ExpressionKind::Unquote(expr) => { if matches!(&expr.kind, ExpressionKind::Variable(..)) { format!("${expr}") diff --git a/tooling/nargo_fmt/src/rewrite/typ.rs b/tooling/nargo_fmt/src/rewrite/typ.rs index 8d1e27078a8..6121f8debf6 100644 --- a/tooling/nargo_fmt/src/rewrite/typ.rs +++ b/tooling/nargo_fmt/src/rewrite/typ.rs @@ -73,6 +73,6 @@ pub(crate) fn rewrite(visitor: &FmtVisitor, _shape: Shape, typ: UnresolvedType) | UnresolvedTypeData::FormatString(_, _) | UnresolvedTypeData::Quoted(_) | UnresolvedTypeData::TraitAsType(_, _) => visitor.slice(typ.span).into(), - UnresolvedTypeData::Error => unreachable!(), + UnresolvedTypeData::Interned(_) | UnresolvedTypeData::Error => unreachable!(), } } diff --git a/tooling/nargo_fmt/src/visitor/stmt.rs b/tooling/nargo_fmt/src/visitor/stmt.rs index 8e05fe3f5c5..b5ac14a33b3 100644 --- a/tooling/nargo_fmt/src/visitor/stmt.rs +++ b/tooling/nargo_fmt/src/visitor/stmt.rs @@ -104,6 +104,9 @@ impl super::FmtVisitor<'_> { StatementKind::Break => self.push_rewrite("break;".into(), span), StatementKind::Continue => self.push_rewrite("continue;".into(), span), StatementKind::Comptime(statement) => self.visit_stmt(statement.kind, span, is_last), + StatementKind::Interned(_) => unreachable!( + "StatementKind::Resolved should only emitted by the comptime interpreter" + ), } } }