Skip to content

Commit

Permalink
chore: stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes authored and rkrasiuk committed Oct 1, 2024
1 parent 8721810 commit da23a9a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/hash_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ impl HashBuilder {
fn push_branch_node(&mut self, current: &Nibbles, len: usize) -> Vec<B256> {
let state_mask = self.groups[len];
let hash_mask = self.hash_masks[len];
let branch_node = BranchNodeRef::new(&self.stack, &state_mask);
let branch_node = BranchNodeRef::new(&self.stack, state_mask);
// Avoid calculating this value if it's not needed.
let children = if self.updated_branch_nodes.is_some() {
branch_node.child_hashes(hash_mask).collect()
Expand Down
69 changes: 43 additions & 26 deletions src/nodes/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl BranchNode {

/// Return branch node as [BranchNodeRef].
pub fn as_ref(&self) -> BranchNodeRef<'_> {
BranchNodeRef::new(&self.stack, &self.state_mask)
BranchNodeRef::new(&self.stack, self.state_mask)
}
}

Expand All @@ -100,7 +100,7 @@ pub struct BranchNodeRef<'a> {
pub stack: &'a [RlpNode],
/// Reference to bitmask indicating the presence of children at
/// the respective nibble positions.
pub state_mask: &'a TrieMask,
pub state_mask: TrieMask,
}

impl fmt::Debug for BranchNodeRef<'_> {
Expand All @@ -122,12 +122,9 @@ impl Encodable for BranchNodeRef<'_> {
Header { list: true, payload_length: self.rlp_payload_length() }.encode(out);

// Extend the RLP buffer with the present children
let mut stack_ptr = self.first_child_index();
for index in CHILD_INDEX_RANGE {
if self.state_mask.is_bit_set(index) {
out.put_slice(&self.stack[stack_ptr]);
// Advance the pointer to the next child.
stack_ptr += 1;
for (_, child) in self.children() {
if let Some(child) = child {
out.put_slice(child);
} else {
out.put_u8(EMPTY_STRING_CODE);
}
Expand All @@ -146,7 +143,7 @@ impl Encodable for BranchNodeRef<'_> {
impl<'a> BranchNodeRef<'a> {
/// Create a new branch node from the stack of nodes.
#[inline]
pub const fn new(stack: &'a [RlpNode], state_mask: &'a TrieMask) -> Self {
pub const fn new(stack: &'a [RlpNode], state_mask: TrieMask) -> Self {
Self { stack, state_mask }
}

Expand All @@ -157,16 +154,21 @@ impl<'a> BranchNodeRef<'a> {
/// If the stack length is less than number of children specified in state mask.
/// Means that the node is in inconsistent state.
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn first_child_index(&self) -> usize {
self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap()
}

#[inline]
fn children(&self) -> impl Iterator<Item = (u8, Option<&RlpNode>)> + '_ {
BranchChildrenIter::new(self)
}

/// Given the hash mask of children, return an iterator over stack items
/// that match the mask.
#[inline]
pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator<Item = B256> + '_ {
BranchChildrenIter::new(self)
self.children()
.filter_map(|(i, c)| c.map(|c| (i, c)))
.filter(move |(index, _)| hash_mask.is_bit_set(*index))
.map(|(_, child)| B256::from_slice(&child[1..]))
}
Expand All @@ -182,13 +184,9 @@ impl<'a> BranchNodeRef<'a> {
#[inline]
fn rlp_payload_length(&self) -> usize {
let mut payload_length = 1;
let mut stack = self.stack[self.first_child_index()..].iter();
for digit in CHILD_INDEX_RANGE {
if self.state_mask.is_bit_set(digit) {
// SAFETY: `first_child_index` guarantees that `stack` is exactly
// `state_mask.count_ones()` long.
let stack_item = unsafe { stack.next().unwrap_unchecked() };
payload_length += stack_item.len();
for (_, child) in self.children() {
if let Some(child) = child {
payload_length += child.len();
} else {
payload_length += 1;
}
Expand All @@ -201,7 +199,7 @@ impl<'a> BranchNodeRef<'a> {
#[derive(Debug)]
struct BranchChildrenIter<'a> {
range: Range<u8>,
state_mask: &'a TrieMask,
state_mask: TrieMask,
stack_iter: Iter<'a, RlpNode>,
}

Expand All @@ -217,15 +215,34 @@ impl<'a> BranchChildrenIter<'a> {
}

impl<'a> Iterator for BranchChildrenIter<'a> {
type Item = (u8, &'a [u8]);
type Item = (u8, Option<&'a RlpNode>);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
let current = self.range.next()?;
if self.state_mask.is_bit_set(current) {
return Some((current, self.stack_iter.next()?));
}
}
let i = self.range.next()?;
let value = if self.state_mask.is_bit_set(i) {
// SAFETY: `first_child_index` guarantees that `stack` is exactly
// `state_mask.count_ones()` long.
Some(unsafe { self.stack_iter.next().unwrap_unchecked() })
} else {
None
};
Some((i, value))
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}

impl core::iter::FusedIterator for BranchChildrenIter<'_> {}

impl ExactSizeIterator for BranchChildrenIter<'_> {
#[inline]
fn len(&self) -> usize {
self.range.len()
}
}

Expand Down
31 changes: 28 additions & 3 deletions src/nodes/rlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ use alloy_rlp::EMPTY_STRING_CODE;
use arrayvec::ArrayVec;
use core::fmt;

const MAX: usize = 33;

/// An RLP-encoded node.
#[derive(Clone, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RlpNode(ArrayVec<u8, 33>);
pub struct RlpNode(ArrayVec<u8, MAX>);

impl alloy_rlp::Decodable for RlpNode {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let bytes = alloy_rlp::Header::decode_bytes(buf, false)?;
Self::from_raw(bytes).ok_or_else(|| alloy_rlp::Error::Custom("RLP node too large"))
Self::from_raw_rlp(bytes)
}
}

Expand Down Expand Up @@ -58,7 +60,7 @@ impl RlpNode {
/// Creates a new RLP-encoded node from the given data.
#[inline]
pub fn from_raw_rlp(data: &[u8]) -> alloy_rlp::Result<Self> {
Self::from_raw(data).ok_or_else(|| alloy_rlp::Error::Custom("RLP node too large"))
Self::from_raw(data).ok_or(alloy_rlp::Error::Custom("RLP node too large"))
}

/// Given an RLP-encoded node, returns it either as `rlp(node)` or `rlp(keccak(rlp(node)))`.
Expand Down Expand Up @@ -88,3 +90,26 @@ impl RlpNode {
&self.0
}
}

#[cfg(feature = "arbitrary")]
impl<'u> arbitrary::Arbitrary<'u> for RlpNode {
fn arbitrary(g: &mut arbitrary::Unstructured<'u>) -> arbitrary::Result<Self> {
let len = g.int_in_range(0..=MAX)?;
let mut arr = ArrayVec::new();
arr.try_extend_from_slice(g.bytes(len)?).unwrap();
Ok(Self(arr))
}
}

#[cfg(feature = "arbitrary")]
impl proptest::arbitrary::Arbitrary for RlpNode {
type Parameters = ();
type Strategy = proptest::strategy::BoxedStrategy<Self>;

fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
use proptest::prelude::*;
proptest::collection::vec(proptest::prelude::any::<u8>(), 0..=MAX)
.prop_map(|vec| Self::from_raw(&vec).unwrap())
.boxed()
}
}

0 comments on commit da23a9a

Please sign in to comment.