Skip to content

Commit

Permalink
fix: reinsert transactions from failed block (#4675)
Browse files Browse the repository at this point in the history
Description
---
After removing transactions from the mempool from a failed validation, reinsert them to keep the valid ones

Motivation and Context
---
Currently, the implementation discards all transactions when validation fails. This is a pretty heavy approach, because the block may be incorrect in the header. There may even be an attack where a malicious user crafts a bad block and removes all transactions in the mempool.

In this approach, the transactions are reinserted into the mempool. If any of them are now invalid (e.g. double spends) they will be discarded

How Has This Been Tested?
---
existing tests, CI
  • Loading branch information
stringhandler authored Sep 16, 2022
1 parent a709282 commit 8030364
Show file tree
Hide file tree
Showing 27 changed files with 190 additions and 142 deletions.
27 changes: 25 additions & 2 deletions applications/tari_base_node/log4rs_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,24 @@ appenders:
count: 5
pattern: "log/base-node/network.{}.log"
encoder:
pattern: "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] [Thread:{I}] {l:5} {m}{n} // {f}:{L}"
pattern: "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] [Thread:{I}] {l:5} {m} // {f}:{L}{n}"
# An appender named "network" that writes to a file with a custom pattern encoder
message_logging:
kind: rolling_file
path: "log/base-node/messages.log"
policy:
kind: compound
trigger:
kind: size
limit: 10mb
roller:
kind: fixed_window
base: 1
count: 5
pattern: "log/base-node/messages.{}.log"
encoder:
pattern: "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] [Thread:{I}] {l:5} {m} // {f}:{L}{n}"


# An appender named "base_layer" that writes to a file with a custom pattern encoder
base_layer:
Expand All @@ -53,7 +70,7 @@ appenders:
count: 5
pattern: "log/base-node/base_layer.{}.log"
encoder:
pattern: "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] [{X(node-public-key)},{X(node-id)}] {l:5} {m}{n} // {f}:{L} "
pattern: "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] [{X(node-public-key)},{X(node-id)}] {l:5} {m} // {f}:{L}{n}"

# An appender named "other" that writes to a file with a custom pattern encoder
other:
Expand Down Expand Up @@ -152,3 +169,9 @@ loggers:
appenders:
- other
additive: false

comms::middleware::message_logging:
# Set to `trace` to retrieve message logging
level: warn
appenders:
- message_logging
11 changes: 11 additions & 0 deletions base_layer/core/src/base_node/service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ async fn handle_incoming_request<B: BlockchainBackend + 'static>(
.send_direct(
origin_public_key,
OutboundDomainMessage::new(&TariMessageType::BaseNodeResponse, message),
"Outbound response message from base node".to_string(),
)
.await?;

Expand Down Expand Up @@ -473,13 +474,22 @@ async fn handle_outbound_request(
node_id: Option<NodeId>,
service_request_timeout: Duration,
) -> Result<(), CommsInterfaceError> {
let debug_info = format!(
"Node request:{} to {}",
&request,
node_id
.as_ref()
.map(|n| n.short_str())
.unwrap_or_else(|| "random".to_string())
);
let request_key = generate_request_key(&mut OsRng);
let service_request = proto::BaseNodeServiceRequest {
request_key,
request: Some(request.try_into().map_err(CommsInterfaceError::InternalError)?),
};

let mut send_msg_params = SendMessageParams::new();
send_msg_params.with_debug_info(debug_info);
match node_id {
Some(node_id) => send_msg_params.direct_node_id(node_id),
None => send_msg_params.random(1),
Expand Down Expand Up @@ -565,6 +575,7 @@ async fn handle_outbound_block(
&TariMessageType::NewBlock,
shared_protos::core::NewBlock::from(new_block),
),
"Outbound new block from base node".to_string(),
)
.await;
if let Err(e) = result {
Expand Down
4 changes: 2 additions & 2 deletions base_layer/core/src/mempool/mempool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ impl Mempool {

/// Insert an unconfirmed transaction into the Mempool.
pub async fn insert(&self, tx: Arc<Transaction>) -> Result<TxStorageResponse, MempoolError> {
self.with_write_access(|storage| storage.insert(tx)).await
self.with_write_access(|storage| Ok(storage.insert(tx))).await
}

/// Inserts all transactions into the mempool.
pub async fn insert_all(&self, transactions: Vec<Arc<Transaction>>) -> Result<(), MempoolError> {
self.with_write_access(|storage| {
for tx in transactions {
storage.insert(tx)?;
storage.insert(tx);
}

Ok(())
Expand Down
41 changes: 22 additions & 19 deletions base_layer/core/src/mempool/mempool_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl MempoolStorage {

/// Insert an unconfirmed transaction into the Mempool. The transaction *MUST* have passed through the validation
/// pipeline already and will thus always be internally consistent by this stage
pub fn insert(&mut self, tx: Arc<Transaction>) -> Result<TxStorageResponse, MempoolError> {
pub fn insert(&mut self, tx: Arc<Transaction>) -> TxStorageResponse {
let tx_id = tx
.body
.kernels()
Expand All @@ -87,41 +87,41 @@ impl MempoolStorage {
"Transaction {} is VALID, inserting in unconfirmed pool", tx_id
);
let weight = self.get_transaction_weighting(0);
self.unconfirmed_pool.insert(tx, None, &weight)?;
Ok(TxStorageResponse::UnconfirmedPool)
self.unconfirmed_pool.insert(tx, None, &weight);
TxStorageResponse::UnconfirmedPool
},
Err(ValidationError::UnknownInputs(dependent_outputs)) => {
if self.unconfirmed_pool.contains_all_outputs(&dependent_outputs) {
let weight = self.get_transaction_weighting(0);
self.unconfirmed_pool.insert(tx, Some(dependent_outputs), &weight)?;
Ok(TxStorageResponse::UnconfirmedPool)
self.unconfirmed_pool.insert(tx, Some(dependent_outputs), &weight);
TxStorageResponse::UnconfirmedPool
} else {
warn!(target: LOG_TARGET, "Validation failed due to unknown inputs");
Ok(TxStorageResponse::NotStoredOrphan)
TxStorageResponse::NotStoredOrphan
}
},
Err(ValidationError::ContainsSTxO) => {
warn!(target: LOG_TARGET, "Validation failed due to already spent input");
Ok(TxStorageResponse::NotStoredAlreadySpent)
TxStorageResponse::NotStoredAlreadySpent
},
Err(ValidationError::MaturityError) => {
warn!(target: LOG_TARGET, "Validation failed due to maturity error");
Ok(TxStorageResponse::NotStoredTimeLocked)
TxStorageResponse::NotStoredTimeLocked
},
Err(ValidationError::ConsensusError(msg)) => {
warn!(target: LOG_TARGET, "Validation failed due to consensus rule: {}", msg);
Ok(TxStorageResponse::NotStoredConsensus)
TxStorageResponse::NotStoredConsensus
},
Err(ValidationError::DuplicateKernelError(msg)) => {
debug!(
target: LOG_TARGET,
"Validation failed due to already mined kernel: {}", msg
);
Ok(TxStorageResponse::NotStoredConsensus)
TxStorageResponse::NotStoredConsensus
},
Err(e) => {
warn!(target: LOG_TARGET, "Validation failed due to error: {}", e);
Ok(TxStorageResponse::NotStored)
TxStorageResponse::NotStored
},
}
}
Expand All @@ -131,11 +131,10 @@ impl MempoolStorage {
}

// Insert a set of new transactions into the UTxPool.
fn insert_txs(&mut self, txs: Vec<Arc<Transaction>>) -> Result<(), MempoolError> {
fn insert_txs(&mut self, txs: Vec<Arc<Transaction>>) {
for tx in txs {
self.insert(tx)?;
self.insert(tx);
}
Ok(())
}

/// Update the Mempool based on the received published block.
Expand Down Expand Up @@ -168,10 +167,14 @@ impl MempoolStorage {
failed_block.header.height,
failed_block.hash().to_hex()
);
self.unconfirmed_pool
let txs = self
.unconfirmed_pool
.remove_published_and_discard_deprecated_transactions(failed_block);

// Reinsert them to validate if they are still valid
self.insert_txs(txs);
self.unconfirmed_pool.compact();
debug!(target: LOG_TARGET, "{}", self.stats());

Ok(())
}

Expand All @@ -190,12 +193,12 @@ impl MempoolStorage {
// validation. This is important as invalid transactions that have not been mined yet may remain in the mempool
// after a reorg.
let removed_txs = self.unconfirmed_pool.drain_all_mempool_transactions();
self.insert_txs(removed_txs)?;
self.insert_txs(removed_txs);
// Remove re-orged transactions from reorg pool and re-submit them to the unconfirmed mempool
let removed_txs = self
.reorg_pool
.remove_reorged_txs_and_discard_double_spends(removed_blocks, new_blocks);
self.insert_txs(removed_txs)?;
self.insert_txs(removed_txs);
// Update the Mempool based on the received set of new blocks.
for block in new_blocks {
self.process_published_block(block)?;
Expand Down Expand Up @@ -235,7 +238,7 @@ impl MempoolStorage {
/// Will only return transactions that will fit into the given weight
pub fn retrieve_and_revalidate(&mut self, total_weight: u64) -> Result<Vec<Arc<Transaction>>, MempoolError> {
let results = self.unconfirmed_pool.fetch_highest_priority_txs(total_weight)?;
self.insert_txs(results.transactions_to_insert)?;
self.insert_txs(results.transactions_to_insert);
Ok(results.retrieved_transactions)
}

Expand Down
5 changes: 2 additions & 3 deletions base_layer/core/src/mempool/service/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use tari_service_framework::{
use tokio::sync::mpsc;

use crate::{
base_node::{comms_interface::LocalNodeCommsInterface, StateMachineHandle},
base_node::comms_interface::LocalNodeCommsInterface,
mempool::{
mempool::Mempool,
service::{
Expand Down Expand Up @@ -135,7 +135,6 @@ impl ServiceInitializer for MempoolServiceInitializer {

context.spawn_until_shutdown(move |handles| {
let outbound_message_service = handles.expect_handle::<Dht>().outbound_requester();
let state_machine = handles.expect_handle::<StateMachineHandle>();
let base_node = handles.expect_handle::<LocalNodeCommsInterface>();

let streams = MempoolStreams {
Expand All @@ -146,7 +145,7 @@ impl ServiceInitializer for MempoolServiceInitializer {
request_receiver,
};
debug!(target: LOG_TARGET, "Mempool service started");
MempoolService::new(outbound_message_service, inbound_handlers, state_machine).start(streams)
MempoolService::new(outbound_message_service, inbound_handlers).start(streams)
});

Ok(())
Expand Down
72 changes: 21 additions & 51 deletions base_layer/core/src/mempool/service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ use tari_utilities::hex::Hex;
use tokio::{sync::mpsc, task};

use crate::{
base_node::{
comms_interface::{BlockEvent, BlockEventReceiver},
StateMachineHandle,
},
base_node::comms_interface::{BlockEvent, BlockEventReceiver},
mempool::service::{
error::MempoolServiceError,
inbound_handlers::MempoolInboundHandlers,
Expand All @@ -66,19 +63,13 @@ pub struct MempoolStreams<STxIn, SLocalReq> {
pub struct MempoolService {
outbound_message_service: OutboundMessageRequester,
inbound_handlers: MempoolInboundHandlers,
state_machine: StateMachineHandle,
}

impl MempoolService {
pub fn new(
outbound_message_service: OutboundMessageRequester,
inbound_handlers: MempoolInboundHandlers,
state_machine: StateMachineHandle,
) -> Self {
pub fn new(outbound_message_service: OutboundMessageRequester, inbound_handlers: MempoolInboundHandlers) -> Self {
Self {
outbound_message_service,
inbound_handlers,
state_machine,
}
}

Expand Down Expand Up @@ -108,12 +99,20 @@ impl MempoolService {

// Outbound tx messages from the OutboundMempoolServiceInterface
Some((txn, excluded_peers)) = outbound_tx_stream.recv() => {
self.spawn_handle_outbound_tx(txn, excluded_peers);
let _res = handle_outbound_tx(&mut self.outbound_message_service, txn, excluded_peers).await.map_err(|e|
error!(target: LOG_TARGET, "Error sending outbound tx message: {}", e)
);
},

// Incoming transaction messages from the Comms layer
Some(transaction_msg) = inbound_transaction_stream.next() => {
self.spawn_handle_incoming_tx(transaction_msg);
let result = handle_incoming_tx(&mut self.inbound_handlers, transaction_msg).await;
if let Err(e) = result {
error!(
target: LOG_TARGET,
"Failed to handle incoming transaction message: {:?}", e
);
}
}

// Incoming local request messages from the LocalMempoolServiceInterface and other local services
Expand Down Expand Up @@ -144,41 +143,6 @@ impl MempoolService {
self.inbound_handlers.handle_request(request).await
}

fn spawn_handle_outbound_tx(&self, tx: Arc<Transaction>, excluded_peers: Vec<NodeId>) {
let outbound_message_service = self.outbound_message_service.clone();
task::spawn(async move {
let result = handle_outbound_tx(outbound_message_service, tx, excluded_peers).await;
if let Err(e) = result {
error!(target: LOG_TARGET, "Failed to handle outbound tx message {:?}", e);
}
});
}

fn spawn_handle_incoming_tx(&self, tx_msg: DomainMessage<Transaction>) {
// Determine if we are bootstrapped
let status_watch = self.state_machine.get_status_info_watch();

if !(*status_watch.borrow()).bootstrapped {
debug!(
target: LOG_TARGET,
"Transaction with Message {} from peer `{}` not processed while busy with initial sync.",
tx_msg.dht_header.message_tag,
tx_msg.source_peer.node_id.short_str(),
);
return;
}
let inbound_handlers = self.inbound_handlers.clone();
task::spawn(async move {
let result = handle_incoming_tx(inbound_handlers, tx_msg).await;
if let Err(e) = result {
error!(
target: LOG_TARGET,
"Failed to handle incoming transaction message: {:?}", e
);
}
});
}

fn spawn_handle_local_request(
&self,
request_context: RequestContext<MempoolRequest, Result<MempoolResponse, MempoolServiceError>>,
Expand Down Expand Up @@ -209,7 +173,7 @@ impl MempoolService {
}

async fn handle_incoming_tx(
mut inbound_handlers: MempoolInboundHandlers,
inbound_handlers: &mut MempoolInboundHandlers,
domain_transaction_msg: DomainMessage<Transaction>,
) -> Result<(), MempoolServiceError> {
let DomainMessage::<_> { source_peer, inner, .. } = domain_transaction_msg;
Expand All @@ -236,7 +200,7 @@ async fn handle_incoming_tx(
}

async fn handle_outbound_tx(
mut outbound_message_service: OutboundMessageRequester,
outbound_message_service: &mut OutboundMessageRequester,
tx: Arc<Transaction>,
exclude_peers: Vec<NodeId>,
) -> Result<(), MempoolServiceError> {
Expand All @@ -247,7 +211,13 @@ async fn handle_outbound_tx(
exclude_peers,
OutboundDomainMessage::new(
&TariMessageType::NewTransaction,
proto::types::Transaction::try_from(tx).map_err(MempoolServiceError::ConversionError)?,
proto::types::Transaction::try_from(tx.clone()).map_err(MempoolServiceError::ConversionError)?,
),
format!(
"Outbound mempool tx: {}",
tx.first_kernel_excess_sig()
.map(|s| s.get_signature().to_hex())
.unwrap_or_else(|| "No kernels!".to_string())
),
)
.await;
Expand Down
Loading

0 comments on commit 8030364

Please sign in to comment.