diff --git a/src/hash_builder/mod.rs b/src/hash_builder/mod.rs index 6273aa4..157b1b9 100644 --- a/src/hash_builder/mod.rs +++ b/src/hash_builder/mod.rs @@ -311,7 +311,7 @@ impl HashBuilder { let state_mask = self.groups[len]; let hash_mask = self.hash_masks[len]; let branch_node = BranchNodeRef::new(&self.stack, &state_mask); - let children = branch_node.child_hashes(hash_mask); + let children = branch_node.child_hashes(hash_mask).collect(); self.rlp_buf.clear(); let rlp = branch_node.rlp(&mut self.rlp_buf); diff --git a/src/nodes/branch.rs b/src/nodes/branch.rs index 70cc5d9..25139d3 100644 --- a/src/nodes/branch.rs +++ b/src/nodes/branch.rs @@ -1,7 +1,8 @@ use super::{super::TrieMask, rlp_node, CHILD_INDEX_RANGE}; use alloy_primitives::{hex, B256}; use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE}; -use core::fmt; +use core::{fmt, ops::Range, slice::Iter}; +use nybbles::Nibbles; #[allow(unused_imports)] use alloc::vec::Vec; @@ -153,20 +154,27 @@ impl<'a> BranchNodeRef<'a> { self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap() } - /// Given the hash and state mask of children present, return an iterator over the stack items + /// Given the hash mask of children, return an iterator over stack items /// that match the mask. - pub fn child_hashes(&self, hash_mask: TrieMask) -> Vec { - let mut stack_ptr = self.first_child_index(); - let mut children = Vec::with_capacity(hash_mask.count_ones() as usize); - for index in CHILD_INDEX_RANGE { - if self.state_mask.is_bit_set(index) { - if hash_mask.is_bit_set(index) { - children.push(B256::from_slice(&self.stack[stack_ptr][1..])); - } - stack_ptr += 1; - } - } - children + pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator + '_ { + BranchChildrenIter::new(self) + .filter(move |(index, _)| hash_mask.is_bit_set(*index)) + .map(|(_, child)| B256::from_slice(&child[1..])) + } + + /// Return an iterator over stack items and corresponding indices that match the state mask. + pub fn indexed_children(&self) -> impl Iterator + '_ { + BranchChildrenIter::new(self).map(|(index, child)| (index, B256::from_slice(&child[1..]))) + } + + /// Given the prefix, return an iterator over stack items that match the + /// state mask and their corresponding full paths. + pub fn prefixed_children(&self, prefix: Nibbles) -> impl Iterator + '_ { + self.indexed_children().map(move |(index, hash)| { + let mut path = prefix.clone(); + path.push(index); + (path, hash) + }) } /// Returns the RLP encoding of the branch node given the state mask of children present. @@ -193,6 +201,38 @@ impl<'a> BranchNodeRef<'a> { } } +/// Iterator over branch node children. +#[derive(Debug)] +struct BranchChildrenIter<'a> { + range: Range, + state_mask: &'a TrieMask, + stack_iter: Iter<'a, Vec>, +} + +impl<'a> BranchChildrenIter<'a> { + /// Create new iterator over branch node children. + fn new(node: &BranchNodeRef<'a>) -> Self { + Self { + range: CHILD_INDEX_RANGE, + state_mask: node.state_mask, + stack_iter: node.stack[node.first_child_index()..].iter(), + } + } +} + +impl<'a> Iterator for BranchChildrenIter<'a> { + type Item = (u8, &'a [u8]); + + fn next(&mut self) -> Option { + loop { + let current = self.range.next()?; + if self.state_mask.is_bit_set(current) { + return Some((current, self.stack_iter.next()?)); + } + } + } +} + /// A struct representing a branch node in an Ethereum trie. /// /// A branch node can have up to 16 children, each corresponding to one of the possible nibble