Skip to content

Commit

Permalink
perf: store RLP-encoded nodes using ArrayVec (#51)
Browse files Browse the repository at this point in the history
* perf: store RLP-encoded nodes using ArrayVec

* chore: stuff
  • Loading branch information
DaniPopes authored Oct 1, 2024
1 parent 9e00d0e commit 296cb8d
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 122 deletions.
10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ alloy-primitives = { version = "0.8.5", default-features = false, features = [
alloy-rlp = { version = "0.3.8", default-features = false, features = [
"derive",
] }

arrayvec = { version = "0.7", default-features = false }
derive_more = { version = "1", default-features = false, features = [
"add",
"add_assign",
Expand Down Expand Up @@ -62,12 +64,18 @@ default = ["std", "alloy-primitives/default"]
std = [
"alloy-primitives/std",
"alloy-rlp/std",
"arrayvec/std",
"derive_more/std",
"nybbles/std",
"tracing/std",
"serde?/std",
]
serde = ["dep:serde", "alloy-primitives/serde", "nybbles/serde"]
serde = [
"dep:serde",
"alloy-primitives/serde",
"arrayvec/serde",
"nybbles/serde",
]
arbitrary = [
"std",
"dep:arbitrary",
Expand Down
37 changes: 18 additions & 19 deletions src/hash_builder/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! The implementation of the hash builder.
use super::{
nodes::{word_rlp, BranchNodeRef, ExtensionNodeRef, LeafNodeRef},
nodes::{BranchNodeRef, ExtensionNodeRef, LeafNodeRef},
proof::ProofRetainer,
BranchNodeCompact, Nibbles, TrieMask, EMPTY_ROOT_HASH,
};
use crate::{proof::ProofNodes, HashMap};
use crate::{nodes::RlpNode, proof::ProofNodes, HashMap};
use alloy_primitives::{hex, keccak256, B256};
use alloy_rlp::EMPTY_STRING_CODE;
use core::cmp;
Expand Down Expand Up @@ -45,7 +45,7 @@ pub use value::HashBuilderValue;
#[allow(missing_docs)]
pub struct HashBuilder {
pub key: Nibbles,
pub stack: Vec<Vec<u8>>,
pub stack: Vec<RlpNode>,
pub value: HashBuilderValue,

pub groups: Vec<TrieMask>,
Expand Down Expand Up @@ -131,7 +131,7 @@ impl HashBuilder {
if !self.key.is_empty() {
self.update(&key);
} else if key.is_empty() {
self.stack.push(word_rlp(&value));
self.stack.push(RlpNode::word_rlp(&value));
}
self.set_key_value(key, value);
self.stored_in_database = stored_in_database;
Expand Down Expand Up @@ -250,7 +250,7 @@ impl HashBuilder {
}
HashBuilderValue::Hash(hash) => {
trace!(target: "trie::hash_builder", ?hash, "pushing branch node hash");
self.stack.push(word_rlp(hash));
self.stack.push(RlpNode::word_rlp(hash));

if self.stored_in_database {
self.tree_masks[current.len() - 1] |=
Expand All @@ -266,17 +266,17 @@ impl HashBuilder {

if build_extensions && !short_node_key.is_empty() {
self.update_masks(&current, len_from);
let stack_last =
self.stack.pop().expect("there should be at least one stack item; qed");
let stack_last = self.stack.pop().expect("there should be at least one stack item");
let extension_node = ExtensionNodeRef::new(&short_node_key, &stack_last);
trace!(target: "trie::hash_builder", ?extension_node, "pushing extension node");
trace!(target: "trie::hash_builder", rlp = {
self.rlp_buf.clear();
hex::encode(extension_node.rlp(&mut self.rlp_buf))
}, "extension node rlp");

self.rlp_buf.clear();
self.stack.push(extension_node.rlp(&mut self.rlp_buf));
let rlp = extension_node.rlp(&mut self.rlp_buf);
trace!(target: "trie::hash_builder",
?extension_node,
?rlp,
"pushing extension node",
);
self.stack.push(rlp);
self.retain_proof_from_buf(&current.slice(..len_from));
self.resize_masks(len_from);
}
Expand Down Expand Up @@ -325,7 +325,7 @@ impl HashBuilder {
fn push_branch_node(&mut self, current: &Nibbles, len: usize) -> Vec<B256> {
let state_mask = self.groups[len];
let hash_mask = self.hash_masks[len];
let branch_node = BranchNodeRef::new(&self.stack, &state_mask);
let branch_node = BranchNodeRef::new(&self.stack, state_mask);
// Avoid calculating this value if it's not needed.
let children = if self.updated_branch_nodes.is_some() {
branch_node.child_hashes(hash_mask).collect()
Expand All @@ -345,10 +345,9 @@ impl HashBuilder {
old_len = self.stack.len(),
"resizing stack to prepare branch node"
);
self.stack.resize(first_child_idx, vec![]);
self.stack.resize_with(first_child_idx, Default::default);

trace!(target: "trie::hash_builder", "pushing branch node with {:?} mask from stack", state_mask);
trace!(target: "trie::hash_builder", rlp = hex::encode(&rlp), "branch node rlp");
trace!(target: "trie::hash_builder", ?rlp, "pushing branch node with {state_mask:?} mask from stack");
self.stack.push(rlp);
children
}
Expand Down Expand Up @@ -570,8 +569,8 @@ mod tests {
#[test]
fn manual_branch_node_ok() {
let raw_input = vec![
(hex!("646f").to_vec(), hex!("76657262").to_vec()),
(hex!("676f6f64").to_vec(), hex!("7075707079").to_vec()),
(hex!("646f").to_vec(), RlpNode::from_raw(&hex!("76657262")).unwrap()),
(hex!("676f6f64").to_vec(), RlpNode::from_raw(&hex!("7075707079")).unwrap()),
];
let expected = triehash_trie_root(raw_input.clone());

Expand Down
107 changes: 68 additions & 39 deletions src/nodes/branch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{super::TrieMask, rlp_node, CHILD_INDEX_RANGE};
use super::{super::TrieMask, RlpNode, CHILD_INDEX_RANGE};
use alloy_primitives::{hex, B256};
use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE};
use core::{fmt, ops::Range, slice::Iter};
Expand All @@ -14,7 +14,7 @@ use alloc::vec::Vec;
#[derive(PartialEq, Eq, Clone, Default)]
pub struct BranchNode {
/// The collection of RLP encoded children.
pub stack: Vec<Vec<u8>>,
pub stack: Vec<RlpNode>,
/// The bitmask indicating the presence of children at the respective nibble positions
pub state_mask: TrieMask,
}
Expand Down Expand Up @@ -61,7 +61,7 @@ impl Decodable for BranchNode {
// Decode without advancing
let Header { payload_length, .. } = Header::decode(&mut &bytes[..])?;
let len = payload_length + length_of_length(payload_length);
stack.push(Vec::from(&bytes[..len]));
stack.push(RlpNode::from_raw_rlp(&bytes[..len])?);
bytes.advance(len);
state_mask.set_bit(index);
}
Expand All @@ -79,13 +79,13 @@ impl Decodable for BranchNode {

impl BranchNode {
/// Creates a new branch node with the given stack and state mask.
pub const fn new(stack: Vec<Vec<u8>>, state_mask: TrieMask) -> Self {
pub const fn new(stack: Vec<RlpNode>, state_mask: TrieMask) -> Self {
Self { stack, state_mask }
}

/// Return branch node as [BranchNodeRef].
pub fn as_ref(&self) -> BranchNodeRef<'_> {
BranchNodeRef::new(&self.stack, &self.state_mask)
BranchNodeRef::new(&self.stack, self.state_mask)
}
}

Expand All @@ -97,10 +97,10 @@ pub struct BranchNodeRef<'a> {
/// NOTE: The referenced stack might have more items than the number of children
/// for this node. We should only ever access items starting from
/// [BranchNodeRef::first_child_index].
pub stack: &'a [Vec<u8>],
pub stack: &'a [RlpNode],
/// Reference to bitmask indicating the presence of children at
/// the respective nibble positions.
pub state_mask: &'a TrieMask,
pub state_mask: TrieMask,
}

impl fmt::Debug for BranchNodeRef<'_> {
Expand All @@ -122,12 +122,9 @@ impl Encodable for BranchNodeRef<'_> {
Header { list: true, payload_length: self.rlp_payload_length() }.encode(out);

// Extend the RLP buffer with the present children
let mut stack_ptr = self.first_child_index();
for index in CHILD_INDEX_RANGE {
if self.state_mask.is_bit_set(index) {
out.put_slice(&self.stack[stack_ptr]);
// Advance the pointer to the next child.
stack_ptr += 1;
for (_, child) in self.children() {
if let Some(child) = child {
out.put_slice(child);
} else {
out.put_u8(EMPTY_STRING_CODE);
}
Expand All @@ -145,7 +142,8 @@ impl Encodable for BranchNodeRef<'_> {

impl<'a> BranchNodeRef<'a> {
/// Create a new branch node from the stack of nodes.
pub const fn new(stack: &'a [Vec<u8>], state_mask: &'a TrieMask) -> Self {
#[inline]
pub const fn new(stack: &'a [RlpNode], state_mask: TrieMask) -> Self {
Self { stack, state_mask }
}

Expand All @@ -155,34 +153,40 @@ impl<'a> BranchNodeRef<'a> {
///
/// If the stack length is less than number of children specified in state mask.
/// Means that the node is in inconsistent state.
#[inline]
pub fn first_child_index(&self) -> usize {
self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap()
}

#[inline]
fn children(&self) -> impl Iterator<Item = (u8, Option<&RlpNode>)> + '_ {
BranchChildrenIter::new(self)
}

/// Given the hash mask of children, return an iterator over stack items
/// that match the mask.
#[inline]
pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator<Item = B256> + '_ {
BranchChildrenIter::new(self)
self.children()
.filter_map(|(i, c)| c.map(|c| (i, c)))
.filter(move |(index, _)| hash_mask.is_bit_set(*index))
.map(|(_, child)| B256::from_slice(&child[1..]))
}

/// Returns the RLP encoding of the branch node given the state mask of children present.
pub fn rlp(&self, out: &mut Vec<u8>) -> Vec<u8> {
self.encode(out);
rlp_node(out)
/// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`.
#[inline]
pub fn rlp(&self, rlp: &mut Vec<u8>) -> RlpNode {
self.encode(rlp);
RlpNode::from_rlp(rlp)
}

/// Returns the length of RLP encoded fields of branch node.
#[inline]
fn rlp_payload_length(&self) -> usize {
let mut payload_length = 1;

let mut stack_ptr = self.first_child_index();
for digit in CHILD_INDEX_RANGE {
if self.state_mask.is_bit_set(digit) {
payload_length += self.stack[stack_ptr].len();
// Advance the pointer to the next child.
stack_ptr += 1;
for (_, child) in self.children() {
if let Some(child) = child {
payload_length += child.len();
} else {
payload_length += 1;
}
Expand All @@ -195,8 +199,8 @@ impl<'a> BranchNodeRef<'a> {
#[derive(Debug)]
struct BranchChildrenIter<'a> {
range: Range<u8>,
state_mask: &'a TrieMask,
stack_iter: Iter<'a, Vec<u8>>,
state_mask: TrieMask,
stack_iter: Iter<'a, RlpNode>,
}

impl<'a> BranchChildrenIter<'a> {
Expand All @@ -211,15 +215,34 @@ impl<'a> BranchChildrenIter<'a> {
}

impl<'a> Iterator for BranchChildrenIter<'a> {
type Item = (u8, &'a [u8]);
type Item = (u8, Option<&'a RlpNode>);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
let current = self.range.next()?;
if self.state_mask.is_bit_set(current) {
return Some((current, self.stack_iter.next()?));
}
}
let i = self.range.next()?;
let value = if self.state_mask.is_bit_set(i) {
// SAFETY: `first_child_index` guarantees that `stack` is exactly
// `state_mask.count_ones()` long.
Some(unsafe { self.stack_iter.next().unwrap_unchecked() })
} else {
None
};
Some((i, value))
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}

impl core::iter::FusedIterator for BranchChildrenIter<'_> {}

impl ExactSizeIterator for BranchChildrenIter<'_> {
#[inline]
fn len(&self) -> usize {
self.range.len()
}
}

Expand Down Expand Up @@ -292,7 +315,7 @@ impl BranchNodeCompact {
#[cfg(test)]
mod tests {
use super::*;
use crate::nodes::{word_rlp, ExtensionNode, LeafNode};
use crate::nodes::{ExtensionNode, LeafNode};
use nybbles::Nibbles;

#[test]
Expand All @@ -302,13 +325,19 @@ mod tests {
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), empty);

let sparse_node = BranchNode::new(
vec![word_rlp(&B256::repeat_byte(1)), word_rlp(&B256::repeat_byte(2))],
vec![
RlpNode::word_rlp(&B256::repeat_byte(1)),
RlpNode::word_rlp(&B256::repeat_byte(2)),
],
TrieMask::new(0b1000100),
);
let encoded = alloy_rlp::encode(&sparse_node);
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), sparse_node);

let leaf_child = LeafNode::new(Nibbles::from_nibbles(hex!("0203")), hex!("1234").to_vec());
let leaf_child = LeafNode::new(
Nibbles::from_nibbles(hex!("0203")),
RlpNode::from_raw(&hex!("1234")).unwrap(),
);
let mut buf = vec![];
let leaf_rlp = leaf_child.as_ref().rlp(&mut buf);
let branch_with_leaf = BranchNode::new(vec![leaf_rlp.clone()], TrieMask::new(0b0010));
Expand All @@ -323,7 +352,7 @@ mod tests {
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_ext);

let full = BranchNode::new(
core::iter::repeat(word_rlp(&B256::repeat_byte(23))).take(16).collect(),
core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(),
TrieMask::new(u16::MAX),
);
let encoded = alloy_rlp::encode(&full);
Expand Down
Loading

0 comments on commit 296cb8d

Please sign in to comment.