Skip to content

Commit

Permalink
fix: disable P2P transaction negotiation while recovery is in progress
Browse files Browse the repository at this point in the history
Some weird behaviour was observed when a wallet would be busy with recovery and then receive transaction negotiation messages, either directly or via SAF.

The Recovery process is updating the Key Manager Indices and looking for commitments on the blockchain so to allow transaction negotiation during this time is dangerous as it might put duplicate commitments into the db and reuse spending keys.

This PR checks for the db key/value used to indicate Recovery progress before handling a transaction negotiation p2p message and if it is there the message is ignored with a log.
  • Loading branch information
philipr-za committed Aug 30, 2021
1 parent 20618b6 commit 8b91216
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 248 deletions.
25 changes: 3 additions & 22 deletions base_layer/wallet/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use crate::{
contacts_service::storage::sqlite_db::ContactsServiceSqliteDatabase,
output_manager_service::storage::sqlite_db::OutputManagerSqliteDatabase,
storage::{sqlite_db::WalletSqliteDatabase, sqlite_utilities::run_migration_and_create_sqlite_connection},
transaction_service::storage::sqlite_db::TransactionServiceSqliteDatabase,
};
use crate::storage::sqlite_utilities::{run_migration_and_create_sqlite_connection, WalletDbConnection};
use core::iter;
use rand::{distributions::Alphanumeric, rngs::OsRng, Rng};
use std::path::Path;
Expand All @@ -39,15 +34,7 @@ pub fn random_string(len: usize) -> String {
}

