diff --git a/crates/engine/tree/src/tree/root.rs b/crates/engine/tree/src/tree/root.rs index 79d4941c8d31..cb64d95d8f92 100644 --- a/crates/engine/tree/src/tree/root.rs +++ b/crates/engine/tree/src/tree/root.rs @@ -1,9 +1,9 @@ //! State root task related functionality. -use alloy_primitives::map::HashSet; +use alloy_primitives::{map::HashSet, Address}; use derive_more::derive::Deref; use rayon::iter::{ParallelBridge, ParallelIterator}; -use reth_errors::ProviderError; +use reth_errors::{ProviderError, ProviderResult}; use reth_evm::system_calls::OnStateHook; use reth_provider::{ providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, @@ -108,6 +108,8 @@ impl StateRootConfig { #[derive(Debug)] #[allow(dead_code)] pub enum StateRootMessage { + /// Prefetch proof targets + PrefetchProofs(HashSet
), /// New state update from transaction execution StateUpdate(EvmState), /// Proof calculation completed for a specific state update @@ -340,6 +342,29 @@ where } } + /// Handles request for proof prefetch. + fn on_prefetch_proof( + scope: &rayon::Scope<'env>, + config: StateRootConfig, + targets: HashSet
, + fetched_proof_targets: &mut MultiProofTargets, + proof_sequence_number: u64, + state_root_message_sender: Sender>, + ) { + let proof_targets = + targets.into_iter().map(|address| (keccak256(address), Default::default())).collect(); + extend_multi_proof_targets_ref(fetched_proof_targets, &proof_targets); + + Self::spawn_multiproof( + scope, + config, + Default::default(), + proof_targets, + proof_sequence_number, + state_root_message_sender, + ); + } + /// Handles state updates. /// /// Returns proof targets derived from the state update. @@ -356,46 +381,39 @@ where let proof_targets = get_proof_targets(&hashed_state_update, fetched_proof_targets); extend_multi_proof_targets_ref(fetched_proof_targets, &proof_targets); + Self::spawn_multiproof( + scope, + config, + hashed_state_update, + proof_targets, + proof_sequence_number, + state_root_message_sender, + ); + } + + fn spawn_multiproof( + scope: &rayon::Scope<'env>, + config: StateRootConfig, + hashed_state_update: HashedPostState, + proof_targets: MultiProofTargets, + proof_sequence_number: u64, + state_root_message_sender: Sender>, + ) { // Dispatch proof gathering for this state update - scope.spawn(move |_| { - let provider = match config.consistent_view.provider_ro() { - Ok(provider) => provider, - Err(error) => { - error!(target: "engine::root", ?error, "Could not get provider"); - let _ = state_root_message_sender - .send(StateRootMessage::ProofCalculationError(error)); - return; - } - }; - - // TODO: replace with parallel proof - let result = Proof::from_tx(provider.tx_ref()) - .with_trie_cursor_factory(InMemoryTrieCursorFactory::new( - DatabaseTrieCursorFactory::new(provider.tx_ref()), - &config.nodes_sorted, - )) - .with_hashed_cursor_factory(HashedPostStateCursorFactory::new( - DatabaseHashedCursorFactory::new(provider.tx_ref()), - &config.state_sorted, - )) - .with_prefix_sets_mut(config.prefix_sets.as_ref().clone()) - .with_branch_node_hash_masks(true) - .multiproof(proof_targets.clone()); - match result { - Ok(proof) => { - let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( - Box::new(ProofCalculated { - state_update: hashed_state_update, - targets: proof_targets, - proof, - sequence_number: proof_sequence_number, - }), - )); - } - Err(error) => { - let _ = state_root_message_sender - .send(StateRootMessage::ProofCalculationError(error.into())); - } + scope.spawn(move |_| match calculate_multiproof(config, proof_targets.clone()) { + Ok(proof) => { + let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated( + Box::new(ProofCalculated { + state_update: hashed_state_update, + targets: proof_targets, + proof, + sequence_number: proof_sequence_number, + }), + )); + } + Err(error) => { + let _ = + state_root_message_sender.send(StateRootMessage::ProofCalculationError(error)); } }); } @@ -486,6 +504,21 @@ where loop { match self.rx.recv() { Ok(message) => match message { + StateRootMessage::PrefetchProofs(targets) => { + debug!( + target: "engine::root", + len = targets.len(), + "Prefetching proofs" + ); + Self::on_prefetch_proof( + scope, + self.config.clone(), + targets, + &mut self.fetched_proof_targets, + self.proof_sequencer.next_sequence(), + self.tx.clone(), + ); + } StateRootMessage::StateUpdate(update) => { if updates_received == 0 { first_update_time = Some(Instant::now()); @@ -681,6 +714,31 @@ fn get_proof_targets( targets } +/// Calculate multiproof for the targets. +#[inline] +fn calculate_multiproof( + config: StateRootConfig, + proof_targets: MultiProofTargets, +) -> ProviderResult +where + Factory: DatabaseProviderFactory + StateCommitmentProvider, +{ + let provider = config.consistent_view.provider_ro()?; + + Ok(Proof::from_tx(provider.tx_ref()) + .with_trie_cursor_factory(InMemoryTrieCursorFactory::new( + DatabaseTrieCursorFactory::new(provider.tx_ref()), + &config.nodes_sorted, + )) + .with_hashed_cursor_factory(HashedPostStateCursorFactory::new( + DatabaseHashedCursorFactory::new(provider.tx_ref()), + &config.state_sorted, + )) + .with_prefix_sets_mut(config.prefix_sets.as_ref().clone()) + .with_branch_node_hash_masks(true) + .multiproof(proof_targets)?) +} + /// Updates the sparse trie with the given proofs and state, and returns the updated trie and the /// time it took. fn update_sparse_trie<