diff --git a/.changelog/unreleased/improvements/3405-phase-1-shielded-sync-refactor.md b/.changelog/unreleased/improvements/3405-phase-1-shielded-sync-refactor.md new file mode 100644 index 0000000000..44f51930d7 --- /dev/null +++ b/.changelog/unreleased/improvements/3405-phase-1-shielded-sync-refactor.md @@ -0,0 +1,7 @@ + - Implements phase 1 of Issue [\#3385](https://github.com/anoma/namada/issues/3385) + - When fetching notes, connections and related failures should not halt shielded sync. Instead, the process + should be restarted + - If fetching is interrupted, the data fetched should be persisted locally so that progress isn't lost. + - A trait for fetching behavior should be added to provide modularity + + ([\#3498](https://github.com/anoma/namada/pull/3498)) diff --git a/Cargo.lock b/Cargo.lock index 1018527284..ccfca3eceb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2663,6 +2663,18 @@ dependencies = [ "paste", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -5257,6 +5269,7 @@ dependencies = [ "ethers", "eyre", "fd-lock", + "flume", "futures", "itertools 0.12.1", "jubjub", @@ -5582,6 +5595,15 @@ dependencies = [ "sha2 0.9.9", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom 0.2.15", +] + [[package]] name = "native-tls" version = "0.2.11" diff --git a/crates/apps_lib/src/cli/client.rs b/crates/apps_lib/src/cli/client.rs index 24ed0981ff..23be54895f 100644 --- a/crates/apps_lib/src/cli/client.rs +++ b/crates/apps_lib/src/cli/client.rs @@ -12,7 +12,7 @@ use crate::cli::cmds::*; use crate::client::{rpc, tx, utils}; impl CliApi { - pub async fn handle_client_command( + pub async fn handle_client_command( client: Option, cmd: cli::NamadaClient, io: IO, diff --git a/crates/apps_lib/src/client/masp.rs b/crates/apps_lib/src/client/masp.rs index 983be6d718..c21b446c11 100644 --- a/crates/apps_lib/src/client/masp.rs +++ b/crates/apps_lib/src/client/masp.rs @@ -1,14 +1,16 @@ use std::fmt::Debug; +use std::sync::{Arc, Mutex}; use color_eyre::owo_colors::OwoColorize; use masp_primitives::sapling::ViewingKey; use masp_primitives::zip32::ExtendedSpendingKey; use namada_sdk::error::Error; use namada_sdk::io::Io; -use namada_sdk::masp::{ - IndexedNoteEntry, ProgressLogger, ProgressType, ShieldedContext, - ShieldedUtils, +use namada_sdk::masp::utils::{ + LedgerMaspClient, PeekableIter, ProgressTracker, ProgressType, + RetryStrategy, }; +use namada_sdk::masp::{IndexedNoteEntry, ShieldedContext, ShieldedUtils}; use namada_sdk::queries::Client; use namada_sdk::storage::BlockHeight; use namada_sdk::{display, display_line, MaybeSend, MaybeSync}; @@ -17,7 +19,7 @@ use namada_sdk::{display, display_line, MaybeSend, MaybeSync}; pub async fn syncing< U: ShieldedUtils + MaybeSend + MaybeSync, C: Client + Sync, - IO: Io, + IO: Io + Send + Sync, >( mut shielded: ShieldedContext, client: &C, @@ -28,137 +30,238 @@ pub async fn syncing< sks: &[ExtendedSpendingKey], fvks: &[ViewingKey], ) -> Result, Error> { - let shutdown_signal = async { - let (tx, rx) = tokio::sync::oneshot::channel(); - namada_sdk::control_flow::shutdown_send(tx).await; - rx.await - }; - display_line!(io, "{}", "==== Shielded sync started ====".on_white()); display_line!(io, "\n\n"); - let logger = CliLogger::new(io); - let sync = async move { - shielded - .fetch( - client, - &logger, - start_query_height, - last_query_height, - batch_size, - sks, - fvks, - ) - .await - .map(|_| shielded) - }; - tokio::select! { - sync = sync => { - let shielded = sync?; - display!(io, "Syncing finished\n"); - Ok(shielded) - }, - sig = shutdown_signal => { - sig.map_err(|e| Error::Other(e.to_string()))?; - display!(io, "\n"); - Ok(ShieldedContext::default()) - }, - } + let tracker = CliProgressTracker::new(io); + + let shielded = shielded + .fetch( + LedgerMaspClient::new(client), + &tracker, + start_query_height, + last_query_height, + RetryStrategy::Forever, + batch_size, + sks, + fvks, + ) + .await + .map(|_| shielded)?; + + display!(io, "Syncing finished\n"); + Ok(shielded) } -pub struct CliLogging<'io, T, IO: Io> { - items: Vec, +/// The amount of progress a shielded sync sub-process has made +#[derive(Default, Copy, Clone, Debug)] +struct IterProgress { index: usize, length: usize, +} + +pub struct LoggingIterator<'io, T, I, IO> +where + T: Debug, + I: Iterator, + IO: Io, +{ + items: I, + progress: Arc>, io: &'io IO, r#type: ProgressType, + peeked: Option, + #[cfg(not(unix))] + num_logs_counter: usize, } -impl<'io, T: Debug, IO: Io> CliLogging<'io, T, IO> { - fn new(items: I, io: &'io IO, r#type: ProgressType) -> Self - where - I: IntoIterator, - { - let items: Vec<_> = items.into_iter().collect(); +impl<'io, T, I, IO> LoggingIterator<'io, T, I, IO> +where + T: Debug, + I: Iterator, + IO: Io, +{ + fn new( + items: I, + io: &'io IO, + r#type: ProgressType, + progress: Arc>, + ) -> Self { + let (size, _) = items.size_hint(); + { + let mut locked = progress.lock().unwrap(); + locked.length = size; + } Self { - length: items.len(), items, - index: 0, + progress, io, r#type, + peeked: None, + #[cfg(not(unix))] + num_logs_counter: 0, } } -} -impl<'io, T: Debug, IO: Io> Iterator for CliLogging<'io, T, IO> { - type Item = T; + fn advance_index(&mut self) { + let mut locked = self.progress.lock().unwrap(); + locked.index += 1; + if let ProgressType::Scan = self.r#type { + locked.length = self.items.size_hint().0; + } + } +} - fn next(&mut self) -> Option { - if self.index == 0 { - self.items = { - let mut new_items = vec![]; - std::mem::swap(&mut new_items, &mut self.items); - new_items.into_iter().rev().collect() - }; +impl<'io, T, I, IO> PeekableIter for LoggingIterator<'io, T, I, IO> +where + T: Debug, + I: Iterator, + IO: Io, +{ + fn peek(&mut self) -> Option<&T> { + if self.peeked.is_none() { + self.peeked = self.items.next(); } - if self.items.is_empty() { - return None; + self.peeked.as_ref() + } + + fn next(&mut self) -> Option { + self.peek(); + let next_item = self.peeked.take()?; + self.advance_index(); + + #[cfg(not(unix))] + { + if self.num_logs_counter % 20 != 0 { + return Some(next_item); + } } - self.index += 1; - let percent = (100 * self.index) / self.length; + + let (index, length) = { + let locked = self.progress.lock().unwrap(); + (locked.length, locked.index) + }; + + let percent = std::cmp::min(100, (100 * index) / length); let completed: String = vec!['#'; percent].iter().collect(); let incomplete: String = vec!['.'; 100 - percent].iter().collect(); - display_line!(self.io, "\x1b[2A\x1b[J"); + + #[cfg(unix)] + { + clear_last_lines(self.io, 2); + } + match self.r#type { ProgressType::Fetch => display_line!( self.io, "Fetched block {:?} of {:?}", - self.items.last().unwrap(), - self.items[0] - ), - ProgressType::Scan => display_line!( - self.io, - "Scanning {} of {}", - self.index, - self.length + index, + length ), + ProgressType::Scan => { + display_line!(self.io, "Scanning {} of {}", index, length) + } } display!(self.io, "[{}{}] ~~ {} %", completed, incomplete, percent); + + #[cfg(not(unix))] + { + self.num_logs_counter += 1; + display_line!(self.io, "\n"); + } + self.io.flush(); - self.items.pop() + Some(next_item) + } +} + +#[cfg(unix)] +fn clear_last_lines(io: &IO, num_lines: usize) { + display_line!(io, "\x1b[{num_lines}A\x1b[J"); +} + +impl<'io, T, I, IO> Drop for LoggingIterator<'io, T, I, IO> +where + T: Debug, + I: Iterator, + IO: Io, +{ + fn drop(&mut self) { + display_line!(self.io, "\x1b[2A\x1b[J"); + } +} + +impl<'io, T, I, IO> Iterator for LoggingIterator<'io, T, I, IO> +where + T: Debug, + I: Iterator, + IO: Io, +{ + type Item = T; + + fn next(&mut self) -> Option { + >::next(self) } } /// A progress logger for the CLI #[derive(Debug, Clone)] -pub struct CliLogger<'io, IO: Io> { +pub struct CliProgressTracker<'io, IO: Io> { io: &'io IO, + fetch: Arc>, + scan: Arc>, } -impl<'io, IO: Io> CliLogger<'io, IO> { +impl<'io, IO: Io> CliProgressTracker<'io, IO> { pub fn new(io: &'io IO) -> Self { - Self { io } + Self { + io, + fetch: Arc::new(Mutex::new(IterProgress::default())), + scan: Arc::new(Mutex::new(IterProgress::default())), + } } } -impl<'io, IO: Io> ProgressLogger for CliLogger<'io, IO> { - type Fetch = CliLogging<'io, u64, IO>; - type Scan = CliLogging<'io, IndexedNoteEntry, IO>; - +impl<'io, IO: Io + Send + Sync> ProgressTracker + for CliProgressTracker<'io, IO> +{ fn io(&self) -> &IO { self.io } - fn fetch(&self, items: I) -> Self::Fetch + fn fetch(&self, items: I) -> impl PeekableIter where - I: IntoIterator, + I: Iterator, { - CliLogging::new(items, self.io, ProgressType::Fetch) + LoggingIterator::new( + items, + self.io, + ProgressType::Fetch, + self.fetch.clone(), + ) } - fn scan(&self, items: I) -> Self::Scan + fn scan(&self, items: I) -> impl Iterator + Send where - I: IntoIterator, + I: Iterator + Send, { - CliLogging::new(items, self.io, ProgressType::Scan) + { + let mut locked = self.scan.lock().unwrap(); + *locked = IterProgress::default(); + } + LoggingIterator::new( + items, + self.io, + ProgressType::Scan, + self.scan.clone(), + ) + } + + fn left_to_fetch(&self) -> usize { + let locked = self.fetch.lock().unwrap(); + if locked.index > locked.length { + 0 + } else { + locked.length - locked.index + } } } diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 622866bd66..949ca60193 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -102,6 +102,7 @@ ethbridge-bridge-contract.workspace = true ethers.workspace = true eyre.workspace = true fd-lock = { workspace = true, optional = true } +flume = "0.11.0" futures.workspace = true itertools.workspace = true jubjub = { workspace = true, optional = true } diff --git a/crates/sdk/src/control_flow/mod.rs b/crates/sdk/src/control_flow/mod.rs index 76327304b9..7d77529c3e 100644 --- a/crates/sdk/src/control_flow/mod.rs +++ b/crates/sdk/src/control_flow/mod.rs @@ -10,6 +10,7 @@ use std::task::{Context, Poll}; use futures::future::FutureExt; #[cfg(any(unix, windows))] use tokio::sync::oneshot; +use tokio::sync::oneshot::error::TryRecvError; /// A shutdown signal receiver. pub struct ShutdownSignal { @@ -19,6 +20,18 @@ pub struct ShutdownSignal { rx: oneshot::Receiver<()>, } +impl ShutdownSignal { + /// Checks if an interrupt signal was received. + #[cfg(any(unix, windows))] + pub fn received(&mut self) -> bool { + match self.rx.try_recv() { + Ok(_) => true, + Err(TryRecvError::Empty) => false, + Err(TryRecvError::Closed) => true, + } + } +} + #[cfg(any(unix, windows))] impl Future for ShutdownSignal { type Output = (); @@ -65,6 +78,13 @@ pub fn install_shutdown_signal() -> ShutdownSignal { } } +/// A manually triggerable shutdown signal used for testing +#[cfg(any(test, feature = "testing"))] +pub fn testing_shutdown_signal() -> (oneshot::Sender<()>, ShutdownSignal) { + let (tx, rx) = oneshot::channel(); + (tx, ShutdownSignal { rx }) +} + /// Shutdown signal receiver #[cfg(unix)] pub async fn shutdown_send(tx: oneshot::Sender<()>) { diff --git a/crates/sdk/src/error.rs b/crates/sdk/src/error.rs index f0c22adfca..fc93e2c94f 100644 --- a/crates/sdk/src/error.rs +++ b/crates/sdk/src/error.rs @@ -22,7 +22,7 @@ pub type Result = std::result::Result; pub enum Error { /// Key Retrieval Errors #[error("Key Error: {0}")] - KeyRetrival(#[from] storage::Error), + KeyRetrieval(#[from] storage::Error), /// Transaction Errors #[error("{0}")] Tx(#[from] TxSubmitError), @@ -44,6 +44,9 @@ pub enum Error { /// Any Other errors that are uncategorized #[error("{0}")] Other(String), + /// An interrupt was called + #[error("Process {0} received an interrupt signal")] + Interrupt(String), } /// Errors that deal with querying some kind of data diff --git a/crates/sdk/src/masp.rs b/crates/sdk/src/masp.rs index bcf28db17b..58c90eae0b 100644 --- a/crates/sdk/src/masp.rs +++ b/crates/sdk/src/masp.rs @@ -1,12 +1,18 @@ //! MASP verification wrappers. +#[cfg(test)] +mod test_utils; +pub mod utils; use std::cmp::Ordering; use std::collections::{btree_map, BTreeMap, BTreeSet}; use std::env; use std::fmt::Debug; +use std::io::{Read, Write}; use std::path::PathBuf; +use std::sync::{Arc, Mutex}; use borsh::{BorshDeserialize, BorshSerialize}; +use borsh_ext::BorshSerializeExt; use itertools::Itertools; use masp_primitives::asset_type::AssetType; #[cfg(feature = "mainnet")] @@ -66,13 +72,18 @@ use rand_core::{CryptoRng, OsRng, RngCore, SeedableRng}; use smooth_operator::checked; use thiserror::Error; +use crate::control_flow::ShutdownSignal; use crate::error::{Error, QueryError}; use crate::io::Io; +use crate::masp::utils::{ + fetch_channel, FetchQueueSender, MaspClient, ProgressTracker, RetryStrategy, +}; use crate::queries::Client; -use crate::rpc::{ - query_block, query_conversion, query_denom, query_native_token, +use crate::rpc::{query_block, query_conversion, query_denom}; +use crate::{ + control_flow, display_line, edisplay_line, query_native_token, rpc, + MaybeSend, MaybeSync, Namada, }; -use crate::{display_line, edisplay_line, rpc, MaybeSend, MaybeSync, Namada}; /// Randomness seed for MASP integration tests to build proofs with /// deterministic rng. @@ -349,34 +360,93 @@ pub type TransferDelta = HashMap; /// Represents the changes that were made to a list of shielded accounts pub type TransactionDelta = HashMap; + /// A cache of fetched indexed transactions. /// -/// The cache is designed so that it either contains -/// all transactions from a given height, or none. -#[derive( - BorshSerialize, BorshDeserialize, BorshDeserializer, Debug, Default, Clone, -)] +/// An invariant that shielded-sync maintains is that +/// this cache either contains all transactions from +/// a given height, or none. +#[derive(Debug, Default, Clone)] pub struct Unscanned { - txs: IndexedNoteData, + txs: Arc>, +} + +impl BorshSerialize for Unscanned { + fn serialize(&self, writer: &mut W) -> std::io::Result<()> { + let locked = self.txs.lock().unwrap(); + let bytes = locked.serialize_to_vec(); + writer.write(&bytes).map(|_| ()) + } +} + +impl BorshDeserialize for Unscanned { + fn deserialize_reader(reader: &mut R) -> std::io::Result { + let unscanned = IndexedNoteData::deserialize_reader(reader)?; + Ok(Self { + txs: Arc::new(Mutex::new(unscanned)), + }) + } } impl Unscanned { - fn extend(&mut self, items: I) + /// Append elements to the cache from an iterator. + pub fn extend(&self, items: I) where I: IntoIterator, { - self.txs.extend(items); + let mut locked = self.txs.lock().unwrap(); + locked.extend(items); + } + + /// Add a single entry to the cache. + pub fn insert(&self, (k, v): IndexedNoteEntry) { + let mut locked = self.txs.lock().unwrap(); + locked.insert(k, v); } - fn contains_height(&self, height: u64) -> bool { - self.txs.keys().any(|k| k.height.0 == height) + /// Check if this cache has already been populated for a given + /// block height. + pub fn contains_height(&self, height: u64) -> bool { + let locked = self.txs.lock().unwrap(); + locked.keys().any(|k| k.height.0 == height) } /// We remove all indices from blocks that have been entirely scanned. /// If a block is only partially scanned, we leave all the events in the /// cache. - fn scanned(&mut self, ix: &IndexedTx) { - self.txs.retain(|i, _| i.height >= ix.height); + pub fn scanned(&self, ix: &IndexedTx) { + let mut locked = self.txs.lock().unwrap(); + locked.retain(|i, _| i.height >= ix.height); + } + + /// Gets the latest block height present in the cache + pub fn latest_height(&self) -> BlockHeight { + let txs = self.txs.lock().unwrap(); + txs.keys() + .max_by_key(|ix| ix.height) + .map(|ix| ix.height) + .unwrap_or_default() + } + + /// Gets the first block height present in the cache + pub fn first_height(&self) -> BlockHeight { + let txs = self.txs.lock().unwrap(); + txs.keys() + .min_by_key(|ix| ix.height) + .map(|ix| ix.height) + .unwrap_or_default() + } + + /// Remove the first entry from the cache and return it. + pub fn pop_first(&self) -> Option { + let mut locked = self.txs.lock().unwrap(); + locked.pop_first() + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + let locked = self.txs.lock().unwrap(); + locked.is_empty() } } @@ -385,10 +455,13 @@ impl IntoIterator for Unscanned { type Item = IndexedNoteEntry; fn into_iter(self) -> Self::IntoIter { - self.txs.into_iter() + let txs = { + let mut locked = self.txs.lock().unwrap(); + std::mem::take(&mut *locked) + }; + txs.into_iter() } } - #[derive(BorshSerialize, BorshDeserialize, Debug)] /// The possible sync states of the shielded context pub enum ContextSyncStatus { @@ -523,18 +596,61 @@ impl ShieldedContext { /// Fetch the current state of the multi-asset shielded pool into a /// ShieldedContext #[allow(clippy::too_many_arguments)] - pub async fn fetch( + #[cfg(not(target_family = "wasm"))] + pub async fn fetch<'client, C, IO, M>( &mut self, - client: &C, - logger: &impl ProgressLogger, + client: M, + progress: &impl ProgressTracker, start_query_height: Option, last_query_height: Option, + retry: RetryStrategy, // NOTE: do not remove this argument, it will be used once the indexer // is ready _batch_size: u64, sks: &[ExtendedSpendingKey], fvks: &[ViewingKey], - ) -> Result<(), Error> { + ) -> Result<(), Error> + where + C: Client + Sync, + IO: Io, + M: MaspClient<'client, C> + 'client, + { + let shutdown_signal = control_flow::install_shutdown_signal(); + self.fetch_aux( + client, + progress, + start_query_height, + last_query_height, + retry, + _batch_size, + sks, + fvks, + shutdown_signal, + ) + .await + } + + #[allow(clippy::too_many_arguments)] + #[cfg(not(target_family = "wasm"))] + async fn fetch_aux<'client, C, IO, M>( + &mut self, + client: M, + progress: &impl ProgressTracker, + start_query_height: Option, + last_query_height: Option, + retry: RetryStrategy, + // NOTE: do not remove this argument, it will be used once the indexer + // is ready + _batch_size: u64, + sks: &[ExtendedSpendingKey], + fvks: &[ViewingKey], + mut shutdown_signal: ShutdownSignal, + ) -> Result<(), Error> + where + C: Client + Sync, + IO: Io, + M: MaspClient<'client, C> + 'client, + { // add new viewing keys // Reload the state from file to get the last confirmed state and // discard any speculative data, we cannot fetch on top of a @@ -549,7 +665,6 @@ impl ShieldedContext { ..Default::default() }; } - for esk in sks { let vk = to_viewing_key(esk).vk; self.vk_heights.entry(vk).or_default(); @@ -557,179 +672,131 @@ impl ShieldedContext { for vk in fvks { self.vk_heights.entry(*vk).or_default(); } + + // Save the context to persist newly added keys let _ = self.save().await; - // the latest block height which has been added to the witness Merkle - // tree + + // the height of the key that is least synced let Some(least_idx) = self.vk_heights.values().min().cloned() else { return Ok(()); }; + // the latest block height which has been added to the witness Merkle + // tree let last_witnessed_tx = self.tx_note_map.keys().max().cloned(); // get the bounds on the block heights to fetch - let start_idx = + let start_height = std::cmp::min(last_witnessed_tx.as_ref(), least_idx.as_ref()) .map(|ix| ix.height); - let start_idx = start_query_height.or(start_idx); - // Load all transactions accepted until this point - // N.B. the cache is a hash map - self.unscanned.extend( - self.fetch_shielded_transfers( - client, - logger, - start_idx, - last_query_height, - ) - .await?, - ); - // persist the cache in case of interruptions. - let _ = self.save().await; - - let txs = logger.scan(self.unscanned.clone()); - for (ref indexed_tx, ref stx) in txs { - if Some(indexed_tx) > last_witnessed_tx.as_ref() { - self.update_witness_map(indexed_tx.to_owned(), stx)?; + let start_height = start_query_height.or(start_height); + // Query for the last produced block height + let last_block_height = query_block(client.rpc_client()) + .await? + .map(|b| b.height) + .unwrap_or_else(BlockHeight::first); + let last_query_height = last_query_height.unwrap_or(last_block_height); + let last_query_height = + std::cmp::min(last_query_height, last_block_height); + + for _ in retry { + // a stateful channel that communicates notes fetched to the trial + // decryption process + let (fetch_send, fetch_recv) = + fetch_channel::new(self.unscanned.clone()); + let fetch_res = self + .fetch_shielded_transfers( + &client, + progress, + &mut shutdown_signal, + fetch_send, + start_height, + last_query_height, + ) + .await; + // if fetching errored, log it. But this is recoverable. + match fetch_res { + Err(e @ Error::Interrupt(_)) => { + display_line!(progress.io(), "{}", e.to_string(),); + return Err(e); + } + Err(e) => display_line!( + progress.io(), + "Error encountered while fetching: {}", + e.to_string(), + ), + _ => {} } - let mut vk_heights = BTreeMap::new(); - std::mem::swap(&mut vk_heights, &mut self.vk_heights); - for (vk, h) in vk_heights - .iter_mut() - .filter(|(_vk, h)| h.as_ref() < Some(indexed_tx)) - { - self.scan_tx(indexed_tx.to_owned(), stx, vk)?; - *h = Some(indexed_tx.to_owned()); + let txs = progress.scan(fetch_recv); + for (ref indexed_tx, ref stx) in txs { + if Some(indexed_tx) > last_witnessed_tx.as_ref() { + self.update_witness_map(indexed_tx.to_owned(), stx)?; + } + let mut vk_heights = BTreeMap::new(); + std::mem::swap(&mut vk_heights, &mut self.vk_heights); + for (vk, h) in vk_heights + .iter_mut() + .filter(|(_vk, h)| h.as_ref() < Some(indexed_tx)) + { + self.scan_tx(indexed_tx.to_owned(), stx, vk)?; + *h = Some(indexed_tx.to_owned()); + } + // possibly remove unneeded elements from the cache. + self.unscanned.scanned(indexed_tx); + std::mem::swap(&mut vk_heights, &mut self.vk_heights); + if shutdown_signal.received() { + let _ = self.save().await; + return Err(Error::Interrupt( + "[ShieldedSync::Scanning]".to_string(), + )); + } + } + // if fetching failed for before completing, we restart + // the fetch process. Otherwise, we can break the loop. + if progress.left_to_fetch() == 0 { + break; } - // possibly remove unneeded elements from the cache. - self.unscanned.scanned(indexed_tx); - std::mem::swap(&mut vk_heights, &mut self.vk_heights); - let _ = self.save().await; } + _ = self.save().await; - Ok(()) + if progress.left_to_fetch() != 0 { + Err(Error::Other( + "After retrying, could not fetch all MASP txs.".to_string(), + )) + } else { + Ok(()) + } } /// Obtain a chronologically-ordered list of all accepted shielded /// transactions from a node. - pub async fn fetch_shielded_transfers( + async fn fetch_shielded_transfers< + 'client, + C: Client + Sync, + IO: Io, + M: MaspClient<'client, C> + 'client, + >( &self, - client: &C, - logger: &impl ProgressLogger, + client: &M, + progress: &impl ProgressTracker, + shutdown_signal: &mut ShutdownSignal, + block_sender: FetchQueueSender, last_indexed_tx: Option, - last_query_height: Option, - ) -> Result { - // Query for the last produced block height - let last_block_height = query_block(client) - .await? - .map_or_else(BlockHeight::first, |block| block.height); - let last_query_height = last_query_height.unwrap_or(last_block_height); - - let mut shielded_txs = BTreeMap::new(); + last_query_height: BlockHeight, + ) -> Result<(), Error> { // Fetch all the transactions we do not have yet let first_height_to_query = last_indexed_tx.map_or_else(|| 1, |last| last.0); - let heights = logger.fetch(first_height_to_query..=last_query_height.0); - for height in heights { - if self.unscanned.contains_height(height) { - continue; - } - - let txs_results = match get_indexed_masp_events_at_height( - client, - height.into(), - None, + let res = client + .fetch_shielded_transfers( + progress, + shutdown_signal, + block_sender, + first_height_to_query, + last_query_height.0, ) - .await? - { - Some(events) => events, - None => continue, - }; - - // Query the actual block to get the txs bytes. If we only need one - // tx it might be slightly better to query the /tx endpoint to - // reduce the amount of data sent over the network, but this is a - // minimal improvement and it's even hard to tell how many times - // we'd need a single masp tx to make this worth it - let block = client - .block(height) - .await - .map_err(|e| Error::from(QueryError::General(e.to_string())))? - .block - .data; - - for (idx, masp_sections_refs) in txs_results { - let tx = Tx::try_from(block[idx.0 as usize].as_ref()) - .map_err(|e| Error::Other(e.to_string()))?; - let extracted_masp_txs = - if let Some(masp_sections_refs) = masp_sections_refs { - Self::extract_masp_tx(&tx, &masp_sections_refs).await? - } else { - Self::extract_masp_tx_from_ibc_message(&tx)? - }; - // Collect the current transactions - shielded_txs.insert( - IndexedTx { - height: height.into(), - index: idx, - }, - extracted_masp_txs, - ); - } - } - - Ok(shielded_txs) - } - - /// Extract the relevant shield portions of a [`Tx`], if any. - async fn extract_masp_tx( - tx: &Tx, - masp_section_refs: &MaspTxRefs, - ) -> Result, Error> { - // NOTE: simply looking for masp sections attached to the tx - // is not safe. We don't validate the sections attached to a - // transaction se we could end up with transactions carrying - // an unnecessary masp section. We must instead look for the - // required masp sections coming from the events - - masp_section_refs - .0 - .iter() - .try_fold(vec![], |mut acc, hash| { - match tx.get_masp_section(hash).cloned().ok_or_else(|| { - Error::Other( - "Missing expected masp transaction".to_string(), - ) - }) { - Ok(transaction) => { - acc.push(transaction); - Ok(acc) - } - Err(e) => Err(e), - } - }) - } - - /// Extract the relevant shield portions from the IBC messages in [`Tx`] - fn extract_masp_tx_from_ibc_message( - tx: &Tx, - ) -> Result, Error> { - let mut masp_txs = Vec::new(); - for cmt in &tx.header.batch { - let tx_data = tx.data(cmt).ok_or_else(|| { - Error::Other("Missing transaction data".to_string()) - })?; - let ibc_msg = decode_message(&tx_data) - .map_err(|_| Error::Other("Invalid IBC message".to_string()))?; - if let IbcMessage::Envelope(ref envelope) = ibc_msg { - if let Some(masp_tx) = extract_masp_tx_from_envelope(envelope) { - masp_txs.push(masp_tx); - } - } - } - if !masp_txs.is_empty() { - Ok(masp_txs) - } else { - Err(Error::Other( - "IBC message doesn't have masp transaction".to_string(), - )) - } + .await; + // persist fetched notes + _ = self.save().await; + res } /// Applies the given transaction to the supplied context. More precisely, @@ -746,6 +813,13 @@ impl ShieldedContext { shielded: &[Transaction], vk: &ViewingKey, ) -> Result<(), Error> { + type Proof = OutputDescription< + < + ::SaplingAuth + as masp_primitives::transaction::components::sapling::Authorization + >::Proof + >; + // For tracking the account changes caused by this Transaction let mut transaction_delta = TransactionDelta::new(); if let ContextSyncStatus::Confirmed = self.sync_status { @@ -759,12 +833,12 @@ impl ShieldedContext { // Let's try to see if this viewing key can decrypt latest // note let notes = self.pos_map.entry(*vk).or_default(); - let decres = try_sapling_note_decryption::<_, OutputDescription<<::SaplingAuth as masp_primitives::transaction::components::sapling::Authorization>::Proof>>( - &NETWORK, - 1.into(), - &PreparedIncomingViewingKey::new(&vk.ivk()), - so, - ); + let decres = try_sapling_note_decryption::<_, Proof>( + &NETWORK, + 1.into(), + &PreparedIncomingViewingKey::new(&vk.ivk()), + so, + ); // So this current viewing key does decrypt this current // note... if let Some((note, pa, memo)) = decres { @@ -2356,6 +2430,59 @@ impl ShieldedContext { } } +/// Extract the relevant shield portions of a [`Tx`], if any. +async fn extract_masp_tx( + tx: &Tx, + masp_section_refs: &MaspTxRefs, +) -> Result, Error> { + // NOTE: simply looking for masp sections attached to the tx + // is not safe. We don't validate the sections attached to a + // transaction se we could end up with transactions carrying + // an unnecessary masp section. We must instead look for the + // required masp sections coming from the events + + masp_section_refs + .0 + .iter() + .try_fold(vec![], |mut acc, hash| { + match tx.get_masp_section(hash).cloned().ok_or_else(|| { + Error::Other("Missing expected masp transaction".to_string()) + }) { + Ok(transaction) => { + acc.push(transaction); + Ok(acc) + } + Err(e) => Err(e), + } + }) +} + +/// Extract the relevant shield portions from the IBC messages in [`Tx`] +fn extract_masp_tx_from_ibc_message( + tx: &Tx, +) -> Result, Error> { + let mut masp_txs = Vec::new(); + for cmt in &tx.header.batch { + let tx_data = tx.data(cmt).ok_or_else(|| { + Error::Other("Missing transaction data".to_string()) + })?; + let ibc_msg = decode_message(&tx_data) + .map_err(|_| Error::Other("Invalid IBC message".to_string()))?; + if let IbcMessage::Envelope(ref envelope) = ibc_msg { + if let Some(masp_tx) = extract_masp_tx_from_envelope(envelope) { + masp_txs.push(masp_tx); + } + } + } + if !masp_txs.is_empty() { + Ok(masp_txs) + } else { + Err(Error::Other( + "IBC message doesn't have masp transaction".to_string(), + )) + } +} + // Retrieves all the indexes at the specified height which refer // to a valid masp transaction. If an index is given, it filters only the // transactions with an index equal or greater to the provided one. @@ -3434,66 +3561,499 @@ pub mod fs { } } -/// A enum to indicate how to log sync progress depending on -/// whether sync is currently fetch or scanning blocks. -#[derive(Debug, Copy, Clone)] -pub enum ProgressType { - /// Fetch - Fetch, - /// Scan - Scan, -} - -#[allow(missing_docs)] -pub trait ProgressLogger { - type Fetch: Iterator; - type Scan: Iterator; - - fn io(&self) -> &IO; - - fn fetch(&self, items: I) -> Self::Fetch - where - I: IntoIterator; - - fn scan(&self, items: I) -> Self::Scan - where - I: IntoIterator; -} - -/// The default type for logging sync progress. -#[derive(Debug, Clone)] -pub struct DefaultLogger<'io, IO: Io> { - io: &'io IO, -} - -impl<'io, IO: Io> DefaultLogger<'io, IO> { - /// Initialize default logger - pub fn new(io: &'io IO) -> Self { - Self { io } - } -} - -impl<'io, IO: Io> ProgressLogger for DefaultLogger<'io, IO> { - type Fetch = as IntoIterator>::IntoIter; - type Scan = as IntoIterator>::IntoIter; +#[cfg(test)] +mod test_shielded_sync { + use core::str::FromStr; + use std::collections::BTreeSet; + + use borsh::BorshDeserialize; + use masp_primitives::transaction::Transaction; + use masp_primitives::zip32::ExtendedFullViewingKey; + use namada_core::masp::ExtendedViewingKey; + use namada_core::storage::TxIndex; + use namada_tx::IndexedTx; + use tempfile::tempdir; + + use crate::control_flow::testing_shutdown_signal; + use crate::error::Error; + use crate::io::StdIo; + use crate::masp::fs::FsShieldedUtils; + use crate::masp::test_utils::{ + test_client, TestUnscannedTracker, TestingMaspClient, + }; + use crate::masp::utils::{DefaultTracker, ProgressTracker, RetryStrategy}; + + // A viewing key derived from A_SPENDING_KEY + pub const AA_VIEWING_KEY: &str = "zvknam1qqqqqqqqqqqqqq9v0sls5r5de7njx8ehu49pqgmqr9ygelg87l5x8y4s9r0pjlvu6x74w9gjpw856zcu826qesdre628y6tjc26uhgj6d9zqur9l5u3p99d9ggc74ald6s8y3sdtka74qmheyqvdrasqpwyv2fsmxlz57lj4grm2pthzj3sflxc0jx0edrakx3vdcngrfjmru8ywkguru8mxss2uuqxdlglaz6undx5h8w7g70t2es850g48xzdkqay5qs0yw06rtxcpjdve6"; + + /// A serialized transaction that will work for testing. + /// Would love to do this in a less opaque fashion, but + /// making these things is a misery not worth my time. + /// + /// This a tx sending 1 BTC from Albert to Albert's PA, + /// that was extracted from a masp integration test. + /// + /// ```ignore + /// vec![ + /// "shield", + /// "--source", + /// ALBERT, + /// "--target", + /// AA_PAYMENT_ADDRESS, + /// "--token", + /// BTC, + /// "--amount", + /// "1", + /// "--node", + /// validator_one_rpc, + /// ] + /// ``` + fn arbitrary_masp_tx() -> Transaction { + Transaction::try_from_slice(&[ + 2, 0, 0, 0, 10, 39, 167, 38, 166, 117, 255, 233, 0, 0, 0, 0, 255, + 255, 255, 255, 1, 162, 120, 217, 193, 173, 117, 92, 126, 107, 199, + 182, 72, 95, 60, 122, 52, 9, 134, 72, 4, 167, 41, 187, 171, 17, + 124, 114, 84, 191, 75, 37, 2, 0, 225, 245, 5, 0, 0, 0, 0, 93, 213, + 181, 21, 38, 32, 230, 52, 155, 4, 203, 26, 70, 63, 59, 179, 142, 7, + 72, 76, 0, 0, 0, 1, 132, 100, 41, 23, 128, 97, 116, 40, 195, 40, + 46, 55, 79, 106, 234, 32, 4, 216, 106, 88, 173, 65, 140, 99, 239, + 71, 103, 201, 111, 149, 166, 13, 73, 224, 253, 98, 27, 199, 11, + 142, 56, 214, 4, 96, 35, 72, 83, 86, 194, 107, 163, 194, 238, 37, + 19, 171, 8, 129, 53, 246, 64, 220, 155, 47, 177, 165, 109, 232, 84, + 247, 128, 184, 40, 26, 113, 196, 190, 181, 57, 213, 45, 144, 46, + 12, 145, 128, 169, 116, 65, 51, 208, 239, 50, 217, 224, 98, 179, + 53, 18, 130, 183, 114, 225, 21, 34, 175, 144, 125, 239, 240, 82, + 100, 174, 1, 192, 32, 187, 208, 205, 31, 108, 59, 87, 201, 148, + 214, 244, 255, 8, 150, 100, 225, 11, 245, 221, 170, 85, 241, 110, + 50, 90, 151, 210, 169, 41, 3, 23, 160, 196, 117, 211, 217, 121, 9, + 42, 236, 19, 149, 94, 62, 163, 222, 172, 128, 197, 56, 100, 233, + 227, 239, 60, 182, 191, 55, 148, 17, 0, 168, 198, 84, 87, 191, 89, + 229, 9, 129, 165, 98, 200, 127, 225, 192, 58, 0, 92, 104, 97, 26, + 125, 169, 209, 40, 170, 29, 93, 16, 114, 174, 23, 233, 218, 112, + 26, 175, 196, 198, 197, 159, 167, 157, 16, 232, 247, 193, 44, 82, + 143, 238, 179, 77, 87, 153, 3, 33, 207, 215, 142, 104, 179, 17, + 252, 148, 215, 150, 76, 56, 169, 13, 240, 4, 195, 221, 45, 250, 24, + 51, 243, 174, 176, 47, 117, 38, 1, 124, 193, 191, 55, 11, 164, 97, + 83, 188, 92, 202, 229, 106, 236, 165, 85, 236, 95, 255, 28, 71, 18, + 173, 202, 47, 63, 226, 129, 203, 154, 54, 155, 177, 161, 106, 210, + 220, 193, 142, 44, 105, 46, 164, 83, 136, 63, 24, 172, 157, 117, 9, + 202, 99, 223, 144, 36, 26, 154, 84, 175, 119, 12, 102, 71, 33, 14, + 131, 250, 86, 215, 153, 18, 94, 213, 61, 196, 67, 132, 204, 89, + 235, 241, 188, 147, 236, 92, 46, 83, 169, 236, 12, 34, 33, 65, 243, + 18, 23, 29, 41, 252, 207, 17, 196, 55, 56, 141, 158, 116, 227, 195, + 159, 233, 72, 26, 69, 72, 213, 50, 101, 161, 127, 213, 35, 210, + 223, 201, 219, 198, 192, 125, 129, 222, 178, 241, 116, 59, 255, 72, + 163, 46, 21, 222, 74, 202, 117, 217, 22, 188, 203, 2, 150, 38, 78, + 78, 250, 45, 36, 225, 240, 227, 115, 33, 114, 189, 25, 9, 219, 239, + 57, 103, 19, 109, 11, 5, 156, 43, 35, 53, 219, 250, 215, 185, 173, + 11, 101, 221, 29, 130, 74, 110, 225, 183, 77, 13, 52, 90, 183, 93, + 212, 175, 132, 21, 229, 109, 188, 124, 103, 3, 39, 174, 140, 115, + 67, 49, 100, 231, 129, 32, 24, 201, 196, 247, 33, 155, 20, 139, 34, + 3, 183, 12, 164, 6, 10, 219, 207, 151, 160, 4, 201, 160, 12, 156, + 82, 142, 226, 19, 134, 144, 53, 220, 140, 61, 74, 151, 129, 102, + 214, 73, 107, 147, 4, 98, 68, 79, 225, 103, 242, 187, 170, 102, + 225, 114, 4, 87, 96, 7, 212, 150, 127, 211, 158, 54, 86, 15, 191, + 21, 116, 202, 195, 60, 65, 134, 22, 2, 44, 133, 64, 181, 121, 66, + 218, 227, 72, 148, 63, 108, 227, 33, 66, 239, 77, 127, 139, 31, 16, + 150, 119, 198, 119, 229, 88, 188, 113, 80, 222, 86, 122, 181, 142, + 186, 130, 125, 236, 166, 95, 134, 243, 128, 65, 169, 33, 65, 73, + 182, 183, 156, 248, 39, 46, 199, 181, 85, 96, 126, 155, 189, 10, + 211, 145, 230, 94, 69, 232, 74, 87, 211, 46, 216, 30, 24, 38, 104, + 192, 165, 28, 73, 36, 227, 194, 41, 168, 5, 181, 176, 112, 67, 92, + 158, 212, 129, 207, 182, 223, 59, 185, 84, 210, 147, 32, 29, 61, + 56, 185, 21, 156, 114, 34, 115, 29, 25, 89, 152, 56, 55, 238, 43, + 0, 114, 89, 79, 95, 104, 143, 180, 51, 53, 108, 223, 236, 59, 47, + 188, 174, 196, 101, 180, 207, 162, 198, 104, 52, 67, 132, 178, 9, + 40, 10, 88, 206, 25, 132, 60, 136, 13, 213, 223, 81, 196, 131, 118, + 15, 53, 125, 165, 177, 170, 170, 17, 94, 53, 151, 51, 16, 170, 23, + 118, 255, 26, 46, 47, 37, 73, 165, 26, 43, 10, 221, 4, 132, 15, 78, + 214, 161, 3, 220, 10, 87, 139, 85, 61, 39, 131, 242, 216, 235, 52, + 93, 46, 180, 196, 151, 54, 207, 80, 223, 90, 252, 77, 10, 122, 175, + 229, 7, 144, 41, 1, 162, 120, 217, 193, 173, 117, 92, 126, 107, + 199, 182, 72, 95, 60, 122, 52, 9, 134, 72, 4, 167, 41, 187, 171, + 17, 124, 114, 84, 191, 75, 37, 2, 0, 31, 10, 250, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 151, 241, 211, 167, + 49, 151, 215, 148, 38, 149, 99, 140, 79, 169, 172, 15, 195, 104, + 140, 79, 151, 116, 185, 5, 161, 78, 58, 63, 23, 27, 172, 88, 108, + 85, 232, 63, 249, 122, 26, 239, 251, 58, 240, 10, 219, 34, 198, + 187, 147, 224, 43, 96, 82, 113, 159, 96, 125, 172, 211, 160, 136, + 39, 79, 101, 89, 107, 208, 208, 153, 32, 182, 26, 181, 218, 97, + 187, 220, 127, 80, 73, 51, 76, 241, 18, 19, 148, 93, 87, 229, 172, + 125, 5, 93, 4, 43, 126, 2, 74, 162, 178, 240, 143, 10, 145, 38, 8, + 5, 39, 45, 197, 16, 81, 198, 228, 122, 212, 250, 64, 59, 2, 180, + 81, 11, 100, 122, 227, 209, 119, 11, 172, 3, 38, 168, 5, 187, 239, + 212, 128, 86, 200, 193, 33, 189, 184, 151, 241, 211, 167, 49, 151, + 215, 148, 38, 149, 99, 140, 79, 169, 172, 15, 195, 104, 140, 79, + 151, 116, 185, 5, 161, 78, 58, 63, 23, 27, 172, 88, 108, 85, 232, + 63, 249, 122, 26, 239, 251, 58, 240, 10, 219, 34, 198, 187, 37, + 197, 248, 90, 113, 62, 149, 117, 145, 118, 42, 241, 60, 208, 83, + 57, 96, 143, 17, 128, 92, 118, 158, 188, 77, 37, 184, 164, 135, + 246, 196, 57, 198, 106, 139, 33, 15, 207, 0, 101, 143, 92, 178, + 132, 19, 106, 221, 246, 176, 100, 20, 114, 26, 55, 163, 14, 173, + 255, 121, 181, 58, 121, 140, 3, + ]) + .expect("Test failed") + } + + /// Test that if fetching fails before finishing, + /// we re-establish the fetching process + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_retry_fetch() { + let temp_dir = tempdir().unwrap(); + let mut shielded_ctx = + FsShieldedUtils::new(temp_dir.path().to_path_buf()); + let (client, masp_tx_sender) = test_client(2.into()); + let io = StdIo; + let progress = DefaultTracker::new(&io); + let vk = ExtendedFullViewingKey::from( + ExtendedViewingKey::from_str(AA_VIEWING_KEY).expect("Test failed"), + ) + .fvk + .vk; + masp_tx_sender.send(None).expect("Test failed"); + + // we first test that with no retries, a fetching failure + // stops process + let result = shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(1), + 0, + &[], + &[vk], + ) + .await + .unwrap_err(); + match result { + Error::Other(msg) => assert_eq!( + msg.as_str(), + "After retrying, could not fetch all MASP txs." + ), + other => panic!("{:?} does not match Error::Other(_)", other), + } - fn io(&self) -> &IO { - self.io - } + // We now have a fetch failure followed by two successful + // masp txs from the same block. + let masp_tx = arbitrary_masp_tx(); + masp_tx_sender.send(None).expect("Test failed"); + masp_tx_sender + .send(Some(( + IndexedTx { + height: 1.into(), + index: TxIndex(1), + }, + vec![masp_tx.clone()], + ))) + .expect("Test failed"); + masp_tx_sender + .send(Some(( + IndexedTx { + height: 1.into(), + index: TxIndex(2), + }, + vec![masp_tx.clone()], + ))) + .expect("Test failed"); + + // This should complete successfully + shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(2), + 0, + &[], + &[vk], + ) + .await + .expect("Test failed"); + + shielded_ctx.load_confirmed().await.expect("Test failed"); + let keys = shielded_ctx + .tx_note_map + .keys() + .cloned() + .collect::>(); + let expected = BTreeSet::from([ + IndexedTx { + height: 1.into(), + index: TxIndex(1), + }, + IndexedTx { + height: 1.into(), + index: TxIndex(2), + }, + ]); + + assert_eq!(keys, expected); + assert_eq!( + *shielded_ctx.vk_heights[&vk].as_ref().unwrap(), + IndexedTx { + height: 1.into(), + index: TxIndex(2), + } + ); + assert_eq!(shielded_ctx.note_map.len(), 2); + } + + /// Test that the progress tracker correctly keeps + /// track of how many blocks there are left to fetch + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_left_to_fetch() { + let temp_dir = tempdir().unwrap(); + let mut shielded_ctx = + FsShieldedUtils::new(temp_dir.path().to_path_buf()); + let (client, masp_tx_sender) = test_client(2.into()); + let io = StdIo; + let progress = DefaultTracker::new(&io); + let vk = ExtendedFullViewingKey::from( + ExtendedViewingKey::from_str(AA_VIEWING_KEY).expect("Test failed"), + ) + .fvk + .vk; + let masp_tx = arbitrary_masp_tx(); + + // first fetch no blocks + masp_tx_sender.send(None).expect("Test failed"); + shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(1), + 0, + &[], + &[vk], + ) + .await + .unwrap_err(); + assert_eq!(progress.left_to_fetch(), 2); - fn fetch(&self, items: I) -> Self::Fetch - where - I: IntoIterator, - { - let items: Vec<_> = items.into_iter().collect(); - items.into_iter() - } + // fetch one of the two blocks + masp_tx_sender + .send(Some(( + IndexedTx { + height: 1.into(), + index: Default::default(), + }, + vec![masp_tx.clone()], + ))) + .expect("Test failed"); + masp_tx_sender.send(None).expect("Test failed"); + shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(1), + 0, + &[], + &[vk], + ) + .await + .unwrap_err(); + assert_eq!(progress.left_to_fetch(), 1); + + // fetch no blocks + masp_tx_sender.send(None).expect("Test failed"); + shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(1), + 0, + &[], + &[vk], + ) + .await + .unwrap_err(); + assert_eq!(progress.left_to_fetch(), 1); + + // fetch no blocks, but increase the latest block height + // thus the amount left to fetch should increase + let (client, masp_tx_sender) = test_client(3.into()); + masp_tx_sender.send(None).expect("Test failed"); + shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(1), + 0, + &[], + &[vk], + ) + .await + .unwrap_err(); + assert_eq!(progress.left_to_fetch(), 2); - fn scan(&self, items: I) -> Self::Scan - where - I: IntoIterator, - { - let items: Vec<_> = items.into_iter().collect(); - items.into_iter() + // fetch remaining block + masp_tx_sender + .send(Some(( + IndexedTx { + height: 2.into(), + index: Default::default(), + }, + vec![masp_tx.clone()], + ))) + .expect("Test failed"); + masp_tx_sender + .send(Some(( + IndexedTx { + height: 3.into(), + index: Default::default(), + }, + vec![masp_tx.clone()], + ))) + .expect("Test failed"); + // this should not produce an error since we have fetched + // all expected blocks + masp_tx_sender.send(None).expect("Test failed"); + shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(1), + 0, + &[], + &[vk], + ) + .await + .expect("Test failed"); + assert_eq!(progress.left_to_fetch(), 0); + } + + /// Test that if we don't scan all fetched notes, they + /// are persisted in a cache + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_unscanned_cache() { + let (client, masp_tx_sender) = test_client(2.into()); + let temp_dir = tempdir().unwrap(); + let mut shielded_ctx = + FsShieldedUtils::new(temp_dir.path().to_path_buf()); + + let io = StdIo; + let progress = TestUnscannedTracker::new(&io); + let vk = ExtendedFullViewingKey::from( + ExtendedViewingKey::from_str(AA_VIEWING_KEY).expect("Test failed"), + ) + .fvk + .vk; + + // the fetched txs + let masp_tx = arbitrary_masp_tx(); + masp_tx_sender + .send(Some(( + IndexedTx { + height: 1.into(), + index: TxIndex(1), + }, + vec![masp_tx.clone()], + ))) + .expect("Test failed"); + masp_tx_sender + .send(Some(( + IndexedTx { + height: 1.into(), + index: TxIndex(2), + }, + vec![masp_tx.clone()], + ))) + .expect("Test failed"); + + shielded_ctx + .fetch( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Times(2), + 0, + &[], + &[vk], + ) + .await + .expect("Test failed"); + + shielded_ctx.load_confirmed().await.expect("Test failed"); + let keys = shielded_ctx + .unscanned + .txs + .lock() + .unwrap() + .keys() + .cloned() + .collect::>(); + let expected = vec![IndexedTx { + height: 1.into(), + index: TxIndex(2), + }]; + assert_eq!(keys, expected); + } + + /// Test that if fetching gets interrupted, + /// we persist the fetched notes in a cache + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_fetching_interrupt() { + let temp_dir = tempdir().unwrap(); + let mut shielded_ctx = + FsShieldedUtils::new(temp_dir.path().to_path_buf()); + let (client, masp_tx_sender) = test_client(2.into()); + let io = StdIo; + let progress = DefaultTracker::new(&io); + let vk = ExtendedFullViewingKey::from( + ExtendedViewingKey::from_str(AA_VIEWING_KEY).expect("Test failed"), + ) + .fvk + .vk; + let (shutdown_send, shutdown_signal) = testing_shutdown_signal(); + // the fetched txs + let masp_tx = arbitrary_masp_tx(); + // mock that we have already fetched a note + let expected = ( + IndexedTx { + height: 1.into(), + index: TxIndex(1), + }, + vec![masp_tx], + ); + masp_tx_sender + .send(Some(expected.clone())) + .expect("Test failed"); + shutdown_send.send(()).expect("Test failed"); + let Error::Interrupt(ref proc) = shielded_ctx + .fetch_aux( + TestingMaspClient::new(&client), + &progress, + None, + None, + RetryStrategy::Forever, + 0, + &[], + &[vk], + shutdown_signal, + ) + .await + .expect_err("Test failed") + else { + panic!("Test failed") + }; + assert_eq!(proc, "[Testing::Fetch]"); + shielded_ctx.load_confirmed().await.expect("Test failed"); + let entry = shielded_ctx.unscanned.pop_first().expect("Test failed"); + assert_eq!(entry, expected); + assert!(shielded_ctx.unscanned.is_empty()); } } diff --git a/crates/sdk/src/masp/test_utils.rs b/crates/sdk/src/masp/test_utils.rs new file mode 100644 index 0000000000..bc21329793 --- /dev/null +++ b/crates/sdk/src/masp/test_utils.rs @@ -0,0 +1,214 @@ +use std::ops::{Deref, DerefMut}; +use std::sync::{Arc, Mutex}; + +use namada_core::storage::BlockHeight; +use namada_state::LastBlock; +use tendermint_rpc::SimpleRequest; + +use crate::control_flow::ShutdownSignal; +use crate::error::Error; +use crate::io::Io; +use crate::masp::utils::{ + FetchQueueSender, IterProgress, MaspClient, PeekableIter, ProgressTracker, +}; +use crate::masp::IndexedNoteEntry; +use crate::queries::testing::TestClient; +use crate::queries::{Client, EncodedResponseQuery, Rpc, RPC}; + +/// A client for testing the shielded-sync functionality +pub struct TestingClient { + /// An actual mocked client for querying + inner: TestClient, + /// Used to inject a channel that we control into + /// the fetch algorithm. The option is to mock connection + /// failures. + next_masp_txs: flume::Receiver>, +} + +impl Deref for TestingClient { + type Target = TestClient; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for TestingClient { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +#[cfg(any(test, feature = "async-client"))] +#[cfg_attr(feature = "async-send", async_trait::async_trait)] +#[cfg_attr(not(feature = "async-send"), async_trait::async_trait(?Send))] +impl Client for TestingClient { + type Error = std::io::Error; + + async fn request( + &self, + path: String, + data: Option>, + height: Option, + prove: bool, + ) -> Result { + self.inner.request(path, data, height, prove).await + } + + async fn perform( + &self, + request: R, + ) -> Result + where + R: SimpleRequest, + { + self.inner.perform(request).await + } +} + +/// Creat a test client for unit testing as well +/// as a channel for communicating with it. +pub fn test_client( + last_height: BlockHeight, +) -> (TestingClient, flume::Sender>) { + let (sender, recv) = flume::unbounded(); + let mut client = TestClient::new(RPC); + client.state.in_mem_mut().last_block = Some(LastBlock { + height: last_height, + time: Default::default(), + }); + ( + TestingClient { + inner: client, + next_masp_txs: recv, + }, + sender, + ) +} + +/// A client for unit tests. It "fetches" a new note +/// when a channel controlled by the unit test sends +/// it one. +#[derive(Clone)] +pub struct TestingMaspClient<'a> { + client: &'a TestingClient, +} + +impl<'client> TestingMaspClient<'client> { + /// Create a new [`TestingMaspClient`] given an rpc client + /// [`TestingClient`]. + pub const fn new(client: &'client TestingClient) -> Self { + Self { client } + } +} + +impl<'a> MaspClient<'a, TestingClient> for TestingMaspClient<'a> { + fn rpc_client(&self) -> &TestingClient { + self.client + } + + async fn fetch_shielded_transfers( + &self, + progress: &impl ProgressTracker, + shutdown_signal: &mut ShutdownSignal, + mut tx_sender: FetchQueueSender, + from: u64, + to: u64, + ) -> Result<(), Error> { + // N.B. this assumes one masp tx per block + let mut fetch_iter = progress.fetch(from..=to); + + while fetch_iter.peek().is_some() { + let next_tx = self + .client + .next_masp_txs + .recv() + .expect("Test failed") + .ok_or_else(|| { + Error::Other( + "Connection to fetch MASP txs failed".to_string(), + ) + })?; + tx_sender.send(next_tx); + if shutdown_signal.received() { + return Err(Error::Interrupt("[Testing::Fetch]".to_string())); + } + fetch_iter.next(); + } + Ok(()) + } +} + +/// An iterator that yields its first element only +struct YieldOnceIterator { + first: Option, +} + +impl YieldOnceIterator { + fn new(mut iter: T) -> Self + where + T: Iterator, + { + let first = iter.next(); + Self { first } + } +} + +impl Iterator for YieldOnceIterator { + type Item = IndexedNoteEntry; + + fn next(&mut self) -> Option { + self.first.take() + } +} + +/// A progress tracker that only scans the first fetched +/// block. The rest are left in the unscanned cache +/// for the purposes of testing the persistence of +/// this cache. +pub(super) struct TestUnscannedTracker<'io, IO> { + io: &'io IO, + progress: Arc>, +} + +impl<'io, IO: Io> TestUnscannedTracker<'io, IO> { + pub fn new(io: &'io IO) -> Self { + Self { + io, + progress: Arc::new(Mutex::new(Default::default())), + } + } +} + +impl<'io, IO: Io> ProgressTracker for TestUnscannedTracker<'io, IO> { + fn io(&self) -> &IO { + self.io + } + + fn fetch(&self, items: I) -> impl PeekableIter + where + I: Iterator, + { + { + let mut locked = self.progress.lock().unwrap(); + locked.length = items.size_hint().0; + } + crate::masp::utils::DefaultFetchIterator { + inner: items, + progress: self.progress.clone(), + peeked: None, + } + } + + fn scan(&self, items: I) -> impl Iterator + Send + where + I: Iterator + Send, + { + YieldOnceIterator::new(items) + } + + fn left_to_fetch(&self) -> usize { + let locked = self.progress.lock().unwrap(); + locked.length - locked.index + } +} diff --git a/crates/sdk/src/masp/utils.rs b/crates/sdk/src/masp/utils.rs index 6dc740abab..29d4cfeb4f 100644 --- a/crates/sdk/src/masp/utils.rs +++ b/crates/sdk/src/masp/utils.rs @@ -1,12 +1,17 @@ //! Helper functions and types use std::sync::{Arc, Mutex}; + use namada_core::storage::BlockHeight; use namada_tx::{IndexedTx, Tx}; +use crate::control_flow::ShutdownSignal; use crate::error::{Error, QueryError}; use crate::io::Io; -use crate::masp::{extract_masp_tx, get_indexed_masp_events_at_height, IndexedNoteEntry, Unscanned}; +use crate::masp::{ + extract_masp_tx, extract_masp_tx_from_ibc_message, + get_indexed_masp_events_at_height, IndexedNoteEntry, Unscanned, +}; use crate::queries::Client; /// When retrying to fetch all notes in a @@ -25,7 +30,7 @@ impl Iterator for RetryStrategy { fn next(&mut self) -> Option { match self { Self::Forever => Some(()), - Self::Times(ref mut count) => { + Self::Times(count) => { if *count == 0 { None } else { @@ -41,16 +46,17 @@ impl Iterator for RetryStrategy { /// of how shielded-sync fetches the necessary data /// from a remote server. pub trait MaspClient<'client, C: Client> { - /// Create a new [`MaspClient`] given an rpc client. - fn new(client: &'client C) -> Self - where - Self: 'client; + /// Return the wrapped client. + fn rpc_client(&self) -> &C; - /// Fetches shielded transfers + /// Fetch shielded transfers from blocks heights in the range `[from, to]`, + /// keeping track of progress through `progress`. The fetched transfers + /// are sent over to a separate worker through `tx_sender`. #[allow(async_fn_in_trait)] - async fn fetch_shielded_transfer( + async fn fetch_shielded_transfers( &self, progress: &impl ProgressTracker, + shutdown_signal: &mut ShutdownSignal, tx_sender: FetchQueueSender, from: u64, to: u64, @@ -59,26 +65,31 @@ pub trait MaspClient<'client, C: Client> { /// An inefficient MASP client which simply uses a /// client to the blockchain to query it directly. -pub struct LedgerMaspClient<'client, C: Client> { +pub struct LedgerMaspClient<'client, C> { client: &'client C, } -#[cfg(not(target_family = "wasm"))] -impl<'client, C: Client + Sync> MaspClient<'client, C> for LedgerMaspClient<'client, C> - where - LedgerMaspClient<'client, C>: 'client, -{ - fn new(client: &'client C) -> Self - where - Self: 'client, - { +impl<'client, C> LedgerMaspClient<'client, C> { + /// Create a new [`MaspClient`] given an rpc client. + pub const fn new(client: &'client C) -> Self { Self { client } } +} +#[cfg(not(target_family = "wasm"))] +impl<'client, C: Client + Sync> MaspClient<'client, C> + for LedgerMaspClient<'client, C> +where + LedgerMaspClient<'client, C>: 'client, +{ + fn rpc_client(&self) -> &C { + self.client + } - async fn fetch_shielded_transfer( + async fn fetch_shielded_transfers( &self, progress: &impl ProgressTracker, + shutdown_signal: &mut ShutdownSignal, mut tx_sender: FetchQueueSender, from: u64, to: u64, @@ -87,6 +98,11 @@ impl<'client, C: Client + Sync> MaspClient<'client, C> for LedgerMaspClient<'cli let mut fetch_iter = progress.fetch(from..=to); while let Some(height) = fetch_iter.peek() { + if shutdown_signal.received() { + return Err(Error::Interrupt( + "[ShieldedSync::Fetching]".to_string(), + )); + } let height = *height; if tx_sender.contains_height(height) { fetch_iter.next(); @@ -124,8 +140,11 @@ impl<'client, C: Client + Sync> MaspClient<'client, C> for LedgerMaspClient<'cli let tx = Tx::try_from(block[idx.0 as usize].as_ref()) .map_err(|e| Error::Other(e.to_string()))?; let extracted_masp_txs = - extract_masp_tx(&tx, &masp_sections_refs).await?; - + if let Some(masp_sections_refs) = masp_sections_refs { + extract_masp_tx(&tx, &masp_sections_refs).await? + } else { + extract_masp_tx_from_ibc_message(&tx)? + }; tx_sender.send(( IndexedTx { height: height.into(), @@ -141,8 +160,6 @@ impl<'client, C: Client + Sync> MaspClient<'client, C> for LedgerMaspClient<'cli } } - - /// A channel-like struct for "sending" newly fetched blocks /// to the scanning algorithm. /// @@ -252,8 +269,8 @@ pub trait PeekableIter { } impl PeekableIter for std::iter::Peekable - where - I: Iterator, +where + I: Iterator, { fn peek(&mut self) -> Option<&J> { self.peek() @@ -278,16 +295,16 @@ pub trait ProgressTracker { /// Return an iterator to fetched shielded transfers fn fetch(&self, items: I) -> impl PeekableIter - where - I: Iterator; + where + I: Iterator; /// Return an iterator over MASP transactions to be scanned fn scan( &self, items: I, ) -> impl Iterator + Send - where - I: Iterator + Send; + where + I: Iterator + Send; /// The number of blocks that need to be fetched fn left_to_fetch(&self) -> usize; @@ -317,8 +334,8 @@ pub(super) struct IterProgress { } pub(super) struct DefaultFetchIterator - where - I: Iterator, +where + I: Iterator, { pub inner: I, pub progress: Arc>, @@ -326,8 +343,8 @@ pub(super) struct DefaultFetchIterator } impl PeekableIter for DefaultFetchIterator - where - I: Iterator, +where + I: Iterator, { fn peek(&mut self) -> Option<&u64> { if self.peeked.is_none() { @@ -351,8 +368,8 @@ impl<'io, IO: Io> ProgressTracker for DefaultTracker<'io, IO> { } fn fetch(&self, items: I) -> impl PeekableIter - where - I: Iterator, + where + I: Iterator, { { let mut locked = self.progress.lock().unwrap(); @@ -366,8 +383,8 @@ impl<'io, IO: Io> ProgressTracker for DefaultTracker<'io, IO> { } fn scan(&self, items: I) -> impl Iterator + Send - where - I: IntoIterator, + where + I: IntoIterator, { let items: Vec<_> = items.into_iter().collect(); items.into_iter() diff --git a/crates/sdk/src/queries/mod.rs b/crates/sdk/src/queries/mod.rs index ed5fdb1836..ce338c89fa 100644 --- a/crates/sdk/src/queries/mod.rs +++ b/crates/sdk/src/queries/mod.rs @@ -98,7 +98,7 @@ pub fn require_no_data(request: &RequestQuery) -> namada_storage::Result<()> { /// Queries testing helpers #[cfg(any(test, feature = "testing"))] -mod testing { +pub(crate) mod testing { use borsh_ext::BorshSerializeExt; use namada_state::testing::TestState; use tendermint_rpc::Response; diff --git a/wasm/Cargo.lock b/wasm/Cargo.lock index 94a4eac8da..ea4908e98e 100644 --- a/wasm/Cargo.lock +++ b/wasm/Cargo.lock @@ -2103,6 +2103,18 @@ dependencies = [ "paste", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -4025,6 +4037,7 @@ dependencies = [ "ethers", "eyre", "fd-lock", + "flume", "futures", "itertools 0.12.1", "jubjub", @@ -4317,6 +4330,15 @@ dependencies = [ "sha2 0.9.9", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom 0.2.15", +] + [[package]] name = "nonempty" version = "0.7.0" @@ -5992,6 +6014,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "spki" diff --git a/wasm_for_tests/Cargo.lock b/wasm_for_tests/Cargo.lock index 2e7778bbe4..f844985eac 100644 --- a/wasm_for_tests/Cargo.lock +++ b/wasm_for_tests/Cargo.lock @@ -2167,6 +2167,18 @@ dependencies = [ "paste", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -4050,6 +4062,7 @@ dependencies = [ "ethers", "eyre", "fd-lock", + "flume", "futures", "itertools 0.12.1", "jubjub", @@ -4332,6 +4345,15 @@ dependencies = [ "sha2 0.9.9", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom 0.2.15", +] + [[package]] name = "nonempty" version = "0.7.0" @@ -5999,6 +6021,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "spki"