Skip to content

Commit

Permalink
fix: ban peers if they send a bad protobuf message (tari-project#5693)
Browse files Browse the repository at this point in the history
Description
---
Propagate errors in decoding protobuf messages and conversions higher up
so that nodes can ban them

Motivation and Context
---
Bad messages were previously ignored and could just fill up a node's
buffers. This PR allows the nodes to ban more messages from those peers.

How Has This Been Tested?
---
Tested locally and cargo test

What process can a PR reviewer use to test or verify this change?
---
This is pretty tough to test, you might have to try connect an old node
to this peer

<!-- Checklist -->
<!-- 1. Is the title of your PR in the form that would make nice release
notes? The title, excluding the conventional commit
tag, will be included exactly as is in the CHANGELOG, so please think
about it carefully. -->


Breaking Changes
---

- [x] None
- [ ] Requires data directory on base node to be deleted
- [ ] Requires hard fork
- [ ] Other - Please specify

<!-- Does this include a breaking change? If so, include this line as a
footer -->
<!-- BREAKING CHANGE: Description what the user should do, e.g. delete a
database, resync the chain -->

---------

Co-authored-by: SW van Heerden <[email protected]>
  • Loading branch information
stringhandler and SWvheerden authored Aug 31, 2023
1 parent 826473d commit 58cbfe6
Show file tree
Hide file tree
Showing 16 changed files with 355 additions and 164 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions base_layer/contacts/src/contacts_service/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

use diesel::result::Error as DieselError;
use tari_common_sqlite::error::SqliteStorageError;
use tari_common_types::tari_address::TariAddressError;
use tari_comms::connectivity::ConnectivityError;
use tari_comms_dht::outbound::DhtOutboundError;
use tari_p2p::services::liveness::error::LivenessError;
Expand All @@ -47,6 +48,12 @@ pub enum ContactsServiceError {
ConnectivityError(#[from] ConnectivityError),
#[error("Outbound comms error: `{0}`")]
OutboundCommsError(#[from] DhtOutboundError),
#[error("Error parsing address: `{source}`")]
MessageParsingError { source: TariAddressError },
#[error("Error decoding message: `{0}`")]
MalformedMessageError(#[from] prost::DecodeError),
#[error("Message source does not match authenticated origin")]
MessageSourceDoesNotMatchOrigin,
}

#[derive(Debug, Error)]
Expand Down
41 changes: 27 additions & 14 deletions base_layer/contacts/src/contacts_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ use tari_p2p::{
domain_message::DomainMessage,
services::{
liveness::{LivenessEvent, LivenessHandle, MetadataKey, PingPongEvent},
utils::{map_decode, ok_or_skip_result},
utils::map_decode,
},
tari_message::TariMessageType,
};
use tari_service_framework::reply_channel;
use tari_shutdown::ShutdownSignal;
use tari_utilities::ByteArray;
use tari_utilities::{epoch_time::EpochTime, ByteArray};
use tokio::sync::broadcast;

use crate::contacts_service::{
Expand Down Expand Up @@ -189,8 +189,8 @@ where T: ContactsBackend + 'static
let chat_messages = self
.subscription_factory
.get_subscription(TariMessageType::Chat, SUBSCRIPTION_LABEL)
.map(map_decode::<proto::Message>)
.filter_map(ok_or_skip_result);
.map(map_decode::<proto::Message>);

pin_mut!(chat_messages);

let shutdown = self
Expand Down Expand Up @@ -406,22 +406,35 @@ where T: ContactsBackend + 'static

async fn handle_incoming_message(
&mut self,
msg: DomainMessage<crate::contacts_service::proto::Message>,
msg: DomainMessage<Result<crate::contacts_service::proto::Message, prost::DecodeError>>,
) -> Result<(), ContactsServiceError> {
let msg_inner = match &msg.inner {
Ok(msg) => msg.clone(),
Err(e) => {
self.connectivity
.ban_peer(
msg.source_peer.node_id.clone(),
"Peer sent illformed message".to_string(),
)
.await?;
return Err(ContactsServiceError::MalformedMessageError(e.clone()));
},
};
if let Some(source_public_key) = msg.authenticated_origin {
let DomainMessage::<_> { inner: msg, .. } = msg;
let message =
Message::try_from(msg_inner).map_err(|ta| ContactsServiceError::MessageParsingError { source: ta })?;

let message = Message::from(msg.clone());
let message = Message {
let our_message = Message {
address: TariAddress::from_public_key(&source_public_key, message.address.network()),
stored_at: Utc::now().naive_utc().timestamp() as u64,
..msg.into()
stored_at: EpochTime::now().as_u64(),
..message
};

self.db
.save_message(message.clone())
.expect("Couldn't save the message");
let _msg = self.message_publisher.send(Arc::new(message));
self.db.save_message(our_message.clone())?;

let _msg = self.message_publisher.send(Arc::new(our_message));
} else {
return Err(ContactsServiceError::MessageSourceDoesNotMatchOrigin);
}

Ok(())
Expand Down
16 changes: 10 additions & 6 deletions base_layer/contacts/src/contacts_service/types/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::convert::TryFrom;

use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use tari_common_types::tari_address::TariAddress;
use tari_common_types::tari_address::{TariAddress, TariAddressError};
use tari_comms_dht::domain_message::OutboundDomainMessage;
use tari_p2p::tari_message::TariMessageType;
use tari_utilities::ByteArray;
Expand Down Expand Up @@ -61,16 +63,18 @@ impl Default for Direction {
}
}

impl From<proto::Message> for Message {
fn from(message: proto::Message) -> Self {
Self {
impl TryFrom<proto::Message> for Message {
type Error = TariAddressError;

fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
Ok(Self {
body: message.body,
address: TariAddress::from_bytes(&message.address).expect("Couldn't parse address"),
address: TariAddress::from_bytes(&message.address)?,
// A Message from a proto::Message will always be an inbound message
direction: Direction::Inbound,
stored_at: message.stored_at,
message_id: message.message_id,
}
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion base_layer/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ once_cell = "1.8.0"
prost = "0.9"
prost-types = "0.9"
rand = "0.8"
randomx-rs = { version = "1.2", optional = true }
randomx-rs = { version = "1.2.1", optional = true }
serde = { version = "1.0.106", features = ["derive"] }
serde_json = "1.0"
serde_repr = "0.1.8"
Expand Down
78 changes: 77 additions & 1 deletion base_layer/core/src/base_node/service/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::time::Duration;

use tari_comms_dht::outbound::DhtOutboundError;
use thiserror::Error;

use crate::base_node::comms_interface::CommsInterfaceError;
use crate::base_node::{comms_interface::CommsInterfaceError, service::initializer::ExtractBlockError};

#[derive(Debug, Error)]
pub enum BaseNodeServiceError {
Expand All @@ -35,4 +37,78 @@ pub enum BaseNodeServiceError {
InvalidRequest(String),
#[error("Invalid response error: `{0}`")]
InvalidResponse(String),
#[error("Invalid block error: `{0}`")]
InvalidBlockMessage(#[from] ExtractBlockError),
}

impl BaseNodeServiceError {
pub fn get_ban_reason(&self) -> Option<BanReason> {
match self {
BaseNodeServiceError::CommsInterfaceError(comms) => match comms {
CommsInterfaceError::UnexpectedApiResponse => Some(BanReason {
reason: "Unexpected API response".to_string(),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::RequestTimedOut => Some(BanReason {
reason: "Request timed out".to_string(),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::InvalidPeerResponse(e) => Some(BanReason {
reason: format!("Invalid peer response: {}", e),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::InvalidBlockHeader(e) => Some(BanReason {
reason: format!("Invalid block header: {}", e),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::InvalidRequest { request, details } => Some(BanReason {
reason: format!("Invalid request: {} ({})", request, details),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::NoBootstrapNodesConfigured |
CommsInterfaceError::TransportChannelError(_) |
CommsInterfaceError::ChainStorageError(_) |
CommsInterfaceError::OutboundMessageError(_) |
CommsInterfaceError::MempoolError(_) |
CommsInterfaceError::BroadcastFailed |
CommsInterfaceError::InternalChannelError(_) |
CommsInterfaceError::DifficultyAdjustmentManagerError(_) |
CommsInterfaceError::InternalError(_) |
CommsInterfaceError::ApiError(_) |
CommsInterfaceError::BlockHeaderNotFound(_) |
CommsInterfaceError::BlockError(_) |
CommsInterfaceError::InvalidFullBlock { .. } |
CommsInterfaceError::MergeMineError(_) |
CommsInterfaceError::DifficultyError(_) => None,
},
BaseNodeServiceError::DhtOutboundError(_) => None,
BaseNodeServiceError::InvalidRequest(e) => Some(BanReason {
reason: format!("Invalid request: {}", e),
ban_duration: Duration::from_secs(60),
}),
BaseNodeServiceError::InvalidResponse(e) => Some(BanReason {
reason: format!("Invalid response: {}", e),
ban_duration: Duration::from_secs(60),
}),
BaseNodeServiceError::InvalidBlockMessage(e) => Some(BanReason {
reason: format!("Invalid block message: {}", e),
ban_duration: Duration::from_secs(60),
}),
}
}
}

pub struct BanReason {
reason: String,
ban_duration: Duration,
}

impl BanReason {
pub fn reason(&self) -> &str {
&self.reason
}

pub fn ban_duration(&self) -> Duration {
self.ban_duration
}
}
65 changes: 33 additions & 32 deletions base_layer/core/src/base_node/service/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tari_comms_dht::Dht;
use tari_p2p::{
comms_connector::{PeerMessage, SubscriptionFactory},
domain_message::DomainMessage,
services::utils::{map_decode, ok_or_skip_result},
services::utils::map_decode,
tari_message::TariMessageType,
};
use tari_service_framework::{
Expand All @@ -39,6 +39,7 @@ use tari_service_framework::{
ServiceInitializer,
ServiceInitializerContext,
};
use thiserror::Error;
use tokio::sync::{broadcast, mpsc};

use crate::{
Expand Down Expand Up @@ -92,58 +93,57 @@ where T: BlockchainBackend
}

/// Get a stream for inbound Base Node request messages
fn inbound_request_stream(&self) -> impl Stream<Item = DomainMessage<proto::BaseNodeServiceRequest>> {
fn inbound_request_stream(
&self,
) -> impl Stream<Item = DomainMessage<Result<proto::BaseNodeServiceRequest, prost::DecodeError>>> {
self.inbound_message_subscription_factory
.get_subscription(TariMessageType::BaseNodeRequest, SUBSCRIPTION_LABEL)
.map(map_decode::<proto::BaseNodeServiceRequest>)
.filter_map(ok_or_skip_result)
}

/// Get a stream for inbound Base Node response messages
fn inbound_response_stream(&self) -> impl Stream<Item = DomainMessage<proto::BaseNodeServiceResponse>> {
fn inbound_response_stream(
&self,
) -> impl Stream<Item = DomainMessage<Result<proto::BaseNodeServiceResponse, prost::DecodeError>>> {
self.inbound_message_subscription_factory
.get_subscription(TariMessageType::BaseNodeResponse, SUBSCRIPTION_LABEL)
.map(map_decode::<proto::BaseNodeServiceResponse>)
.filter_map(ok_or_skip_result)
}

/// Create a stream of 'New Block` messages
fn inbound_block_stream(&self) -> impl Stream<Item = DomainMessage<NewBlock>> {
fn inbound_block_stream(&self) -> impl Stream<Item = DomainMessage<Result<NewBlock, ExtractBlockError>>> {
self.inbound_message_subscription_factory
.get_subscription(TariMessageType::NewBlock, SUBSCRIPTION_LABEL)
.filter_map(extract_block)
.map(extract_block)
}
}

async fn extract_block(msg: Arc<PeerMessage>) -> Option<DomainMessage<NewBlock>> {
match msg.decode_message::<shared_protos::core::NewBlock>() {
#[derive(Error, Debug)]
pub enum ExtractBlockError {
#[error("Could not decode inbound block message. {0}")]
DecodeError(#[from] prost::DecodeError),
#[error("Inbound block message was ill-formed. {0}")]
MalformedMessage(String),
}

fn extract_block(msg: Arc<PeerMessage>) -> DomainMessage<Result<NewBlock, ExtractBlockError>> {
let new_block = match msg.decode_message::<shared_protos::core::NewBlock>() {
Ok(block) => block,
Err(e) => {
warn!(
target: LOG_TARGET,
"Could not decode inbound block message. {}",
e.to_string()
);
None
},
Ok(new_block) => {
let block = match NewBlock::try_from(new_block) {
Err(e) => {
let origin = &msg.source_peer.node_id;
warn!(
target: LOG_TARGET,
"Inbound block message from {} was ill-formed. {}", origin, e
);
return None;
},
Ok(b) => b,
};
Some(DomainMessage {
return DomainMessage {
source_peer: msg.source_peer.clone(),
dht_header: msg.dht_header.clone(),
authenticated_origin: msg.authenticated_origin.clone(),
inner: block,
})
inner: Err(e.into()),
}
},
};
let block = NewBlock::try_from(new_block).map_err(ExtractBlockError::MalformedMessage);
DomainMessage {
source_peer: msg.source_peer.clone(),
dht_header: msg.dht_header.clone(),
authenticated_origin: msg.authenticated_origin.clone(),
inner: block,
}
}

Expand Down Expand Up @@ -194,7 +194,7 @@ where T: BlockchainBackend + 'static
mempool,
consensus_manager,
outbound_nci.clone(),
connectivity,
connectivity.clone(),
randomx_factory,
);

Expand All @@ -212,6 +212,7 @@ where T: BlockchainBackend + 'static
inbound_nch,
service_request_timeout,
state_machine,
connectivity,
)
.start(streams);
futures::pin_mut!(service);
Expand Down
Loading

0 comments on commit 58cbfe6

Please sign in to comment.