Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ban peers if they send a bad protobuf message #5693

Merged
merged 14 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions base_layer/contacts/src/contacts_service/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

use diesel::result::Error as DieselError;
use tari_common_sqlite::error::SqliteStorageError;
use tari_common_types::tari_address::TariAddressError;
use tari_comms::connectivity::ConnectivityError;
use tari_comms_dht::outbound::DhtOutboundError;
use tari_p2p::services::liveness::error::LivenessError;
Expand All @@ -47,6 +48,12 @@ pub enum ContactsServiceError {
ConnectivityError(#[from] ConnectivityError),
#[error("Outbound comms error: `{0}`")]
OutboundCommsError(#[from] DhtOutboundError),
#[error("Error parsing address: `{source}`")]
MessageParsingError { source: TariAddressError },
#[error("Error decoding message: `{0}`")]
IllFormedMessageError(#[from] prost::DecodeError),
#[error("Message source does not match authenticated origin")]
MessageSourceDoesNotMatchOrigin,
}

#[derive(Debug, Error)]
Expand Down
41 changes: 27 additions & 14 deletions base_layer/contacts/src/contacts_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ use tari_p2p::{
domain_message::DomainMessage,
services::{
liveness::{LivenessEvent, LivenessHandle, MetadataKey, PingPongEvent},
utils::{map_decode, ok_or_skip_result},
utils::map_decode,
},
tari_message::TariMessageType,
};
use tari_service_framework::reply_channel;
use tari_shutdown::ShutdownSignal;
use tari_utilities::ByteArray;
use tari_utilities::{epoch_time::EpochTime, ByteArray};
use tokio::sync::broadcast;

use crate::contacts_service::{
Expand Down Expand Up @@ -189,8 +189,8 @@ where T: ContactsBackend + 'static
let chat_messages = self
.subscription_factory
.get_subscription(TariMessageType::Chat, SUBSCRIPTION_LABEL)
.map(map_decode::<proto::Message>)
.filter_map(ok_or_skip_result);
.map(map_decode::<proto::Message>);

pin_mut!(chat_messages);

let shutdown = self
Expand Down Expand Up @@ -406,22 +406,35 @@ where T: ContactsBackend + 'static

async fn handle_incoming_message(
&mut self,
msg: DomainMessage<crate::contacts_service::proto::Message>,
msg: DomainMessage<Result<crate::contacts_service::proto::Message, prost::DecodeError>>,
) -> Result<(), ContactsServiceError> {
let msg_inner = match &msg.inner {
Ok(msg) => msg.clone(),
Err(e) => {
self.connectivity
.ban_peer(
msg.source_peer.node_id.clone(),
"Peer sent illformed message".to_string(),
)
.await?;
return Err(ContactsServiceError::IllFormedMessageError(e.clone()));
},
};
if let Some(source_public_key) = msg.authenticated_origin {
let DomainMessage::<_> { inner: msg, .. } = msg;
let message =
Message::try_from(msg_inner).map_err(|ta| ContactsServiceError::MessageParsingError { source: ta })?;

let message = Message::from(msg.clone());
let message = Message {
let our_message = Message {
address: TariAddress::from_public_key(&source_public_key, message.address.network()),
stored_at: Utc::now().naive_utc().timestamp() as u64,
..msg.into()
stored_at: EpochTime::now().as_u64(),
..message
};

self.db
.save_message(message.clone())
.expect("Couldn't save the message");
let _msg = self.message_publisher.send(Arc::new(message));
self.db.save_message(our_message.clone())?;

let _msg = self.message_publisher.send(Arc::new(our_message));
} else {
return Err(ContactsServiceError::MessageSourceDoesNotMatchOrigin);
}

Ok(())
Expand Down
16 changes: 10 additions & 6 deletions base_layer/contacts/src/contacts_service/types/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::convert::TryFrom;

use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use tari_common_types::tari_address::TariAddress;
use tari_common_types::tari_address::{TariAddress, TariAddressError};
use tari_comms_dht::domain_message::OutboundDomainMessage;
use tari_p2p::tari_message::TariMessageType;
use tari_utilities::ByteArray;
Expand Down Expand Up @@ -61,16 +63,18 @@ impl Default for Direction {
}
}

impl From<proto::Message> for Message {
fn from(message: proto::Message) -> Self {
Self {
impl TryFrom<proto::Message> for Message {
type Error = TariAddressError;

fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
Ok(Self {
body: message.body,
address: TariAddress::from_bytes(&message.address).expect("Couldn't parse address"),
address: TariAddress::from_bytes(&message.address)?,
// A Message from a proto::Message will always be an inbound message
direction: Direction::Inbound,
stored_at: message.stored_at,
message_id: message.message_id,
}
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion base_layer/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ once_cell = "1.8.0"
prost = "0.9"
prost-types = "0.9"
rand = "0.8"
randomx-rs = { version = "1.2", optional = true }
randomx-rs = { version = "1.2.1", optional = true }
serde = { version = "1.0.106", features = ["derive"] }
serde_json = "1.0"
serde_repr = "0.1.8"
Expand Down
81 changes: 80 additions & 1 deletion base_layer/core/src/base_node/service/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::time::Duration;

use tari_comms_dht::outbound::DhtOutboundError;
use thiserror::Error;

use crate::base_node::comms_interface::CommsInterfaceError;
use crate::base_node::{comms_interface::CommsInterfaceError, service::initializer::ExtractBlockError};

#[derive(Debug, Error)]
pub enum BaseNodeServiceError {
Expand All @@ -35,4 +37,81 @@ pub enum BaseNodeServiceError {
InvalidRequest(String),
#[error("Invalid response error: `{0}`")]
InvalidResponse(String),
#[error("Invalid block error: `{0}`")]
InvalidBlockMessage(#[from] ExtractBlockError),
}

impl BaseNodeServiceError {
pub fn get_ban_reason(&self) -> Option<BanReason> {
match self {
BaseNodeServiceError::CommsInterfaceError(comms) => {
match comms {
CommsInterfaceError::UnexpectedApiResponse => Some(BanReason {
reason: "Unexpected API response".to_string(),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::RequestTimedOut => Some(BanReason {
reason: "Request timed out".to_string(),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::InvalidPeerResponse(e) => Some(BanReason {
reason: format!("Invalid peer response: {}", e),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::InvalidBlockHeader(e) => Some(BanReason {
reason: format!("Invalid block header: {}", e),
ban_duration: Duration::from_secs(60),
}),
CommsInterfaceError::InvalidRequest { request, details } => Some(BanReason {
reason: format!("Invalid request: {} ({})", request, details),
ban_duration: Duration::from_secs(60),
}),
// CommsInterfaceError::NoBootstrapNodesConfigured |
// CommsInterfaceError::TransportChannelError(_) |
// CommsInterfaceError::ChainStorageError(_) |
// CommsInterfaceError::OutboundMessageError(_) |
// CommsInterfaceError::MempoolError(_) |
// CommsInterfaceError::BroadcastFailed |
// CommsInterfaceError::InternalChannelError(_) => {}
// CommsInterfaceError::DifficultyAdjustmentManagerError(_) => {}
// CommsInterfaceError::InternalError(_) => {}
// CommsInterfaceError::ApiError(_) => {}
// CommsInterfaceError::BlockHeaderNotFound(_) => {}
// CommsInterfaceError::BlockError(_) => {}
// CommsInterfaceError::InvalidFullBlock { .. } => {}
// CommsInterfaceError::MergeMineError(_) => {}
// CommsInterfaceError::DifficultyError(_) => {}
_ => None,
}
},
BaseNodeServiceError::DhtOutboundError(_) => None,
BaseNodeServiceError::InvalidRequest(e) => Some(BanReason {
reason: format!("Invalid request: {}", e),
ban_duration: Duration::from_secs(60),
}),
BaseNodeServiceError::InvalidResponse(e) => Some(BanReason {
reason: format!("Invalid response: {}", e),
ban_duration: Duration::from_secs(60),
}),
BaseNodeServiceError::InvalidBlockMessage(e) => Some(BanReason {
reason: format!("Invalid block message: {}", e),
ban_duration: Duration::from_secs(60),
}),
}
}
}

pub struct BanReason {
reason: String,
ban_duration: Duration,
}

impl BanReason {
pub fn reason(&self) -> &str {
&self.reason
}

pub fn ban_duration(&self) -> Duration {
self.ban_duration
}
}
66 changes: 34 additions & 32 deletions base_layer/core/src/base_node/service/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use tari_comms_dht::Dht;
use tari_p2p::{
comms_connector::{PeerMessage, SubscriptionFactory},
domain_message::DomainMessage,
services::utils::{map_decode, ok_or_skip_result},
services::utils::map_decode,
tari_message::TariMessageType,
};
use tari_service_framework::{
Expand All @@ -39,6 +39,7 @@ use tari_service_framework::{
ServiceInitializer,
ServiceInitializerContext,
};
use thiserror::Error;
use tokio::sync::{broadcast, mpsc};

use crate::{
Expand Down Expand Up @@ -92,58 +93,58 @@ where T: BlockchainBackend
}

/// Get a stream for inbound Base Node request messages
fn inbound_request_stream(&self) -> impl Stream<Item = DomainMessage<proto::BaseNodeServiceRequest>> {
fn inbound_request_stream(
&self,
) -> impl Stream<Item = DomainMessage<Result<proto::BaseNodeServiceRequest, prost::DecodeError>>> {
self.inbound_message_subscription_factory
.get_subscription(TariMessageType::BaseNodeRequest, SUBSCRIPTION_LABEL)
.map(map_decode::<proto::BaseNodeServiceRequest>)
.filter_map(ok_or_skip_result)
// .filter_map(ok_or_skip_result)
}

/// Get a stream for inbound Base Node response messages
fn inbound_response_stream(&self) -> impl Stream<Item = DomainMessage<proto::BaseNodeServiceResponse>> {
fn inbound_response_stream(
&self,
) -> impl Stream<Item = DomainMessage<Result<proto::BaseNodeServiceResponse, prost::DecodeError>>> {
self.inbound_message_subscription_factory
.get_subscription(TariMessageType::BaseNodeResponse, SUBSCRIPTION_LABEL)
.map(map_decode::<proto::BaseNodeServiceResponse>)
.filter_map(ok_or_skip_result)
}

/// Create a stream of 'New Block` messages
fn inbound_block_stream(&self) -> impl Stream<Item = DomainMessage<NewBlock>> {
fn inbound_block_stream(&self) -> impl Stream<Item = DomainMessage<Result<NewBlock, ExtractBlockError>>> {
self.inbound_message_subscription_factory
.get_subscription(TariMessageType::NewBlock, SUBSCRIPTION_LABEL)
.filter_map(extract_block)
.map(extract_block)
}
}

async fn extract_block(msg: Arc<PeerMessage>) -> Option<DomainMessage<NewBlock>> {
match msg.decode_message::<shared_protos::core::NewBlock>() {
#[derive(Error, Debug)]
pub enum ExtractBlockError {
#[error("Could not decode inbound block message. {0}")]
DecodeError(#[from] prost::DecodeError),
#[error("Inbound block message was ill-formed. {0}")]
IllFormedMessage(String),
}

fn extract_block(msg: Arc<PeerMessage>) -> DomainMessage<Result<NewBlock, ExtractBlockError>> {
let new_block = match msg.decode_message::<shared_protos::core::NewBlock>() {
Ok(block) => block,
Err(e) => {
warn!(
target: LOG_TARGET,
"Could not decode inbound block message. {}",
e.to_string()
);
None
},
Ok(new_block) => {
let block = match NewBlock::try_from(new_block) {
Err(e) => {
let origin = &msg.source_peer.node_id;
warn!(
target: LOG_TARGET,
"Inbound block message from {} was ill-formed. {}", origin, e
);
return None;
},
Ok(b) => b,
};
Some(DomainMessage {
return DomainMessage {
source_peer: msg.source_peer.clone(),
dht_header: msg.dht_header.clone(),
authenticated_origin: msg.authenticated_origin.clone(),
inner: block,
})
inner: Err(e.into()),
}
},
};
let block = NewBlock::try_from(new_block).map_err(ExtractBlockError::IllFormedMessage);
DomainMessage {
source_peer: msg.source_peer.clone(),
dht_header: msg.dht_header.clone(),
authenticated_origin: msg.authenticated_origin.clone(),
inner: block,
}
}

Expand Down Expand Up @@ -194,7 +195,7 @@ where T: BlockchainBackend + 'static
mempool,
consensus_manager,
outbound_nci.clone(),
connectivity,
connectivity.clone(),
randomx_factory,
);

Expand All @@ -212,6 +213,7 @@ where T: BlockchainBackend + 'static
inbound_nch,
service_request_timeout,
state_machine,
connectivity,
)
.start(streams);
futures::pin_mut!(service);
Expand Down
Loading