From 614587731d634014fc2a9278536394e288b6d099 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Tue, 27 Aug 2024 08:35:16 -0400 Subject: [PATCH 01/21] chore(perf): Update to stdlib keccak for reduced Brillig code size (#5827) # Description ## Problem\* Resolves ## Summary\* We can reduce the size of our keccak stdlib method in Brillig. There are operations that are repeated across multiple places, different operations we can perform depending on whether we are in an unconstrained or constrained runtime, and we had an extra unnecessary loop for building our `sliced_buffer` variable. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- noir_stdlib/src/hash/keccak.nr | 81 ++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/noir_stdlib/src/hash/keccak.nr b/noir_stdlib/src/hash/keccak.nr index bb8a9cc2ce2..0c31d238f66 100644 --- a/noir_stdlib/src/hash/keccak.nr +++ b/noir_stdlib/src/hash/keccak.nr @@ -1,19 +1,27 @@ +use crate::collections::vec::Vec; +use crate::runtime::is_unconstrained; + global LIMBS_PER_BLOCK = 17; //BLOCK_SIZE / 8; global NUM_KECCAK_LANES = 25; global BLOCK_SIZE = 136; //(1600 - BITS * 2) / WORD_SIZE; global WORD_SIZE = 8; -use crate::collections::vec::Vec; - #[foreign(keccakf1600)] fn keccakf1600(input: [u64; 25]) -> [u64; 25] {} #[no_predicates] -pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u8; 32] { +pub(crate) fn keccak256(input: [u8; N], message_size: u32) -> [u8; 32] { assert(N >= message_size); - for i in 0..N { - if i >= message_size { - input[i] = 0; + let mut block_bytes = [0; BLOCK_SIZE]; + if is_unconstrained() { + for i in 0..message_size { + block_bytes[i] = input[i]; + } + } else { + for i in 0..N { + if i < message_size { + block_bytes[i] = input[i]; + } } } @@ -24,11 +32,6 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u let real_max_blocks = (message_size + BLOCK_SIZE) / BLOCK_SIZE; let real_blocks_bytes = real_max_blocks * BLOCK_SIZE; - let mut block_bytes = [0; BLOCK_SIZE]; - for i in 0..N { - block_bytes[i] = input[i]; - } - block_bytes[message_size] = 1; block_bytes[real_blocks_bytes - 1] = 0x80; @@ -36,28 +39,28 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u // means we need to swap our byte ordering let num_limbs = max_blocks * LIMBS_PER_BLOCK; //max_blocks_length / WORD_SIZE; for i in 0..num_limbs { - let mut temp = [0; 8]; - for j in 0..8 { - temp[j] = block_bytes[8*i+j]; + let mut temp = [0; WORD_SIZE]; + let word_size_times_i = WORD_SIZE * i; + for j in 0..WORD_SIZE { + temp[j] = block_bytes[word_size_times_i+j]; } - for j in 0..8 { - block_bytes[8 * i + j] = temp[7 - j]; + for j in 0..WORD_SIZE { + block_bytes[word_size_times_i + j] = temp[7 - j]; } } - let byte_size = max_blocks_length; + let mut sliced_buffer = Vec::new(); - for _i in 0..num_limbs { - sliced_buffer.push(0); - } // populate a vector of 64-bit limbs from our byte array for i in 0..num_limbs { + let word_size_times_i = i * WORD_SIZE; + let ws_times_i_plus_7 = word_size_times_i + 7; let mut sliced = 0; - if (i * WORD_SIZE + WORD_SIZE > byte_size) { - let slice_size = byte_size - (i * WORD_SIZE); + if (word_size_times_i + WORD_SIZE > max_blocks_length) { + let slice_size = max_blocks_length - word_size_times_i; let byte_shift = (WORD_SIZE - slice_size) * 8; let mut v = 1; for k in 0..slice_size { - sliced += v * (block_bytes[i * WORD_SIZE+7-k] as Field); + sliced += v * (block_bytes[ws_times_i_plus_7-k] as Field); v *= 256; } let w = 1 << (byte_shift as u8); @@ -65,22 +68,20 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u } else { let mut v = 1; for k in 0..WORD_SIZE { - sliced += v * (block_bytes[i * WORD_SIZE+7-k] as Field); + sliced += v * (block_bytes[ws_times_i_plus_7-k] as Field); v *= 256; } } - sliced_buffer.set(i, sliced as u64); + + sliced_buffer.push(sliced as u64); } //2. sponge_absorb - let num_blocks = max_blocks; let mut state : [u64;NUM_KECCAK_LANES]= [0; NUM_KECCAK_LANES]; - let mut under_block = true; - for i in 0..num_blocks { - if i == real_max_blocks { - under_block = false; - } - if under_block { + // When in an unconstrained runtime we can take advantage of runtime loop bounds, + // thus allowing us to simplify the loop body. + if is_unconstrained() { + for i in 0..real_max_blocks { if (i == 0) { for j in 0..LIMBS_PER_BLOCK { state[j] = sliced_buffer.get(j); @@ -92,6 +93,22 @@ pub(crate) fn keccak256(mut input: [u8; N], message_size: u32) -> [u } state = keccakf1600(state); } + } else { + // `real_max_blocks` is guaranteed to at least be `1` + // We peel out the first block as to avoid a conditional inside of the loop. + // Otherwise, a dynamic predicate can cause a blowup in a constrained runtime. + for j in 0..LIMBS_PER_BLOCK { + state[j] = sliced_buffer.get(j); + } + state = keccakf1600(state); + for i in 1..max_blocks { + if i < real_max_blocks { + for j in 0..LIMBS_PER_BLOCK { + state[j] = state[j] ^ sliced_buffer.get(i * LIMBS_PER_BLOCK + j); + } + state = keccakf1600(state); + } + } } //3. sponge_squeeze From c52dc1c77aedf5a876a858cc5a942c29e868e9e6 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Tue, 27 Aug 2024 09:09:49 -0400 Subject: [PATCH 02/21] fix(sha256): Perform compression per block and utilize ROM instead of RAM when setting up the message block (#5760) # Description ## Problem\* Resolves #5761 Resolution to performance blow-up found with sha256_var. ## Summary\* ### Issue The crux of the blow-up was the result of calling `sha256_compression` inside of the same loop where we build the message block. In the current `sha256_var` algorithm we are looping over the entire message and conditionally checking a msg byte pointer (the pointer into the msg block) to determine whether we have filled up a msg block and should run the sha compression. However, in a circuit this leads to us calling the compression opcode `N` times where `N` is the size of the message. We also were utilize RAM to build our message block when we do not have to do so. We can instead construct our block outside of the circuit and verify that the block has been constructed as we expect with assertion that just require ROM. ### Improvements This PR produces a ~16x improvement in ACIR opcodes a >13x improvement in backend constraints for the following circuit: ```rust fn main(foo: [u8; 95], toggle: bool) { let size: Field = 93 + toggle as Field * 2; let hash = std::sha256::sha256_var(foo, size as u64); println(f"{hash}"); } ``` #### master nargo info: ``` +---------+----------------------------+----------------------+--------------+-----------------+ | Package | Function | Expression Width | ACIR Opcodes | Brillig Opcodes | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | main | Bounded { width: 4 } | 125852 | 243 | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | print_unconstrained | N/A | N/A | 230 | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | directive_integer_quotient | N/A | N/A | 6 | +---------+----------------------------+----------------------+--------------+-----------------+ | sha256 | directive_invert | N/A | N/A | 7 | +---------+----------------------------+----------------------+--------------+-----------------+ ``` bb gates: ``` {"functions": [ { "acir_opcodes": 125852, "circuit_size": 597646, ``` #### This PR Output of nargo info: ``` +----------------------------+----------------------------+----------------------+--------------+-----------------+ | Package | Function | Expression Width | ACIR Opcodes | Brillig Opcodes | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | main | Bounded { width: 4 } | 7768 | 1041 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | build_msg_block_iter | N/A | N/A | 299 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | pad_msg_block | N/A | N/A | 201 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | attach_len_to_msg_block | N/A | N/A | 298 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | print_unconstrained | N/A | N/A | 230 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | directive_integer_quotient | N/A | N/A | 6 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ | sha256_var_size_regression | directive_invert | N/A | N/A | 7 | +----------------------------+----------------------------+----------------------+--------------+-----------------+ ``` bb gates output: ``` {"functions": [ { "acir_opcodes": 7768, "circuit_size": 44663, ``` ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [ ] I have tested the changes locally. - [ ] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- noir_stdlib/src/hash/sha256.nr | 237 ++++++++++++++---- .../sha256_var_size_regression/Nargo.toml | 7 + .../sha256_var_size_regression/Prover.toml | 3 + .../sha256_var_size_regression/src/main.nr | 17 ++ .../Nargo.toml | 7 + .../Prover.toml | 2 + .../src/main.nr | 9 + 7 files changed, 235 insertions(+), 47 deletions(-) create mode 100644 test_programs/execution_success/sha256_var_size_regression/Nargo.toml create mode 100644 test_programs/execution_success/sha256_var_size_regression/Prover.toml create mode 100644 test_programs/execution_success/sha256_var_size_regression/src/main.nr create mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml create mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml create mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr diff --git a/noir_stdlib/src/hash/sha256.nr b/noir_stdlib/src/hash/sha256.nr index 5035be4b73e..55cdd984003 100644 --- a/noir_stdlib/src/hash/sha256.nr +++ b/noir_stdlib/src/hash/sha256.nr @@ -17,82 +17,224 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { sha256_var(msg, N as u64) } +// Convert 64-byte array to array of 16 u32s +fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { + let mut msg32: [u32; 16] = [0; 16]; + + for i in 0..16 { + let mut msg_field: Field = 0; + for j in 0..4 { + msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; + } + msg32[15 - i] = msg_field as u32; + } + + msg32 +} + +unconstrained fn build_msg_block_iter( + msg: [u8; N], + message_size: u64, + mut msg_block: [u8; 64], + msg_start: u32 +) -> ([u8; 64], u64) { + let mut msg_byte_ptr: u64 = 0; // Message byte pointer + for k in msg_start..N { + if k as u64 < message_size { + msg_block[msg_byte_ptr] = msg[k]; + msg_byte_ptr = msg_byte_ptr + 1; + + if msg_byte_ptr == 64 { + msg_byte_ptr = 0; + } + } + } + (msg_block, msg_byte_ptr) +} + +// Verify the block we are compressing was appropriately constructed +fn verify_msg_block( + msg: [u8; N], + message_size: u64, + msg_block: [u8; 64], + msg_start: u32 +) -> u64 { + let mut msg_byte_ptr: u64 = 0; // Message byte pointer + for k in msg_start..N { + if k as u64 < message_size { + assert_eq(msg_block[msg_byte_ptr], msg[k]); + msg_byte_ptr = msg_byte_ptr + 1; + if msg_byte_ptr == 64 { + // Enough to hash block + msg_byte_ptr = 0; + } + } else { + // Need to assert over the msg block in the else case as well + if N < 64 { + assert_eq(msg_block[msg_byte_ptr], 0); + } else { + assert_eq(msg_block[msg_byte_ptr], msg[k]); + } + } + } + msg_byte_ptr +} + +global BLOCK_SIZE = 64; + // Variable size SHA-256 hash pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { - let mut msg_block: [u8; 64] = [0; 64]; + let num_blocks = N / BLOCK_SIZE; + let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE]; let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value - let mut i: u64 = 0; // Message byte pointer - for k in 0..N { - if k as u64 < message_size { - // Populate msg_block - msg_block[i] = msg[k]; - i = i + 1; - if i == 64 { - // Enough to hash block - h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); + let mut msg_byte_ptr = 0; // Pointer into msg_block - i = 0; - } + if num_blocks == 0 { + unsafe { + let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, 0); + msg_block = new_msg_block; + msg_byte_ptr = new_msg_byte_ptr; + } + + if !crate::runtime::is_unconstrained() { + msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, 0); } } + + for i in 0..num_blocks { + unsafe { + let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, BLOCK_SIZE * i); + msg_block = new_msg_block; + msg_byte_ptr = new_msg_byte_ptr; + } + if !crate::runtime::is_unconstrained() { + // Verify the block we are compressing was appropriately constructed + msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i); + } + + // Hash the block + h = sha256_compression(msg_u8_to_u32(msg_block), h); + } + + let last_block = msg_block; // Pad the rest such that we have a [u32; 2] block at the end representing the length - // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i] = 1 << 7; - i = i + 1; + // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). + msg_block[msg_byte_ptr] = 1 << 7; + msg_byte_ptr = msg_byte_ptr + 1; + unsafe { + let (new_msg_block, new_msg_byte_ptr)= pad_msg_block(msg_block, msg_byte_ptr); + msg_block = new_msg_block; + if crate::runtime::is_unconstrained() { + msg_byte_ptr = new_msg_byte_ptr; + } + } + + if !crate::runtime::is_unconstrained() { + for i in 0..64 { + if i as u64 < msg_byte_ptr - 1 { + assert_eq(msg_block[i], last_block[i]); + } + } + assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7); + + // If i >= 57, there aren't enough bits in the current message block to accomplish this, so + // the 1 and 0s fill up the current block, which we then compress accordingly. + // Not enough bits (64) to store length. Fill up with zeros. + for _i in 57..64 { + if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 { + assert_eq(msg_block[msg_byte_ptr], 0); + msg_byte_ptr += 1; + } + } + } + + if msg_byte_ptr >= 57 { + h = sha256_compression(msg_u8_to_u32(msg_block), h); + + msg_byte_ptr = 0; + } + + unsafe { + msg_block = attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size); + } + + if !crate::runtime::is_unconstrained() { + if msg_byte_ptr != 0 { + for i in 0..64 { + if i as u64 < msg_byte_ptr - 1 { + assert_eq(msg_block[i], last_block[i]); + } + } + assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7); + } + + let len = 8 * message_size; + let len_bytes = (len as Field).to_le_bytes(8); + // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). + for _ in 0..64 { + if msg_byte_ptr < 56 { + assert_eq(msg_block[msg_byte_ptr], 0); + msg_byte_ptr = msg_byte_ptr + 1; + } + } + + let mut block_idx = 0; + for i in 56..64 { + assert_eq(msg_block[63 - block_idx], len_bytes[i - 56]); + block_idx = block_idx + 1; + } + } + + hash_final_block(msg_block, h) +} + +unconstrained fn pad_msg_block( + mut msg_block: [u8; 64], + mut msg_byte_ptr: u64 +) -> ([u8; 64], u64) { // If i >= 57, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. - if i >= 57 { + if msg_byte_ptr >= 57 { // Not enough bits (64) to store length. Fill up with zeros. - if i < 64 { - for _i in 57..64 { - if i <= 63 { - msg_block[i] = 0; - i += 1; + if msg_byte_ptr < 64 { + for _ in 57..64 { + if msg_byte_ptr <= 63 { + msg_block[msg_byte_ptr] = 0; + msg_byte_ptr += 1; } } } - h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); - - i = 0; } + (msg_block, msg_byte_ptr) +} +unconstrained fn attach_len_to_msg_block( + mut msg_block: [u8; 64], + mut msg_byte_ptr: u64, + message_size: u64 +) -> [u8; 64] { let len = 8 * message_size; let len_bytes = (len as Field).to_le_bytes(8); for _i in 0..64 { - // In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56). - if i < 56 { - msg_block[i] = 0; - i = i + 1; - } else if i < 64 { + // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). + if msg_byte_ptr < 56 { + msg_block[msg_byte_ptr] = 0; + msg_byte_ptr = msg_byte_ptr + 1; + } else if msg_byte_ptr < 64 { for j in 0..8 { msg_block[63 - j] = len_bytes[j]; } - i += 8; - } - } - hash_final_block(msg_block, h) -} - -// Convert 64-byte array to array of 16 u32s -fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { - let mut msg32: [u32; 16] = [0; 16]; - - for i in 0..16 { - let mut msg_field: Field = 0; - for j in 0..4 { - msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; + msg_byte_ptr += 8; } - msg32[15 - i] = msg_field as u32; } - - msg32 + msg_block } fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] { let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes // Hash final padded block - state = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), state); + state = sha256_compression(msg_u8_to_u32(msg_block), state); // Return final hash as byte array for j in 0..8 { @@ -104,3 +246,4 @@ fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] { out_h } + diff --git a/test_programs/execution_success/sha256_var_size_regression/Nargo.toml b/test_programs/execution_success/sha256_var_size_regression/Nargo.toml new file mode 100644 index 00000000000..3e141ee5d5f --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sha256_var_size_regression" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_size_regression/Prover.toml b/test_programs/execution_success/sha256_var_size_regression/Prover.toml new file mode 100644 index 00000000000..df632a42858 --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/Prover.toml @@ -0,0 +1,3 @@ +enable = [true, false] +foo = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] +toggle = false diff --git a/test_programs/execution_success/sha256_var_size_regression/src/main.nr b/test_programs/execution_success/sha256_var_size_regression/src/main.nr new file mode 100644 index 00000000000..de1c2b23c5f --- /dev/null +++ b/test_programs/execution_success/sha256_var_size_regression/src/main.nr @@ -0,0 +1,17 @@ +global NUM_HASHES = 2; + +fn main(foo: [u8; 95], toggle: bool, enable: [bool; NUM_HASHES]) { + let mut result = [[0; 32]; NUM_HASHES]; + let mut const_result = [[0; 32]; NUM_HASHES]; + let size: Field = 93 + toggle as Field * 2; + for i in 0..NUM_HASHES { + if enable[i] { + result[i] = std::sha256::sha256_var(foo, size as u64); + const_result[i] = std::sha256::sha256_var(foo, 93); + } + } + + for i in 0..NUM_HASHES { + assert_eq(result[i], const_result[i]); + } +} diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml b/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml new file mode 100644 index 00000000000..e8f3e6bbe64 --- /dev/null +++ b/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sha256_var_witness_const_regression" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml b/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml new file mode 100644 index 00000000000..7b91051c1a0 --- /dev/null +++ b/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml @@ -0,0 +1,2 @@ +input = [0, 0] +toggle = false \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr b/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr new file mode 100644 index 00000000000..97c4435d41d --- /dev/null +++ b/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr @@ -0,0 +1,9 @@ +fn main(input: [u8; 2], toggle: bool) { + let size: Field = 1 + toggle as Field; + assert(!toggle); + + let variable_sha = std::sha256::sha256_var(input, size as u64); + let constant_sha = std::sha256::sha256_var(input, 1); + + assert_eq(variable_sha, constant_sha); +} From c0c9cc9653f6dfb64ae661b94c7438f2ab1ed4bd Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Tue, 27 Aug 2024 15:16:27 +0100 Subject: [PATCH 03/21] chore: redo typo PR by nnsW3 (#5834) Thanks nnsW3 for https://github.com/noir-lang/noir/pull/5833. Our policy is to redo typo changes to dissuade metric farming. This is an automated script. --- docs/docs/noir/concepts/data_types/slices.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/noir/concepts/data_types/slices.mdx b/docs/docs/noir/concepts/data_types/slices.mdx index 95da2030843..a0c87c29259 100644 --- a/docs/docs/noir/concepts/data_types/slices.mdx +++ b/docs/docs/noir/concepts/data_types/slices.mdx @@ -20,7 +20,7 @@ fn main() -> pub u32 { } ``` -To write a slice literal, use a preceeding ampersand as in: `&[0; 2]` or +To write a slice literal, use a preceding ampersand as in: `&[0; 2]` or `&[1, 2, 3]`. It is important to note that slices are not references to arrays. In Noir, From 3c778b73d9458ab708df21c850468d708676cde4 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Tue, 27 Aug 2024 11:54:43 -0400 Subject: [PATCH 04/21] chore(perf): Simplify poseidon2 algorithm (#5811) # Description ## Problem\* Resolves Optimizations found while looking exploring other Brillig opts. ## Summary\* There are a couple optimizations here: 1. I noticed that we loop over the cache and do some resetting inside of `squeeze` of `Poseidon2`. However, we our `Hasher` always creates a fresh Poseidon2 object so it seems unnecessary to reset the cache in this way. In Brillig this leads to an extra loop that is essentially unused and blows up the code size of any programs using poseidon in an unconstrained environment. 2. We were writing into a `result` array and returning it from `perform_duplex`. This result was unused inside of `absorb` and we can directly access `self.state` inside of `squeeze`. I no longer return anything from `perform_duplex`. ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [ ] I have tested the changes locally. - [ ] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- noir_stdlib/src/hash/poseidon2.nr | 55 +++++++------------------------ 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/noir_stdlib/src/hash/poseidon2.nr b/noir_stdlib/src/hash/poseidon2.nr index 9626da0cf97..cf820f86370 100644 --- a/noir_stdlib/src/hash/poseidon2.nr +++ b/noir_stdlib/src/hash/poseidon2.nr @@ -26,7 +26,7 @@ impl Poseidon2 { result } - fn perform_duplex(&mut self) -> [Field; RATE] { + fn perform_duplex(&mut self) { // zero-pad the cache for i in 0..RATE { if i >= self.cache_size { @@ -38,61 +38,30 @@ impl Poseidon2 { self.state[i] += self.cache[i]; } self.state = crate::hash::poseidon2_permutation(self.state, 4); - // return `RATE` number of field elements from the sponge state. - let mut result = [0; RATE]; - for i in 0..RATE { - result[i] = self.state[i]; - } - result } fn absorb(&mut self, input: Field) { - if (!self.squeeze_mode) & (self.cache_size == RATE) { + assert(!self.squeeze_mode); + if self.cache_size == RATE { // If we're absorbing, and the cache is full, apply the sponge permutation to compress the cache - let _ = self.perform_duplex(); + self.perform_duplex(); self.cache[0] = input; self.cache_size = 1; - } else if (!self.squeeze_mode) & (self.cache_size != RATE) { + } else { // If we're absorbing, and the cache is not full, add the input into the cache self.cache[self.cache_size] = input; self.cache_size += 1; - } else if self.squeeze_mode { - // If we're in squeeze mode, switch to absorb mode and add the input into the cache. - // N.B. I don't think this code path can be reached?! - self.cache[0] = input; - self.cache_size = 1; - self.squeeze_mode = false; } } fn squeeze(&mut self) -> Field { - if self.squeeze_mode & (self.cache_size == 0) { - // If we're in squeze mode and the cache is empty, there is nothing left to squeeze out of the sponge! - // Switch to absorb mode. - self.squeeze_mode = false; - self.cache_size = 0; - } - if !self.squeeze_mode { - // If we're in absorb mode, apply sponge permutation to compress the cache, populate cache with compressed - // state and switch to squeeze mode. Note: this code block will execute if the previous `if` condition was - // matched - let new_output_elements = self.perform_duplex(); - self.squeeze_mode = true; - for i in 0..RATE { - self.cache[i] = new_output_elements[i]; - } - self.cache_size = RATE; - } - // By this point, we should have a non-empty cache. Pop one item off the top of the cache and return it. - let result = self.cache[0]; - for i in 1..RATE { - if i < self.cache_size { - self.cache[i - 1] = self.cache[i]; - } - } - self.cache_size -= 1; - self.cache[self.cache_size] = 0; - result + assert(!self.squeeze_mode); + // If we're in absorb mode, apply sponge permutation to compress the cache. + self.perform_duplex(); + self.squeeze_mode = true; + + // Pop one item off the top of the permutation and return it. + self.state[0] } fn hash_internal(input: [Field; N], in_len: u32, is_variable_length: bool) -> Field { From 82eb1581251faa9716d762a673fa1b871b3e7be2 Mon Sep 17 00:00:00 2001 From: jfecher Date: Tue, 27 Aug 2024 15:48:39 -0500 Subject: [PATCH 05/21] fix!: Check unused generics are bound (#5840) # Description ## Problem\* ## Summary\* Ran into this while working on arithmetic generics. It's possible to have cases like: ```rs fn foo() {} fn main() { foo(); } ``` For which `N` is unused but we don't issue a "type annotations needed" error for it currently because it isn't present in the instantiated function type of `foo`. This PR fixes this and issues the error. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../src/monomorphization/mod.rs | 8 +++++++ compiler/noirc_frontend/src/node_interner.rs | 4 ++++ noir_stdlib/src/collections/map.nr | 18 +++++++-------- noir_stdlib/src/collections/umap.nr | 22 +++++++++---------- noir_stdlib/src/hash/sha256.nr | 9 +++----- noir_stdlib/src/option.nr | 2 +- .../unspecified_generic/Nargo.toml | 7 ++++++ .../unspecified_generic/src/main.nr | 5 +++++ .../method_call_regression/src/main.nr | 10 ++++----- .../compile_success_empty/option/src/main.nr | 2 +- 10 files changed, 54 insertions(+), 33 deletions(-) create mode 100644 test_programs/compile_failure/unspecified_generic/Nargo.toml create mode 100644 test_programs/compile_failure/unspecified_generic/src/main.nr diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index edb831b2158..79ac02710d9 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -845,6 +845,14 @@ impl<'interner> Monomorphizer<'interner> { return self.resolve_trait_method_expr(expr_id, typ, method); } + // Ensure all instantiation bindings are bound. + // This ensures even unused type variables like `fn foo() {}` have concrete types + if let Some(bindings) = self.interner.try_get_instantiation_bindings(expr_id) { + for (_, binding) in bindings.values() { + Self::check_type(binding, ident.location)?; + } + } + let definition = self.interner.definition(ident.id); let ident = match &definition.kind { DefinitionKind::Function(func_id) => { diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 2c0426f6938..4837028b80f 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1289,6 +1289,10 @@ impl NodeInterner { &self.instantiation_bindings[&expr_id] } + pub fn try_get_instantiation_bindings(&self, expr_id: ExprId) -> Option<&TypeBindings> { + self.instantiation_bindings.get(&expr_id) + } + pub fn get_field_index(&self, expr_id: ExprId) -> usize { self.field_indices[&expr_id] } diff --git a/noir_stdlib/src/collections/map.nr b/noir_stdlib/src/collections/map.nr index bd50f345356..4607b06d667 100644 --- a/noir_stdlib/src/collections/map.nr +++ b/noir_stdlib/src/collections/map.nr @@ -77,10 +77,10 @@ impl Slot { // While conducting lookup, we iterate attempt from 0 to N - 1 due to heuristic, // that if we have went that far without finding desired, // it is very unlikely to be after - performance will be heavily degraded. -impl HashMap { +impl HashMap { // Creates a new instance of HashMap with specified BuildHasher. // docs:start:with_hasher - pub fn with_hasher(_build_hasher: B) -> Self + pub fn with_hasher(_build_hasher: B) -> Self where B: BuildHasher { // docs:end:with_hasher @@ -99,7 +99,7 @@ impl HashMap { // Returns true if the map contains a value for the specified key. // docs:start:contains_key - pub fn contains_key( + pub fn contains_key( self, key: K ) -> bool @@ -183,7 +183,7 @@ impl HashMap { // For each key-value entry applies mutator function. // docs:start:iter_mut - pub fn iter_mut( + pub fn iter_mut( &mut self, f: fn(K, V) -> (K, V) ) @@ -208,7 +208,7 @@ impl HashMap { // For each key applies mutator function. // docs:start:iter_keys_mut - pub fn iter_keys_mut( + pub fn iter_keys_mut( &mut self, f: fn(K) -> K ) @@ -278,7 +278,7 @@ impl HashMap { // Get the value by key. If it does not exist, returns none(). // docs:start:get - pub fn get( + pub fn get( self, key: K ) -> Option @@ -313,7 +313,7 @@ impl HashMap { // Insert key-value entry. In case key was already present, value is overridden. // docs:start:insert - pub fn insert( + pub fn insert( &mut self, key: K, value: V @@ -356,7 +356,7 @@ impl HashMap { // Removes a key-value entry. If key is not present, HashMap remains unchanged. // docs:start:remove - pub fn remove( + pub fn remove( &mut self, key: K ) @@ -388,7 +388,7 @@ impl HashMap { } // Apply HashMap's hasher onto key to obtain pre-hash for probing. - fn hash( + fn hash( self, key: K ) -> u32 diff --git a/noir_stdlib/src/collections/umap.nr b/noir_stdlib/src/collections/umap.nr index 86ae79ea644..c552c053a92 100644 --- a/noir_stdlib/src/collections/umap.nr +++ b/noir_stdlib/src/collections/umap.nr @@ -76,10 +76,10 @@ impl Slot { // While conducting lookup, we iterate attempt from 0 to N - 1 due to heuristic, // that if we have went that far without finding desired, // it is very unlikely to be after - performance will be heavily degraded. -impl UHashMap { +impl UHashMap { // Creates a new instance of UHashMap with specified BuildHasher. // docs:start:with_hasher - pub fn with_hasher(_build_hasher: B) -> Self + pub fn with_hasher(_build_hasher: B) -> Self where B: BuildHasher { // docs:end:with_hasher @@ -88,7 +88,7 @@ impl UHashMap { Self { _table, _len, _build_hasher } } - pub fn with_hasher_and_capacity(_build_hasher: B, capacity: u32) -> Self + pub fn with_hasher_and_capacity(_build_hasher: B, capacity: u32) -> Self where B: BuildHasher { // docs:end:with_hasher @@ -110,7 +110,7 @@ impl UHashMap { // Returns true if the map contains a value for the specified key. // docs:start:contains_key - pub fn contains_key( + pub fn contains_key( self, key: K ) -> bool @@ -194,7 +194,7 @@ impl UHashMap { // For each key-value entry applies mutator function. // docs:start:iter_mut - unconstrained pub fn iter_mut( + unconstrained pub fn iter_mut( &mut self, f: fn(K, V) -> (K, V) ) @@ -216,7 +216,7 @@ impl UHashMap { // For each key applies mutator function. // docs:start:iter_keys_mut - unconstrained pub fn iter_keys_mut( + unconstrained pub fn iter_keys_mut( &mut self, f: fn(K) -> K ) @@ -283,7 +283,7 @@ impl UHashMap { // Get the value by key. If it does not exist, returns none(). // docs:start:get - unconstrained pub fn get( + unconstrained pub fn get( self, key: K ) -> Option @@ -315,7 +315,7 @@ impl UHashMap { // Insert key-value entry. In case key was already present, value is overridden. // docs:start:insert - unconstrained pub fn insert( + unconstrained pub fn insert( &mut self, key: K, value: V @@ -353,7 +353,7 @@ impl UHashMap { } } - unconstrained fn try_resize(&mut self) + unconstrained fn try_resize(&mut self) where B: BuildHasher, K: Eq + Hash, H: Hasher { if self.len() + 1 >= self.capacity() / 2 { let capacity = self.capacity() * 2; @@ -368,7 +368,7 @@ impl UHashMap { // Removes a key-value entry. If key is not present, UHashMap remains unchanged. // docs:start:remove - unconstrained pub fn remove( + unconstrained pub fn remove( &mut self, key: K ) @@ -397,7 +397,7 @@ impl UHashMap { } // Apply UHashMap's hasher onto key to obtain pre-hash for probing. - fn hash( + fn hash( self, key: K ) -> u32 diff --git a/noir_stdlib/src/hash/sha256.nr b/noir_stdlib/src/hash/sha256.nr index 55cdd984003..352df656068 100644 --- a/noir_stdlib/src/hash/sha256.nr +++ b/noir_stdlib/src/hash/sha256.nr @@ -122,7 +122,7 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { msg_block[msg_byte_ptr] = 1 << 7; msg_byte_ptr = msg_byte_ptr + 1; unsafe { - let (new_msg_block, new_msg_byte_ptr)= pad_msg_block(msg_block, msg_byte_ptr); + let (new_msg_block, new_msg_byte_ptr) = pad_msg_block(msg_block, msg_byte_ptr); msg_block = new_msg_block; if crate::runtime::is_unconstrained() { msg_byte_ptr = new_msg_byte_ptr; @@ -188,10 +188,7 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { hash_final_block(msg_block, h) } -unconstrained fn pad_msg_block( - mut msg_block: [u8; 64], - mut msg_byte_ptr: u64 -) -> ([u8; 64], u64) { +unconstrained fn pad_msg_block(mut msg_block: [u8; 64], mut msg_byte_ptr: u64) -> ([u8; 64], u64) { // If i >= 57, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. if msg_byte_ptr >= 57 { @@ -208,7 +205,7 @@ unconstrained fn pad_msg_block( (msg_block, msg_byte_ptr) } -unconstrained fn attach_len_to_msg_block( +unconstrained fn attach_len_to_msg_block( mut msg_block: [u8; 64], mut msg_byte_ptr: u64, message_size: u64 diff --git a/noir_stdlib/src/option.nr b/noir_stdlib/src/option.nr index 8d6d9ef970d..5b6b36679f8 100644 --- a/noir_stdlib/src/option.nr +++ b/noir_stdlib/src/option.nr @@ -116,7 +116,7 @@ impl Option { } /// If self is Some, return self. Otherwise, return `default()`. - pub fn or_else(self, default: fn[Env]() -> Self) -> Self { + pub fn or_else(self, default: fn[Env]() -> Self) -> Self { if self._is_some { self } else { default() } } diff --git a/test_programs/compile_failure/unspecified_generic/Nargo.toml b/test_programs/compile_failure/unspecified_generic/Nargo.toml new file mode 100644 index 00000000000..15b97018f2d --- /dev/null +++ b/test_programs/compile_failure/unspecified_generic/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "unspecified_generic" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] diff --git a/test_programs/compile_failure/unspecified_generic/src/main.nr b/test_programs/compile_failure/unspecified_generic/src/main.nr new file mode 100644 index 00000000000..f26d794d567 --- /dev/null +++ b/test_programs/compile_failure/unspecified_generic/src/main.nr @@ -0,0 +1,5 @@ +fn foo() {} + +fn main() { + foo(); +} diff --git a/test_programs/compile_success_empty/method_call_regression/src/main.nr b/test_programs/compile_success_empty/method_call_regression/src/main.nr index 88b8dc57196..de58271cae6 100644 --- a/test_programs/compile_success_empty/method_call_regression/src/main.nr +++ b/test_programs/compile_success_empty/method_call_regression/src/main.nr @@ -1,14 +1,14 @@ fn main() { - // s: Struct - let s = Struct { b: () }; + // s: Struct + let s = Struct { a: 0, b: () }; // Regression for #3089 s.foo(); } -struct Struct { b: B } +struct Struct { a: A, b: B } // Before the fix, this candidate is searched first, binding ? to `u8` permanently. -impl Struct { +impl Struct { fn foo(self) {} } @@ -18,6 +18,6 @@ impl Struct { // With the fix, the type of `s` correctly no longer changes until a // method is actually selected. So this candidate is now valid since // `Struct` unifies with `Struct` with `? = u32`. -impl Struct { +impl Struct { fn foo(self) {} } diff --git a/test_programs/compile_success_empty/option/src/main.nr b/test_programs/compile_success_empty/option/src/main.nr index c5f321256b1..d135b2d88b8 100644 --- a/test_programs/compile_success_empty/option/src/main.nr +++ b/test_programs/compile_success_empty/option/src/main.nr @@ -1,6 +1,6 @@ fn main() { let ten = 10; // giving this a name, to ensure that the Option functions work with closures - let none = Option::none(); + let none: Option = Option::none(); let some = Option::some(3); assert(none.is_none()); From 716a774d3564d57282cd96238f2584d06a964617 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:23:17 +0100 Subject: [PATCH 06/21] chore: don't require empty `Prover.toml` for programs with zero arguments but a return value (#5845) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ents but a return value # Description ## Problem\* Resolves ## Summary\* This is a small QOL change but we don't need to have a `Prover.toml` file in the case where the circuit has no inputs but potential return values. In this case we can tolerate a missing `Prover.toml` file. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- tooling/nargo_cli/src/cli/fs/inputs.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tooling/nargo_cli/src/cli/fs/inputs.rs b/tooling/nargo_cli/src/cli/fs/inputs.rs index dee9a00507c..4a7a81431bb 100644 --- a/tooling/nargo_cli/src/cli/fs/inputs.rs +++ b/tooling/nargo_cli/src/cli/fs/inputs.rs @@ -25,7 +25,13 @@ pub(crate) fn read_inputs_from_file>( let file_path = path.as_ref().join(file_name).with_extension(format.ext()); if !file_path.exists() { - return Err(FilesystemError::MissingTomlFile(file_name.to_owned(), file_path)); + if abi.parameters.is_empty() { + // Reading a return value from the `Prover.toml` is optional, + // so if the ABI has no parameters we can skip reading the file if it doesn't exist. + return Ok((BTreeMap::new(), None)); + } else { + return Err(FilesystemError::MissingTomlFile(file_name.to_owned(), file_path)); + } } let input_string = std::fs::read_to_string(file_path).unwrap(); From 2823ba7242db788ca1d7f6e7a48be2f1de62f278 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:11:56 +0100 Subject: [PATCH 07/21] feat: simplify constant calls to `poseidon2_permutation`, `schnorr_verify` and `embedded_curve_add` (#5140) # Description ## Problem\* Resolves ## Summary\* Now that we have rust implementations of all blackbox functions, we can perform any pedersen operations with constant inputs at compile-time. We couldn't do this before as it would have required an async initialisation step for the wasm compiler but that is no longer an issue. I've done this by using compile-time flags to select the blackbox solver we're using. If nargo is compiled with a non-bn254 field then it will not perform these optimizations. The ultimate plan is for this solver to be specified by the end user in a similar way to the way the field definition is. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- Cargo.lock | 1 + Cargo.toml | 2 +- acvm-repo/acir_field/Cargo.toml | 2 +- compiler/noirc_driver/Cargo.toml | 4 + compiler/noirc_evaluator/Cargo.toml | 7 +- .../src/ssa/acir_gen/acir_ir/acir_variable.rs | 7 +- .../src/ssa/ir/instruction/call.rs | 45 +++-- .../src/ssa/ir/instruction/call/blackbox.rs | 190 ++++++++++++++++++ compiler/noirc_frontend/Cargo.toml | 2 +- .../Nargo.toml | 7 + .../src/main.nr | 11 + .../poseidon2_simplification/Nargo.toml | 7 + .../poseidon2_simplification/src/main.nr | 7 + .../schnorr_simplification/Nargo.toml | 6 + .../schnorr_simplification/src/main.nr | 78 +++++++ .../embedded_curve_ops/src/main.nr | 1 - tooling/nargo_cli/Cargo.toml | 2 +- 17 files changed, 357 insertions(+), 22 deletions(-) create mode 100644 compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs create mode 100644 test_programs/compile_success_empty/embedded_curve_add_simplification/Nargo.toml create mode 100644 test_programs/compile_success_empty/embedded_curve_add_simplification/src/main.nr create mode 100644 test_programs/compile_success_empty/poseidon2_simplification/Nargo.toml create mode 100644 test_programs/compile_success_empty/poseidon2_simplification/src/main.nr create mode 100644 test_programs/compile_success_empty/schnorr_simplification/Nargo.toml create mode 100644 test_programs/compile_success_empty/schnorr_simplification/src/main.nr diff --git a/Cargo.lock b/Cargo.lock index f78fbfede27..2cf79c40303 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3004,6 +3004,7 @@ version = "0.33.0" dependencies = [ "acvm", "bn254_blackbox_solver", + "cfg-if 1.0.0", "chrono", "fxhash", "im", diff --git a/Cargo.toml b/Cargo.toml index 52cb1012b71..bf5739ebbe8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,7 +130,7 @@ criterion = "0.5.0" # https://github.com/tikv/pprof-rs/pull/172 pprof = { version = "0.13", features = ["flamegraph", "criterion"] } - +cfg-if = "1.0.0" dirs = "4" serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0" diff --git a/acvm-repo/acir_field/Cargo.toml b/acvm-repo/acir_field/Cargo.toml index c1cffc1334e..acc34457bc9 100644 --- a/acvm-repo/acir_field/Cargo.toml +++ b/acvm-repo/acir_field/Cargo.toml @@ -24,7 +24,7 @@ ark-bn254.workspace = true ark-bls12-381 = { workspace = true, optional = true } ark-ff.workspace = true -cfg-if = "1.0.0" +cfg-if.workspace = true [dev-dependencies] proptest.workspace = true diff --git a/compiler/noirc_driver/Cargo.toml b/compiler/noirc_driver/Cargo.toml index b244018cc71..6b200e79b89 100644 --- a/compiler/noirc_driver/Cargo.toml +++ b/compiler/noirc_driver/Cargo.toml @@ -29,3 +29,7 @@ rust-embed.workspace = true tracing.workspace = true aztec_macros = { path = "../../aztec_macros" } + +[features] +bn254 = ["noirc_frontend/bn254", "noirc_evaluator/bn254"] +bls12_381 = ["noirc_frontend/bls12_381", "noirc_evaluator/bls12_381"] diff --git a/compiler/noirc_evaluator/Cargo.toml b/compiler/noirc_evaluator/Cargo.toml index 81feb0b7154..3bc7f544170 100644 --- a/compiler/noirc_evaluator/Cargo.toml +++ b/compiler/noirc_evaluator/Cargo.toml @@ -26,6 +26,11 @@ serde_json.workspace = true serde_with = "3.2.0" tracing.workspace = true chrono = "0.4.37" +cfg-if.workspace = true [dev-dependencies] -proptest.workspace = true \ No newline at end of file +proptest.workspace = true + +[features] +bn254 = ["noirc_frontend/bn254"] +bls12_381= [] diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index a6b962a45b2..6d17484ee95 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -10,7 +10,6 @@ use crate::ssa::ir::{instruction::Endian, types::NumericType}; use acvm::acir::circuit::brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs}; use acvm::acir::circuit::opcodes::{AcirFunctionId, BlockId, BlockType, MemOp}; use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, ExpressionWidth, Opcode}; -use acvm::blackbox_solver; use acvm::brillig_vm::{MemoryValue, VMStatus, VM}; use acvm::{ acir::AcirField, @@ -2128,7 +2127,11 @@ fn execute_brillig( } // Instantiate a Brillig VM given the solved input registers and memory, along with the Brillig bytecode. - let mut vm = VM::new(calldata, code, Vec::new(), &blackbox_solver::StubbedBlackBoxSolver); + // + // We pass a stubbed solver here as a concrete solver implies a field choice which conflicts with this function + // being generic. + let solver = acvm::blackbox_solver::StubbedBlackBoxSolver; + let mut vm = VM::new(calldata, code, Vec::new(), &solver); // Run the Brillig VM on these inputs, bytecode, etc! let vm_status = vm.process_opcodes(); diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index ea2523e873e..de7ab6e532d 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -1,7 +1,10 @@ use fxhash::FxHashMap as HashMap; use std::{collections::VecDeque, rc::Rc}; -use acvm::{acir::AcirField, acir::BlackBoxFunc, BlackBoxResolutionError, FieldElement}; +use acvm::{ + acir::{AcirField, BlackBoxFunc}, + BlackBoxResolutionError, FieldElement, +}; use bn254_blackbox_solver::derive_generators; use iter_extended::vecmap; use num_bigint::BigUint; @@ -20,6 +23,8 @@ use crate::ssa::{ use super::{Binary, BinaryOp, Endian, Instruction, SimplifyResult}; +mod blackbox; + /// Try to simplify this call instruction. If the instruction can be simplified to a known value, /// that value is returned. Otherwise None is returned. /// @@ -468,11 +473,17 @@ fn simplify_black_box_func( arguments: &[ValueId], dfg: &mut DataFlowGraph, ) -> SimplifyResult { + cfg_if::cfg_if! { + if #[cfg(feature = "bn254")] { + let solver = bn254_blackbox_solver::Bn254BlackBoxSolver; + } else { + let solver = acvm::blackbox_solver::StubbedBlackBoxSolver; + } + }; match bb_func { BlackBoxFunc::SHA256 => simplify_hash(dfg, arguments, acvm::blackbox_solver::sha256), BlackBoxFunc::Blake2s => simplify_hash(dfg, arguments, acvm::blackbox_solver::blake2s), BlackBoxFunc::Blake3 => simplify_hash(dfg, arguments, acvm::blackbox_solver::blake3), - BlackBoxFunc::PedersenCommitment | BlackBoxFunc::PedersenHash => SimplifyResult::None, BlackBoxFunc::Keccakf1600 => { if let Some((array_input, _)) = dfg.get_array_constant(arguments[0]) { if array_is_constant(dfg, &array_input) { @@ -503,20 +514,26 @@ fn simplify_black_box_func( BlackBoxFunc::Keccak256 => { unreachable!("Keccak256 should have been replaced by calls to Keccakf1600") } - BlackBoxFunc::Poseidon2Permutation => SimplifyResult::None, //TODO(Guillaume) - BlackBoxFunc::EcdsaSecp256k1 => { - simplify_signature(dfg, arguments, acvm::blackbox_solver::ecdsa_secp256k1_verify) - } - BlackBoxFunc::EcdsaSecp256r1 => { - simplify_signature(dfg, arguments, acvm::blackbox_solver::ecdsa_secp256r1_verify) + BlackBoxFunc::Poseidon2Permutation => { + blackbox::simplify_poseidon2_permutation(dfg, solver, arguments) } + BlackBoxFunc::EcdsaSecp256k1 => blackbox::simplify_signature( + dfg, + arguments, + acvm::blackbox_solver::ecdsa_secp256k1_verify, + ), + BlackBoxFunc::EcdsaSecp256r1 => blackbox::simplify_signature( + dfg, + arguments, + acvm::blackbox_solver::ecdsa_secp256r1_verify, + ), + + BlackBoxFunc::PedersenCommitment + | BlackBoxFunc::PedersenHash + | BlackBoxFunc::MultiScalarMul => SimplifyResult::None, + BlackBoxFunc::EmbeddedCurveAdd => blackbox::simplify_ec_add(dfg, solver, arguments), + BlackBoxFunc::SchnorrVerify => blackbox::simplify_schnorr_verify(dfg, solver, arguments), - BlackBoxFunc::MultiScalarMul - | BlackBoxFunc::SchnorrVerify - | BlackBoxFunc::EmbeddedCurveAdd => { - // Currently unsolvable here as we rely on an implementation in the backend. - SimplifyResult::None - } BlackBoxFunc::BigIntAdd | BlackBoxFunc::BigIntSub | BlackBoxFunc::BigIntMul diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs new file mode 100644 index 00000000000..706e8891cde --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -0,0 +1,190 @@ +use std::rc::Rc; + +use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement}; +use iter_extended::vecmap; + +use crate::ssa::ir::{ + dfg::DataFlowGraph, instruction::SimplifyResult, types::Type, value::ValueId, +}; + +use super::{array_is_constant, make_constant_array, to_u8_vec}; + +pub(super) fn simplify_ec_add( + dfg: &mut DataFlowGraph, + solver: impl BlackBoxFunctionSolver, + arguments: &[ValueId], +) -> SimplifyResult { + match ( + dfg.get_numeric_constant(arguments[0]), + dfg.get_numeric_constant(arguments[1]), + dfg.get_numeric_constant(arguments[2]), + dfg.get_numeric_constant(arguments[3]), + dfg.get_numeric_constant(arguments[4]), + dfg.get_numeric_constant(arguments[5]), + ) { + ( + Some(point1_x), + Some(point1_y), + Some(point1_is_infinity), + Some(point2_x), + Some(point2_y), + Some(point2_is_infinity), + ) => { + let Ok((result_x, result_y, result_is_infinity)) = solver.ec_add( + &point1_x, + &point1_y, + &point1_is_infinity, + &point2_x, + &point2_y, + &point2_is_infinity, + ) else { + return SimplifyResult::None; + }; + + let result_x = dfg.make_constant(result_x, Type::field()); + let result_y = dfg.make_constant(result_y, Type::field()); + let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); + + let typ = Type::Array(Rc::new(vec![Type::field()]), 3); + let result_array = + dfg.make_array(im::vector![result_x, result_y, result_is_infinity], typ); + + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } +} + +pub(super) fn simplify_poseidon2_permutation( + dfg: &mut DataFlowGraph, + solver: impl BlackBoxFunctionSolver, + arguments: &[ValueId], +) -> SimplifyResult { + match (dfg.get_array_constant(arguments[0]), dfg.get_numeric_constant(arguments[1])) { + (Some((state, _)), Some(state_length)) if array_is_constant(dfg, &state) => { + let state: Vec = state + .iter() + .map(|id| { + dfg.get_numeric_constant(*id) + .expect("value id from array should point at constant") + }) + .collect(); + + let Some(state_length) = state_length.try_to_u32() else { + return SimplifyResult::None; + }; + + let Ok(new_state) = solver.poseidon2_permutation(&state, state_length) else { + return SimplifyResult::None; + }; + + let result_array = make_constant_array(dfg, new_state, Type::field()); + + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } +} + +pub(super) fn simplify_schnorr_verify( + dfg: &mut DataFlowGraph, + solver: impl BlackBoxFunctionSolver, + arguments: &[ValueId], +) -> SimplifyResult { + match ( + dfg.get_numeric_constant(arguments[0]), + dfg.get_numeric_constant(arguments[1]), + dfg.get_array_constant(arguments[2]), + dfg.get_array_constant(arguments[3]), + ) { + (Some(public_key_x), Some(public_key_y), Some((signature, _)), Some((message, _))) + if array_is_constant(dfg, &signature) && array_is_constant(dfg, &message) => + { + let signature = to_u8_vec(dfg, signature); + let signature: [u8; 64] = + signature.try_into().expect("Compiler should produce correctly sized signature"); + + let message = to_u8_vec(dfg, message); + + let Ok(valid_signature) = + solver.schnorr_verify(&public_key_x, &public_key_y, &signature, &message) + else { + return SimplifyResult::None; + }; + + let valid_signature = dfg.make_constant(valid_signature.into(), Type::bool()); + SimplifyResult::SimplifiedTo(valid_signature) + } + _ => SimplifyResult::None, + } +} + +pub(super) fn simplify_hash( + dfg: &mut DataFlowGraph, + arguments: &[ValueId], + hash_function: fn(&[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, +) -> SimplifyResult { + match dfg.get_array_constant(arguments[0]) { + Some((input, _)) if array_is_constant(dfg, &input) => { + let input_bytes: Vec = to_u8_vec(dfg, input); + + let hash = hash_function(&input_bytes) + .expect("Rust solvable black box function should not fail"); + + let hash_values = vecmap(hash, |byte| FieldElement::from_be_bytes_reduce(&[byte])); + + let result_array = make_constant_array(dfg, hash_values, Type::unsigned(8)); + SimplifyResult::SimplifiedTo(result_array) + } + _ => SimplifyResult::None, + } +} + +type ECDSASignatureVerifier = fn( + hashed_msg: &[u8], + public_key_x: &[u8; 32], + public_key_y: &[u8; 32], + signature: &[u8; 64], +) -> Result; + +pub(super) fn simplify_signature( + dfg: &mut DataFlowGraph, + arguments: &[ValueId], + signature_verifier: ECDSASignatureVerifier, +) -> SimplifyResult { + match ( + dfg.get_array_constant(arguments[0]), + dfg.get_array_constant(arguments[1]), + dfg.get_array_constant(arguments[2]), + dfg.get_array_constant(arguments[3]), + ) { + ( + Some((public_key_x, _)), + Some((public_key_y, _)), + Some((signature, _)), + Some((hashed_message, _)), + ) if array_is_constant(dfg, &public_key_x) + && array_is_constant(dfg, &public_key_y) + && array_is_constant(dfg, &signature) + && array_is_constant(dfg, &hashed_message) => + { + let public_key_x: [u8; 32] = to_u8_vec(dfg, public_key_x) + .try_into() + .expect("ECDSA public key fields are 32 bytes"); + let public_key_y: [u8; 32] = to_u8_vec(dfg, public_key_y) + .try_into() + .expect("ECDSA public key fields are 32 bytes"); + let signature: [u8; 64] = + to_u8_vec(dfg, signature).try_into().expect("ECDSA signatures are 64 bytes"); + let hashed_message: Vec = to_u8_vec(dfg, hashed_message); + + let valid_signature = + signature_verifier(&hashed_message, &public_key_x, &public_key_y, &signature) + .expect("Rust solvable black box function should not fail"); + + let valid_signature = dfg.make_constant(valid_signature.into(), Type::bool()); + SimplifyResult::SimplifiedTo(valid_signature) + } + _ => SimplifyResult::None, + } +} diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index 7ef8870eaa8..c0f6c8965fb 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -27,7 +27,7 @@ num-traits.workspace = true rustc-hash = "1.1.0" small-ord-set = "0.1.3" regex = "1.9.1" -cfg-if = "1.0.0" +cfg-if.workspace = true tracing.workspace = true petgraph = "0.6" rangemap = "1.4.0" diff --git a/test_programs/compile_success_empty/embedded_curve_add_simplification/Nargo.toml b/test_programs/compile_success_empty/embedded_curve_add_simplification/Nargo.toml new file mode 100644 index 00000000000..02586f5e926 --- /dev/null +++ b/test_programs/compile_success_empty/embedded_curve_add_simplification/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "embedded_curve_add_simplification" +type = "bin" +authors = [""] +compiler_version = ">=0.23.0" + +[dependencies] diff --git a/test_programs/compile_success_empty/embedded_curve_add_simplification/src/main.nr b/test_programs/compile_success_empty/embedded_curve_add_simplification/src/main.nr new file mode 100644 index 00000000000..39992a6454b --- /dev/null +++ b/test_programs/compile_success_empty/embedded_curve_add_simplification/src/main.nr @@ -0,0 +1,11 @@ +use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}; + +fn main() { + let zero = EmbeddedCurvePoint::point_at_infinity(); + let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + + assert(g1 + zero == g1); + assert(g1 - g1 == zero); + assert(g1 - zero == g1); + assert(zero + zero == zero); +} diff --git a/test_programs/compile_success_empty/poseidon2_simplification/Nargo.toml b/test_programs/compile_success_empty/poseidon2_simplification/Nargo.toml new file mode 100644 index 00000000000..fbf2c11b220 --- /dev/null +++ b/test_programs/compile_success_empty/poseidon2_simplification/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "poseidon2_simplification" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/poseidon2_simplification/src/main.nr b/test_programs/compile_success_empty/poseidon2_simplification/src/main.nr new file mode 100644 index 00000000000..423dfcc4d3b --- /dev/null +++ b/test_programs/compile_success_empty/poseidon2_simplification/src/main.nr @@ -0,0 +1,7 @@ +use std::hash::poseidon2; + +fn main() { + let digest = poseidon2::Poseidon2::hash([0], 1); + let expected_digest = 0x2710144414c3a5f2354f4c08d52ed655b9fe253b4bf12cb9ad3de693d9b1db11; + assert_eq(digest, expected_digest); +} diff --git a/test_programs/compile_success_empty/schnorr_simplification/Nargo.toml b/test_programs/compile_success_empty/schnorr_simplification/Nargo.toml new file mode 100644 index 00000000000..599f06ac3d2 --- /dev/null +++ b/test_programs/compile_success_empty/schnorr_simplification/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "schnorr_simplification" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/compile_success_empty/schnorr_simplification/src/main.nr b/test_programs/compile_success_empty/schnorr_simplification/src/main.nr new file mode 100644 index 00000000000..e1095cd7fe2 --- /dev/null +++ b/test_programs/compile_success_empty/schnorr_simplification/src/main.nr @@ -0,0 +1,78 @@ +use std::embedded_curve_ops; + +// Note: If main has any unsized types, then the verifier will never be able +// to figure out the circuit instance +fn main() { + let message = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let pub_key_x = 0x04b260954662e97f00cab9adb773a259097f7a274b83b113532bce27fa3fb96a; + let pub_key_y = 0x2fd51571db6c08666b0edfbfbc57d432068bccd0110a39b166ab243da0037197; + let signature = [ + 1, + 13, + 119, + 112, + 212, + 39, + 233, + 41, + 84, + 235, + 255, + 93, + 245, + 172, + 186, + 83, + 157, + 253, + 76, + 77, + 33, + 128, + 178, + 15, + 214, + 67, + 105, + 107, + 177, + 234, + 77, + 48, + 27, + 237, + 155, + 84, + 39, + 84, + 247, + 27, + 22, + 8, + 176, + 230, + 24, + 115, + 145, + 220, + 254, + 122, + 135, + 179, + 171, + 4, + 214, + 202, + 64, + 199, + 19, + 84, + 239, + 138, + 124, + 12 + ]; + + let valid_signature = std::schnorr::verify_signature(pub_key_x, pub_key_y, signature, message); + assert(valid_signature); +} diff --git a/test_programs/noir_test_success/embedded_curve_ops/src/main.nr b/test_programs/noir_test_success/embedded_curve_ops/src/main.nr index 0c2c333fa62..760df58c34a 100644 --- a/test_programs/noir_test_success/embedded_curve_ops/src/main.nr +++ b/test_programs/noir_test_success/embedded_curve_ops/src/main.nr @@ -4,7 +4,6 @@ use std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_sca fn test_infinite_point() { let zero = EmbeddedCurvePoint::point_at_infinity(); - let zero = EmbeddedCurvePoint { x: 0, y: 0, is_infinite: true }; let g1 = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; let g2 = g1 + g1; diff --git a/tooling/nargo_cli/Cargo.toml b/tooling/nargo_cli/Cargo.toml index f3d9f92caaa..4e3f3a57e87 100644 --- a/tooling/nargo_cli/Cargo.toml +++ b/tooling/nargo_cli/Cargo.toml @@ -31,7 +31,7 @@ nargo_fmt.workspace = true nargo_toml.workspace = true noir_lsp.workspace = true noir_debugger.workspace = true -noirc_driver.workspace = true +noirc_driver = { workspace = true, features = ["bn254"] } noirc_frontend = { workspace = true, features = ["bn254"] } noirc_abi.workspace = true noirc_errors.workspace = true From c7473c6fcc6cbfd118b0352229ff86001cde3a64 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:21:18 +0100 Subject: [PATCH 08/21] chore: add documentation to `to_be_bytes`, etc. (#5843) # Description ## Problem\* Resolves ## Summary\* This PR adds some proper documentation to `to_be_bytes` and similar functions. This flags up the potential security concern and failure cases. ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [x] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- noir_stdlib/src/array.nr | 1 + noir_stdlib/src/field/mod.nr | 69 ++++++++++++++++++++++++++++++------ noir_stdlib/src/slice.nr | 5 +-- 3 files changed, 63 insertions(+), 12 deletions(-) diff --git a/noir_stdlib/src/array.nr b/noir_stdlib/src/array.nr index cef79e7c7f6..23683a54e45 100644 --- a/noir_stdlib/src/array.nr +++ b/noir_stdlib/src/array.nr @@ -3,6 +3,7 @@ use crate::option::Option; use crate::convert::From; impl [T; N] { + /// Returns the length of the slice. #[builtin(array_len)] pub fn len(self) -> u32 {} diff --git a/noir_stdlib/src/field/mod.nr b/noir_stdlib/src/field/mod.nr index 4b6deaa1106..534ac012beb 100644 --- a/noir_stdlib/src/field/mod.nr +++ b/noir_stdlib/src/field/mod.nr @@ -2,35 +2,85 @@ mod bn254; use bn254::lt as bn254_lt; impl Field { + /// Asserts that `self` can be represented in `bit_size` bits. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{bit_size}`. + pub fn assert_max_bit_size(self, bit_size: u32) { + crate::assert_constant(bit_size); + assert(bit_size < modulus_num_bits() as u32); + self.__assert_max_bit_size(bit_size); + } + + #[builtin(apply_range_constraint)] + fn __assert_max_bit_size(self, bit_size: u32) {} + + /// Decomposes `self` into its little endian bit decomposition as a `[u1]` slice of length `bit_size`. + /// This slice will be zero padded should not all bits be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{bit_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `bit_size` equal to or greater than the number of bits necessary to represent the `Field` modulus + /// (e.g. 254 for the BN254 field) allow for multiple bit decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_le_bits(self: Self, bit_size: u32) -> [u1] { crate::assert_constant(bit_size); self.__to_le_bits(bit_size) } + /// Decomposes `self` into its big endian bit decomposition as a `[u1]` slice of length `bit_size`. + /// This slice will be zero padded should not all bits be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{bit_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `bit_size` equal to or greater than the number of bits necessary to represent the `Field` modulus + /// (e.g. 254 for the BN254 field) allow for multiple bit decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_be_bits(self: Self, bit_size: u32) -> [u1] { crate::assert_constant(bit_size); self.__to_be_bits(bit_size) } + /// See `Field.to_be_bits` #[builtin(to_le_bits)] fn __to_le_bits(self, _bit_size: u32) -> [u1] {} + /// See `Field.to_le_bits` #[builtin(to_be_bits)] fn __to_be_bits(self, bit_size: u32) -> [u1] {} - #[builtin(apply_range_constraint)] - fn __assert_max_bit_size(self, bit_size: u32) {} - - pub fn assert_max_bit_size(self: Self, bit_size: u32) { - crate::assert_constant(bit_size); - assert(bit_size < modulus_num_bits() as u32); - self.__assert_max_bit_size(bit_size); - } - + /// Decomposes `self` into its little endian byte decomposition as a `[u8]` slice of length `byte_size`. + /// This slice will be zero padded should not all bytes be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{8*byte_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `byte_size` equal to or greater than the number of bytes necessary to represent the `Field` modulus + /// (e.g. 32 for the BN254 field) allow for multiple byte decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_le_bytes(self: Self, byte_size: u32) -> [u8] { self.to_le_radix(256, byte_size) } + /// Decomposes `self` into its big endian byte decomposition as a `[u8]` slice of length `byte_size`. + /// This slice will be zero padded should not all bytes be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^{8*byte_size}` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// Values of `byte_size` equal to or greater than the number of bytes necessary to represent the `Field` modulus + /// (e.g. 32 for the BN254 field) allow for multiple byte decompositions. This is due to how the `Field` will + /// wrap around due to overflow when verifying the decomposition. pub fn to_be_bytes(self: Self, byte_size: u32) -> [u8] { self.to_be_radix(256, byte_size) } @@ -47,7 +97,6 @@ impl Field { self.__to_be_radix(radix, result_len) } - // decompose `_self` into a `_result_len` vector over the `_radix` basis // `_radix` must be less than 256 #[builtin(to_le_radix)] fn __to_le_radix(self, radix: u32, result_len: u32) -> [u8] {} diff --git a/noir_stdlib/src/slice.nr b/noir_stdlib/src/slice.nr index f9aa98a9ecd..66c69db65f0 100644 --- a/noir_stdlib/src/slice.nr +++ b/noir_stdlib/src/slice.nr @@ -1,6 +1,7 @@ use crate::append::Append; impl [T] { + /// Returns the length of the slice. #[builtin(array_len)] pub fn len(self) -> u32 {} @@ -37,8 +38,8 @@ impl [T] { #[builtin(slice_remove)] pub fn remove(self, index: u32) -> (Self, T) {} - // Append each element of the `other` slice to the end of `self`. - // This returns a new slice and leaves both input slices unchanged. + /// Append each element of the `other` slice to the end of `self`. + /// This returns a new slice and leaves both input slices unchanged. pub fn append(mut self, other: Self) -> Self { for elem in other { self = self.push_back(elem); From 58f855ec2124db39e5b2b08630d514d852d0e7df Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Wed, 28 Aug 2024 14:57:11 -0300 Subject: [PATCH 09/21] feat: warn on unused imports (#5847) # Description ## Problem Resolves #4762 ## Summary The compiler now produces a warning on unused imports. ## Additional Context I don't know if this is the correct approach. I added a new HashSet to track imported names, then those are removed as we import things. I thought about tracking this in `ItemScope` but it's tricky because the way things are they include imported and self things. There's another thing: eventually it would be nice to warn on unused types, like in Rust. For that maybe it would make sense to track this in `ItemScope`... but given that we currently don't have visibility modifiers for types, it can't be done right now. So maybe doing it just for imports with specific code is fine. I had to refactor a bit `resolve_path` because some paths were looked up not using that method and so some things weren't marked as used... but now `StandardPathResolver::new` is used in exactly one place in the compiler. ## Documentation Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- compiler/noirc_driver/src/lib.rs | 26 ++++++++-- compiler/noirc_frontend/src/elaborator/mod.rs | 14 ++---- .../noirc_frontend/src/elaborator/scope.rs | 43 ++++++++++------ .../noirc_frontend/src/elaborator/types.rs | 2 +- .../src/hir/def_collector/dc_crate.rs | 22 ++++++++ .../noirc_frontend/src/hir/def_map/mod.rs | 3 ++ .../src/hir/def_map/module_data.rs | 22 +++++++- .../src/hir/resolution/errors.rs | 11 ++++ compiler/noirc_frontend/src/tests.rs | 50 +++++++++++++++++-- tooling/lsp/src/notifications/mod.rs | 8 ++- tooling/nargo/src/package.rs | 7 +++ tooling/nargo_cli/src/cli/check_cmd.rs | 11 ++-- tooling/nargo_cli/src/cli/export_cmd.rs | 6 ++- tooling/nargo_cli/src/cli/test_cmd.rs | 13 +++-- 14 files changed, 195 insertions(+), 43 deletions(-) diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index cb3a4d25c9d..f95c9de7c2c 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -131,6 +131,18 @@ pub struct CompileOptions { pub skip_underconstrained_check: bool, } +#[derive(Clone, Debug, Default)] +pub struct CheckOptions { + pub compile_options: CompileOptions, + pub error_on_unused_imports: bool, +} + +impl CheckOptions { + pub fn new(compile_options: &CompileOptions, error_on_unused_imports: bool) -> Self { + Self { compile_options: compile_options.clone(), error_on_unused_imports } + } +} + pub fn parse_expression_width(input: &str) -> Result { use std::io::{Error, ErrorKind}; let width = input @@ -278,8 +290,10 @@ pub fn add_dep( pub fn check_crate( context: &mut Context, crate_id: CrateId, - options: &CompileOptions, + check_options: &CheckOptions, ) -> CompilationResult<()> { + let options = &check_options.compile_options; + let macros: &[&dyn MacroProcessor] = if options.disable_macros { &[] } else { &[&aztec_macros::AztecMacro] }; @@ -289,6 +303,7 @@ pub fn check_crate( context, options.debug_comptime_in_file.as_deref(), options.arithmetic_generics, + check_options.error_on_unused_imports, macros, ); errors.extend(diagnostics.into_iter().map(|(error, file_id)| { @@ -322,7 +337,10 @@ pub fn compile_main( options: &CompileOptions, cached_program: Option, ) -> CompilationResult { - let (_, mut warnings) = check_crate(context, crate_id, options)?; + let error_on_unused_imports = true; + let check_options = CheckOptions::new(options, error_on_unused_imports); + + let (_, mut warnings) = check_crate(context, crate_id, &check_options)?; let main = context.get_main_function(&crate_id).ok_or_else(|| { // TODO(#2155): This error might be a better to exist in Nargo @@ -357,7 +375,9 @@ pub fn compile_contract( crate_id: CrateId, options: &CompileOptions, ) -> CompilationResult { - let (_, warnings) = check_crate(context, crate_id, options)?; + let error_on_unused_imports = true; + let check_options = CheckOptions::new(options, error_on_unused_imports); + let (_, warnings) = check_crate(context, crate_id, &check_options)?; // TODO: We probably want to error if contracts is empty let contracts = context.get_all_contracts(&crate_id); diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 53b46536078..e8b38193223 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -11,7 +11,7 @@ use crate::{ UnresolvedTypeAlias, }, def_map::DefMaps, - resolution::{errors::ResolverError, path_resolver::PathResolver}, + resolution::errors::ResolverError, scope::ScopeForest as GenericScopeForest, type_check::{generics::TraitGenerics, TypeCheckError}, }, @@ -36,7 +36,7 @@ use crate::{ hir::{ def_collector::{dc_crate::CollectedItems, errors::DefCollectorErrorKind}, def_map::{LocalModuleId, ModuleDefId, ModuleId, MAIN_FUNCTION}, - resolution::{import::PathResolution, path_resolver::StandardPathResolver}, + resolution::import::PathResolution, Context, }, hir_def::function::{FuncMeta, HirFunction}, @@ -630,10 +630,8 @@ impl<'context> Elaborator<'context> { } } - pub fn resolve_module_by_path(&self, path: Path) -> Option { - let path_resolver = StandardPathResolver::new(self.module_id()); - - match path_resolver.resolve(self.def_maps, path.clone(), &mut None) { + pub fn resolve_module_by_path(&mut self, path: Path) -> Option { + match self.resolve_path(path.clone()) { Ok(PathResolution { module_def_id: ModuleDefId::ModuleId(module_id), error }) => { if error.is_some() { None @@ -646,9 +644,7 @@ impl<'context> Elaborator<'context> { } fn resolve_trait_by_path(&mut self, path: Path) -> Option { - let path_resolver = StandardPathResolver::new(self.module_id()); - - let error = match path_resolver.resolve(self.def_maps, path.clone(), &mut None) { + let error = match self.resolve_path(path.clone()) { Ok(PathResolution { module_def_id: ModuleDefId::TraitId(trait_id), error }) => { if let Some(error) = error { self.push_err(error); diff --git a/compiler/noirc_frontend/src/elaborator/scope.rs b/compiler/noirc_frontend/src/elaborator/scope.rs index 3288d10b62e..a51fd737f74 100644 --- a/compiler/noirc_frontend/src/elaborator/scope.rs +++ b/compiler/noirc_frontend/src/elaborator/scope.rs @@ -2,6 +2,7 @@ use noirc_errors::{Location, Spanned}; use crate::ast::{PathKind, ERROR_IDENT}; use crate::hir::def_map::{LocalModuleId, ModuleId}; +use crate::hir::resolution::import::{PathResolution, PathResolutionResult}; use crate::hir::resolution::path_resolver::{PathResolver, StandardPathResolver}; use crate::hir::scope::{Scope as GenericScope, ScopeTree as GenericScopeTree}; use crate::macros_api::Ident; @@ -29,7 +30,7 @@ type ScopeTree = GenericScopeTree; impl<'context> Elaborator<'context> { pub(super) fn lookup(&mut self, path: Path) -> Result { let span = path.span(); - let id = self.resolve_path(path)?; + let id = self.resolve_path_or_error(path)?; T::try_from(id).ok_or_else(|| ResolverError::Expected { expected: T::description(), got: id.as_str().to_owned(), @@ -42,15 +43,37 @@ impl<'context> Elaborator<'context> { ModuleId { krate: self.crate_id, local_id: self.local_module } } - pub(super) fn resolve_path(&mut self, path: Path) -> Result { + pub(super) fn resolve_path_or_error( + &mut self, + path: Path, + ) -> Result { + let path_resolution = self.resolve_path(path)?; + + if let Some(error) = path_resolution.error { + self.push_err(error); + } + + Ok(path_resolution.module_def_id) + } + + pub(super) fn resolve_path(&mut self, path: Path) -> PathResolutionResult { let mut module_id = self.module_id(); let mut path = path; + if path.kind == PathKind::Plain { + let def_map = self.def_maps.get_mut(&self.crate_id).unwrap(); + let module_data = &mut def_map.modules[module_id.local_id.0]; + module_data.use_import(&path.segments[0].ident); + } + if path.kind == PathKind::Plain && path.first_name() == SELF_TYPE_NAME { if let Some(Type::Struct(struct_type, _)) = &self.self_type { let struct_type = struct_type.borrow(); if path.segments.len() == 1 { - return Ok(ModuleDefId::TypeId(struct_type.id)); + return Ok(PathResolution { + module_def_id: ModuleDefId::TypeId(struct_type.id), + error: None, + }); } module_id = struct_type.id.module_id(); @@ -65,11 +88,7 @@ impl<'context> Elaborator<'context> { self.resolve_path_in_module(path, module_id) } - fn resolve_path_in_module( - &mut self, - path: Path, - module_id: ModuleId, - ) -> Result { + fn resolve_path_in_module(&mut self, path: Path, module_id: ModuleId) -> PathResolutionResult { let resolver = StandardPathResolver::new(module_id); let path_resolution; @@ -99,11 +118,7 @@ impl<'context> Elaborator<'context> { path_resolution = resolver.resolve(self.def_maps, path, &mut None)?; } - if let Some(error) = path_resolution.error { - self.push_err(error); - } - - Ok(path_resolution.module_def_id) + Ok(path_resolution) } pub(super) fn get_struct(&self, type_id: StructId) -> Shared { @@ -150,7 +165,7 @@ impl<'context> Elaborator<'context> { pub(super) fn lookup_global(&mut self, path: Path) -> Result { let span = path.span(); - let id = self.resolve_path(path)?; + let id = self.resolve_path_or_error(path)?; if let Some(function) = TryFromModuleDefId::try_from(id) { return Ok(self.interner.function_definition_id(function)); diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 44bded6b92f..3b1ffeb2fc2 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -421,7 +421,7 @@ impl<'context> Elaborator<'context> { } // If we cannot find a local generic of the same name, try to look up a global - match self.resolve_path(path.clone()) { + match self.resolve_path_or_error(path.clone()) { Ok(ModuleDefId::GlobalId(id)) => { if let Some(current_item) = self.current_item { self.interner.add_global_dependency(current_item, id); diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index a961de628a8..30c91b42b2e 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -245,6 +245,7 @@ impl DefCollector { /// Collect all of the definitions in a given crate into a CrateDefMap /// Modules which are not a part of the module hierarchy starting with /// the root module, will be ignored. + #[allow(clippy::too_many_arguments)] pub fn collect_crate_and_dependencies( mut def_map: CrateDefMap, context: &mut Context, @@ -252,6 +253,7 @@ impl DefCollector { root_file_id: FileId, debug_comptime_in_file: Option<&str>, enable_arithmetic_generics: bool, + error_on_unused_imports: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { let mut errors: Vec<(CompilationError, FileId)> = vec![]; @@ -265,11 +267,13 @@ impl DefCollector { let crate_graph = &context.crate_graph[crate_id]; for dep in crate_graph.dependencies.clone() { + let error_on_unused_imports = false; errors.extend(CrateDefMap::collect_defs( dep.crate_id, context, debug_comptime_in_file, enable_arithmetic_generics, + error_on_unused_imports, macro_processors, )); @@ -413,8 +417,26 @@ impl DefCollector { ); } + if error_on_unused_imports { + Self::check_unused_imports(context, crate_id, &mut errors); + } + errors } + + fn check_unused_imports( + context: &Context, + crate_id: CrateId, + errors: &mut Vec<(CompilationError, FileId)>, + ) { + errors.extend(context.def_maps[&crate_id].modules().iter().flat_map(|(_, module)| { + module.unused_imports().iter().map(|ident| { + let ident = ident.clone(); + let error = CompilationError::ResolverError(ResolverError::UnusedImport { ident }); + (error, module.location.file) + }) + })); + } } fn add_import_reference( diff --git a/compiler/noirc_frontend/src/hir/def_map/mod.rs b/compiler/noirc_frontend/src/hir/def_map/mod.rs index e607de52ff1..758b4cf6e5c 100644 --- a/compiler/noirc_frontend/src/hir/def_map/mod.rs +++ b/compiler/noirc_frontend/src/hir/def_map/mod.rs @@ -77,6 +77,7 @@ impl CrateDefMap { context: &mut Context, debug_comptime_in_file: Option<&str>, enable_arithmetic_generics: bool, + error_on_unused_imports: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { // Check if this Crate has already been compiled @@ -127,12 +128,14 @@ impl CrateDefMap { root_file_id, debug_comptime_in_file, enable_arithmetic_generics, + error_on_unused_imports, macro_processors, )); errors.extend( parsing_errors.iter().map(|e| (e.clone().into(), root_file_id)).collect::>(), ); + errors } diff --git a/compiler/noirc_frontend/src/hir/def_map/module_data.rs b/compiler/noirc_frontend/src/hir/def_map/module_data.rs index 8a0125cfe95..7b14db8be77 100644 --- a/compiler/noirc_frontend/src/hir/def_map/module_data.rs +++ b/compiler/noirc_frontend/src/hir/def_map/module_data.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use noirc_errors::Location; @@ -24,6 +24,10 @@ pub struct ModuleData { /// True if this module is a `contract Foo { ... }` module containing contract functions pub is_contract: bool, + + /// List of all unused imports. Each time something is imported into this module it's added + /// to this set. When it's used, it's removed. At the end of the program only unused imports remain. + unused_imports: HashSet, } impl ModuleData { @@ -35,6 +39,7 @@ impl ModuleData { definitions: ItemScope::default(), location, is_contract, + unused_imports: HashSet::new(), } } @@ -121,6 +126,11 @@ impl ModuleData { id: ModuleDefId, is_prelude: bool, ) -> Result<(), (Ident, Ident)> { + // Empty spans could come from implicitly injected imports, and we don't want to track those + if name.span().start() < name.span().end() { + self.unused_imports.insert(name.clone()); + } + self.scope.add_item_to_namespace(name, ItemVisibility::Public, id, None, is_prelude) } @@ -137,4 +147,14 @@ impl ModuleData { pub fn value_definitions(&self) -> impl Iterator + '_ { self.definitions.values().values().flat_map(|a| a.values().map(|(id, _, _)| *id)) } + + /// Marks an ident as being used by an import. + pub fn use_import(&mut self, ident: &Ident) { + self.unused_imports.remove(ident); + } + + /// Returns the list of all unused imports at this moment. + pub fn unused_imports(&self) -> &HashSet { + &self.unused_imports + } } diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 0aad50d13b2..e5a89e61fc2 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -20,6 +20,8 @@ pub enum ResolverError { DuplicateDefinition { name: String, first_span: Span, second_span: Span }, #[error("Unused variable")] UnusedVariable { ident: Ident }, + #[error("Unused import")] + UnusedImport { ident: Ident }, #[error("Could not find variable in this scope")] VariableNotDeclared { name: String, span: Span }, #[error("path is not an identifier")] @@ -152,6 +154,15 @@ impl<'a> From<&'a ResolverError> for Diagnostic { ident.span(), ) } + ResolverError::UnusedImport { ident } => { + let name = &ident.0.contents; + + Diagnostic::simple_warning( + format!("unused import {name}"), + "unused import ".to_string(), + ident.span(), + ) + } ResolverError::VariableNotDeclared { name, span } => Diagnostic::simple_error( format!("cannot find `{name}` in this scope "), "not found in this scope".to_string(), diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index cc4aae7f447..870c781b89d 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -76,15 +76,21 @@ pub(crate) fn get_program(src: &str) -> (ParsedModule, Context, Vec<(Compilation extern_prelude: BTreeMap::new(), }; + let debug_comptime_in_file = None; + let enable_arithmetic_generics = false; + let error_on_unused_imports = true; + let macro_processors = &[]; + // Now we want to populate the CrateDefMap using the DefCollector errors.extend(DefCollector::collect_crate_and_dependencies( def_map, &mut context, program.clone().into_sorted(), root_file_id, - None, // No debug_comptime_in_file - false, // Disallow arithmetic generics - &[], // No macro processors + debug_comptime_in_file, + enable_arithmetic_generics, + error_on_unused_imports, + macro_processors, )); } (program, context, errors) @@ -2424,6 +2430,10 @@ fn use_super() { mod foo { use super::some_func; + + fn bar() { + some_func(); + } } "#; assert_no_errors(src); @@ -3187,3 +3197,37 @@ fn as_trait_path_syntax_no_impl() { use CompilationError::TypeError; assert!(matches!(&errors[0].0, TypeError(TypeCheckError::NoMatchingImplFound { .. }))); } + +#[test] +fn errors_on_unused_import() { + let src = r#" + mod foo { + pub fn bar() {} + pub fn baz() {} + + trait Foo { + } + } + + use foo::bar; + use foo::baz; + use foo::Foo; + + impl Foo for Field { + } + + fn main() { + baz(); + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::ResolverError(ResolverError::UnusedImport { ident }) = &errors[0].0 + else { + panic!("Expected an unused import error"); + }; + + assert_eq!(ident.to_string(), "bar"); +} diff --git a/tooling/lsp/src/notifications/mod.rs b/tooling/lsp/src/notifications/mod.rs index 4d2186badc3..8b030c9e0aa 100644 --- a/tooling/lsp/src/notifications/mod.rs +++ b/tooling/lsp/src/notifications/mod.rs @@ -2,7 +2,7 @@ use std::ops::ControlFlow; use crate::insert_all_files_for_workspace_into_file_manager; use async_lsp::{ErrorCode, LanguageClient, ResponseError}; -use noirc_driver::{check_crate, file_manager_with_stdlib}; +use noirc_driver::{check_crate, file_manager_with_stdlib, CheckOptions}; use noirc_errors::{DiagnosticKind, FileDiagnostic}; use crate::types::{ @@ -132,7 +132,11 @@ pub(crate) fn process_workspace_for_noir_document( let (mut context, crate_id) = crate::prepare_package(&workspace_file_manager, &parsed_files, package); - let file_diagnostics = match check_crate(&mut context, crate_id, &Default::default()) { + let options = CheckOptions { + error_on_unused_imports: package.error_on_unused_imports(), + ..Default::default() + }; + let file_diagnostics = match check_crate(&mut context, crate_id, &options) { Ok(((), warnings)) => warnings, Err(errors_and_warnings) => errors_and_warnings, }; diff --git a/tooling/nargo/src/package.rs b/tooling/nargo/src/package.rs index f55ca5550a3..cde616a9e32 100644 --- a/tooling/nargo/src/package.rs +++ b/tooling/nargo/src/package.rs @@ -73,4 +73,11 @@ impl Package { pub fn is_library(&self) -> bool { self.package_type == PackageType::Library } + + pub fn error_on_unused_imports(&self) -> bool { + match self.package_type { + PackageType::Library => false, + PackageType::Binary | PackageType::Contract => true, + } + } } diff --git a/tooling/nargo_cli/src/cli/check_cmd.rs b/tooling/nargo_cli/src/cli/check_cmd.rs index 5239070b4d2..1130a82fdfc 100644 --- a/tooling/nargo_cli/src/cli/check_cmd.rs +++ b/tooling/nargo_cli/src/cli/check_cmd.rs @@ -10,7 +10,7 @@ use nargo::{ use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; use noirc_abi::{AbiParameter, AbiType, MAIN_RETURN_NAME}; use noirc_driver::{ - check_crate, compute_function_abi, file_manager_with_stdlib, CompileOptions, + check_crate, compute_function_abi, file_manager_with_stdlib, CheckOptions, CompileOptions, NOIR_ARTIFACT_VERSION_STRING, }; use noirc_frontend::{ @@ -81,7 +81,9 @@ fn check_package( allow_overwrite: bool, ) -> Result { let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate_and_report_errors(&mut context, crate_id, compile_options)?; + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(compile_options, error_on_unused_imports); + check_crate_and_report_errors(&mut context, crate_id, &check_options)?; if package.is_library() || package.is_contract() { // Libraries do not have ABIs while contracts have many, so we cannot generate a `Prover.toml` file. @@ -150,9 +152,10 @@ fn create_input_toml_template( pub(crate) fn check_crate_and_report_errors( context: &mut Context, crate_id: CrateId, - options: &CompileOptions, + check_options: &CheckOptions, ) -> Result<(), CompileError> { - let result = check_crate(context, crate_id, options); + let options = &check_options.compile_options; + let result = check_crate(context, crate_id, check_options); report_errors(result, &context.file_manager, options.deny_warnings, options.silence_warnings) } diff --git a/tooling/nargo_cli/src/cli/export_cmd.rs b/tooling/nargo_cli/src/cli/export_cmd.rs index 19add7f30dc..5721dd33e27 100644 --- a/tooling/nargo_cli/src/cli/export_cmd.rs +++ b/tooling/nargo_cli/src/cli/export_cmd.rs @@ -12,7 +12,7 @@ use nargo::workspace::Workspace; use nargo::{insert_all_files_for_workspace_into_file_manager, parse_all}; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; use noirc_driver::{ - compile_no_check, file_manager_with_stdlib, CompileOptions, CompiledProgram, + compile_no_check, file_manager_with_stdlib, CheckOptions, CompileOptions, CompiledProgram, NOIR_ARTIFACT_VERSION_STRING, }; @@ -83,7 +83,9 @@ fn compile_exported_functions( compile_options: &CompileOptions, ) -> Result<(), CliError> { let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate_and_report_errors(&mut context, crate_id, compile_options)?; + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(compile_options, error_on_unused_imports); + check_crate_and_report_errors(&mut context, crate_id, &check_options)?; let exported_functions = context.get_all_exported_functions_in_crate(&crate_id); diff --git a/tooling/nargo_cli/src/cli/test_cmd.rs b/tooling/nargo_cli/src/cli/test_cmd.rs index 0d7c8fc8bf7..2b0c0fd58db 100644 --- a/tooling/nargo_cli/src/cli/test_cmd.rs +++ b/tooling/nargo_cli/src/cli/test_cmd.rs @@ -10,7 +10,8 @@ use nargo::{ }; use nargo_toml::{get_package_manifest, resolve_workspace_from_toml, PackageSelection}; use noirc_driver::{ - check_crate, file_manager_with_stdlib, CompileOptions, NOIR_ARTIFACT_VERSION_STRING, + check_crate, file_manager_with_stdlib, CheckOptions, CompileOptions, + NOIR_ARTIFACT_VERSION_STRING, }; use noirc_frontend::{ graph::CrateName, @@ -180,7 +181,9 @@ fn run_test + Default>( // We then need to construct a separate copy for each test. let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate(&mut context, crate_id, compile_options) + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(compile_options, error_on_unused_imports); + check_crate(&mut context, crate_id, &check_options) .expect("Any errors should have occurred when collecting test functions"); let test_functions = context @@ -206,10 +209,12 @@ fn get_tests_in_package( parsed_files: &ParsedFiles, package: &Package, fn_name: FunctionNameMatch, - compile_options: &CompileOptions, + options: &CompileOptions, ) -> Result, CliError> { let (mut context, crate_id) = prepare_package(file_manager, parsed_files, package); - check_crate_and_report_errors(&mut context, crate_id, compile_options)?; + let error_on_unused_imports = package.error_on_unused_imports(); + let check_options = CheckOptions::new(options, error_on_unused_imports); + check_crate_and_report_errors(&mut context, crate_id, &check_options)?; Ok(context .get_all_test_functions_in_crate_matching(&crate_id, fn_name) From c23463e4d57e7cb4acf701e4f9485e2e8eaee8a1 Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 28 Aug 2024 14:30:39 -0500 Subject: [PATCH 10/21] chore: Add missing cases to arithmetic generics (#5841) # Description ## Problem\* ## Summary\* In the initial arithmetic generics PR we only added the one specific case for simplifying `(N + C1) - C2`. Later in the associated types PR we added another case to simplify the non-constant `(N + M) - M` This PR fills in the missing cases for each other operator. It also has somewhat better overflow handling by returning an `Option` in the operator function and removing the wrapping operations. ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [x] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [ ] I have tested the changes locally. - [ ] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: Michael J Klein --- .../noirc_frontend/src/elaborator/types.rs | 9 +- .../src/hir/resolution/errors.rs | 9 + .../src/hir/type_check/generics.rs | 1 + compiler/noirc_frontend/src/hir_def/types.rs | 291 +++++++++--------- .../src/hir_def/types/arithmetic.rs | 215 +++++++++++++ .../src/monomorphization/errors.rs | 2 +- .../src/monomorphization/mod.rs | 2 + .../arithmetic_generics/src/main.nr | 31 ++ 8 files changed, 414 insertions(+), 146 deletions(-) create mode 100644 compiler/noirc_frontend/src/hir_def/types/arithmetic.rs diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 3b1ffeb2fc2..8e4c9aa4af1 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -445,14 +445,19 @@ impl<'context> Elaborator<'context> { }) } UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), - UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); let lhs = self.convert_expression_type(*lhs); let rhs = self.convert_expression_type(*rhs); match (lhs, rhs) { (Type::Constant(lhs), Type::Constant(rhs)) => { - Type::Constant(op.function(lhs, rhs)) + if let Some(result) = op.function(lhs, rhs) { + Type::Constant(result) + } else { + self.push_err(ResolverError::OverflowInType { lhs, op, rhs, span }); + Type::Error + } } (lhs, rhs) => { if !self.enable_arithmetic_generics { diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index e5a89e61fc2..0b0d8d735eb 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -122,6 +122,8 @@ pub enum ResolverError { NamedTypeArgs { span: Span, item_kind: &'static str }, #[error("Associated constants may only be a field or integer type")] AssociatedConstantsMustBeNumeric { span: Span }, + #[error("Overflow in `{lhs} {op} {rhs}`")] + OverflowInType { lhs: u32, op: crate::BinaryTypeOperator, rhs: u32, span: Span }, } impl ResolverError { @@ -491,6 +493,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) } + ResolverError::OverflowInType { lhs, op, rhs, span } => { + Diagnostic::simple_error( + format!("Overflow in `{lhs} {op} {rhs}`"), + "Overflow here".to_string(), + *span, + ) + } } } } diff --git a/compiler/noirc_frontend/src/hir/type_check/generics.rs b/compiler/noirc_frontend/src/hir/type_check/generics.rs index 379c53944e5..697c78745f9 100644 --- a/compiler/noirc_frontend/src/hir/type_check/generics.rs +++ b/compiler/noirc_frontend/src/hir/type_check/generics.rs @@ -160,6 +160,7 @@ fn fmt_trait_generics( write!(f, "{} = {}", named.name, named.typ)?; } } + write!(f, ">")?; } Ok(()) } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 807666f9af9..c59c86b9616 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -24,7 +24,9 @@ use super::{ traits::NamedType, }; -#[derive(PartialEq, Eq, Clone, Hash, Ord, PartialOrd)] +mod arithmetic; + +#[derive(Eq, Clone, Ord, PartialOrd)] pub enum Type { /// A primitive Field type FieldElement, @@ -1657,132 +1659,6 @@ impl Type { } } - /// Try to canonicalize the representation of this type. - /// Currently the only type with a canonical representation is - /// `Type::Infix` where for each consecutive commutative operator - /// we sort the non-constant operands by `Type: Ord` and place all constant - /// operands at the end, constant folded. - /// - /// For example: - /// - `canonicalize[((1 + N) + M) + 2] = (M + N) + 3` - /// - `canonicalize[A + 2 * B + 3 - 2] = A + (B * 2) + 3 - 2` - pub fn canonicalize(&self) -> Type { - match self.follow_bindings() { - Type::InfixExpr(lhs, op, rhs) => { - // evaluate_to_u32 also calls canonicalize so if we just called - // `self.evaluate_to_u32()` we'd get infinite recursion. - if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { - return Type::Constant(op.function(lhs, rhs)); - } - - let lhs = lhs.canonicalize(); - let rhs = rhs.canonicalize(); - if let Some(result) = Self::try_simplify_addition(&lhs, op, &rhs) { - return result; - } - - if let Some(result) = Self::try_simplify_subtraction(&lhs, op, &rhs) { - return result; - } - - if op.is_commutative() { - return Self::sort_commutative(&lhs, op, &rhs); - } - - Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) - } - other => other, - } - } - - fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type { - let mut queue = vec![lhs.clone(), rhs.clone()]; - - let mut sorted = BTreeSet::new(); - - let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 }; - let mut constant = zero_value; - - // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. - while let Some(item) = queue.pop() { - match item.canonicalize() { - Type::InfixExpr(lhs, new_op, rhs) if new_op == op => { - queue.push(*lhs); - queue.push(*rhs); - } - Type::Constant(new_constant) => { - constant = op.function(constant, new_constant); - } - other => { - sorted.insert(other); - } - } - } - - if let Some(first) = sorted.pop_first() { - let mut typ = first.clone(); - - for rhs in sorted { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); - } - - if constant != zero_value { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(Type::Constant(constant))); - } - - typ - } else { - // Every type must have been a constant - Type::Constant(constant) - } - } - - /// Try to simplify an addition expression of `lhs + rhs`. - /// - /// - Simplifies `(a - b) + b` to `a`. - fn try_simplify_addition(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { - use BinaryTypeOperator::*; - match lhs { - Type::InfixExpr(l_lhs, l_op, l_rhs) => { - if op == Addition && *l_op == Subtraction { - // TODO: Propagate type bindings. Can do in another PR, this one is large enough. - let unifies = l_rhs.try_unify(rhs, &mut TypeBindings::new()); - if unifies.is_ok() { - return Some(l_lhs.as_ref().clone()); - } - } - None - } - _ => None, - } - } - - /// Try to simplify a subtraction expression of `lhs - rhs`. - /// - /// - Simplifies `(a + C1) - C2` to `a + (C1 - C2)` if C1 and C2 are constants. - fn try_simplify_subtraction(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { - use BinaryTypeOperator::*; - match lhs { - Type::InfixExpr(l_lhs, l_op, l_rhs) => { - // Simplify `(N + 2) - 1` - if op == Subtraction && *l_op == Addition { - if let (Some(lhs_const), Some(rhs_const)) = - (l_rhs.evaluate_to_u32(), rhs.evaluate_to_u32()) - { - if lhs_const > rhs_const { - let constant = Box::new(Type::Constant(lhs_const - rhs_const)); - return Some( - Type::InfixExpr(l_lhs.clone(), *l_op, constant).canonicalize(), - ); - } - } - } - None - } - _ => None, - } - } - /// Try to unify a type variable to `self`. /// This is a helper function factored out from try_unify. fn try_unify_to_type_variable( @@ -1926,7 +1802,7 @@ impl Type { Type::InfixExpr(lhs, op, rhs) => { let lhs = lhs.evaluate_to_u32()?; let rhs = rhs.evaluate_to_u32()?; - Some(op.function(lhs, rhs)) + op.function(lhs, rhs) } _ => None, } @@ -2030,17 +1906,13 @@ impl Type { Type::Forall(typevars, typ) => { assert_eq!(types.len() + implicit_generic_count, typevars.len(), "Turbofish operator used with incorrect generic count which was not caught by name resolution"); + let bindings = + (0..implicit_generic_count).map(|_| interner.next_type_variable()).chain(types); + let replacements = typevars .iter() - .enumerate() - .map(|(i, var)| { - let binding = if i < implicit_generic_count { - interner.next_type_variable() - } else { - types[i - implicit_generic_count].clone() - }; - (var.id(), (var.clone(), binding)) - }) + .zip(bindings) + .map(|(var, binding)| (var.id(), (var.clone(), binding))) .collect(); let instantiated = typ.substitute(&replacements); @@ -2457,13 +2329,13 @@ fn convert_array_expression_to_slice( impl BinaryTypeOperator { /// Perform the actual rust numeric operation associated with this operator - pub fn function(self, a: u32, b: u32) -> u32 { + pub fn function(self, a: u32, b: u32) -> Option { match self { - BinaryTypeOperator::Addition => a.wrapping_add(b), - BinaryTypeOperator::Subtraction => a.wrapping_sub(b), - BinaryTypeOperator::Multiplication => a.wrapping_mul(b), - BinaryTypeOperator::Division => a.wrapping_div(b), - BinaryTypeOperator::Modulo => a.wrapping_rem(b), + BinaryTypeOperator::Addition => a.checked_add(b), + BinaryTypeOperator::Subtraction => a.checked_sub(b), + BinaryTypeOperator::Multiplication => a.checked_mul(b), + BinaryTypeOperator::Division => a.checked_div(b), + BinaryTypeOperator::Modulo => a.checked_rem(b), } } @@ -2681,3 +2553,136 @@ impl std::fmt::Debug for StructType { write!(f, "{}", self.name) } } + +impl std::hash::Hash for Type { + fn hash(&self, state: &mut H) { + if let Some(variable) = self.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + typ.hash(state); + return; + } + } + + if !matches!(self, Type::TypeVariable(..) | Type::NamedGeneric(..)) { + std::mem::discriminant(self).hash(state); + } + + match self { + Type::FieldElement | Type::Bool | Type::Unit | Type::Error => (), + Type::Array(len, elem) => { + len.hash(state); + elem.hash(state); + } + Type::Slice(elem) => elem.hash(state), + Type::Integer(sign, bits) => { + sign.hash(state); + bits.hash(state); + } + Type::String(len) => len.hash(state), + Type::FmtString(len, env) => { + len.hash(state); + env.hash(state); + } + Type::Tuple(elems) => elems.hash(state), + Type::Struct(def, args) => { + def.hash(state); + args.hash(state); + } + Type::Alias(alias, args) => { + alias.hash(state); + args.hash(state); + } + Type::TypeVariable(var, _) | Type::NamedGeneric(var, ..) => var.hash(state), + Type::TraitAsType(trait_id, _, args) => { + trait_id.hash(state); + args.hash(state); + } + Type::Function(args, ret, env, is_unconstrained) => { + args.hash(state); + ret.hash(state); + env.hash(state); + is_unconstrained.hash(state); + } + Type::MutableReference(elem) => elem.hash(state), + Type::Forall(vars, typ) => { + vars.hash(state); + typ.hash(state); + } + Type::Constant(value) => value.hash(state), + Type::Quoted(typ) => typ.hash(state), + Type::InfixExpr(lhs, op, rhs) => { + lhs.hash(state); + op.hash(state); + rhs.hash(state); + } + } + } +} + +impl PartialEq for Type { + fn eq(&self, other: &Self) -> bool { + if let Some(variable) = self.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + return typ == other; + } + } + + if let Some(variable) = other.get_inner_type_variable() { + if let TypeBinding::Bound(typ) = &*variable.borrow() { + return self == typ; + } + } + + use Type::*; + match (self, other) { + (FieldElement, FieldElement) | (Bool, Bool) | (Unit, Unit) | (Error, Error) => true, + (Array(lhs_len, lhs_elem), Array(rhs_len, rhs_elem)) => { + lhs_len == rhs_len && lhs_elem == rhs_elem + } + (Slice(lhs_elem), Slice(rhs_elem)) => lhs_elem == rhs_elem, + (Integer(lhs_sign, lhs_bits), Integer(rhs_sign, rhs_bits)) => { + lhs_sign == rhs_sign && lhs_bits == rhs_bits + } + (String(lhs_len), String(rhs_len)) => lhs_len == rhs_len, + (FmtString(lhs_len, lhs_env), FmtString(rhs_len, rhs_env)) => { + lhs_len == rhs_len && lhs_env == rhs_env + } + (Tuple(lhs_types), Tuple(rhs_types)) => lhs_types == rhs_types, + (Struct(lhs_struct, lhs_generics), Struct(rhs_struct, rhs_generics)) => { + lhs_struct == rhs_struct && lhs_generics == rhs_generics + } + (Alias(lhs_alias, lhs_generics), Alias(rhs_alias, rhs_generics)) => { + lhs_alias == rhs_alias && lhs_generics == rhs_generics + } + (TraitAsType(lhs_trait, _, lhs_generics), TraitAsType(rhs_trait, _, rhs_generics)) => { + lhs_trait == rhs_trait && lhs_generics == rhs_generics + } + ( + Function(lhs_args, lhs_ret, lhs_env, lhs_unconstrained), + Function(rhs_args, rhs_ret, rhs_env, rhs_unconstrained), + ) => { + let args_and_ret_eq = lhs_args == rhs_args && lhs_ret == rhs_ret; + args_and_ret_eq && lhs_env == rhs_env && lhs_unconstrained == rhs_unconstrained + } + (MutableReference(lhs_elem), MutableReference(rhs_elem)) => lhs_elem == rhs_elem, + (Forall(lhs_vars, lhs_type), Forall(rhs_vars, rhs_type)) => { + lhs_vars == rhs_vars && lhs_type == rhs_type + } + (Constant(lhs), Constant(rhs)) => lhs == rhs, + (Quoted(lhs), Quoted(rhs)) => lhs == rhs, + (InfixExpr(l_lhs, l_op, l_rhs), InfixExpr(r_lhs, r_op, r_rhs)) => { + l_lhs == r_lhs && l_op == r_op && l_rhs == r_rhs + } + // Special case: we consider unbound named generics and type variables to be equal to each + // other if their type variable ids match. This is important for some corner cases in + // monomorphization where we call `replace_named_generics_with_type_variables` but + // still want them to be equal for canonicalization checks in arithmetic generics. + // Without this we'd fail the `serialize` test. + ( + NamedGeneric(lhs_var, _, _) | TypeVariable(lhs_var, _), + NamedGeneric(rhs_var, _, _) | TypeVariable(rhs_var, _), + ) => lhs_var.id() == rhs_var.id(), + _ => false, + } + } +} diff --git a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs new file mode 100644 index 00000000000..ad07185dff1 --- /dev/null +++ b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -0,0 +1,215 @@ +use std::collections::BTreeSet; + +use crate::{BinaryTypeOperator, Type}; + +impl Type { + /// Try to canonicalize the representation of this type. + /// Currently the only type with a canonical representation is + /// `Type::Infix` where for each consecutive commutative operator + /// we sort the non-constant operands by `Type: Ord` and place all constant + /// operands at the end, constant folded. + /// + /// For example: + /// - `canonicalize[((1 + N) + M) + 2] = (M + N) + 3` + /// - `canonicalize[A + 2 * B + 3 - 2] = A + (B * 2) + 3 - 2` + pub fn canonicalize(&self) -> Type { + match self.follow_bindings() { + Type::InfixExpr(lhs, op, rhs) => { + // evaluate_to_u32 also calls canonicalize so if we just called + // `self.evaluate_to_u32()` we'd get infinite recursion. + if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { + if let Some(result) = op.function(lhs, rhs) { + return Type::Constant(result); + } + } + + let lhs = lhs.canonicalize(); + let rhs = rhs.canonicalize(); + if let Some(result) = Self::try_simplify_non_constants_in_lhs(&lhs, op, &rhs) { + return result.canonicalize(); + } + + if let Some(result) = Self::try_simplify_non_constants_in_rhs(&lhs, op, &rhs) { + return result.canonicalize(); + } + + // Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + // where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) + if let Some(result) = Self::try_simplify_partial_constants(&lhs, op, &rhs) { + return result.canonicalize(); + } + + if op.is_commutative() { + return Self::sort_commutative(&lhs, op, &rhs); + } + + Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) + } + other => other, + } + } + + fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type { + let mut queue = vec![lhs.clone(), rhs.clone()]; + + let mut sorted = BTreeSet::new(); + + let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 }; + let mut constant = zero_value; + + // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. + while let Some(item) = queue.pop() { + match item.canonicalize() { + Type::InfixExpr(lhs, new_op, rhs) if new_op == op => { + queue.push(*lhs); + queue.push(*rhs); + } + Type::Constant(new_constant) => { + if let Some(result) = op.function(constant, new_constant) { + constant = result; + } else { + sorted.insert(Type::Constant(new_constant)); + } + } + other => { + sorted.insert(other); + } + } + } + + if let Some(first) = sorted.pop_first() { + let mut typ = first.clone(); + + for rhs in sorted { + typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); + } + + if constant != zero_value { + typ = Type::InfixExpr(Box::new(typ), op, Box::new(Type::Constant(constant))); + } + + typ + } else { + // Every type must have been a constant + Type::Constant(constant) + } + } + + /// Try to simplify non-constant expressions in the form `(N op1 M) op2 M` + /// where the two `M` terms are expected to cancel out. + /// Precondition: `lhs & rhs are in canonical form` + /// + /// - Simplifies `(N +/- M) -/+ M` to `N` + /// - Simplifies `(N */÷ M) ÷/* M` to `N` + fn try_simplify_non_constants_in_lhs( + lhs: &Type, + op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + let Type::InfixExpr(l_lhs, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + // Note that this is exact, syntactic equality, not unification. + // `rhs` is expected to already be in canonical form. + if l_op.inverse() != Some(op) || l_rhs.canonicalize() != *rhs { + return None; + } + + Some(*l_lhs) + } + + /// Try to simplify non-constant expressions in the form `N op1 (M op1 N)` + /// where the two `M` terms are expected to cancel out. + /// Precondition: `lhs & rhs are in canonical form` + /// + /// Unlike `try_simplify_non_constants_in_lhs` we can't simplify `N / (M * N)` + /// Since that should simplify to `1 / M` instead of `M`. + /// + /// - Simplifies `N +/- (M -/+ N)` to `M` + /// - Simplifies `N * (M ÷ N)` to `M` + fn try_simplify_non_constants_in_rhs( + lhs: &Type, + op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + let Type::InfixExpr(r_lhs, r_op, r_rhs) = rhs.follow_bindings() else { + return None; + }; + + // `N / (M * N)` should be simplified to `1 / M`, but we only handle + // simplifying to `M` in this function. + if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication { + return None; + } + + // Note that this is exact, syntactic equality, not unification. + // `lhs` is expected to already be in canonical form. + if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize() { + return None; + } + + Some(*r_lhs) + } + + /// Given: + /// lhs = `N op C1` + /// rhs = C2 + /// Returns: `(N, op, C1, C2)` if C1 and C2 are constants. + /// Note that the operator here is within the `lhs` term, the operator + /// separating lhs and rhs is not needed. + /// Precondition: `lhs & rhs are in canonical form` + fn parse_partial_constant_expr( + lhs: &Type, + rhs: &Type, + ) -> Option<(Box, BinaryTypeOperator, u32, u32)> { + let rhs = rhs.evaluate_to_u32()?; + + let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { + return None; + }; + + let l_rhs = l_rhs.evaluate_to_u32()?; + Some((l_type, l_op, l_rhs, rhs)) + } + + /// Try to simplify partially constant expressions in the form `(N op1 C1) op2 C2` + /// where C1 and C2 are constants that can be combined (e.g. N + 5 - 3 = N + 2) + /// Precondition: `lhs & rhs are in canonical form` + /// + /// - Simplifies `(N +/- C1) +/- C2` to `N +/- (C1 +/- C2)` if C1 and C2 are constants. + /// - Simplifies `(N */÷ C1) */÷ C2` to `N */÷ (C1 */÷ C2)` if C1 and C2 are constants. + fn try_simplify_partial_constants( + lhs: &Type, + mut op: BinaryTypeOperator, + rhs: &Type, + ) -> Option { + use BinaryTypeOperator::*; + let (l_type, l_op, l_const, r_const) = Type::parse_partial_constant_expr(lhs, rhs)?; + + match (l_op, op) { + (Addition | Subtraction, Addition | Subtraction) => { + // If l_op is a subtraction we want to inverse the rhs operator. + if l_op == Subtraction { + op = op.inverse()?; + } + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) + } + (Multiplication | Division, Multiplication | Division) => { + // If l_op is a division we want to inverse the rhs operator. + if l_op == Division { + op = op.inverse()?; + } + // If op is a division we need to ensure it divides evenly + if op == Division && (r_const == 0 || l_const % r_const != 0) { + None + } else { + let result = op.function(l_const, r_const)?; + Some(Type::InfixExpr(l_type, l_op, Box::new(Type::Constant(result)))) + } + } + _ => None, + } + } +} diff --git a/compiler/noirc_frontend/src/monomorphization/errors.rs b/compiler/noirc_frontend/src/monomorphization/errors.rs index 665bf26f7b9..ce8ef3572e6 100644 --- a/compiler/noirc_frontend/src/monomorphization/errors.rs +++ b/compiler/noirc_frontend/src/monomorphization/errors.rs @@ -34,7 +34,7 @@ impl MonomorphizationError { fn into_diagnostic(self) -> CustomDiagnostic { let message = match &self { MonomorphizationError::UnknownArrayLength { length, .. } => { - format!("ICE: Could not determine array length `{length}`") + format!("Could not determine array length `{length}`") } MonomorphizationError::NoDefaultType { location } => { let message = "Type annotation needed".into(); diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 79ac02710d9..87b55540bbd 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -301,6 +301,7 @@ impl<'interner> Monomorphizer<'interner> { } let meta = self.interner.function_meta(&f).clone(); + let mut func_sig = meta.function_signature(); // Follow the bindings of the function signature for entry points // which are not `main` such as foldable functions. @@ -1958,6 +1959,7 @@ pub fn resolve_trait_method( TraitImplKind::Normal(impl_id) => impl_id, TraitImplKind::Assumed { object_type, trait_generics } => { let location = interner.expr_location(&expr_id); + match interner.lookup_trait_implementation( &object_type, method.trait_id, diff --git a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr index 6cd13ab0e2f..ad8dff6c7b9 100644 --- a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr +++ b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr @@ -7,6 +7,9 @@ fn main() { let _ = split_first([1, 2, 3]); let _ = push_multiple([1, 2, 3]); + + test_constant_folding::<10>(); + test_non_constant_folding::<10, 20>(); } fn split_first(array: [T; N]) -> (T, [T; N - 1]) { @@ -101,3 +104,31 @@ fn demo_proof() -> Equiv, (Equiv, (), W, () let p3: Equiv, (), W, ()> = add_equiv_r::(p3_sub); equiv_trans(equiv_trans(p1, p2), p3) } + +fn test_constant_folding() { + // N + C1 - C2 = N + (C1 - C2) + let _: W = W:: {}; + + // N - C1 + C2 = N - (C1 - C2) + let _: W = W:: {}; + + // N * C1 / C2 = N * (C1 / C2) + let _: W = W:: {}; + + // N / C1 * C2 = N / (C1 / C2) + let _: W = W:: {}; +} + +fn test_non_constant_folding() { + // N + M - M = N + let _: W = W:: {}; + + // N - M + M = N + let _: W = W:: {}; + + // N * M / M = N + let _: W = W:: {}; + + // N / M * M = N + let _: W = W:: {}; +} From 5739904f8d9e6c00d9e140cd4926b4d149412476 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:15:49 +0100 Subject: [PATCH 11/21] feat: remove unnecessary copying of vector size during reversal (#5852) # Description ## Problem\* Resolves ## Summary\* In order to calculate `vector_size - 1 - iterator`, rather than initialising a register to `vector_size` and then decrementing it on each loop iteration, we're currently recalculating it from scratch on each iteration. I've modified this so that we can avoid these extra ops ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../src/brillig/brillig_ir/codegen_memory.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs index d20f736ee6d..ec3b080895b 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_memory.rs @@ -263,20 +263,14 @@ impl BrilligContext { let index_at_end_of_array = self.allocate_register(); let end_value_register = self.allocate_register(); - self.codegen_loop(iteration_count, |ctx, iterator_register| { - // Load both values - ctx.codegen_array_get(vector.pointer, iterator_register, start_value_register); + self.mov_instruction(index_at_end_of_array, vector.size); + self.codegen_loop(iteration_count, |ctx, iterator_register| { // The index at the end of array is size - 1 - iterator - ctx.mov_instruction(index_at_end_of_array, vector.size); ctx.codegen_usize_op_in_place(index_at_end_of_array, BrilligBinaryOp::Sub, 1); - ctx.memory_op_instruction( - index_at_end_of_array, - iterator_register.address, - index_at_end_of_array, - BrilligBinaryOp::Sub, - ); + // Load both values + ctx.codegen_array_get(vector.pointer, iterator_register, start_value_register); ctx.codegen_array_get( vector.pointer, SingleAddrVariable::new_usize(index_at_end_of_array), From 130b7b6871ad165a75df5fa5760c94a7402521f4 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Wed, 28 Aug 2024 16:17:17 -0400 Subject: [PATCH 12/21] fix(sha256): Fix upper bound when building msg block and delay final block compression under certain cases (#5838) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description ## Problem\* Resolves #5836 ## Summary\* We accept a start index based upon the current block when parsing a message. We should accurately base the upper bound to be based upon this start index. We also have special handling for building a message block but not compressing it when the message is less than the block size. We need to also do this handling for the last message block when we have a message that is larger than the block size. ## Additional Context ~~sha256_var is currently getting warnings from the under constrained check. It looks to only be happening on the new regression test added as part of this PR that uses a larger message. The old sha256 tests do not look to trigger these warnings which is strange. I am a bit unsure why I am getting these warnings as msg block and msg block pointer are being verified on each iteration of the loop.~~ Screenshot 2024-08-27 at 12 04 58 PM EDIT: This was only happening as my test was hashing constant values, thus it was a dumb circuit. e.g the following: ```rust fn main(result: pub [u8; 32]) { let headers = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116]; let hash = std::hash::sha256_var(headers, headers.len() as u64); assert_eq(hash, result); } ``` The message needs to come from the inputs and the under-constrained warnings go away. ## Documentation\* Check one: - [X] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [X] I have tested the changes locally. - [X] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../src/ssa/acir_gen/acir_ir/acir_variable.rs | 1 - noir_stdlib/src/hash/sha256.nr | 150 +++++++++--------- .../Nargo.toml | 2 +- .../sha256_regression/Prover.toml | 9 ++ .../sha256_regression/src/main.nr | 26 +++ .../Prover.toml | 2 - .../src/main.nr | 9 -- 7 files changed, 112 insertions(+), 87 deletions(-) rename test_programs/execution_success/{sha256_var_witness_const_regression => sha256_regression}/Nargo.toml (64%) create mode 100644 test_programs/execution_success/sha256_regression/Prover.toml create mode 100644 test_programs/execution_success/sha256_regression/src/main.nr delete mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml delete mode 100644 test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 6d17484ee95..317cf43669c 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -1434,7 +1434,6 @@ impl AcirContext { name, BlackBoxFunc::MultiScalarMul | BlackBoxFunc::Keccakf1600 - | BlackBoxFunc::Sha256Compression | BlackBoxFunc::Blake2s | BlackBoxFunc::Blake3 | BlackBoxFunc::AND diff --git a/noir_stdlib/src/hash/sha256.nr b/noir_stdlib/src/hash/sha256.nr index 352df656068..d0e3d5e88c5 100644 --- a/noir_stdlib/src/hash/sha256.nr +++ b/noir_stdlib/src/hash/sha256.nr @@ -1,3 +1,5 @@ +use crate::runtime::is_unconstrained; + // Implementation of SHA-256 mapping a byte array of variable length to // 32 bytes. @@ -32,21 +34,17 @@ fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { msg32 } -unconstrained fn build_msg_block_iter( - msg: [u8; N], - message_size: u64, - mut msg_block: [u8; 64], - msg_start: u32 -) -> ([u8; 64], u64) { +unconstrained fn build_msg_block_iter(msg: [u8; N], message_size: u64, msg_start: u32) -> ([u8; 64], u64) { + let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE]; let mut msg_byte_ptr: u64 = 0; // Message byte pointer - for k in msg_start..N { + let mut msg_end = msg_start + BLOCK_SIZE; + if msg_end > N { + msg_end = N; + } + for k in msg_start..msg_end { if k as u64 < message_size { msg_block[msg_byte_ptr] = msg[k]; msg_byte_ptr = msg_byte_ptr + 1; - - if msg_byte_ptr == 64 { - msg_byte_ptr = 0; - } } } (msg_block, msg_byte_ptr) @@ -60,27 +58,32 @@ fn verify_msg_block( msg_start: u32 ) -> u64 { let mut msg_byte_ptr: u64 = 0; // Message byte pointer - for k in msg_start..N { + let mut msg_end = msg_start + BLOCK_SIZE; + let mut extra_bytes = 0; + if msg_end > N { + msg_end = N; + extra_bytes = msg_end - N; + } + + for k in msg_start..msg_end { if k as u64 < message_size { - assert_eq(msg_block[msg_byte_ptr], msg[k]); msg_byte_ptr = msg_byte_ptr + 1; - if msg_byte_ptr == 64 { - // Enough to hash block - msg_byte_ptr = 0; - } + } + } + + for i in 0..BLOCK_SIZE { + if i as u64 >= msg_byte_ptr { + assert_eq(msg_block[i], 0); } else { - // Need to assert over the msg block in the else case as well - if N < 64 { - assert_eq(msg_block[msg_byte_ptr], 0); - } else { - assert_eq(msg_block[msg_byte_ptr], msg[k]); - } + assert_eq(msg_block[i], msg[msg_start + i - extra_bytes]); } } + msg_byte_ptr } global BLOCK_SIZE = 64; +global ZERO = 0; // Variable size SHA-256 hash pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { @@ -89,38 +92,55 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value let mut msg_byte_ptr = 0; // Pointer into msg_block - if num_blocks == 0 { - unsafe { - let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, 0); - msg_block = new_msg_block; + for i in 0..num_blocks { + let (new_msg_block, new_msg_byte_ptr) = unsafe { + build_msg_block_iter(msg, message_size, BLOCK_SIZE * i) + }; + msg_block = new_msg_block; + + if !is_unconstrained() { + // Verify the block we are compressing was appropriately constructed + msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i); + } else { msg_byte_ptr = new_msg_byte_ptr; } - if !crate::runtime::is_unconstrained() { - msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, 0); - } + // Compress the block + h = sha256_compression(msg_u8_to_u32(msg_block), h); } - for i in 0..num_blocks { - unsafe { - let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, BLOCK_SIZE * i); - msg_block = new_msg_block; + let modulo = N % BLOCK_SIZE; + // Handle setup of the final msg block. + // This case is only hit if the msg is less than the block size, + // or our message cannot be evenly split into blocks. + if modulo != 0 { + let (new_msg_block, new_msg_byte_ptr) = unsafe { + build_msg_block_iter(msg, message_size, BLOCK_SIZE * num_blocks) + }; + msg_block = new_msg_block; + + if !is_unconstrained() { + msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * num_blocks); + } else { msg_byte_ptr = new_msg_byte_ptr; } - if !crate::runtime::is_unconstrained() { - // Verify the block we are compressing was appropriately constructed - msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i); - } + } - // Hash the block - h = sha256_compression(msg_u8_to_u32(msg_block), h); + if msg_byte_ptr == BLOCK_SIZE as u64 { + msg_byte_ptr = 0; } - let last_block = msg_block; + // This variable is used to get around the compiler under-constrained check giving a warning. + // We want to check against a constant zero, but if it does not come from the circuit inputs + // or return values the compiler check will issue a warning. + let zero = msg_block[0] - msg_block[0]; + // Pad the rest such that we have a [u32; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). msg_block[msg_byte_ptr] = 1 << 7; + let last_block = msg_block; msg_byte_ptr = msg_byte_ptr + 1; + unsafe { let (new_msg_block, new_msg_byte_ptr) = pad_msg_block(msg_block, msg_byte_ptr); msg_block = new_msg_block; @@ -131,18 +151,15 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { if !crate::runtime::is_unconstrained() { for i in 0..64 { - if i as u64 < msg_byte_ptr - 1 { - assert_eq(msg_block[i], last_block[i]); - } + assert_eq(msg_block[i], last_block[i]); } - assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7); // If i >= 57, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. // Not enough bits (64) to store length. Fill up with zeros. for _i in 57..64 { if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 { - assert_eq(msg_block[msg_byte_ptr], 0); + assert_eq(msg_block[msg_byte_ptr], zero); msg_byte_ptr += 1; } } @@ -154,34 +171,23 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { msg_byte_ptr = 0; } - unsafe { - msg_block = attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size); - } + msg_block = unsafe { + attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size) + }; if !crate::runtime::is_unconstrained() { - if msg_byte_ptr != 0 { - for i in 0..64 { - if i as u64 < msg_byte_ptr - 1 { - assert_eq(msg_block[i], last_block[i]); - } + for i in 0..56 { + if i < msg_byte_ptr { + assert_eq(msg_block[i], last_block[i]); + } else { + assert_eq(msg_block[i], zero); } - assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7); } let len = 8 * message_size; - let len_bytes = (len as Field).to_le_bytes(8); - // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). - for _ in 0..64 { - if msg_byte_ptr < 56 { - assert_eq(msg_block[msg_byte_ptr], 0); - msg_byte_ptr = msg_byte_ptr + 1; - } - } - - let mut block_idx = 0; + let len_bytes = (len as Field).to_be_bytes(8); for i in 56..64 { - assert_eq(msg_block[63 - block_idx], len_bytes[i - 56]); - block_idx = block_idx + 1; + assert_eq(msg_block[i], len_bytes[i - 56]); } } @@ -205,13 +211,9 @@ unconstrained fn pad_msg_block(mut msg_block: [u8; 64], mut msg_byte_ptr: u64) - (msg_block, msg_byte_ptr) } -unconstrained fn attach_len_to_msg_block( - mut msg_block: [u8; 64], - mut msg_byte_ptr: u64, - message_size: u64 -) -> [u8; 64] { +unconstrained fn attach_len_to_msg_block(mut msg_block: [u8; 64], mut msg_byte_ptr: u64, message_size: u64) -> [u8; 64] { let len = 8 * message_size; - let len_bytes = (len as Field).to_le_bytes(8); + let len_bytes = (len as Field).to_be_bytes(8); for _i in 0..64 { // In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56). if msg_byte_ptr < 56 { @@ -219,7 +221,7 @@ unconstrained fn attach_len_to_msg_block( msg_byte_ptr = msg_byte_ptr + 1; } else if msg_byte_ptr < 64 { for j in 0..8 { - msg_block[63 - j] = len_bytes[j]; + msg_block[msg_byte_ptr + j] = len_bytes[j]; } msg_byte_ptr += 8; } diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml b/test_programs/execution_success/sha256_regression/Nargo.toml similarity index 64% rename from test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml rename to test_programs/execution_success/sha256_regression/Nargo.toml index e8f3e6bbe64..ce98d000bcb 100644 --- a/test_programs/execution_success/sha256_var_witness_const_regression/Nargo.toml +++ b/test_programs/execution_success/sha256_regression/Nargo.toml @@ -1,5 +1,5 @@ [package] -name = "sha256_var_witness_const_regression" +name = "sha256_regression" type = "bin" authors = [""] compiler_version = ">=0.33.0" diff --git a/test_programs/execution_success/sha256_regression/Prover.toml b/test_programs/execution_success/sha256_regression/Prover.toml new file mode 100644 index 00000000000..ba0aadd1b75 --- /dev/null +++ b/test_programs/execution_success/sha256_regression/Prover.toml @@ -0,0 +1,9 @@ +msg_just_over_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116] +msg_multiple_of_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99] +msg_just_under_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59] +msg_big_not_block_multiple = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99, 116, 58, 77, 101, 115, 115, 97, 103, 101, 45, 73, 100, 58, 68, 97, 116, 101, 58, 116, 111, 59, 32, 98, 61] +# Results matched against ethers library +result_just_over_block = [91, 122, 146, 93, 52, 109, 133, 148, 171, 61, 156, 70, 189, 238, 153, 7, 222, 184, 94, 24, 65, 114, 192, 244, 207, 199, 87, 232, 192, 224, 171, 207] +result_multiple_of_block = [116, 90, 151, 31, 78, 22, 138, 180, 211, 189, 69, 76, 227, 200, 155, 29, 59, 123, 154, 60, 47, 153, 203, 129, 157, 251, 48, 2, 79, 11, 65, 47] +result_just_under_block = [143, 140, 76, 173, 222, 123, 102, 68, 70, 149, 207, 43, 39, 61, 34, 79, 216, 252, 213, 165, 74, 16, 110, 74, 29, 64, 138, 167, 30, 1, 9, 119] +result_big = [112, 144, 73, 182, 208, 98, 9, 238, 54, 229, 61, 145, 222, 17, 72, 62, 148, 222, 186, 55, 192, 82, 220, 35, 66, 47, 193, 200, 22, 38, 26, 186] diff --git a/test_programs/execution_success/sha256_regression/src/main.nr b/test_programs/execution_success/sha256_regression/src/main.nr new file mode 100644 index 00000000000..855931b4300 --- /dev/null +++ b/test_programs/execution_success/sha256_regression/src/main.nr @@ -0,0 +1,26 @@ +fn main( + msg_just_over_block: [u8; 68], + result_just_over_block: pub [u8; 32], + msg_multiple_of_block: [u8; 448], + result_multiple_of_block: pub [u8; 32], + // We want to make sure we are testing a message with a size >= 57 but < 64 + msg_just_under_block: [u8; 60], + result_just_under_block: pub [u8; 32], + msg_big_not_block_multiple: [u8; 472], + result_big: pub [u8; 32] +) { + let hash = std::hash::sha256_var(msg_just_over_block, msg_just_over_block.len() as u64); + assert_eq(hash, result_just_over_block); + + let hash = std::hash::sha256_var(msg_multiple_of_block, msg_multiple_of_block.len() as u64); + assert_eq(hash, result_multiple_of_block); + + let hash = std::hash::sha256_var(msg_just_under_block, msg_just_under_block.len() as u64); + assert_eq(hash, result_just_under_block); + + let hash = std::hash::sha256_var( + msg_big_not_block_multiple, + msg_big_not_block_multiple.len() as u64 + ); + assert_eq(hash, result_big); +} diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml b/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml deleted file mode 100644 index 7b91051c1a0..00000000000 --- a/test_programs/execution_success/sha256_var_witness_const_regression/Prover.toml +++ /dev/null @@ -1,2 +0,0 @@ -input = [0, 0] -toggle = false \ No newline at end of file diff --git a/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr b/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr deleted file mode 100644 index 97c4435d41d..00000000000 --- a/test_programs/execution_success/sha256_var_witness_const_regression/src/main.nr +++ /dev/null @@ -1,9 +0,0 @@ -fn main(input: [u8; 2], toggle: bool) { - let size: Field = 1 + toggle as Field; - assert(!toggle); - - let variable_sha = std::sha256::sha256_var(input, size as u64); - let constant_sha = std::sha256::sha256_var(input, 1); - - assert_eq(variable_sha, constant_sha); -} From 39b30ba2e9f13d8d99bfb1833e14e294f80773e5 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Wed, 28 Aug 2024 18:09:07 -0300 Subject: [PATCH 13/21] feat: add `FunctionDef::body` (#5825) # Description ## Problem Part of #5668 ## Summary Also allows `quote { ... }.as_expr()` to work with statements and L-values. Then adds an example that injects a `_context: Context` parameter to functions annotated with `#[aztec]`, and to calls inside those functions (any call for now). Also adds an incomplete `Expr::map` to help doing this. Oh, and allows unquoting `Value::Expr`, which wasn't possible before. ## Additional Context A lot is remaining, but I thought this was a good cutting point. Next I might work on #5828 because debugging comptime code without it is tricky. I didn't add docs for `Expr::map` yet as it's still not fully functional. I'll add them once it works for in all cases. ## Documentation Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [x] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- aztec_macros/src/utils/parse_utils.rs | 12 +- compiler/noirc_frontend/src/ast/expression.rs | 8 +- compiler/noirc_frontend/src/ast/mod.rs | 7 +- compiler/noirc_frontend/src/ast/statement.rs | 48 ++- compiler/noirc_frontend/src/debug/mod.rs | 3 + .../src/elaborator/expressions.rs | 5 + .../src/elaborator/statements.rs | 9 + .../noirc_frontend/src/elaborator/types.rs | 4 + .../src/hir/comptime/interpreter/builtin.rs | 237 +++++++++----- .../interpreter/builtin/builtin_helpers.rs | 33 +- .../noirc_frontend/src/hir/comptime/value.rs | 264 ++++++++++++++- compiler/noirc_frontend/src/lexer/token.rs | 39 ++- compiler/noirc_frontend/src/node_interner.rs | 60 ++++ compiler/noirc_frontend/src/parser/mod.rs | 2 +- compiler/noirc_frontend/src/parser/parser.rs | 27 +- .../src/parser/parser/primitives.rs | 7 + .../noirc_frontend/src/parser/parser/types.rs | 10 + docs/docs/noir/standard_library/meta/expr.md | 18 ++ .../standard_library/meta/function_def.md | 10 +- docs/docs/noir/standard_library/meta/op.md | 12 + noir_stdlib/src/meta/expr.nr | 306 ++++++++++++++++++ noir_stdlib/src/meta/function_def.nr | 7 +- noir_stdlib/src/meta/op.nr | 56 ++++ .../comptime_function_definition/src/main.nr | 2 +- .../inject_context_attribute/Nargo.toml | 6 + .../inject_context_attribute/src/main.nr | 53 +++ .../comptime_expr/src/main.nr | 256 +++++++++++++++ tooling/lsp/src/requests/completion.rs | 12 +- .../lsp/src/requests/completion/builtins.rs | 2 +- tooling/lsp/src/requests/inlay_hint.rs | 9 +- .../src/requests/signature_help/traversal.rs | 7 +- tooling/nargo_fmt/src/rewrite/expr.rs | 3 + tooling/nargo_fmt/src/rewrite/typ.rs | 2 +- tooling/nargo_fmt/src/visitor/stmt.rs | 3 + 34 files changed, 1423 insertions(+), 116 deletions(-) create mode 100644 test_programs/compile_success_empty/inject_context_attribute/Nargo.toml create mode 100644 test_programs/compile_success_empty/inject_context_attribute/src/main.nr diff --git a/aztec_macros/src/utils/parse_utils.rs b/aztec_macros/src/utils/parse_utils.rs index 4c6cbb10d9f..f2998fbaafc 100644 --- a/aztec_macros/src/utils/parse_utils.rs +++ b/aztec_macros/src/utils/parse_utils.rs @@ -218,7 +218,10 @@ fn empty_statement(statement: &mut Statement) { StatementKind::For(for_loop_statement) => empty_for_loop_statement(for_loop_statement), StatementKind::Comptime(statement) => empty_statement(statement), StatementKind::Semi(expression) => empty_expression(expression), - StatementKind::Break | StatementKind::Continue | StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -271,12 +274,15 @@ fn empty_expression(expression: &mut Expression) { ExpressionKind::Unsafe(block_expression, _span) => { empty_block_expression(block_expression); } - ExpressionKind::Quote(..) | ExpressionKind::Resolved(_) | ExpressionKind::Error => (), ExpressionKind::AsTraitPath(path) => { empty_unresolved_type(&mut path.typ); empty_path(&mut path.trait_path); empty_ident(&mut path.impl_item); } + ExpressionKind::Quote(..) + | ExpressionKind::Resolved(_) + | ExpressionKind::Interned(_) + | ExpressionKind::Error => (), } } @@ -353,6 +359,7 @@ fn empty_unresolved_type(unresolved_type: &mut UnresolvedType) { | UnresolvedTypeData::Unit | UnresolvedTypeData::Quoted(_) | UnresolvedTypeData::Resolved(_) + | UnresolvedTypeData::Interned(_) | UnresolvedTypeData::Unspecified | UnresolvedTypeData::Error => (), } @@ -531,6 +538,7 @@ fn empty_lvalue(lvalue: &mut LValue) { empty_expression(index); } LValue::Dereference(lvalue, _) => empty_lvalue(lvalue), + LValue::Interned(..) => (), } } diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index dc07f55ee33..f242180134d 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -7,7 +7,7 @@ use crate::ast::{ }; use crate::hir::def_collector::errors::DefCollectorErrorKind; use crate::macros_api::StructId; -use crate::node_interner::{ExprId, QuotedTypeId}; +use crate::node_interner::{ExprId, InternedExpressionKind, QuotedTypeId}; use crate::token::{Attributes, FunctionAttribute, Token, Tokens}; use crate::{Kind, Type}; use acvm::{acir::AcirField, FieldElement}; @@ -43,6 +43,11 @@ pub enum ExpressionKind { // code. It is used to translate function values back into the AST while // guaranteeing they have the same instantiated type and definition id without resolving again. Resolved(ExprId), + + // This is an interned ExpressionKind during comptime code. + // The actual ExpressionKind can be retrieved with a NodeInterner. + Interned(InternedExpressionKind), + Error, } @@ -603,6 +608,7 @@ impl Display for ExpressionKind { Unsafe(block, _) => write!(f, "unsafe {block}"), Error => write!(f, "Error"), Resolved(_) => write!(f, "?Resolved"), + Interned(_) => write!(f, "?Interned"), Unquote(expr) => write!(f, "$({expr})"), Quote(tokens) => { let tokens = vecmap(&tokens.0, ToString::to_string); diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index 6f6d5cbccdf..b10e58aac0c 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -22,7 +22,7 @@ pub use traits::*; pub use type_alias::*; use crate::{ - node_interner::QuotedTypeId, + node_interner::{InternedUnresolvedTypeData, QuotedTypeId}, parser::{ParserError, ParserErrorReason}, token::IntType, BinaryTypeOperator, @@ -141,6 +141,10 @@ pub enum UnresolvedTypeData { /// as a result of being spliced into a macro's token stream input. Resolved(QuotedTypeId), + // This is an interned UnresolvedTypeData during comptime code. + // The actual UnresolvedTypeData can be retrieved with a NodeInterner. + Interned(InternedUnresolvedTypeData), + Unspecified, // This is for when the user declares a variable without specifying it's type Error, } @@ -297,6 +301,7 @@ impl std::fmt::Display for UnresolvedTypeData { Unspecified => write!(f, "unspecified"), Parenthesized(typ) => write!(f, "({typ})"), Resolved(_) => write!(f, "(resolved type)"), + Interned(_) => write!(f, "?Interned"), AsTraitPath(path) => write!(f, "{path}"), } } diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index edccf545a02..c88fcba749b 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -13,6 +13,7 @@ use super::{ use crate::elaborator::types::SELF_TYPE_NAME; use crate::lexer::token::SpannedToken; use crate::macros_api::{SecondaryAttribute, UnresolvedTypeData}; +use crate::node_interner::{InternedExpressionKind, InternedStatementKind}; use crate::parser::{ParserError, ParserErrorReason}; use crate::token::Token; @@ -45,6 +46,9 @@ pub enum StatementKind { Comptime(Box), // This is an expression with a trailing semi-colon Semi(Expression), + // This is an interned StatementKind during comptime code. + // The actual StatementKind can be retrieved with a NodeInterner. + Interned(InternedStatementKind), // This statement is the result of a recovered parse error. // To avoid issuing multiple errors in later steps, it should // be skipped in any future analysis if possible. @@ -97,6 +101,9 @@ impl StatementKind { // A semicolon on a for loop is optional and does nothing StatementKind::For(_) => self, + // No semicolon needed for a resolved statement + StatementKind::Interned(_) => self, + StatementKind::Expression(expr) => { match (&expr.kind, semi, last_statement_in_block) { // Semicolons are optional for these expressions @@ -534,6 +541,7 @@ pub enum LValue { MemberAccess { object: Box, field_name: Ident, span: Span }, Index { array: Box, index: Expression, span: Span }, Dereference(Box, Span), + Interned(InternedExpressionKind, Span), } #[derive(Debug, PartialEq, Eq, Clone)] @@ -591,7 +599,7 @@ impl Recoverable for Pattern { } impl LValue { - fn as_expression(&self) -> Expression { + pub fn as_expression(&self) -> Expression { let kind = match self { LValue::Ident(ident) => ExpressionKind::Variable(Path::from_ident(ident.clone())), LValue::MemberAccess { object, field_name, span: _ } => { @@ -612,17 +620,53 @@ impl LValue { rhs: lvalue.as_expression(), })) } + LValue::Interned(id, _) => ExpressionKind::Interned(*id), }; let span = self.span(); Expression::new(kind, span) } + pub fn from_expression(expr: Expression) -> LValue { + LValue::from_expression_kind(expr.kind, expr.span) + } + + pub fn from_expression_kind(expr: ExpressionKind, span: Span) -> LValue { + match expr { + ExpressionKind::Variable(path) => LValue::Ident(path.as_ident().unwrap().clone()), + ExpressionKind::MemberAccess(member_access) => LValue::MemberAccess { + object: Box::new(LValue::from_expression(member_access.lhs)), + field_name: member_access.rhs, + span, + }, + ExpressionKind::Index(index) => LValue::Index { + array: Box::new(LValue::from_expression(index.collection)), + index: index.index, + span, + }, + ExpressionKind::Prefix(prefix) => { + if matches!( + prefix.operator, + crate::ast::UnaryOp::Dereference { implicitly_added: false } + ) { + LValue::Dereference(Box::new(LValue::from_expression(prefix.rhs)), span) + } else { + panic!("Called LValue::from_expression with an invalid prefix operator") + } + } + ExpressionKind::Interned(id) => LValue::Interned(id, span), + _ => { + panic!("Called LValue::from_expression with an invalid expression") + } + } + } + pub fn span(&self) -> Span { match self { LValue::Ident(ident) => ident.span(), LValue::MemberAccess { span, .. } | LValue::Index { span, .. } | LValue::Dereference(_, span) => *span, + LValue::Interned(_, span) => *span, } } } @@ -777,6 +821,7 @@ impl Display for StatementKind { StatementKind::Continue => write!(f, "continue"), StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind), StatementKind::Semi(semi) => write!(f, "{semi};"), + StatementKind::Interned(_) => write!(f, "(resolved);"), StatementKind::Error => write!(f, "Error"), } } @@ -809,6 +854,7 @@ impl Display for LValue { } LValue::Index { array, index, span: _ } => write!(f, "{array}[{index}]"), LValue::Dereference(lvalue, _span) => write!(f, "*{lvalue}"), + LValue::Interned(_, _) => write!(f, "?Interned"), } } } diff --git a/compiler/noirc_frontend/src/debug/mod.rs b/compiler/noirc_frontend/src/debug/mod.rs index 935acc4e6d0..fe027969473 100644 --- a/compiler/noirc_frontend/src/debug/mod.rs +++ b/compiler/noirc_frontend/src/debug/mod.rs @@ -322,6 +322,9 @@ impl DebugInstrumenter { ast::LValue::Dereference(_ref, _span) => { unimplemented![] } + ast::LValue::Interned(..) => { + unimplemented![] + } } } build_assign_member_stmt( diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index cf0b4f4071a..beede7a3a30 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -62,6 +62,11 @@ impl<'context> Elaborator<'context> { self.elaborate_unsafe_block(block_expression) } ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)), + ExpressionKind::Interned(id) => { + let expr_kind = self.interner.get_expression_kind(id); + let expr = Expression::new(expr_kind.clone(), expr.span); + return self.elaborate_expression(expr); + } ExpressionKind::Error => (HirExpression::Error, Type::Error), ExpressionKind::Unquote(_) => { self.push_err(ResolverError::UnquoteUsedOutsideQuote { span: expr.span }); diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index 0bb8641b6b3..dcbdf89391e 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -39,6 +39,11 @@ impl<'context> Elaborator<'context> { let (expr, _typ) = self.elaborate_expression(expr); (HirStatement::Semi(expr), Type::Unit) } + StatementKind::Interned(id) => { + let kind = self.interner.get_statement_kind(id); + let statement = Statement { kind: kind.clone(), span: statement.span }; + self.elaborate_statement_value(statement) + } StatementKind::Error => (HirStatement::Error, Type::Error), } } @@ -357,6 +362,10 @@ impl<'context> Elaborator<'context> { let lvalue = HirLValue::Dereference { lvalue, element_type, location }; (lvalue, typ, true) } + LValue::Interned(id, span) => { + let lvalue = self.interner.get_lvalue(id, span).clone(); + self.elaborate_lvalue(lvalue, assign_span) + } } } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 8e4c9aa4af1..e41234a5be5 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -158,6 +158,10 @@ impl<'context> Elaborator<'context> { Parenthesized(typ) => self.resolve_type_inner(*typ, kind), Resolved(id) => self.interner.get_quoted_type(id).clone(), AsTraitPath(path) => self.resolve_as_trait_path(*path), + Interned(id) => { + let typ = self.interner.get_unresolved_type_data(id).clone(); + return self.resolve_type_inner(UnresolvedType { typ, span }, kind); + } }; let location = Location::new(named_path_span.unwrap_or(typ.span), self.file); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 8aa8e92408f..4b68f82a275 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -9,9 +9,9 @@ use builtin_helpers::{ check_one_argument, check_three_arguments, check_two_arguments, get_expr, get_function_def, get_module, get_quoted, get_slice, get_struct, get_trait_constraint, get_trait_def, get_trait_impl, get_tuple, get_type, get_u32, get_unresolved_type, hir_pattern_to_tokens, - mutate_func_meta_type, parse, parse_tokens, replace_func_meta_parameters, - replace_func_meta_return_type, + mutate_func_meta_type, parse, replace_func_meta_parameters, replace_func_meta_return_type, }; +use chumsky::{prelude::choice, Parser}; use im::Vector; use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; @@ -19,19 +19,16 @@ use rustc_hash::FxHashMap as HashMap; use crate::{ ast::{ - ArrayLiteral, Expression, ExpressionKind, FunctionKind, FunctionReturnType, IntegerBitSize, - Literal, StatementKind, UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, - }, - hir::comptime::{ - errors::IResult, - value::{add_token_spans, ExprValue}, - InterpreterError, Value, + ArrayLiteral, BlockExpression, Expression, ExpressionKind, FunctionKind, + FunctionReturnType, IntegerBitSize, LValue, Literal, Statement, StatementKind, UnaryOp, + UnresolvedType, UnresolvedTypeData, Visibility, }, + hir::comptime::{errors::IResult, value::ExprValue, InterpreterError, Value}, hir_def::function::FunctionBody, - macros_api::{ModuleDefId, NodeInterner, Signedness}, + macros_api::{HirExpression, HirLiteral, ModuleDefId, NodeInterner, Signedness}, node_interner::{DefinitionKind, TraitImplKind}, parser::{self}, - token::{SpannedToken, Token}, + token::Token, QuotedType, Shared, Type, }; @@ -53,33 +50,40 @@ impl<'local, 'context> Interpreter<'local, 'context> { "array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location), "array_len" => array_len(interner, arguments, location), "as_slice" => as_slice(interner, arguments, location), - "expr_as_array" => expr_as_array(arguments, return_type, location), - "expr_as_assign" => expr_as_assign(arguments, return_type, location), - "expr_as_binary_op" => expr_as_binary_op(arguments, return_type, location), - "expr_as_block" => expr_as_block(arguments, return_type, location), - "expr_as_bool" => expr_as_bool(arguments, return_type, location), - "expr_as_cast" => expr_as_cast(arguments, return_type, location), - "expr_as_comptime" => expr_as_comptime(arguments, return_type, location), - "expr_as_function_call" => expr_as_function_call(arguments, return_type, location), - "expr_as_if" => expr_as_if(arguments, return_type, location), - "expr_as_index" => expr_as_index(arguments, return_type, location), - "expr_as_integer" => expr_as_integer(arguments, return_type, location), - "expr_as_member_access" => expr_as_member_access(arguments, return_type, location), - "expr_as_method_call" => expr_as_method_call(arguments, return_type, location), + "expr_as_array" => expr_as_array(interner, arguments, return_type, location), + "expr_as_assign" => expr_as_assign(interner, arguments, return_type, location), + "expr_as_binary_op" => expr_as_binary_op(interner, arguments, return_type, location), + "expr_as_block" => expr_as_block(interner, arguments, return_type, location), + "expr_as_bool" => expr_as_bool(interner, arguments, return_type, location), + "expr_as_cast" => expr_as_cast(interner, arguments, return_type, location), + "expr_as_comptime" => expr_as_comptime(interner, arguments, return_type, location), + "expr_as_function_call" => { + expr_as_function_call(interner, arguments, return_type, location) + } + "expr_as_if" => expr_as_if(interner, arguments, return_type, location), + "expr_as_index" => expr_as_index(interner, arguments, return_type, location), + "expr_as_integer" => expr_as_integer(interner, arguments, return_type, location), + "expr_as_member_access" => { + expr_as_member_access(interner, arguments, return_type, location) + } + "expr_as_method_call" => { + expr_as_method_call(interner, arguments, return_type, location) + } "expr_as_repeated_element_array" => { - expr_as_repeated_element_array(arguments, return_type, location) + expr_as_repeated_element_array(interner, arguments, return_type, location) } "expr_as_repeated_element_slice" => { - expr_as_repeated_element_slice(arguments, return_type, location) + expr_as_repeated_element_slice(interner, arguments, return_type, location) } - "expr_as_slice" => expr_as_slice(arguments, return_type, location), - "expr_as_tuple" => expr_as_tuple(arguments, return_type, location), - "expr_as_unary_op" => expr_as_unary_op(arguments, return_type, location), - "expr_as_unsafe" => expr_as_unsafe(arguments, return_type, location), - "expr_has_semicolon" => expr_has_semicolon(arguments, location), - "expr_is_break" => expr_is_break(arguments, location), - "expr_is_continue" => expr_is_continue(arguments, location), + "expr_as_slice" => expr_as_slice(interner, arguments, return_type, location), + "expr_as_tuple" => expr_as_tuple(interner, arguments, return_type, location), + "expr_as_unary_op" => expr_as_unary_op(interner, arguments, return_type, location), + "expr_as_unsafe" => expr_as_unsafe(interner, arguments, return_type, location), + "expr_has_semicolon" => expr_has_semicolon(interner, arguments, location), + "expr_is_break" => expr_is_break(interner, arguments, location), + "expr_is_continue" => expr_is_continue(interner, arguments, location), "is_unconstrained" => Ok(Value::Bool(true)), + "function_def_body" => function_def_body(interner, arguments, location), "function_def_name" => function_def_name(interner, arguments, location), "function_def_parameters" => function_def_parameters(interner, arguments, location), "function_def_return_type" => function_def_return_type(interner, arguments, location), @@ -135,7 +139,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "type_is_bool" => type_is_bool(arguments, location), "type_is_field" => type_is_field(arguments, location), "type_of" => type_of(arguments, location), - "unresolved_type_is_field" => unresolved_type_is_field(arguments, location), + "unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location), "zeroed" => zeroed(return_type), _ => { let item = format!("Comptime evaluation for builtin function {name}"); @@ -361,10 +365,14 @@ fn quoted_as_expr( ) -> IResult { let argument = check_one_argument(arguments, location)?; - let expr = parse(argument, parser::expression(), "an expression").ok(); - let value = expr.map(|expr| Value::expression(expr.kind)); + let expr_parser = parser::expression().map(|expr| Value::expression(expr.kind)); + let statement_parser = parser::fresh_statement().map(Value::statement); + let lvalue_parser = parser::lvalue(parser::expression()).map(Value::lvalue); + let parser = choice((expr_parser, statement_parser, lvalue_parser)); - option(return_type, value) + let expr = parse(argument, parser, "an expression").ok(); + + option(return_type, expr) } // fn as_module(quoted: Quoted) -> Option @@ -711,11 +719,12 @@ fn trait_impl_trait_generic_args( // fn is_field(self) -> bool fn unresolved_type_is_field( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let typ = get_unresolved_type(self_argument)?; + let typ = get_unresolved_type(interner, self_argument)?; Ok(Value::Bool(matches!(typ, UnresolvedTypeData::FieldElement))) } @@ -802,11 +811,12 @@ fn zeroed(return_type: Type) -> IResult { // fn as_array(self) -> Option<[Expr]> fn expr_as_array( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Array( ArrayLiteral::Standard(exprs), ))) = expr @@ -822,11 +832,12 @@ fn expr_as_array( // fn as_assign(self) -> Option<(Expr, Expr)> fn expr_as_assign( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Statement(StatementKind::Assign(assign)) = expr { let lhs = Value::lvalue(assign.lvalue); let rhs = Value::expression(assign.expression.kind); @@ -839,11 +850,12 @@ fn expr_as_assign( // fn as_binary_op(self) -> Option<(Expr, BinaryOp, Expr)> fn expr_as_binary_op( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { + expr_as(interner, arguments, return_type.clone(), location, |expr| { if let ExprValue::Expression(ExpressionKind::Infix(infix_expr)) = expr { let option_type = extract_option_generic_type(return_type); let Type::Tuple(mut tuple_types) = option_type else { @@ -872,11 +884,12 @@ fn expr_as_binary_op( // fn as_block(self) -> Option<[Expr]> fn expr_as_block( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Block(block_expr)) = expr { Some(block_expression_to_value(block_expr)) } else { @@ -887,11 +900,12 @@ fn expr_as_block( // fn as_bool(self) -> Option fn expr_as_bool( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Bool(bool))) = expr { Some(Value::Bool(bool)) } else { @@ -902,11 +916,12 @@ fn expr_as_bool( // fn as_cast(self) -> Option<(Expr, UnresolvedType)> fn expr_as_cast( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Cast(cast)) = expr { let lhs = Value::expression(cast.lhs.kind); let typ = Value::UnresolvedType(cast.r#type.typ); @@ -919,13 +934,14 @@ fn expr_as_cast( // fn as_comptime(self) -> Option<[Expr]> fn expr_as_comptime( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { use ExpressionKind::Block; - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Comptime(block_expr, _)) = expr { Some(block_expression_to_value(block_expr)) } else if let ExprValue::Statement(StatementKind::Comptime(statement)) = expr { @@ -951,11 +967,12 @@ fn expr_as_comptime( // fn as_function_call(self) -> Option<(Expr, [Expr])> fn expr_as_function_call( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Call(call_expression)) = expr { let function = Value::expression(call_expression.func.kind); let arguments = call_expression.arguments.into_iter(); @@ -971,11 +988,12 @@ fn expr_as_function_call( // fn as_if(self) -> Option<(Expr, Expr, Option)> fn expr_as_if( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { + expr_as(interner, arguments, return_type.clone(), location, |expr| { if let ExprValue::Expression(ExpressionKind::If(if_expr)) = expr { // Get the type of `Option` let option_type = extract_option_generic_type(return_type.clone()); @@ -1003,11 +1021,12 @@ fn expr_as_if( // fn as_index(self) -> Option fn expr_as_index( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Index(index_expr)) = expr { Some(Value::Tuple(vec![ Value::expression(index_expr.collection.kind), @@ -1021,27 +1040,36 @@ fn expr_as_index( // fn as_integer(self) -> Option<(Field, bool)> fn expr_as_integer( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { - if let ExprValue::Expression(ExpressionKind::Literal(Literal::Integer(field, sign))) = expr - { + expr_as(interner, arguments, return_type.clone(), location, |expr| match expr { + ExprValue::Expression(ExpressionKind::Literal(Literal::Integer(field, sign))) => { Some(Value::Tuple(vec![Value::Field(field), Value::Bool(sign)])) - } else { - None } + ExprValue::Expression(ExpressionKind::Resolved(id)) => { + if let HirExpression::Literal(HirLiteral::Integer(field, sign)) = + interner.expression(&id) + { + Some(Value::Tuple(vec![Value::Field(field), Value::Bool(sign)])) + } else { + None + } + } + _ => None, }) } // fn as_member_access(self) -> Option<(Expr, Quoted)> fn expr_as_member_access( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| match expr { + expr_as(interner, arguments, return_type, location, |expr| match expr { ExprValue::Expression(ExpressionKind::MemberAccess(member_access)) => { let tokens = Rc::new(vec![Token::Ident(member_access.rhs.0.contents.clone())]); Some(Value::Tuple(vec![ @@ -1059,11 +1087,12 @@ fn expr_as_member_access( // fn as_method_call(self) -> Option<(Expr, Quoted, [UnresolvedType], [Expr])> fn expr_as_method_call( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::MethodCall(method_call)) = expr { let object = Value::expression(method_call.object.kind); @@ -1092,11 +1121,12 @@ fn expr_as_method_call( // fn as_repeated_element_array(self) -> Option<(Expr, Expr)> fn expr_as_repeated_element_array( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Array( ArrayLiteral::Repeated { repeated_element, length }, ))) = expr @@ -1113,11 +1143,12 @@ fn expr_as_repeated_element_array( // fn as_repeated_element_slice(self) -> Option<(Expr, Expr)> fn expr_as_repeated_element_slice( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Slice( ArrayLiteral::Repeated { repeated_element, length }, ))) = expr @@ -1134,11 +1165,12 @@ fn expr_as_repeated_element_slice( // fn as_slice(self) -> Option<[Expr]> fn expr_as_slice( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Literal(Literal::Slice( ArrayLiteral::Standard(exprs), ))) = expr @@ -1154,11 +1186,12 @@ fn expr_as_slice( // fn as_tuple(self) -> Option<[Expr]> fn expr_as_tuple( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Tuple(expressions)) = expr { let expressions = expressions.into_iter().map(|expr| Value::expression(expr.kind)).collect(); @@ -1172,11 +1205,12 @@ fn expr_as_tuple( // fn as_unary_op(self) -> Option<(UnaryOp, Expr)> fn expr_as_unary_op( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type.clone(), location, |expr| { + expr_as(interner, arguments, return_type.clone(), location, |expr| { if let ExprValue::Expression(ExpressionKind::Prefix(prefix_expr)) = expr { let option_type = extract_option_generic_type(return_type); let Type::Tuple(mut tuple_types) = option_type else { @@ -1209,11 +1243,12 @@ fn expr_as_unary_op( // fn as_unsafe(self) -> Option<[Expr]> fn expr_as_unsafe( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { + expr_as(interner, arguments, return_type, location, |expr| { if let ExprValue::Expression(ExpressionKind::Unsafe(block_expr, _)) = expr { Some(block_expression_to_value(block_expr)) } else { @@ -1223,28 +1258,41 @@ fn expr_as_unsafe( } // fn as_has_semicolon(self) -> bool -fn expr_has_semicolon(arguments: Vec<(Value, Location)>, location: Location) -> IResult { +fn expr_has_semicolon( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let expr_value = get_expr(self_argument)?; + let expr_value = get_expr(interner, self_argument)?; Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Semi(..))))) } // fn is_break(self) -> bool -fn expr_is_break(arguments: Vec<(Value, Location)>, location: Location) -> IResult { +fn expr_is_break( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let expr_value = get_expr(self_argument)?; + let expr_value = get_expr(interner, self_argument)?; Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Break)))) } // fn is_continue(self) -> bool -fn expr_is_continue(arguments: Vec<(Value, Location)>, location: Location) -> IResult { +fn expr_is_continue( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { let self_argument = check_one_argument(arguments, location)?; - let expr_value = get_expr(self_argument)?; + let expr_value = get_expr(interner, self_argument)?; Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Continue)))) } // Helper function for implementing the `expr_as_...` functions. fn expr_as( + interner: &NodeInterner, arguments: Vec<(Value, Location)>, return_type: Type, location: Location, @@ -1254,7 +1302,7 @@ where F: FnOnce(ExprValue) -> Option, { let self_argument = check_one_argument(arguments, location)?; - let mut expr_value = get_expr(self_argument)?; + let mut expr_value = get_expr(interner, self_argument)?; loop { match expr_value { ExprValue::Expression(ExpressionKind::Parenthesized(expression)) => { @@ -1264,6 +1312,15 @@ where | ExprValue::Statement(StatementKind::Semi(expression)) => { expr_value = ExprValue::Expression(expression.kind); } + ExprValue::Expression(ExpressionKind::Interned(id)) => { + expr_value = ExprValue::Expression(interner.get_expression_kind(id).clone()); + } + ExprValue::Statement(StatementKind::Interned(id)) => { + expr_value = ExprValue::Statement(interner.get_statement_kind(id).clone()); + } + ExprValue::LValue(LValue::Interned(id, span)) => { + expr_value = ExprValue::LValue(interner.get_lvalue(id, span).clone()); + } _ => break, } } @@ -1272,6 +1329,22 @@ where option(return_type, option_value) } +// fn body(self) -> Expr +fn function_def_body( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let func_id = get_function_def(self_argument)?; + let func_meta = interner.function_meta(&func_id); + if let FunctionBody::Unresolved(_, block_expr, _) = &func_meta.function_body { + Ok(Value::expression(ExpressionKind::Block(block_expr.clone()))) + } else { + Err(InterpreterError::FunctionAlreadyResolved { location }) + } +} + // fn name(self) -> Quoted fn function_def_name( interner: &NodeInterner, @@ -1326,32 +1399,30 @@ fn function_def_return_type( Ok(Value::Type(func_meta.return_type().follow_bindings())) } -// fn set_body(self, body: Quoted) +// fn set_body(self, body: Expr) fn function_def_set_body( interpreter: &mut Interpreter, arguments: Vec<(Value, Location)>, location: Location, ) -> IResult { let (self_argument, body_argument) = check_two_arguments(arguments, location)?; - let body_argument_location = body_argument.1; + let body_location = body_argument.1; let func_id = get_function_def(self_argument)?; check_function_not_yet_resolved(interpreter, func_id, location)?; - let body_tokens = get_quoted(body_argument)?; - let mut body_quoted = add_token_spans(body_tokens.clone(), body_argument_location.span); - - // Surround the body in `{ ... }` so we can parse it as a block - body_quoted.0.insert(0, SpannedToken::new(Token::LeftBrace, location.span)); - body_quoted.0.push(SpannedToken::new(Token::RightBrace, location.span)); + let body_argument = get_expr(interpreter.elaborator.interner, body_argument)?; + let statement_kind = match body_argument { + ExprValue::Expression(expression_kind) => StatementKind::Expression(Expression { + kind: expression_kind, + span: body_location.span, + }), + ExprValue::Statement(statement_kind) => statement_kind, + ExprValue::LValue(lvalue) => StatementKind::Expression(lvalue.as_expression()), + }; - let body = parse_tokens( - body_tokens, - body_quoted, - body_argument_location, - parser::block(parser::fresh_statement()), - "a block", - )?; + let statement = Statement { kind: statement_kind, span: body_location.span }; + let body = BlockExpression { statements: vec![statement] }; let func_meta = interpreter.elaborator.interner.function_meta_mut(&func_id); func_meta.has_body = true; diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index a409731a5e4..809a54ecb44 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -4,7 +4,10 @@ use acvm::FieldElement; use noirc_errors::Location; use crate::{ - ast::{BlockExpression, IntegerBitSize, Signedness, UnresolvedTypeData}, + ast::{ + BlockExpression, ExpressionKind, IntegerBitSize, LValue, Signedness, StatementKind, + UnresolvedTypeData, + }, hir::{ comptime::{ errors::IResult, @@ -142,9 +145,23 @@ pub(crate) fn get_u32((value, location): (Value, Location)) -> IResult { } } -pub(crate) fn get_expr((value, location): (Value, Location)) -> IResult { +pub(crate) fn get_expr( + interner: &NodeInterner, + (value, location): (Value, Location), +) -> IResult { match value { - Value::Expr(expr) => Ok(expr), + Value::Expr(expr) => match expr { + ExprValue::Expression(ExpressionKind::Interned(id)) => { + Ok(ExprValue::Expression(interner.get_expression_kind(id).clone())) + } + ExprValue::Statement(StatementKind::Interned(id)) => { + Ok(ExprValue::Statement(interner.get_statement_kind(id).clone())) + } + ExprValue::LValue(LValue::Interned(id, _)) => { + Ok(ExprValue::LValue(interner.get_lvalue(id, location.span).clone())) + } + _ => Ok(expr), + }, value => type_mismatch(value, Type::Quoted(QuotedType::Expr), location), } } @@ -208,10 +225,18 @@ pub(crate) fn get_quoted((value, location): (Value, Location)) -> IResult IResult { match value { - Value::UnresolvedType(typ) => Ok(typ), + Value::UnresolvedType(typ) => { + if let UnresolvedTypeData::Interned(id) = typ { + let typ = interner.get_unresolved_type_data(id).clone(); + Ok(typ) + } else { + Ok(typ) + } + } value => type_mismatch(value, Type::Quoted(QuotedType::UnresolvedType), location), } } diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 18f482585ea..5b4875c8c41 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -9,8 +9,11 @@ use strum_macros::Display; use crate::{ ast::{ - ArrayLiteral, BlockExpression, ConstructorExpression, Ident, IntegerBitSize, LValue, - Signedness, Statement, StatementKind, UnresolvedTypeData, + ArrayLiteral, AssignStatement, BlockExpression, CallExpression, CastExpression, + ConstrainStatement, ConstructorExpression, ForLoopStatement, ForRange, Ident, IfExpression, + IndexExpression, InfixExpression, IntegerBitSize, LValue, Lambda, LetStatement, + MemberAccessExpression, MethodCallExpression, PrefixExpression, Signedness, Statement, + StatementKind, UnresolvedTypeData, }, hir::{def_map::ModuleId, type_check::generics::TraitGenerics}, hir_def::{ @@ -417,6 +420,18 @@ impl Value { let token = match self { Value::Quoted(tokens) => return Ok(unwrap_rc(tokens)), Value::Type(typ) => Token::QuotedType(interner.push_quoted_type(typ)), + Value::Expr(ExprValue::Expression(expr)) => { + Token::InternedExpr(interner.push_expression_kind(expr)) + } + Value::Expr(ExprValue::Statement(statement)) => { + Token::InternedStatement(interner.push_statement_kind(statement)) + } + Value::Expr(ExprValue::LValue(lvalue)) => { + Token::InternedLValue(interner.push_lvalue(lvalue)) + } + Value::UnresolvedType(typ) => { + Token::InternedUnresolvedTypeData(interner.push_unresolved_type_data(typ)) + } other => Token::UnquoteMarker(other.into_hir_expression(interner, location)?), }; Ok(vec![token]) @@ -597,10 +612,23 @@ impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { Value::ModuleDefinition(_) => write!(f, "(module)"), Value::Zeroed(typ) => write!(f, "(zeroed {typ})"), Value::Type(typ) => write!(f, "{}", typ), - Value::Expr(ExprValue::Expression(expr)) => write!(f, "{}", expr), - Value::Expr(ExprValue::Statement(statement)) => write!(f, "{}", statement), - Value::Expr(ExprValue::LValue(lvalue)) => write!(f, "{}", lvalue), - Value::UnresolvedType(typ) => write!(f, "{}", typ), + Value::Expr(ExprValue::Expression(expr)) => { + write!(f, "{}", remove_interned_in_expression_kind(self.interner, expr.clone())) + } + Value::Expr(ExprValue::Statement(statement)) => { + write!(f, "{}", remove_interned_in_statement_kind(self.interner, statement.clone())) + } + Value::Expr(ExprValue::LValue(lvalue)) => { + write!(f, "{}", remove_interned_in_lvalue(self.interner, lvalue.clone())) + } + Value::UnresolvedType(typ) => { + if let UnresolvedTypeData::Interned(id) = typ { + let typ = self.interner.get_unresolved_type_data(*id); + write!(f, "{}", typ) + } else { + write!(f, "{}", typ) + } + } } } } @@ -609,3 +637,227 @@ fn display_trait_constraint(interner: &NodeInterner, trait_constraint: &TraitCon let trait_ = interner.get_trait(trait_constraint.trait_id); format!("{}: {}{}", trait_constraint.typ, trait_.name, trait_constraint.trait_generics) } + +// Returns a new Expression where all Interned and Resolved expressions have been turned into non-interned ExpressionKind. +fn remove_interned_in_expression(interner: &NodeInterner, expr: Expression) -> Expression { + Expression { kind: remove_interned_in_expression_kind(interner, expr.kind), span: expr.span } +} + +// Returns a new ExpressionKind where all Interned and Resolved expressions have been turned into non-interned ExpressionKind. +fn remove_interned_in_expression_kind( + interner: &NodeInterner, + expr: ExpressionKind, +) -> ExpressionKind { + match expr { + ExpressionKind::Literal(literal) => { + ExpressionKind::Literal(remove_interned_in_literal(interner, literal)) + } + ExpressionKind::Block(block) => { + let statements = + vecmap(block.statements, |stmt| remove_interned_in_statement(interner, stmt)); + ExpressionKind::Block(BlockExpression { statements }) + } + ExpressionKind::Prefix(prefix) => ExpressionKind::Prefix(Box::new(PrefixExpression { + rhs: remove_interned_in_expression(interner, prefix.rhs), + ..*prefix + })), + ExpressionKind::Index(index) => ExpressionKind::Index(Box::new(IndexExpression { + collection: remove_interned_in_expression(interner, index.collection), + index: remove_interned_in_expression(interner, index.index), + })), + ExpressionKind::Call(call) => ExpressionKind::Call(Box::new(CallExpression { + func: Box::new(remove_interned_in_expression(interner, *call.func)), + arguments: vecmap(call.arguments, |arg| remove_interned_in_expression(interner, arg)), + ..*call + })), + ExpressionKind::MethodCall(call) => { + ExpressionKind::MethodCall(Box::new(MethodCallExpression { + object: remove_interned_in_expression(interner, call.object), + arguments: vecmap(call.arguments, |arg| { + remove_interned_in_expression(interner, arg) + }), + ..*call + })) + } + ExpressionKind::Constructor(constructor) => { + ExpressionKind::Constructor(Box::new(ConstructorExpression { + fields: vecmap(constructor.fields, |(name, expr)| { + (name, remove_interned_in_expression(interner, expr)) + }), + ..*constructor + })) + } + ExpressionKind::MemberAccess(member_access) => { + ExpressionKind::MemberAccess(Box::new(MemberAccessExpression { + lhs: remove_interned_in_expression(interner, member_access.lhs), + ..*member_access + })) + } + ExpressionKind::Cast(cast) => ExpressionKind::Cast(Box::new(CastExpression { + lhs: remove_interned_in_expression(interner, cast.lhs), + ..*cast + })), + ExpressionKind::Infix(infix) => ExpressionKind::Infix(Box::new(InfixExpression { + lhs: remove_interned_in_expression(interner, infix.lhs), + rhs: remove_interned_in_expression(interner, infix.rhs), + ..*infix + })), + ExpressionKind::If(if_expr) => ExpressionKind::If(Box::new(IfExpression { + condition: remove_interned_in_expression(interner, if_expr.condition), + consequence: remove_interned_in_expression(interner, if_expr.consequence), + alternative: if_expr + .alternative + .map(|alternative| remove_interned_in_expression(interner, alternative)), + })), + ExpressionKind::Variable(_) => expr, + ExpressionKind::Tuple(expressions) => ExpressionKind::Tuple(vecmap(expressions, |expr| { + remove_interned_in_expression(interner, expr) + })), + ExpressionKind::Lambda(lambda) => ExpressionKind::Lambda(Box::new(Lambda { + body: remove_interned_in_expression(interner, lambda.body), + ..*lambda + })), + ExpressionKind::Parenthesized(expr) => { + ExpressionKind::Parenthesized(Box::new(remove_interned_in_expression(interner, *expr))) + } + ExpressionKind::Quote(_) => expr, + ExpressionKind::Unquote(expr) => { + ExpressionKind::Unquote(Box::new(remove_interned_in_expression(interner, *expr))) + } + ExpressionKind::Comptime(block, span) => { + let statements = + vecmap(block.statements, |stmt| remove_interned_in_statement(interner, stmt)); + ExpressionKind::Comptime(BlockExpression { statements }, span) + } + ExpressionKind::Unsafe(block, span) => { + let statements = + vecmap(block.statements, |stmt| remove_interned_in_statement(interner, stmt)); + ExpressionKind::Unsafe(BlockExpression { statements }, span) + } + ExpressionKind::AsTraitPath(_) => expr, + ExpressionKind::Resolved(id) => { + let expr = interner.expression(&id); + expr.to_display_ast(interner, Span::default()).kind + } + ExpressionKind::Interned(id) => { + let expr = interner.get_expression_kind(id).clone(); + remove_interned_in_expression_kind(interner, expr) + } + ExpressionKind::Error => expr, + } +} + +fn remove_interned_in_literal(interner: &NodeInterner, literal: Literal) -> Literal { + match literal { + Literal::Array(array_literal) => { + Literal::Array(remove_interned_in_array_literal(interner, array_literal)) + } + Literal::Slice(array_literal) => { + Literal::Array(remove_interned_in_array_literal(interner, array_literal)) + } + Literal::Bool(_) + | Literal::Integer(_, _) + | Literal::Str(_) + | Literal::RawStr(_, _) + | Literal::FmtStr(_) + | Literal::Unit => literal, + } +} + +fn remove_interned_in_array_literal( + interner: &NodeInterner, + literal: ArrayLiteral, +) -> ArrayLiteral { + match literal { + ArrayLiteral::Standard(expressions) => { + ArrayLiteral::Standard(vecmap(expressions, |expr| { + remove_interned_in_expression(interner, expr) + })) + } + ArrayLiteral::Repeated { repeated_element, length } => ArrayLiteral::Repeated { + repeated_element: Box::new(remove_interned_in_expression(interner, *repeated_element)), + length: Box::new(remove_interned_in_expression(interner, *length)), + }, + } +} + +// Returns a new Statement where all Interned statements have been turned into non-interned StatementKind. +fn remove_interned_in_statement(interner: &NodeInterner, statement: Statement) -> Statement { + Statement { + kind: remove_interned_in_statement_kind(interner, statement.kind), + span: statement.span, + } +} + +// Returns a new StatementKind where all Interned statements have been turned into non-interned StatementKind. +fn remove_interned_in_statement_kind( + interner: &NodeInterner, + statement: StatementKind, +) -> StatementKind { + match statement { + StatementKind::Let(let_statement) => StatementKind::Let(LetStatement { + expression: remove_interned_in_expression(interner, let_statement.expression), + ..let_statement + }), + StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement( + remove_interned_in_expression(interner, constrain.0), + constrain.1.map(|expr| remove_interned_in_expression(interner, expr)), + constrain.2, + )), + StatementKind::Expression(expr) => { + StatementKind::Expression(remove_interned_in_expression(interner, expr)) + } + StatementKind::Assign(assign) => StatementKind::Assign(AssignStatement { + lvalue: assign.lvalue, + expression: remove_interned_in_expression(interner, assign.expression), + }), + StatementKind::For(for_loop) => StatementKind::For(ForLoopStatement { + range: match for_loop.range { + ForRange::Range(from, to) => ForRange::Range( + remove_interned_in_expression(interner, from), + remove_interned_in_expression(interner, to), + ), + ForRange::Array(expr) => { + ForRange::Array(remove_interned_in_expression(interner, expr)) + } + }, + block: remove_interned_in_expression(interner, for_loop.block), + ..for_loop + }), + StatementKind::Comptime(statement) => { + StatementKind::Comptime(Box::new(remove_interned_in_statement(interner, *statement))) + } + StatementKind::Semi(expr) => { + StatementKind::Semi(remove_interned_in_expression(interner, expr)) + } + StatementKind::Interned(id) => { + let statement = interner.get_statement_kind(id).clone(); + remove_interned_in_statement_kind(interner, statement) + } + StatementKind::Break | StatementKind::Continue | StatementKind::Error => statement, + } +} + +// Returns a new LValue where all Interned LValues have been turned into LValue. +fn remove_interned_in_lvalue(interner: &NodeInterner, lvalue: LValue) -> LValue { + match lvalue { + LValue::Ident(_) => lvalue, + LValue::MemberAccess { object, field_name, span } => LValue::MemberAccess { + object: Box::new(remove_interned_in_lvalue(interner, *object)), + field_name, + span, + }, + LValue::Index { array, index, span } => LValue::Index { + array: Box::new(remove_interned_in_lvalue(interner, *array)), + index: remove_interned_in_expression(interner, index), + span, + }, + LValue::Dereference(lvalue, span) => { + LValue::Dereference(Box::new(remove_interned_in_lvalue(interner, *lvalue)), span) + } + LValue::Interned(id, span) => { + let lvalue = interner.get_lvalue(id, span); + remove_interned_in_lvalue(interner, lvalue) + } + } +} diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index 8ee0fca2957..b3b6d25480f 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -4,7 +4,10 @@ use std::{fmt, iter::Map, vec::IntoIter}; use crate::{ lexer::errors::LexerErrorKind, - node_interner::{ExprId, QuotedTypeId}, + node_interner::{ + ExprId, InternedExpressionKind, InternedStatementKind, InternedUnresolvedTypeData, + QuotedTypeId, + }, }; /// Represents a token in noir's grammar - a word, number, @@ -28,6 +31,10 @@ pub enum BorrowedToken<'input> { BlockComment(&'input str, Option), Quote(&'input Tokens), QuotedType(QuotedTypeId), + InternedExpression(InternedExpressionKind), + InternedStatement(InternedStatementKind), + InternedLValue(InternedExpressionKind), + InternedUnresolvedTypeData(InternedUnresolvedTypeData), /// < Less, /// <= @@ -134,6 +141,14 @@ pub enum Token { /// to avoid having to tokenize it, re-parse it, and re-resolve it which /// may change the underlying type. QuotedType(QuotedTypeId), + /// A reference to an interned `ExpressionKind`. + InternedExpr(InternedExpressionKind), + /// A reference to an interned `StatementKind`. + InternedStatement(InternedStatementKind), + /// A reference to an interned `LValue`. + InternedLValue(InternedExpressionKind), + /// A reference to an interned `UnresolvedTypeData`. + InternedUnresolvedTypeData(InternedUnresolvedTypeData), /// < Less, /// <= @@ -233,6 +248,10 @@ pub fn token_to_borrowed_token(token: &Token) -> BorrowedToken<'_> { Token::BlockComment(ref s, _style) => BorrowedToken::BlockComment(s, *_style), Token::Quote(stream) => BorrowedToken::Quote(stream), Token::QuotedType(id) => BorrowedToken::QuotedType(*id), + Token::InternedExpr(id) => BorrowedToken::InternedExpression(*id), + Token::InternedStatement(id) => BorrowedToken::InternedStatement(*id), + Token::InternedLValue(id) => BorrowedToken::InternedLValue(*id), + Token::InternedUnresolvedTypeData(id) => BorrowedToken::InternedUnresolvedTypeData(*id), Token::IntType(ref i) => BorrowedToken::IntType(i.clone()), Token::Less => BorrowedToken::Less, Token::LessEqual => BorrowedToken::LessEqual, @@ -353,8 +372,12 @@ impl fmt::Display for Token { } write!(f, "}}") } - // Quoted types only have an ID so there is nothing to display + // Quoted types and exprs only have an ID so there is nothing to display Token::QuotedType(_) => write!(f, "(type)"), + Token::InternedExpr(_) | Token::InternedStatement(_) | Token::InternedLValue(_) => { + write!(f, "(expr)") + } + Token::InternedUnresolvedTypeData(_) => write!(f, "(type)"), Token::IntType(ref i) => write!(f, "{i}"), Token::Less => write!(f, "<"), Token::LessEqual => write!(f, "<="), @@ -407,6 +430,10 @@ pub enum TokenKind { Attribute, Quote, QuotedType, + InternedExpr, + InternedStatement, + InternedLValue, + InternedUnresolvedTypeData, UnquoteMarker, } @@ -420,6 +447,10 @@ impl fmt::Display for TokenKind { TokenKind::Attribute => write!(f, "attribute"), TokenKind::Quote => write!(f, "quote"), TokenKind::QuotedType => write!(f, "quoted type"), + TokenKind::InternedExpr => write!(f, "interned expr"), + TokenKind::InternedStatement => write!(f, "interned statement"), + TokenKind::InternedLValue => write!(f, "interned lvalue"), + TokenKind::InternedUnresolvedTypeData => write!(f, "interned unresolved type"), TokenKind::UnquoteMarker => write!(f, "macro result"), } } @@ -439,6 +470,10 @@ impl Token { Token::UnquoteMarker(_) => TokenKind::UnquoteMarker, Token::Quote(_) => TokenKind::Quote, Token::QuotedType(_) => TokenKind::QuotedType, + Token::InternedExpr(_) => TokenKind::InternedExpr, + Token::InternedStatement(_) => TokenKind::InternedStatement, + Token::InternedLValue(_) => TokenKind::InternedLValue, + Token::InternedUnresolvedTypeData(_) => TokenKind::InternedUnresolvedTypeData, tok => TokenKind::Token(tok.clone()), } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 4837028b80f..32f25790e12 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -13,7 +13,11 @@ use petgraph::prelude::DiGraph; use petgraph::prelude::NodeIndex as PetGraphIndex; use rustc_hash::FxHashMap as HashMap; +use crate::ast::ExpressionKind; use crate::ast::Ident; +use crate::ast::LValue; +use crate::ast::StatementKind; +use crate::ast::UnresolvedTypeData; use crate::graph::CrateId; use crate::hir::comptime; use crate::hir::def_collector::dc_crate::CompilationError; @@ -208,6 +212,15 @@ pub struct NodeInterner { /// the actual type since types do not implement Send or Sync. quoted_types: noirc_arena::Arena, + // Interned `ExpressionKind`s during comptime code. + interned_expression_kinds: noirc_arena::Arena, + + // Interned `StatementKind`s during comptime code. + interned_statement_kinds: noirc_arena::Arena, + + // Interned `UnresolvedTypeData`s during comptime code. + interned_unresolved_type_datas: noirc_arena::Arena, + /// Determins whether to run in LSP mode. In LSP mode references are tracked. pub(crate) lsp_mode: bool, @@ -580,6 +593,15 @@ pub struct GlobalInfo { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct QuotedTypeId(noirc_arena::Index); +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct InternedExpressionKind(noirc_arena::Index); + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct InternedStatementKind(noirc_arena::Index); + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct InternedUnresolvedTypeData(noirc_arena::Index); + impl Default for NodeInterner { fn default() -> Self { NodeInterner { @@ -617,6 +639,9 @@ impl Default for NodeInterner { type_alias_ref: Vec::new(), type_ref_locations: Vec::new(), quoted_types: Default::default(), + interned_expression_kinds: Default::default(), + interned_statement_kinds: Default::default(), + interned_unresolved_type_datas: Default::default(), lsp_mode: false, location_indices: LocationIndices::default(), reference_graph: petgraph::graph::DiGraph::new(), @@ -2042,6 +2067,41 @@ impl NodeInterner { &self.quoted_types[id.0] } + pub fn push_expression_kind(&mut self, expr: ExpressionKind) -> InternedExpressionKind { + InternedExpressionKind(self.interned_expression_kinds.insert(expr)) + } + + pub fn get_expression_kind(&self, id: InternedExpressionKind) -> &ExpressionKind { + &self.interned_expression_kinds[id.0] + } + + pub fn push_statement_kind(&mut self, statement: StatementKind) -> InternedStatementKind { + InternedStatementKind(self.interned_statement_kinds.insert(statement)) + } + + pub fn get_statement_kind(&self, id: InternedStatementKind) -> &StatementKind { + &self.interned_statement_kinds[id.0] + } + + pub fn push_lvalue(&mut self, lvalue: LValue) -> InternedExpressionKind { + self.push_expression_kind(lvalue.as_expression().kind) + } + + pub fn get_lvalue(&self, id: InternedExpressionKind, span: Span) -> LValue { + LValue::from_expression_kind(self.get_expression_kind(id).clone(), span) + } + + pub fn push_unresolved_type_data( + &mut self, + typ: UnresolvedTypeData, + ) -> InternedUnresolvedTypeData { + InternedUnresolvedTypeData(self.interned_unresolved_type_datas.insert(typ)) + } + + pub fn get_unresolved_type_data(&self, id: InternedUnresolvedTypeData) -> &UnresolvedTypeData { + &self.interned_unresolved_type_datas[id.0] + } + /// Returns the type of an operator (which is always a function), along with its return type. pub fn get_infix_operator_type( &self, diff --git a/compiler/noirc_frontend/src/parser/mod.rs b/compiler/noirc_frontend/src/parser/mod.rs index f1972bcb9b5..11944cd3304 100644 --- a/compiler/noirc_frontend/src/parser/mod.rs +++ b/compiler/noirc_frontend/src/parser/mod.rs @@ -25,7 +25,7 @@ use noirc_errors::Span; pub use parser::path::path_no_turbofish; pub use parser::traits::trait_bound; pub use parser::{ - block, expression, fresh_statement, parse_program, parse_type, pattern, top_level_items, + block, expression, fresh_statement, lvalue, parse_program, parse_type, pattern, top_level_items, }; #[derive(Debug, Clone)] diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 56c80ee1ce0..8a894ec2b83 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -73,7 +73,9 @@ mod test_helpers; use literals::literal; use path::{maybe_empty_path, path}; -use primitives::{dereference, ident, negation, not, nothing, right_shift_operator, token_kind}; +use primitives::{ + dereference, ident, interned_expr, negation, not, nothing, right_shift_operator, token_kind, +}; use traits::where_clause; /// Entry function for the parser - also handles lexing internally. @@ -487,6 +489,7 @@ where continue_statement(), return_statement(expr_parser.clone()), comptime_statement(expr_parser.clone(), expr_no_constructors, statement), + interned_statement(), expr_parser.map(StatementKind::Expression), )) }) @@ -526,6 +529,15 @@ where keyword(Keyword::Comptime).ignore_then(comptime_statement).map(StatementKind::Comptime) } +pub(super) fn interned_statement() -> impl NoirParser { + token_kind(TokenKind::InternedStatement).map(|token| match token { + Token::InternedStatement(id) => StatementKind::Interned(id), + _ => { + unreachable!("token_kind(InternedStatement) guarantees we parse an interned statement") + } + }) +} + /// Comptime in an expression position only accepts entire blocks fn comptime_expr<'a, S>(statement: S) -> impl NoirParser + 'a where @@ -642,7 +654,7 @@ enum LValueRhs { Index(Expression, Span), } -fn lvalue<'a, P>(expr_parser: P) -> impl NoirParser + 'a +pub fn lvalue<'a, P>(expr_parser: P) -> impl NoirParser + 'a where P: ExprParser + 'a, { @@ -655,7 +667,15 @@ where let parenthesized = lvalue.delimited_by(just(Token::LeftParen), just(Token::RightParen)); - let term = choice((parenthesized, dereferences, l_ident)); + let interned = + token_kind(TokenKind::InternedLValue).map_with_span(|token, span| match token { + Token::InternedLValue(id) => LValue::Interned(id, span), + _ => unreachable!( + "token_kind(InternedLValue) guarantees we parse an interned lvalue" + ), + }); + + let term = choice((parenthesized, dereferences, l_ident, interned)); let l_member_rhs = just(Token::Dot).ignore_then(field_name()).map_with_span(LValueRhs::MemberAccess); @@ -1154,6 +1174,7 @@ where literal(), as_trait_path(parse_type()).map(ExpressionKind::AsTraitPath), macro_quote_marker(), + interned_expr(), )) .map_with_span(Expression::new) .or(parenthesized(expr_parser.clone()).map_with_span(|sub_expr, span| { diff --git a/compiler/noirc_frontend/src/parser/parser/primitives.rs b/compiler/noirc_frontend/src/parser/parser/primitives.rs index 9145fb945c9..c1516e2c927 100644 --- a/compiler/noirc_frontend/src/parser/parser/primitives.rs +++ b/compiler/noirc_frontend/src/parser/parser/primitives.rs @@ -119,6 +119,13 @@ pub(super) fn macro_quote_marker() -> impl NoirParser { }) } +pub(super) fn interned_expr() -> impl NoirParser { + token_kind(TokenKind::InternedExpr).map(|token| match token { + Token::InternedExpr(id) => ExpressionKind::Interned(id), + _ => unreachable!("token_kind(InternedExpr) guarantees we parse an interned expr"), + }) +} + #[cfg(test)] mod test { use crate::parser::parser::{ diff --git a/compiler/noirc_frontend/src/parser/parser/types.rs b/compiler/noirc_frontend/src/parser/parser/types.rs index c655ab8c5a4..f83303151eb 100644 --- a/compiler/noirc_frontend/src/parser/parser/types.rs +++ b/compiler/noirc_frontend/src/parser/parser/types.rs @@ -40,6 +40,7 @@ pub(super) fn parse_type_inner<'a>( function_type(recursive_type_parser.clone()), mutable_reference_type(recursive_type_parser.clone()), as_trait_path_type(recursive_type_parser), + interned_unresolved_type(), )) } @@ -168,6 +169,15 @@ pub(super) fn resolved_type() -> impl NoirParser { }) } +pub(super) fn interned_unresolved_type() -> impl NoirParser { + token_kind(TokenKind::InternedUnresolvedTypeData).map_with_span(|token, span| match token { + Token::InternedUnresolvedTypeData(id) => UnresolvedTypeData::Interned(id).with_span(span), + _ => unreachable!( + "token_kind(InternedUnresolvedTypeData) guarantees we parse an interned unresolved type" + ), + }) +} + pub(super) fn string_type() -> impl NoirParser { keyword(Keyword::String) .ignore_then(type_expression().delimited_by(just(Token::Less), just(Token::Greater))) diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md index 0a32b2b04fc..d421e8b56a3 100644 --- a/docs/docs/noir/standard_library/meta/expr.md +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -161,3 +161,21 @@ comptime { #include_code is_continue noir_stdlib/src/meta/expr.nr rust `true` if this expression is `continue`. + +### mutate + +#include_code mutate noir_stdlib/src/meta/expr.nr rust + +Applies a mapping function to this expression and to all of its sub-expressions. +`f` will be applied to each sub-expression first, then applied to the expression itself. + +This happens recursively for every expression within `self`. + +For example, calling `mutate` on `(&[1], &[2, 3])` with an `f` that returns `Option::some` +for expressions that are integers, doubling them, would return `(&[2], &[4, 6])`. + +### quoted + +#include_code quoted noir_stdlib/src/meta/expr.nr rust + +Returns this expression as a `Quoted` value. It's the same as `quote { $self }`. \ No newline at end of file diff --git a/docs/docs/noir/standard_library/meta/function_def.md b/docs/docs/noir/standard_library/meta/function_def.md index 4b359a9d343..8a4e8c84958 100644 --- a/docs/docs/noir/standard_library/meta/function_def.md +++ b/docs/docs/noir/standard_library/meta/function_def.md @@ -7,6 +7,14 @@ a function definition in the source program. ## Methods +### body + +#include_code body noir_stdlib/src/meta/function_def.nr rust + +Returns the body of the function as an expression. This is only valid +on functions in the current crate which have not yet been resolved. +This means any functions called at compile-time are invalid targets for this method. + ### name #include_code name noir_stdlib/src/meta/function_def.nr rust @@ -33,8 +41,6 @@ Mutate the function body to a new expression. This is only valid on functions in the current crate which have not yet been resolved. This means any functions called at compile-time are invalid targets for this method. -Requires the new body to be a valid expression. - ### set_parameters #include_code set_parameters noir_stdlib/src/meta/function_def.nr rust diff --git a/docs/docs/noir/standard_library/meta/op.md b/docs/docs/noir/standard_library/meta/op.md index 37d4cb746ac..d8b154edc02 100644 --- a/docs/docs/noir/standard_library/meta/op.md +++ b/docs/docs/noir/standard_library/meta/op.md @@ -37,6 +37,12 @@ Returns `true` if this operator is `-`. `true` if this operator is `*` +#### quoted + +#include_code unary_quoted noir_stdlib/src/meta/op.nr rust + +Returns this operator as a `Quoted` value. + ### BinaryOp Represents a binary operator. One of `+`, `-`, `*`, `/`, `%`, `==`, `!=`, `<`, `<=`, `>`, `>=`, `&`, `|`, `^`, `>>`, or `<<`. @@ -132,3 +138,9 @@ Represents a binary operator. One of `+`, `-`, `*`, `/`, `%`, `==`, `!=`, `<`, ` #include_code is_shift_right noir_stdlib/src/meta/op.nr rust `true` if this operator is `<<` + +#### quoted + +#include_code binary_quoted noir_stdlib/src/meta/op.nr rust + +Returns this operator as a `Quoted` value. \ No newline at end of file diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index ee3980f8f54..c09d9b92c9b 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -110,4 +110,310 @@ impl Expr { // docs:start:is_continue fn is_continue(self) -> bool {} // docs:end:is_continue + + // docs:start:mutate + fn mutate(self, f: fn[Env](Expr) -> Option) -> Expr { + // docs:end:mutate + let result = mutate_array(self, f); + let result = result.or_else(|| mutate_assign(self, f)); + let result = result.or_else(|| mutate_binary_op(self, f)); + let result = result.or_else(|| mutate_block(self, f)); + let result = result.or_else(|| mutate_cast(self, f)); + let result = result.or_else(|| mutate_comptime(self, f)); + let result = result.or_else(|| mutate_if(self, f)); + let result = result.or_else(|| mutate_index(self, f)); + let result = result.or_else(|| mutate_function_call(self, f)); + let result = result.or_else(|| mutate_member_access(self, f)); + let result = result.or_else(|| mutate_method_call(self, f)); + let result = result.or_else(|| mutate_repeated_element_array(self, f)); + let result = result.or_else(|| mutate_repeated_element_slice(self, f)); + let result = result.or_else(|| mutate_slice(self, f)); + let result = result.or_else(|| mutate_tuple(self, f)); + let result = result.or_else(|| mutate_unary_op(self, f)); + let result = result.or_else(|| mutate_unsafe(self, f)); + if result.is_some() { + let result = result.unwrap_unchecked(); + let modified = f(result); + modified.unwrap_or(result) + } else { + f(self).unwrap_or(self) + } + } + + // docs:start:quoted + fn quoted(self) -> Quoted { + // docs:end:quoted + quote { $self } + } +} + +fn mutate_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_array().map( + |exprs: [Expr]| { + let exprs = mutate_expressions(exprs, f); + new_array(exprs) + } + ) +} + +fn mutate_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_assign().map( + |expr: (Expr, Expr)| { + let (lhs, rhs) = expr; + let lhs = lhs.mutate(f); + let rhs = rhs.mutate(f); + new_assign(lhs, rhs) + } + ) +} + +fn mutate_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_binary_op().map( + |expr: (Expr, BinaryOp, Expr)| { + let (lhs, op, rhs) = expr; + let lhs = lhs.mutate(f); + let rhs = rhs.mutate(f); + new_binary_op(lhs, op, rhs) + } + ) +} + +fn mutate_block(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_block().map( + |exprs: [Expr]| { + let exprs = mutate_expressions(exprs, f); + new_block(exprs) + } + ) +} + +fn mutate_cast(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_cast().map( + |expr: (Expr, UnresolvedType)| { + let (expr, typ) = expr; + let expr = expr.mutate(f); + new_cast(expr, typ) + } + ) +} + +fn mutate_comptime(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_comptime().map( + |exprs: [Expr]| { + let exprs = exprs.map(|expr: Expr| expr.mutate(f)); + new_comptime(exprs) + } + ) +} + +fn mutate_function_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_function_call().map( + |expr: (Expr, [Expr])| { + let (function, arguments) = expr; + let function = function.mutate(f); + let arguments = arguments.map(|arg: Expr| arg.mutate(f)); + new_function_call(function, arguments) + } + ) +} + +fn mutate_if(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_if().map( + |expr: (Expr, Expr, Option)| { + let (condition, consequence, alternative) = expr; + let condition = condition.mutate(f); + let consequence = consequence.mutate(f); + let alternative = alternative.map(|alternative: Expr| alternative.mutate(f)); + new_if(condition, consequence, alternative) + } + ) +} + +fn mutate_index(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_index().map( + |expr: (Expr, Expr)| { + let (object, index) = expr; + let object = object.mutate(f); + let index = index.mutate(f); + new_index(object, index) + } + ) +} + +fn mutate_member_access(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_member_access().map( + |expr: (Expr, Quoted)| { + let (object, name) = expr; + let object = object.mutate(f); + new_member_access(object, name) + } + ) +} + +fn mutate_method_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_method_call().map( + |expr: (Expr, Quoted, [UnresolvedType], [Expr])| { + let (object, name, generics, arguments) = expr; + let object = object.mutate(f); + let arguments = arguments.map(|arg: Expr| arg.mutate(f)); + new_method_call(object, name, generics, arguments) + } + ) +} + +fn mutate_repeated_element_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_repeated_element_array().map( + |expr: (Expr, Expr)| { + let (expr, length) = expr; + let expr = expr.mutate(f); + let length = length.mutate(f); + new_repeated_element_array(expr, length) + } + ) +} + +fn mutate_repeated_element_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_repeated_element_slice().map( + |expr: (Expr, Expr)| { + let (expr, length) = expr; + let expr = expr.mutate(f); + let length = length.mutate(f); + new_repeated_element_slice(expr, length) + } + ) +} + +fn mutate_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_slice().map( + |exprs: [Expr]| { + let exprs = mutate_expressions(exprs, f); + new_slice(exprs) + } + ) +} + +fn mutate_tuple(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_tuple().map( + |exprs: [Expr]| { + let exprs = mutate_expressions(exprs, f); + new_tuple(exprs) + } + ) +} + +fn mutate_unary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_unary_op().map( + |expr: (UnaryOp, Expr)| { + let (op, rhs) = expr; + let rhs = rhs.mutate(f); + new_unary_op(op, rhs) + } + ) +} + +fn mutate_unsafe(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_unsafe().map( + |exprs: [Expr]| { + let exprs = exprs.map(|expr: Expr| expr.mutate(f)); + new_unsafe(exprs) + } + ) +} + +fn mutate_expressions(exprs: [Expr], f: fn[Env](Expr) -> Option) -> [Expr] { + exprs.map(|expr: Expr| expr.mutate(f)) +} + +fn new_array(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { , }); + quote { [$exprs]}.as_expr().unwrap() +} + +fn new_assign(lhs: Expr, rhs: Expr) -> Expr { + quote { $lhs = $rhs }.as_expr().unwrap() +} + +fn new_binary_op(lhs: Expr, op: BinaryOp, rhs: Expr) -> Expr { + let op = op.quoted(); + quote { ($lhs) $op ($rhs) }.as_expr().unwrap() +} + +fn new_block(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { ; }); + quote { { $exprs }}.as_expr().unwrap() +} + +fn new_cast(expr: Expr, typ: UnresolvedType) -> Expr { + quote { ($expr) as $typ }.as_expr().unwrap() +} + +fn new_comptime(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { ; }); + quote { comptime { $exprs }}.as_expr().unwrap() +} + +fn new_if(condition: Expr, consequence: Expr, alternative: Option) -> Expr { + if alternative.is_some() { + let alternative = alternative.unwrap(); + quote { if $condition { $consequence } else { $alternative }}.as_expr().unwrap() + } else { + quote { if $condition { $consequence } }.as_expr().unwrap() + } +} + +fn new_index(object: Expr, index: Expr) -> Expr { + quote { $object[$index] }.as_expr().unwrap() +} + +fn new_member_access(object: Expr, name: Quoted) -> Expr { + quote { $object.$name }.as_expr().unwrap() +} + +fn new_function_call(function: Expr, arguments: [Expr]) -> Expr { + let arguments = join_expressions(arguments, quote { , }); + + quote { $function($arguments) }.as_expr().unwrap() +} + +fn new_method_call(object: Expr, name: Quoted, generics: [UnresolvedType], arguments: [Expr]) -> Expr { + let arguments = join_expressions(arguments, quote { , }); + + if generics.len() == 0 { + quote { $object.$name($arguments) }.as_expr().unwrap() + } else { + let generics = generics.map(|generic| quote { $generic }).join(quote { , }); + quote { $object.$name::<$generics>($arguments) }.as_expr().unwrap() + } +} + +fn new_repeated_element_array(expr: Expr, length: Expr) -> Expr { + quote { [$expr; $length] }.as_expr().unwrap() +} + +fn new_repeated_element_slice(expr: Expr, length: Expr) -> Expr { + quote { &[$expr; $length] }.as_expr().unwrap() +} + +fn new_slice(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { , }); + quote { &[$exprs]}.as_expr().unwrap() +} + +fn new_tuple(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { , }); + quote { ($exprs) }.as_expr().unwrap() +} + +fn new_unary_op(op: UnaryOp, rhs: Expr) -> Expr { + let op = op.quoted(); + quote { $op($rhs) }.as_expr().unwrap() +} + +fn new_unsafe(exprs: [Expr]) -> Expr { + let exprs = join_expressions(exprs, quote { ; }); + quote { unsafe { $exprs }}.as_expr().unwrap() +} + +fn join_expressions(exprs: [Expr], separator: Quoted) -> Quoted { + exprs.map(|expr: Expr| expr.quoted()).join(separator) } diff --git a/noir_stdlib/src/meta/function_def.nr b/noir_stdlib/src/meta/function_def.nr index 7ac8803e7e4..84f9c60b304 100644 --- a/noir_stdlib/src/meta/function_def.nr +++ b/noir_stdlib/src/meta/function_def.nr @@ -1,4 +1,9 @@ impl FunctionDefinition { + #[builtin(function_def_body)] + // docs:start:body + fn body(self) -> Expr {} + // docs:end:body + #[builtin(function_def_name)] // docs:start:name fn name(self) -> Quoted {} @@ -16,7 +21,7 @@ impl FunctionDefinition { #[builtin(function_def_set_body)] // docs:start:set_body - fn set_body(self, body: Quoted) {} + fn set_body(self, body: Expr) {} // docs:end:set_body #[builtin(function_def_set_parameters)] diff --git a/noir_stdlib/src/meta/op.nr b/noir_stdlib/src/meta/op.nr index 9c892c4d80b..f3060a1648b 100644 --- a/noir_stdlib/src/meta/op.nr +++ b/noir_stdlib/src/meta/op.nr @@ -26,6 +26,22 @@ impl UnaryOp { // docs:end:is_dereference self.op == 3 } + + // docs:start:unary_quoted + pub fn quoted(self) -> Quoted { + // docs:end:unary_quoted + if self.is_minus() { + quote { - } + } else if self.is_not() { + quote { ! } + } else if self.is_mutable_reference() { + quote { &mut } + } else if self.is_dereference() { + quote { * } + } else { + crate::mem::zeroed() + } + } } struct BinaryOp { @@ -128,5 +144,45 @@ impl BinaryOp { // docs:end:is_modulo self.op == 15 } + + // docs:start:binary_quoted + pub fn quoted(self) -> Quoted { + // docs:end:binary_quoted + if self.is_add() { + quote { + } + } else if self.is_subtract() { + quote { - } + } else if self.is_multiply() { + quote { * } + } else if self.is_divide() { + quote { / } + } else if self.is_equal() { + quote { == } + } else if self.is_not_equal() { + quote { != } + } else if self.is_less_than() { + quote { < } + } else if self.is_less_than_or_equal() { + quote { <= } + } else if self.is_greater_than() { + quote { > } + } else if self.is_greater_than_or_equal() { + quote { >= } + } else if self.is_and() { + quote { & } + } else if self.is_or() { + quote { | } + } else if self.is_xor() { + quote { ^ } + } else if self.is_shift_right() { + quote { >> } + } else if self.is_shift_left() { + quote { << } + } else if self.is_modulo() { + quote { % } + } else { + crate::mem::zeroed() + } + } } diff --git a/test_programs/compile_success_empty/comptime_function_definition/src/main.nr b/test_programs/compile_success_empty/comptime_function_definition/src/main.nr index ce09ba86e49..06da5a1dde5 100644 --- a/test_programs/compile_success_empty/comptime_function_definition/src/main.nr +++ b/test_programs/compile_success_empty/comptime_function_definition/src/main.nr @@ -49,7 +49,7 @@ comptime fn mutate_add_one(f: FunctionDefinition) { assert_eq(f.return_type(), type_of(0)); // fn add_one(x: Field) -> Field { x + 1 } - f.set_body(quote { x + 1 }); + f.set_body(quote { x + 1 }.as_expr().unwrap()); } fn main() { diff --git a/test_programs/compile_success_empty/inject_context_attribute/Nargo.toml b/test_programs/compile_success_empty/inject_context_attribute/Nargo.toml new file mode 100644 index 00000000000..10f9cb1f9e2 --- /dev/null +++ b/test_programs/compile_success_empty/inject_context_attribute/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "inject_context_attribute" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/compile_success_empty/inject_context_attribute/src/main.nr b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr new file mode 100644 index 00000000000..65003ed837b --- /dev/null +++ b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr @@ -0,0 +1,53 @@ +struct Context { + value: Field, +} + +#[inject_context] +fn foo(x: Field) { + if true { + // 20 + 1 => 21 + bar(qux(x + 1)); + } else { + assert(false); + } +} + +#[inject_context] +fn bar(x: Field) { + let expected = _context.value; + assert_eq(x, expected); +} + +#[inject_context] +fn qux(x: Field) -> Field { + // 21 * 2 => 42 + x * 2 +} + +fn inject_context(f: FunctionDefinition) { + // Add a `_context: Context` parameter to the function + let parameters = f.parameters(); + let parameters = parameters.push_front((quote { _context }, quote { Context }.as_type())); + f.set_parameters(parameters); + + // Create a new body where every function call has `_context` added to the list of arguments. + let body = f.body().mutate(mapping_function); + f.set_body(body); +} + +fn mapping_function(expr: Expr) -> Option { + expr.as_function_call().map( + |func_call: (Expr, [Expr])| { + let (name, arguments) = func_call; + let arguments = arguments.push_front(quote { _context }.as_expr().unwrap()); + let arguments = arguments.map(|arg: Expr| arg.quoted()).join(quote { , }); + quote { $name($arguments) }.as_expr().unwrap() + } + ) +} + +fn main() { + let context = Context { value: 42 }; + foo(context, 20); +} + diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index 329e97dc9d9..abc7a793fd1 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -15,6 +15,20 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_array() { + comptime + { + let expr = quote { [1, 2, 4] }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let elems = expr.as_array().unwrap(); + assert_eq(elems.len(), 3); + assert_eq(elems[0].as_integer().unwrap(), (2, false)); + assert_eq(elems[1].as_integer().unwrap(), (4, false)); + assert_eq(elems[2].as_integer().unwrap(), (8, false)); + } + } + #[test] fn test_expr_as_assign() { comptime @@ -26,6 +40,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_assign() { + comptime + { + let expr = quote { { a = 1; } }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let exprs = expr.as_block().unwrap(); + let (_lhs, rhs) = exprs[0].as_assign().unwrap(); + assert_eq(rhs.as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_block() { comptime @@ -43,6 +69,24 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_block() { + comptime + { + let expr = quote { { 1; 4; 23 } }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let exprs = expr.as_block().unwrap(); + assert_eq(exprs.len(), 3); + assert_eq(exprs[0].as_integer().unwrap(), (2, false)); + assert_eq(exprs[1].as_integer().unwrap(), (8, false)); + assert_eq(exprs[2].as_integer().unwrap(), (46, false)); + + assert(exprs[0].has_semicolon()); + assert(exprs[1].has_semicolon()); + assert(!exprs[2].has_semicolon()); + } + } + #[test] fn test_expr_as_method_call() { comptime @@ -61,6 +105,25 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_method_call() { + comptime + { + let expr = quote { foo.bar(3, 4) }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + + let (_object, name, generics, arguments) = expr.as_method_call().unwrap(); + + assert_eq(name, quote { bar }); + + assert_eq(generics.len(), 0); + + assert_eq(arguments.len(), 2); + assert_eq(arguments[0].as_integer().unwrap(), (6, false)); + assert_eq(arguments[1].as_integer().unwrap(), (8, false)); + } + } + #[test] fn test_expr_as_integer() { comptime @@ -73,6 +136,17 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_integer() { + comptime + { + let expr = quote { 1 }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + + assert_eq((2, false), expr.as_integer().unwrap()); + } + } + #[test] fn test_expr_as_binary_op() { comptime @@ -96,6 +170,20 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_binary_op() { + comptime + { + let expr = quote { 3 + 4 }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + + let (lhs, op, rhs) = expr.as_binary_op().unwrap(); + assert_eq(lhs.as_integer().unwrap(), (6, false)); + assert(op.is_add()); + assert_eq(rhs.as_integer().unwrap(), (8, false)); + } + } + #[test] fn test_expr_as_bool() { comptime @@ -119,6 +207,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_cast() { + comptime + { + let expr = quote { 1 as Field }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (expr, typ) = expr.as_cast().unwrap(); + assert_eq(expr.as_integer().unwrap(), (2, false)); + assert(typ.is_field()); + } + } + #[test] fn test_expr_as_comptime() { comptime @@ -129,6 +229,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_comptime() { + comptime + { + let expr = quote { comptime { 1; 4; 23 } }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let exprs = expr.as_comptime().unwrap(); + assert_eq(exprs.len(), 3); + assert_eq(exprs[0].as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_comptime_as_statement() { comptime @@ -157,6 +269,18 @@ mod tests { } // docs:end:as_expr_example + #[test] + fn test_expr_mutate_for_function_call() { + comptime + { + let expr = quote { foo(42) }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (_function, args) = expr.as_function_call().unwrap(); + assert_eq(args.len(), 1); + assert_eq(args[0].as_integer().unwrap(), (84, false)); + } + } + #[test] fn test_expr_as_if() { comptime @@ -171,6 +295,29 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_if() { + comptime + { + let expr = quote { if 1 { 2 } }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (condition, consequence, alternative) = expr.as_if().unwrap(); + assert_eq(condition.as_integer().unwrap(), (2, false)); + let consequence = consequence.as_block().unwrap()[0].as_block().unwrap()[0]; + assert_eq(consequence.as_integer().unwrap(), (4, false)); + assert(alternative.is_none()); + + let expr = quote { if 1 { 2 } else { 3 } }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (condition, consequence, alternative) = expr.as_if().unwrap(); + assert_eq(condition.as_integer().unwrap(), (2, false)); + let consequence = consequence.as_block().unwrap()[0].as_block().unwrap()[0]; + assert_eq(consequence.as_integer().unwrap(), (4, false)); + let alternative = alternative.unwrap().as_block().unwrap()[0].as_block().unwrap()[0]; + assert_eq(alternative.as_integer().unwrap(), (6, false)); + } + } + #[test] fn test_expr_as_index() { comptime @@ -180,6 +327,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_index() { + comptime + { + let expr = quote { 1[2] }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (object, index) = expr.as_index().unwrap(); + assert_eq(object.as_integer().unwrap(), (2, false)); + assert_eq(index.as_integer().unwrap(), (4, false)); + } + } + #[test] fn test_expr_as_member_access() { comptime @@ -190,6 +349,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_member_access() { + comptime + { + let expr = quote { 1.bar }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (expr, name) = expr.as_member_access().unwrap(); + assert_eq(name, quote { bar }); + assert_eq(expr.as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_member_access_with_an_lvalue() { comptime @@ -213,6 +384,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_repeated_element_array() { + comptime + { + let expr = quote { [1; 3] }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (expr, length) = expr.as_repeated_element_array().unwrap(); + assert_eq(expr.as_integer().unwrap(), (2, false)); + assert_eq(length.as_integer().unwrap(), (6, false)); + } + } + #[test] fn test_expr_as_repeated_element_slice() { comptime @@ -224,6 +407,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_repeated_element_slice() { + comptime + { + let expr = quote { &[1; 3] }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (expr, length) = expr.as_repeated_element_slice().unwrap(); + assert_eq(expr.as_integer().unwrap(), (2, false)); + assert_eq(length.as_integer().unwrap(), (6, false)); + } + } + #[test] fn test_expr_as_slice() { comptime @@ -237,6 +432,20 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_slice() { + comptime + { + let expr = quote { &[1, 3, 5] }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let elems = expr.as_slice().unwrap(); + assert_eq(elems.len(), 3); + assert_eq(elems[0].as_integer().unwrap(), (2, false)); + assert_eq(elems[1].as_integer().unwrap(), (6, false)); + assert_eq(elems[2].as_integer().unwrap(), (10, false)); + } + } + #[test] fn test_expr_as_tuple() { comptime @@ -247,6 +456,19 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_tuple() { + comptime + { + let expr = quote { (1, 2) }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let tuple_exprs = expr.as_tuple().unwrap(); + assert_eq(tuple_exprs.len(), 2); + assert_eq(tuple_exprs[0].as_integer().unwrap(), (2, false)); + assert_eq(tuple_exprs[1].as_integer().unwrap(), (4, false)); + } + } + #[test] fn test_expr_as_unary_op() { comptime @@ -258,6 +480,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_unary_op() { + comptime + { + let expr = quote { -(1) }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let (op, expr) = expr.as_unary_op().unwrap(); + assert(op.is_minus()); + assert_eq(expr.as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_unsafe() { comptime @@ -268,6 +502,18 @@ mod tests { } } + #[test] + fn test_expr_mutate_for_unsafe() { + comptime + { + let expr = quote { unsafe { 1; 4; 23 } }.as_expr().unwrap(); + let expr = expr.mutate(times_two); + let exprs = expr.as_unsafe().unwrap(); + assert_eq(exprs.len(), 3); + assert_eq(exprs[0].as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_is_break() { comptime @@ -308,6 +554,16 @@ mod tests { let (_, op, _) = expr.as_binary_op().unwrap(); op } + + comptime fn times_two(expr: Expr) -> Option { + expr.as_integer().and_then( + |integer: (Field, bool)| { + let (value, _) = integer; + let value = value * 2; + quote { $value }.as_expr() + } + ) + } } fn main() {} diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index c61f92795ad..f339ed19622 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -449,7 +449,10 @@ impl<'a> NodeFinder<'a> { StatementKind::Semi(expression) => { self.find_in_expression(expression); } - StatementKind::Break | StatementKind::Continue | StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -501,6 +504,7 @@ impl<'a> NodeFinder<'a> { self.find_in_expression(index); } LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), + LValue::Interned(..) => (), } } @@ -565,7 +569,10 @@ impl<'a> NodeFinder<'a> { ExpressionKind::AsTraitPath(as_trait_path) => { self.find_in_as_trait_path(as_trait_path); } - ExpressionKind::Quote(_) | ExpressionKind::Resolved(_) | ExpressionKind::Error => (), + ExpressionKind::Quote(_) + | ExpressionKind::Resolved(_) + | ExpressionKind::Interned(_) + | ExpressionKind::Error => (), } // "foo." (no identifier afterwards) is parsed as the expression on the left hand-side of the dot. @@ -739,6 +746,7 @@ impl<'a> NodeFinder<'a> { | UnresolvedTypeData::Bool | UnresolvedTypeData::Unit | UnresolvedTypeData::Resolved(_) + | UnresolvedTypeData::Interned(_) | UnresolvedTypeData::Error => (), } } diff --git a/tooling/lsp/src/requests/completion/builtins.rs b/tooling/lsp/src/requests/completion/builtins.rs index b9c4ce2358a..430e04aedfd 100644 --- a/tooling/lsp/src/requests/completion/builtins.rs +++ b/tooling/lsp/src/requests/completion/builtins.rs @@ -90,6 +90,7 @@ pub(super) fn keyword_builtin_type(keyword: &Keyword) -> Option<&'static str> { Keyword::Expr => Some("Expr"), Keyword::Field => Some("Field"), Keyword::FunctionDefinition => Some("FunctionDefinition"), + Keyword::Quoted => Some("Quoted"), Keyword::StructDefinition => Some("StructDefinition"), Keyword::TraitConstraint => Some("TraitConstraint"), Keyword::TraitDefinition => Some("TraitDefinition"), @@ -122,7 +123,6 @@ pub(super) fn keyword_builtin_type(keyword: &Keyword) -> Option<&'static str> { | Keyword::Module | Keyword::Mut | Keyword::Pub - | Keyword::Quoted | Keyword::Return | Keyword::ReturnData | Keyword::String diff --git a/tooling/lsp/src/requests/inlay_hint.rs b/tooling/lsp/src/requests/inlay_hint.rs index a1e083187d3..2f6e7dede5d 100644 --- a/tooling/lsp/src/requests/inlay_hint.rs +++ b/tooling/lsp/src/requests/inlay_hint.rs @@ -202,9 +202,10 @@ impl<'a> InlayHintCollector<'a> { } StatementKind::Comptime(statement) => self.collect_in_statement(statement), StatementKind::Semi(expression) => self.collect_in_expression(expression), - StatementKind::Break => (), - StatementKind::Continue => (), - StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -303,6 +304,7 @@ impl<'a> InlayHintCollector<'a> { | ExpressionKind::Variable(..) | ExpressionKind::Quote(..) | ExpressionKind::Resolved(..) + | ExpressionKind::Interned(..) | ExpressionKind::Error => (), } } @@ -692,6 +694,7 @@ fn get_expression_name(expression: &Expression) -> Option { | ExpressionKind::Unquote(..) | ExpressionKind::Comptime(..) | ExpressionKind::Resolved(..) + | ExpressionKind::Interned(..) | ExpressionKind::Literal(..) | ExpressionKind::Unsafe(..) | ExpressionKind::Error => None, diff --git a/tooling/lsp/src/requests/signature_help/traversal.rs b/tooling/lsp/src/requests/signature_help/traversal.rs index 22f92a86124..6a31a22d63a 100644 --- a/tooling/lsp/src/requests/signature_help/traversal.rs +++ b/tooling/lsp/src/requests/signature_help/traversal.rs @@ -125,7 +125,10 @@ impl<'a> SignatureFinder<'a> { StatementKind::Semi(expression) => { self.find_in_expression(expression); } - StatementKind::Break | StatementKind::Continue | StatementKind::Error => (), + StatementKind::Break + | StatementKind::Continue + | StatementKind::Interned(_) + | StatementKind::Error => (), } } @@ -160,6 +163,7 @@ impl<'a> SignatureFinder<'a> { self.find_in_expression(index); } LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), + LValue::Interned(..) => (), } } @@ -232,6 +236,7 @@ impl<'a> SignatureFinder<'a> { | ExpressionKind::AsTraitPath(_) | ExpressionKind::Quote(_) | ExpressionKind::Resolved(_) + | ExpressionKind::Interned(_) | ExpressionKind::Error => (), } } diff --git a/tooling/nargo_fmt/src/rewrite/expr.rs b/tooling/nargo_fmt/src/rewrite/expr.rs index 4fee7d3e197..caa60b17cc2 100644 --- a/tooling/nargo_fmt/src/rewrite/expr.rs +++ b/tooling/nargo_fmt/src/rewrite/expr.rs @@ -175,6 +175,9 @@ pub(crate) fn rewrite( ExpressionKind::Resolved(_) => { unreachable!("ExpressionKind::Resolved should only emitted by the comptime interpreter") } + ExpressionKind::Interned(_) => { + unreachable!("ExpressionKind::Interned should only emitted by the comptime interpreter") + } ExpressionKind::Unquote(expr) => { if matches!(&expr.kind, ExpressionKind::Variable(..)) { format!("${expr}") diff --git a/tooling/nargo_fmt/src/rewrite/typ.rs b/tooling/nargo_fmt/src/rewrite/typ.rs index 8d1e27078a8..6121f8debf6 100644 --- a/tooling/nargo_fmt/src/rewrite/typ.rs +++ b/tooling/nargo_fmt/src/rewrite/typ.rs @@ -73,6 +73,6 @@ pub(crate) fn rewrite(visitor: &FmtVisitor, _shape: Shape, typ: UnresolvedType) | UnresolvedTypeData::FormatString(_, _) | UnresolvedTypeData::Quoted(_) | UnresolvedTypeData::TraitAsType(_, _) => visitor.slice(typ.span).into(), - UnresolvedTypeData::Error => unreachable!(), + UnresolvedTypeData::Interned(_) | UnresolvedTypeData::Error => unreachable!(), } } diff --git a/tooling/nargo_fmt/src/visitor/stmt.rs b/tooling/nargo_fmt/src/visitor/stmt.rs index 8e05fe3f5c5..b5ac14a33b3 100644 --- a/tooling/nargo_fmt/src/visitor/stmt.rs +++ b/tooling/nargo_fmt/src/visitor/stmt.rs @@ -104,6 +104,9 @@ impl super::FmtVisitor<'_> { StatementKind::Break => self.push_rewrite("break;".into(), span), StatementKind::Continue => self.push_rewrite("continue;".into(), span), StatementKind::Comptime(statement) => self.visit_stmt(statement.kind, span, is_last), + StatementKind::Interned(_) => unreachable!( + "StatementKind::Resolved should only emitted by the comptime interpreter" + ), } } } From 062103ea039042e8e999b29dbb1fafc3cebd513c Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 28 Aug 2024 16:19:12 -0500 Subject: [PATCH 14/21] feat(optimization): Avoid merging identical (by ID) arrays (#5853) # Description ## Problem\* Resolves ## Summary\* Found while I was working on Jan's bug. There were a number of arrays we merged where we'd get `vN = if ... then vM else vM` which we had to go through the long merge for without this. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- compiler/noirc_evaluator/src/ssa/ir/instruction.rs | 8 ++++++-- .../src/ssa/opt/flatten_cfg/value_merger.rs | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index c3cd27bf179..36069f17933 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -733,11 +733,15 @@ impl Instruction { } } + let then_value = dfg.resolve(*then_value); + let else_value = dfg.resolve(*else_value); + if then_value == else_value { + return SimplifiedTo(then_value); + } + if matches!(&typ, Type::Numeric(_)) { let then_condition = *then_condition; - let then_value = *then_value; let else_condition = *else_condition; - let else_value = *else_value; let result = ValueMerger::merge_numeric_values( dfg, diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index 7c2db62b0ea..75ee57dd4fa 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -58,6 +58,13 @@ impl<'a> ValueMerger<'a> { then_value: ValueId, else_value: ValueId, ) -> ValueId { + let then_value = self.dfg.resolve(then_value); + let else_value = self.dfg.resolve(else_value); + + if then_value == else_value { + return then_value; + } + match self.dfg.type_of_value(then_value) { Type::Numeric(_) => Self::merge_numeric_values( self.dfg, From 0e8becc7bccee2ae4e4e3ef373df08c3e9ef88c9 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Wed, 28 Aug 2024 18:02:57 -0400 Subject: [PATCH 15/21] feat(meta): Comptime keccak (#5854) # Description ## Problem\* Resolves Just under 1/3 of a public transfer in Aztec is due to a keccak hash of the selector. We should ideally be simplifying this through our normal compiler passes as the inputs are all known at compile time. This is difficult at the moment due to having the same compilation pipeline for SSA and ACIR and would require changing how we inline. For an easy win we can enable comptime hashing of keccak. This also just more generally provides another one of our foreign functions in the comptime environment which will be needed in the future. ## Summary\* I added functions in the interpreter for `keccakf1600` and `to_le_radix`. I also included a test under `compile_succes_empty` that calls keccak in a comptime env. The `assert_constant` builtin is also used by our byte decomposition functions. For this we can just return true as we already check that no non-comptime variables are referenced in comptime code. ## Additional Context ## Documentation\* Check one: - [] No documentation needed. - [ ] Documentation included in this PR. - [X] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [X] I have tested the changes locally. - [X] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: jfecher --- .../src/hir/comptime/interpreter/builtin.rs | 41 +++++++++++++++++-- .../interpreter/builtin/builtin_helpers.rs | 10 +++++ .../src/hir/comptime/interpreter/foreign.rs | 31 +++++++++++++- .../comptime_keccak/Nargo.toml | 7 ++++ .../comptime_keccak/src/main.nr | 31 ++++++++++++++ 5 files changed, 114 insertions(+), 6 deletions(-) create mode 100644 test_programs/compile_success_empty/comptime_keccak/Nargo.toml create mode 100644 test_programs/compile_success_empty/comptime_keccak/src/main.nr diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 4b68f82a275..852733b6ca8 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -6,15 +6,17 @@ use std::{ use acvm::{AcirField, FieldElement}; use builtin_helpers::{ block_expression_to_value, check_argument_count, check_function_not_yet_resolved, - check_one_argument, check_three_arguments, check_two_arguments, get_expr, get_function_def, - get_module, get_quoted, get_slice, get_struct, get_trait_constraint, get_trait_def, - get_trait_impl, get_tuple, get_type, get_u32, get_unresolved_type, hir_pattern_to_tokens, - mutate_func_meta_type, parse, replace_func_meta_parameters, replace_func_meta_return_type, + check_one_argument, check_three_arguments, check_two_arguments, get_expr, get_field, + get_function_def, get_module, get_quoted, get_slice, get_struct, get_trait_constraint, + get_trait_def, get_trait_impl, get_tuple, get_type, get_u32, get_unresolved_type, + hir_pattern_to_tokens, mutate_func_meta_type, parse, replace_func_meta_parameters, + replace_func_meta_return_type, }; use chumsky::{prelude::choice, Parser}; use im::Vector; use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; +use num_bigint::BigUint; use rustc_hash::FxHashMap as HashMap; use crate::{ @@ -49,6 +51,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { match name { "array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location), "array_len" => array_len(interner, arguments, location), + "assert_constant" => Ok(Value::Bool(true)), "as_slice" => as_slice(interner, arguments, location), "expr_as_array" => expr_as_array(interner, arguments, return_type, location), "expr_as_assign" => expr_as_assign(interner, arguments, return_type, location), @@ -114,6 +117,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "struct_def_as_type" => struct_def_as_type(interner, arguments, location), "struct_def_fields" => struct_def_fields(interner, arguments, location), "struct_def_generics" => struct_def_generics(interner, arguments, location), + "to_le_radix" => to_le_radix(arguments, location), "trait_constraint_eq" => trait_constraint_eq(interner, arguments, location), "trait_constraint_hash" => trait_constraint_hash(interner, arguments, location), "trait_def_as_trait_constraint" => { @@ -425,6 +429,35 @@ fn quoted_as_type( Ok(Value::Type(typ)) } +fn to_le_radix(arguments: Vec<(Value, Location)>, location: Location) -> IResult { + let (value, radix, limb_count) = check_three_arguments(arguments, location)?; + + let value = get_field(value)?; + let radix = get_u32(radix)?; + let limb_count = get_u32(limb_count)?; + + // Decompose the integer into its radix digits in little endian form. + let decomposed_integer = compute_to_radix(value, radix); + let decomposed_integer = vecmap(0..limb_count as usize, |i| match decomposed_integer.get(i) { + Some(digit) => Value::U8(*digit), + None => Value::U8(0), + }); + Ok(Value::Array( + decomposed_integer.into(), + Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + )) +} + +fn compute_to_radix(field: FieldElement, radix: u32) -> Vec { + let bit_size = u32::BITS - (radix - 1).leading_zeros(); + let radix_big = BigUint::from(radix); + assert_eq!(BigUint::from(2u128).pow(bit_size), radix_big, "ICE: Radix must be a power of 2"); + let big_integer = BigUint::from_bytes_be(&field.to_be_bytes()); + + // Decompose the integer into its radix digits in little endian form. + big_integer.to_radix_le(radix) +} + // fn as_array(self) -> Option<(Type, Type)> fn type_as_array( arguments: Vec<(Value, Location)>, diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index 809a54ecb44..2e06240e995 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -145,6 +145,16 @@ pub(crate) fn get_u32((value, location): (Value, Location)) -> IResult { } } +pub(crate) fn get_u64((value, location): (Value, Location)) -> IResult { + match value { + Value::U64(value) => Ok(value), + value => { + let expected = Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour); + type_mismatch(value, expected, location) + } + } +} + pub(crate) fn get_expr( interner: &NodeInterner, (value, location): (Value, Location), diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs index f7caf84ec42..5ae60bb4d00 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs @@ -1,5 +1,6 @@ -use acvm::BlackBoxFunctionSolver; +use acvm::blackbox_solver::BlackBoxFunctionSolver; use bn254_blackbox_solver::Bn254BlackBoxSolver; +use im::Vector; use iter_extended::try_vecmap; use noirc_errors::Location; @@ -8,7 +9,9 @@ use crate::{ macros_api::NodeInterner, }; -use super::builtin::builtin_helpers::{check_two_arguments, get_array, get_field, get_u32}; +use super::builtin::builtin_helpers::{ + check_one_argument, check_two_arguments, get_array, get_field, get_u32, get_u64, +}; pub(super) fn call_foreign( interner: &mut NodeInterner, @@ -18,6 +21,7 @@ pub(super) fn call_foreign( ) -> IResult { match name { "poseidon2_permutation" => poseidon2_permutation(interner, arguments, location), + "keccakf1600" => keccakf1600(interner, arguments, location), _ => { let item = format!("Comptime evaluation for builtin function {name}"); Err(InterpreterError::Unimplemented { item, location }) @@ -47,3 +51,26 @@ fn poseidon2_permutation( let array = fields.into_iter().map(Value::Field).collect(); Ok(Value::Array(array, typ)) } + +fn keccakf1600( + interner: &mut NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let input = check_one_argument(arguments, location)?; + let input_location = input.1; + + let (input, typ) = get_array(interner, input)?; + + let input = try_vecmap(input, |integer| get_u64((integer, input_location)))?; + + let mut state = [0u64; 25]; + for (it, input_value) in state.iter_mut().zip(input.iter()) { + *it = *input_value; + } + let result_lanes = acvm::blackbox_solver::keccakf1600(state) + .map_err(|error| InterpreterError::BlackBoxError(error, location))?; + + let array: Vector = result_lanes.into_iter().map(Value::U64).collect(); + Ok(Value::Array(array, typ)) +} diff --git a/test_programs/compile_success_empty/comptime_keccak/Nargo.toml b/test_programs/compile_success_empty/comptime_keccak/Nargo.toml new file mode 100644 index 00000000000..47c8654804d --- /dev/null +++ b/test_programs/compile_success_empty/comptime_keccak/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "comptime_keccak" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/comptime_keccak/src/main.nr b/test_programs/compile_success_empty/comptime_keccak/src/main.nr new file mode 100644 index 00000000000..3cde32b6ba9 --- /dev/null +++ b/test_programs/compile_success_empty/comptime_keccak/src/main.nr @@ -0,0 +1,31 @@ +// Tests a very simple program. +// +// The features being tested is keccak256 in brillig +fn main() { + comptime + { + let x = 0xbd; + let result = [ + 0x5a, 0x50, 0x2f, 0x9f, 0xca, 0x46, 0x7b, 0x26, 0x6d, 0x5b, 0x78, 0x33, 0x65, 0x19, 0x37, 0xe8, 0x05, 0x27, 0x0c, 0xa3, 0xf3, 0xaf, 0x1c, 0x0d, 0xd2, 0x46, 0x2d, 0xca, 0x4b, 0x3b, 0x1a, 0xbf + ]; + // We use the `as` keyword here to denote the fact that we want to take just the first byte from the x Field + // The padding is taken care of by the program + let digest = keccak256([x as u8], 1); + assert(digest == result); + //#1399: variable message size + let message_size = 4; + let hash_a = keccak256([1, 2, 3, 4], message_size); + let hash_b = keccak256([1, 2, 3, 4, 0, 0, 0, 0], message_size); + + assert(hash_a == hash_b); + + let message_size_big = 8; + let hash_c = keccak256([1, 2, 3, 4, 0, 0, 0, 0], message_size_big); + + assert(hash_a != hash_c); + } +} + +comptime fn keccak256(data: [u8; N], msg_len: u32) -> [u8; 32] { + std::hash::keccak256(data, msg_len) +} From 516833f54beb8896c179cc76095382c5ce3725ca Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:40:54 +0200 Subject: [PATCH 16/21] chore: underconstrained check in parallel (#5848) # Description ## Problem\* underconstrained check can add significant time to the compilation process ## Summary\* perform the underconstrained check in parallel. The PR parallelises the check per function. In the example I tried, it gives no benefit as all the work is done in one function. Then the PR also parallelise the 'merge sets' by doing it over chunks. In the example I tried, it reduces the duration from 77s down to 1s. ## Additional Context ## Documentation\* Check one: - [X] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [X] I have tested the changes locally. - [X] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: TomAFrench Co-authored-by: jfecher --- Cargo.lock | 1 + Cargo.toml | 1 + compiler/noirc_evaluator/Cargo.toml | 1 + .../src/brillig/brillig_gen/brillig_block.rs | 4 +- .../check_for_underconstrained_values.rs | 57 ++++++++++++------- .../src/ssa/function_builder/data_bus.rs | 7 +-- .../src/ssa/function_builder/mod.rs | 8 +-- .../src/ssa/ir/instruction/call.rs | 6 +- .../src/ssa/ir/instruction/call/blackbox.rs | 4 +- compiler/noirc_evaluator/src/ssa/ir/map.rs | 2 +- compiler/noirc_evaluator/src/ssa/ir/types.rs | 12 ++-- .../src/ssa/opt/constant_folding.rs | 10 ++-- .../src/ssa/opt/flatten_cfg.rs | 8 +-- .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 8 +-- compiler/noirc_evaluator/src/ssa/opt/rc.rs | 12 ++-- .../src/ssa/opt/remove_bit_shifts.rs | 4 +- .../src/ssa/ssa_gen/context.rs | 11 ++-- tooling/lsp/Cargo.toml | 2 +- tooling/nargo/Cargo.toml | 2 +- tooling/nargo_cli/Cargo.toml | 2 +- 20 files changed, 91 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2cf79c40303..279d0b59ce1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3013,6 +3013,7 @@ dependencies = [ "noirc_frontend", "num-bigint", "proptest", + "rayon", "serde", "serde_json", "serde_with", diff --git a/Cargo.toml b/Cargo.toml index bf5739ebbe8..a903ef6fec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -154,6 +154,7 @@ color-eyre = "0.6.2" rand = "0.8.5" proptest = "1.2.0" proptest-derive = "0.4.0" +rayon = "1.8.0" im = { version = "15.1", features = ["serde"] } tracing = "0.1.40" diff --git a/compiler/noirc_evaluator/Cargo.toml b/compiler/noirc_evaluator/Cargo.toml index 3bc7f544170..1db6af2ae85 100644 --- a/compiler/noirc_evaluator/Cargo.toml +++ b/compiler/noirc_evaluator/Cargo.toml @@ -26,6 +26,7 @@ serde_json.workspace = true serde_with = "3.2.0" tracing.workspace = true chrono = "0.4.37" +rayon.workspace = true cfg-if.workspace = true [dev-dependencies] diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index a9801c8904e..1e672eeea3c 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -22,7 +22,7 @@ use acvm::{acir::AcirField, FieldElement}; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use iter_extended::vecmap; use num_bigint::BigUint; -use std::rc::Rc; +use std::sync::Arc; use super::brillig_black_box::convert_black_box_call; use super::brillig_block_variables::BlockVariables; @@ -1701,7 +1701,7 @@ impl<'block> BrilligBlock<'block> { fn initialize_constant_array_runtime( &mut self, - item_types: Rc>, + item_types: Arc>, item_to_repeat: Vec, item_count: usize, pointer: MemoryAddress, diff --git a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs index 79db4e645ee..26eab290d4b 100644 --- a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs +++ b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs @@ -1,8 +1,6 @@ //! This module defines an SSA pass that detects if the final function has any subgraphs independent from inputs and outputs. //! If this is the case, then part of the final circuit can be completely replaced by any other passing circuit, since there are no constraints ensuring connections. //! So the compiler informs the developer of this as a bug -use im::HashMap; - use crate::errors::{InternalBug, SsaReport}; use crate::ssa::ir::basic_block::BasicBlockId; use crate::ssa::ir::function::RuntimeType; @@ -10,25 +8,29 @@ use crate::ssa::ir::function::{Function, FunctionId}; use crate::ssa::ir::instruction::{Instruction, InstructionId, Intrinsic}; use crate::ssa::ir::value::{Value, ValueId}; use crate::ssa::ssa_gen::Ssa; +use im::HashMap; +use rayon::prelude::*; use std::collections::{BTreeMap, HashSet}; impl Ssa { /// Go through each top-level non-brillig function and detect if it has independent subgraphs #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn check_for_underconstrained_values(&mut self) -> Vec { - let mut warnings: Vec = Vec::new(); - for function in self.functions.values() { - match function.runtime() { - RuntimeType::Acir { .. } => { - warnings.extend(check_for_underconstrained_values_within_function( - function, + let functions_id = self.functions.values().map(|f| f.id().to_usize()).collect::>(); + functions_id + .iter() + .par_bridge() + .flat_map(|fid| { + let function_to_process = &self.functions[&FunctionId::new(*fid)]; + match function_to_process.runtime() { + RuntimeType::Acir { .. } => check_for_underconstrained_values_within_function( + function_to_process, &self.functions, - )); + ), + RuntimeType::Brillig => Vec::new(), } - RuntimeType::Brillig => (), - } - } - warnings + }) + .collect() } } @@ -88,9 +90,8 @@ impl Context { self.visited_blocks.insert(block); self.connect_value_ids_in_block(function, block, all_functions); } - // Merge ValueIds into sets, where each original small set of ValueIds is merged with another set if they intersect - self.merge_sets(); + self.value_sets = Self::merge_sets_par(&self.value_sets); } /// Find sets that contain input or output value of the function @@ -267,14 +268,13 @@ impl Context { /// Merge all small sets into larger ones based on whether the sets intersect or not /// /// If two small sets have a common ValueId, we merge them into one - fn merge_sets(&mut self) { + fn merge_sets(current: &[HashSet]) -> Vec> { let mut new_set_id: usize = 0; let mut updated_sets: HashMap> = HashMap::new(); let mut value_dictionary: HashMap = HashMap::new(); let mut parsed_value_set: HashSet = HashSet::new(); - // Go through each set - for set in self.value_sets.iter() { + for set in current.iter() { // Check if the set has any of the ValueIds we've encountered at previous iterations let intersection: HashSet = set.intersection(&parsed_value_set).copied().collect(); @@ -327,7 +327,26 @@ impl Context { } updated_sets.insert(largest_set_index, largest_set); } - self.value_sets = updated_sets.values().cloned().collect(); + updated_sets.values().cloned().collect() + } + + /// Parallel version of merge_sets + /// The sets are merged by chunks, and then the chunks are merged together + fn merge_sets_par(sets: &[HashSet]) -> Vec> { + let mut sets = sets.to_owned(); + let mut len = sets.len(); + let mut prev_len = len + 1; + + while len > 1000 && len < prev_len { + sets = sets.par_chunks(1000).flat_map(Self::merge_sets).collect(); + + prev_len = len; + len = sets.len(); + } + // TODO: if prev_len >= len, this means we cannot effectively merge the sets anymore + // We should instead partition the sets into disjoint chunks and work on those chunks, + // but for now we fallback to the non-parallel implementation + Self::merge_sets(&sets) } } #[cfg(test)] diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs index 9f964cf048d..38895bb977e 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs @@ -1,5 +1,4 @@ -use std::collections::BTreeMap; -use std::rc::Rc; +use std::{collections::BTreeMap, sync::Arc}; use crate::ssa::ir::{types::Type, value::ValueId}; use acvm::FieldElement; @@ -155,8 +154,8 @@ impl FunctionBuilder { let len = databus.values.len(); let array = if len > 0 { - let array = - self.array_constant(databus.values, Type::Array(Rc::new(vec![Type::field()]), len)); + let array = self + .array_constant(databus.values, Type::Array(Arc::new(vec![Type::field()]), len)); Some(array) } else { None diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 8cc42241d92..bf6430c36d7 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod data_bus; -use std::{borrow::Cow, collections::BTreeMap, rc::Rc}; +use std::{borrow::Cow, collections::BTreeMap, sync::Arc}; use acvm::{acir::circuit::ErrorSelector, FieldElement}; use noirc_errors::Location; @@ -189,7 +189,7 @@ impl FunctionBuilder { /// given amount of field elements. Returns the result of the allocate instruction, /// which is always a Reference to the allocated data. pub(crate) fn insert_allocate(&mut self, element_type: Type) -> ValueId { - let reference_type = Type::Reference(Rc::new(element_type)); + let reference_type = Type::Reference(Arc::new(element_type)); self.insert_instruction(Instruction::Allocate, Some(vec![reference_type])).first() } @@ -516,7 +516,7 @@ impl std::ops::Index for FunctionBuilder { #[cfg(test)] mod tests { - use std::rc::Rc; + use std::sync::Arc; use acvm::{acir::AcirField, FieldElement}; @@ -542,7 +542,7 @@ mod tests { let to_bits_id = builder.import_intrinsic_id(Intrinsic::ToBits(Endian::Little)); let input = builder.numeric_constant(FieldElement::from(7_u128), Type::field()); let length = builder.numeric_constant(FieldElement::from(8_u128), Type::field()); - let result_types = vec![Type::Array(Rc::new(vec![Type::bool()]), 8)]; + let result_types = vec![Type::Array(Arc::new(vec![Type::bool()]), 8)]; let call_results = builder.insert_call(to_bits_id, vec![input, length], result_types).into_owned(); diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index de7ab6e532d..2c6aedeca35 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -1,5 +1,5 @@ use fxhash::FxHashMap as HashMap; -use std::{collections::VecDeque, rc::Rc}; +use std::{collections::VecDeque, sync::Arc}; use acvm::{ acir::{AcirField, BlackBoxFunc}, @@ -561,7 +561,7 @@ fn simplify_black_box_func( fn make_constant_array(dfg: &mut DataFlowGraph, results: Vec, typ: Type) -> ValueId { let result_constants = vecmap(results, |element| dfg.make_constant(element, typ.clone())); - let typ = Type::Array(Rc::new(vec![typ]), result_constants.len()); + let typ = Type::Array(Arc::new(vec![typ]), result_constants.len()); dfg.make_array(result_constants.into(), typ) } @@ -572,7 +572,7 @@ fn make_constant_slice( ) -> (ValueId, ValueId) { let result_constants = vecmap(results, |element| dfg.make_constant(element, typ.clone())); - let typ = Type::Slice(Rc::new(vec![typ])); + let typ = Type::Slice(Arc::new(vec![typ])); let length = FieldElement::from(result_constants.len() as u128); (dfg.make_constant(length, Type::length_type()), dfg.make_array(result_constants.into(), typ)) } diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs index 706e8891cde..7789b212e58 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call/blackbox.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::sync::Arc; use acvm::{acir::AcirField, BlackBoxFunctionSolver, BlackBoxResolutionError, FieldElement}; use iter_extended::vecmap; @@ -45,7 +45,7 @@ pub(super) fn simplify_ec_add( let result_y = dfg.make_constant(result_y, Type::field()); let result_is_infinity = dfg.make_constant(result_is_infinity, Type::bool()); - let typ = Type::Array(Rc::new(vec![Type::field()]), 3); + let typ = Type::Array(Arc::new(vec![Type::field()]), 3); let result_array = dfg.make_array(im::vector![result_x, result_y, result_is_infinity], typ); diff --git a/compiler/noirc_evaluator/src/ssa/ir/map.rs b/compiler/noirc_evaluator/src/ssa/ir/map.rs index f1265553b83..769d52e6e65 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/map.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/map.rs @@ -27,7 +27,7 @@ impl Id { /// Constructs a new Id for the given index. /// This constructor is deliberately private to prevent /// constructing invalid IDs. - fn new(index: usize) -> Self { + pub(crate) fn new(index: usize) -> Self { Self { index, _marker: std::marker::PhantomData } } diff --git a/compiler/noirc_evaluator/src/ssa/ir/types.rs b/compiler/noirc_evaluator/src/ssa/ir/types.rs index e467fa5400d..b7ee37ba17a 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/types.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/types.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use std::rc::Rc; +use std::sync::Arc; use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; @@ -72,13 +72,13 @@ pub(crate) enum Type { Numeric(NumericType), /// A reference to some value, such as an array - Reference(Rc), + Reference(Arc), /// An immutable array value with the given element type and length - Array(Rc, usize), + Array(Arc, usize), /// An immutable slice value with a given element type - Slice(Rc), + Slice(Arc), /// A function that may be called directly Function, @@ -112,7 +112,7 @@ impl Type { /// Creates the str type, of the given length N pub(crate) fn str(length: usize) -> Type { - Type::Array(Rc::new(vec![Type::char()]), length) + Type::Array(Arc::new(vec![Type::char()]), length) } /// Creates the native field type. @@ -190,7 +190,7 @@ impl Type { } } - pub(crate) fn element_types(self) -> Rc> { + pub(crate) fn element_types(self) -> Arc> { match self { Type::Array(element_types, _) | Type::Slice(element_types) => element_types, other => panic!("element_types: Expected array or slice, found {other}"), diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index c8f6d201d86..ff9a63c8d79 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -311,7 +311,7 @@ impl Context { #[cfg(test)] mod test { - use std::rc::Rc; + use std::sync::Arc; use crate::ssa::{ function_builder::FunctionBuilder, @@ -509,7 +509,7 @@ mod test { let one = builder.field_constant(1u128); let v1 = builder.insert_binary(v0, BinaryOp::Add, one); - let array_type = Type::Array(Rc::new(vec![Type::field()]), 1); + let array_type = Type::Array(Arc::new(vec![Type::field()]), 1); let arr = builder.current_function.dfg.make_array(vec![v1].into(), array_type); builder.terminate_with_return(vec![arr]); @@ -601,7 +601,7 @@ mod test { // Compiling main let mut builder = FunctionBuilder::new("main".into(), main_id); - let v0 = builder.add_parameter(Type::Array(Rc::new(vec![Type::field()]), 4)); + let v0 = builder.add_parameter(Type::Array(Arc::new(vec![Type::field()]), 4)); let v1 = builder.add_parameter(Type::unsigned(32)); let v2 = builder.add_parameter(Type::unsigned(1)); let v3 = builder.add_parameter(Type::unsigned(1)); @@ -737,7 +737,7 @@ mod test { let zero = builder.field_constant(0u128); let one = builder.field_constant(1u128); - let typ = Type::Array(Rc::new(vec![Type::field()]), 2); + let typ = Type::Array(Arc::new(vec![Type::field()]), 2); let array = builder.array_constant(vec![zero, one].into(), typ); let _v2 = builder.insert_array_get(array, v1, Type::field()); @@ -787,7 +787,7 @@ mod test { let v0 = builder.add_parameter(Type::bool()); let v1 = builder.add_parameter(Type::bool()); - let v2 = builder.add_parameter(Type::Array(Rc::new(vec![Type::field()]), 2)); + let v2 = builder.add_parameter(Type::Array(Arc::new(vec![Type::field()]), 2)); let zero = builder.numeric_constant(0u128, Type::length_type()); let one = builder.numeric_constant(1u128, Type::length_type()); diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 72ed02b00a8..d5fb98c7adc 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -878,7 +878,7 @@ impl<'f> Context<'f> { #[cfg(test)] mod test { - use std::rc::Rc; + use std::sync::Arc; use acvm::acir::AcirField; @@ -1016,7 +1016,7 @@ mod test { let b2 = builder.insert_block(); let v0 = builder.add_parameter(Type::bool()); - let v1 = builder.add_parameter(Type::Reference(Rc::new(Type::field()))); + let v1 = builder.add_parameter(Type::Reference(Arc::new(Type::field()))); builder.terminate_with_jmpif(v0, b1, b2); @@ -1078,7 +1078,7 @@ mod test { let b3 = builder.insert_block(); let v0 = builder.add_parameter(Type::bool()); - let v1 = builder.add_parameter(Type::Reference(Rc::new(Type::field()))); + let v1 = builder.add_parameter(Type::Reference(Arc::new(Type::field()))); builder.terminate_with_jmpif(v0, b1, b2); @@ -1477,7 +1477,7 @@ mod test { let b2 = builder.insert_block(); let b3 = builder.insert_block(); - let element_type = Rc::new(vec![Type::unsigned(8)]); + let element_type = Arc::new(vec![Type::unsigned(8)]); let array_type = Type::Array(element_type.clone(), 2); let array = builder.add_parameter(array_type); diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index e5a25dcfef1..9d6582c0db7 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -425,7 +425,7 @@ impl<'f> PerFunctionContext<'f> { #[cfg(test)] mod tests { - use std::rc::Rc; + use std::sync::Arc; use acvm::{acir::AcirField, FieldElement}; use im::vector; @@ -454,11 +454,11 @@ mod tests { let func_id = Id::test_new(0); let mut builder = FunctionBuilder::new("func".into(), func_id); - let v0 = builder.insert_allocate(Type::Array(Rc::new(vec![Type::field()]), 2)); + let v0 = builder.insert_allocate(Type::Array(Arc::new(vec![Type::field()]), 2)); let one = builder.field_constant(FieldElement::one()); let two = builder.field_constant(FieldElement::one()); - let element_type = Rc::new(vec![Type::field()]); + let element_type = Arc::new(vec![Type::field()]); let array_type = Type::Array(element_type, 2); let array = builder.array_constant(vector![one, two], array_type.clone()); @@ -672,7 +672,7 @@ mod tests { let zero = builder.field_constant(0u128); builder.insert_store(v0, zero); - let v2 = builder.insert_allocate(Type::Reference(Rc::new(Type::field()))); + let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field()))); builder.insert_store(v2, v0); let v3 = builder.insert_load(v2, Type::field()); diff --git a/compiler/noirc_evaluator/src/ssa/opt/rc.rs b/compiler/noirc_evaluator/src/ssa/opt/rc.rs index 1561547e32e..4f109a27874 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/rc.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/rc.rs @@ -161,7 +161,7 @@ fn remove_instructions(to_remove: HashSet, function: &mut Functio #[cfg(test)] mod test { - use std::rc::Rc; + use std::sync::Arc; use crate::ssa::{ function_builder::FunctionBuilder, @@ -209,14 +209,14 @@ mod test { let mut builder = FunctionBuilder::new("foo".into(), main_id); builder.set_runtime(RuntimeType::Brillig); - let inner_array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let inner_array_type = Type::Array(Arc::new(vec![Type::field()]), 2); let v0 = builder.add_parameter(inner_array_type.clone()); builder.insert_inc_rc(v0); builder.insert_inc_rc(v0); builder.insert_dec_rc(v0); - let outer_array_type = Type::Array(Rc::new(vec![inner_array_type]), 1); + let outer_array_type = Type::Array(Arc::new(vec![inner_array_type]), 1); let array = builder.array_constant(vec![v0].into(), outer_array_type); builder.terminate_with_return(vec![array]); @@ -248,7 +248,7 @@ mod test { let main_id = Id::test_new(0); let mut builder = FunctionBuilder::new("mutator".into(), main_id); - let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let array_type = Type::Array(Arc::new(vec![Type::field()]), 2); let v0 = builder.add_parameter(array_type.clone()); let v1 = builder.insert_allocate(array_type.clone()); @@ -297,8 +297,8 @@ mod test { let main_id = Id::test_new(0); let mut builder = FunctionBuilder::new("mutator2".into(), main_id); - let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); - let reference_type = Type::Reference(Rc::new(array_type.clone())); + let array_type = Type::Array(Arc::new(vec![Type::field()]), 2); + let reference_type = Type::Reference(Arc::new(array_type.clone())); let v0 = builder.add_parameter(reference_type); diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index 628e1bd7410..6ca7a76d740 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, rc::Rc}; +use std::{borrow::Cow, sync::Arc}; use acvm::{acir::AcirField, FieldElement}; @@ -174,7 +174,7 @@ impl Context<'_> { let to_bits = self.function.dfg.import_intrinsic(Intrinsic::ToBits(Endian::Little)); let length = self.field_constant(FieldElement::from(bit_size as i128)); let result_types = - vec![Type::field(), Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)]; + vec![Type::field(), Type::Array(Arc::new(vec![Type::bool()]), bit_size as usize)]; let rhs_bits = self.insert_call(to_bits, vec![rhs, length], result_types); let rhs_bits = rhs_bits[1]; diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index 13e5c2445ad..fb7091a8854 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -1,5 +1,4 @@ -use std::rc::Rc; -use std::sync::{Mutex, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; @@ -198,7 +197,7 @@ impl<'a> FunctionContext<'a> { // A mutable reference wraps each element into a reference. // This can be multiple values if the element type is a tuple. ast::Type::MutableReference(element) => { - Self::map_type_helper(element, &mut |typ| f(Type::Reference(Rc::new(typ)))) + Self::map_type_helper(element, &mut |typ| f(Type::Reference(Arc::new(typ)))) } ast::Type::FmtString(len, fields) => { // A format string is represented by multiple values @@ -213,7 +212,7 @@ impl<'a> FunctionContext<'a> { let element_types = Self::convert_type(elements).flatten(); Tree::Branch(vec![ Tree::Leaf(f(Type::length_type())), - Tree::Leaf(f(Type::Slice(Rc::new(element_types)))), + Tree::Leaf(f(Type::Slice(Arc::new(element_types)))), ]) } other => Tree::Leaf(f(Self::convert_non_tuple_type(other))), @@ -237,7 +236,7 @@ impl<'a> FunctionContext<'a> { ast::Type::Field => Type::field(), ast::Type::Array(len, element) => { let element_types = Self::convert_type(element).flatten(); - Type::Array(Rc::new(element_types), *len as usize) + Type::Array(Arc::new(element_types), *len as usize) } ast::Type::Integer(Signedness::Signed, bits) => Type::signed((*bits).into()), ast::Type::Integer(Signedness::Unsigned, bits) => Type::unsigned((*bits).into()), @@ -253,7 +252,7 @@ impl<'a> FunctionContext<'a> { ast::Type::MutableReference(element) => { // Recursive call to panic if element is a tuple let element = Self::convert_non_tuple_type(element); - Type::Reference(Rc::new(element)) + Type::Reference(Arc::new(element)) } } } diff --git a/tooling/lsp/Cargo.toml b/tooling/lsp/Cargo.toml index 353a6ade904..c15895d801f 100644 --- a/tooling/lsp/Cargo.toml +++ b/tooling/lsp/Cargo.toml @@ -31,7 +31,7 @@ async-lsp = { workspace = true, features = ["omni-trait"] } serde_with = "3.2.0" thiserror.workspace = true fm.workspace = true -rayon = "1.8.0" +rayon.workspace = true fxhash.workspace = true convert_case = "0.6.0" diff --git a/tooling/nargo/Cargo.toml b/tooling/nargo/Cargo.toml index 046eca88099..c5d4bbc9788 100644 --- a/tooling/nargo/Cargo.toml +++ b/tooling/nargo/Cargo.toml @@ -23,7 +23,7 @@ noirc_printable_type.workspace = true iter-extended.workspace = true thiserror.workspace = true tracing.workspace = true -rayon = "1.8.0" +rayon.workspace = true jsonrpc.workspace = true rand.workspace = true serde.workspace = true diff --git a/tooling/nargo_cli/Cargo.toml b/tooling/nargo_cli/Cargo.toml index 4e3f3a57e87..284be56d247 100644 --- a/tooling/nargo_cli/Cargo.toml +++ b/tooling/nargo_cli/Cargo.toml @@ -42,7 +42,7 @@ toml.workspace = true serde.workspace = true serde_json.workspace = true prettytable-rs = "0.10" -rayon = "1.8.0" +rayon.workspace = true thiserror.workspace = true tower.workspace = true async-lsp = { workspace = true, features = ["client-monitor", "stdio", "tracing", "tokio"] } From 4e4ad26d56e6a487ca446ea4e1732c6af04e1410 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 29 Aug 2024 12:06:38 -0300 Subject: [PATCH 17/21] feat: add `Expr::as_assert` (#5857) # Description ## Problem Part of #5668 ## Summary Adds `Expr::as_assert` and handles it in `Expr::mutate`. ## Additional Context None. ## Documentation Check one: - [ ] No documentation needed. - [x] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../src/hir/comptime/interpreter/builtin.rs | 35 +++- docs/docs/noir/standard_library/meta/expr.md | 12 +- noir_stdlib/src/meta/expr.nr | 156 ++++++++++-------- .../inject_context_attribute/src/main.nr | 2 +- .../comptime_expr/src/main.nr | 72 +++++--- 5 files changed, 188 insertions(+), 89 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 852733b6ca8..bc8f473e08d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -21,7 +21,7 @@ use rustc_hash::FxHashMap as HashMap; use crate::{ ast::{ - ArrayLiteral, BlockExpression, Expression, ExpressionKind, FunctionKind, + ArrayLiteral, BlockExpression, ConstrainKind, Expression, ExpressionKind, FunctionKind, FunctionReturnType, IntegerBitSize, LValue, Literal, Statement, StatementKind, UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, }, @@ -54,6 +54,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "assert_constant" => Ok(Value::Bool(true)), "as_slice" => as_slice(interner, arguments, location), "expr_as_array" => expr_as_array(interner, arguments, return_type, location), + "expr_as_assert" => expr_as_assert(interner, arguments, return_type, location), "expr_as_assign" => expr_as_assign(interner, arguments, return_type, location), "expr_as_binary_op" => expr_as_binary_op(interner, arguments, return_type, location), "expr_as_block" => expr_as_block(interner, arguments, return_type, location), @@ -863,6 +864,38 @@ fn expr_as_array( }) } +// fn as_assert(self) -> Option<(Expr, Option)> +fn expr_as_assert( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(interner, arguments, return_type.clone(), location, |expr| { + if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr { + if constrain.2 == ConstrainKind::Assert { + let predicate = Value::expression(constrain.0.kind); + + let option_type = extract_option_generic_type(return_type); + let Type::Tuple(mut tuple_types) = option_type else { + panic!("Expected the return type option generic arg to be a tuple"); + }; + assert_eq!(tuple_types.len(), 2); + + let option_type = tuple_types.pop().unwrap(); + let message = constrain.1.map(|message| Value::expression(message.kind)); + let message = option(option_type, message).ok()?; + + Some(Value::Tuple(vec![predicate, message])) + } else { + None + } + } else { + None + } + }) +} + // fn as_assign(self) -> Option<(Expr, Expr)> fn expr_as_assign( interner: &NodeInterner, diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md index d421e8b56a3..7d3d1c2f453 100644 --- a/docs/docs/noir/standard_library/meta/expr.md +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -12,6 +12,12 @@ title: Expr If this expression is an array, this returns a slice of each element in the array. +### as_assert + +#include_code as_assert noir_stdlib/src/meta/expr.nr rust + +If this expression is an assert, this returns the assert expression and the optional message. + ### as_assign #include_code as_assign noir_stdlib/src/meta/expr.nr rust @@ -162,16 +168,16 @@ comptime { `true` if this expression is `continue`. -### mutate +### modify -#include_code mutate noir_stdlib/src/meta/expr.nr rust +#include_code modify noir_stdlib/src/meta/expr.nr rust Applies a mapping function to this expression and to all of its sub-expressions. `f` will be applied to each sub-expression first, then applied to the expression itself. This happens recursively for every expression within `self`. -For example, calling `mutate` on `(&[1], &[2, 3])` with an `f` that returns `Option::some` +For example, calling `modify` on `(&[1], &[2, 3])` with an `f` that returns `Option::some` for expressions that are integers, doubling them, would return `(&[2], &[4, 6])`. ### quoted diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index c09d9b92c9b..838c96570b5 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -8,6 +8,11 @@ impl Expr { fn as_array(self) -> Option<[Expr]> {} // docs:end:as_array + #[builtin(expr_as_assert)] + // docs:start:as_assert + fn as_assert(self) -> Option<(Expr, Option)> {} + // docs:end:as_assert + #[builtin(expr_as_assign)] // docs:start:as_assign fn as_assign(self) -> Option<(Expr, Expr)> {} @@ -111,26 +116,27 @@ impl Expr { fn is_continue(self) -> bool {} // docs:end:is_continue - // docs:start:mutate - fn mutate(self, f: fn[Env](Expr) -> Option) -> Expr { - // docs:end:mutate - let result = mutate_array(self, f); - let result = result.or_else(|| mutate_assign(self, f)); - let result = result.or_else(|| mutate_binary_op(self, f)); - let result = result.or_else(|| mutate_block(self, f)); - let result = result.or_else(|| mutate_cast(self, f)); - let result = result.or_else(|| mutate_comptime(self, f)); - let result = result.or_else(|| mutate_if(self, f)); - let result = result.or_else(|| mutate_index(self, f)); - let result = result.or_else(|| mutate_function_call(self, f)); - let result = result.or_else(|| mutate_member_access(self, f)); - let result = result.or_else(|| mutate_method_call(self, f)); - let result = result.or_else(|| mutate_repeated_element_array(self, f)); - let result = result.or_else(|| mutate_repeated_element_slice(self, f)); - let result = result.or_else(|| mutate_slice(self, f)); - let result = result.or_else(|| mutate_tuple(self, f)); - let result = result.or_else(|| mutate_unary_op(self, f)); - let result = result.or_else(|| mutate_unsafe(self, f)); + // docs:start:modify + fn modify(self, f: fn[Env](Expr) -> Option) -> Expr { + // docs:end:modify + let result = modify_array(self, f); + let result = result.or_else(|| modify_assert(self, f)); + let result = result.or_else(|| modify_assign(self, f)); + let result = result.or_else(|| modify_binary_op(self, f)); + let result = result.or_else(|| modify_block(self, f)); + let result = result.or_else(|| modify_cast(self, f)); + let result = result.or_else(|| modify_comptime(self, f)); + let result = result.or_else(|| modify_if(self, f)); + let result = result.or_else(|| modify_index(self, f)); + let result = result.or_else(|| modify_function_call(self, f)); + let result = result.or_else(|| modify_member_access(self, f)); + let result = result.or_else(|| modify_method_call(self, f)); + let result = result.or_else(|| modify_repeated_element_array(self, f)); + let result = result.or_else(|| modify_repeated_element_slice(self, f)); + let result = result.or_else(|| modify_slice(self, f)); + let result = result.or_else(|| modify_tuple(self, f)); + let result = result.or_else(|| modify_unary_op(self, f)); + let result = result.or_else(|| modify_unsafe(self, f)); if result.is_some() { let result = result.unwrap_unchecked(); let modified = f(result); @@ -147,181 +153,192 @@ impl Expr { } } -fn mutate_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_array().map( |exprs: [Expr]| { - let exprs = mutate_expressions(exprs, f); + let exprs = modify_expressions(exprs, f); new_array(exprs) } ) } -fn mutate_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_assert(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_assert().map( + |expr: (Expr, Option)| { + let (predicate, msg) = expr; + let predicate = predicate.modify(f); + let msg = msg.map(|msg: Expr| msg.modify(f)); + new_assert(predicate, msg) + } + ) +} + +fn modify_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_assign().map( |expr: (Expr, Expr)| { let (lhs, rhs) = expr; - let lhs = lhs.mutate(f); - let rhs = rhs.mutate(f); + let lhs = lhs.modify(f); + let rhs = rhs.modify(f); new_assign(lhs, rhs) } ) } -fn mutate_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_binary_op().map( |expr: (Expr, BinaryOp, Expr)| { let (lhs, op, rhs) = expr; - let lhs = lhs.mutate(f); - let rhs = rhs.mutate(f); + let lhs = lhs.modify(f); + let rhs = rhs.modify(f); new_binary_op(lhs, op, rhs) } ) } -fn mutate_block(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_block(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_block().map( |exprs: [Expr]| { - let exprs = mutate_expressions(exprs, f); + let exprs = modify_expressions(exprs, f); new_block(exprs) } ) } -fn mutate_cast(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_cast(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_cast().map( |expr: (Expr, UnresolvedType)| { let (expr, typ) = expr; - let expr = expr.mutate(f); + let expr = expr.modify(f); new_cast(expr, typ) } ) } -fn mutate_comptime(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_comptime(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_comptime().map( |exprs: [Expr]| { - let exprs = exprs.map(|expr: Expr| expr.mutate(f)); + let exprs = exprs.map(|expr: Expr| expr.modify(f)); new_comptime(exprs) } ) } -fn mutate_function_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_function_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_function_call().map( |expr: (Expr, [Expr])| { let (function, arguments) = expr; - let function = function.mutate(f); - let arguments = arguments.map(|arg: Expr| arg.mutate(f)); + let function = function.modify(f); + let arguments = arguments.map(|arg: Expr| arg.modify(f)); new_function_call(function, arguments) } ) } -fn mutate_if(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_if(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_if().map( |expr: (Expr, Expr, Option)| { let (condition, consequence, alternative) = expr; - let condition = condition.mutate(f); - let consequence = consequence.mutate(f); - let alternative = alternative.map(|alternative: Expr| alternative.mutate(f)); + let condition = condition.modify(f); + let consequence = consequence.modify(f); + let alternative = alternative.map(|alternative: Expr| alternative.modify(f)); new_if(condition, consequence, alternative) } ) } -fn mutate_index(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_index(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_index().map( |expr: (Expr, Expr)| { let (object, index) = expr; - let object = object.mutate(f); - let index = index.mutate(f); + let object = object.modify(f); + let index = index.modify(f); new_index(object, index) } ) } -fn mutate_member_access(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_member_access(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_member_access().map( |expr: (Expr, Quoted)| { let (object, name) = expr; - let object = object.mutate(f); + let object = object.modify(f); new_member_access(object, name) } ) } -fn mutate_method_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_method_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_method_call().map( |expr: (Expr, Quoted, [UnresolvedType], [Expr])| { let (object, name, generics, arguments) = expr; - let object = object.mutate(f); - let arguments = arguments.map(|arg: Expr| arg.mutate(f)); + let object = object.modify(f); + let arguments = arguments.map(|arg: Expr| arg.modify(f)); new_method_call(object, name, generics, arguments) } ) } -fn mutate_repeated_element_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_repeated_element_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_repeated_element_array().map( |expr: (Expr, Expr)| { let (expr, length) = expr; - let expr = expr.mutate(f); - let length = length.mutate(f); + let expr = expr.modify(f); + let length = length.modify(f); new_repeated_element_array(expr, length) } ) } -fn mutate_repeated_element_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_repeated_element_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_repeated_element_slice().map( |expr: (Expr, Expr)| { let (expr, length) = expr; - let expr = expr.mutate(f); - let length = length.mutate(f); + let expr = expr.modify(f); + let length = length.modify(f); new_repeated_element_slice(expr, length) } ) } -fn mutate_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_slice().map( |exprs: [Expr]| { - let exprs = mutate_expressions(exprs, f); + let exprs = modify_expressions(exprs, f); new_slice(exprs) } ) } -fn mutate_tuple(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_tuple(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_tuple().map( |exprs: [Expr]| { - let exprs = mutate_expressions(exprs, f); + let exprs = modify_expressions(exprs, f); new_tuple(exprs) } ) } -fn mutate_unary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_unary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_unary_op().map( |expr: (UnaryOp, Expr)| { let (op, rhs) = expr; - let rhs = rhs.mutate(f); + let rhs = rhs.modify(f); new_unary_op(op, rhs) } ) } -fn mutate_unsafe(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { +fn modify_unsafe(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_unsafe().map( |exprs: [Expr]| { - let exprs = exprs.map(|expr: Expr| expr.mutate(f)); + let exprs = exprs.map(|expr: Expr| expr.modify(f)); new_unsafe(exprs) } ) } -fn mutate_expressions(exprs: [Expr], f: fn[Env](Expr) -> Option) -> [Expr] { - exprs.map(|expr: Expr| expr.mutate(f)) +fn modify_expressions(exprs: [Expr], f: fn[Env](Expr) -> Option) -> [Expr] { + exprs.map(|expr: Expr| expr.modify(f)) } fn new_array(exprs: [Expr]) -> Expr { @@ -329,6 +346,15 @@ fn new_array(exprs: [Expr]) -> Expr { quote { [$exprs]}.as_expr().unwrap() } +fn new_assert(predicate: Expr, msg: Option) -> Expr { + if msg.is_some() { + let msg = msg.unwrap(); + quote { assert($predicate, $msg) }.as_expr().unwrap() + } else { + quote { assert($predicate) }.as_expr().unwrap() + } +} + fn new_assign(lhs: Expr, rhs: Expr) -> Expr { quote { $lhs = $rhs }.as_expr().unwrap() } diff --git a/test_programs/compile_success_empty/inject_context_attribute/src/main.nr b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr index 65003ed837b..594b37ce072 100644 --- a/test_programs/compile_success_empty/inject_context_attribute/src/main.nr +++ b/test_programs/compile_success_empty/inject_context_attribute/src/main.nr @@ -31,7 +31,7 @@ fn inject_context(f: FunctionDefinition) { f.set_parameters(parameters); // Create a new body where every function call has `_context` added to the list of arguments. - let body = f.body().mutate(mapping_function); + let body = f.body().modify(mapping_function); f.set_body(body); } diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index abc7a793fd1..e7e156ad1f1 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -20,7 +20,7 @@ mod tests { comptime { let expr = quote { [1, 2, 4] }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let elems = expr.as_array().unwrap(); assert_eq(elems.len(), 3); assert_eq(elems[0].as_integer().unwrap(), (2, false)); @@ -29,6 +29,40 @@ mod tests { } } + #[test] + fn test_expr_as_assert() { + comptime + { + let expr = quote { assert(true) }.as_expr().unwrap(); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_bool().unwrap(), true); + assert(msg.is_none()); + + let expr = quote { assert(false, "oops") }.as_expr().unwrap(); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_bool().unwrap(), false); + assert(msg.is_some()); + } + } + + #[test] + fn test_expr_mutate_for_assert() { + comptime + { + let expr = quote { assert(1) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_integer().unwrap(), (2, false)); + assert(msg.is_none()); + + let expr = quote { assert(1, 2) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (predicate, msg) = expr.as_assert().unwrap(); + assert_eq(predicate.as_integer().unwrap(), (2, false)); + assert_eq(msg.unwrap().as_integer().unwrap(), (4, false)); + } + } + #[test] fn test_expr_as_assign() { comptime @@ -45,7 +79,7 @@ mod tests { comptime { let expr = quote { { a = 1; } }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let exprs = expr.as_block().unwrap(); let (_lhs, rhs) = exprs[0].as_assign().unwrap(); assert_eq(rhs.as_integer().unwrap(), (2, false)); @@ -74,7 +108,7 @@ mod tests { comptime { let expr = quote { { 1; 4; 23 } }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let exprs = expr.as_block().unwrap(); assert_eq(exprs.len(), 3); assert_eq(exprs[0].as_integer().unwrap(), (2, false)); @@ -110,7 +144,7 @@ mod tests { comptime { let expr = quote { foo.bar(3, 4) }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (_object, name, generics, arguments) = expr.as_method_call().unwrap(); @@ -141,7 +175,7 @@ mod tests { comptime { let expr = quote { 1 }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); assert_eq((2, false), expr.as_integer().unwrap()); } @@ -175,7 +209,7 @@ mod tests { comptime { let expr = quote { 3 + 4 }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (lhs, op, rhs) = expr.as_binary_op().unwrap(); assert_eq(lhs.as_integer().unwrap(), (6, false)); @@ -212,7 +246,7 @@ mod tests { comptime { let expr = quote { 1 as Field }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (expr, typ) = expr.as_cast().unwrap(); assert_eq(expr.as_integer().unwrap(), (2, false)); assert(typ.is_field()); @@ -234,7 +268,7 @@ mod tests { comptime { let expr = quote { comptime { 1; 4; 23 } }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let exprs = expr.as_comptime().unwrap(); assert_eq(exprs.len(), 3); assert_eq(exprs[0].as_integer().unwrap(), (2, false)); @@ -274,7 +308,7 @@ mod tests { comptime { let expr = quote { foo(42) }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (_function, args) = expr.as_function_call().unwrap(); assert_eq(args.len(), 1); assert_eq(args[0].as_integer().unwrap(), (84, false)); @@ -300,7 +334,7 @@ mod tests { comptime { let expr = quote { if 1 { 2 } }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (condition, consequence, alternative) = expr.as_if().unwrap(); assert_eq(condition.as_integer().unwrap(), (2, false)); let consequence = consequence.as_block().unwrap()[0].as_block().unwrap()[0]; @@ -308,7 +342,7 @@ mod tests { assert(alternative.is_none()); let expr = quote { if 1 { 2 } else { 3 } }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (condition, consequence, alternative) = expr.as_if().unwrap(); assert_eq(condition.as_integer().unwrap(), (2, false)); let consequence = consequence.as_block().unwrap()[0].as_block().unwrap()[0]; @@ -332,7 +366,7 @@ mod tests { comptime { let expr = quote { 1[2] }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (object, index) = expr.as_index().unwrap(); assert_eq(object.as_integer().unwrap(), (2, false)); assert_eq(index.as_integer().unwrap(), (4, false)); @@ -354,7 +388,7 @@ mod tests { comptime { let expr = quote { 1.bar }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (expr, name) = expr.as_member_access().unwrap(); assert_eq(name, quote { bar }); assert_eq(expr.as_integer().unwrap(), (2, false)); @@ -389,7 +423,7 @@ mod tests { comptime { let expr = quote { [1; 3] }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (expr, length) = expr.as_repeated_element_array().unwrap(); assert_eq(expr.as_integer().unwrap(), (2, false)); assert_eq(length.as_integer().unwrap(), (6, false)); @@ -412,7 +446,7 @@ mod tests { comptime { let expr = quote { &[1; 3] }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (expr, length) = expr.as_repeated_element_slice().unwrap(); assert_eq(expr.as_integer().unwrap(), (2, false)); assert_eq(length.as_integer().unwrap(), (6, false)); @@ -437,7 +471,7 @@ mod tests { comptime { let expr = quote { &[1, 3, 5] }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let elems = expr.as_slice().unwrap(); assert_eq(elems.len(), 3); assert_eq(elems[0].as_integer().unwrap(), (2, false)); @@ -461,7 +495,7 @@ mod tests { comptime { let expr = quote { (1, 2) }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let tuple_exprs = expr.as_tuple().unwrap(); assert_eq(tuple_exprs.len(), 2); assert_eq(tuple_exprs[0].as_integer().unwrap(), (2, false)); @@ -485,7 +519,7 @@ mod tests { comptime { let expr = quote { -(1) }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let (op, expr) = expr.as_unary_op().unwrap(); assert(op.is_minus()); assert_eq(expr.as_integer().unwrap(), (2, false)); @@ -507,7 +541,7 @@ mod tests { comptime { let expr = quote { unsafe { 1; 4; 23 } }.as_expr().unwrap(); - let expr = expr.mutate(times_two); + let expr = expr.modify(times_two); let exprs = expr.as_unsafe().unwrap(); assert_eq(exprs.len(), 3); assert_eq(exprs[0].as_integer().unwrap(), (2, false)); From cfd68d4c1bd1a2319698fca99d200a5d86ffa771 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 29 Aug 2024 12:25:08 -0300 Subject: [PATCH 18/21] feat: show backtrace on comptime assertion failures (#5842) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description ## Problem Resolves #5828 ## Summary `CustomDiagnostic` didn't have a `call_stack` property because so far this wasn't needed. I only added it for `FailingConstraint`, which is, _I think_, the only case where a backtrace might be needed. In all other cases (for example "wrong number of arguments" it might be clear why it's failing without needing a backtrace). Here's an example output: ``` error: Assertion failed ┌─ std/option.nr:33:16 │ 33 │ assert(self._is_some); │ ------------- │ ┌─ src/main.nr:8:9 │ 8 │ foo(); │ ----- │ ┌─ src/other.nr:2:5 │ 2 │ bar(); │ ----- · 7 │ let _ = expr.as_integer().unwrap(); │ -------------------------- │ Aborting due to 1 previous error ``` ## Additional Context None. ## Documentation Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- Cargo.lock | 1 + aztec_macros/Cargo.toml | 1 + compiler/noirc_errors/src/reporter.rs | 18 +++++++---- .../noirc_frontend/src/elaborator/comptime.rs | 1 + compiler/noirc_frontend/src/elaborator/mod.rs | 5 ++++ .../noirc_frontend/src/hir/comptime/errors.rs | 13 ++++++-- .../src/hir/comptime/interpreter.rs | 9 ++++-- .../src/hir/comptime/interpreter/builtin.rs | 30 +++++++++++++------ 8 files changed, 59 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 279d0b59ce1..cd936e4bca2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,7 @@ version = "0.33.0" dependencies = [ "acvm", "convert_case 0.6.0", + "im", "iter-extended", "noirc_errors", "noirc_frontend", diff --git a/aztec_macros/Cargo.toml b/aztec_macros/Cargo.toml index c9d88e36e28..258379cd7b8 100644 --- a/aztec_macros/Cargo.toml +++ b/aztec_macros/Cargo.toml @@ -18,5 +18,6 @@ noirc_frontend.workspace = true noirc_errors.workspace = true iter-extended.workspace = true convert_case = "0.6.0" +im.workspace = true regex = "1.10" tiny-keccak = { version = "2.0.0", features = ["keccak"] } diff --git a/compiler/noirc_errors/src/reporter.rs b/compiler/noirc_errors/src/reporter.rs index 3ce0f268715..b21dc759f14 100644 --- a/compiler/noirc_errors/src/reporter.rs +++ b/compiler/noirc_errors/src/reporter.rs @@ -46,7 +46,7 @@ impl CustomDiagnostic { ) -> CustomDiagnostic { CustomDiagnostic { message: primary_message, - secondaries: vec![CustomLabel::new(secondary_message, secondary_span)], + secondaries: vec![CustomLabel::new(secondary_message, secondary_span, None)], notes: Vec::new(), kind, } @@ -98,7 +98,7 @@ impl CustomDiagnostic { ) -> CustomDiagnostic { CustomDiagnostic { message: primary_message, - secondaries: vec![CustomLabel::new(secondary_message, secondary_span)], + secondaries: vec![CustomLabel::new(secondary_message, secondary_span, None)], notes: Vec::new(), kind: DiagnosticKind::Bug, } @@ -113,7 +113,11 @@ impl CustomDiagnostic { } pub fn add_secondary(&mut self, message: String, span: Span) { - self.secondaries.push(CustomLabel::new(message, span)); + self.secondaries.push(CustomLabel::new(message, span, None)); + } + + pub fn add_secondary_with_file(&mut self, message: String, span: Span, file: fm::FileId) { + self.secondaries.push(CustomLabel::new(message, span, Some(file))); } pub fn is_error(&self) -> bool { @@ -153,11 +157,12 @@ impl std::fmt::Display for CustomDiagnostic { pub struct CustomLabel { pub message: String, pub span: Span, + pub file: Option, } impl CustomLabel { - fn new(message: String, span: Span) -> CustomLabel { - CustomLabel { message, span } + fn new(message: String, span: Span, file: Option) -> CustomLabel { + CustomLabel { message, span, file } } } @@ -234,7 +239,8 @@ fn convert_diagnostic( .map(|sl| { let start_span = sl.span.start() as usize; let end_span = sl.span.end() as usize; - Label::secondary(file_id, start_span..end_span).with_message(&sl.message) + let file = sl.file.unwrap_or(file_id); + Label::secondary(file, start_span..end_span).with_message(&sl.message) }) .collect() } else { diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index 01b4585640f..12099b556b7 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -46,6 +46,7 @@ impl<'context> Elaborator<'context> { self.crate_id, self.debug_comptime_in_file, self.enable_arithmetic_generics, + self.interpreter_call_stack.clone(), ); elaborator.function_context.push(FunctionContext::default()); diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index e8b38193223..5bbd5f00ca8 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -168,6 +168,8 @@ pub struct Elaborator<'context> { /// Temporary flag to enable the experimental arithmetic generics feature enable_arithmetic_generics: bool, + + pub(crate) interpreter_call_stack: im::Vector, } #[derive(Default)] @@ -191,6 +193,7 @@ impl<'context> Elaborator<'context> { crate_id: CrateId, debug_comptime_in_file: Option, enable_arithmetic_generics: bool, + interpreter_call_stack: im::Vector, ) -> Self { Self { scopes: ScopeForest::default(), @@ -214,6 +217,7 @@ impl<'context> Elaborator<'context> { unresolved_globals: BTreeMap::new(), enable_arithmetic_generics, current_trait: None, + interpreter_call_stack, } } @@ -229,6 +233,7 @@ impl<'context> Elaborator<'context> { crate_id, debug_comptime_in_file, enable_arithmetic_generics, + im::Vector::new(), ) } diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index fd916485eaf..cfee6bcedac 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -56,6 +56,7 @@ pub enum InterpreterError { FailingConstraint { message: Option, location: Location, + call_stack: im::Vector, }, NoMethodFound { name: String, @@ -353,12 +354,20 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let msg = format!("Expected a `bool` but found `{typ}`"); CustomDiagnostic::simple_error(msg, String::new(), location.span) } - InterpreterError::FailingConstraint { message, location } => { + InterpreterError::FailingConstraint { message, location, call_stack } => { let (primary, secondary) = match message { Some(msg) => (msg.clone(), "Assertion failed".into()), None => ("Assertion failed".into(), String::new()), }; - CustomDiagnostic::simple_error(primary, secondary, location.span) + let mut diagnostic = + CustomDiagnostic::simple_error(primary, secondary, location.span); + + // Only take at most 3 frames starting from the top of the stack to avoid producing too much output + for frame in call_stack.iter().rev().take(3) { + diagnostic.add_secondary_with_file("".to_string(), frame.span, frame.file); + } + + diagnostic } InterpreterError::NoMethodFound { name, typ, location } => { let msg = format!("No method named `{name}` found for type `{typ}`"); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 33f8c9d8332..4980045c68d 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -70,7 +70,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { current_function: Option, ) -> Self { let bound_generics = Vec::new(); - Self { elaborator, crate_id, current_function, bound_generics, in_loop: false } + let in_loop = false; + Self { elaborator, crate_id, current_function, bound_generics, in_loop } } pub(crate) fn call_function( @@ -99,8 +100,11 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { } self.remember_bindings(&instantiation_bindings, &impl_bindings); + self.elaborator.interpreter_call_stack.push_back(location); + let result = self.call_function_inner(function, arguments, location); + self.elaborator.interpreter_call_stack.pop_back(); undo_instantiation_bindings(impl_bindings); undo_instantiation_bindings(instantiation_bindings); self.rebind_generics_from_previous_function(); @@ -1462,7 +1466,8 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let message = constrain.2.and_then(|expr| self.evaluate(expr).ok()); let message = message.map(|value| value.display(self.elaborator.interner).to_string()); - Err(InterpreterError::FailingConstraint { location, message }) + let call_stack = self.elaborator.interpreter_call_stack.clone(); + Err(InterpreterError::FailingConstraint { location, message, call_stack }) } value => { let location = self.elaborator.interner.expr_location(&constrain.0); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index bc8f473e08d..5ffc58004be 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -48,6 +48,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { location: Location, ) -> IResult { let interner = &mut self.elaborator.interner; + let call_stack = &self.elaborator.interpreter_call_stack; match name { "array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location), "array_len" => array_len(interner, arguments, location), @@ -110,11 +111,11 @@ impl<'local, 'context> Interpreter<'local, 'context> { "quoted_as_type" => quoted_as_type(self, arguments, location), "quoted_eq" => quoted_eq(arguments, location), "slice_insert" => slice_insert(interner, arguments, location), - "slice_pop_back" => slice_pop_back(interner, arguments, location), - "slice_pop_front" => slice_pop_front(interner, arguments, location), + "slice_pop_back" => slice_pop_back(interner, arguments, location, call_stack), + "slice_pop_front" => slice_pop_front(interner, arguments, location, call_stack), "slice_push_back" => slice_push_back(interner, arguments, location), "slice_push_front" => slice_push_front(interner, arguments, location), - "slice_remove" => slice_remove(interner, arguments, location), + "slice_remove" => slice_remove(interner, arguments, location, call_stack), "struct_def_as_type" => struct_def_as_type(interner, arguments, location), "struct_def_fields" => struct_def_fields(interner, arguments, location), "struct_def_generics" => struct_def_generics(interner, arguments, location), @@ -154,8 +155,16 @@ impl<'local, 'context> Interpreter<'local, 'context> { } } -fn failing_constraint(message: impl Into, location: Location) -> IResult { - Err(InterpreterError::FailingConstraint { message: Some(message.into()), location }) +fn failing_constraint( + message: impl Into, + location: Location, + call_stack: &im::Vector, +) -> IResult { + Err(InterpreterError::FailingConstraint { + message: Some(message.into()), + location, + call_stack: call_stack.clone(), + }) } fn array_len( @@ -287,6 +296,7 @@ fn slice_remove( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, + call_stack: &im::Vector, ) -> IResult { let (slice, index) = check_two_arguments(arguments, location)?; @@ -294,7 +304,7 @@ fn slice_remove( let index = get_u32(index)? as usize; if values.is_empty() { - return failing_constraint("slice_remove called on empty slice", location); + return failing_constraint("slice_remove called on empty slice", location, call_stack); } if index >= values.len() { @@ -302,7 +312,7 @@ fn slice_remove( "slice_remove: index {index} is out of bounds for a slice of length {}", values.len() ); - return failing_constraint(message, location); + return failing_constraint(message, location, call_stack); } let element = values.remove(index); @@ -325,13 +335,14 @@ fn slice_pop_front( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, + call_stack: &im::Vector, ) -> IResult { let argument = check_one_argument(arguments, location)?; let (mut values, typ) = get_slice(interner, argument)?; match values.pop_front() { Some(element) => Ok(Value::Tuple(vec![element, Value::Slice(values, typ)])), - None => failing_constraint("slice_pop_front called on empty slice", location), + None => failing_constraint("slice_pop_front called on empty slice", location, call_stack), } } @@ -339,13 +350,14 @@ fn slice_pop_back( interner: &mut NodeInterner, arguments: Vec<(Value, Location)>, location: Location, + call_stack: &im::Vector, ) -> IResult { let argument = check_one_argument(arguments, location)?; let (mut values, typ) = get_slice(interner, argument)?; match values.pop_back() { Some(element) => Ok(Value::Tuple(vec![Value::Slice(values, typ), element])), - None => failing_constraint("slice_pop_back called on empty slice", location), + None => failing_constraint("slice_pop_back called on empty slice", location, call_stack), } } From 46e266a5229dada42ee397beb0d39322451b1458 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Thu, 29 Aug 2024 14:59:44 -0400 Subject: [PATCH 19/21] fix(sha256): Add extra checks against message size when constructing msg blocks (#5861) # Description ## Problem\* Resolves Issue found in zk passport https://github.com/ocelots-app/passport-verifier/blob/47e9464e7e782b07b6d791bf1d13257fce2f486b/crates/lib/data-check/integrity/src/lib.nr#L118 when performing sha on a message with a large padding. ## Summary\* The current sha algorithm accounts for message padding, but only where ithe padding is still contained in the block we are compressing. For the case where we have a padding that extends multiple blocks past the message size we end up with a correctness error. We need to add more checks against the message size to make sure we are comrpessing the correct msg block. An increase in gate count is expected from these changes. ## Additional Context ## Documentation\* Check one: - [X] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [X] I have tested the changes locally. - [X] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: Michael J Klein --- noir_stdlib/src/hash/sha256.nr | 45 +++++++++++-------- .../sha256_regression/Prover.toml | 5 +++ .../sha256_regression/src/main.nr | 15 ++++++- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/noir_stdlib/src/hash/sha256.nr b/noir_stdlib/src/hash/sha256.nr index d0e3d5e88c5..50cf2f965f9 100644 --- a/noir_stdlib/src/hash/sha256.nr +++ b/noir_stdlib/src/hash/sha256.nr @@ -67,18 +67,11 @@ fn verify_msg_block( for k in msg_start..msg_end { if k as u64 < message_size { + assert_eq(msg_block[msg_byte_ptr], msg[k]); msg_byte_ptr = msg_byte_ptr + 1; } } - for i in 0..BLOCK_SIZE { - if i as u64 >= msg_byte_ptr { - assert_eq(msg_block[i], 0); - } else { - assert_eq(msg_block[i], msg[msg_start + i - extra_bytes]); - } - } - msg_byte_ptr } @@ -93,20 +86,29 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { let mut msg_byte_ptr = 0; // Pointer into msg_block for i in 0..num_blocks { + let msg_start = BLOCK_SIZE * i; let (new_msg_block, new_msg_byte_ptr) = unsafe { - build_msg_block_iter(msg, message_size, BLOCK_SIZE * i) + build_msg_block_iter(msg, message_size, msg_start) }; - msg_block = new_msg_block; + if msg_start as u64 < message_size { + msg_block = new_msg_block; + } if !is_unconstrained() { // Verify the block we are compressing was appropriately constructed - msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i); - } else { + let new_msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, msg_start); + if msg_start as u64 < message_size { + msg_byte_ptr = new_msg_byte_ptr; + } + } else if msg_start as u64 < message_size { msg_byte_ptr = new_msg_byte_ptr; } - // Compress the block - h = sha256_compression(msg_u8_to_u32(msg_block), h); + // If the block is filled, compress it. + // An un-filled block is handled after this loop. + if msg_byte_ptr == 64 { + h = sha256_compression(msg_u8_to_u32(msg_block), h); + } } let modulo = N % BLOCK_SIZE; @@ -114,14 +116,21 @@ pub fn sha256_var(msg: [u8; N], message_size: u64) -> [u8; 32] { // This case is only hit if the msg is less than the block size, // or our message cannot be evenly split into blocks. if modulo != 0 { + let msg_start = BLOCK_SIZE * num_blocks; let (new_msg_block, new_msg_byte_ptr) = unsafe { - build_msg_block_iter(msg, message_size, BLOCK_SIZE * num_blocks) + build_msg_block_iter(msg, message_size, msg_start) }; - msg_block = new_msg_block; + + if msg_start as u64 < message_size { + msg_block = new_msg_block; + } if !is_unconstrained() { - msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * num_blocks); - } else { + let new_msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, msg_start); + if msg_start as u64 < message_size { + msg_byte_ptr = new_msg_byte_ptr; + } + } else if msg_start as u64 < message_size { msg_byte_ptr = new_msg_byte_ptr; } } diff --git a/test_programs/execution_success/sha256_regression/Prover.toml b/test_programs/execution_success/sha256_regression/Prover.toml index ba0aadd1b75..ea0a0f2e8a7 100644 --- a/test_programs/execution_success/sha256_regression/Prover.toml +++ b/test_programs/execution_success/sha256_regression/Prover.toml @@ -2,8 +2,13 @@ msg_just_over_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114 msg_multiple_of_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99] msg_just_under_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59] msg_big_not_block_multiple = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99, 116, 58, 77, 101, 115, 115, 97, 103, 101, 45, 73, 100, 58, 68, 97, 116, 101, 58, 116, 111, 59, 32, 98, 61] +msg_big_with_padding = [48, 130, 1, 37, 2, 1, 0, 48, 11, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 48, 130, 1, 17, 48, 37, 2, 1, 1, 4, 32, 176, 223, 31, 133, 108, 84, 158, 102, 70, 11, 165, 175, 196, 12, 201, 130, 25, 131, 46, 125, 156, 194, 28, 23, 55, 133, 157, 164, 135, 136, 220, 78, 48, 37, 2, 1, 2, 4, 32, 190, 82, 180, 235, 222, 33, 79, 50, 152, 136, 142, 35, 116, 224, 6, 242, 156, 141, 128, 248, 10, 61, 98, 86, 248, 45, 207, 210, 90, 232, 175, 38, 48, 37, 2, 1, 3, 4, 32, 0, 194, 104, 108, 237, 246, 97, 230, 116, 198, 69, 110, 26, 87, 17, 89, 110, 199, 108, 250, 36, 21, 39, 87, 110, 102, 250, 213, 174, 131, 171, 174, 48, 37, 2, 1, 11, 4, 32, 136, 155, 87, 144, 111, 15, 152, 127, 85, 25, 154, 81, 20, 58, 51, 75, 193, 116, 234, 0, 60, 30, 29, 30, 183, 141, 72, 247, 255, 203, 100, 124, 48, 37, 2, 1, 12, 4, 32, 41, 234, 106, 78, 31, 11, 114, 137, 237, 17, 92, 71, 134, 47, 62, 78, 189, 233, 201, 214, 53, 4, 47, 189, 201, 133, 6, 121, 34, 131, 64, 142, 48, 37, 2, 1, 13, 4, 32, 91, 222, 210, 193, 62, 222, 104, 82, 36, 41, 138, 253, 70, 15, 148, 208, 156, 45, 105, 171, 241, 195, 185, 43, 217, 162, 146, 201, 222, 89, 238, 38, 48, 37, 2, 1, 14, 4, 32, 76, 123, 216, 13, 51, 227, 72, 245, 59, 193, 238, 166, 103, 49, 23, 164, 171, 188, 194, 197, 156, 187, 249, 28, 198, 95, 69, 15, 182, 56, 54, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +msg_big_no_padding = [48, 130, 1, 37, 2, 1, 0, 48, 11, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 48, 130, 1, 17, 48, 37, 2, 1, 1, 4, 32, 176, 223, 31, 133, 108, 84, 158, 102, 70, 11, 165, 175, 196, 12, 201, 130, 25, 131, 46, 125, 156, 194, 28, 23, 55, 133, 157, 164, 135, 136, 220, 78, 48, 37, 2, 1, 2, 4, 32, 190, 82, 180, 235, 222, 33, 79, 50, 152, 136, 142, 35, 116, 224, 6, 242, 156, 141, 128, 248, 10, 61, 98, 86, 248, 45, 207, 210, 90, 232, 175, 38, 48, 37, 2, 1, 3, 4, 32, 0, 194, 104, 108, 237, 246, 97, 230, 116, 198, 69, 110, 26, 87, 17, 89, 110, 199, 108, 250, 36, 21, 39, 87, 110, 102, 250, 213, 174, 131, 171, 174, 48, 37, 2, 1, 11, 4, 32, 136, 155, 87, 144, 111, 15, 152, 127, 85, 25, 154, 81, 20, 58, 51, 75, 193, 116, 234, 0, 60, 30, 29, 30, 183, 141, 72, 247, 255, 203, 100, 124, 48, 37, 2, 1, 12, 4, 32, 41, 234, 106, 78, 31, 11, 114, 137, 237, 17, 92, 71, 134, 47, 62, 78, 189, 233, 201, 214, 53, 4, 47, 189, 201, 133, 6, 121, 34, 131, 64, 142, 48, 37, 2, 1, 13, 4, 32, 91, 222, 210, 193, 62, 222, 104, 82, 36, 41, 138, 253, 70, 15, 148, 208, 156, 45, 105, 171, 241, 195, 185, 43, 217, 162, 146, 201, 222, 89, 238, 38, 48, 37, 2, 1, 14, 4, 32, 76, 123, 216, 13, 51, 227, 72, 245, 59, 193, 238, 166, 103, 49, 23, 164, 171, 188, 194, 197, 156, 187, 249, 28, 198, 95, 69, 15, 182, 56, 54, 38] +message_size = 297 + # Results matched against ethers library result_just_over_block = [91, 122, 146, 93, 52, 109, 133, 148, 171, 61, 156, 70, 189, 238, 153, 7, 222, 184, 94, 24, 65, 114, 192, 244, 207, 199, 87, 232, 192, 224, 171, 207] result_multiple_of_block = [116, 90, 151, 31, 78, 22, 138, 180, 211, 189, 69, 76, 227, 200, 155, 29, 59, 123, 154, 60, 47, 153, 203, 129, 157, 251, 48, 2, 79, 11, 65, 47] result_just_under_block = [143, 140, 76, 173, 222, 123, 102, 68, 70, 149, 207, 43, 39, 61, 34, 79, 216, 252, 213, 165, 74, 16, 110, 74, 29, 64, 138, 167, 30, 1, 9, 119] result_big = [112, 144, 73, 182, 208, 98, 9, 238, 54, 229, 61, 145, 222, 17, 72, 62, 148, 222, 186, 55, 192, 82, 220, 35, 66, 47, 193, 200, 22, 38, 26, 186] +result_big_with_padding = [32, 85, 108, 174, 127, 112, 178, 182, 8, 43, 134, 123, 192, 211, 131, 66, 184, 240, 212, 181, 240, 180, 106, 195, 24, 117, 54, 129, 19, 10, 250, 53] \ No newline at end of file diff --git a/test_programs/execution_success/sha256_regression/src/main.nr b/test_programs/execution_success/sha256_regression/src/main.nr index 855931b4300..83049640ac4 100644 --- a/test_programs/execution_success/sha256_regression/src/main.nr +++ b/test_programs/execution_success/sha256_regression/src/main.nr @@ -1,3 +1,4 @@ +// A bunch of different test cases for sha256_var in the stdlib fn main( msg_just_over_block: [u8; 68], result_just_over_block: pub [u8; 32], @@ -7,7 +8,13 @@ fn main( msg_just_under_block: [u8; 60], result_just_under_block: pub [u8; 32], msg_big_not_block_multiple: [u8; 472], - result_big: pub [u8; 32] + result_big: pub [u8; 32], + // This message is only 297 elements and we want to hash only a variable amount + msg_big_with_padding: [u8; 700], + // This is the same as `msg_big_with_padding` but with no padding + msg_big_no_padding: [u8; 297], + message_size: u64, + result_big_with_padding: pub [u8; 32] ) { let hash = std::hash::sha256_var(msg_just_over_block, msg_just_over_block.len() as u64); assert_eq(hash, result_just_over_block); @@ -23,4 +30,10 @@ fn main( msg_big_not_block_multiple.len() as u64 ); assert_eq(hash, result_big); + + let hash_padding = std::hash::sha256_var(msg_big_with_padding, message_size); + assert_eq(hash_padding, result_big_with_padding); + + let hash_no_padding = std::hash::sha256_var(msg_big_no_padding, message_size); + assert_eq(hash_no_padding, result_big_with_padding); } From 663e00cffcb2cd66ddc2b33c0453afca0e15f703 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 29 Aug 2024 16:09:40 -0300 Subject: [PATCH 20/21] feat: LSP signature help for assert and assert_eq (#5862) # Description ## Problem This is something minor, but it's good for completeness and also for users to easily learn that there's an optional failure message you can use in these built-in functions. ## Summary ![lsp-signature-help-assert](https://github.com/user-attachments/assets/77d621fb-360f-47df-92da-f51af00298db) ![lsp-signature-help-assert_eq](https://github.com/user-attachments/assets/6d23a051-9440-4a86-b970-e060b60ab5a1) ## Additional Context I don't know if the syntax is clear for showing that the last argument is optional, so if you have another syntax in mind let me know! ## Documentation Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../lsp/src/requests/completion/builtins.rs | 21 ++- tooling/lsp/src/requests/completion/tests.rs | 8 +- tooling/lsp/src/requests/signature_help.rs | 137 ++++++++++++++++-- .../lsp/src/requests/signature_help/tests.rs | 47 ++++++ .../src/requests/signature_help/traversal.rs | 18 +-- 5 files changed, 194 insertions(+), 37 deletions(-) diff --git a/tooling/lsp/src/requests/completion/builtins.rs b/tooling/lsp/src/requests/completion/builtins.rs index 430e04aedfd..6ccf3ae8119 100644 --- a/tooling/lsp/src/requests/completion/builtins.rs +++ b/tooling/lsp/src/requests/completion/builtins.rs @@ -3,7 +3,10 @@ use noirc_frontend::token::Keyword; use strum::IntoEnumIterator; use super::{ - completion_items::{simple_completion_item, snippet_completion_item}, + completion_items::{ + completion_item_with_trigger_parameter_hints_command, simple_completion_item, + snippet_completion_item, + }, kinds::FunctionCompletionKind, name_matches, NodeFinder, }; @@ -31,12 +34,16 @@ impl<'a> NodeFinder<'a> { } } - self.completion_items.push(snippet_completion_item( - label, - CompletionItemKind::FUNCTION, - insert_text, - description, - )); + self.completion_items.push( + completion_item_with_trigger_parameter_hints_command( + snippet_completion_item( + label, + CompletionItemKind::FUNCTION, + insert_text, + description, + ), + ), + ); } } } diff --git a/tooling/lsp/src/requests/completion/tests.rs b/tooling/lsp/src/requests/completion/tests.rs index 59e007c5a70..2d667ead6bf 100644 --- a/tooling/lsp/src/requests/completion/tests.rs +++ b/tooling/lsp/src/requests/completion/tests.rs @@ -470,19 +470,19 @@ mod completion_tests { assert_completion_excluding_auto_import( src, vec![ - snippet_completion_item( + completion_item_with_trigger_parameter_hints_command(snippet_completion_item( "assert(…)", CompletionItemKind::FUNCTION, "assert(${1:predicate})", Some("fn(T)".to_string()), - ), + )), function_completion_item("assert_constant(…)", "assert_constant(${1:x})", "fn(T)"), - snippet_completion_item( + completion_item_with_trigger_parameter_hints_command(snippet_completion_item( "assert_eq(…)", CompletionItemKind::FUNCTION, "assert_eq(${1:lhs}, ${2:rhs})", Some("fn(T, T)".to_string()), - ), + )), ], ) .await; diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index 8aa74fe9900..25676b57381 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -7,7 +7,10 @@ use lsp_types::{ }; use noirc_errors::{Location, Span}; use noirc_frontend::{ - ast::{CallExpression, Expression, FunctionReturnType, MethodCallExpression}, + ast::{ + CallExpression, ConstrainKind, ConstrainStatement, Expression, ExpressionKind, + FunctionReturnType, MethodCallExpression, + }, hir_def::{function::FuncMeta, stmt::HirPattern}, macros_api::NodeInterner, node_interner::ReferenceId, @@ -104,6 +107,55 @@ impl<'a> SignatureFinder<'a> { ); } + pub(super) fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { + self.find_in_expression(&constrain_statement.0); + + if let Some(exp) = &constrain_statement.1 { + self.find_in_expression(exp); + } + + if self.signature_help.is_some() { + return; + } + + let arguments_span = if let Some(expr) = &constrain_statement.1 { + Span::from(constrain_statement.0.span.start()..expr.span.end()) + } else { + constrain_statement.0.span + }; + + if !self.includes_span(arguments_span) { + return; + } + + match constrain_statement.2 { + ConstrainKind::Assert => { + let mut arguments = vec![constrain_statement.0.clone()]; + if let Some(expr) = &constrain_statement.1 { + arguments.push(expr.clone()); + } + + let active_parameter = self.compute_active_parameter(&arguments); + let signature_information = self.assert_signature_information(active_parameter); + self.set_signature_help(signature_information); + } + ConstrainKind::AssertEq => { + if let ExpressionKind::Infix(infix) = &constrain_statement.0.kind { + let mut arguments = vec![infix.lhs.clone(), infix.rhs.clone()]; + if let Some(expr) = &constrain_statement.1 { + arguments.push(expr.clone()); + } + + let active_parameter = self.compute_active_parameter(&arguments); + let signature_information = + self.assert_eq_signature_information(active_parameter); + self.set_signature_help(signature_information); + } + } + ConstrainKind::Constrain => (), + } + } + fn try_compute_signature_help( &mut self, arguments: &[Expression], @@ -119,18 +171,7 @@ impl<'a> SignatureFinder<'a> { return; } - let mut active_parameter = None; - for (index, arg) in arguments.iter().enumerate() { - if self.includes_span(arg.span) || arg.span.start() as usize >= self.byte_index { - active_parameter = Some(index as u32); - break; - } - } - - if active_parameter.is_none() { - active_parameter = Some(arguments.len() as u32); - } - + let active_parameter = self.compute_active_parameter(arguments); let location = Location::new(name_span, self.file); // Check if the call references a named function @@ -267,6 +308,60 @@ impl<'a> SignatureFinder<'a> { } } + fn assert_signature_information(&self, active_parameter: Option) -> SignatureInformation { + self.hardcoded_signature_information( + active_parameter, + "assert", + &["predicate: bool", "[failure_message: str]"], + ) + } + + fn assert_eq_signature_information( + &self, + active_parameter: Option, + ) -> SignatureInformation { + self.hardcoded_signature_information( + active_parameter, + "assert_eq", + &["lhs: T", "rhs: T", "[failure_message: str]"], + ) + } + + fn hardcoded_signature_information( + &self, + active_parameter: Option, + name: &str, + arguments: &[&str], + ) -> SignatureInformation { + let mut label = String::new(); + let mut parameters = Vec::new(); + + label.push_str(name); + label.push('('); + for (index, typ) in arguments.iter().enumerate() { + if index > 0 { + label.push_str(", "); + } + + let parameter_start = label.chars().count(); + label.push_str(typ); + let parameter_end = label.chars().count(); + + parameters.push(ParameterInformation { + label: ParameterLabel::LabelOffsets([parameter_start as u32, parameter_end as u32]), + documentation: None, + }); + } + label.push(')'); + + SignatureInformation { + label, + documentation: None, + parameters: Some(parameters), + active_parameter, + } + } + fn hir_pattern_to_argument(&self, pattern: &HirPattern, text: &mut String) { match pattern { HirPattern::Identifier(hir_ident) => { @@ -286,6 +381,22 @@ impl<'a> SignatureFinder<'a> { self.signature_help = Some(signature_help); } + fn compute_active_parameter(&self, arguments: &[Expression]) -> Option { + let mut active_parameter = None; + for (index, arg) in arguments.iter().enumerate() { + if self.includes_span(arg.span) || arg.span.start() as usize >= self.byte_index { + active_parameter = Some(index as u32); + break; + } + } + + if active_parameter.is_none() { + active_parameter = Some(arguments.len() as u32); + } + + active_parameter + } + fn includes_span(&self, span: Span) -> bool { span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize } diff --git a/tooling/lsp/src/requests/signature_help/tests.rs b/tooling/lsp/src/requests/signature_help/tests.rs index c48ee159084..4b3f3c38156 100644 --- a/tooling/lsp/src/requests/signature_help/tests.rs +++ b/tooling/lsp/src/requests/signature_help/tests.rs @@ -193,4 +193,51 @@ mod signature_help_tests { assert_eq!(signature.active_parameter, Some(0)); } + + #[test] + async fn test_signature_help_for_assert() { + let src = r#" + fn bar() { + assert(>|<1, "hello"); + } + "#; + + let signature_help = get_signature_help(src).await; + assert_eq!(signature_help.signatures.len(), 1); + + let signature = &signature_help.signatures[0]; + assert_eq!(signature.label, "assert(predicate: bool, [failure_message: str])"); + + let params = signature.parameters.as_ref().unwrap(); + assert_eq!(params.len(), 2); + + check_label(&signature.label, ¶ms[0].label, "predicate: bool"); + check_label(&signature.label, ¶ms[1].label, "[failure_message: str]"); + + assert_eq!(signature.active_parameter, Some(0)); + } + + #[test] + async fn test_signature_help_for_assert_eq() { + let src = r#" + fn bar() { + assert_eq(>|])"); + + let params = signature.parameters.as_ref().unwrap(); + assert_eq!(params.len(), 3); + + check_label(&signature.label, ¶ms[0].label, "lhs: T"); + check_label(&signature.label, ¶ms[1].label, "rhs: T"); + check_label(&signature.label, ¶ms[2].label, "[failure_message: str]"); + + assert_eq!(signature.active_parameter, Some(0)); + } } diff --git a/tooling/lsp/src/requests/signature_help/traversal.rs b/tooling/lsp/src/requests/signature_help/traversal.rs index 6a31a22d63a..e9b050fc965 100644 --- a/tooling/lsp/src/requests/signature_help/traversal.rs +++ b/tooling/lsp/src/requests/signature_help/traversal.rs @@ -4,11 +4,11 @@ use super::SignatureFinder; use noirc_frontend::{ ast::{ - ArrayLiteral, AssignStatement, BlockExpression, CastExpression, ConstrainStatement, - ConstructorExpression, Expression, ExpressionKind, ForLoopStatement, ForRange, - IfExpression, IndexExpression, InfixExpression, LValue, Lambda, LetStatement, Literal, - MemberAccessExpression, NoirFunction, NoirTrait, NoirTraitImpl, Statement, StatementKind, - TraitImplItem, TraitItem, TypeImpl, + ArrayLiteral, AssignStatement, BlockExpression, CastExpression, ConstructorExpression, + Expression, ExpressionKind, ForLoopStatement, ForRange, IfExpression, IndexExpression, + InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, + NoirFunction, NoirTrait, NoirTraitImpl, Statement, StatementKind, TraitImplItem, TraitItem, + TypeImpl, }, parser::{Item, ItemKind}, ParsedModule, @@ -136,14 +136,6 @@ impl<'a> SignatureFinder<'a> { self.find_in_expression(&let_statement.expression); } - pub(super) fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { - self.find_in_expression(&constrain_statement.0); - - if let Some(exp) = &constrain_statement.1 { - self.find_in_expression(exp); - } - } - pub(super) fn find_in_assign_statement(&mut self, assign_statement: &AssignStatement) { self.find_in_lvalue(&assign_statement.lvalue); self.find_in_expression(&assign_statement.expression); From bceee55cc3833978d120e194820cfae9132c8006 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 29 Aug 2024 20:02:44 -0300 Subject: [PATCH 21/21] feat: add `Expr::resolve` and `TypedExpr::as_function_definition` (#5859) # Description ## Problem Part of #5668 ## Summary ## Additional Context ## Documentation Check one: - [ ] No documentation needed. - [x] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: jfecher --- .../src/elaborator/statements.rs | 2 +- .../src/hir/comptime/interpreter/builtin.rs | 80 +++++++++++++++++-- .../interpreter/builtin/builtin_helpers.rs | 9 ++- .../noirc_frontend/src/hir/comptime/value.rs | 18 ++++- compiler/noirc_frontend/src/hir_def/types.rs | 2 + compiler/noirc_frontend/src/lexer/token.rs | 3 + .../noirc_frontend/src/parser/parser/types.rs | 7 ++ docs/docs/noir/concepts/comptime.md | 1 + docs/docs/noir/standard_library/meta/expr.md | 8 +- .../noir/standard_library/meta/typed_expr.md | 13 +++ noir_stdlib/src/meta/expr.nr | 5 ++ noir_stdlib/src/meta/mod.nr | 1 + noir_stdlib/src/meta/typed_expr.nr | 8 ++ .../comptime_expr/src/main.nr | 11 +++ .../lsp/src/requests/completion/builtins.rs | 2 + 15 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 docs/docs/noir/standard_library/meta/typed_expr.md create mode 100644 noir_stdlib/src/meta/typed_expr.nr diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index dcbdf89391e..d7d330f891a 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -48,7 +48,7 @@ impl<'context> Elaborator<'context> { } } - pub(super) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { + pub(crate) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { let span = statement.span; let (hir_statement, typ) = self.elaborate_statement_value(statement); let id = self.interner.push_stmt(hir_statement); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 5ffc58004be..0dafe408c7f 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -8,9 +8,9 @@ use builtin_helpers::{ block_expression_to_value, check_argument_count, check_function_not_yet_resolved, check_one_argument, check_three_arguments, check_two_arguments, get_expr, get_field, get_function_def, get_module, get_quoted, get_slice, get_struct, get_trait_constraint, - get_trait_def, get_trait_impl, get_tuple, get_type, get_u32, get_unresolved_type, - hir_pattern_to_tokens, mutate_func_meta_type, parse, replace_func_meta_parameters, - replace_func_meta_return_type, + get_trait_def, get_trait_impl, get_tuple, get_type, get_typed_expr, get_u32, + get_unresolved_type, hir_pattern_to_tokens, mutate_func_meta_type, parse, + replace_func_meta_parameters, replace_func_meta_return_type, }; use chumsky::{prelude::choice, Parser}; use im::Vector; @@ -25,7 +25,11 @@ use crate::{ FunctionReturnType, IntegerBitSize, LValue, Literal, Statement, StatementKind, UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, }, - hir::comptime::{errors::IResult, value::ExprValue, InterpreterError, Value}, + hir::comptime::{ + errors::IResult, + value::{ExprValue, TypedExpr}, + InterpreterError, Value, + }, hir_def::function::FunctionBody, macros_api::{HirExpression, HirLiteral, ModuleDefId, NodeInterner, Signedness}, node_interner::{DefinitionKind, TraitImplKind}, @@ -87,6 +91,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "expr_has_semicolon" => expr_has_semicolon(interner, arguments, location), "expr_is_break" => expr_is_break(interner, arguments, location), "expr_is_continue" => expr_is_continue(interner, arguments, location), + "expr_resolve" => expr_resolve(self, arguments, location), "is_unconstrained" => Ok(Value::Bool(true)), "function_def_body" => function_def_body(interner, arguments, location), "function_def_name" => function_def_name(interner, arguments, location), @@ -145,6 +150,9 @@ impl<'local, 'context> Interpreter<'local, 'context> { "type_is_bool" => type_is_bool(arguments, location), "type_is_field" => type_is_field(arguments, location), "type_of" => type_of(arguments, location), + "typed_expr_as_function_definition" => { + typed_expr_as_function_definition(interner, arguments, return_type, location) + } "unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location), "zeroed" => zeroed(return_type), _ => { @@ -763,6 +771,23 @@ fn trait_impl_trait_generic_args( Ok(Value::Slice(trait_generics, slice_type)) } +fn typed_expr_as_function_definition( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let typed_expr = get_typed_expr(self_argument)?; + let option_value = if let TypedExpr::ExprId(expr_id) = typed_expr { + let func_id = interner.lookup_function_from_expr(&expr_id); + func_id.map(Value::FunctionDefinition) + } else { + None + }; + option(return_type, option_value) +} + // fn is_field(self) -> bool fn unresolved_type_is_field( interner: &NodeInterner, @@ -1380,7 +1405,48 @@ where F: FnOnce(ExprValue) -> Option, { let self_argument = check_one_argument(arguments, location)?; - let mut expr_value = get_expr(interner, self_argument)?; + let expr_value = get_expr(interner, self_argument)?; + let expr_value = unwrap_expr_value(interner, expr_value); + + let option_value = f(expr_value); + option(return_type, option_value) +} + +// fn resolve(self) -> TypedExpr +fn expr_resolve( + interpreter: &mut Interpreter, + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let self_argument_location = self_argument.1; + let expr_value = get_expr(interpreter.elaborator.interner, self_argument)?; + let expr_value = unwrap_expr_value(interpreter.elaborator.interner, expr_value); + + let value = + interpreter.elaborate_item(interpreter.current_function, |elaborator| match expr_value { + ExprValue::Expression(expression_kind) => { + let expr = Expression { kind: expression_kind, span: self_argument_location.span }; + let (expr_id, _) = elaborator.elaborate_expression(expr); + Value::TypedExpr(TypedExpr::ExprId(expr_id)) + } + ExprValue::Statement(statement_kind) => { + let statement = + Statement { kind: statement_kind, span: self_argument_location.span }; + let (stmt_id, _) = elaborator.elaborate_statement(statement); + Value::TypedExpr(TypedExpr::StmtId(stmt_id)) + } + ExprValue::LValue(lvalue) => { + let expr = lvalue.as_expression(); + let (expr_id, _) = elaborator.elaborate_expression(expr); + Value::TypedExpr(TypedExpr::ExprId(expr_id)) + } + }); + + Ok(value) +} + +fn unwrap_expr_value(interner: &NodeInterner, mut expr_value: ExprValue) -> ExprValue { loop { match expr_value { ExprValue::Expression(ExpressionKind::Parenthesized(expression)) => { @@ -1402,9 +1468,7 @@ where _ => break, } } - - let option_value = f(expr_value); - option(return_type, option_value) + expr_value } // fn body(self) -> Expr diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index 2e06240e995..dd9ea51961e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -11,7 +11,7 @@ use crate::{ hir::{ comptime::{ errors::IResult, - value::{add_token_spans, ExprValue}, + value::{add_token_spans, ExprValue, TypedExpr}, Interpreter, InterpreterError, Value, }, def_map::ModuleId, @@ -227,6 +227,13 @@ pub(crate) fn get_type((value, location): (Value, Location)) -> IResult { } } +pub(crate) fn get_typed_expr((value, location): (Value, Location)) -> IResult { + match value { + Value::TypedExpr(typed_expr) => Ok(typed_expr), + value => type_mismatch(value, Type::Quoted(QuotedType::TypedExpr), location), + } +} + pub(crate) fn get_quoted((value, location): (Value, Location)) -> IResult>> { match value { Value::Quoted(tokens) => Ok(tokens), diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 5b4875c8c41..b96c4852931 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -24,7 +24,7 @@ use crate::{ Expression, ExpressionKind, HirExpression, HirLiteral, Literal, NodeInterner, Path, StructId, }, - node_interner::{ExprId, FuncId, TraitId, TraitImplId}, + node_interner::{ExprId, FuncId, StmtId, TraitId, TraitImplId}, parser::{self, NoirParser, TopLevelStatement}, token::{SpannedToken, Token, Tokens}, QuotedType, Shared, Type, TypeBindings, @@ -69,6 +69,7 @@ pub enum Value { Type(Type), Zeroed(Type), Expr(ExprValue), + TypedExpr(TypedExpr), UnresolvedType(UnresolvedTypeData), } @@ -79,6 +80,12 @@ pub enum ExprValue { LValue(LValue), } +#[derive(Debug, Clone, PartialEq, Eq, Display)] +pub enum TypedExpr { + ExprId(ExprId), + StmtId(StmtId), +} + impl Value { pub(crate) fn expression(expr: ExpressionKind) -> Self { Value::Expr(ExprValue::Expression(expr)) @@ -137,6 +144,7 @@ impl Value { Value::Type(_) => Type::Quoted(QuotedType::Type), Value::Zeroed(typ) => return Cow::Borrowed(typ), Value::Expr(_) => Type::Quoted(QuotedType::Expr), + Value::TypedExpr(_) => Type::Quoted(QuotedType::TypedExpr), Value::UnresolvedType(_) => Type::Quoted(QuotedType::UnresolvedType), }) } @@ -264,7 +272,8 @@ impl Value { statements: vec![Statement { kind: statement, span: location.span }], }) } - Value::Expr(ExprValue::LValue(_)) + Value::Expr(ExprValue::LValue(lvalue)) => lvalue.as_expression().kind, + Value::TypedExpr(..) | Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(..) @@ -389,7 +398,9 @@ impl Value { HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements))) } Value::Quoted(tokens) => HirExpression::Unquote(add_token_spans(tokens, location.span)), - Value::Expr(..) + Value::TypedExpr(TypedExpr::ExprId(expr_id)) => interner.expression(&expr_id), + Value::TypedExpr(TypedExpr::StmtId(..)) + | Value::Expr(..) | Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(..) @@ -621,6 +632,7 @@ impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { Value::Expr(ExprValue::LValue(lvalue)) => { write!(f, "{}", remove_interned_in_lvalue(self.interner, lvalue.clone())) } + Value::TypedExpr(_) => write!(f, "(typed expr)"), Value::UnresolvedType(typ) => { if let UnresolvedTypeData::Interned(id) = typ { let typ = self.interner.get_unresolved_type_data(*id); diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index c59c86b9616..638003d3fcd 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -153,6 +153,7 @@ pub enum QuotedType { Quoted, TopLevelItem, Type, + TypedExpr, StructDefinition, TraitConstraint, TraitDefinition, @@ -741,6 +742,7 @@ impl std::fmt::Display for QuotedType { QuotedType::Quoted => write!(f, "Quoted"), QuotedType::TopLevelItem => write!(f, "TopLevelItem"), QuotedType::Type => write!(f, "Type"), + QuotedType::TypedExpr => write!(f, "TypedExpr"), QuotedType::StructDefinition => write!(f, "StructDefinition"), QuotedType::TraitDefinition => write!(f, "TraitDefinition"), QuotedType::TraitConstraint => write!(f, "TraitConstraint"), diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index b3b6d25480f..c9bd465b6a6 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -962,6 +962,7 @@ pub enum Keyword { TraitDefinition, TraitImpl, Type, + TypedExpr, TypeType, Unchecked, Unconstrained, @@ -1017,6 +1018,7 @@ impl fmt::Display for Keyword { Keyword::TraitDefinition => write!(f, "TraitDefinition"), Keyword::TraitImpl => write!(f, "TraitImpl"), Keyword::Type => write!(f, "type"), + Keyword::TypedExpr => write!(f, "TypedExpr"), Keyword::TypeType => write!(f, "Type"), Keyword::Unchecked => write!(f, "unchecked"), Keyword::Unconstrained => write!(f, "unconstrained"), @@ -1075,6 +1077,7 @@ impl Keyword { "TraitImpl" => Keyword::TraitImpl, "type" => Keyword::Type, "Type" => Keyword::TypeType, + "TypedExpr" => Keyword::TypedExpr, "StructDefinition" => Keyword::StructDefinition, "unchecked" => Keyword::Unchecked, "unconstrained" => Keyword::Unconstrained, diff --git a/compiler/noirc_frontend/src/parser/parser/types.rs b/compiler/noirc_frontend/src/parser/parser/types.rs index f83303151eb..9dabb8b83b6 100644 --- a/compiler/noirc_frontend/src/parser/parser/types.rs +++ b/compiler/noirc_frontend/src/parser/parser/types.rs @@ -88,6 +88,7 @@ pub(super) fn comptime_type() -> impl NoirParser { type_of_quoted_types(), top_level_item_type(), quoted_type(), + typed_expr_type(), )) } @@ -159,6 +160,12 @@ fn quoted_type() -> impl NoirParser { .map_with_span(|_, span| UnresolvedTypeData::Quoted(QuotedType::Quoted).with_span(span)) } +/// This is the type of a typed/resolved expression. +fn typed_expr_type() -> impl NoirParser { + keyword(Keyword::TypedExpr) + .map_with_span(|_, span| UnresolvedTypeData::Quoted(QuotedType::TypedExpr).with_span(span)) +} + /// This is the type of an already resolved type. /// The only way this can appear in the token input is if an already resolved `Type` object /// was spliced into a macro's token stream via the `$` operator. diff --git a/docs/docs/noir/concepts/comptime.md b/docs/docs/noir/concepts/comptime.md index 2b5c29538b9..ed55a541fbd 100644 --- a/docs/docs/noir/concepts/comptime.md +++ b/docs/docs/noir/concepts/comptime.md @@ -232,6 +232,7 @@ The following is an incomplete list of some `comptime` types along with some use - `fn fields(self) -> [(Quoted, Type)]` - Return the name and type of each field - `TraitConstraint`: A trait constraint such as `From` +- `TypedExpr`: A type-checked expression. - `UnresolvedType`: A syntactic notation that refers to a Noir type that hasn't been resolved yet There are many more functions available by exploring the `std::meta` module and its submodules. diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md index 7d3d1c2f453..8f708b2359e 100644 --- a/docs/docs/noir/standard_library/meta/expr.md +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -184,4 +184,10 @@ for expressions that are integers, doubling them, would return `(&[2], &[4, 6])` #include_code quoted noir_stdlib/src/meta/expr.nr rust -Returns this expression as a `Quoted` value. It's the same as `quote { $self }`. \ No newline at end of file +Returns this expression as a `Quoted` value. It's the same as `quote { $self }`. + +### resolve + +#include_code resolve noir_stdlib/src/meta/expr.nr rust + +Resolves and type-checks this expression and returns the result as a `TypedExpr`. If any names used by this expression are not in scope or if there are any type errors, this will give compiler errors as if the expression was written directly into the current `comptime` function. \ No newline at end of file diff --git a/docs/docs/noir/standard_library/meta/typed_expr.md b/docs/docs/noir/standard_library/meta/typed_expr.md new file mode 100644 index 00000000000..eacfd9c1230 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/typed_expr.md @@ -0,0 +1,13 @@ +--- +title: TypedExpr +--- + +`std::meta::typed_expr` contains methods on the built-in `TypedExpr` type for resolved and type-checked expressions. + +## Methods + +### as_function_definition + +#include_code as_function_definition noir_stdlib/src/meta/typed_expr.nr rust + +If this expression refers to a function definitions, returns it. Otherwise returns `Option::none()`. \ No newline at end of file diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index 838c96570b5..5c677f39a9e 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -151,6 +151,11 @@ impl Expr { // docs:end:quoted quote { $self } } + + #[builtin(expr_resolve)] + // docs:start:resolve + fn resolve(self) -> TypedExpr {} + // docs:end:resolve } fn modify_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index be1b12540c9..24398054467 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -7,6 +7,7 @@ mod trait_constraint; mod trait_def; mod trait_impl; mod typ; +mod typed_expr; mod quoted; mod unresolved_type; diff --git a/noir_stdlib/src/meta/typed_expr.nr b/noir_stdlib/src/meta/typed_expr.nr new file mode 100644 index 00000000000..8daede97438 --- /dev/null +++ b/noir_stdlib/src/meta/typed_expr.nr @@ -0,0 +1,8 @@ +use crate::option::Option; + +impl TypedExpr { + #[builtin(typed_expr_as_function_definition)] + // docs:start:as_function_definition + fn as_function_definition(self) -> Option {} + // docs:end:as_function_definition +} diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index e7e156ad1f1..7248d51ca9a 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -577,6 +577,17 @@ mod tests { } } + #[test] + fn test_resolve_to_function_definition() { + comptime + { + let expr = quote { times_two }.as_expr().unwrap(); + let func = expr.resolve().as_function_definition().unwrap(); + assert_eq(func.name(), quote { times_two }); + assert_eq(func.parameters().len(), 1); + } + } + comptime fn get_unary_op(quoted: Quoted) -> UnaryOp { let expr = quoted.as_expr().unwrap(); let (op, _) = expr.as_unary_op().unwrap(); diff --git a/tooling/lsp/src/requests/completion/builtins.rs b/tooling/lsp/src/requests/completion/builtins.rs index 6ccf3ae8119..54340075b15 100644 --- a/tooling/lsp/src/requests/completion/builtins.rs +++ b/tooling/lsp/src/requests/completion/builtins.rs @@ -102,6 +102,7 @@ pub(super) fn keyword_builtin_type(keyword: &Keyword) -> Option<&'static str> { Keyword::TraitConstraint => Some("TraitConstraint"), Keyword::TraitDefinition => Some("TraitDefinition"), Keyword::TraitImpl => Some("TraitImpl"), + Keyword::TypedExpr => Some("TypedExpr"), Keyword::TypeType => Some("Type"), Keyword::UnresolvedType => Some("UnresolvedType"), @@ -207,6 +208,7 @@ pub(super) fn keyword_builtin_function(keyword: &Keyword) -> Option