Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(root): support proof prefetch in the task #13428

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 99 additions & 41 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! State root task related functionality.

use alloy_primitives::map::HashSet;
use alloy_primitives::{map::HashSet, Address};
use derive_more::derive::Deref;
use rayon::iter::{ParallelBridge, ParallelIterator};
use reth_errors::ProviderError;
use reth_errors::{ProviderError, ProviderResult};
use reth_evm::system_calls::OnStateHook;
use reth_provider::{
providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory,
Expand Down Expand Up @@ -108,6 +108,8 @@ impl<Factory> StateRootConfig<Factory> {
#[derive(Debug)]
#[allow(dead_code)]
pub enum StateRootMessage<BPF: BlindedProviderFactory> {
/// Prefetch proof targets
PrefetchProofs(HashSet<Address>),
/// New state update from transaction execution
StateUpdate(EvmState),
/// Proof calculation completed for a specific state update
Expand Down Expand Up @@ -340,6 +342,29 @@ where
}
}

/// Handles request for proof prefetch.
fn on_prefetch_proof(
scope: &rayon::Scope<'env>,
config: StateRootConfig<Factory>,
targets: HashSet<Address>,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
) {
let proof_targets =
targets.into_iter().map(|address| (keccak256(address), Default::default())).collect();
extend_multi_proof_targets_ref(fetched_proof_targets, &proof_targets);

Self::spawn_multiproof(
scope,
config,
Default::default(),
proof_targets,
proof_sequence_number,
state_root_message_sender,
);
}

/// Handles state updates.
///
/// Returns proof targets derived from the state update.
Expand All @@ -356,46 +381,39 @@ where
let proof_targets = get_proof_targets(&hashed_state_update, fetched_proof_targets);
extend_multi_proof_targets_ref(fetched_proof_targets, &proof_targets);

Self::spawn_multiproof(
scope,
config,
hashed_state_update,
proof_targets,
proof_sequence_number,
state_root_message_sender,
);
}

fn spawn_multiproof(
scope: &rayon::Scope<'env>,
config: StateRootConfig<Factory>,
hashed_state_update: HashedPostState,
proof_targets: MultiProofTargets,
proof_sequence_number: u64,
state_root_message_sender: Sender<StateRootMessage<BPF>>,
) {
// Dispatch proof gathering for this state update
scope.spawn(move |_| {
let provider = match config.consistent_view.provider_ro() {
Ok(provider) => provider,
Err(error) => {
error!(target: "engine::root", ?error, "Could not get provider");
let _ = state_root_message_sender
.send(StateRootMessage::ProofCalculationError(error));
return;
}
};

// TODO: replace with parallel proof
let result = Proof::from_tx(provider.tx_ref())
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&config.nodes_sorted,
))
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&config.state_sorted,
))
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
.with_branch_node_hash_masks(true)
.multiproof(proof_targets.clone());
match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ = state_root_message_sender
.send(StateRootMessage::ProofCalculationError(error.into()));
}
scope.spawn(move |_| match calculate_multiproof(config, proof_targets.clone()) {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Box::new(ProofCalculated {
state_update: hashed_state_update,
targets: proof_targets,
proof,
sequence_number: proof_sequence_number,
}),
));
}
Err(error) => {
let _ =
state_root_message_sender.send(StateRootMessage::ProofCalculationError(error));
}
});
}
Expand Down Expand Up @@ -486,6 +504,21 @@ where
loop {
match self.rx.recv() {
Ok(message) => match message {
StateRootMessage::PrefetchProofs(targets) => {
debug!(
target: "engine::root",
len = targets.len(),
"Prefetching proofs"
);
Self::on_prefetch_proof(
scope,
self.config.clone(),
targets,
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
self.tx.clone(),
);
}
StateRootMessage::StateUpdate(update) => {
if updates_received == 0 {
first_update_time = Some(Instant::now());
Expand Down Expand Up @@ -681,6 +714,31 @@ fn get_proof_targets(
targets
}

/// Calculate multiproof for the targets.
#[inline]
fn calculate_multiproof<Factory>(
config: StateRootConfig<Factory>,
proof_targets: MultiProofTargets,
) -> ProviderResult<MultiProof>
where
Factory: DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider,
{
let provider = config.consistent_view.provider_ro()?;

Ok(Proof::from_tx(provider.tx_ref())
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&config.nodes_sorted,
))
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&config.state_sorted,
))
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
.with_branch_node_hash_masks(true)
.multiproof(proof_targets)?)
}

/// Updates the sparse trie with the given proofs and state, and returns the updated trie and the
/// time it took.
fn update_sparse_trie<
Expand Down
Loading