/// A test helper to create a temporary wallet service databases
pub fn make_wallet_databases(
path: Option<String>,
) -> (
WalletSqliteDatabase,
TransactionServiceSqliteDatabase,
OutputManagerSqliteDatabase,
ContactsServiceSqliteDatabase,
Option<TempDir>,
) {
pub fn make_wallet_database_connection(path: Option<String>) -> (WalletDbConnection, Option<TempDir>) {
let (path_string, temp_dir): (String, Option<TempDir>) = if let Some(p) = path {
(p, None)
} else {
Expand All @@ -61,11 +48,5 @@ pub fn make_wallet_databases(

let connection =
run_migration_and_create_sqlite_connection(&db_path.to_str().expect("Should be able to make path")).unwrap();
(
WalletSqliteDatabase::new(connection.clone(), None).expect("Should be able to create wallet database"),
TransactionServiceSqliteDatabase::new(connection.clone(), None),
OutputManagerSqliteDatabase::new(connection.clone(), None),
ContactsServiceSqliteDatabase::new(connection),
temp_dir,
)
(connection, temp_dir)
}
5 changes: 5 additions & 0 deletions base_layer/wallet/src/transaction_service/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use crate::{
error::WalletStorageError,
output_manager_service::{error::OutputManagerError, TxId},
transaction_service::storage::database::DbKey,
};
Expand Down Expand Up @@ -100,6 +101,8 @@ pub enum TransactionServiceError {
TransportChannelError(#[from] TransportChannelError),
#[error("Transaction storage error: `{0}`")]
TransactionStorageError(#[from] TransactionStorageError),
#[error("Wallet storage error: `{0}`")]
WalletStorageError(#[from] WalletStorageError),
#[error("Invalid message error: `{0}`")]
InvalidMessageError(String),
#[error("Transaction error: `{0}`")]
Expand Down Expand Up @@ -140,6 +143,8 @@ pub enum TransactionServiceError {
ByteArrayError(#[from] tari_crypto::tari_utilities::ByteArrayError),
#[error("Transaction Service Error: `{0}`")]
ServiceError(String),
#[error("Wallet Recovery in progress so Transaction Service Messaging Requests ignored")]
WalletRecoveryInProgress,
}

#[derive(Debug, Error)]
Expand Down
38 changes: 27 additions & 11 deletions base_layer/wallet/src/transaction_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod tasks;

use crate::{
output_manager_service::handle::OutputManagerHandle,
storage::database::{WalletBackend, WalletDatabase},
transaction_service::{
config::TransactionServiceConfig,
handle::TransactionServiceHandle,
Expand Down Expand Up @@ -64,32 +65,39 @@ use tokio::sync::broadcast;
const LOG_TARGET: &str = "wallet::transaction_service";
const SUBSCRIPTION_LABEL: &str = "Transaction Service";

pub struct TransactionServiceInitializer<T>
where T: TransactionBackend
pub struct TransactionServiceInitializer<T, W>
where
T: TransactionBackend,
W: WalletBackend,
{
config: TransactionServiceConfig,
subscription_factory: Arc<SubscriptionFactory>,
backend: Option<T>,
tx_backend: Option<T>,
node_identity: Arc<NodeIdentity>,
factories: CryptoFactories,
wallet_database: Option<WalletDatabase<W>>,
}

impl<T> TransactionServiceInitializer<T>
where T: TransactionBackend
impl<T, W> TransactionServiceInitializer<T, W>
where
T: TransactionBackend,
W: WalletBackend,
{
pub fn new(
config: TransactionServiceConfig,
subscription_factory: Arc<SubscriptionFactory>,
backend: T,
node_identity: Arc<NodeIdentity>,
factories: CryptoFactories,
wallet_database: WalletDatabase<W>,
) -> Self {
Self {
config,
subscription_factory,
backend: Some(backend),
tx_backend: Some(backend),
node_identity,
factories,
wallet_database: Some(wallet_database),
}
}

Expand Down Expand Up @@ -161,8 +169,10 @@ where T: TransactionBackend
}

#[async_trait]
impl<T> ServiceInitializer for TransactionServiceInitializer<T>
where T: TransactionBackend + 'static
impl<T, W> ServiceInitializer for TransactionServiceInitializer<T, W>
where
T: TransactionBackend + 'static,
W: WalletBackend + 'static,
{
async fn initialize(&mut self, context: ServiceInitializerContext) -> Result<(), ServiceInitializationError> {
let (sender, receiver) = reply_channel::unbounded();
Expand All @@ -179,11 +189,16 @@ where T: TransactionBackend + 'static
// Register handle before waiting for handles to be ready
context.register_handle(transaction_handle);

let backend = self
.backend
let tx_backend = self
.tx_backend
.take()
.expect("Cannot start Transaction Service without providing a backend");

let wallet_database = self
.wallet_database
.take()
.expect("Cannot start Transaction Service without providing a wallet database");

let node_identity = self.node_identity.clone();
let factories = self.factories.clone();
let config = self.config.clone();
Expand All @@ -195,7 +210,8 @@ where T: TransactionBackend + 'static

let result = TransactionService::new(
config,
TransactionDatabase::new(backend),
TransactionDatabase::new(tx_backend),
wallet_database,
receiver,
transaction_stream,
transaction_reply_stream,
Expand Down
48 changes: 42 additions & 6 deletions base_layer/wallet/src/transaction_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

use crate::{
output_manager_service::{handle::OutputManagerHandle, TxId},
storage::database::{WalletBackend, WalletDatabase},
transaction_service::{
config::TransactionServiceConfig,
error::{TransactionServiceError, TransactionServiceProtocolError},
Expand All @@ -44,6 +45,7 @@ use crate::{
},
},
types::{HashDigest, ValidationRetryStrategy},
utxo_scanner_service::utxo_scanning::RECOVERY_KEY,
};
use chrono::{NaiveDateTime, Utc};
use digest::Digest;
Expand Down Expand Up @@ -109,7 +111,10 @@ pub struct TransactionService<
BNResponseStream,
TBackend,
TTxCancelledStream,
> where TBackend: TransactionBackend + 'static
WBackend,
> where
TBackend: TransactionBackend + 'static,
WBackend: WalletBackend + 'static,
{
config: TransactionServiceConfig,
db: TransactionDatabase<TBackend>,
Expand All @@ -136,22 +141,33 @@ pub struct TransactionService<
timeout_update_publisher: broadcast::Sender<Duration>,
base_node_update_publisher: broadcast::Sender<CommsPublicKey>,
power_mode: PowerMode,
wallet_db: WalletDatabase<WBackend>,
}

#[allow(clippy::too_many_arguments)]
impl<TTxStream, TTxReplyStream, TTxFinalizedStream, BNResponseStream, TBackend, TTxCancelledStream>
TransactionService<TTxStream, TTxReplyStream, TTxFinalizedStream, BNResponseStream, TBackend, TTxCancelledStream>
impl<TTxStream, TTxReplyStream, TTxFinalizedStream, BNResponseStream, TBackend, TTxCancelledStream, WBackend>
TransactionService<
TTxStream,
TTxReplyStream,
TTxFinalizedStream,
BNResponseStream,
TBackend,
TTxCancelledStream,
WBackend,
>
where
TTxStream: Stream<Item = DomainMessage<proto::TransactionSenderMessage>>,
TTxReplyStream: Stream<Item = DomainMessage<proto::RecipientSignedMessage>>,
TTxFinalizedStream: Stream<Item = DomainMessage<proto::TransactionFinalizedMessage>>,
BNResponseStream: Stream<Item = DomainMessage<base_node_proto::BaseNodeServiceResponse>>,
TTxCancelledStream: Stream<Item = DomainMessage<proto::TransactionCancelledMessage>>,
TBackend: TransactionBackend + 'static,
WBackend: WalletBackend + 'static,
{
pub fn new(
config: TransactionServiceConfig,
db: TransactionDatabase<TBackend>,
wallet_db: WalletDatabase<WBackend>,
request_stream: Receiver<
TransactionServiceRequest,
Result<TransactionServiceResponse, TransactionServiceError>,
Expand Down Expand Up @@ -210,6 +226,7 @@ where
timeout_update_publisher,
base_node_update_publisher,
power_mode: PowerMode::Normal,
wallet_db,
}
}

Expand Down Expand Up @@ -318,7 +335,7 @@ where
msg.dht_header.message_tag);
}
Err(e) => {
warn!(target: LOG_TARGET, "Failed to handle incoming Transaction message: {:?} for NodeID: {}, Trace: {}",
warn!(target: LOG_TARGET, "Failed to handle incoming Transaction message: {} for NodeID: {}, Trace: {}",
e, self.node_identity.node_id().short_str(), msg.dht_header.message_tag);
let _ = self.event_publisher.send(Arc::new(TransactionEvent::Error(format!("Error handling \
Transaction Sender message: {:?}", e).to_string())));
Expand Down Expand Up @@ -348,7 +365,7 @@ where
msg.dht_header.message_tag);
},
Err(e) => {
warn!(target: LOG_TARGET, "Failed to handle incoming Transaction Reply message: {:?} \
warn!(target: LOG_TARGET, "Failed to handle incoming Transaction Reply message: {} \
for NodeId: {}, Trace: {}", e, self.node_identity.node_id().short_str(),
msg.dht_header.message_tag);
let _ = self.event_publisher.send(Arc::new(TransactionEvent::Error("Error handling \
Expand Down Expand Up @@ -386,7 +403,7 @@ where
msg.dht_header.message_tag);
},
Err(e) => {
warn!(target: LOG_TARGET, "Failed to handle incoming Transaction Finalized message: {:?} \
warn!(target: LOG_TARGET, "Failed to handle incoming Transaction Finalized message: {} \
for NodeID: {}, Trace: {}", e , self.node_identity.node_id().short_str(),
msg.dht_header.message_tag.as_value());
let _ = self.event_publisher.send(Arc::new(TransactionEvent::Error("Error handling Transaction \
Expand Down Expand Up @@ -885,6 +902,9 @@ where
source_pubkey: CommsPublicKey,
recipient_reply: proto::RecipientSignedMessage,
) -> Result<(), TransactionServiceError> {
// Check if a wallet recovery is in progress, if it is we will ignore this request
self.check_recovery_status().await?;

let recipient_reply: RecipientSignedMessage = recipient_reply
.try_into()
.map_err(TransactionServiceError::InvalidMessageError)?;
Expand Down Expand Up @@ -1187,6 +1207,9 @@ where
traced_message_tag: u64,
join_handles: &mut FuturesUnordered<JoinHandle<Result<u64, TransactionServiceProtocolError>>>,
) -> Result<(), TransactionServiceError> {
// Check if a wallet recovery is in progress, if it is we will ignore this request
self.check_recovery_status().await?;

let sender_message: TransactionSenderMessage = sender_message
.try_into()
.map_err(TransactionServiceError::InvalidMessageError)?;
Expand Down Expand Up @@ -1295,6 +1318,9 @@ where
finalized_transaction: proto::TransactionFinalizedMessage,
join_handles: &mut FuturesUnordered<JoinHandle<Result<u64, TransactionServiceProtocolError>>>,
) -> Result<(), TransactionServiceError> {
// Check if a wallet recovery is in progress, if it is we will ignore this request
self.check_recovery_status().await?;

let tx_id = finalized_transaction.tx_id;
let transaction: Transaction = finalized_transaction
.transaction
Expand Down Expand Up @@ -2031,6 +2057,16 @@ where

Ok(())
}

/// Check if a Recovery Status is currently stored in the databse, this indicates that a wallet recovery is in
/// progress
async fn check_recovery_status(&self) -> Result<(), TransactionServiceError> {
let value = self.wallet_db.get_client_key_value(RECOVERY_KEY.to_owned()).await?;
match value {
None => Ok(()),
Some(_) => Err(TransactionServiceError::WalletRecoveryInProgress),
}
}
}

/// This struct is a collection of the common resources that a protocol in the service requires.
Expand Down
1 change: 1 addition & 0 deletions base_layer/wallet/src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ where
transaction_backend,
node_identity.clone(),
factories.clone(),
wallet_database.clone(),
))
.add_initializer(ContactsServiceInitializer::new(contacts_backend))
.add_initializer(BaseNodeServiceInitializer::new(
Expand Down
Loading

0 comments on commit 8b91216

Please sign in to comment.