From d36c0f5fd93f8190c9f39990ce4ec859c2b6d567 Mon Sep 17 00:00:00 2001 From: Andrey36652 <35865938+Andrey36652@users.noreply.github.com> Date: Tue, 3 Dec 2024 19:01:58 +0300 Subject: [PATCH] perf: fix decoder streams to make pooled connections reusable (#2484) When a response body is being decompressed, and the length wasn't known, but was using chunked transfer-encoding, the remaining `0\r\n\r\n` was not consumed. That would leave the connection in a state that could be not be reused, and so the pool had to discard it. This fix makes sure the remaining end chunk is consumed, improving the amount of pooled connections that can be reused. Closes #2381 --- src/async_impl/decoder.rs | 116 ++++++++++++++++----- tests/brotli.rs | 210 +++++++++++++++++++++++++++++++++++++ tests/deflate.rs | 212 +++++++++++++++++++++++++++++++++++++ tests/gzip.rs | 213 ++++++++++++++++++++++++++++++++++++++ tests/support/server.rs | 103 ++++++++++++++++++ tests/zstd.rs | 207 ++++++++++++++++++++++++++++++++++++ 6 files changed, 1033 insertions(+), 28 deletions(-) diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index d742e6d35..96a27ac45 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -9,6 +9,14 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +#[cfg(any( + feature = "gzip", + feature = "zstd", + feature = "brotli", + feature = "deflate" +))] +use futures_util::stream::Fuse; + #[cfg(feature = "gzip")] use async_compression::tokio::bufread::GzipDecoder; @@ -108,19 +116,19 @@ enum Inner { /// A `Gzip` decoder will uncompress the gzipped response content before returning it. #[cfg(feature = "gzip")] - Gzip(Pin, BytesCodec>>>), + Gzip(Pin, BytesCodec>>>>), /// A `Brotli` decoder will uncompress the brotlied response content before returning it. #[cfg(feature = "brotli")] - Brotli(Pin, BytesCodec>>>), + Brotli(Pin, BytesCodec>>>>), /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it. #[cfg(feature = "zstd")] - Zstd(Pin, BytesCodec>>>), + Zstd(Pin, BytesCodec>>>>), /// A `Deflate` decoder will uncompress the deflated response content before returning it. #[cfg(feature = "deflate")] - Deflate(Pin, BytesCodec>>>), + Deflate(Pin, BytesCodec>>>>), /// A decoder that doesn't have a value yet. #[cfg(any( @@ -365,34 +373,74 @@ impl HttpBody for Decoder { } #[cfg(feature = "gzip")] Inner::Gzip(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after gzip stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "brotli")] Inner::Brotli(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after brotli stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "zstd")] Inner::Zstd(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after zstd stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after deflate stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } } @@ -456,25 +504,37 @@ impl Future for Pending { match self.1 { #[cfg(feature = "brotli")] - DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new( - BrotliDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin( + FramedRead::new( + BrotliDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "zstd")] - DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new( - ZstdDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin( + FramedRead::new( + ZstdDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "gzip")] - DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new( - GzipDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin( + FramedRead::new( + GzipDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "deflate")] - DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new( - ZlibDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin( + FramedRead::new( + ZlibDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), } } } diff --git a/tests/brotli.rs b/tests/brotli.rs index 5c2b01849..ba116ed92 100644 --- a/tests/brotli.rs +++ b/tests/brotli.rs @@ -1,6 +1,7 @@ mod support; use std::io::Read; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn brotli_response() { @@ -145,3 +146,212 @@ async fn brotli_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: br\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn brotli_compress(input: &[u8]) -> Vec { + let mut encoder = brotli_crate::CompressorReader::new(input, 4096, 5, 20); + let mut brotlied_content = Vec::new(); + encoder.read_to_end(&mut brotlied_content).unwrap(); + brotlied_content +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", brotlied_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &brotlied_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/deflate.rs b/tests/deflate.rs index ec27ba180..55331afc5 100644 --- a/tests/deflate.rs +++ b/tests/deflate.rs @@ -1,6 +1,7 @@ mod support; use std::io::Write; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn deflate_response() { @@ -148,3 +149,214 @@ async fn deflate_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: deflate\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn deflate_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::zlib::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to deflate encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", deflated_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &deflated_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/gzip.rs b/tests/gzip.rs index 57189e0ac..74ead8783 100644 --- a/tests/gzip.rs +++ b/tests/gzip.rs @@ -2,6 +2,8 @@ mod support; use support::server; use std::io::Write; +use tokio::io::AsyncWriteExt; +use tokio::time::Duration; #[tokio::test] async fn gzip_response() { @@ -149,3 +151,214 @@ async fn gzip_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: gzip\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn gzip_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::gzip::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to gzip encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", gzipped_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &gzipped_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/support/server.rs b/tests/support/server.rs index 29835ead1..79ebd2d8f 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -6,6 +6,8 @@ use std::sync::mpsc as std_mpsc; use std::thread; use std::time::Duration; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; use tokio::runtime; use tokio::sync::oneshot; @@ -240,3 +242,104 @@ where .join() .unwrap() } + +pub fn low_level_with_response(do_response: F) -> Server +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c> + + Clone + + Send + + 'static, +{ + // Spawn new runtime in thread to prevent reactor execution context conflict + let test_name = thread::current().name().unwrap_or("").to_string(); + thread::spawn(move || { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("new rt"); + let listener = rt.block_on(async move { + tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap() + }); + let addr = listener.local_addr().unwrap(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (panic_tx, panic_rx) = std_mpsc::channel(); + let (events_tx, events_rx) = std_mpsc::channel(); + let tname = format!("test({})-support-server", test_name,); + thread::Builder::new() + .name(tname) + .spawn(move || { + rt.block_on(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => { + break; + } + accepted = listener.accept() => { + let (io, _) = accepted.expect("accepted"); + let do_response = do_response.clone(); + let events_tx = events_tx.clone(); + tokio::spawn(async move { + low_level_server_client(io, do_response).await; + let _ = events_tx.send(Event::ConnectionClosed); + }); + } + } + } + let _ = panic_tx.send(()); + }); + }) + .expect("thread spawn"); + Server { + addr, + panic_rx, + events_rx, + shutdown_tx: Some(shutdown_tx), + } + }) + .join() + .unwrap() +} + +async fn low_level_server_client(mut client_socket: TcpStream, do_response: F) +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c>, +{ + loop { + let request = low_level_read_http_request(&mut client_socket) + .await + .expect("read_http_request failed"); + if request.is_empty() { + // connection closed by client + break; + } + + Box::into_pin(do_response(&request, &mut client_socket)).await; + } +} + +async fn low_level_read_http_request( + client_socket: &mut TcpStream, +) -> core::result::Result, std::io::Error> { + let mut buf = Vec::new(); + + // Read until the delimiter "\r\n\r\n" is found + loop { + let mut temp_buffer = [0; 1024]; + let n = client_socket.read(&mut temp_buffer).await?; + + if n == 0 { + break; + } + + buf.extend_from_slice(&temp_buffer[..n]); + + if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") { + return Ok(buf.drain(..pos + 4).collect()); + } + } + + Ok(buf) +} diff --git a/tests/zstd.rs b/tests/zstd.rs index d1886ee49..ed3914e79 100644 --- a/tests/zstd.rs +++ b/tests/zstd.rs @@ -1,5 +1,6 @@ mod support; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn zstd_response() { @@ -142,3 +143,209 @@ async fn zstd_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: zstd\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn zstd_compress(input: &[u8]) -> Vec { + zstd_crate::encode_all(input, 3).unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", zstded_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &zstded_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +}