Skip to content

Commit

Permalink
feat(codec): add max_message_size parameter
Browse files Browse the repository at this point in the history
resolves hyperium#1097
  • Loading branch information
aoudiamoncef committed Feb 16, 2023
1 parent 26b848b commit 3aabcac
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tonic/benches/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ macro_rules! bench {
b.iter(|| {
rt.block_on(async {
let decoder = MockDecoder::new($message_size);
let mut stream = Streaming::new_request(decoder, body.clone(), None);
let mut stream = Streaming::new_request(decoder, body.clone(), None, None);

let mut count = 0;
while let Some(msg) = stream.message().await.unwrap() {
Expand Down
11 changes: 9 additions & 2 deletions tonic/src/client/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,14 @@ impl<T> Grpc<T> {
M2: Send + Sync + 'static,
{
let request = request
.map(|s| encode_client(codec.encoder(), s, self.config.send_compression_encodings))
.map(|s| {
encode_client(
codec.encoder(),
s,
self.config.send_compression_encodings,
None,
)
})
.map(BoxBody::new);

let request = self.config.prepare_request(request, path);
Expand Down Expand Up @@ -278,7 +285,7 @@ impl<T> Grpc<T> {

let response = response.map(|body| {
if expect_additional_trailers {
Streaming::new_response(decoder, body, status_code, encoding)
Streaming::new_response(decoder, body, status_code, encoding, None)
} else {
Streaming::new_empty(decoder, body)
}
Expand Down
43 changes: 38 additions & 5 deletions tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::compression::{decompress, CompressionEncoding};
use super::{DecodeBuf, Decoder, HEADER_SIZE};
use super::{DecodeBuf, Decoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE};
use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
use bytes::{Buf, BufMut, BytesMut};
use futures_core::Stream;
Expand Down Expand Up @@ -32,6 +32,7 @@ struct StreamingInner {
trailers: Option<MetadataMap>,
decompress_buf: BytesMut,
encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
}

impl<T> Unpin for Streaming<T> {}
Expand Down Expand Up @@ -59,13 +60,20 @@ impl<T> Streaming<T> {
body: B,
status_code: StatusCode,
encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
) -> Self
where
B: Body + Send + 'static,
B::Error: Into<crate::Error>,
D: Decoder<Item = T, Error = Status> + Send + 'static,
{
Self::new(decoder, body, Direction::Response(status_code), encoding)
Self::new(
decoder,
body,
Direction::Response(status_code),
encoding,
max_message_size,
)
}

pub(crate) fn new_empty<B, D>(decoder: D, body: B) -> Self
Expand All @@ -74,24 +82,36 @@ impl<T> Streaming<T> {
B::Error: Into<crate::Error>,
D: Decoder<Item = T, Error = Status> + Send + 'static,
{
Self::new(decoder, body, Direction::EmptyResponse, None)
Self::new(decoder, body, Direction::EmptyResponse, None, None)
}

#[doc(hidden)]
pub fn new_request<B, D>(decoder: D, body: B, encoding: Option<CompressionEncoding>) -> Self
pub fn new_request<B, D>(
decoder: D,
body: B,
encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
) -> Self
where
B: Body + Send + 'static,
B::Error: Into<crate::Error>,
D: Decoder<Item = T, Error = Status> + Send + 'static,
{
Self::new(decoder, body, Direction::Request, encoding)
Self::new(
decoder,
body,
Direction::Request,
encoding,
max_message_size,
)
}

fn new<B, D>(
decoder: D,
body: B,
direction: Direction,
encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
) -> Self
where
B: Body + Send + 'static,
Expand All @@ -111,6 +131,7 @@ impl<T> Streaming<T> {
trailers: None,
decompress_buf: BytesMut::new(),
encoding,
max_message_size,
},
}
}
Expand Down Expand Up @@ -151,7 +172,19 @@ impl StreamingInner {
return Err(Status::new(Code::Internal, message));
}
};

let len = self.buf.get_u32() as usize;
let limit = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
if len > limit {
return Err(Status::new(
Code::OutOfRange,
format!(
"Error, message length too large: found {} bytes, the limit is: {} bytes",
len, limit
),
));
}

self.buf.reserve(len);

self.state = State::ReadBody {
Expand Down
32 changes: 29 additions & 3 deletions tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride};
use super::{EncodeBuf, Encoder, HEADER_SIZE};
use super::{EncodeBuf, Encoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE};
use crate::{Code, Status};
use bytes::{BufMut, Bytes, BytesMut};
use futures_core::{Stream, TryStream};
Expand All @@ -19,12 +19,20 @@ pub(crate) fn encode_server<T, U>(
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
let stream = encode(encoder, source, compression_encoding, compression_override).into_stream();
let stream = encode(
encoder,
source,
compression_encoding,
compression_override,
max_message_size,
)
.into_stream();

EncodeBody::new_server(stream)
}
Expand All @@ -33,6 +41,7 @@ pub(crate) fn encode_client<T, U>(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
where
T: Encoder<Error = Status>,
Expand All @@ -43,6 +52,7 @@ where
source.map(Ok),
compression_encoding,
SingleMessageCompressionOverride::default(),
max_message_size,
)
.into_stream();
EncodeBody::new_client(stream)
Expand All @@ -53,6 +63,7 @@ fn encode<T, U>(
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> impl TryStream<Ok = Bytes, Error = Status>
where
T: Encoder<Error = Status>,
Expand Down Expand Up @@ -81,6 +92,7 @@ where
&mut buf,
&mut uncompression_buf,
compression_encoding,
max_message_size,
item,
)
})
Expand All @@ -91,6 +103,7 @@ fn encode_item<T>(
buf: &mut BytesMut,
uncompression_buf: &mut BytesMut,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
item: T::Item,
) -> Result<Bytes, Status>
where
Expand Down Expand Up @@ -119,14 +132,27 @@ where
}

// now that we know length, we can write the header
finish_encoding(compression_encoding, buf)
finish_encoding(compression_encoding, max_message_size, buf)
}

fn finish_encoding(
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
buf: &mut BytesMut,
) -> Result<Bytes, Status> {
let len = buf.len() - HEADER_SIZE;

let limit = max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
if len > limit {
return Err(Status::new(
Code::OutOfRange,
format!(
"Error, message length too large: found {} bytes, the limit is: {} bytes",
len, limit
),
));
}

if len > std::u32::MAX as usize {
return Err(Status::resource_exhausted(format!(
"Cannot return body with more than 4GB of data but got {len} bytes"
Expand Down
3 changes: 3 additions & 0 deletions tonic/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ const HEADER_SIZE: usize =
// data length
std::mem::size_of::<u32>();

// The default maximum uncompressed size in bytes for a message. Defaults to 4MB.
const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024 * 1024;

/// Trait that knows how to encode and decode gRPC messages.
pub trait Codec {
/// The encodable message.
Expand Down
73 changes: 71 additions & 2 deletions tonic/src/codec/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ mod tests {
use crate::codec::{
encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
};
use crate::Status;
use crate::{Code, Status};
use bytes::{Buf, BufMut, BytesMut};
use http_body::Body;

const LEN: usize = 10000;
// The maximum uncompressed size in bytes for a message. Set to 2MB.
const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;

#[tokio::test]
async fn decode() {
Expand All @@ -103,7 +105,7 @@ mod tests {

let body = body::MockBody::new(&buf[..], 10005, 0);

let mut stream = Streaming::new_request(decoder, body, None);
let mut stream = Streaming::new_request(decoder, body, None, None);

let mut i = 0usize;
while let Some(output_msg) = stream.message().await.unwrap() {
Expand All @@ -113,6 +115,39 @@ mod tests {
assert_eq!(i, 1);
}

#[tokio::test]
async fn decode_max_message_size_exceeded() {
let decoder = MockDecoder::default();

let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];

let mut buf = BytesMut::new();

buf.reserve(msg.len() + HEADER_SIZE);
buf.put_u8(0);
buf.put_u32(msg.len() as u32);

buf.put(&msg[..]);

let body = body::MockBody::new(&buf[..], 10005, 0);

let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));

let actual = stream.message().await.unwrap_err();

let expected = Status::new(
Code::OutOfRange,
format!(
"Error, message length too large: found {} bytes, the limit is: {} bytes",
msg.len(),
MAX_MESSAGE_SIZE
),
);

assert_eq!(actual.code(), expected.code());
assert_eq!(actual.message(), expected.message());
}

#[tokio::test]
async fn encode() {
let encoder = MockEncoder::default();
Expand All @@ -127,6 +162,7 @@ mod tests {
source,
None,
SingleMessageCompressionOverride::default(),
None,
);

futures_util::pin_mut!(body);
Expand All @@ -136,6 +172,38 @@ mod tests {
}
}

#[tokio::test]
async fn encode_max_message_size_exceeded() {
let encoder = MockEncoder::default();

let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];

let messages = std::iter::once(Ok::<_, Status>(msg));
let source = futures_util::stream::iter(messages);

let body = encode_server(
encoder,
source,
None,
SingleMessageCompressionOverride::default(),
Some(MAX_MESSAGE_SIZE),
);

futures_util::pin_mut!(body);

assert!(body.data().await.is_none());
assert_eq!(
body.trailers()
.await
.expect("no error polling trailers")
.expect("some trailers")
.get("grpc-status")
.expect("grpc-status header"),
"11"
);
assert!(body.is_end_stream());
}

// skip on windows because CI stumbles over our 4GB allocation
#[cfg(not(target_family = "windows"))]
#[tokio::test]
Expand All @@ -152,6 +220,7 @@ mod tests {
source,
None,
SingleMessageCompressionOverride::default(),
Some(usize::MAX),
);

futures_util::pin_mut!(body);
Expand Down
11 changes: 8 additions & 3 deletions tonic/src/server/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,12 @@ where

let (parts, body) = request.into_parts();

let stream =
Streaming::new_request(self.codec.decoder(), body, request_compression_encoding);
let stream = Streaming::new_request(
self.codec.decoder(),
body,
request_compression_encoding,
None,
);

futures_util::pin_mut!(stream);

Expand Down Expand Up @@ -309,7 +313,7 @@ where
let encoding = self.request_encoding_if_supported(&request)?;

let request =
request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding));
request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding, None));

Ok(Request::from_http(request))
}
Expand Down Expand Up @@ -349,6 +353,7 @@ where
body.into_stream(),
accept_encoding,
compression_override,
None,
);

http::Response::from_parts(parts, BoxBody::new(body))
Expand Down
Loading

0 comments on commit 3aabcac

Please sign in to comment.