Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check for max payload size when publishing messages #1211

Merged
merged 2 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use futures::StreamExt;
use once_cell::sync::Lazy;
use regex::Regex;
use std::fmt::Display;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
Expand All @@ -35,13 +35,26 @@ static VERSION_RE: Lazy<Regex> =

/// An error returned from the [`Client::publish`], [`Client::publish_with_headers`],
/// [`Client::publish_with_reply`] or [`Client::publish_with_reply_and_headers`] functions.
#[derive(Debug, Error)]
#[error("failed to publish message: {0}")]
pub struct PublishError(#[source] crate::Error);
pub type PublishError = Error<PublishErrorKind>;

impl From<tokio::sync::mpsc::error::SendError<Command>> for PublishError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
PublishError(Box::new(err))
PublishError::with_source(PublishErrorKind::Send, err)
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
pub enum PublishErrorKind {
MaxPayloadExceeded,
Send,
}

impl Display for PublishErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PublishErrorKind::MaxPayloadExceeded => write!(f, "max payload size exceeded"),
PublishErrorKind::Send => write!(f, "failed to send message"),
}
}
}

Expand All @@ -57,6 +70,7 @@ pub struct Client {
subscription_capacity: usize,
inbox_prefix: String,
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
}

impl Client {
Expand All @@ -67,6 +81,7 @@ impl Client {
capacity: usize,
inbox_prefix: String,
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
) -> Client {
Client {
info,
Expand All @@ -76,6 +91,7 @@ impl Client {
subscription_capacity: capacity,
inbox_prefix,
request_timeout,
max_payload,
}
}

Expand Down Expand Up @@ -154,6 +170,17 @@ impl Client {
payload: Bytes,
) -> Result<(), PublishError> {
let subject = subject.to_subject();
let max_payload = self.max_payload.load(Ordering::Relaxed);
if payload.len() > max_payload {
return Err(PublishError::with_source(
PublishErrorKind::MaxPayloadExceeded,
format!(
"Payload size limit of {} exceeded by message size of {}",
payload.len(),
max_payload
),
));
}

self.sender
.send(Command::Publish {
Expand Down
8 changes: 8 additions & 0 deletions async-nats/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use rand::thread_rng;
use std::cmp;
use std::io;
use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
Expand Down Expand Up @@ -72,6 +73,7 @@ pub(crate) struct Connector {
attempts: usize,
pub(crate) events_tx: tokio::sync::mpsc::Sender<Event>,
pub(crate) state_tx: tokio::sync::watch::Sender<State>,
pub(crate) max_payload: Arc<AtomicUsize>,
}

pub(crate) fn reconnect_delay_callback_default(attempts: usize) -> Duration {
Expand All @@ -90,6 +92,7 @@ impl Connector {
options: ConnectorOptions,
events_tx: tokio::sync::mpsc::Sender<Event>,
state_tx: tokio::sync::watch::Sender<State>,
max_payload: Arc<AtomicUsize>,
) -> Result<Connector, io::Error> {
let servers = addrs.to_server_addrs()?.map(|addr| (addr, 0)).collect();

Expand All @@ -99,6 +102,7 @@ impl Connector {
options,
events_tx,
state_tx,
max_payload,
})
}

Expand Down Expand Up @@ -282,6 +286,10 @@ impl Connector {
self.attempts = 0;
self.events_tx.send(Event::Connected).await.ok();
self.state_tx.send(State::Connected).ok();
self.max_payload.store(
server_info.max_payload,
std::sync::atomic::Ordering::Relaxed,
);
return Ok((server_info, connection));
}
None => {
Expand Down
8 changes: 7 additions & 1 deletion async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ use std::option;
use std::pin::Pin;
use std::slice;
use std::str::{self, FromStr};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::ErrorKind;
use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
Expand Down Expand Up @@ -877,6 +879,8 @@ pub async fn connect_with_options<A: ToServerAddrs>(

let (events_tx, mut events_rx) = mpsc::channel(128);
let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
// We're setting it to the default server payload size.
let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));

let mut connector = Connector::new(
addrs,
Expand All @@ -900,6 +904,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
},
events_tx,
state_tx,
max_payload.clone(),
)
.map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;

Expand All @@ -912,7 +917,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
info = info_ok;
}

let (info_sender, info_watcher) = tokio::sync::watch::channel(info);
let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
let (sender, mut receiver) = mpsc::channel(options.sender_capacity);

let client = Client::new(
Expand All @@ -922,6 +927,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
options.subscription_capacity,
options.inbox_prefix,
options.request_timeout,
max_payload,
);

task::spawn(async move {
Expand Down
17 changes: 17 additions & 0 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -866,4 +866,21 @@ mod client {
Event::ClientError(async_nats::ClientError::MaxReconnects)
);
}

#[tokio::test]
async fn publish_payload_size() {
let server = nats_server::run_server("tests/configs/max_payload.conf");

let client = async_nats::connect(server.client_url()).await.unwrap();

// this exceeds the small payload limit in server config.
let payload = vec![0u8; 1024 * 1024];

client.publish("big", payload.into()).await.unwrap_err();
client.publish("small", "data".into()).await.unwrap();
client
.publish("just_ok", vec![0u8; 1024 * 128].into())
.await
.unwrap();
}
}
1 change: 1 addition & 0 deletions async-nats/tests/configs/max_payload.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
max_payload = 128KB
Loading