From f6d02b3d3397a88e0029a18af66dd54461561916 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petar=20Vujovi=C4=87?= Date: Mon, 22 Jul 2024 11:30:31 +0200 Subject: [PATCH] feat(host): create initial cancel handling (#316) * [WIP](host): connect cancel API * feat(host): create initial cancel handling * refactor(host,task_manager): remove duplicate struct and impls * refactor(lib): remove reference where not needed * refactor(host,task_manager): replace struct with type * chore(host): use get instead of post for reporting * fix(task_manager): clear hash map for in memory db * refactor(host,task_manager): clean up function params * refactor(host): minimize fetching and type complexity * refactor(host): skip task if getting task data is unsuccessful * feat(host): pass error from proving to task actor * feat(host): make getting task manager more ergonomic * refactor(core,lib): clean up preflight code * feat(host): simplify proof request config reading * fix(lib): fix typo * refactor(host): extract cache utils to separate module * fix(core): handle errors for blob checking * fix(task_manager): handle proof retrieval better * refactor(host): make proof handler more ergonomic * fix(tasks): pass closure instead of variable * fix(tasks): handle inconsistency in status triplet * chore(core,lib): clean up dependency declaration * feat(core,host,lib,risc0,sp1): unify proof output * fix(clippy): fix lint * refactor(host): add error trace * refactor(tasks): use type instead of struct * refactor(tasks): move struct definition * feat(host): use single channel for tasks * feat(all): pass a store to provers to cancel a remote task * fix(all): fix some async errors * fix(core,host,lib): handle async boundary * fix(provers): fix types on trait implementors * fix(core): fix lints * fix(provers): fix lints * fix(core): fix test mock * fix(core): use provers instead of drivers * fix(lib,tasks): account for multi prover environment * fix(tasks): fix clippy lints * feat(risc0): add cancellation to bonsai (#320) * feat(risc0): add cancellation to bonsai * fix(risc0): remove id after cancellation * fix(risc0): add prover code to proof key * feat(sp1): add cancellation support (#322) * feat(sp1): add cancellation support * fix(sp1): add prover code to proof key * refactor(risc0,sp1): rename arguments * fix(clippy): remove unused imports * fix(clippy): rename unused args and unwrap result --- Cargo.lock | 2 + Cargo.toml | 14 +- core/Cargo.toml | 10 +- core/src/interfaces.rs | 64 ++++-- core/src/lib.rs | 36 ++- core/src/preflight.rs | 142 ++++++------ core/src/prover.rs | 15 +- host/Cargo.toml | 1 + host/src/cache.rs | 158 +++++++++++++ host/src/interfaces.rs | 21 ++ host/src/lib.rs | 40 +++- host/src/proof.rs | 247 +++++++++++++-------- host/src/server/api/mod.rs | 2 +- host/src/server/api/v1/proof.rs | 294 +++---------------------- host/src/server/api/v2/mod.rs | 22 +- host/src/server/api/v2/proof/cancel.rs | 26 ++- host/src/server/api/v2/proof/mod.rs | 77 +++---- host/src/server/api/v2/proof/prune.rs | 4 +- host/src/server/api/v2/proof/report.rs | 8 +- lib/Cargo.toml | 13 +- lib/src/input.rs | 31 ++- lib/src/primitives/eip4844.rs | 4 +- lib/src/prover.rs | 39 +++- provers/risc0/driver/src/bonsai.rs | 22 +- provers/risc0/driver/src/lib.rs | 35 ++- provers/sgx/prover/src/lib.rs | 19 +- provers/sp1/driver/src/benchmark.rs | 28 ++- provers/sp1/driver/src/lib.rs | 109 ++++++++- script/cancel-block.sh | 132 +++++++++++ tasks/src/adv_sqlite.rs | 235 +++++++++++++++----- tasks/src/lib.rs | 181 ++++++--------- tasks/src/mem_db.rs | 205 ++++++++--------- tasks/tests/main.rs | 168 +++++--------- 33 files changed, 1417 insertions(+), 987 deletions(-) create mode 100644 host/src/cache.rs create mode 100755 script/cancel-block.sh diff --git a/Cargo.lock b/Cargo.lock index 62433c5fb..a2a0fee7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5776,6 +5776,7 @@ dependencies = [ "sp1-driver", "thiserror", "tokio", + "tokio-util", "tower", "tower-http", "tracing", @@ -5826,6 +5827,7 @@ dependencies = [ "tokio", "tracing", "url", + "utoipa", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index e9279b135..4d31db36d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,10 @@ raiko-core = { path = "./core" } raiko-tasks = { path = "./tasks" } # reth -reth-primitives = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false, features = ["alloy-compat", "taiko"] } +reth-primitives = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false, features = [ + "alloy-compat", + "taiko", +] } reth-evm-ethereum = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false } reth-evm = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false } reth-rpc-types = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false } @@ -77,16 +80,14 @@ alloy-network = { version = "0.1", default-features = false, features = [ "k256", ] } alloy-contract = { version = "0.1", default-features = false } -alloy-eips = { version = "0.1", default-features = false, features = [ - "serde", -] } +alloy-eips = { version = "0.1", default-features = false, features = ["serde"] } alloy-provider = { version = "0.1", default-features = false, features = [ "reqwest", ] } -alloy-transport-http = { version = "0.1",default-features = false, features = [ +alloy-transport-http = { version = "0.1", default-features = false, features = [ "reqwest", ] } -alloy-signer = { version = "0.1", default-features = false} +alloy-signer = { version = "0.1", default-features = false } alloy-signer-local = { version = "0.1", default-features = false } # ethers (TODO: remove) @@ -135,6 +136,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } structopt = "0.3.24" prometheus = { version = "0.13.3", features = ["process"] } tokio = { version = "^1.23", features = ["full"] } +tokio-util = { version = "0.7.11" } reqwest = { version = "0.11.22", features = ["json"] } url = "2.5.0" async-trait = "0.1.80" diff --git a/core/Cargo.toml b/core/Cargo.toml index 3f2380ea7..8509db15e 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -14,11 +14,11 @@ sgx-prover = { path = "../provers/sgx/prover", optional = true } raiko-lib = { workspace = true } # reth -reth-primitives.workspace = true -reth-evm-ethereum.workspace = true -reth-evm.workspace = true -reth-revm.workspace = true -reth-provider.workspace = true +reth-primitives = { workspace = true } +reth-evm-ethereum = { workspace = true } +reth-evm = { workspace = true } +reth-revm = { workspace = true } +reth-provider = { workspace = true } # alloy alloy-rlp = { workspace = true } diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index ebd86e127..eb385b385 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -4,7 +4,7 @@ use clap::{Args, ValueEnum}; use raiko_lib::{ input::{BlobProofType, GuestInput, GuestOutput}, primitives::eip4844::{calc_kzg_proof, commitment_to_version_hash, kzg_proof_to_bytes}, - prover::{Proof, Prover, ProverError}, + prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverError}, }; use reth_primitives::hex; use serde::{Deserialize, Serialize}; @@ -160,14 +160,15 @@ impl ProofType { input: GuestInput, output: &GuestOutput, config: &Value, + store: Option<&mut dyn IdWrite>, ) -> RaikoResult { let mut proof = match self { - ProofType::Native => NativeProver::run(input.clone(), output, config) + ProofType::Native => NativeProver::run(input.clone(), output, config, store) .await - .map_err(>::into), + .map_err(|e| e.into()), ProofType::Sp1 => { #[cfg(feature = "sp1")] - return sp1_driver::Sp1Prover::run(input.clone(), output, config) + return sp1_driver::Sp1Prover::run(input.clone(), output, config, store) .await .map_err(|e| e.into()); #[cfg(not(feature = "sp1"))] @@ -175,7 +176,7 @@ impl ProofType { } ProofType::Risc0 => { #[cfg(feature = "risc0")] - return risc0_driver::Risc0Prover::run(input.clone(), output, config) + return risc0_driver::Risc0Prover::run(input.clone(), output, config, store) .await .map_err(|e| e.into()); #[cfg(not(feature = "risc0"))] @@ -183,7 +184,7 @@ impl ProofType { } ProofType::Sgx => { #[cfg(feature = "sgx")] - return sgx_prover::SgxProver::run(input.clone(), output, config) + return sgx_prover::SgxProver::run(input.clone(), output, config, store) .await .map_err(|e| e.into()); #[cfg(not(feature = "sgx"))] @@ -195,18 +196,55 @@ impl ProofType { if let Some(blob_commitment) = input.taiko.blob_commitment.clone() { let kzg_proof = calc_kzg_proof( &input.taiko.tx_data, - &commitment_to_version_hash(&blob_commitment.try_into().unwrap()), + &commitment_to_version_hash(&blob_commitment.try_into().map_err(|_| { + RaikoError::Conversion( + "Could not convert blob commitment to version hash".to_owned(), + ) + })?), ) - .unwrap(); - let kzg_proof_hex = hex::encode(kzg_proof_to_bytes(&kzg_proof)); - proof - .as_object_mut() - .unwrap() - .insert("kzg_proof".to_string(), Value::String(kzg_proof_hex)); + .map_err(|e| anyhow::anyhow!(e))?; + proof.kzg_proof = Some(hex::encode(kzg_proof_to_bytes(&kzg_proof))); } Ok(proof) } + + pub async fn cancel_proof( + &self, + proof_key: ProofKey, + read: Box<&mut dyn IdStore>, + ) -> RaikoResult<()> { + match self { + ProofType::Native => NativeProver::cancel(proof_key, read) + .await + .map_err(|e| e.into()), + ProofType::Sp1 => { + #[cfg(feature = "sp1")] + return sp1_driver::Sp1Prover::cancel(proof_key, read) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sp1"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + ProofType::Risc0 => { + #[cfg(feature = "risc0")] + return risc0_driver::Risc0Prover::cancel(proof_key, read) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "risc0"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + ProofType::Sgx => { + #[cfg(feature = "sgx")] + return sgx_prover::SgxProver::cancel(proof_key, read) + .await + .map_err(|e| e.into()); + #[cfg(not(feature = "sgx"))] + Err(RaikoError::FeatureNotSupportedError(*self)) + } + }?; + Ok(()) + } } #[serde_as] diff --git a/core/src/lib.rs b/core/src/lib.rs index 8dd72c937..40b11989f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -5,11 +5,20 @@ use crate::{ }; use alloy_primitives::Address; use alloy_rpc_types::EIP1186AccountProofResponse; -use raiko_lib::builder::{create_mem_db, RethBlockBuilder}; -use raiko_lib::consts::{ChainSpec, VerifierType}; -use raiko_lib::input::{GuestInput, GuestOutput, TaikoProverData}; use raiko_lib::protocol_instance::ProtocolInstance; use raiko_lib::prover::Proof; +use raiko_lib::{ + builder::{create_mem_db, RethBlockBuilder}, + prover::ProofKey, +}; +use raiko_lib::{ + consts::{ChainSpec, VerifierType}, + prover::IdStore, +}; +use raiko_lib::{ + input::{GuestInput, GuestOutput, TaikoProverData}, + prover::IdWrite, +}; use reth_primitives::Header; use serde_json::Value; use std::{collections::HashMap, hint::black_box}; @@ -92,13 +101,26 @@ impl Raiko { } } - pub async fn prove(&self, input: GuestInput, output: &GuestOutput) -> RaikoResult { - let data = serde_json::to_value(&self.request)?; + pub async fn prove( + &self, + input: GuestInput, + output: &GuestOutput, + store: Option<&mut dyn IdWrite>, + ) -> RaikoResult { + let config = serde_json::to_value(&self.request)?; self.request .proof_type - .run_prover(input, output, &data) + .run_prover(input, output, &config, store) .await } + + pub async fn cancel( + &self, + proof_key: ProofKey, + read: Box<&mut dyn IdStore>, + ) -> RaikoResult<()> { + self.request.proof_type.cancel_proof(proof_key, read).await + } } fn check_header(exp: &Header, header: &Header) -> Result<(), RaikoError> { @@ -268,7 +290,7 @@ mod tests { .expect("input generation failed"); let output = raiko.get_output(&input).expect("output generation failed"); let _proof = raiko - .prove(input, &output) + .prove(input, &output, None) .await .expect("proof generation failed"); } diff --git a/core/src/preflight.rs b/core/src/preflight.rs index ba1895ad0..7ecfee2e4 100644 --- a/core/src/preflight.rs +++ b/core/src/preflight.rs @@ -25,7 +25,6 @@ use raiko_lib::{ eip4844::{self, commitment_to_version_hash, KZG_SETTINGS}, mpt::proofs_to_tries, }, - utils::zlib_compress_data, Measurement, }; use reth_evm_ethereum::taiko::decode_anchor; @@ -48,14 +47,17 @@ pub async fn preflight( let blocks = provider .get_blocks(&[(block_number, true), (block_number - 1, false)]) .await?; - let (block, parent_block) = ( - blocks.first().ok_or_else(|| { - RaikoError::Preflight("No block data for the requested block".to_owned()) - })?, - blocks.get(1).ok_or_else(|| { - RaikoError::Preflight("No parent block data for the requested block".to_owned()) - })?, - ); + let mut blocks = blocks.iter(); + let Some(block) = blocks.next() else { + return Err(RaikoError::Preflight( + "No block data for the requested block".to_owned(), + )); + }; + let Some(parent_block) = blocks.next() else { + return Err(RaikoError::Preflight( + "No parent block data for the requested block".to_owned(), + )); + }; info!( "Processing block {:?} with hash: {:?}", @@ -67,9 +69,8 @@ pub async fn preflight( debug!("block transactions: {:?}", block.transactions.len()); // Convert the alloy block to a reth block - let block = Block::try_from(block.clone()).map_err(|e| { - RaikoError::Conversion(format!("Failed converting to reth block: {}", e).to_owned()) - })?; + let block = Block::try_from(block.clone()) + .map_err(|e| RaikoError::Conversion(format!("Failed converting to reth block: {e}")))?; let taiko_guest_input = if taiko_chain_spec.is_taiko() { prepare_taiko_chain_input( @@ -84,66 +85,66 @@ pub async fn preflight( } else { // For Ethereum blocks we just convert the block transactions in a tx_list // so that we don't have to supports separate paths. - TaikoGuestInput { - tx_data: zlib_compress_data(&alloy_rlp::encode(&block.body))?, - ..Default::default() - } + TaikoGuestInput::try_from(block.body.clone()).map_err(|e| RaikoError::Conversion(e.0))? }; measurement.stop(); // Create the guest input - let input = GuestInput { - block: block.clone(), - chain_spec: taiko_chain_spec.clone(), - parent_state_trie: Default::default(), - parent_storage: Default::default(), - contracts: Default::default(), - parent_header: parent_block.header.clone().try_into().unwrap(), - ancestor_headers: Default::default(), - taiko: taiko_guest_input, - }; + let input = GuestInput::from(( + block.clone(), + parent_block + .header + .clone() + .try_into() + .expect("Couldn't transform alloy header to reth header"), + taiko_chain_spec.clone(), + taiko_guest_input, + )); // Create the block builder, run the transactions and extract the DB - let provider_db = ProviderDb::new( - provider, - taiko_chain_spec, - if let Some(parent_block_number) = parent_block.header.number { - parent_block_number - } else { - return Err(RaikoError::Preflight( - "No parent block number for the requested block".to_owned(), - )); - }, - ) - .await?; + let Some(parent_block_number) = parent_block.header.number else { + return Err(RaikoError::Preflight( + "No parent block number for the requested block".to_owned(), + )); + }; + let provider_db = ProviderDb::new(provider, taiko_chain_spec, parent_block_number).await?; // Now re-execute the transactions in the block to collect all required data let mut builder = RethBlockBuilder::new(&input, provider_db); + // Optimize data gathering by executing the transactions multiple times so data can be requested in batches - let is_local = false; - let max_iterations = if is_local { 1 } else { 100 }; - let mut done = false; - let mut num_iterations = 0; - while !done { + let max_iterations = 100; + for num_iterations in 0.. { inplace_print(&format!("Execution iteration {num_iterations}...")); - let optimistic = num_iterations + 1 < max_iterations; - builder.db.as_mut().unwrap().optimistic = optimistic; + let Some(db) = builder.db.as_mut() else { + return Err(RaikoError::Preflight("No db in builder".to_owned())); + }; + db.optimistic = num_iterations + 1 < max_iterations; + + builder + .execute_transactions(num_iterations + 1 < max_iterations) + .map_err(|_| { + RaikoError::Preflight("Executing transactions in builder failed".to_owned()) + })?; - builder.execute_transactions(optimistic).expect("execute"); - if builder.db.as_mut().unwrap().fetch_data().await { - done = true; + let Some(db) = builder.db.as_mut() else { + return Err(RaikoError::Preflight("No db in builder".to_owned())); + }; + if db.fetch_data().await { + clear_line(); + info!("State data fetched in {num_iterations} iterations"); + break; } - num_iterations += 1; } - clear_line(); - info!("State data fetched in {num_iterations} iterations"); - let provider_db = builder.db.as_mut().unwrap(); + let Some(db) = builder.db.as_mut() else { + return Err(RaikoError::Preflight("No db in builder".to_owned())); + }; // Gather inclusion proofs for the initial and final state let measurement = Measurement::start("Fetching storage proofs...", true); - let (parent_proofs, proofs, num_storage_proofs) = provider_db.get_proofs().await?; + let (parent_proofs, proofs, num_storage_proofs) = db.get_proofs().await?; measurement.stop_with_count(&format!( "[{} Account/{num_storage_proofs} Storage]", parent_proofs.len() + proofs.len(), @@ -151,32 +152,34 @@ pub async fn preflight( // Construct the state trie and storage from the storage proofs. let measurement = Measurement::start("Constructing MPT...", true); - let (state_trie, storage) = + let (parent_state_trie, parent_storage) = proofs_to_tries(input.parent_header.state_root, parent_proofs, proofs)?; measurement.stop(); // Gather proofs for block history let measurement = Measurement::start("Fetching historical block headers...", true); - let ancestor_headers = provider_db.get_ancestor_headers().await?; + let ancestor_headers = db.get_ancestor_headers().await?; measurement.stop(); // Get the contracts from the initial db. let measurement = Measurement::start("Fetching contract code...", true); - let mut contracts = HashSet::new(); - let initial_db = &provider_db.initial_db; - for account in initial_db.accounts.values() { - let code = &account.info.code; - if let Some(code) = code { - contracts.insert(code.bytecode().0.clone()); - } - } + let contracts = + HashSet::::from_iter(db.initial_db.accounts.values().filter_map(|account| { + account + .info + .code + .clone() + .map(|code| Bytes(code.bytecode().0.clone())) + })) + .into_iter() + .collect::>(); measurement.stop(); // Fill in remaining generated guest input data let input = GuestInput { - parent_state_trie: state_trie, - parent_storage: storage, - contracts: contracts.into_iter().map(Bytes).collect(), + parent_state_trie, + parent_storage, + contracts, ancestor_headers, ..input }; @@ -282,7 +285,7 @@ fn block_time_to_block_slot( genesis_time: u64, block_per_slot: u64, ) -> RaikoResult { - if genesis_time == 0u64 { + if genesis_time == 0 { Err(RaikoError::Anyhow(anyhow!( "genesis time is 0, please check chain spec" ))) @@ -296,10 +299,7 @@ fn block_time_to_block_slot( } fn blob_to_bytes(blob_str: &str) -> Vec { - match hex::decode(blob_str.to_lowercase().trim_start_matches("0x")) { - Ok(b) => b, - Err(_) => Vec::new(), - } + hex::decode(blob_str.to_lowercase().trim_start_matches("0x")).unwrap_or_default() } fn calc_blob_versioned_hash(blob_str: &str) -> [u8; 32] { diff --git a/core/src/prover.rs b/core/src/prover.rs index 188e2dd46..4ac1a512d 100644 --- a/core/src/prover.rs +++ b/core/src/prover.rs @@ -4,7 +4,7 @@ use raiko_lib::{ consts::VerifierType, input::{GuestInput, GuestOutput}, protocol_instance::ProtocolInstance, - prover::{to_proof, Proof, Prover, ProverConfig, ProverError, ProverResult}, + prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; use serde::{de::Error, Deserialize, Serialize}; use serde_with::serde_as; @@ -28,6 +28,7 @@ impl Prover for NativeProver { input: GuestInput, output: &GuestOutput, config: &ProverConfig, + _store: Option<&mut dyn IdWrite>, ) -> ProverResult { let param = config @@ -56,8 +57,14 @@ impl Prover for NativeProver { )); } - to_proof(Ok(NativeResponse { - output: output.clone(), - })) + Ok(Proof { + proof: None, + quote: None, + kzg_proof: None, + }) + } + + async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> { + Ok(()) } } diff --git a/host/Cargo.toml b/host/Cargo.toml index 32eb09207..6222369e6 100644 --- a/host/Cargo.toml +++ b/host/Cargo.toml @@ -50,6 +50,7 @@ serde = { workspace = true } serde_with = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true } +tokio-util = { workspace = true } env_logger = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/host/src/cache.rs b/host/src/cache.rs new file mode 100644 index 000000000..c61d5959a --- /dev/null +++ b/host/src/cache.rs @@ -0,0 +1,158 @@ +use std::{fs::File, path::PathBuf}; + +use raiko_core::{ + interfaces::RaikoError, + provider::{rpc::RpcBlockDataProvider, BlockDataProvider}, +}; +use raiko_lib::input::{get_input_path, GuestInput}; +use tracing::{debug, info}; + +use crate::interfaces::{HostError, HostResult}; + +pub fn get_input( + cache_path: &Option, + block_number: u64, + network: &str, +) -> Option { + let dir = cache_path.as_ref()?; + + let path = get_input_path(dir, block_number, network); + + let file = File::open(path).ok()?; + + bincode::deserialize_from(file).ok() +} + +pub fn set_input( + cache_path: &Option, + block_number: u64, + network: &str, + input: &GuestInput, +) -> HostResult<()> { + let Some(dir) = cache_path.as_ref() else { + return Ok(()); + }; + + let path = get_input_path(dir, block_number, network); + info!("caching input for {path:?}"); + + let file = File::create(&path).map_err(>::into)?; + bincode::serialize_into(file, input).map_err(|e| HostError::Anyhow(e.into())) +} + +pub async fn validate_input( + cached_input: Option, + provider: &RpcBlockDataProvider, +) -> HostResult { + if let Some(cache_input) = cached_input { + debug!("Using cached input"); + let blocks = provider + .get_blocks(&[(cache_input.block.number, false)]) + .await?; + let block = blocks + .first() + .ok_or_else(|| RaikoError::RPC("No block data for the requested block".to_owned()))?; + + let cached_block_hash = cache_input.block.header.hash_slow(); + let real_block_hash = block.header.hash.unwrap(); + debug!( + "cache_block_hash={:?}, real_block_hash={:?}", + cached_block_hash, real_block_hash + ); + + // double check if cache is valid + if cached_block_hash == real_block_hash { + Ok(cache_input) + } else { + Err(HostError::InvalidRequestConfig( + "Cached input is not valid".to_owned(), + )) + } + } else { + Err(HostError::InvalidRequestConfig( + "Cached input is not enabled".to_owned(), + )) + } +} + +#[cfg(test)] +mod test { + use crate::cache; + + use alloy_primitives::{Address, B256}; + use raiko_core::{ + interfaces::{ProofRequest, ProofType}, + provider::rpc::RpcBlockDataProvider, + Raiko, + }; + use raiko_lib::input::BlobProofType; + use raiko_lib::{ + consts::{Network, SupportedChainSpecs}, + input::GuestInput, + }; + + async fn create_cache_input( + l1_network: &String, + network: &String, + block_number: u64, + ) -> (GuestInput, RpcBlockDataProvider) { + let l1_chain_spec = SupportedChainSpecs::default() + .get_chain_spec(l1_network) + .unwrap(); + let taiko_chain_spec = SupportedChainSpecs::default() + .get_chain_spec(network) + .unwrap(); + let proof_request = ProofRequest { + block_number, + network: network.to_string(), + l1_network: l1_network.to_string(), + graffiti: B256::ZERO, + prover: Address::ZERO, + proof_type: ProofType::Native, + blob_proof_type: BlobProofType::ProofOfCommitment, + prover_args: Default::default(), + }; + let raiko = Raiko::new( + l1_chain_spec.clone(), + taiko_chain_spec.clone(), + proof_request.clone(), + ); + let provider = RpcBlockDataProvider::new( + &taiko_chain_spec.rpc.clone(), + proof_request.block_number - 1, + ) + .expect("provider init ok"); + + let input = raiko + .generate_input(provider.clone()) + .await + .expect("input generation failed"); + (input, provider.clone()) + } + + #[tokio::test] + async fn test_generate_input_from_cache() { + let l1 = &Network::Holesky.to_string(); + let l2 = &Network::TaikoA7.to_string(); + let block_number: u64 = 123456; + let (input, provider) = create_cache_input(l1, l2, block_number).await; + let cache_path = Some("./".into()); + assert!(cache::set_input(&cache_path, block_number, l2, &input).is_ok()); + let cached_input = cache::get_input(&cache_path, block_number, l2).expect("load cache"); + assert!(cache::validate_input(Some(cached_input), &provider) + .await + .is_ok()); + + let new_l1 = &Network::Ethereum.to_string(); + let new_l2 = &Network::TaikoMainnet.to_string(); + let (new_input, _) = create_cache_input(new_l1, new_l2, block_number).await; + // save to old l2 cache slot + assert!(cache::set_input(&cache_path, block_number, l2, &new_input).is_ok()); + let inv_cached_input = cache::get_input(&cache_path, block_number, l2).expect("load cache"); + + // should fail with old provider + assert!(cache::validate_input(Some(inv_cached_input), &provider) + .await + .is_err()); + } +} diff --git a/host/src/interfaces.rs b/host/src/interfaces.rs index 0d78cec10..95f385f0b 100644 --- a/host/src/interfaces.rs +++ b/host/src/interfaces.rs @@ -129,3 +129,24 @@ impl From for TaskStatus { } } } + +impl From<&HostError> for TaskStatus { + fn from(value: &HostError) -> Self { + match value { + HostError::HandleDropped + | HostError::CapacityFull + | HostError::JoinHandle(_) + | HostError::InvalidAddress(_) + | HostError::InvalidRequestConfig(_) => unreachable!(), + HostError::Conversion(_) + | HostError::Serde(_) + | HostError::Core(_) + | HostError::Anyhow(_) + | HostError::FeatureNotSupportedError(_) + | HostError::Io(_) => TaskStatus::UnspecifiedFailureReason, + HostError::RPC(_) => TaskStatus::NetworkFailure, + HostError::Guest(_) => TaskStatus::ProofFailure_Generic, + HostError::TaskManager(_) => TaskStatus::SqlDbCorruption, + } + } +} diff --git a/host/src/lib.rs b/host/src/lib.rs index a4efb905e..d46a4849a 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -8,13 +8,14 @@ use raiko_core::{ merge, }; use raiko_lib::consts::SupportedChainSpecs; -use raiko_tasks::TaskManagerOpts; +use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManagerOpts, TaskManagerWrapper}; use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::sync::mpsc; use crate::{interfaces::HostResult, proof::ProofActor}; +pub mod cache; pub mod interfaces; pub mod metrics; pub mod proof; @@ -134,13 +135,29 @@ impl From<&Opts> for TaskManagerOpts { } } -pub type TaskChannelOpts = (ProofRequest, Opts, SupportedChainSpecs); - #[derive(Debug, Clone)] pub struct ProverState { pub opts: Opts, pub chain_specs: SupportedChainSpecs, - pub task_channel: mpsc::Sender, + pub task_channel: mpsc::Sender, +} + +#[derive(Debug, Serialize)] +pub enum Message { + Cancel(TaskDescriptor), + Task(ProofRequest), +} + +impl From<&ProofRequest> for Message { + fn from(value: &ProofRequest) -> Self { + Self::Task(value.clone()) + } +} + +impl From<&TaskDescriptor> for Message { + fn from(value: &TaskDescriptor) -> Self { + Self::Cancel(value.clone()) + } } impl ProverState { @@ -163,10 +180,13 @@ impl ProverState { } } - let (task_channel, receiver) = mpsc::channel::(opts.concurrency_limit); + let (task_channel, receiver) = mpsc::channel::(opts.concurrency_limit); + + let opts_clone = opts.clone(); + let chain_specs_clone = chain_specs.clone(); tokio::spawn(async move { - ProofActor::new(receiver, opts.concurrency_limit) + ProofActor::new(receiver, opts_clone, chain_specs_clone) .run() .await; }); @@ -177,6 +197,14 @@ impl ProverState { task_channel, }) } + + pub fn task_manager(&self) -> TaskManagerWrapper { + get_task_manager(&(&self.opts).into()) + } + + pub fn request_config(&self) -> ProofRequestOpt { + self.opts.proof_request_opt.clone() + } } #[global_allocator] diff --git a/host/src/proof.rs b/host/src/proof.rs index f49fcbc89..46ade302d 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -1,70 +1,152 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use raiko_core::{ interfaces::{ProofRequest, RaikoError}, provider::{get_task_data, rpc::RpcBlockDataProvider}, Raiko, }; -use raiko_lib::{consts::SupportedChainSpecs, Measurement}; -use raiko_tasks::{get_task_manager, TaskManager, TaskStatus}; -use tokio::sync::{mpsc::Receiver, Semaphore}; -use tracing::{error, info}; +use raiko_lib::{ + consts::SupportedChainSpecs, + prover::{IdWrite, Proof}, + Measurement, +}; +use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrapper, TaskStatus}; +use tokio::{ + select, + sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore}, +}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info, warn}; use crate::{ + cache, interfaces::{HostError, HostResult}, memory, metrics::{ inc_guest_error, inc_guest_success, inc_host_error, observe_guest_time, observe_prepare_input_time, observe_total_time, }, - server::api::v1::{ - proof::{get_cached_input, set_cached_input, validate_cache_input}, - ProofResponse, - }, - Opts, TaskChannelOpts, + Message, Opts, }; pub struct ProofActor { - rx: Receiver, - task_count: usize, + opts: Opts, + chain_specs: SupportedChainSpecs, + tasks: Arc>>, + receiver: Receiver, } impl ProofActor { - pub fn new(rx: Receiver, task_count: usize) -> Self { - Self { rx, task_count } + pub fn new(receiver: Receiver, opts: Opts, chain_specs: SupportedChainSpecs) -> Self { + let tasks = Arc::new(Mutex::new( + HashMap::::new(), + )); + + Self { + tasks, + opts, + chain_specs, + receiver, + } + } + + pub async fn cancel_task(&mut self, key: TaskDescriptor) -> HostResult<()> { + let tasks_map = self.tasks.lock().await; + let Some(task) = tasks_map.get(&key) else { + warn!("No task with those keys to cancel"); + return Ok(()); + }; + + let mut manager = get_task_manager(&self.opts.clone().into()); + key.proof_system + .cancel_proof( + (key.chain_id, key.blockhash, key.proof_system as u8), + Box::new(&mut manager), + ) + .await?; + task.cancel(); + Ok(()) + } + + pub async fn run_task(&mut self, proof_request: ProofRequest, _permit: OwnedSemaphorePermit) { + let cancel_token = CancellationToken::new(); + + let Ok((chain_id, blockhash)) = get_task_data( + &proof_request.network, + proof_request.block_number, + &self.chain_specs, + ) + .await + else { + error!("Could not get task data for {proof_request:?}"); + return; + }; + + let key = TaskDescriptor::from(( + chain_id, + blockhash, + proof_request.proof_type, + proof_request.prover.clone().to_string(), + )); + + let mut tasks = self.tasks.lock().await; + tasks.insert(key.clone(), cancel_token.clone()); + + let tasks = self.tasks.clone(); + let opts = self.opts.clone(); + let chain_specs = self.chain_specs.clone(); + + tokio::spawn(async move { + select! { + _ = cancel_token.cancelled() => { + info!("Task cancelled"); + } + result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => { + match result { + Ok(()) => { + info!("Proof generated"); + } + Err(error) => { + error!("Worker failed due to: {error:?}"); + } + }; + } + } + let mut tasks = tasks.lock().await; + tasks.remove(&key); + }); } pub async fn run(&mut self) { - let semaphore = Arc::new(Semaphore::new(self.task_count)); - while let Some(message) = self.rx.recv().await { - let permit = Arc::clone(&semaphore).acquire_owned().await; - tokio::spawn(async move { - let _permit = permit; - if let Err(error) = Self::handle_message(message).await { - error!("Worker failed due to: {error:?}"); + let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit)); + + while let Some(message) = self.receiver.recv().await { + match message { + Message::Cancel(key) => { + if let Err(error) = self.cancel_task(key).await { + error!("Failed to cancel task: {error}") + } } - }); + Message::Task(proof_request) => { + let permit = Arc::clone(&semaphore) + .acquire_owned() + .await + .expect("Couldn't acquire permit"); + self.run_task(proof_request, permit).await; + } + } } } pub async fn handle_message( - (proof_request, opts, chain_specs): TaskChannelOpts, + proof_request: ProofRequest, + key: TaskDescriptor, + opts: &Opts, + chain_specs: &SupportedChainSpecs, ) -> HostResult<()> { - let (chain_id, blockhash) = get_task_data( - &proof_request.network, - proof_request.block_number, - &chain_specs, - ) - .await?; let mut manager = get_task_manager(&opts.clone().into()); - let status = manager - .get_task_proving_status( - chain_id, - blockhash, - proof_request.proof_type, - Some(proof_request.prover.clone().to_string()), - ) - .await?; + + let status = manager.get_task_proving_status(&key).await?; if let Some(latest_status) = status.iter().last() { if !matches!(latest_status.0, TaskStatus::Registered) { @@ -73,47 +155,22 @@ impl ProofActor { } manager - .update_task_progress( - chain_id, - blockhash, - proof_request.proof_type, - Some(proof_request.prover.to_string()), - TaskStatus::WorkInProgress, - None, - ) + .update_task_progress(key.clone(), TaskStatus::WorkInProgress, None) .await?; - match handle_proof(&proof_request, &opts, &chain_specs).await { - Ok(result) => { - let proof_string = result.proof.unwrap_or_default(); - let proof = proof_string.as_bytes(); - - manager - .update_task_progress( - chain_id, - blockhash, - proof_request.proof_type, - Some(proof_request.prover.to_string()), - TaskStatus::Success, - Some(proof), - ) - .await?; - } - Err(error) => { - manager - .update_task_progress( - chain_id, - blockhash, - proof_request.proof_type, - Some(proof_request.prover.to_string()), - error.into(), - None, - ) - .await?; - } - } + let (status, proof) = + match handle_proof(&proof_request, opts, chain_specs, Some(&mut manager)).await { + Err(error) => { + error!("{error}"); + (error.into(), None) + } + Ok(proof) => (TaskStatus::Success, Some(serde_json::to_vec(&proof)?)), + }; - Ok(()) + manager + .update_task_progress(key, status, proof.as_deref()) + .await + .map_err(|e| e.into()) } } @@ -121,14 +178,15 @@ pub async fn handle_proof( proof_request: &ProofRequest, opts: &Opts, chain_specs: &SupportedChainSpecs, -) -> HostResult { + store: Option<&mut TaskManagerWrapper>, +) -> HostResult { info!( "# Generating proof for block {} on {}", proof_request.block_number, proof_request.network ); // Check for a cached input for the given request config. - let cached_input = get_cached_input( + let cached_input = cache::get_input( &opts.cache_path, proof_request.block_number, &proof_request.network.to_string(), @@ -154,7 +212,7 @@ pub async fn handle_proof( &taiko_chain_spec.rpc.clone(), proof_request.block_number - 1, )?; - let input = match validate_cache_input(cached_input, &provider).await { + let input = match cache::validate_input(cached_input, &provider).await { Ok(cache_input) => cache_input, Err(_) => { // no valid cache @@ -173,20 +231,23 @@ pub async fn handle_proof( memory::reset_stats(); let measurement = Measurement::start("Generating proof...", false); - let proof = raiko.prove(input.clone(), &output).await.map_err(|e| { - let total_time = total_time.stop_with("====> Proof generation failed"); - observe_total_time(proof_request.block_number, total_time, false); - match e { - RaikoError::Guest(e) => { - inc_guest_error(&proof_request.proof_type, proof_request.block_number); - HostError::Core(e.into()) - } - e => { - inc_host_error(proof_request.block_number); - e.into() + let proof = raiko + .prove(input.clone(), &output, store.map(|s| s as &mut dyn IdWrite)) + .await + .map_err(|e| { + let total_time = total_time.stop_with("====> Proof generation failed"); + observe_total_time(proof_request.block_number, total_time, false); + match e { + RaikoError::Guest(e) => { + inc_guest_error(&proof_request.proof_type, proof_request.block_number); + HostError::Core(e.into()) + } + e => { + inc_host_error(proof_request.block_number); + e.into() + } } - } - })?; + })?; let guest_time = measurement.stop_with("=> Proof generated"); observe_guest_time( &proof_request.proof_type, @@ -201,12 +262,12 @@ pub async fn handle_proof( observe_total_time(proof_request.block_number, total_time, true); // Cache the input for future use. - set_cached_input( + cache::set_input( &opts.cache_path, proof_request.block_number, &proof_request.network.to_string(), &input, )?; - ProofResponse::try_from(proof) + Ok(proof) } diff --git a/host/src/server/api/mod.rs b/host/src/server/api/mod.rs index 806698a95..4aa8e0981 100644 --- a/host/src/server/api/mod.rs +++ b/host/src/server/api/mod.rs @@ -58,7 +58,7 @@ pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Rout } pub fn create_docs() -> utoipa::openapi::OpenApi { - v1::create_docs() + v2::create_docs() } async fn check_max_body_size(req: Request, next: Next) -> Response { diff --git a/host/src/server/api/v1/proof.rs b/host/src/server/api/v1/proof.rs index e4a960705..dabdcbc0a 100644 --- a/host/src/server/api/v1/proof.rs +++ b/host/src/server/api/v1/proof.rs @@ -1,204 +1,16 @@ -use std::{fs::File, path::PathBuf}; - use axum::{debug_handler, extract::State, routing::post, Json, Router}; -use raiko_core::{ - interfaces::{ProofRequest, RaikoError}, - provider::{rpc::RpcBlockDataProvider, BlockDataProvider}, - Raiko, -}; -use raiko_lib::{ - input::{get_input_path, GuestInput}, - Measurement, -}; +use raiko_core::interfaces::ProofRequest; +use raiko_lib::prover::Proof; use serde_json::Value; -use tracing::{debug, info}; use utoipa::OpenApi; use crate::{ - interfaces::{HostError, HostResult}, - memory, - metrics::{ - dec_current_req, inc_current_req, inc_guest_error, inc_guest_req_count, inc_guest_success, - inc_host_error, inc_host_req_count, observe_guest_time, observe_prepare_input_time, - observe_total_time, - }, - server::api::v1::ProofResponse, + interfaces::HostResult, + metrics::{dec_current_req, inc_current_req, inc_guest_req_count, inc_host_req_count}, + proof::handle_proof, ProverState, }; -pub fn get_cached_input( - cache_path: &Option, - block_number: u64, - network: &str, -) -> Option { - let dir = cache_path.as_ref()?; - - let path = get_input_path(dir, block_number, network); - - let file = File::open(path).ok()?; - - bincode::deserialize_from(file).ok() -} - -pub fn set_cached_input( - cache_path: &Option, - block_number: u64, - network: &str, - input: &GuestInput, -) -> HostResult<()> { - let Some(dir) = cache_path.as_ref() else { - return Ok(()); - }; - - let path = get_input_path(dir, block_number, network); - info!("caching input for {path:?}"); - - let file = File::create(&path).map_err(>::into)?; - bincode::serialize_into(file, input).map_err(|e| HostError::Anyhow(e.into())) -} - -pub async fn validate_cache_input( - cached_input: Option, - provider: &RpcBlockDataProvider, -) -> HostResult { - if let Some(cache_input) = cached_input { - debug!("Using cached input"); - let blocks = provider - .get_blocks(&[(cache_input.block.number, false)]) - .await?; - let block = blocks - .first() - .ok_or_else(|| RaikoError::RPC("No block data for the requested block".to_owned()))?; - - let cached_block_hash = cache_input.block.header.hash_slow(); - let real_block_hash = block.header.hash.unwrap(); - debug!( - "cache_block_hash={:?}, real_block_hash={:?}", - cached_block_hash, real_block_hash - ); - - // double check if cache is valid - if cached_block_hash == real_block_hash { - Ok(cache_input) - } else { - Err(HostError::InvalidRequestConfig( - "Cached input is not valid".to_owned(), - )) - } - } else { - Err(HostError::InvalidRequestConfig( - "Cached input is not enabled".to_owned(), - )) - } -} - -pub async fn handle_proof( - ProverState { - opts, - chain_specs: support_chain_specs, - .. - }: ProverState, - req: Value, -) -> HostResult { - // Override the existing proof request config from the config file and command line - // options with the request from the client. - let mut config = opts.proof_request_opt.clone(); - config.merge(&req)?; - - // Construct the actual proof request from the available configs. - let proof_request = ProofRequest::try_from(config)?; - inc_host_req_count(proof_request.block_number); - inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); - - info!( - "# Generating proof for block {} on {}", - proof_request.block_number, proof_request.network - ); - - // Check for a cached input for the given request config. - let cached_input = get_cached_input( - &opts.cache_path, - proof_request.block_number, - &proof_request.network.to_string(), - ); - - let l1_chain_spec = support_chain_specs - .get_chain_spec(&proof_request.l1_network.to_string()) - .ok_or_else(|| HostError::InvalidRequestConfig("Unsupported l1 network".to_string()))?; - - let taiko_chain_spec = support_chain_specs - .get_chain_spec(&proof_request.network.to_string()) - .ok_or_else(|| HostError::InvalidRequestConfig("Unsupported raiko network".to_string()))?; - - // Execute the proof generation. - let total_time = Measurement::start("", false); - - let raiko = Raiko::new( - l1_chain_spec.clone(), - taiko_chain_spec.clone(), - proof_request.clone(), - ); - let provider = RpcBlockDataProvider::new( - &taiko_chain_spec.rpc.clone(), - proof_request.block_number - 1, - )?; - let input = match validate_cache_input(cached_input, &provider).await { - Ok(cache_input) => cache_input, - Err(_) => { - // no valid cache - memory::reset_stats(); - let measurement = Measurement::start("Generating input...", false); - let input = raiko.generate_input(provider).await?; - let input_time = measurement.stop_with("=> Input generated"); - observe_prepare_input_time(proof_request.block_number, input_time, true); - memory::print_stats("Input generation peak memory used: "); - input - } - }; - memory::reset_stats(); - let output = raiko.get_output(&input)?; - memory::print_stats("Guest program peak memory used: "); - - memory::reset_stats(); - let measurement = Measurement::start("Generating proof...", false); - let proof = raiko.prove(input.clone(), &output).await.map_err(|e| { - let total_time = total_time.stop_with("====> Proof generation failed"); - observe_total_time(proof_request.block_number, total_time, false); - match e { - RaikoError::Guest(e) => { - inc_guest_error(&proof_request.proof_type, proof_request.block_number); - HostError::Core(e.into()) - } - e => { - inc_host_error(proof_request.block_number); - e.into() - } - } - })?; - let guest_time = measurement.stop_with("=> Proof generated"); - observe_guest_time( - &proof_request.proof_type, - proof_request.block_number, - guest_time, - true, - ); - memory::print_stats("Prover peak memory used: "); - - inc_guest_success(&proof_request.proof_type, proof_request.block_number); - let total_time = total_time.stop_with("====> Complete proof generated"); - observe_total_time(proof_request.block_number, total_time, true); - - // Cache the input for future use. - set_cached_input( - &opts.cache_path, - proof_request.block_number, - &proof_request.network.to_string(), - &input, - )?; - - ProofResponse::try_from(proof) -} - #[utoipa::path(post, path = "/proof", tag = "Proving", request_body = ProofRequestOpt, @@ -218,12 +30,30 @@ pub async fn handle_proof( async fn proof_handler( State(prover_state): State, Json(req): Json, -) -> HostResult { +) -> HostResult> { inc_current_req(); - handle_proof(prover_state, req).await.map_err(|e| { + // Override the existing proof request config from the config file and command line + // options with the request from the client. + let mut config = prover_state.request_config(); + config.merge(&req)?; + + // Construct the actual proof request from the available configs. + let proof_request = ProofRequest::try_from(config)?; + inc_host_req_count(proof_request.block_number); + inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); + + handle_proof( + &proof_request, + &prover_state.opts, + &prover_state.chain_specs, + None, + ) + .await + .map_err(|e| { dec_current_req(); e }) + .map(Json) } #[derive(OpenApi)] @@ -237,77 +67,3 @@ pub fn create_docs() -> utoipa::openapi::OpenApi { pub fn create_router() -> Router { Router::new().route("/", post(proof_handler)) } - -#[cfg(test)] -mod test { - use super::*; - use alloy_primitives::{Address, B256}; - use raiko_core::interfaces::ProofType; - use raiko_lib::consts::{Network, SupportedChainSpecs}; - use raiko_lib::input::BlobProofType; - - async fn create_cache_input( - l1_network: &String, - network: &String, - block_number: u64, - ) -> (GuestInput, RpcBlockDataProvider) { - let l1_chain_spec = SupportedChainSpecs::default() - .get_chain_spec(l1_network) - .unwrap(); - let taiko_chain_spec = SupportedChainSpecs::default() - .get_chain_spec(network) - .unwrap(); - let proof_request = ProofRequest { - block_number, - network: network.to_string(), - l1_network: l1_network.to_string(), - graffiti: B256::ZERO, - prover: Address::ZERO, - proof_type: ProofType::Native, - blob_proof_type: BlobProofType::ProofOfCommitment, - prover_args: Default::default(), - }; - let raiko = Raiko::new( - l1_chain_spec.clone(), - taiko_chain_spec.clone(), - proof_request.clone(), - ); - let provider = RpcBlockDataProvider::new( - &taiko_chain_spec.rpc.clone(), - proof_request.block_number - 1, - ) - .expect("provider init ok"); - - let input = raiko - .generate_input(provider.clone()) - .await - .expect("input generation failed"); - (input, provider.clone()) - } - - #[tokio::test] - async fn test_generate_input_from_cache() { - let l1 = &Network::Holesky.to_string(); - let l2 = &Network::TaikoA7.to_string(); - let block_number: u64 = 123456; - let (input, provider) = create_cache_input(l1, l2, block_number).await; - let cache_path = Some("./".into()); - assert!(set_cached_input(&cache_path, block_number, l2, &input).is_ok()); - let cached_input = get_cached_input(&cache_path, block_number, l2).expect("load cache"); - assert!(validate_cache_input(Some(cached_input), &provider) - .await - .is_ok()); - - let new_l1 = &Network::Ethereum.to_string(); - let new_l2 = &Network::TaikoMainnet.to_string(); - let (new_input, _) = create_cache_input(new_l1, new_l2, block_number).await; - // save to old l2 cache slot - assert!(set_cached_input(&cache_path, block_number, l2, &new_input).is_ok()); - let inv_cached_input = get_cached_input(&cache_path, block_number, l2).expect("load cache"); - - // should fail with old provider - assert!(validate_cache_input(Some(inv_cached_input), &provider) - .await - .is_err()); - } -} diff --git a/host/src/server/api/v2/mod.rs b/host/src/server/api/v2/mod.rs index 29091dbea..fe8003bcc 100644 --- a/host/src/server/api/v2/mod.rs +++ b/host/src/server/api/v2/mod.rs @@ -1,6 +1,7 @@ use axum::{response::IntoResponse, Json, Router}; +use raiko_lib::prover::Proof; use raiko_tasks::TaskStatus; -use serde::{Serialize, Serializer}; +use serde::Serialize; use utoipa::{OpenApi, ToSchema}; use utoipa_scalar::{Scalar, Servable}; use utoipa_swagger_ui::SwaggerUi; @@ -58,24 +59,11 @@ pub enum ProofResponse { status: TaskStatus, }, Proof { - #[serde(serialize_with = "ProofResponse::serialize_proof")] /// The proof. - proof: Option>, + proof: Proof, }, } -impl ProofResponse { - fn serialize_proof(proof: &Option>, serializer: S) -> Result - where - S: Serializer, - { - match proof { - Some(value) => serializer.serialize_str(&String::from_utf8(value.clone()).unwrap()), - None => serializer.serialize_str(""), - } - } -} - #[derive(Debug, Serialize, ToSchema)] #[serde(tag = "status", rename_all = "lowercase")] pub enum Status { @@ -86,7 +74,9 @@ pub enum Status { impl From> for Status { fn from(proof: Vec) -> Self { Self::Ok { - data: ProofResponse::Proof { proof: Some(proof) }, + data: ProofResponse::Proof { + proof: serde_json::from_slice(&proof).unwrap_or_default(), + }, } } } diff --git a/host/src/server/api/v2/proof/cancel.rs b/host/src/server/api/v2/proof/cancel.rs index 62c892ac9..0f6b1bb61 100644 --- a/host/src/server/api/v2/proof/cancel.rs +++ b/host/src/server/api/v2/proof/cancel.rs @@ -1,10 +1,10 @@ use axum::{debug_handler, extract::State, routing::post, Json, Router}; use raiko_core::{interfaces::ProofRequest, provider::get_task_data}; -use raiko_tasks::{get_task_manager, TaskManager, TaskStatus}; +use raiko_tasks::{TaskDescriptor, TaskManager, TaskStatus}; use serde_json::Value; use utoipa::OpenApi; -use crate::{interfaces::HostResult, server::api::v2::CancelStatus, ProverState}; +use crate::{interfaces::HostResult, server::api::v2::CancelStatus, Message, ProverState}; #[utoipa::path(post, path = "/proof/cancel", tag = "Proving", @@ -28,7 +28,7 @@ async fn cancel_handler( ) -> HostResult { // Override the existing proof request config from the config file and command line // options with the request from the client. - let mut config = prover_state.opts.proof_request_opt.clone(); + let mut config = prover_state.request_config(); config.merge(&req)?; // Construct the actual proof request from the available configs. @@ -41,17 +41,19 @@ async fn cancel_handler( ) .await?; - let mut manager = get_task_manager(&(&prover_state.opts).into()); + let key = TaskDescriptor::from(( + chain_id, + block_hash, + proof_request.proof_type, + proof_request.prover.clone().to_string(), + )); + + prover_state.task_channel.try_send(Message::from(&key))?; + + let mut manager = prover_state.task_manager(); manager - .update_task_progress( - chain_id, - block_hash, - proof_request.proof_type, - Some(proof_request.prover.to_string()), - TaskStatus::Cancelled, - None, - ) + .update_task_progress(key, TaskStatus::Cancelled, None) .await?; Ok(CancelStatus::Ok) diff --git a/host/src/server/api/v2/proof/mod.rs b/host/src/server/api/v2/proof/mod.rs index 6e66e5506..4ecd9ce8b 100644 --- a/host/src/server/api/v2/proof/mod.rs +++ b/host/src/server/api/v2/proof/mod.rs @@ -1,8 +1,6 @@ use axum::{debug_handler, extract::State, routing::post, Json, Router}; use raiko_core::{interfaces::ProofRequest, provider::get_task_data}; -use raiko_tasks::{ - get_task_manager, EnqueueTaskParams, TaskManager, TaskProvingStatus, TaskStatus, -}; +use raiko_tasks::{TaskDescriptor, TaskManager, TaskStatus}; use serde_json::Value; use utoipa::OpenApi; @@ -10,7 +8,7 @@ use crate::{ interfaces::HostResult, metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, server::api::v2::Status, - ProverState, + Message, ProverState, }; mod cancel; @@ -40,7 +38,7 @@ async fn proof_handler( inc_current_req(); // Override the existing proof request config from the config file and command line // options with the request from the client. - let mut config = prover_state.opts.proof_request_opt.clone(); + let mut config = prover_state.request_config(); config.merge(&req)?; // Construct the actual proof request from the available configs. @@ -48,40 +46,30 @@ async fn proof_handler( inc_host_req_count(proof_request.block_number); inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); - let (chain_id, block_hash) = get_task_data( + let (chain_id, blockhash) = get_task_data( &proof_request.network, proof_request.block_number, &prover_state.chain_specs, ) .await?; - let mut manager = get_task_manager(&(&prover_state.opts).into()); - let status = manager - .get_task_proving_status( - chain_id, - block_hash, - proof_request.proof_type, - Some(proof_request.prover.to_string()), - ) - .await?; - - let Some(TaskProvingStatus(latest_status, ..)) = status.last() else { + let key = TaskDescriptor::from(( + chain_id, + blockhash, + proof_request.proof_type, + proof_request.prover.to_string(), + )); + + let mut manager = prover_state.task_manager(); + let status = manager.get_task_proving_status(&key).await?; + + let Some((latest_status, ..)) = status.last() else { // If there are no tasks with provided config, create a new one. - manager - .enqueue_task(&EnqueueTaskParams { - chain_id, - blockhash: block_hash, - proof_type: proof_request.proof_type, - prover: proof_request.prover.to_string(), - block_number: proof_request.block_number, - }) - .await?; - - prover_state.task_channel.try_send(( - proof_request.clone(), - prover_state.opts, - prover_state.chain_specs, - ))?; + manager.enqueue_task(&key).await?; + + prover_state + .task_channel + .try_send(Message::from(&proof_request))?; return Ok(TaskStatus::Registered.into()); }; @@ -93,33 +81,18 @@ async fn proof_handler( | TaskStatus::Cancelled_NeverStarted | TaskStatus::CancellationInProgress => { manager - .enqueue_task(&EnqueueTaskParams { - chain_id, - blockhash: block_hash, - proof_type: proof_request.proof_type, - prover: proof_request.prover.to_string(), - block_number: proof_request.block_number, - }) + .update_task_progress(key, TaskStatus::Registered, None) .await?; - prover_state.task_channel.try_send(( - proof_request.clone(), - prover_state.opts, - prover_state.chain_specs, - ))?; + prover_state + .task_channel + .try_send(Message::from(&proof_request))?; Ok(TaskStatus::Registered.into()) } // If the task has succeeded, return the proof. TaskStatus::Success => { - let proof = manager - .get_task_proof( - chain_id, - block_hash, - proof_request.proof_type, - Some(proof_request.prover.to_string()), - ) - .await?; + let proof = manager.get_task_proof(&key).await?; Ok(proof.into()) } diff --git a/host/src/server/api/v2/proof/prune.rs b/host/src/server/api/v2/proof/prune.rs index 9e14b3f1a..166fa4413 100644 --- a/host/src/server/api/v2/proof/prune.rs +++ b/host/src/server/api/v2/proof/prune.rs @@ -1,5 +1,5 @@ use axum::{debug_handler, extract::State, routing::post, Router}; -use raiko_tasks::{get_task_manager, TaskManager}; +use raiko_tasks::TaskManager; use utoipa::OpenApi; use crate::{interfaces::HostResult, server::api::v2::PruneStatus, ProverState}; @@ -13,7 +13,7 @@ use crate::{interfaces::HostResult, server::api::v2::PruneStatus, ProverState}; #[debug_handler(state = ProverState)] /// Prune all tasks. async fn prune_handler(State(prover_state): State) -> HostResult { - let mut manager = get_task_manager(&(&prover_state.opts).into()); + let mut manager = prover_state.task_manager(); manager.prune_db().await?; diff --git a/host/src/server/api/v2/proof/report.rs b/host/src/server/api/v2/proof/report.rs index c5a3cb095..92388268f 100644 --- a/host/src/server/api/v2/proof/report.rs +++ b/host/src/server/api/v2/proof/report.rs @@ -1,5 +1,5 @@ -use axum::{debug_handler, extract::State, routing::post, Json, Router}; -use raiko_tasks::{get_task_manager, TaskManager}; +use axum::{debug_handler, extract::State, routing::get, Json, Router}; +use raiko_tasks::TaskManager; use serde_json::Value; use utoipa::OpenApi; @@ -16,7 +16,7 @@ use crate::{interfaces::HostResult, ProverState}; /// /// Retrieve a list of `{ chain_id, blockhash, prover_type, prover, status }` items. async fn report_handler(State(prover_state): State) -> HostResult> { - let mut manager = get_task_manager(&(&prover_state.opts).into()); + let mut manager = prover_state.task_manager(); let task_report = manager.list_all_tasks().await?; @@ -32,5 +32,5 @@ pub fn create_docs() -> utoipa::openapi::OpenApi { } pub fn create_router() -> Router { - Router::new().route("/", post(report_handler)) + Router::new().route("/", get(report_handler)) } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 0190ed6b0..75cb0d04f 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -5,10 +5,10 @@ edition = "2021" [dependencies] # reth -reth-primitives.workspace = true -reth-evm-ethereum.workspace = true -reth-evm.workspace = true -reth-chainspec.workspace = true +reth-primitives = { workspace = true } +reth-evm-ethereum = { workspace = true } +reth-evm = { workspace = true } +reth-chainspec = { workspace = true } # alloy alloy-rlp = { workspace = true } @@ -39,6 +39,9 @@ sha2 = { workspace = true } sha3 = { workspace = true } rlp = { workspace = true, features = ["std"] } +# docs +utoipa = { workspace = true } + # misc cfg-if = { workspace = true } tracing = { workspace = true } @@ -72,4 +75,4 @@ sgx = [] sp1 = [] risc0 = [] sp1-cycle-tracker = [] -proof_of_equivalence = [] \ No newline at end of file +proof_of_equivalence = [] diff --git a/lib/src/input.rs b/lib/src/input.rs index 51e034677..09ac13bf2 100644 --- a/lib/src/input.rs +++ b/lib/src/input.rs @@ -15,7 +15,7 @@ use reth_primitives::{Block, Header}; #[cfg(not(feature = "std"))] use crate::no_std::*; -use crate::{consts::ChainSpec, primitives::mpt::MptNode}; +use crate::{consts::ChainSpec, primitives::mpt::MptNode, utils::zlib_compress_data}; /// Represents the state of an account's storage. /// The storage trie together with the used storage slots allow us to reconstruct all the @@ -44,6 +44,20 @@ pub struct GuestInput { pub taiko: TaikoGuestInput, } +impl From<(Block, Header, ChainSpec, TaikoGuestInput)> for GuestInput { + fn from( + (block, parent_header, chain_spec, taiko): (Block, Header, ChainSpec, TaikoGuestInput), + ) -> Self { + Self { + block, + chain_spec, + taiko, + parent_header, + ..Self::default() + } + } +} + #[serde_as] #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct TaikoGuestInput { @@ -57,6 +71,21 @@ pub struct TaikoGuestInput { pub blob_proof_type: BlobProofType, } +pub struct ZlibCompressError(pub String); + +impl TryFrom> for TaikoGuestInput { + type Error = ZlibCompressError; + + fn try_from(value: Vec) -> Result { + let tx_data = zlib_compress_data(&alloy_rlp::encode(&value)) + .map_err(|e| ZlibCompressError(e.to_string()))?; + Ok(Self { + tx_data, + ..Self::default() + }) + } +} + #[derive(Clone, Debug, Serialize, Deserialize, Default)] pub enum BlobProofType { /// Guest runs through the entire computation from blob to Kzg commitment diff --git a/lib/src/primitives/eip4844.rs b/lib/src/primitives/eip4844.rs index 9dd31ae9d..d62851da1 100644 --- a/lib/src/primitives/eip4844.rs +++ b/lib/src/primitives/eip4844.rs @@ -117,13 +117,13 @@ mod test { // The input is encoded as follows: // | versioned_hash | z | y | commitment | proof | // | 32 | 32 | 32 | 48 | 48 | - let version_hash = commitment_to_version_hash(&commitment); + let version_hash = commitment_to_version_hash(commitment); let mut input = [0u8; 192]; input[..32].copy_from_slice(&(*version_hash)); input[32..64].copy_from_slice(&z.to_bytes()); input[64..96].copy_from_slice(&y.to_bytes()); input[96..144].copy_from_slice(commitment); - input[144..192].copy_from_slice(&kzg_proof_to_bytes(&proof)); + input[144..192].copy_from_slice(&kzg_proof_to_bytes(proof)); Ok(reth_primitives::revm_precompile::kzg_point_evaluation::run( &Bytes::copy_from_slice(&input), diff --git a/lib/src/prover.rs b/lib/src/prover.rs index e5d93343d..e43b511ce 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -1,9 +1,10 @@ -use serde::Serialize; -use thiserror::Error as ThisError; +use reth_primitives::{ChainId, B256}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; use crate::input::{GuestInput, GuestOutput}; -#[derive(ThisError, Debug)] +#[derive(thiserror::Error, Debug)] pub enum ProverError { #[error("ProverError::GuestError `{0}`")] GuestError(String), @@ -11,6 +12,8 @@ pub enum ProverError { FileIo(#[from] std::io::Error), #[error("ProverError::Param `{0}`")] Param(#[from] serde_json::Error), + #[error("Store error `{0}`")] + StoreError(String), } impl From for ProverError { @@ -21,7 +24,28 @@ impl From for ProverError { pub type ProverResult = core::result::Result; pub type ProverConfig = serde_json::Value; -pub type Proof = serde_json::Value; +pub type ProofKey = (ChainId, B256, u8); + +#[derive(Debug, Serialize, ToSchema, Deserialize, Default)] +/// The response body of a proof request. +pub struct Proof { + /// The ZK proof. + pub proof: Option, + /// The TEE quote. + pub quote: Option, + /// The kzg proof. + pub kzg_proof: Option, +} + +pub trait IdWrite: Send { + fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>; + + fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>; +} + +pub trait IdStore: IdWrite { + fn read_id(&self, key: ProofKey) -> ProverResult; +} #[allow(async_fn_in_trait)] pub trait Prover { @@ -29,11 +53,8 @@ pub trait Prover { input: GuestInput, output: &GuestOutput, config: &ProverConfig, + store: Option<&mut dyn IdWrite>, ) -> ProverResult; -} -pub fn to_proof(proof: ProverResult) -> ProverResult { - proof.and_then(|res| { - serde_json::to_value(res).map_err(|err| ProverError::GuestError(err.to_string())) - }) + async fn cancel(proof_key: ProofKey, read: Box<&mut dyn IdStore>) -> ProverResult<()>; } diff --git a/provers/risc0/driver/src/bonsai.rs b/provers/risc0/driver/src/bonsai.rs index c2f4f08a2..7be4ea5d4 100644 --- a/provers/risc0/driver/src/bonsai.rs +++ b/provers/risc0/driver/src/bonsai.rs @@ -1,5 +1,8 @@ use log::{debug, error, info, warn}; -use raiko_lib::primitives::keccak::keccak; +use raiko_lib::{ + primitives::keccak::keccak, + prover::{IdWrite, ProofKey}, +}; use risc0_zkvm::{ compute_image_id, is_dev_mode, serde::to_vec, sha::Digest, Assumption, ExecutorEnv, ExecutorImpl, Receipt, @@ -89,6 +92,8 @@ pub async fn maybe_prove, Vec), + proof_key: ProofKey, + id_store: &mut Option<&mut dyn IdWrite>, ) -> Option<(String, Receipt)> { let (assumption_instances, assumption_uuids) = assumptions; @@ -115,6 +120,8 @@ pub async fn maybe_prove anyhow::Result { Ok(client.upload_receipt(bincode::serialize(receipt)?)?) } +pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> { + let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?; + let session = bonsai_sdk::alpha::SessionId { uuid }; + session.stop(&client)?; + Ok(()) +} + pub async fn prove_bonsai( encoded_input: Vec, elf: &[u8], expected_output: &O, assumption_uuids: Vec, + proof_key: ProofKey, + id_store: &mut Option<&mut dyn IdWrite>, ) -> anyhow::Result<(String, Receipt)> { info!("Proving on Bonsai"); // Compute the image_id, then upload the ELF with the image_id as its key. @@ -202,6 +218,10 @@ pub async fn prove_bonsai( assumption_uuids.clone(), )?; + if let Some(id_store) = id_store { + id_store.store_id(proof_key, session.uuid.clone())?; + } + verify_bonsai_receipt(image_id, expected_output, session.uuid.clone(), 8).await } diff --git a/provers/risc0/driver/src/lib.rs b/provers/risc0/driver/src/lib.rs index c50145a1e..f9237fe77 100644 --- a/provers/risc0/driver/src/lib.rs +++ b/provers/risc0/driver/src/lib.rs @@ -4,7 +4,7 @@ use alloy_primitives::B256; use hex::ToHex; use raiko_lib::{ input::{GuestInput, GuestOutput}, - prover::{to_proof, Proof, Prover, ProverConfig, ProverError, ProverResult}, + prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; use risc0_zkvm::{serde::to_vec, sha::Digest}; use serde::{Deserialize, Serialize}; @@ -36,15 +36,35 @@ pub struct Risc0Param { pub struct Risc0Response { pub proof: String, } + +impl From for Proof { + fn from(value: Risc0Response) -> Self { + Self { + proof: Some(value.proof), + quote: None, + kzg_proof: None, + } + } +} + pub struct Risc0Prover; +const RISC0_PROVER_CODE: u8 = 3; + impl Prover for Risc0Prover { async fn run( input: GuestInput, output: &GuestOutput, config: &ProverConfig, + id_store: Option<&mut dyn IdWrite>, ) -> ProverResult { + let mut id_store = id_store; let config = Risc0Param::deserialize(config.get("risc0").unwrap()).unwrap(); + let proof_key = ( + input.chain_spec.chain_id, + output.hash.clone(), + RISC0_PROVER_CODE, + ); debug!("elf code length: {}", RISC0_GUEST_ELF.len()); let encoded_input = to_vec(&input).expect("Could not serialize proving input!"); @@ -55,6 +75,8 @@ impl Prover for Risc0Prover { RISC0_GUEST_ELF, &output.hash, Default::default(), + proof_key, + &mut id_store, ) .await; @@ -80,7 +102,16 @@ impl Prover for Risc0Prover { .map_err(|err| format!("Failed to verify SNARK: {err:?}"))?; } - to_proof(Ok(Risc0Response { proof: journal })) + Ok(Risc0Response { proof: journal }.into()) + } + + async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> { + let uuid = id_store.read_id(key)?; + cancel_proof(uuid) + .await + .map_err(|e| ProverError::GuestError(e.to_string()))?; + id_store.remove_id(key)?; + Ok(()) } } diff --git a/provers/sgx/prover/src/lib.rs b/provers/sgx/prover/src/lib.rs index f4582cf28..7b82f0e2c 100644 --- a/provers/sgx/prover/src/lib.rs +++ b/provers/sgx/prover/src/lib.rs @@ -11,7 +11,7 @@ use std::{ use once_cell::sync::Lazy; use raiko_lib::{ input::{GuestInput, GuestOutput}, - prover::{to_proof, Proof, Prover, ProverConfig, ProverError, ProverResult}, + prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -44,6 +44,16 @@ pub struct SgxResponse { pub quote: String, } +impl From for Proof { + fn from(value: SgxResponse) -> Self { + Self { + proof: Some(value.proof), + quote: Some(value.quote), + kzg_proof: None, + } + } +} + pub const ELF_NAME: &str = "sgx-guest"; pub const CONFIG: &str = if cfg!(feature = "docker_build") { "../provers/sgx/config" @@ -61,6 +71,7 @@ impl Prover for SgxProver { input: GuestInput, _output: &GuestOutput, config: &ProverConfig, + _store: Option<&mut dyn IdWrite>, ) -> ProverResult { let sgx_param = SgxParam::deserialize(config.get("sgx").unwrap()).unwrap(); @@ -134,7 +145,11 @@ impl Prover for SgxProver { sgx_proof = prove(gramine_cmd(), input.clone(), sgx_param.instance_id).await } - to_proof(sgx_proof) + sgx_proof.map(|r| r.into()) + } + + async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> { + Ok(()) } } diff --git a/provers/sp1/driver/src/benchmark.rs b/provers/sp1/driver/src/benchmark.rs index 7c0434f81..61b950e16 100644 --- a/provers/sp1/driver/src/benchmark.rs +++ b/provers/sp1/driver/src/benchmark.rs @@ -20,33 +20,37 @@ fn prove(elf: &[u8]) { } #[bench] -fn bench_sha256(b: &mut Bencher) { - run_once(|b| { +fn bench_sha256(_: &mut Bencher) { + run_once(|_| { prove(SHA256_ELF); Ok(()) - }); + }) + .unwrap(); } #[bench] -fn bench_ecdsa(b: &mut Bencher) { - run_once(|b| { +fn bench_ecdsa(_: &mut Bencher) { + run_once(|_| { prove(ECDSA_ELF); Ok(()) - }); + }) + .unwrap(); } #[bench] -fn bench_bn254_add(b: &mut Bencher) { - run_once(|b| { +fn bench_bn254_add(_: &mut Bencher) { + run_once(|_| { prove(BN254_ADD_ELF); Ok(()) - }); + }) + .unwrap(); } #[bench] -fn bench_bn254_mul(b: &mut Bencher) { - run_once(|b| { +fn bench_bn254_mul(_: &mut Bencher) { + run_once(|_| { prove(BN254_MUL_ELF); Ok(()) - }); + }) + .unwrap(); } diff --git a/provers/sp1/driver/src/lib.rs b/provers/sp1/driver/src/lib.rs index 1b46adbbe..42937bc42 100644 --- a/provers/sp1/driver/src/lib.rs +++ b/provers/sp1/driver/src/lib.rs @@ -1,11 +1,15 @@ #![cfg(feature = "enable")] use raiko_lib::{ input::{GuestInput, GuestOutput}, - prover::{to_proof, Proof, Prover, ProverConfig, ProverError, ProverResult}, + prover::{IdStore, IdWrite, Proof, ProofKey, Prover, ProverConfig, ProverError, ProverResult}, }; use serde::{Deserialize, Serialize}; -use sp1_sdk::{ProverClient, SP1Stdin}; -use std::env; +use sp1_sdk::{ + network::client::NetworkClient, + proto::network::{ProofMode, ProofStatus, UnclaimReason}, + ProverClient, SP1Stdin, +}; +use std::{env, thread::sleep, time::Duration}; use tracing::info as tracing_info; const ELF: &[u8] = include_bytes!("../../guest/elf/sp1-guest"); @@ -15,13 +19,26 @@ pub struct Sp1Response { pub proof: String, } +impl From for Proof { + fn from(value: Sp1Response) -> Self { + Self { + proof: Some(value.proof), + quote: None, + kzg_proof: None, + } + } +} + pub struct Sp1Prover; +const SP1_PROVER_CODE: u8 = 1; + impl Prover for Sp1Prover { async fn run( input: GuestInput, - _output: &GuestOutput, + output: &GuestOutput, _config: &ProverConfig, + id_store: Option<&mut dyn IdWrite>, ) -> ProverResult { // Write the input. let mut stdin = SP1Stdin::new(); @@ -30,9 +47,68 @@ impl Prover for Sp1Prover { // Generate the proof for the given program. let client = ProverClient::new(); let (pk, vk) = client.setup(ELF); - let proof = client - .prove(&pk, stdin) - .map_err(|_| ProverError::GuestError("Sp1: proving failed".to_owned()))?; + let local = true; + let proof = match local { + true => { + let proof = client + .prove(&pk, stdin) + .map_err(|_| ProverError::GuestError("Sp1: proving failed".to_owned()))?; + Ok::<_, ProverError>(proof) + } + false => { + let private_key = env::var("SP1_PRIVATE_KEY").map_err(|_| { + ProverError::GuestError( + "SP1_PRIVATE_KEY must be set for remote proving".to_owned(), + ) + })?; + let network_client = NetworkClient::new(&private_key); + let proof_id = network_client + .create_proof(&pk.elf, &stdin, ProofMode::Core, "v1.0.8-testnet") + .await + .map_err(|_| { + ProverError::GuestError("Sp1: creating proof failed".to_owned()) + })?; + if let Some(id_store) = id_store { + id_store.store_id( + (input.chain_spec.chain_id, output.hash, SP1_PROVER_CODE), + proof_id.clone(), + )?; + } + let proof = { + let mut is_claimed = false; + loop { + let (status, maybe_proof) = network_client + .get_proof_status(&proof_id) + .await + .map_err(|_| { + ProverError::GuestError( + "Sp1: getting proof status failed".to_owned(), + ) + })?; + + match status.status() { + ProofStatus::ProofFulfilled => { + break Ok(maybe_proof.unwrap()); + } + ProofStatus::ProofClaimed => { + if !is_claimed { + is_claimed = true; + } + } + ProofStatus::ProofUnclaimed => { + break Err(ProverError::GuestError(format!( + "Proof generation failed: {}", + status.unclaim_description() + ))); + } + _ => {} + } + sleep(Duration::from_secs(2)); + } + }?; + Ok::<_, ProverError>(proof) + } + }?; // Verify proof. client @@ -53,9 +129,24 @@ impl Prover for Sp1Prover { .map_err(|_| ProverError::GuestError("Sp1: saving proof failed".to_owned()))?; tracing_info!("successfully generated and verified proof for the program!"); - to_proof(Ok(Sp1Response { + Ok(Sp1Response { proof: serde_json::to_string(&proof).unwrap(), - })) + } + .into()) + } + + async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> { + let proof_id = id_store.read_id(key)?; + let private_key = env::var("SP1_PRIVATE_KEY").map_err(|_| { + ProverError::GuestError("SP1_PRIVATE_KEY must be set for remote proving".to_owned()) + })?; + let network_client = NetworkClient::new(&private_key); + network_client + .unclaim_proof(proof_id, UnclaimReason::Abandoned, "".to_owned()) + .await + .map_err(|_| ProverError::GuestError("Sp1: couldn't unclaim proof".to_owned()))?; + id_store.remove_id(key)?; + Ok(()) } } diff --git a/script/cancel-block.sh b/script/cancel-block.sh new file mode 100755 index 000000000..bc6d3efba --- /dev/null +++ b/script/cancel-block.sh @@ -0,0 +1,132 @@ +#!/usr/bin/env bash + +getBlockNumber() { + # Get the latest block number from the node + output=$(curl $rpc -s -X POST -H "Content-Type: application/json" --data '{"method":"eth_blockNumber","params":[],"id":1,"jsonrpc":"2.0"}') + + # Extract the hexadecimal number using jq and remove the surrounding quotes + hex_number=$(echo $output | jq -r '.result') + + # Convert the hexadecimal to decimal + block_number=$(echo $((${hex_number}))) + + # Return the block number by echoing it + echo "$block_number" +} + +# Use the first command line argument as the chain name +chain="$1" +# Use the second command line argument as the proof type +proof="$2" +# Use the third(/fourth) parameter(s) as the block number as a range +# Use the special value "sync" as the third parameter to follow the tip of the chain +rangeStart="$3" +rangeEnd="$4" + +# Check the chain name and set the corresponding RPC values +if [ "$chain" == "ethereum" ]; then + l1_network="ethereum" +elif [ "$chain" == "holesky" ]; then + l1_network="holesky" +elif [ "$chain" == "taiko_mainnet" ]; then + l1_network="ethereum" +elif [ "$chain" == "taiko_a7" ]; then + l1_network="holesky" +else + echo "Using customized chain name $1. Please double check the RPCs." + l1_network="holesky" +fi + +if [ "$proof" == "native" ]; then + proofParam=' + "proof_type": "native", + "native" : { + "write_guest_input_path": null + } + ' +elif [ "$proof" == "sp1" ]; then + proofParam=' + "proof_type": "sp1" + ' +elif [ "$proof" == "sgx" ]; then + proofParam=' + "proof_type": "sgx", + "sgx" : { + "instance_id": 123, + "setup": false, + "bootstrap": false, + "prove": true, + "input_path": null + } + ' +elif [ "$proof" == "risc0" ]; then + proofParam=' + "proof_type": "risc0", + "risc0": { + "bonsai": false, + "snark": false, + "profile": true, + "execution_po2": 18 + } + ' +elif [ "$proof" == "risc0-bonsai" ]; then + proofParam=' + "proof_type": "risc0", + "risc0": { + "bonsai": true, + "snark": true, + "profile": false, + "execution_po2": 20 + } + ' +else + echo "Invalid proof name. Please use 'native', 'risc0[-bonsai]', 'sp1', or 'sgx'." + exit 1 +fi + +if [ "$rangeStart" == "sync" ]; then + sync="true" + rangeStart=$(getBlockNumber) + rangeEnd=$((rangeStart + 1000000)) + sleep 1.0 +fi + +if [ "$rangeStart" == "" ]; then + echo "Please specify a valid block range like \"10\" or \"10 20\"" + exit 1 +fi + +if [ "$rangeEnd" == "" ]; then + rangeEnd=$rangeStart +fi + +prover="0x70997970C51812dc3A010C7d01b50e0d17dc79C8" +graffiti="8008500000000000000000000000000000000000000000000000000000000000" + +for block in $(eval echo {$rangeStart..$rangeEnd}); do + # Special sync logic to follow the tip of the chain + if [ "$sync" == "true" ]; then + block_number=$(getBlockNumber) + # While the current block is greater than the block number from the blockchain + while [ "$block" -gt "$block_number" ]; do + sleep 0.1 # Wait for 100ms + block_number=$(getBlockNumber) # Query again to get the updated block number + done + # Sleep a bit longer because sometimes the block data isn't available yet + sleep 1.0 + fi + + echo "- proving block $block" + curl --location --request POST 'http://localhost:8080/proof/cancel' \ + --header 'Content-Type: application/json' \ + --header 'Authorization: Bearer 4cbd753fbcbc2639de804f8ce425016a50e0ecd53db00cb5397912e83f5e570e' \ + --data-raw "{ + \"network\": \"$chain\", + \"l1_network\": \"$l1_network\", + \"block_number\": $block, + \"prover\": \"$prover\", + \"graffiti\": \"$graffiti\", + $proofParam + }" + echo "" +done diff --git a/tasks/src/adv_sqlite.rs b/tasks/src/adv_sqlite.rs index 692cb543f..f362bbb74 100644 --- a/tasks/src/adv_sqlite.rs +++ b/tasks/src/adv_sqlite.rs @@ -159,16 +159,18 @@ use std::{ }; use chrono::{DateTime, Utc}; -use raiko_core::interfaces::ProofType; -use raiko_lib::primitives::{ChainId, B256}; +use raiko_lib::{ + primitives::B256, + prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}, +}; use rusqlite::{ named_params, {Connection, OpenFlags}, }; -use tokio::sync::Mutex; +use tokio::{runtime::Builder, sync::Mutex}; use crate::{ - EnqueueTaskParams, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, - TaskManagerResult, TaskProvingStatus, TaskProvingStatusRecords, TaskReport, TaskStatus, + TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, + TaskProvingStatus, TaskProvingStatusRecords, TaskReport, TaskStatus, }; // Types @@ -230,6 +232,17 @@ impl TaskDb { // and introduce a migration on DB opening ... if conserving history is important. conn.execute_batch( r#" + -- Key value store + ----------------------------------------------- + CREATE TABLE store( + chain_id INTEGER NOT NULL, + blockhash BLOB NOT NULL, + proofsys_id INTEGER NOT NULL, + id TEXT NOT NULL, + FOREIGN KEY(proofsys_id) REFERENCES proofsys(id), + UNIQUE (chain_id, blockhash, proofsys_id) + ); + -- Metadata and mappings ----------------------------------------------- CREATE TABLE metadata( @@ -298,7 +311,7 @@ impl TaskDb { -- Proofs might also be large, so we isolate them in a dedicated table CREATE TABLE task_proofs( task_id INTEGER UNIQUE NOT NULL PRIMARY KEY, - proof BLOB NOT NULL, + proof TEXT, FOREIGN KEY(task_id) REFERENCES tasks(id) ); @@ -485,13 +498,12 @@ impl TaskDb { pub fn enqueue_task( &self, - EnqueueTaskParams { + TaskDescriptor { chain_id, blockhash, - proof_type, + proof_system, prover, - .. - }: &EnqueueTaskParams, + }: &TaskDescriptor, ) -> TaskManagerResult> { let mut statement = self.conn.prepare_cached( r#" @@ -514,11 +526,11 @@ impl TaskDb { statement.execute(named_params! { ":chain_id": chain_id, ":blockhash": blockhash.to_vec(), - ":proofsys_id": *proof_type as u8, + ":proofsys_id": *proof_system as u8, ":prover": prover, })?; - Ok(vec![TaskProvingStatus( + Ok(vec![( TaskStatus::Registered, Some(prover.clone()), Utc::now(), @@ -527,10 +539,12 @@ impl TaskDb { pub fn update_task_progress( &self, - chain_id: ChainId, - blockhash: B256, - proof_type: ProofType, - prover: Option, + TaskDescriptor { + chain_id, + blockhash, + proof_system, + prover, + }: TaskDescriptor, status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()> { @@ -559,10 +573,10 @@ impl TaskDb { statement.execute(named_params! { ":chain_id": chain_id, ":blockhash": blockhash.to_vec(), - ":proofsys_id": proof_type as u8, + ":proofsys_id": proof_system as u8, + ":prover": prover, ":status_id": status as i32, - ":prover": prover.unwrap_or_default(), - ":proof": proof + ":proof": proof.map(hex::encode) })?; Ok(()) @@ -570,20 +584,23 @@ impl TaskDb { pub fn get_task_proving_status( &self, - chain_id: ChainId, - blockhash: B256, - proof_type: ProofType, - prover: Option, + TaskDescriptor { + chain_id, + blockhash, + proof_system, + prover, + }: &TaskDescriptor, ) -> TaskManagerResult { let mut statement = self.conn.prepare_cached( r#" SELECT ts.status_id, - t.prover, + tp.proof, timestamp FROM task_status ts LEFT JOIN tasks t ON ts.task_id = t.id + LEFT JOIN task_proofs tp ON tp.task_id = t.id WHERE t.chain_id = :chain_id AND t.blockhash = :blockhash @@ -597,13 +614,13 @@ impl TaskDb { named_params! { ":chain_id": chain_id, ":blockhash": blockhash.to_vec(), - ":proofsys_id": proof_type as u8, - ":prover": prover.unwrap_or_default(), + ":proofsys_id": *proof_system as u8, + ":prover": prover, }, |row| { - Ok(TaskProvingStatus( + Ok(( TaskStatus::from(row.get::<_, i32>(0)?), - Some(row.get::<_, String>(1)?), + row.get::<_, Option>(1)?, row.get::<_, DateTime>(2)?, )) }, @@ -614,10 +631,12 @@ impl TaskDb { pub fn get_task_proof( &self, - chain_id: ChainId, - blockhash: B256, - proof_type: ProofType, - prover: Option, + TaskDescriptor { + chain_id, + blockhash, + proof_system, + prover, + }: &TaskDescriptor, ) -> TaskManagerResult> { let mut statement = self.conn.prepare_cached( r#" @@ -639,13 +658,18 @@ impl TaskDb { named_params! { ":chain_id": chain_id, ":blockhash": blockhash.to_vec(), - ":proofsys_id": proof_type as u8, - ":prover": prover.unwrap_or_default(), + ":proofsys_id": *proof_system as u8, + ":prover": prover, }, - |row| row.get(0), + |row| row.get::<_, Option>(0), )?; - Ok(query) + let Some(proof) = query else { + return Ok(vec![]); + }; + + hex::decode(proof) + .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) } pub fn get_db_size(&self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { @@ -713,7 +737,7 @@ impl TaskDb { )?; let query = statement .query_map([], |row| { - Ok(TaskReport( + Ok(( TaskDescriptor { chain_id: row.get(0)?, blockhash: B256::from_slice(&row.get::<_, Vec>(1)?), @@ -727,6 +751,117 @@ impl TaskDb { Ok(query) } + + fn store_id( + &self, + (chain_id, blockhash, proof_key): ProofKey, + id: String, + ) -> TaskManagerResult<()> { + let mut statement = self.conn.prepare_cached( + r#" + INSERT INTO + store( + chain_id, + blockhash, + proofsys_id, + id + ) + VALUES + ( + :chain_id, + :blockhash, + :proofsys_id, + :id + ); + "#, + )?; + statement.execute(named_params! { + ":chain_id": chain_id, + ":blockhash": blockhash.to_vec(), + ":proofsys_id": proof_key, + ":id": id, + })?; + + Ok(()) + } + + fn remove_id(&self, (chain_id, blockhash, proof_key): ProofKey) -> TaskManagerResult<()> { + let mut statement = self.conn.prepare_cached( + r#" + DELETE FROM + store + WHERE + chain_id = :chain_id + AND blockhash = :blockhash + AND proofsys_id = :proofsys_id; + "#, + )?; + statement.execute(named_params! { + ":chain_id": chain_id, + ":blockhash": blockhash.to_vec(), + ":proofsys_id": proof_key, + })?; + + Ok(()) + } + + fn read_id(&self, (chain_id, blockhash, proof_key): ProofKey) -> TaskManagerResult { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + id + FROM + store + WHERE + chain_id = :chain_id + AND blockhash = :blockhash + AND proofsys_id = :proofsys_id + LIMIT + 1; + "#, + )?; + let query = statement.query_row( + named_params! { + ":chain_id": chain_id, + ":blockhash": blockhash.to_vec(), + ":proofsys_id": proof_key, + }, + |row| row.get::<_, String>(0), + )?; + + Ok(query) + } +} + +impl IdWrite for SqliteTaskManager { + fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { + let rt = Builder::new_current_thread().enable_all().build()?; + rt.block_on(async move { + let task_db = self.arc_task_db.lock().await; + task_db.store_id(key, id) + }) + .map_err(|e| ProverError::StoreError(e.to_string())) + } + + fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { + let rt = Builder::new_current_thread().enable_all().build()?; + rt.block_on(async move { + let task_db = self.arc_task_db.lock().await; + task_db.remove_id(key) + }) + .map_err(|e| ProverError::StoreError(e.to_string())) + } +} + +impl IdStore for SqliteTaskManager { + fn read_id(&self, key: ProofKey) -> ProverResult { + let rt = Builder::new_current_thread().enable_all().build()?; + rt.block_on(async move { + let task_db = self.arc_task_db.lock().await; + task_db.read_id(key) + }) + .map_err(|e| ProverError::StoreError(e.to_string())) + } } #[async_trait::async_trait] @@ -750,7 +885,7 @@ impl TaskManager for SqliteTaskManager { async fn enqueue_task( &mut self, - params: &EnqueueTaskParams, + params: &TaskDescriptor, ) -> Result, TaskManagerError> { let task_db = self.arc_task_db.lock().await; task_db.enqueue_task(params) @@ -758,38 +893,26 @@ impl TaskManager for SqliteTaskManager { async fn update_task_progress( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_type: ProofType, - prover: Option, + key: TaskDescriptor, status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()> { let task_db = self.arc_task_db.lock().await; - task_db.update_task_progress(chain_id, blockhash, proof_type, prover, status, proof) + task_db.update_task_progress(key, status, proof) } /// Returns the latest triplet (submitter or fulfiller, status, last update time) async fn get_task_proving_status( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_type: ProofType, - prover: Option, + key: &TaskDescriptor, ) -> TaskManagerResult { let task_db = self.arc_task_db.lock().await; - task_db.get_task_proving_status(chain_id, blockhash, proof_type, prover) + task_db.get_task_proving_status(key) } - async fn get_task_proof( - &mut self, - chain_id: ChainId, - blockhash: B256, - proof_type: ProofType, - prover: Option, - ) -> TaskManagerResult> { + async fn get_task_proof(&mut self, key: &TaskDescriptor) -> TaskManagerResult> { let task_db = self.arc_task_db.lock().await; - task_db.get_task_proof(chain_id, blockhash, proof_type, prover) + task_db.get_task_proof(key) } /// Returns the total and detailed database size diff --git a/tasks/src/lib.rs b/tasks/src/lib.rs index d0e7fdbb5..728ff7475 100644 --- a/tasks/src/lib.rs +++ b/tasks/src/lib.rs @@ -6,7 +6,10 @@ use std::{ use chrono::{DateTime, Utc}; use num_enum::{FromPrimitive, IntoPrimitive}; use raiko_core::interfaces::ProofType; -use raiko_lib::primitives::{ChainId, B256}; +use raiko_lib::{ + primitives::{ChainId, B256}, + prover::{IdStore, IdWrite, ProofKey, ProverResult}, +}; use rusqlite::Error as SqlError; use serde::Serialize; use utoipa::ToSchema; @@ -76,15 +79,6 @@ pub enum TaskStatus { SqlDbCorruption = -99999, } -#[derive(Debug, Clone, Default)] -pub struct EnqueueTaskParams { - pub chain_id: ChainId, - pub blockhash: B256, - pub proof_type: ProofType, - pub prover: String, - pub block_number: u64, -} - #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash)] pub struct TaskDescriptor { pub chain_id: ChainId, @@ -93,108 +87,79 @@ pub struct TaskDescriptor { pub prover: String, } -impl TaskDescriptor { - pub fn to_vec(self) -> Vec { - self.into() - } -} - -impl From for Vec { - fn from(val: TaskDescriptor) -> Self { - let mut v = Vec::new(); - v.extend_from_slice(&val.chain_id.to_be_bytes()); - v.extend_from_slice(val.blockhash.as_ref()); - v.extend_from_slice(&(val.proof_system as u8).to_be_bytes()); - v.extend_from_slice(val.prover.as_bytes()); - v - } -} - -// Taskkey from EnqueueTaskParams -impl From<&EnqueueTaskParams> for TaskDescriptor { - fn from(params: &EnqueueTaskParams) -> TaskDescriptor { +impl From<(ChainId, B256, ProofType, String)> for TaskDescriptor { + fn from( + (chain_id, blockhash, proof_system, prover): (ChainId, B256, ProofType, String), + ) -> Self { TaskDescriptor { - chain_id: params.chain_id, - blockhash: params.blockhash, - proof_system: params.proof_type, - prover: params.prover.clone(), + chain_id, + blockhash, + proof_system, + prover, } } } -impl From<(ChainId, B256, ProofType, Option)> for TaskDescriptor { +impl From for (ChainId, B256) { fn from( - (chain_id, blockhash, proof_system, prover): (ChainId, B256, ProofType, Option), - ) -> Self { TaskDescriptor { chain_id, blockhash, - proof_system, - prover: prover.unwrap_or_default(), - } + .. + }: TaskDescriptor, + ) -> Self { + (chain_id, blockhash) } } -#[derive(Debug, Clone)] -pub struct TaskProvingStatus(pub TaskStatus, pub Option, pub DateTime); +/// Task status triplet (status, proof, timestamp). +pub type TaskProvingStatus = (TaskStatus, Option, DateTime); pub type TaskProvingStatusRecords = Vec; +pub type TaskReport = (TaskDescriptor, TaskStatus); + #[derive(Debug, Clone)] pub struct TaskManagerOpts { pub sqlite_file: PathBuf, pub max_db_size: usize, } -#[derive(Debug, Serialize, ToSchema)] -pub struct TaskReport(pub TaskDescriptor, pub TaskStatus); - #[async_trait::async_trait] -pub trait TaskManager { - /// new a task manager +pub trait TaskManager: IdStore + IdWrite { + /// Create a new task manager. fn new(opts: &TaskManagerOpts) -> Self; - /// enqueue_task + /// Enqueue a new task to the tasks database. async fn enqueue_task( &mut self, - request: &EnqueueTaskParams, + request: &TaskDescriptor, ) -> TaskManagerResult; - /// Update the task progress + /// Update a specific tasks progress. async fn update_task_progress( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: TaskDescriptor, status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()>; - /// Returns the latest triplet (submitter or fulfiller, status, last update time) + /// Returns the latest triplet (status, proof - if any, last update time). async fn get_task_proving_status( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: &TaskDescriptor, ) -> TaskManagerResult; - /// Returns the proof for the given task - async fn get_task_proof( - &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, - ) -> TaskManagerResult>; + /// Returns the proof for the given task. + async fn get_task_proof(&mut self, key: &TaskDescriptor) -> TaskManagerResult>; - /// Returns the total and detailed database size + /// Returns the total and detailed database size. async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)>; - /// Prune old tasks + /// Prune old tasks. async fn prune_db(&mut self) -> TaskManagerResult<()>; + /// List all tasks in the db. async fn list_all_tasks(&mut self) -> TaskManagerResult>; } @@ -214,6 +179,31 @@ pub struct TaskManagerWrapper { manager: TaskManagerInstance, } +impl IdWrite for TaskManagerWrapper { + fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => manager.store_id(key, id), + TaskManagerInstance::Sqlite(ref mut manager) => manager.store_id(key, id), + } + } + + fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => manager.remove_id(key), + TaskManagerInstance::Sqlite(ref mut manager) => manager.remove_id(key), + } + } +} + +impl IdStore for TaskManagerWrapper { + fn read_id(&self, key: ProofKey) -> ProverResult { + match &self.manager { + TaskManagerInstance::InMemory(manager) => manager.read_id(key), + TaskManagerInstance::Sqlite(manager) => manager.read_id(key), + } + } +} + #[async_trait::async_trait] impl TaskManager for TaskManagerWrapper { fn new(opts: &TaskManagerOpts) -> Self { @@ -228,7 +218,7 @@ impl TaskManager for TaskManagerWrapper { async fn enqueue_task( &mut self, - request: &EnqueueTaskParams, + request: &TaskDescriptor, ) -> TaskManagerResult { match &mut self.manager { TaskManagerInstance::InMemory(ref mut manager) => manager.enqueue_task(request).await, @@ -238,66 +228,38 @@ impl TaskManager for TaskManagerWrapper { async fn update_task_progress( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: TaskDescriptor, status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()> { match &mut self.manager { TaskManagerInstance::InMemory(ref mut manager) => { - manager - .update_task_progress(chain_id, blockhash, proof_system, prover, status, proof) - .await + manager.update_task_progress(key, status, proof).await } TaskManagerInstance::Sqlite(ref mut manager) => { - manager - .update_task_progress(chain_id, blockhash, proof_system, prover, status, proof) - .await + manager.update_task_progress(key, status, proof).await } } } async fn get_task_proving_status( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: &TaskDescriptor, ) -> TaskManagerResult { match &mut self.manager { TaskManagerInstance::InMemory(ref mut manager) => { - manager - .get_task_proving_status(chain_id, blockhash, proof_system, prover) - .await + manager.get_task_proving_status(key).await } TaskManagerInstance::Sqlite(ref mut manager) => { - manager - .get_task_proving_status(chain_id, blockhash, proof_system, prover) - .await + manager.get_task_proving_status(key).await } } } - async fn get_task_proof( - &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, - ) -> TaskManagerResult> { + async fn get_task_proof(&mut self, key: &TaskDescriptor) -> TaskManagerResult> { match &mut self.manager { - TaskManagerInstance::InMemory(ref mut manager) => { - manager - .get_task_proof(chain_id, blockhash, proof_system, prover) - .await - } - TaskManagerInstance::Sqlite(ref mut manager) => { - manager - .get_task_proof(chain_id, blockhash, proof_system, prover) - .await - } + TaskManagerInstance::InMemory(ref mut manager) => manager.get_task_proof(key).await, + TaskManagerInstance::Sqlite(ref mut manager) => manager.get_task_proof(key).await, } } @@ -348,12 +310,11 @@ mod test { assert_eq!( task_manager - .enqueue_task(&EnqueueTaskParams { + .enqueue_task(&TaskDescriptor { chain_id: 1, blockhash: B256::default(), - proof_type: ProofType::Native, + proof_system: ProofType::Native, prover: "test".to_string(), - block_number: 1 }) .await .unwrap() diff --git a/tasks/src/mem_db.rs b/tasks/src/mem_db.rs index f0496d3b1..13137760f 100644 --- a/tasks/src/mem_db.rs +++ b/tasks/src/mem_db.rs @@ -12,17 +12,16 @@ use std::{ sync::{Arc, Once}, }; -use crate::{ - ensure, EnqueueTaskParams, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, - TaskManagerResult, TaskProvingStatus, TaskProvingStatusRecords, TaskReport, TaskStatus, -}; - use chrono::Utc; -use raiko_core::interfaces::ProofType; -use raiko_lib::primitives::{ChainId, B256}; -use tokio::sync::Mutex; +use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; +use tokio::{runtime::Builder, sync::Mutex}; use tracing::{debug, info}; +use crate::{ + ensure, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, + TaskProvingStatusRecords, TaskReport, TaskStatus, +}; + #[derive(Debug)] pub struct InMemoryTaskManager { db: Arc>, @@ -31,24 +30,21 @@ pub struct InMemoryTaskManager { #[derive(Debug)] pub struct InMemoryTaskDb { enqueue_task: HashMap, + store: HashMap, } impl InMemoryTaskDb { fn new() -> InMemoryTaskDb { InMemoryTaskDb { enqueue_task: HashMap::new(), + store: HashMap::new(), } } - fn enqueue_task(&mut self, params: &EnqueueTaskParams) { - let key = TaskDescriptor::from(params); - let task_status = TaskProvingStatus( - TaskStatus::Registered, - Some(params.prover.clone()), - Utc::now(), - ); + fn enqueue_task(&mut self, key: &TaskDescriptor) { + let task_status = (TaskStatus::Registered, None, Utc::now()); - match self.enqueue_task.get(&key) { + match self.enqueue_task.get(key) { Some(task_proving_records) => { debug!( "Task already exists: {:?}", @@ -56,71 +52,58 @@ impl InMemoryTaskDb { ); } // do nothing None => { - info!("Enqueue new task: {:?}", params); - self.enqueue_task.insert(key, vec![task_status]); + info!("Enqueue new task: {key:?}"); + self.enqueue_task.insert(key.clone(), vec![task_status]); } } } fn update_task_progress( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: TaskDescriptor, status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()> { - let key = TaskDescriptor::from((chain_id, blockhash, proof_system, prover.clone())); ensure(self.enqueue_task.contains_key(&key), "no task found")?; self.enqueue_task.entry(key).and_modify(|entry| { if let Some(latest) = entry.last() { if latest.0 != status { - entry.push(TaskProvingStatus( - status, - proof.map(hex::encode), - Utc::now(), - )); + entry.push((status, proof.map(hex::encode), Utc::now())); } } }); + Ok(()) } fn get_task_proving_status( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: &TaskDescriptor, ) -> TaskManagerResult { - let key = TaskDescriptor::from((chain_id, blockhash, proof_system, prover.clone())); - - match self.enqueue_task.get(&key) { - Some(proving_status_records) => Ok(proving_status_records.clone()), - None => Ok(vec![]), - } + Ok(self.enqueue_task.get(key).cloned().unwrap_or_default()) } - fn get_task_proof( - &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, - ) -> TaskManagerResult> { - let key = TaskDescriptor::from((chain_id, blockhash, proof_system, prover.clone())); - ensure(self.enqueue_task.contains_key(&key), "no task found")?; + fn get_task_proof(&mut self, key: &TaskDescriptor) -> TaskManagerResult> { + ensure(self.enqueue_task.contains_key(key), "no task found")?; - let Some(proving_status_records) = self.enqueue_task.get(&key) else { - return Err(TaskManagerError::SqlError("no task in db".to_owned())); - }; + let proving_status_records = self + .enqueue_task + .get(key) + .ok_or_else(|| TaskManagerError::SqlError("no task in db".to_owned()))?; - proving_status_records + let (_, proof, ..) = proving_status_records + .iter() + .filter(|(status, ..)| (status == &TaskStatus::Success)) .last() - .map(|status| hex::decode(status.1.clone().unwrap()).unwrap()) - .ok_or_else(|| TaskManagerError::SqlError("working in progress".to_owned())) + .ok_or_else(|| TaskManagerError::SqlError("no successful task in db".to_owned()))?; + + let Some(proof) = proof else { + return Ok(vec![]); + }; + + hex::decode(proof) + .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) } fn size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { @@ -128,6 +111,7 @@ impl InMemoryTaskDb { } fn prune(&mut self) -> TaskManagerResult<()> { + self.enqueue_task.clear(); Ok(()) } @@ -136,13 +120,58 @@ impl InMemoryTaskDb { .enqueue_task .iter() .flat_map(|(descriptor, statuses)| { - // Return the latest status - statuses - .last() - .map(|status| TaskReport(descriptor.clone(), status.0)) + statuses.last().map(|status| (descriptor.clone(), status.0)) }) .collect()) } + + fn store_id(&mut self, key: ProofKey, id: String) -> TaskManagerResult<()> { + self.store.insert(key, id); + Ok(()) + } + + fn remove_id(&mut self, key: ProofKey) -> TaskManagerResult<()> { + self.store.remove(&key); + Ok(()) + } + + fn read_id(&mut self, key: ProofKey) -> TaskManagerResult { + self.store + .get(&key) + .cloned() + .ok_or_else(|| TaskManagerError::SqlError("no id found".to_owned())) + } +} + +impl IdWrite for InMemoryTaskManager { + fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { + let rt = Builder::new_current_thread().enable_all().build()?; + rt.block_on(async move { + let mut db = self.db.lock().await; + db.store_id(key, id) + }) + .map_err(|e| ProverError::StoreError(e.to_string())) + } + + fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { + let rt = Builder::new_current_thread().enable_all().build()?; + rt.block_on(async move { + let mut db = self.db.lock().await; + db.remove_id(key) + }) + .map_err(|e| ProverError::StoreError(e.to_string())) + } +} + +impl IdStore for InMemoryTaskManager { + fn read_id(&self, key: ProofKey) -> ProverResult { + let rt = Builder::new_current_thread().enable_all().build()?; + rt.block_on(async move { + let mut db = self.db.lock().await; + db.read_id(key) + }) + .map_err(|e| ProverError::StoreError(e.to_string())) + } } #[async_trait::async_trait] @@ -166,62 +195,40 @@ impl TaskManager for InMemoryTaskManager { async fn enqueue_task( &mut self, - params: &EnqueueTaskParams, + params: &TaskDescriptor, ) -> TaskManagerResult { let mut db = self.db.lock().await; - let status = db.get_task_proving_status( - params.chain_id, - params.blockhash, - params.proof_type, - Some(params.prover.to_string()), - )?; - if status.is_empty() { - db.enqueue_task(params); - db.get_task_proving_status( - params.chain_id, - params.blockhash, - params.proof_type, - Some(params.prover.clone()), - ) - } else { - Ok(status) + let status = db.get_task_proving_status(params)?; + if !status.is_empty() { + return Ok(status); } + + db.enqueue_task(params); + db.get_task_proving_status(params) } async fn update_task_progress( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: TaskDescriptor, status: TaskStatus, proof: Option<&[u8]>, ) -> TaskManagerResult<()> { let mut db = self.db.lock().await; - db.update_task_progress(chain_id, blockhash, proof_system, prover, status, proof) + db.update_task_progress(key, status, proof) } /// Returns the latest triplet (submitter or fulfiller, status, last update time) async fn get_task_proving_status( &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, + key: &TaskDescriptor, ) -> TaskManagerResult { let mut db = self.db.lock().await; - db.get_task_proving_status(chain_id, blockhash, proof_system, prover) + db.get_task_proving_status(key) } - async fn get_task_proof( - &mut self, - chain_id: ChainId, - blockhash: B256, - proof_system: ProofType, - prover: Option, - ) -> TaskManagerResult> { + async fn get_task_proof(&mut self, key: &TaskDescriptor) -> TaskManagerResult> { let mut db = self.db.lock().await; - db.get_task_proof(chain_id, blockhash, proof_system, prover) + db.get_task_proof(key) } /// Returns the total and detailed database size @@ -243,6 +250,8 @@ impl TaskManager for InMemoryTaskManager { #[cfg(test)] mod tests { + use alloy_primitives::B256; + use super::*; use crate::ProofType; @@ -254,20 +263,14 @@ mod tests { #[test] fn test_db_enqueue() { let mut db = InMemoryTaskDb::new(); - let params = EnqueueTaskParams { + let params = TaskDescriptor { chain_id: 1, blockhash: B256::default(), - proof_type: ProofType::Native, + proof_system: ProofType::Native, prover: "0x1234".to_owned(), - ..Default::default() }; db.enqueue_task(¶ms); - let status = db.get_task_proving_status( - params.chain_id, - params.blockhash, - params.proof_type, - Some(params.prover.clone()), - ); + let status = db.get_task_proving_status(¶ms); assert!(status.is_ok()); } } diff --git a/tasks/tests/main.rs b/tasks/tests/main.rs index 3b0099181..7214a546c 100644 --- a/tasks/tests/main.rs +++ b/tasks/tests/main.rs @@ -15,9 +15,7 @@ mod tests { use rand_chacha::ChaCha8Rng; use raiko_lib::{input::BlobProofType, primitives::B256}; - use raiko_tasks::{ - get_task_manager, EnqueueTaskParams, TaskManager, TaskManagerOpts, TaskStatus, - }; + use raiko_tasks::{get_task_manager, TaskManager, TaskManagerOpts, TaskStatus}; fn create_random_task(rng: &mut ChaCha8Rng) -> (u64, B256, ProofRequest) { let chain_id = 100; @@ -68,15 +66,17 @@ mod tests { max_db_size: 1_000_000, }); - let (chain_id, block_hash, request) = + let (chain_id, blockhash, request) = create_random_task(&mut ChaCha8Rng::seed_from_u64(123)); - tama.enqueue_task(&EnqueueTaskParams { - chain_id, - blockhash: block_hash, - proof_type: request.proof_type, - prover: request.prover.to_string(), - block_number: request.block_number, - }) + tama.enqueue_task( + &( + chain_id, + blockhash, + request.proof_type, + request.prover.to_string(), + ) + .into(), + ) .await .unwrap(); } @@ -106,24 +106,29 @@ mod tests { let mut tasks = vec![]; for _ in 0..5 { - let (chain_id, block_hash, request) = create_random_task(&mut rng); + let (chain_id, blockhash, request) = create_random_task(&mut rng); - tama.enqueue_task(&EnqueueTaskParams { - chain_id, - blockhash: block_hash, - proof_type: request.proof_type, - prover: request.prover.to_string(), - block_number: request.block_number, - }) + tama.enqueue_task( + &( + chain_id, + blockhash, + request.proof_type, + request.prover.to_string(), + ) + .into(), + ) .await .unwrap(); let task_status = tama .get_task_proving_status( - chain_id, - block_hash, - request.proof_type, - Some(request.prover.to_string()), + &( + chain_id, + blockhash, + request.proof_type, + request.prover.to_string(), + ) + .into(), ) .await .unwrap(); @@ -135,7 +140,7 @@ mod tests { tasks.push(( chain_id, - block_hash, + blockhash, request.block_number, request.proof_type, request.prover, @@ -147,19 +152,13 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[0].0, - tasks[0].1, - tasks[0].3, - Some(tasks[0].4.to_string()), + &(tasks[0].0, tasks[0].1, tasks[0].3, tasks[0].4.to_string()).into(), ) .await .unwrap(); println!("{task_status:?}"); tama.update_task_progress( - tasks[0].0, - tasks[0].1, - tasks[0].3, - Some(tasks[0].4.to_string()), + (tasks[0].0, tasks[0].1, tasks[0].3, tasks[0].4.to_string()).into(), TaskStatus::Cancelled_NeverStarted, None, ) @@ -168,10 +167,7 @@ mod tests { let task_status = tama .get_task_proving_status( - tasks[0].0, - tasks[0].1, - tasks[0].3, - Some(tasks[0].4.to_string()), + &(tasks[0].0, tasks[0].1, tasks[0].3, tasks[0].4.to_string()).into(), ) .await .unwrap(); @@ -183,10 +179,7 @@ mod tests { // ----------------------- { tama.update_task_progress( - tasks[1].0, - tasks[1].1, - tasks[1].3, - Some(tasks[1].4.to_string()), + (tasks[1].0, tasks[1].1, tasks[1].3, tasks[1].4.to_string()).into(), TaskStatus::WorkInProgress, None, ) @@ -196,10 +189,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[1].0, - tasks[1].1, - tasks[1].3, - Some(tasks[1].4.to_string()), + &(tasks[1].0, tasks[1].1, tasks[1].3, tasks[1].4.to_string()).into(), ) .await .unwrap(); @@ -211,10 +201,7 @@ mod tests { std::thread::sleep(Duration::from_millis(1)); tama.update_task_progress( - tasks[1].0, - tasks[1].1, - tasks[1].3, - Some(tasks[1].4.to_string()), + (tasks[1].0, tasks[1].1, tasks[1].3, tasks[1].4.to_string()).into(), TaskStatus::CancellationInProgress, None, ) @@ -224,10 +211,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[1].0, - tasks[1].1, - tasks[1].3, - Some(tasks[1].4.to_string()), + &(tasks[1].0, tasks[1].1, tasks[1].3, tasks[1].4.to_string()).into(), ) .await .unwrap(); @@ -240,10 +224,7 @@ mod tests { std::thread::sleep(Duration::from_millis(1)); tama.update_task_progress( - tasks[1].0, - tasks[1].1, - tasks[1].3, - Some(tasks[1].4.to_string()), + (tasks[1].0, tasks[1].1, tasks[1].3, tasks[1].4.to_string()).into(), TaskStatus::Cancelled, None, ) @@ -253,10 +234,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[1].0, - tasks[1].1, - tasks[1].3, - Some(tasks[1].4.to_string()), + &(tasks[1].0, tasks[1].1, tasks[1].3, tasks[1].4.to_string()).into(), ) .await .unwrap(); @@ -271,10 +249,7 @@ mod tests { // ----------------------- { tama.update_task_progress( - tasks[2].0, - tasks[2].1, - tasks[2].3, - Some(tasks[2].4.to_string()), + (tasks[2].0, tasks[2].1, tasks[2].3, tasks[2].4.to_string()).into(), TaskStatus::WorkInProgress, None, ) @@ -284,10 +259,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[2].0, - tasks[2].1, - tasks[2].3, - Some(tasks[2].4.to_string()), + &(tasks[2].0, tasks[2].1, tasks[2].3, tasks[2].4.to_string()).into(), ) .await .unwrap(); @@ -300,10 +272,7 @@ mod tests { let proof: Vec<_> = (&mut rng).gen_iter::().take(128).collect(); tama.update_task_progress( - tasks[2].0, - tasks[2].1, - tasks[2].3, - Some(tasks[2].4.to_string()), + (tasks[2].0, tasks[2].1, tasks[2].3, tasks[2].4.to_string()).into(), TaskStatus::Success, Some(&proof), ) @@ -313,10 +282,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[2].0, - tasks[2].1, - tasks[2].3, - Some(tasks[2].4.to_string()), + &(tasks[2].0, tasks[2].1, tasks[2].3, tasks[2].4.to_string()).into(), ) .await .unwrap(); @@ -329,10 +295,7 @@ mod tests { assert_eq!( proof, tama.get_task_proof( - tasks[2].0, - tasks[2].1, - tasks[2].3, - Some(tasks[2].4.to_string()) + &(tasks[2].0, tasks[2].1, tasks[2].3, tasks[2].4.to_string()).into() ) .await .unwrap() @@ -342,10 +305,7 @@ mod tests { // ----------------------- { tama.update_task_progress( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + (tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), TaskStatus::WorkInProgress, None, ) @@ -355,10 +315,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + &(tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), ) .await .unwrap(); @@ -370,10 +327,7 @@ mod tests { std::thread::sleep(Duration::from_millis(1)); tama.update_task_progress( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + (tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), TaskStatus::NetworkFailure, None, ) @@ -383,10 +337,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + &(tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), ) .await .unwrap(); @@ -399,10 +350,7 @@ mod tests { std::thread::sleep(Duration::from_millis(1)); tama.update_task_progress( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + (tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), TaskStatus::WorkInProgress, None, ) @@ -412,10 +360,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + &(tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), ) .await .unwrap(); @@ -430,10 +375,7 @@ mod tests { let proof: Vec<_> = (&mut rng).gen_iter::().take(128).collect(); tama.update_task_progress( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + (tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), TaskStatus::Success, Some(proof.as_slice()), ) @@ -443,10 +385,7 @@ mod tests { { let task_status = tama .get_task_proving_status( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()), + &(tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into(), ) .await .unwrap(); @@ -461,10 +400,7 @@ mod tests { assert_eq!( proof, tama.get_task_proof( - tasks[3].0, - tasks[3].1, - tasks[3].3, - Some(tasks[3].4.to_string()) + &(tasks[3].0, tasks[3].1, tasks[3].3, tasks[3].4.to_string()).into() ) .await .unwrap()