diff --git a/noir-projects/noir-protocol-circuits/crates/types/src/merkle_tree/variable_merkle_tree.nr b/noir-projects/noir-protocol-circuits/crates/types/src/merkle_tree/variable_merkle_tree.nr index 22f13811692..044f131beff 100644 --- a/noir-projects/noir-protocol-circuits/crates/types/src/merkle_tree/variable_merkle_tree.nr +++ b/noir-projects/noir-protocol-circuits/crates/types/src/merkle_tree/variable_merkle_tree.nr @@ -1,4 +1,5 @@ use crate::hash::accumulate_sha256; + // N = maximum leaves // For now we only care about the root pub struct VariableMerkleTree { @@ -50,17 +51,15 @@ fn get_prev_power_2(value: u32) -> u32 { let next_power_2 = 2 << next_power_exponent; let prev_power_2 = next_power_2 / 2; - assert((value == 0) | (value == 1) | (value > prev_power_2)); + assert(prev_power_2 < value); assert(value <= next_power_2); prev_power_2 } -// This calculates the root of the minimal size merkle tree required -// to store num_non_empty_leaves -// Since we cannot isolate branches, it doesn't cost fewer gates than using -// MerkleTree on the full array of elements N, but is slightly cheaper on-chain -// and cleaner elsewhere. +// Calculates the root of the minimal size merkle tree required to store num_non_empty_leaves. +// Since we cannot isolate branches, it doesn't cost fewer gates than using MerkleTree on the full array of elements N, +// but is slightly cheaper on-chain and cleaner elsewhere. impl VariableMerkleTree { // Example - tx_0 with 3 msgs | tx_1 with 2 msgs creates: // @@ -72,32 +71,34 @@ impl VariableMerkleTree { // | tx_0 | | tx_1 | // pub fn new_sha(leaves: [Field; N], num_non_empty_leaves: u32) -> Self { - let prev_power_2 = get_prev_power_2(num_non_empty_leaves); - - // hash base layer - // If we have no num_non_empty_leaves, we return 0 - let mut stop = num_non_empty_leaves == 0; - + let num_nodes_layer_1 = if (num_non_empty_leaves == 0) { + // For 0 leaves, there is no layer 1, no hashing happens and the root is set to 0 + 0 + } else if (num_non_empty_leaves == 1) { + // For 1 leaf, 1 round of hashing happens and root = hash([leaf, 0]) + 1 + } else { + // For more than 1 leaf, we dynamically compute num of nodes in layer 1 by finding the previous power of 2 + get_prev_power_2(num_non_empty_leaves) + }; + + // We hash the base layer let mut nodes = [0; N]; for i in 0..N / 2 { - // stop after non zero leaves - if i == prev_power_2 { - stop = true; - } - if (!stop) { + if (i < num_nodes_layer_1) { nodes[i] = accumulate_sha256([leaves[2 * i], leaves[2 * i + 1]]); } } - // hash the other layers - stop = prev_power_2 == 1; + // We hash the other layers + let mut stop = num_non_empty_leaves < 3; - let mut next_layer_end = prev_power_2 / 2; + let mut next_layer_end = num_nodes_layer_1 / 2; let mut next_layer_size = next_layer_end; let mut root = nodes[0]; for i in 0..(N - 1 - N / 2) { if !stop { - nodes[prev_power_2 + i] = accumulate_sha256([nodes[2 * i], nodes[2 * i + 1]]); + nodes[num_nodes_layer_1 + i] = accumulate_sha256([nodes[2 * i], nodes[2 * i + 1]]); if i == next_layer_end { // Reached next layer => move up one layer next_layer_size = next_layer_size / 2; @@ -105,7 +106,7 @@ impl VariableMerkleTree { } if (next_layer_size == 1) { // Reached root - root = nodes[prev_power_2 + i]; + root = nodes[num_nodes_layer_1 + i]; stop = true; } }