Skip to content

Commit

Permalink
perf(trie): split Parallel::multiproof workload in chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez committed Dec 17, 2024
1 parent 03649f2 commit 8a3d628
Showing 1 changed file with 136 additions and 44 deletions.
180 changes: 136 additions & 44 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use reth_trie::{
};
use reth_trie_common::proof::ProofRetainer;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::sync::Arc;
use tracing::{debug, error};
use std::{sync::Arc, time::Instant};
use tracing::{debug, error, trace};

#[cfg(feature = "metrics")]
use crate::metrics::ParallelStateRootMetrics;
Expand Down Expand Up @@ -108,54 +108,152 @@ where
});
let prefix_sets = prefix_sets.freeze();

let storage_root_targets = StorageRootTargets::new(
let storage_root_targets: Vec<_> = StorageRootTargets::new(
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
prefix_sets.storage_prefix_sets.clone(),
);
)
.into_iter()
.sorted_unstable_by_key(|(address, _)| *address)
.collect();

// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs");

const CHUNK_SIZE: usize = 128;
let num_chunks = storage_root_targets.len().div_ceil(CHUNK_SIZE);

debug!(
target: "trie::parallel_state_root",
total_targets = storage_root_targets.len(),
chunk_size = CHUNK_SIZE,
num_chunks,
"Starting batched proof generation"
);

// Create a single channel for all proofs with appropriate capacity
let (tx, rx) = std::sync::mpsc::sync_channel(CHUNK_SIZE);

let mut storage_proofs =
B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let view = self.view.clone();
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
for (chunk_idx, chunk) in storage_root_targets.chunks(CHUNK_SIZE).enumerate() {
let chunk_size = chunk.len();

let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
debug!(
target: "trie::parallel_state_root",
chunk_idx,
chunk_size,
"Processing proof batch"
);

let (tx, rx) = std::sync::mpsc::sync_channel(1);
// Spawn tasks for this batch
for (hashed_address, prefix_set) in chunk {
let view = self.view.clone();
let target_slots = targets.get(hashed_address).cloned().unwrap_or_default();
let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;
let tx = tx.clone();
let hashed_address = *hashed_address;
let prefix_set = prefix_set.clone();

rayon::spawn_fifo(move || {
let result = (|| -> Result<_, ParallelStateRootError> {
let provider_ro = view.provider_ro()?;
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&trie_nodes_sorted,
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
rayon::spawn(move || {
debug!(
target: "trie::parallel",
?hashed_address,
"Starting proof calculation"
);

StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))
})();
if let Err(err) = tx.send(result) {
error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result");
let task_start = Instant::now();
let result = (|| -> Result<_, ParallelStateRootError> {
let provider_start = Instant::now();
let provider_ro = view.provider_ro()?;
trace!(
target: "trie::parallel",
?hashed_address,
provider_time_ms = provider_start.elapsed().as_millis(),
"Got provider"
);

let cursor_start = Instant::now();
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&trie_nodes_sorted,
);
let hashed_cursor_factory = HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
trace!(
target: "trie::parallel",
?hashed_address,
cursor_time_ms = cursor_start.elapsed().as_millis(),
"Created cursors"
);

let proof_start = Instant::now();
let proof = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;

trace!(
target: "trie::parallel",
?hashed_address,
proof_time_ms = proof_start.elapsed().as_millis(),
"Completed proof calculation"
);

Ok((hashed_address, proof))
})();

let task_time = task_start.elapsed();
if let Err(e) = tx.send(result) {
error!(
target: "trie::parallel",
?hashed_address,
error = ?e,
task_time_ms = task_time.as_millis(),
"Failed to send proof result"
);
}
});
}

// Wait for all proofs in this batch
for _ in 0..chunk_size {
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
Ok(result) => match result {
Ok((address, proof)) => {
storage_proofs.insert(address, proof);
}
Err(e) => {
error!(
target: "trie::parallel",
error = ?e,
chunk_idx,
"Proof calculation failed"
);
return Err(e);
}
},
Err(e) => {
error!(
target: "trie::parallel",
error = ?e,
chunk_idx,
"Failed to receive proof result"
);
return Err(ParallelStateRootError::Other(format!(
"Failed to receive proof result: {e:?}",
)));
}
}
});
storage_proofs.insert(hashed_address, rx);
}
}

let provider_ro = self.view.provider_ro()?;
Expand Down Expand Up @@ -199,13 +297,7 @@ where
}
TrieElement::Leaf(hashed_address, account) => {
let storage_multiproof = match storage_proofs.remove(&hashed_address) {
Some(rx) => rx.recv().map_err(|_| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(
DatabaseError::Other(format!(
"channel closed for {hashed_address}"
)),
))
})??,
Some(proof) => proof,
// Since we do not store all intermediate nodes in the database, there might
// be a possibility of re-adding a non-modified leaf to the hash builder.
None => {
Expand Down

0 comments on commit 8a3d628

Please sign in to comment.