diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 39d5e0ea92..12ef93cd27 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -22,45 +22,3 @@ message = "Make `BehaviorVersion` be future-proof by disallowing it to be constr references = ["aws-sdk-rust#1111", "smithy-rs#3513"] meta = { "breaking" = true, "tada" = false, "bug" = true, "target" = "client" } author = "Ten0" - -[[smithy-rs]] -message = """ -Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following: - -```rust -let config = my_service::Config::builder() - .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) - // ... - .build(); -``` -""" -references = ["smithy-rs#3485"] -meta = { "breaking" = false, "tada" = true, "bug" = false } -authors = ["jdisanti"] - -[[aws-sdk-rust]] -message = """ -Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following: - -```rust -let config = aws_config::defaults(BehaviorVersion::latest()) - .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) - .load() - .await; -``` -""" -references = ["smithy-rs#3485"] -meta = { "breaking" = false, "tada" = true, "bug" = false } -author = "jdisanti" - -[[smithy-rs]] -message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit." -references = ["smithy-rs#3485"] -meta = { "breaking" = false, "tada" = false, "bug" = true } -authors = ["jdisanti"] - -[[aws-sdk-rust]] -message = "Stalled stream protection on downloads will now only trigger if the upstream source is too slow. Previously, stalled stream protection could be erroneously triggered if the user was slowly consuming the stream slower than the minimum speed limit." -references = ["smithy-rs#3485"] -meta = { "breaking" = false, "tada" = false, "bug" = true } -author = "jdisanti" diff --git a/aws/sdk/integration-tests/s3/Cargo.toml b/aws/sdk/integration-tests/s3/Cargo.toml index 0d8ec0a9bc..50ce1ae5c0 100644 --- a/aws/sdk/integration-tests/s3/Cargo.toml +++ b/aws/sdk/integration-tests/s3/Cargo.toml @@ -48,6 +48,3 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter", "json"] } # If you're writing a test with this, take heed! `no-env-filter` means you'll be capturing # logs from everything that speaks, so be specific with your asserts. tracing-test = { version = "0.2.4", features = ["no-env-filter"] } - -[dependencies] -pin-project-lite = "0.2.13" diff --git a/aws/sdk/integration-tests/s3/tests/body_size_hint.rs b/aws/sdk/integration-tests/s3/tests/body_size_hint.rs deleted file mode 100644 index 97e9ac7234..0000000000 --- a/aws/sdk/integration-tests/s3/tests/body_size_hint.rs +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -//! Body wrappers must pass through size_hint - -use aws_config::SdkConfig; -use aws_sdk_s3::{ - config::{Credentials, Region, SharedCredentialsProvider}, - primitives::{ByteStream, SdkBody}, - Client, -}; -use aws_smithy_runtime::client::http::test_util::{capture_request, infallible_client_fn}; -use http_body::Body; - -#[tokio::test] -async fn download_body_size_hint_check() { - let test_body_content = b"hello"; - let test_body = || SdkBody::from(&test_body_content[..]); - assert_eq!( - Some(test_body_content.len() as u64), - (test_body)().size_hint().exact(), - "pre-condition check" - ); - - let http_client = infallible_client_fn(move |_| { - http::Response::builder() - .status(200) - .body((test_body)()) - .unwrap() - }); - let sdk_config = SdkConfig::builder() - .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) - .region(Region::new("us-east-1")) - .http_client(http_client) - .build(); - let client = Client::new(&sdk_config); - let response = client - .get_object() - .bucket("foo") - .key("foo") - .send() - .await - .unwrap(); - assert_eq!( - ( - test_body_content.len() as u64, - Some(test_body_content.len() as u64), - ), - response.body.size_hint(), - "the size hint should be passed through all the default body wrappers" - ); -} - -#[tokio::test] -async fn upload_body_size_hint_check() { - let test_body_content = b"hello"; - - let (http_client, rx) = capture_request(None); - let sdk_config = SdkConfig::builder() - .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) - .region(Region::new("us-east-1")) - .http_client(http_client) - .build(); - let client = Client::new(&sdk_config); - let body = ByteStream::from_static(test_body_content); - assert_eq!( - ( - test_body_content.len() as u64, - Some(test_body_content.len() as u64), - ), - body.size_hint(), - "pre-condition check" - ); - let _response = client - .put_object() - .bucket("foo") - .key("foo") - .body(body) - .send() - .await; - let captured_request = rx.expect_request(); - assert_eq!( - Some(test_body_content.len() as u64), - captured_request.body().size_hint().exact(), - "the size hint should be passed through all the default body wrappers" - ); -} diff --git a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs index 21a224adfa..25008a415e 100644 --- a/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs +++ b/aws/sdk/integration-tests/s3/tests/stalled-stream-protection.rs @@ -4,90 +4,27 @@ */ use aws_credential_types::Credentials; -use aws_sdk_s3::{ - config::{Region, StalledStreamProtectionConfig}, - error::BoxError, -}; -use aws_sdk_s3::{error::DisplayErrorContext, primitives::ByteStream}; +use aws_sdk_s3::config::{Region, StalledStreamProtectionConfig}; +use aws_sdk_s3::primitives::ByteStream; use aws_sdk_s3::{Client, Config}; -use aws_smithy_runtime::{assert_str_contains, test_util::capture_test_logs::capture_test_logs}; -use aws_smithy_types::body::SdkBody; -use bytes::{Bytes, BytesMut}; -use http_body::Body; +use bytes::BytesMut; use std::error::Error; +use std::future::Future; +use std::net::SocketAddr; use std::time::Duration; -use std::{future::Future, task::Poll}; -use std::{net::SocketAddr, pin::Pin, task::Context}; -use tokio::{ - net::{TcpListener, TcpStream}, - time::sleep, -}; use tracing::debug; -enum SlowBodyState { - Wait(Pin + Send + Sync + 'static>>), - Send, - Taken, -} - -struct SlowBody { - state: SlowBodyState, -} - -impl SlowBody { - fn new() -> Self { - Self { - state: SlowBodyState::Send, - } - } -} - -impl Body for SlowBody { - type Data = Bytes; - type Error = BoxError; - - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - loop { - let mut state = SlowBodyState::Taken; - std::mem::swap(&mut state, &mut self.state); - match state { - SlowBodyState::Wait(mut fut) => match fut.as_mut().poll(cx) { - Poll::Ready(_) => self.state = SlowBodyState::Send, - Poll::Pending => { - self.state = SlowBodyState::Wait(fut); - return Poll::Pending; - } - }, - SlowBodyState::Send => { - self.state = SlowBodyState::Wait(Box::pin(sleep(Duration::from_micros(100)))); - return Poll::Ready(Some(Ok(Bytes::from_static( - b"data_data_data_data_data_data_data_data_data_data_data_data_\ - data_data_data_data_data_data_data_data_data_data_data_data_\ - data_data_data_data_data_data_data_data_data_data_data_data_\ - data_data_data_data_data_data_data_data_data_data_data_data_", - )))); - } - SlowBodyState::Taken => unreachable!(), - } - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } -} - +// This test doesn't work because we can't count on `hyper` to poll the body, +// regardless of whether we schedule a wake. To make this functionality work, +// we'd have to integrate more closely with the orchestrator. +// +// I'll leave this test here because we do eventually want to support stalled +// stream protection for uploads. +#[ignore] #[tokio::test] async fn test_stalled_stream_protection_defaults_for_upload() { - let _logs = capture_test_logs(); - - // We spawn a faulty server that will stop all request processing after reading half of the request body. + // We spawn a faulty server that will close the connection after + // writing half of the response body. let (server, server_addr) = start_faulty_upload_server().await; let _ = tokio::spawn(server); @@ -95,8 +32,7 @@ async fn test_stalled_stream_protection_defaults_for_upload() { .credentials_provider(Credentials::for_tests()) .region(Region::new("us-east-1")) .endpoint_url(format!("http://{server_addr}")) - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3510): make stalled stream protection enabled by default with BMV and remove this line - .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) + // .stalled_stream_protection(StalledStreamProtectionConfig::enabled().build()) .build(); let client = Client::from_conf(conf); @@ -104,19 +40,22 @@ async fn test_stalled_stream_protection_defaults_for_upload() { .put_object() .bucket("a-test-bucket") .key("stalled-stream-test.txt") - .body(ByteStream::new(SdkBody::from_body_0_4(SlowBody::new()))) + .body(ByteStream::from_static(b"Hello")) .send() .await .expect_err("upload stream stalled out"); - let err_msg = DisplayErrorContext(&err).to_string(); - assert_str_contains!( - err_msg, + let err = err.source().expect("inner error exists"); + assert_eq!( + err.to_string(), "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" ); } async fn start_faulty_upload_server() -> (impl Future, SocketAddr) { + use tokio::net::{TcpListener, TcpStream}; + use tokio::time::sleep; + let listener = TcpListener::bind("0.0.0.0:0") .await .expect("socket is free"); @@ -126,7 +65,12 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) let mut buf = BytesMut::new(); let mut time_to_stall = false; - while !time_to_stall { + loop { + if time_to_stall { + debug!("faulty server has read partial request, now getting stuck"); + break; + } + match socket.try_read_buf(&mut buf) { Ok(0) => { unreachable!( @@ -135,7 +79,12 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) } Ok(n) => { debug!("read {n} bytes from the socket"); + + // Check to see if we've received some headers if buf.len() >= 128 { + let s = String::from_utf8_lossy(&buf); + debug!("{s}"); + time_to_stall = true; } } @@ -149,7 +98,6 @@ async fn start_faulty_upload_server() -> (impl Future, SocketAddr) } } - debug!("faulty server has read partial request, now getting stuck"); loop { tokio::task::yield_now().await } @@ -281,11 +229,14 @@ async fn test_stalled_stream_protection_for_downloads_is_enabled_by_default() { err.to_string(), "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" ); - // the 1s check interval is included in the 5s grace period - assert_eq!(start.elapsed().as_secs(), 5); + // 1s check interval + 5s grace period + assert_eq!(start.elapsed().as_secs(), 6); } async fn start_faulty_download_server() -> (impl Future, SocketAddr) { + use tokio::net::{TcpListener, TcpStream}; + use tokio::time::sleep; + let listener = TcpListener::bind("0.0.0.0:0") .await .expect("socket is free"); diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt index 8304efc2c4..83c3b6dd6b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt @@ -120,12 +120,15 @@ class StalledStreamProtectionOperationCustomization( is OperationSection.AdditionalInterceptors -> { val stalledStreamProtectionModule = RuntimeType.smithyRuntime(rc).resolve("client::stalled_stream_protection") section.registerInterceptor(rc, this) { + // Currently, only response bodies are protected/supported because + // we can't count on hyper to poll a request body on wake. rustTemplate( """ - #{StalledStreamProtectionInterceptor}::default() + #{StalledStreamProtectionInterceptor}::new(#{Kind}::ResponseBody) """, *preludeScope, "StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"), + "Kind" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptorKind"), ) } } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs index f90f886592..25c9c5c67d 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs @@ -20,17 +20,15 @@ const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(5); /// When enabled, download streams that stall out will be cancelled. #[derive(Clone, Debug)] pub struct StalledStreamProtectionConfig { - upload_enabled: bool, - download_enabled: bool, + is_enabled: bool, grace_period: Duration, } impl StalledStreamProtectionConfig { - /// Create a new config that enables stalled stream protection for both uploads and downloads. + /// Create a new config that enables stalled stream protection. pub fn enabled() -> Builder { Builder { - upload_enabled: Some(true), - download_enabled: Some(true), + is_enabled: Some(true), grace_period: None, } } @@ -38,25 +36,14 @@ impl StalledStreamProtectionConfig { /// Create a new config that disables stalled stream protection. pub fn disabled() -> Self { Self { - upload_enabled: false, - download_enabled: false, + is_enabled: false, grace_period: DEFAULT_GRACE_PERIOD, } } - /// Return whether stalled stream protection is enabled for either uploads or downloads. + /// Return whether stalled stream protection is enabled. pub fn is_enabled(&self) -> bool { - self.upload_enabled || self.download_enabled - } - - /// True if stalled stream protection is enabled for upload streams. - pub fn upload_enabled(&self) -> bool { - self.upload_enabled - } - - /// True if stalled stream protection is enabled for download streams. - pub fn download_enabled(&self) -> bool { - self.download_enabled + self.is_enabled } /// Return the grace period for stalled stream protection. @@ -70,8 +57,7 @@ impl StalledStreamProtectionConfig { #[derive(Clone, Debug)] pub struct Builder { - upload_enabled: Option, - download_enabled: Option, + is_enabled: Option, grace_period: Option, } @@ -88,48 +74,22 @@ impl Builder { self } - /// Set whether stalled stream protection is enabled for both uploads and downloads. - pub fn is_enabled(mut self, enabled: bool) -> Self { - self.set_is_enabled(Some(enabled)); - self - } - - /// Set whether stalled stream protection is enabled for both uploads and downloads. - pub fn set_is_enabled(&mut self, enabled: Option) -> &mut Self { - self.set_upload_enabled(enabled); - self.set_download_enabled(enabled); - self - } - - /// Set whether stalled stream protection is enabled for upload streams. - pub fn upload_enabled(mut self, enabled: bool) -> Self { - self.set_upload_enabled(Some(enabled)); - self - } - - /// Set whether stalled stream protection is enabled for upload streams. - pub fn set_upload_enabled(&mut self, enabled: Option) -> &mut Self { - self.upload_enabled = enabled; - self - } - - /// Set whether stalled stream protection is enabled for download streams. - pub fn download_enabled(mut self, enabled: bool) -> Self { - self.set_download_enabled(Some(enabled)); + /// Set whether stalled stream protection is enabled. + pub fn is_enabled(mut self, is_enabled: bool) -> Self { + self.is_enabled = Some(is_enabled); self } - /// Set whether stalled stream protection is enabled for download streams. - pub fn set_download_enabled(&mut self, enabled: Option) -> &mut Self { - self.download_enabled = enabled; + /// Set whether stalled stream protection is enabled. + pub fn set_is_enabled(&mut self, is_enabled: Option) -> &mut Self { + self.is_enabled = is_enabled; self } /// Build the config. pub fn build(self) -> StalledStreamProtectionConfig { StalledStreamProtectionConfig { - upload_enabled: self.upload_enabled.unwrap_or_default(), - download_enabled: self.download_enabled.unwrap_or_default(), + is_enabled: self.is_enabled.unwrap_or_default(), grace_period: self.grace_period.unwrap_or(DEFAULT_GRACE_PERIOD), } } @@ -138,8 +98,7 @@ impl Builder { impl From for Builder { fn from(config: StalledStreamProtectionConfig) -> Self { Builder { - upload_enabled: Some(config.upload_enabled), - download_enabled: Some(config.download_enabled), + is_enabled: Some(config.is_enabled), grace_period: Some(config.grace_period), } } diff --git a/rust-runtime/aws-smithy-runtime/Cargo.toml b/rust-runtime/aws-smithy-runtime/Cargo.toml index 179cac47eb..276add22ef 100644 --- a/rust-runtime/aws-smithy-runtime/Cargo.toml +++ b/rust-runtime/aws-smithy-runtime/Cargo.toml @@ -43,7 +43,7 @@ serde_json = { version = "1", features = ["preserve_order"], optional = true } indexmap = { version = "2", optional = true, features = ["serde"] } tokio = { version = "1.25", features = [] } tracing = "0.1.37" -tracing-subscriber = { version = "0.3.16", optional = true, features = ["env-filter", "fmt", "json"] } +tracing-subscriber = { version = "0.3.16", optional = true, features = ["fmt", "json"] } [dev-dependencies] approx = "0.5.1" diff --git a/rust-runtime/aws-smithy-runtime/src/client/defaults.rs b/rust-runtime/aws-smithy-runtime/src/client/defaults.rs index 99f549e542..5a23e3f3db 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/defaults.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/defaults.rs @@ -171,16 +171,7 @@ pub fn default_identity_cache_plugin() -> Option { /// /// By default, when throughput falls below 1/Bs for more than 5 seconds, the /// stream is cancelled. -#[deprecated( - since = "1.2.0", - note = "This function wasn't intended to be public, and didn't take the behavior major version as an argument, so it couldn't be evolved over time." -)] pub fn default_stalled_stream_protection_config_plugin() -> Option { - default_stalled_stream_protection_config_plugin_v2(BehaviorVersion::v2023_11_09()) -} -fn default_stalled_stream_protection_config_plugin_v2( - _behavior_version: BehaviorVersion, -) -> Option { Some( default_plugin( "default_stalled_stream_protection_config_plugin", @@ -193,8 +184,6 @@ fn default_stalled_stream_protection_config_plugin_v2( .with_config(layer("default_stalled_stream_protection_config", |layer| { layer.store_put( StalledStreamProtectionConfig::enabled() - // TODO(https://github.com/smithy-lang/smithy-rs/issues/3510): enable behind new behavior version - .upload_enabled(false) .grace_period(Duration::from_secs(5)) .build(), ); @@ -270,10 +259,6 @@ impl DefaultPluginParams { pub fn default_plugins( params: DefaultPluginParams, ) -> impl IntoIterator { - let behavior_version = params - .behavior_version - .unwrap_or_else(BehaviorVersion::latest); - [ default_http_client_plugin(), default_identity_cache_plugin(), @@ -287,7 +272,7 @@ pub fn default_plugins( default_timeout_config_plugin(), // TODO(https://github.com/smithy-lang/smithy-rs/issues/3523): Reenable this /* enforce_content_length_runtime_plugin(), */ - default_stalled_stream_protection_config_plugin_v2(behavior_version), + default_stalled_stream_protection_config_plugin(), ] .into_iter() .flatten() diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs index 59c8a3c64c..c576a34afa 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs @@ -15,46 +15,25 @@ pub mod options; pub use throughput::Throughput; mod throughput; -use crate::client::http::body::minimum_throughput::throughput::ThroughputReport; use aws_smithy_async::rt::sleep::Sleep; use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep}; use aws_smithy_async::time::{SharedTimeSource, TimeSource}; -use aws_smithy_runtime_api::{ - box_error::BoxError, - client::{ - http::HttpConnectorFuture, result::ConnectorError, runtime_components::RuntimeComponents, - stalled_stream_protection::StalledStreamProtectionConfig, - }, -}; -use aws_smithy_runtime_api::{client::orchestrator::HttpResponse, shared::IntoShared}; -use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace}; +use aws_smithy_runtime_api::box_error::BoxError; +use aws_smithy_runtime_api::shared::IntoShared; use options::MinimumThroughputBodyOptions; -use std::{ - fmt, - sync::{Arc, Mutex}, - task::Poll, -}; -use std::{future::Future, pin::Pin}; -use std::{ - task::Context, - time::{Duration, SystemTime}, -}; +use std::fmt; +use std::time::SystemTime; use throughput::ThroughputLogs; -/// Use [`MinimumThroughputDownloadBody`] instead. -#[deprecated(note = "Renamed to MinimumThroughputDownloadBody since it doesn't work for uploads")] -pub type MinimumThroughputBody = MinimumThroughputDownloadBody; - pin_project_lite::pin_project! { /// A body-wrapping type that ensures data is being streamed faster than some lower limit. /// /// If data is being streamed too slowly, this body type will emit an error next time it's polled. - pub struct MinimumThroughputDownloadBody { + pub struct MinimumThroughputBody { async_sleep: SharedAsyncSleep, time_source: SharedTimeSource, options: MinimumThroughputBodyOptions, throughput_logs: ThroughputLogs, - resolution: Duration, #[pin] sleep_fut: Option, #[pin] @@ -64,7 +43,10 @@ pin_project_lite::pin_project! { } } -impl MinimumThroughputDownloadBody { +const SIZE_OF_ONE_LOG: usize = std::mem::size_of::<(SystemTime, u64)>(); // 24 bytes per log +const NUMBER_OF_LOGS_IN_ONE_KB: f64 = 1024.0 / SIZE_OF_ONE_LOG as f64; + +impl MinimumThroughputBody { /// Create a new minimum throughput body. pub fn new( time_source: impl TimeSource + 'static, @@ -72,15 +54,14 @@ impl MinimumThroughputDownloadBody { body: B, options: MinimumThroughputBodyOptions, ) -> Self { - let time_source: SharedTimeSource = time_source.into_shared(); - let now = time_source.now(); - let throughput_logs = ThroughputLogs::new(options.check_window(), now); - let resolution = throughput_logs.resolution(); Self { - throughput_logs, - resolution, + throughput_logs: ThroughputLogs::new( + // Never keep more than 10KB of logs in memory. This currently + // equates to 426 logs. + (NUMBER_OF_LOGS_IN_ONE_KB * 10.0) as usize, + ), async_sleep: async_sleep.into_shared(), - time_source, + time_source: time_source.into_shared(), inner: body, sleep_fut: None, grace_period_fut: None, @@ -112,286 +93,4 @@ impl fmt::Display for Error { impl std::error::Error for Error {} -/// Used to store the upload throughput in the interceptor context. -#[derive(Clone, Debug)] -pub(crate) struct UploadThroughput { - logs: Arc>, -} - -impl UploadThroughput { - pub(crate) fn new(time_window: Duration, now: SystemTime) -> Self { - Self { - logs: Arc::new(Mutex::new(ThroughputLogs::new(time_window, now))), - } - } - - pub(crate) fn resolution(&self) -> Duration { - self.logs.lock().unwrap().resolution() - } - - pub(crate) fn push_pending(&self, now: SystemTime) { - self.logs.lock().unwrap().push_pending(now); - } - pub(crate) fn push_bytes_transferred(&self, now: SystemTime, bytes: u64) { - self.logs.lock().unwrap().push_bytes_transferred(now, bytes); - } - - pub(crate) fn report(&self, now: SystemTime) -> ThroughputReport { - self.logs.lock().unwrap().report(now) - } -} - -impl Storable for UploadThroughput { - type Storer = StoreReplace; -} - -pin_project_lite::pin_project! { - pub(crate) struct ThroughputReadingBody { - time_source: SharedTimeSource, - throughput: UploadThroughput, - #[pin] - inner: B, - } -} - -impl ThroughputReadingBody { - pub(crate) fn new( - time_source: SharedTimeSource, - throughput: UploadThroughput, - body: B, - ) -> Self { - Self { - time_source, - throughput, - inner: body, - } - } -} - -const ZERO_THROUGHPUT: Throughput = Throughput::new_bytes_per_second(0); - -// Helper trait for interpretting the throughput report. -trait UploadReport { - fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput); -} -impl UploadReport for ThroughputReport { - fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) { - let throughput = match self { - // If the report is incomplete, then we don't have enough data yet to - // decide if minimum throughput was violated. - ThroughputReport::Incomplete => { - tracing::trace!( - "not enough data to decide if minimum throughput has been violated" - ); - return (false, ZERO_THROUGHPUT); - } - // If most of the datapoints are Poll::Pending, then the user has stalled. - // In this case, we don't want to say minimum throughput was violated. - ThroughputReport::Pending => { - tracing::debug!( - "the user has stalled; this will not become a minimum throughput violation" - ); - return (false, ZERO_THROUGHPUT); - } - // If there has been no polling, then the server has stalled. Alternatively, - // if we're transferring data, but it's too slow, then we also want to say - // that the minimum throughput has been violated. - ThroughputReport::NoPolling => ZERO_THROUGHPUT, - ThroughputReport::Transferred(tp) => tp, - }; - if throughput < minimum_throughput { - tracing::debug!( - "current throughput: {throughput} is below minimum: {minimum_throughput}" - ); - (true, throughput) - } else { - (false, throughput) - } - } -} - -pin_project_lite::pin_project! { - /// Future that pairs with [`UploadThroughput`] to add a minimum throughput - /// requirement to a request upload stream. - struct UploadThroughputCheckFuture { - #[pin] - response: HttpConnectorFuture, - #[pin] - check_interval: Option, - #[pin] - grace_period: Option, - - time_source: SharedTimeSource, - sleep_impl: SharedAsyncSleep, - upload_throughput: UploadThroughput, - resolution: Duration, - options: MinimumThroughputBodyOptions, - - failing_throughput: Option, - } -} - -impl UploadThroughputCheckFuture { - fn new( - response: HttpConnectorFuture, - time_source: SharedTimeSource, - sleep_impl: SharedAsyncSleep, - upload_throughput: UploadThroughput, - options: MinimumThroughputBodyOptions, - ) -> Self { - let resolution = upload_throughput.resolution(); - Self { - response, - check_interval: Some(sleep_impl.sleep(resolution)), - grace_period: None, - time_source, - sleep_impl, - upload_throughput, - resolution, - options, - failing_throughput: None, - } - } -} - -impl Future for UploadThroughputCheckFuture { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - if let Poll::Ready(output) = this.response.poll(cx) { - return Poll::Ready(output); - } else { - let mut below_minimum_throughput = false; - let check_interval_expired = this - .check_interval - .as_mut() - .as_pin_mut() - .expect("always set") - .poll(cx) - .is_ready(); - if check_interval_expired { - // Set up the next check interval - *this.check_interval = Some(this.sleep_impl.sleep(*this.resolution)); - - // Wake so that the check interval future gets polled - // next time this poll method is called. If it never gets polled, - // then this task won't be woken to check again. - cx.waker().wake_by_ref(); - } - - let should_check = check_interval_expired || this.grace_period.is_some(); - if should_check { - let now = this.time_source.now(); - let report = this.upload_throughput.report(now); - let (violated, current_throughput) = - report.minimum_throughput_violated(this.options.minimum_throughput()); - below_minimum_throughput = violated; - if below_minimum_throughput && !this.failing_throughput.is_some() { - *this.failing_throughput = Some(current_throughput); - } else if !below_minimum_throughput { - *this.failing_throughput = None; - } - } - - // If we kicked off a grace period and are now satisfied, clear out the grace period - if !below_minimum_throughput && this.grace_period.is_some() { - tracing::debug!("upload minimum throughput recovered during grace period"); - *this.grace_period = None; - } - if below_minimum_throughput { - // Start a grace period if below minimum throughput - if this.grace_period.is_none() { - tracing::debug!( - grace_period=?this.options.grace_period(), - "upload minimum throughput below configured minimum; starting grace period" - ); - *this.grace_period = Some(this.sleep_impl.sleep(this.options.grace_period())); - } - // Check the grace period if one is already set and we're not satisfied - if let Some(grace_period) = this.grace_period.as_pin_mut() { - if grace_period.poll(cx).is_ready() { - tracing::debug!("grace period ended; timing out request"); - return Poll::Ready(Err(ConnectorError::timeout( - Error::ThroughputBelowMinimum { - expected: this.options.minimum_throughput(), - actual: this - .failing_throughput - .expect("always set if there's a grace period"), - } - .into(), - ))); - } - } - } - } - Poll::Pending - } -} - -pin_project_lite::pin_project! { - #[project = EnumProj] - pub(crate) enum MaybeUploadThroughputCheckFuture { - Direct { #[pin] future: HttpConnectorFuture }, - Checked { #[pin] future: UploadThroughputCheckFuture }, - } -} - -impl MaybeUploadThroughputCheckFuture { - pub(crate) fn new( - cfg: &mut ConfigBag, - components: &RuntimeComponents, - connector_future: HttpConnectorFuture, - ) -> Self { - if let Some(sspcfg) = cfg.load::().cloned() { - if sspcfg.is_enabled() { - let options = MinimumThroughputBodyOptions::from(sspcfg); - return Self::new_inner( - connector_future, - components.time_source(), - components.sleep_impl(), - cfg.interceptor_state().load::().cloned(), - Some(options), - ); - } - } - tracing::debug!("no minimum upload throughput checks"); - Self::new_inner(connector_future, None, None, None, None) - } - - fn new_inner( - response: HttpConnectorFuture, - time_source: Option, - sleep_impl: Option, - upload_throughput: Option, - options: Option, - ) -> Self { - match (time_source, sleep_impl, upload_throughput, options) { - (Some(time_source), Some(sleep_impl), Some(upload_throughput), Some(options)) => { - tracing::debug!(options=?options, "applying minimum upload throughput check future"); - Self::Checked { - future: UploadThroughputCheckFuture::new( - response, - time_source, - sleep_impl, - upload_throughput, - options, - ), - } - } - _ => Self::Direct { future: response }, - } - } -} - -impl Future for MaybeUploadThroughputCheckFuture { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project() { - EnumProj::Direct { future } => future.poll(cx), - EnumProj::Checked { future } => future.poll(cx), - } - } -} +// Tests are implemented per HTTP body type. diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs index a8f2fe9c4b..075ef39d63 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs @@ -3,58 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -use super::{BoxError, Error, MinimumThroughputDownloadBody}; -use crate::client::http::body::minimum_throughput::{ - throughput::ThroughputReport, Throughput, ThroughputReadingBody, -}; +use super::{BoxError, Error, MinimumThroughputBody}; use aws_smithy_async::rt::sleep::AsyncSleep; use http_body_0_4::Body; use std::future::Future; use std::pin::{pin, Pin}; use std::task::{Context, Poll}; -const ZERO_THROUGHPUT: Throughput = Throughput::new_bytes_per_second(0); - -// Helper trait for interpretting the throughput report. -trait DownloadReport { - fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput); -} -impl DownloadReport for ThroughputReport { - fn minimum_throughput_violated(self, minimum_throughput: Throughput) -> (bool, Throughput) { - let throughput = match self { - // If the report is incomplete, then we don't have enough data yet to - // decide if minimum throughput was violated. - ThroughputReport::Incomplete => { - tracing::trace!( - "not enough data to decide if minimum throughput has been violated" - ); - return (false, ZERO_THROUGHPUT); - } - // If no polling is taking place, then the user has stalled. - // In this case, we don't want to say minimum throughput was violated. - ThroughputReport::NoPolling => { - tracing::debug!( - "the user has stalled; this will not become a minimum throughput violation" - ); - return (false, ZERO_THROUGHPUT); - } - // If we're stuck in Poll::Pending, then the server has stalled. Alternatively, - // if we're transferring data, but it's too slow, then we also want to say - // that the minimum throughput has been violated. - ThroughputReport::Pending => ZERO_THROUGHPUT, - ThroughputReport::Transferred(tp) => tp, - }; - let violated = throughput < minimum_throughput; - if violated { - tracing::debug!( - "current throughput: {throughput} is below minimum: {minimum_throughput}" - ); - } - (violated, throughput) - } -} - -impl Body for MinimumThroughputDownloadBody +impl Body for MinimumThroughputBody where B: Body, { @@ -74,13 +30,12 @@ where let poll_res = match this.inner.poll_data(cx) { Poll::Ready(Some(Ok(bytes))) => { tracing::trace!("received data: {}", bytes.len()); - this.throughput_logs - .push_bytes_transferred(now, bytes.len() as u64); + this.throughput_logs.push((now, bytes.len() as u64)); Poll::Ready(Some(Ok(bytes))) } Poll::Pending => { tracing::trace!("received poll pending"); - this.throughput_logs.push_pending(now); + this.throughput_logs.push((now, 0)); Poll::Pending } // If we've read all the data or an error occurred, then return that result. @@ -91,27 +46,44 @@ where let mut sleep_fut = this .sleep_fut .take() - .unwrap_or_else(|| this.async_sleep.sleep(*this.resolution)); + .unwrap_or_else(|| this.async_sleep.sleep(this.options.check_interval())); if let Poll::Ready(()) = pin!(&mut sleep_fut).poll(cx) { tracing::trace!("sleep future triggered—triggering a wakeup"); // Whenever the sleep future expires, we replace it. - sleep_fut = this.async_sleep.sleep(*this.resolution); + sleep_fut = this.async_sleep.sleep(this.options.check_interval()); // We also schedule a wake up for current task to ensure that // it gets polled at least one more time. cx.waker().wake_by_ref(); }; this.sleep_fut.replace(sleep_fut); + let calculated_tpt = match this + .throughput_logs + .calculate_throughput(now, this.options.check_window()) + { + Some(tpt) => tpt, + None => { + tracing::trace!("calculated throughput is None!"); + return poll_res; + } + }; + tracing::trace!( + "calculated throughput {:?} (window: {:?})", + calculated_tpt, + this.options.check_window() + ); // Calculate the current throughput and emit an error if it's too low and // the grace period has elapsed. - let report = this.throughput_logs.report(now); - let (violated, current_throughput) = - report.minimum_throughput_violated(this.options.minimum_throughput()); - if violated { - if this.grace_period_fut.is_none() { - tracing::debug!("entering minimum throughput grace period"); - } + let is_below_minimum_throughput = calculated_tpt <= this.options.minimum_throughput(); + if is_below_minimum_throughput { + // Check the grace period future to see if it needs creating. + tracing::trace!( + in_grace_period = this.grace_period_fut.is_some(), + observed_throughput = ?calculated_tpt, + minimum_throughput = ?this.options.minimum_throughput(), + "below minimum throughput" + ); let mut grace_period_fut = this .grace_period_fut .take() @@ -120,16 +92,13 @@ where // The grace period has ended! return Poll::Ready(Some(Err(Box::new(Error::ThroughputBelowMinimum { expected: self.options.minimum_throughput(), - actual: current_throughput, + actual: calculated_tpt, })))); }; this.grace_period_fut.replace(grace_period_fut); } else { // Ensure we don't have an active grace period future if we're not // currently below the minimum throughput. - if this.grace_period_fut.is_some() { - tracing::debug!("throughput recovered; exiting grace period"); - } let _ = this.grace_period_fut.take(); } @@ -143,63 +112,290 @@ where let this = self.as_mut().project(); this.inner.poll_trailers(cx) } +} + +// These tests use `hyper::body::Body::wrap_stream` +#[cfg(all(test, feature = "connector-hyper-0-14-x", feature = "test-util"))] +mod test { + use super::{super::Throughput, Error, MinimumThroughputBody}; + use crate::client::http::body::minimum_throughput::options::MinimumThroughputBodyOptions; + use crate::test_util::capture_test_logs::capture_test_logs; + use aws_smithy_async::rt::sleep::AsyncSleep; + use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep, ManualTimeSource}; + use aws_smithy_types::body::SdkBody; + use aws_smithy_types::byte_stream::{AggregatedBytes, ByteStream}; + use aws_smithy_types::error::display::DisplayErrorContext; + use bytes::{BufMut, Bytes, BytesMut}; + use http::HeaderMap; + use http_body_0_4::Body; + use once_cell::sync::Lazy; + use pretty_assertions::assert_eq; + use std::convert::Infallible; + use std::error::Error as StdError; + use std::future::{poll_fn, Future}; + use std::pin::{pin, Pin}; + use std::task::{Context, Poll}; + use std::time::{Duration, UNIX_EPOCH}; + + struct NeverBody; + + impl Body for NeverBody { + type Data = Bytes; + type Error = Box<(dyn StdError + Send + Sync + 'static)>; + + fn poll_data( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + Poll::Pending + } - fn size_hint(&self) -> http_body_0_4::SizeHint { - self.inner.size_hint() + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + unreachable!("body can't be read, so this won't be called") + } } - fn is_end_stream(&self) -> bool { - self.inner.is_end_stream() + #[tokio::test()] + async fn test_self_waking() { + let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH); + let mut body = MinimumThroughputBody::new( + time_source.clone(), + async_sleep.clone(), + NeverBody, + Default::default(), + ); + time_source.advance(Duration::from_secs(1)); + let actual_err = body.data().await.expect("next chunk exists").unwrap_err(); + let expected_err = Error::ThroughputBelowMinimum { + expected: (1, Duration::from_secs(1)).into(), + actual: (0, Duration::from_secs(1)).into(), + }; + + assert_eq!(expected_err.to_string(), actual_err.to_string()); } -} -impl Body for ThroughputReadingBody -where - B: Body, -{ - type Data = bytes::Bytes; - type Error = BoxError; + fn create_test_stream( + async_sleep: impl AsyncSleep + Clone, + ) -> impl futures_util::Stream> { + futures_util::stream::unfold(1, move |state| { + let async_sleep = async_sleep.clone(); + async move { + if state > 255 { + None + } else { + async_sleep.sleep(Duration::from_secs(1)).await; + Some(( + Result::<_, Infallible>::Ok(Bytes::from_static(b"00000000")), + state + 1, + )) + } + } + }) + } - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - // this code is called quite frequently in production—one every millisecond or so when downloading - // a stream. However, SystemTime::now is on the order of nanoseconds - let now = self.time_source.now(); - // Attempt to read the data from the inner body, then update the - // throughput logs. - let this = self.as_mut().project(); - match this.inner.poll_data(cx) { - Poll::Ready(Some(Ok(bytes))) => { - tracing::trace!("received data: {}", bytes.len()); - this.throughput - .push_bytes_transferred(now, bytes.len() as u64); - Poll::Ready(Some(Ok(bytes))) + static EXPECTED_BYTES: Lazy> = + Lazy::new(|| (1..=255).flat_map(|_| b"00000000").copied().collect()); + + fn eight_byte_per_second_stream_with_minimum_throughput_timeout( + minimum_throughput: Throughput, + ) -> ( + impl Future>, + ManualTimeSource, + InstantSleep, + ) { + let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH); + let time_clone = time_source.clone(); + + // Will send ~8 bytes per second. + let stream = create_test_stream(async_sleep.clone()); + let body = ByteStream::new(SdkBody::from_body_0_4(hyper_0_14::body::Body::wrap_stream( + stream, + ))); + let body = body.map(move |body| { + let time_source = time_clone.clone(); + // We don't want to log these sleeps because it would duplicate + // the `sleep` calls being logged by the MTB + let async_sleep = InstantSleep::unlogged(); + SdkBody::from_body_0_4(MinimumThroughputBody::new( + time_source, + async_sleep, + body, + MinimumThroughputBodyOptions::builder() + .minimum_throughput(minimum_throughput) + .build(), + )) + }); + + (body.collect(), time_source, async_sleep) + } + + async fn expect_error(minimum_throughput: Throughput) { + let (res, ..) = + eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput); + let expected_err = Error::ThroughputBelowMinimum { + expected: minimum_throughput, + actual: Throughput::new(8, Duration::from_secs(1)), + }; + match res.await { + Ok(_) => { + panic!( + "response succeeded instead of returning the expected error '{expected_err}'" + ) } - Poll::Pending => { - tracing::trace!("received poll pending"); - this.throughput.push_pending(now); - Poll::Pending + Err(actual_err) => { + assert_eq!( + expected_err.to_string(), + // We need to source this so that we don't get the streaming error it's wrapped in. + actual_err.source().unwrap().to_string() + ); } - // If we've read all the data or an error occurred, then return that result. - res => res, } } - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let this = self.as_mut().project(); - this.inner.poll_trailers(cx) + #[tokio::test] + async fn test_throughput_timeout_less_than() { + let minimum_throughput = Throughput::new_bytes_per_second(9); + expect_error(minimum_throughput).await; } - fn size_hint(&self) -> http_body_0_4::SizeHint { - self.inner.size_hint() + async fn expect_success(minimum_throughput: Throughput) { + let (res, time_source, async_sleep) = + eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput); + match res.await { + Ok(res) => { + assert_eq!(255.0, time_source.seconds_since_unix_epoch()); + assert_eq!(Duration::from_secs(255), async_sleep.total_duration()); + assert_eq!(*EXPECTED_BYTES, res.to_vec()); + } + Err(err) => panic!("{}", DisplayErrorContext(err.source().unwrap())), + } } - fn is_end_stream(&self) -> bool { - self.inner.is_end_stream() + #[tokio::test] + async fn test_throughput_timeout_equal_to() { + let (_guard, _) = capture_test_logs(); + // a tiny bit less. To capture 0-throughput properly, we need to allow 0 to be 0 + let minimum_throughput = Throughput::new(31, Duration::from_secs(4)); + expect_success(minimum_throughput).await; + } + + #[tokio::test] + async fn test_throughput_timeout_greater_than() { + let minimum_throughput = Throughput::new(20, Duration::from_secs(3)); + expect_success(minimum_throughput).await; + } + + // A multiplier for the sine wave amplitude; Chosen arbitrarily. + const BYTE_COUNT_UPPER_LIMIT: u64 = 1000; + + /// emits 1000B/S for 5 seconds then suddenly stops + fn sudden_stop( + async_sleep: impl AsyncSleep + Clone, + ) -> impl futures_util::Stream> { + let sleep_dur = Duration::from_millis(50); + fastrand::seed(0); + futures_util::stream::unfold(1, move |i| { + let async_sleep = async_sleep.clone(); + async move { + let number_seconds = (i * sleep_dur).as_secs_f64(); + async_sleep.sleep(sleep_dur).await; + if number_seconds > 5.0 { + Some((Result::::Ok(Bytes::new()), i + 1)) + } else { + let mut bytes = BytesMut::new(); + let bytes_per_segment = + (BYTE_COUNT_UPPER_LIMIT as f64) * sleep_dur.as_secs_f64(); + for _ in 0..bytes_per_segment as usize { + bytes.put_u8(0) + } + + Some((Result::::Ok(bytes.into()), i + 1)) + } + } + }) + } + + #[tokio::test] + async fn test_stalled_stream_detection() { + test_suddenly_stopping_stream(0, Duration::from_secs(6)).await + } + + #[tokio::test] + async fn test_slow_stream_detection() { + test_suddenly_stopping_stream(BYTE_COUNT_UPPER_LIMIT / 2, Duration::from_secs_f64(5.50)) + .await + } + + #[tokio::test] + async fn test_check_interval() { + let (_guard, _) = capture_test_logs(); + let (ts, sleep) = instant_time_and_sleep(UNIX_EPOCH); + let mut body = MinimumThroughputBody::new( + ts, + sleep.clone(), + NeverBody, + MinimumThroughputBodyOptions::builder() + .check_interval(Duration::from_millis(1234)) + .grace_period(Duration::from_millis(456)) + .build(), + ); + let mut body = pin!(body); + let _ = poll_fn(|cx| body.as_mut().poll_data(cx)).await; + assert_eq!( + sleep.logs(), + vec![ + // sleep, by second sleep we know we have no data, then the grace period + Duration::from_millis(1234), + Duration::from_millis(1234), + Duration::from_millis(456) + ] + ); + } + + async fn test_suddenly_stopping_stream(throughput_limit: u64, time_until_timeout: Duration) { + let (_guard, _) = capture_test_logs(); + let options = MinimumThroughputBodyOptions::builder() + // Minimum throughput per second will be approx. half of the BYTE_COUNT_UPPER_LIMIT. + .minimum_throughput(Throughput::new_bytes_per_second(throughput_limit)) + .build(); + let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH); + let time_clone = time_source.clone(); + + let stream = sudden_stop(async_sleep.clone()); + let body = ByteStream::new(SdkBody::from_body_0_4(hyper_0_14::body::Body::wrap_stream( + stream, + ))); + let res = body + .map(move |body| { + let time_source = time_clone.clone(); + // We don't want to log these sleeps because it would duplicate + // the `sleep` calls being logged by the MTB + let async_sleep = InstantSleep::unlogged(); + SdkBody::from_body_0_4(MinimumThroughputBody::new( + time_source, + async_sleep, + body, + options.clone(), + )) + }) + .collect(); + + match res.await { + Ok(_res) => { + panic!("stream should have timed out"); + } + Err(err) => { + dbg!(err); + assert_eq!( + async_sleep.total_duration(), + time_until_timeout, + "With throughput limit {:?} expected timeout after {:?} (stream starts sending 0's at 5 seconds.", + throughput_limit, time_until_timeout + ); + } + } } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs index 113461a31e..4c8fc1177b 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs @@ -12,7 +12,6 @@ use std::time::Duration; pub struct MinimumThroughputBodyOptions { /// The minimum throughput that is acceptable. minimum_throughput: Throughput, - /// The 'grace period' after which the minimum throughput will be enforced. /// /// If this is set to 0, the minimum throughput will be enforced immediately. @@ -25,6 +24,9 @@ pub struct MinimumThroughputBodyOptions { /// stream-startup. grace_period: Duration, + /// The interval at which the throughput is checked. + check_interval: Duration, + /// The period of time to consider when computing the throughput /// /// This SHOULD be longer than the check interval, or stuck-streams may evade detection. @@ -42,6 +44,7 @@ impl MinimumThroughputBodyOptions { MinimumThroughputBodyOptionsBuilder::new() .minimum_throughput(self.minimum_throughput) .grace_period(self.grace_period) + .check_interval(self.check_interval) } /// The throughput check grace period. @@ -62,10 +65,12 @@ impl MinimumThroughputBodyOptions { self.check_window } - /// Not used. Always returns `Duration::from_millis(500)`. - #[deprecated(note = "No longer used. Always returns Duration::from_millis(500)")] + /// The rate at which the throughput is checked. + /// + /// The actual rate throughput is checked may be higher than this value, + /// but it will never be lower. pub fn check_interval(&self) -> Duration { - Duration::from_millis(500) + self.check_interval } } @@ -74,6 +79,7 @@ impl Default for MinimumThroughputBodyOptions { Self { minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT, grace_period: DEFAULT_GRACE_PERIOD, + check_interval: DEFAULT_CHECK_INTERVAL, check_window: DEFAULT_CHECK_WINDOW, } } @@ -83,10 +89,11 @@ impl Default for MinimumThroughputBodyOptions { #[derive(Debug, Default, Clone)] pub struct MinimumThroughputBodyOptionsBuilder { minimum_throughput: Option, - check_window: Option, + check_interval: Option, grace_period: Option, } +const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_millis(500); const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(0); const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput { bytes_read: 1, @@ -129,30 +136,19 @@ impl MinimumThroughputBodyOptionsBuilder { self } - /// No longer used. The check interval is now based on the check window (not currently configurable). - #[deprecated( - note = "No longer used. The check interval is now based on the check window (not currently configurable). Open an issue if you need to configure the check window." - )] - pub fn check_interval(self, _check_interval: Duration) -> Self { - self - } - - /// No longer used. The check interval is now based on the check window (not currently configurable). - #[deprecated( - note = "No longer used. The check interval is now based on the check window (not currently configurable). Open an issue if you need to configure the check window." - )] - pub fn set_check_interval(&mut self, _check_interval: Option) -> &mut Self { + /// Set the rate at which throughput is checked. + /// + /// Defaults to 1 second. + pub fn check_interval(mut self, check_interval: Duration) -> Self { + self.set_check_interval(Some(check_interval)); self } - #[allow(unused)] - pub(crate) fn check_window(mut self, check_window: Duration) -> Self { - self.set_check_window(Some(check_window)); - self - } - #[allow(unused)] - pub(crate) fn set_check_window(&mut self, check_window: Option) -> &mut Self { - self.check_window = check_window; + /// Set the rate at which throughput is checked. + /// + /// Defaults to 1 second. + pub fn set_check_interval(&mut self, check_interval: Option) -> &mut Self { + self.check_interval = check_interval; self } @@ -165,7 +161,8 @@ impl MinimumThroughputBodyOptionsBuilder { minimum_throughput: self .minimum_throughput .unwrap_or(DEFAULT_MINIMUM_THROUGHPUT), - check_window: self.check_window.unwrap_or(DEFAULT_CHECK_WINDOW), + check_interval: self.check_interval.unwrap_or(DEFAULT_CHECK_INTERVAL), + check_window: DEFAULT_CHECK_WINDOW, } } } @@ -175,6 +172,7 @@ impl From for MinimumThroughputBodyOptions { MinimumThroughputBodyOptions { grace_period: value.grace_period(), minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT, + check_interval: DEFAULT_CHECK_INTERVAL, check_window: DEFAULT_CHECK_WINDOW, } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index 57ea3318e7..e2a9b294e6 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -3,12 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +use std::collections::VecDeque; use std::fmt; use std::time::{Duration, SystemTime}; /// Throughput representation for use when configuring [`super::MinimumThroughputBody`] #[derive(Debug, Clone, Copy)] -#[cfg_attr(test, derive(Eq))] pub struct Throughput { pub(super) bytes_read: u64, pub(super) per_time_elapsed: Duration, @@ -29,7 +29,7 @@ impl Throughput { } /// Create a new throughput in bytes per second. - pub const fn new_bytes_per_second(bytes: u64) -> Self { + pub fn new_bytes_per_second(bytes: u64) -> Self { Self { bytes_read: bytes, per_time_elapsed: Duration::from_secs(1), @@ -37,7 +37,7 @@ impl Throughput { } /// Create a new throughput in kilobytes per second. - pub const fn new_kilobytes_per_second(kilobytes: u64) -> Self { + pub fn new_kilobytes_per_second(kilobytes: u64) -> Self { Self { bytes_read: kilobytes * 1000, per_time_elapsed: Duration::from_secs(1), @@ -45,7 +45,7 @@ impl Throughput { } /// Create a new throughput in megabytes per second. - pub const fn new_megabytes_per_second(megabytes: u64) -> Self { + pub fn new_megabytes_per_second(megabytes: u64) -> Self { Self { bytes_read: megabytes * 1000 * 1000, per_time_elapsed: Duration::from_secs(1), @@ -97,288 +97,90 @@ impl From<(u64, Duration)> for Throughput { } } -/// Overall label for a given bin. -#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] -enum BinLabel { - // IMPORTANT: The order of these enums matters since it represents their priority: - // Pending > TransferredBytes > NoPolling > Empty - // - /// There is no data in this bin. - Empty, - - /// No polling took place during this bin. - NoPolling, - - /// This many bytes were transferred during this bin. - TransferredBytes, - - /// The user/remote was not providing/consuming data fast enough during this bin. - /// - /// The number is the number of bytes transferred, if this replaced TransferredBytes. - Pending, -} - -/// Represents a bin (or a cell) in a linear grid that represents a small chunk of time. -#[derive(Copy, Clone, Debug)] -struct Bin { - label: BinLabel, - bytes: u64, -} - -impl Bin { - const fn new(label: BinLabel, bytes: u64) -> Self { - Self { label, bytes } - } - const fn empty() -> Self { - Self::new(BinLabel::Empty, 0) - } - - fn is_empty(&self) -> bool { - matches!(self.label, BinLabel::Empty) - } - - fn merge(&mut self, other: Bin) -> &mut Self { - // Assign values based on this priority order (highest priority higher up): - // 1. Pending - // 2. TransferredBytes - // 3. NoPolling - // 4. Empty - self.label = if other.label > self.label { - other.label - } else { - self.label - }; - self.bytes += other.bytes; - self - } - - /// Number of bytes transferred during this bin - fn bytes(&self) -> u64 { - self.bytes - } -} - -#[derive(Copy, Clone, Debug, Default)] -struct BinCounts { - /// Number of bins with no data. - empty: usize, - /// Number of "no polling" bins. - no_polling: usize, - /// Number of "bytes transferred" bins. - transferred: usize, - /// Number of "pending" bins. - pending: usize, -} - -/// Underlying stack-allocated linear grid buffer for tracking -/// throughput events for [`ThroughputLogs`]. -#[derive(Copy, Clone, Debug)] -struct LogBuffer { - entries: [Bin; N], - // The length only needs to exist so that the `fill_gaps` function - // can differentiate between `Empty` due to there not having been enough - // time to establish a full buffer worth of data vs. `Empty` due to a - // polling gap. Once the length reaches N, it will never change again. - length: usize, -} -impl LogBuffer { - fn new() -> Self { - Self { - entries: [Bin::empty(); N], - length: 0, - } - } - - /// Mutably returns the tail of the buffer. - /// - /// ## Panics - /// - /// The buffer MUST have at least one bin in it before this is called. - fn tail_mut(&mut self) -> &mut Bin { - debug_assert!(self.length > 0); - &mut self.entries[self.length - 1] - } - - /// Pushes a bin into the buffer. If the buffer is already full, - /// then this will rotate the entire buffer to the left. - fn push(&mut self, bin: Bin) { - if self.filled() { - self.entries.rotate_left(1); - self.entries[N - 1] = bin; - } else { - self.entries[self.length] = bin; - self.length += 1; - } - } - - /// Returns the total number of bytes transferred within the time window. - fn bytes_transferred(&self) -> u64 { - self.entries.iter().take(self.length).map(Bin::bytes).sum() - } - - #[inline] - fn filled(&self) -> bool { - self.length == N - } - - /// Fills in missing NoData entries. - /// - /// We want NoData entries to represent when a future hasn't been polled. - /// Since the future is in charge of logging in the first place, the only - /// way we can know about these is by examining gaps in time. - fn fill_gaps(&mut self) { - for entry in self.entries.iter_mut().take(self.length) { - if entry.is_empty() { - *entry = Bin::new(BinLabel::NoPolling, 0); - } - } - } - - /// Returns the counts of each bin type in the buffer. - fn counts(&self) -> BinCounts { - let mut counts = BinCounts::default(); - for entry in &self.entries { - match entry.label { - BinLabel::Empty => counts.empty += 1, - BinLabel::NoPolling => counts.no_polling += 1, - BinLabel::TransferredBytes => counts.transferred += 1, - BinLabel::Pending => counts.pending += 1, - } - } - counts - } -} - -/// Report/summary of all the events in a time window. -#[cfg_attr(test, derive(Debug, Eq, PartialEq))] -pub(crate) enum ThroughputReport { - /// Not enough data to draw any conclusions. This happens early in a request/response. - Incomplete, - /// The stream hasn't been polled for most of this time window. - NoPolling, - /// The stream has been waiting for most of the time window. - Pending, - /// The stream transferred this amount of throughput during the time window. - Transferred(Throughput), -} - -const BIN_COUNT: usize = 10; - -/// Log of throughput in a request or response stream. -/// -/// Used to determine if a configured minimum throughput is being met or not -/// so that a request or response stream can be timed out in the event of a -/// stall. -/// -/// Request/response streams push data transfer or pending events to this log -/// based on what's going on in their poll functions. The log tracks three kinds -/// of events despite only receiving two: the third is "no polling". The poll -/// functions cannot know when they're not being polled, so the log examines gaps -/// in the event history to know when no polling took place. -/// -/// The event logging is simplified down to a linear grid consisting of 10 "bins", -/// with each bin representing 1/10th the total time window. When an event is pushed, -/// it is either merged into the current tail bin, or all the bins are rotated -/// left to create a new empty tail bin, and then it is merged into that one. -#[derive(Clone, Debug)] +#[derive(Clone)] pub(super) struct ThroughputLogs { - resolution: Duration, - current_tail: SystemTime, - buffer: LogBuffer, + max_length: usize, + inner: VecDeque<(SystemTime, u64)>, + bytes_processed: u64, } impl ThroughputLogs { - /// Creates a new log starting at `now` with the given `time_window`. - /// - /// Note: the `time_window` gets divided by 10 to create smaller sub-windows - /// to track throughput. The time window should be configured to be large enough - /// so that these sub-windows aren't too small for network-based events. - /// A time window of 10ms probably won't work, but 500ms might. The default - /// is one second. - pub(super) fn new(time_window: Duration, now: SystemTime) -> Self { - assert!(!time_window.is_zero()); - let resolution = time_window.div_f64(BIN_COUNT as f64); + pub(super) fn new(max_length: usize) -> Self { Self { - resolution, - current_tail: now, - buffer: LogBuffer::new(), + inner: VecDeque::with_capacity(max_length), + max_length, + bytes_processed: 0, } } - /// Returns the resolution at which events are logged at. - /// - /// The resolution is the number of bins in the time window. - pub(super) fn resolution(&self) -> Duration { - self.resolution - } - - /// Pushes a "pending" event. - /// - /// Pending indicates the streaming future is waiting for something. - /// In an upload, it is waiting for data from the user, and in a download, - /// it is waiting for data from the server. - pub(super) fn push_pending(&mut self, time: SystemTime) { - self.push(time, Bin::new(BinLabel::Pending, 0)); - } - - /// Pushes a data transferred event. - /// - /// Indicates that this number of bytes were transferred at this time. - pub(super) fn push_bytes_transferred(&mut self, time: SystemTime, bytes: u64) { - self.push(time, Bin::new(BinLabel::TransferredBytes, bytes)); - } - - fn push(&mut self, now: SystemTime, value: Bin) { - self.catch_up(now); - self.buffer.tail_mut().merge(value); - self.buffer.fill_gaps(); - } - - /// Pushes empty bins until `current_tail` is caught up to `now`. - fn catch_up(&mut self, now: SystemTime) { - while now >= self.current_tail { - self.current_tail += self.resolution; - self.buffer.push(Bin::empty()); + pub(super) fn push(&mut self, throughput: (SystemTime, u64)) { + // When the number of logs exceeds the max length, toss the oldest log. + if self.inner.len() == self.max_length { + self.bytes_processed -= self.inner.pop_front().map(|(_, sz)| sz).unwrap_or_default(); } - assert!(self.current_tail >= now); - } - /// Generates an overall report of the time window. - pub(super) fn report(&mut self, now: SystemTime) -> ThroughputReport { - self.catch_up(now); - self.buffer.fill_gaps(); - - let BinCounts { - empty, - no_polling, - transferred, - pending, - } = self.buffer.counts(); - - // If there are any empty cells at all, then we haven't been tracking - // long enough to make any judgements about the stream's progress. - if empty > 0 { - return ThroughputReport::Incomplete; + debug_assert!(self.inner.capacity() > self.inner.len()); + self.bytes_processed += throughput.1; + self.inner.push_back(throughput); + } + + fn buffer_full(&self) -> bool { + self.inner.len() == self.max_length + } + + pub(super) fn calculate_throughput( + &self, + now: SystemTime, + time_window: Duration, + ) -> Option { + // There are a lot of pathological cases that are 0 throughput. These cases largely shouldn't + // happen, because the check interval MUST be less than the check window + let total_length = self + .inner + .iter() + .last()? + .0 + .duration_since(self.inner.front()?.0) + .ok()?; + // during a "healthy" request we'll only have a few milliseconds of logs (shorter than the check window) + if total_length < time_window { + // if we haven't hit our requested time window & the buffer still isn't full, then + // return `None` — this is the "startup grace period" + return if !self.buffer_full() { + None + } else { + // Otherwise, if the entire buffer fits in the timewindow, we can the shortcut to + // avoid recomputing all the data + Some(Throughput { + bytes_read: self.bytes_processed, + per_time_elapsed: total_length, + }) + }; } + let minimum_ts = now - time_window; + let first_item = self.inner.iter().find(|(ts, _)| *ts >= minimum_ts)?.0; - let bytes = self.buffer.bytes_transferred(); - let time = self.resolution * (BIN_COUNT - empty) as u32; - let throughput = Throughput::new(bytes, time); + let time_elapsed = now.duration_since(first_item).unwrap_or_default(); - let half = BIN_COUNT / 2; - match (transferred > 0, no_polling >= half, pending >= half) { - (true, _, _) => ThroughputReport::Transferred(throughput), - (_, true, _) => ThroughputReport::NoPolling, - (_, _, true) => ThroughputReport::Pending, - _ => ThroughputReport::Incomplete, - } + let total_bytes_logged = self + .inner + .iter() + .rev() + .take_while(|(ts, _)| *ts > minimum_ts) + .map(|t| t.1) + .sum::(); + + Some(Throughput { + bytes_read: total_bytes_logged, + per_time_elapsed: time_elapsed, + }) } } #[cfg(test)] mod test { - use super::*; - use std::time::Duration; + use super::{Throughput, ThroughputLogs}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[test] fn test_throughput_eq() { @@ -390,146 +192,92 @@ mod test { assert_eq!(t2, t3); } - #[test] - fn incomplete_no_entries() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - let report = logs.report(start); - assert_eq!(ThroughputReport::Incomplete, report); - } - - #[test] - fn incomplete_with_entries() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - logs.push_pending(start); + fn build_throughput_log( + length: u32, + tick_duration: Duration, + rate: u64, + ) -> (ThroughputLogs, SystemTime) { + let mut throughput_logs = ThroughputLogs::new(length as usize); + for i in 1..=length { + throughput_logs.push((UNIX_EPOCH + (tick_duration * i), rate)); + } - let report = logs.report(start + Duration::from_millis(300)); - assert_eq!(ThroughputReport::Incomplete, report); + assert_eq!(length as usize, throughput_logs.inner.len()); + (throughput_logs, UNIX_EPOCH + (tick_duration * length)) } - #[test] - fn incomplete_with_transferred() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - logs.push_pending(start); - logs.push_bytes_transferred(start + Duration::from_millis(100), 10); - - let report = logs.report(start + Duration::from_millis(300)); - assert_eq!(ThroughputReport::Incomplete, report); + const EPSILON: f64 = 0.001; + macro_rules! assert_delta { + ($x:expr, $y:expr, $d:expr) => { + if !(($x as f64) - $y < $d || $y - ($x as f64) < $d) { + panic!(); + } + }; } #[test] - fn push_pending_at_the_beginning_of_each_tick() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - - let mut now = start; - for i in 1..=BIN_COUNT { - logs.push_pending(now); - now += logs.resolution(); - - assert_eq!(i, logs.buffer.counts().pending); + fn test_throughput_log_calculate_throughput_1() { + let (throughput_logs, now) = build_throughput_log(1000, Duration::from_secs(1), 1); + + for dur in [10, 100, 100] { + let throughput = throughput_logs + .calculate_throughput(now, Duration::from_secs(dur)) + .unwrap(); + assert_eq!(1.0, throughput.bytes_per_second()); } - - let report = dbg!(&mut logs).report(now); - assert_eq!(ThroughputReport::Pending, report); + let throughput = throughput_logs + .calculate_throughput(now, Duration::from_secs_f64(101.5)) + .unwrap(); + assert_delta!(1, throughput.bytes_per_second(), EPSILON); } #[test] - fn push_pending_at_the_end_of_each_tick() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - - let mut now = start; - for i in 1..BIN_COUNT { - now += logs.resolution(); - logs.push_pending(now); - - assert_eq!(i, dbg!(&logs).buffer.counts().pending); - assert_eq!(0, logs.buffer.counts().transferred); - assert_eq!(1, logs.buffer.counts().no_polling); - } - // This should replace the initial "no polling" bin - now += logs.resolution(); - logs.push_pending(now); - assert_eq!(0, logs.buffer.counts().no_polling); + fn test_throughput_log_calculate_throughput_2() { + let (throughput_logs, now) = build_throughput_log(1000, Duration::from_secs(5), 5); - let report = dbg!(&mut logs).report(now); - assert_eq!(ThroughputReport::Pending, report); + let throughput = throughput_logs + .calculate_throughput(now, Duration::from_secs(1000)) + .unwrap(); + assert_eq!(1.0, throughput.bytes_per_second()); } #[test] - fn push_transferred_at_the_beginning_of_each_tick() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - - let mut now = start; - for i in 1..=BIN_COUNT { - logs.push_bytes_transferred(now, 10); - if i != BIN_COUNT { - now += logs.resolution(); - } + fn test_throughput_log_calculate_throughput_3() { + let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(200), 1024); - assert_eq!(i, logs.buffer.counts().transferred); - assert_eq!(0, logs.buffer.counts().pending); - assert_eq!(0, logs.buffer.counts().no_polling); - } - - let report = dbg!(&mut logs).report(now); - assert_eq!( - ThroughputReport::Transferred(Throughput::new(100, Duration::from_secs(1))), - report - ); + let throughput = throughput_logs + .calculate_throughput(now, Duration::from_secs(5)) + .unwrap(); + let expected_throughput = 1024.0 * 5.0; + assert_eq!(expected_throughput, throughput.bytes_per_second()); } #[test] - fn no_polling() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - let report = logs.report(start + Duration::from_secs(2)); - assert_eq!(ThroughputReport::NoPolling, report); - } + fn test_throughput_log_calculate_throughput_4() { + let (throughput_logs, now) = build_throughput_log(1000, Duration::from_millis(100), 12); - // Transferred bytes MUST take priority over pending - #[test] - fn mixed_bag_mostly_pending() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - - logs.push_bytes_transferred(start + Duration::from_millis(50), 10); - logs.push_pending(start + Duration::from_millis(150)); - logs.push_pending(start + Duration::from_millis(250)); - logs.push_bytes_transferred(start + Duration::from_millis(350), 10); - logs.push_pending(start + Duration::from_millis(450)); - // skip 550 - logs.push_pending(start + Duration::from_millis(650)); - logs.push_pending(start + Duration::from_millis(750)); - logs.push_pending(start + Duration::from_millis(850)); - - let report = logs.report(start + Duration::from_millis(999)); - assert_eq!( - ThroughputReport::Transferred(Throughput::new_bytes_per_second(20)), - report - ); + let throughput = throughput_logs + .calculate_throughput(now, Duration::from_secs(1)) + .unwrap(); + let expected_throughput = 12.0 * 10.0; + + assert_eq!(expected_throughput, throughput.bytes_per_second()); } #[test] - fn mixed_bag_mostly_pending_no_transferred() { - let start = SystemTime::UNIX_EPOCH; - let mut logs = ThroughputLogs::new(Duration::from_secs(1), start); - - logs.push_pending(start + Duration::from_millis(50)); - logs.push_pending(start + Duration::from_millis(150)); - logs.push_pending(start + Duration::from_millis(250)); - // skip 350 - logs.push_pending(start + Duration::from_millis(450)); - // skip 550 - logs.push_pending(start + Duration::from_millis(650)); - logs.push_pending(start + Duration::from_millis(750)); - logs.push_pending(start + Duration::from_millis(850)); - - let report = logs.report(start + Duration::from_millis(999)); - assert_eq!(ThroughputReport::Pending, report); + fn test_throughput_followed_by_0() { + let tick = Duration::from_millis(100); + let (mut throughput_logs, now) = build_throughput_log(1000, tick, 12); + let throughput = throughput_logs + .calculate_throughput(now, Duration::from_secs(1)) + .unwrap(); + let expected_throughput = 12.0 * 10.0; + + assert_eq!(expected_throughput, throughput.bytes_per_second()); + throughput_logs.push((now + tick, 0)); + let throughput = throughput_logs + .calculate_throughput(now + tick, Duration::from_secs(1)) + .unwrap(); + assert_eq!(108.0, throughput.bytes_per_second()); } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index 112fbe85ba..f8bbc2c05c 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -5,12 +5,9 @@ use self::auth::orchestrate_auth; use crate::client::interceptors::Interceptors; +use crate::client::orchestrator::endpoints::orchestrate_endpoint; use crate::client::orchestrator::http::{log_response_body, read_body}; use crate::client::timeout::{MaybeTimeout, MaybeTimeoutConfig, TimeoutKind}; -use crate::client::{ - http::body::minimum_throughput::MaybeUploadThroughputCheckFuture, - orchestrator::endpoints::orchestrate_endpoint, -}; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_runtime_api::box_error::BoxError; use aws_smithy_runtime_api::client::http::{HttpClient, HttpConnector, HttpConnectorSettings}; @@ -388,12 +385,7 @@ async fn try_attempt( builder.build() }; let connector = http_client.http_connector(&settings, runtime_components); - let response_future = MaybeUploadThroughputCheckFuture::new( - cfg, - runtime_components, - connector.call(request), - ); - response_future.await.map_err(OrchestratorError::connector) + connector.call(request).await.map_err(OrchestratorError::connector) }); trace!(response = ?response, "received response from service"); ctx.set_response(response); diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs index e761cf601f..bd875c72e6 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs @@ -12,6 +12,7 @@ use crate::client::orchestrator::endpoints::StaticUriEndpointResolver; use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy}; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_async::time::TimeSource; +use aws_smithy_runtime_api::box_error::BoxError; use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver; use aws_smithy_runtime_api::client::auth::{ AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver, @@ -34,9 +35,6 @@ use aws_smithy_runtime_api::client::ser_de::{ DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer, }; use aws_smithy_runtime_api::shared::IntoShared; -use aws_smithy_runtime_api::{ - box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig, -}; use aws_smithy_types::config_bag::{ConfigBag, Layer}; use aws_smithy_types::retry::RetryConfig; use aws_smithy_types::timeout::TimeoutConfig; @@ -295,15 +293,6 @@ impl OperationBuilder { self } - /// Configures stalled stream protection with the given config. - pub fn stalled_stream_protection( - mut self, - stalled_stream_protection: StalledStreamProtectionConfig, - ) -> Self { - self.config.store_put(stalled_stream_protection); - self - } - /// Configures the serializer for the builder. pub fn serializer( mut self, @@ -350,28 +339,6 @@ impl OperationBuilder { } } - /// Configures the a deserializer implementation for the builder. - pub fn deserializer_impl( - mut self, - deserializer: impl DeserializeResponse + Send + Sync + 'static, - ) -> OperationBuilder - where - O2: fmt::Debug + Send + Sync + 'static, - E2: std::error::Error + fmt::Debug + Send + Sync + 'static, - { - let deserializer: SharedResponseDeserializer = deserializer.into_shared(); - self.config.store_put(deserializer); - - OperationBuilder { - service_name: self.service_name, - operation_name: self.operation_name, - config: self.config, - runtime_components: self.runtime_components, - runtime_plugins: self.runtime_plugins, - _phantom: Default::default(), - } - } - /// Creates an `Operation` from the builder. pub fn build(self) -> Operation { let service_name = self.service_name.expect("service_name required"); diff --git a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs index 83cfb64752..3e07b3f0b8 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs @@ -3,10 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::client::http::body::minimum_throughput::{ - options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody, - UploadThroughput, -}; +use crate::client::http::body::minimum_throughput::MinimumThroughputBody; use aws_smithy_async::rt::sleep::SharedAsyncSleep; use aws_smithy_async::time::SharedTimeSource; use aws_smithy_runtime_api::box_error::BoxError; @@ -21,16 +18,14 @@ use aws_smithy_types::config_bag::ConfigBag; use std::mem; /// Adds stalled stream protection when sending requests and/or receiving responses. -#[derive(Debug, Default)] -#[non_exhaustive] -pub struct StalledStreamProtectionInterceptor; +#[derive(Debug)] +pub struct StalledStreamProtectionInterceptor { + enable_for_request_body: bool, + enable_for_response_body: bool, +} /// Stalled stream protection can be enable for request bodies, response bodies, /// or both. -#[deprecated( - since = "1.2.0", - note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag." -)] pub enum StalledStreamProtectionInterceptorKind { /// Enable stalled stream protection for request bodies. RequestBody, @@ -42,13 +37,18 @@ pub enum StalledStreamProtectionInterceptorKind { impl StalledStreamProtectionInterceptor { /// Create a new stalled stream protection interceptor. - #[deprecated( - since = "1.2.0", - note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default." - )] - #[allow(deprecated)] - pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self { - Default::default() + pub fn new(kind: StalledStreamProtectionInterceptorKind) -> Self { + use StalledStreamProtectionInterceptorKind::*; + let (enable_for_request_body, enable_for_response_body) = match kind { + RequestBody => (true, false), + ResponseBody => (false, true), + RequestAndResponseBody => (true, true), + }; + + Self { + enable_for_request_body, + enable_for_response_body, + } } } @@ -63,26 +63,19 @@ impl Intercept for StalledStreamProtectionInterceptor { runtime_components: &RuntimeComponents, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - if let Some(sspcfg) = cfg.load::().cloned() { - if sspcfg.upload_enabled() { - let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; - let now = time_source.now(); - - let options: MinimumThroughputBodyOptions = sspcfg.into(); - let throughput = UploadThroughput::new(options.check_window(), now); - cfg.interceptor_state().store_put(throughput.clone()); - - tracing::trace!("adding stalled stream protection to request body"); - let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken()); - let it = it.map_preserve_contents(move |body| { - let time_source = time_source.clone(); - SdkBody::from_body_0_4(ThroughputReadingBody::new( + if self.enable_for_request_body { + if let Some(cfg) = cfg.load::() { + if cfg.is_enabled() { + let (async_sleep, time_source) = + get_runtime_component_deps(runtime_components)?; + tracing::trace!("adding stalled stream protection to request body"); + add_stalled_stream_protection_to_body( + context.request_mut().body_mut(), + cfg, + async_sleep, time_source, - throughput.clone(), - body, - )) - }); - let _ = mem::replace(context.request_mut().body_mut(), it); + ); + } } } @@ -95,25 +88,19 @@ impl Intercept for StalledStreamProtectionInterceptor { runtime_components: &RuntimeComponents, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - if let Some(sspcfg) = cfg.load::() { - if sspcfg.download_enabled() { - let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?; - tracing::trace!("adding stalled stream protection to response body"); - let sspcfg = sspcfg.clone(); - let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken()); - let it = it.map_preserve_contents(move |body| { - let sspcfg = sspcfg.clone(); - let async_sleep = async_sleep.clone(); - let time_source = time_source.clone(); - let mtb = MinimumThroughputDownloadBody::new( - time_source, + if self.enable_for_response_body { + if let Some(cfg) = cfg.load::() { + if cfg.is_enabled() { + let (async_sleep, time_source) = + get_runtime_component_deps(runtime_components)?; + tracing::trace!("adding stalled stream protection to response body"); + add_stalled_stream_protection_to_body( + context.response_mut().body_mut(), + cfg, async_sleep, - body, - sspcfg.into(), + time_source, ); - SdkBody::from_body_0_4(mtb) - }); - let _ = mem::replace(context.response_mut().body_mut(), it); + } } } Ok(()) @@ -131,3 +118,21 @@ fn get_runtime_component_deps( .ok_or("A time source is required when stalled stream protection is enabled")?; Ok((async_sleep, time_source)) } + +fn add_stalled_stream_protection_to_body( + body: &mut SdkBody, + cfg: &StalledStreamProtectionConfig, + async_sleep: SharedAsyncSleep, + time_source: SharedTimeSource, +) { + let cfg = cfg.clone(); + let it = mem::replace(body, SdkBody::taken()); + let it = it.map_preserve_contents(move |body| { + let cfg = cfg.clone(); + let async_sleep = async_sleep.clone(); + let time_source = time_source.clone(); + let mtb = MinimumThroughputBody::new(time_source, async_sleep, body, cfg.into()); + SdkBody::from_body_0_4(mtb) + }); + let _ = mem::replace(body, it); +} diff --git a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs index d5447c98ab..92b450c115 100644 --- a/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs +++ b/rust-runtime/aws-smithy-runtime/src/test_util/capture_test_logs.rs @@ -14,29 +14,6 @@ use tracing_subscriber::fmt::TestWriter; #[derive(Debug)] pub struct LogCaptureGuard(#[allow(dead_code)] DefaultGuard); -/// Enables output of test logs to stdout at trace level by default. -/// -/// The env filter can be changed with the `RUST_LOG` environment variable. -#[must_use] -pub fn show_test_logs() -> LogCaptureGuard { - let (mut writer, _rx) = Tee::stdout(); - writer.loud(); - - let env_var = env::var("RUST_LOG").ok(); - let env_filter = env_var.as_deref().unwrap_or("trace"); - eprintln!( - "Enabled verbose test logging with env filter {env_filter:?}. \ - You can change the env filter with the RUST_LOG environment variable." - ); - - let subscriber = tracing_subscriber::fmt() - .with_env_filter(env_filter) - .with_writer(Mutex::new(writer)) - .finish(); - let guard = tracing::subscriber::set_default(subscriber); - LogCaptureGuard(guard) -} - /// Capture logs from this test. /// /// The logs will be captured until the `DefaultGuard` is dropped. diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs deleted file mode 100644 index 3596fa2e38..0000000000 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_common.rs +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -#![cfg(all(feature = "client", feature = "test-util"))] - -pub use aws_smithy_async::{ - test_util::tick_advance_sleep::{ - tick_advance_time_and_sleep, TickAdvanceSleep, TickAdvanceTime, - }, - time::TimeSource, -}; -pub use aws_smithy_runtime::{ - assert_str_contains, - client::{ - orchestrator::operation::Operation, - stalled_stream_protection::StalledStreamProtectionInterceptor, - }, - test_util::capture_test_logs::show_test_logs, -}; -pub use aws_smithy_runtime_api::{ - box_error::BoxError, - client::{ - http::{ - HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, - SharedHttpConnector, - }, - interceptors::context::{Error, Output}, - orchestrator::{HttpRequest, HttpResponse, OrchestratorError}, - result::SdkError, - runtime_components::RuntimeComponents, - ser_de::DeserializeResponse, - stalled_stream_protection::StalledStreamProtectionConfig, - }, - http::{Response, StatusCode}, - shared::IntoShared, -}; -pub use aws_smithy_types::{ - body::SdkBody, error::display::DisplayErrorContext, timeout::TimeoutConfig, -}; -pub use bytes::Bytes; -pub use http_body_0_4::Body; -pub use pin_utils::pin_mut; -pub use std::{ - collections::VecDeque, - convert::Infallible, - future::poll_fn, - mem, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll}, - time::Duration, -}; -pub use tracing::{info, Instrument as _}; - -/// No really, it's 42 bytes long... super neat -pub const NEAT_DATA: Bytes = Bytes::from_static(b"some really neat data"); - -/// Ticks time forward by the given duration, and logs the current time for debugging. -#[macro_export] -macro_rules! tick { - ($ticker:ident, $duration:expr) => { - $ticker.tick($duration).await; - let now = $ticker - .now() - .duration_since(std::time::SystemTime::UNIX_EPOCH) - .unwrap(); - tracing::info!("ticked {:?}, now at {:?}", $duration, now); - }; -} - -#[derive(Debug)] -pub struct FakeServer(pub SharedHttpConnector); -impl HttpClient for FakeServer { - fn http_connector( - &self, - _settings: &HttpConnectorSettings, - _components: &RuntimeComponents, - ) -> SharedHttpConnector { - self.0.clone() - } -} - -struct ChannelBody { - receiver: tokio::sync::mpsc::Receiver, -} -impl http_body_0_4::Body for ChannelBody { - type Data = Bytes; - type Error = Infallible; - - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.receiver.poll_recv(cx) { - Poll::Ready(value) => Poll::Ready(value.map(|v| Ok(v))), - Poll::Pending => Poll::Pending, - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - unreachable!() - } -} - -pub fn channel_body() -> (SdkBody, tokio::sync::mpsc::Sender) { - let (sender, receiver) = tokio::sync::mpsc::channel(1000); - (SdkBody::from_body_0_4(ChannelBody { receiver }), sender) -} diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs deleted file mode 100644 index 54e953322c..0000000000 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_download.rs +++ /dev/null @@ -1,297 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -#![cfg(all(feature = "client", feature = "test-util"))] - -use std::time::Duration; - -#[macro_use] -mod stalled_stream_common; -use stalled_stream_common::*; - -/// Scenario: Successfully download at a rate above the minimum throughput. -/// Expected: MUST NOT timeout. -#[tokio::test] -async fn download_success() { - let _logs = show_test_logs(); - - let (time, sleep) = tick_advance_time_and_sleep(); - let (server, response_sender) = channel_server(); - let op = operation(server, time.clone(), sleep); - - let server = tokio::spawn(async move { - for _ in 1..100 { - response_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - } - drop(response_sender); - tick!(time, Duration::from_secs(1)); - }); - - let response_body = op.invoke(()).await.expect("initial success"); - let result = eagerly_consume(response_body).await; - server.await.unwrap(); - - result.ok().expect("response MUST NOT timeout"); -} - -/// Scenario: Download takes a some time to start, but then goes normally. -/// Expected: MUT NOT timeout. -#[tokio::test] -async fn download_slow_start() { - let _logs = show_test_logs(); - - let (time, sleep) = tick_advance_time_and_sleep(); - let (server, response_sender) = channel_server(); - let op = operation(server, time.clone(), sleep); - - let server = tokio::spawn(async move { - // Delay almost to the end of the grace period before sending anything - tick!(time, Duration::from_secs(4)); - for _ in 1..100 { - response_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - } - drop(response_sender); - tick!(time, Duration::from_secs(1)); - }); - - let response_body = op.invoke(()).await.expect("initial success"); - let result = eagerly_consume(response_body).await; - server.await.unwrap(); - - result.ok().expect("response MUST NOT timeout"); -} - -/// Scenario: Download starts fine, and then slowly falls below minimum throughput. -/// Expected: MUST timeout. -#[tokio::test] -async fn download_too_slow() { - let _logs = show_test_logs(); - - let (time, sleep) = tick_advance_time_and_sleep(); - let (server, response_sender) = channel_server(); - let op = operation(server, time.clone(), sleep); - - let server = tokio::spawn(async move { - // Get slower with every poll - for delay in 1..100 { - let _ = response_sender.send(NEAT_DATA).await; - tick!(time, Duration::from_secs(delay)); - } - drop(response_sender); - tick!(time, Duration::from_secs(1)); - }); - - let response_body = op.invoke(()).await.expect("initial success"); - let result = eagerly_consume(response_body).await; - server.await.unwrap(); - - let err = result.expect_err("should have timed out"); - assert_str_contains!( - DisplayErrorContext(err.as_ref()).to_string(), - "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" - ); -} - -/// Scenario: Download starts fine, and then the server stalls and stops sending data. -/// Expected: MUST timeout. -#[tokio::test] -async fn download_stalls() { - let _logs = show_test_logs(); - - let (time, sleep) = tick_advance_time_and_sleep(); - let (server, response_sender) = channel_server(); - let op = operation(server, time.clone(), sleep); - - let server = tokio::spawn(async move { - for _ in 1..10 { - response_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - } - tick!(time, Duration::from_secs(10)); - }); - - let response_body = op.invoke(()).await.expect("initial success"); - let result = tokio::spawn(eagerly_consume(response_body)); - server.await.unwrap(); - - let err = result - .await - .expect("no panics") - .expect_err("should have timed out"); - assert_str_contains!( - DisplayErrorContext(err.as_ref()).to_string(), - "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" - ); -} - -/// Scenario: Download starts fine, but then the server stalls for a time within the -/// grace period. Following that, it starts sending data again. -/// Expected: MUST NOT timeout. -#[tokio::test] -async fn download_stall_recovery_in_grace_period() { - let _logs = show_test_logs(); - - let (time, sleep) = tick_advance_time_and_sleep(); - let (server, response_sender) = channel_server(); - let op = operation(server, time.clone(), sleep); - - let server = tokio::spawn(async move { - for _ in 1..10 { - response_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - } - // Delay almost to the end of the grace period - tick!(time, Duration::from_secs(4)); - // And now recover - for _ in 1..10 { - response_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - } - drop(response_sender); - tick!(time, Duration::from_secs(1)); - }); - - let response_body = op.invoke(()).await.expect("initial success"); - let result = eagerly_consume(response_body).await; - server.await.unwrap(); - - result.ok().expect("response MUST NOT timeout"); -} - -/// Scenario: The server sends data fast enough, but the customer doesn't consume the -/// data fast enough. -/// Expected: MUST NOT timeout. -#[tokio::test] -async fn user_downloads_data_too_slowly() { - let _logs = show_test_logs(); - - let (time, sleep) = tick_advance_time_and_sleep(); - let (server, response_sender) = channel_server(); - let op = operation(server, time.clone(), sleep); - - let server = tokio::spawn(async move { - for _ in 1..100 { - response_sender.send(NEAT_DATA).await.unwrap(); - } - drop(response_sender); - }); - - let response_body = op.invoke(()).await.expect("initial success"); - let result = slowly_consume(time, response_body).await; - server.await.unwrap(); - - result.ok().expect("response MUST NOT timeout"); -} - -use download_test_tools::*; -mod download_test_tools { - use crate::stalled_stream_common::*; - - fn response(body: SdkBody) -> HttpResponse { - HttpResponse::try_from(http::Response::builder().status(200).body(body).unwrap()).unwrap() - } - - pub fn operation( - http_connector: impl HttpConnector + 'static, - time: TickAdvanceTime, - sleep: TickAdvanceSleep, - ) -> Operation<(), SdkBody, Infallible> { - #[derive(Debug)] - struct Deserializer; - impl DeserializeResponse for Deserializer { - fn deserialize_streaming( - &self, - response: &mut HttpResponse, - ) -> Option>> { - let mut body = SdkBody::taken(); - mem::swap(response.body_mut(), &mut body); - Some(Ok(Output::erase(body))) - } - - fn deserialize_nonstreaming( - &self, - _: &HttpResponse, - ) -> Result> { - unreachable!() - } - } - - let operation = Operation::builder() - .service_name("test") - .operation_name("test") - .http_client(FakeServer(http_connector.into_shared())) - .endpoint_url("http://localhost:1234/doesntmatter") - .no_auth() - .no_retry() - .timeout_config(TimeoutConfig::disabled()) - .serializer(|_body: ()| Ok(HttpRequest::new(SdkBody::empty()))) - .deserializer_impl(Deserializer) - .stalled_stream_protection( - StalledStreamProtectionConfig::enabled() - .grace_period(Duration::from_secs(5)) - .build(), - ) - .interceptor(StalledStreamProtectionInterceptor::default()) - .sleep_impl(sleep) - .time_source(time) - .build(); - operation - } - - /// Fake server/connector that responds with a channel body. - pub fn channel_server() -> (SharedHttpConnector, tokio::sync::mpsc::Sender) { - #[derive(Debug)] - struct FakeServerConnector { - body: Arc>>, - } - impl HttpConnector for FakeServerConnector { - fn call(&self, _request: HttpRequest) -> HttpConnectorFuture { - let body = self.body.lock().unwrap().take().unwrap(); - HttpConnectorFuture::new(async move { Ok(response(body)) }) - } - } - - let (body, body_sender) = channel_body(); - ( - FakeServerConnector { - body: Arc::new(Mutex::new(Some(body))), - } - .into_shared(), - body_sender, - ) - } - - /// Simulate a client eagerly consuming all the data sent to it from the server. - pub async fn eagerly_consume(body: SdkBody) -> Result<(), BoxError> { - pin_mut!(body); - while let Some(result) = poll_fn(|cx| body.as_mut().poll_data(cx)).await { - if let Err(err) = result { - return Err(err); - } else { - tracing::info!("consumed bytes from the response body"); - } - } - Ok(()) - } - - /// Simulate a client very slowly consuming data with an eager server. - /// - /// This implementation will take longer than the grace period to consume - /// the next piece of data. - pub async fn slowly_consume(time: TickAdvanceTime, body: SdkBody) -> Result<(), BoxError> { - pin_mut!(body); - while let Some(result) = poll_fn(|cx| body.as_mut().poll_data(cx)).await { - if let Err(err) = result { - return Err(err); - } else { - tracing::info!("consumed bytes from the response body"); - tick!(time, Duration::from_secs(10)); - } - } - Ok(()) - } -} diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs index f1ed0f779a..70211cfe52 100644 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs +++ b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_performance.rs @@ -7,7 +7,7 @@ use aws_smithy_async::rt::sleep::TokioSleep; use aws_smithy_async::time::{SystemTimeSource, TimeSource}; -use aws_smithy_runtime::client::http::body::minimum_throughput::MinimumThroughputDownloadBody; +use aws_smithy_runtime::client::http::body::minimum_throughput::MinimumThroughputBody; use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; use aws_smithy_types::body::SdkBody; use aws_smithy_types::byte_stream::ByteStream; @@ -92,7 +92,7 @@ async fn make_request(address: &str, wrap_body: bool) -> Duration { let time_source = SystemTimeSource::new(); let sleep = TokioSleep::new(); let opts = StalledStreamProtectionConfig::enabled().build(); - let mtb = MinimumThroughputDownloadBody::new(time_source, sleep, body, opts.into()); + let mtb = MinimumThroughputBody::new(time_source, sleep, body, opts.into()); SdkBody::from_body_0_4(mtb) }); } diff --git a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs b/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs deleted file mode 100644 index f64fa321b2..0000000000 --- a/rust-runtime/aws-smithy-runtime/tests/stalled_stream_upload.rs +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -#![cfg(all(feature = "client", feature = "test-util"))] - -#[macro_use] -mod stalled_stream_common; -use stalled_stream_common::*; - -/// Scenario: Successful upload at a rate above the minimum throughput. -/// Expected: MUST NOT timeout. -#[tokio::test] -async fn upload_success() { - let _logs = show_test_logs(); - - let (server, time, sleep) = eager_server(true); - let op = operation(server, time, sleep); - - let (body, body_sender) = channel_body(); - let result = tokio::spawn(async move { op.invoke(body).await }); - - for _ in 0..100 { - body_sender.send(NEAT_DATA).await.unwrap(); - } - drop(body_sender); - - assert_eq!(200, result.await.unwrap().expect("success").as_u16()); -} - -/// Scenario: Upload takes some time to start, but then goes normally. -/// Expected: MUST NOT timeout. -#[tokio::test] -async fn upload_slow_start() { - let _logs = show_test_logs(); - - let (server, time, sleep) = eager_server(false); - let op = operation(server, time.clone(), sleep); - - let (body, body_sender) = channel_body(); - let result = tokio::spawn(async move { op.invoke(body).await }); - - let _streamer = tokio::spawn(async move { - // Advance longer than the grace period. This shouldn't fail since - // it is the customer's side that hasn't produced data yet, not a server issue. - time.tick(Duration::from_secs(10)).await; - - for _ in 0..100 { - body_sender.send(NEAT_DATA).await.unwrap(); - time.tick(Duration::from_secs(1)).await; - } - drop(body_sender); - time.tick(Duration::from_secs(1)).await; - }); - - assert_eq!(200, result.await.unwrap().expect("success").as_u16()); -} - -/// Scenario: The upload is going fine, but falls below the minimum throughput. -/// Expected: MUST timeout. -#[tokio::test] -async fn upload_too_slow() { - let _logs = show_test_logs(); - - // Server that starts off fast enough, but gets slower over time until it should timeout. - let (server, time, sleep) = time_sequence_server([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); - let op = operation(server, time, sleep); - - let (body, body_sender) = channel_body(); - let result = tokio::spawn(async move { op.invoke(body).await }); - - let _streamer = tokio::spawn(async move { - for send in 0..100 { - info!("send {send}"); - body_sender.send(NEAT_DATA).await.unwrap(); - } - drop(body_sender); - }); - - expect_timeout(result.await.expect("no panics")); -} - -/// Scenario: The server stops asking for data, the client maxes out its send buffer, -/// and the request stream stops being polled. -/// Expected: MUST timeout after the grace period completes. -#[tokio::test] -async fn upload_stalls() { - let _logs = show_test_logs(); - - let (server, time, sleep) = stalling_server(); - let op = operation(server, time.clone(), sleep); - - let (body, body_sender) = channel_body(); - let result = tokio::spawn(async move { op.invoke(body).await }); - - let _streamer = tokio::spawn(async move { - for send in 1..=100 { - info!("send {send}"); - body_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - } - drop(body_sender); - time.tick(Duration::from_secs(1)).await; - }); - - expect_timeout(result.await.expect("no panics")); -} - -/// Scenario: All the request data is either uploaded to the server or buffered in the -/// HTTP client, but the response doesn't start coming through within the grace period. -/// Expected: MUST timeout after the grace period completes. -#[tokio::test] -async fn complete_upload_no_response() { - let _logs = show_test_logs(); - - let (server, time, sleep) = stalling_server(); - let op = operation(server, time.clone(), sleep); - - let (body, body_sender) = channel_body(); - let result = tokio::spawn(async move { op.invoke(body).await }); - - let _streamer = tokio::spawn(async move { - body_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - drop(body_sender); - time.tick(Duration::from_secs(6)).await; - }); - - expect_timeout(result.await.expect("no panics")); -} - -// Scenario: The server stops asking for data, the client maxes out its send buffer, -// and the request stream stops being polled. However, before the grace period -// is over, the server recovers and starts asking for data again. -// Expected: MUST NOT timeout. -#[tokio::test] -async fn upload_stall_recovery_in_grace_period() { - let _logs = show_test_logs(); - - // Server starts off fast enough, but then slows down almost up to - // the grace period, and then recovers. - let (server, time, sleep) = time_sequence_server([1, 4, 1]); - let op = operation(server, time, sleep); - - let (body, body_sender) = channel_body(); - let result = tokio::spawn(async move { op.invoke(body).await }); - - let _streamer = tokio::spawn(async move { - for send in 0..100 { - info!("send {send}"); - body_sender.send(NEAT_DATA).await.unwrap(); - } - drop(body_sender); - }); - - assert_eq!(200, result.await.unwrap().expect("success").as_u16()); -} - -// Scenario: The customer isn't providing data on the stream fast enough to satisfy -// the minimum throughput. This shouldn't be considered a stall since the -// server is asking for more data and could handle it if it were available. -// Expected: MUST NOT timeout. -#[tokio::test] -async fn user_provides_data_too_slowly() { - let _logs = show_test_logs(); - - let (server, time, sleep) = eager_server(false); - let op = operation(server, time.clone(), sleep.clone()); - - let (body, body_sender) = channel_body(); - let result = tokio::spawn(async move { op.invoke(body).await }); - - let _streamer = tokio::spawn(async move { - body_sender.send(NEAT_DATA).await.unwrap(); - tick!(time, Duration::from_secs(1)); - body_sender.send(NEAT_DATA).await.unwrap(); - - // Now advance 10 seconds before sending more data, simulating a - // customer taking time to produce more data to stream. - tick!(time, Duration::from_secs(10)); - body_sender.send(NEAT_DATA).await.unwrap(); - drop(body_sender); - tick!(time, Duration::from_secs(1)); - }); - - assert_eq!(200, result.await.unwrap().expect("success").as_u16()); -} - -use upload_test_tools::*; -mod upload_test_tools { - use crate::stalled_stream_common::*; - - pub fn successful_response() -> HttpResponse { - HttpResponse::try_from( - http::Response::builder() - .status(200) - .body(SdkBody::empty()) - .unwrap(), - ) - .unwrap() - } - - pub fn operation( - http_connector: impl HttpConnector + 'static, - time: TickAdvanceTime, - sleep: TickAdvanceSleep, - ) -> Operation { - let operation = Operation::builder() - .service_name("test") - .operation_name("test") - .http_client(FakeServer(http_connector.into_shared())) - .endpoint_url("http://localhost:1234/doesntmatter") - .no_auth() - .no_retry() - .timeout_config(TimeoutConfig::disabled()) - .serializer(|body: SdkBody| Ok(HttpRequest::new(body))) - .deserializer::<_, Infallible>(|response| Ok(response.status())) - .stalled_stream_protection( - StalledStreamProtectionConfig::enabled() - .grace_period(Duration::from_secs(5)) - .build(), - ) - .interceptor(StalledStreamProtectionInterceptor::default()) - .sleep_impl(sleep) - .time_source(time) - .build(); - operation - } - - /// Creates a fake HttpConnector implementation that calls the given async $body_fn - /// to get the response body. This $body_fn is given a request body, time, and sleep. - macro_rules! fake_server { - ($name:ident, $body_fn:expr) => { - fake_server!($name, $body_fn, (), ()) - }; - ($name:ident, $body_fn:expr, $params_ty:ty, $params:expr) => {{ - #[derive(Debug)] - struct $name(TickAdvanceTime, TickAdvanceSleep, $params_ty); - impl HttpConnector for $name { - fn call(&self, mut request: HttpRequest) -> HttpConnectorFuture { - let time = self.0.clone(); - let sleep = self.1.clone(); - let params = self.2.clone(); - let span = tracing::span!(tracing::Level::INFO, "FAKE SERVER"); - HttpConnectorFuture::new( - async move { - let mut body = SdkBody::taken(); - mem::swap(request.body_mut(), &mut body); - pin_mut!(body); - - Ok($body_fn(body, time, sleep, params).await) - } - .instrument(span), - ) - } - } - let (time, sleep) = tick_advance_time_and_sleep(); - ( - $name(time.clone(), sleep.clone(), $params).into_shared(), - time, - sleep, - ) - }}; - } - - /// Fake server/connector that immediately reads all incoming data with an - /// optional 1 second gap in between polls. - pub fn eager_server( - advance_time: bool, - ) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { - async fn fake_server( - mut body: Pin<&mut SdkBody>, - time: TickAdvanceTime, - _: TickAdvanceSleep, - advance_time: bool, - ) -> HttpResponse { - while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { - if advance_time { - tick!(time, Duration::from_secs(1)); - } - } - successful_response() - } - fake_server!(FakeServerConnector, fake_server, bool, advance_time) - } - - /// Fake server/connector that reads some data, and then stalls. - pub fn stalling_server() -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { - async fn fake_server( - mut body: Pin<&mut SdkBody>, - _time: TickAdvanceTime, - _sleep: TickAdvanceSleep, - _: (), - ) -> HttpResponse { - let mut times = 5; - while times > 0 && poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { - times -= 1; - } - // never awake after this - tracing::info!("stalling indefinitely"); - std::future::pending::<()>().await; - unreachable!() - } - fake_server!(FakeServerConnector, fake_server) - } - - /// Fake server/connector that polls data after each period of time in the given - /// sequence. Once the sequence completes, it will delay 1 second after each poll. - pub fn time_sequence_server( - time_sequence: impl IntoIterator, - ) -> (SharedHttpConnector, TickAdvanceTime, TickAdvanceSleep) { - async fn fake_server( - mut body: Pin<&mut SdkBody>, - time: TickAdvanceTime, - _sleep: TickAdvanceSleep, - time_sequence: Vec, - ) -> HttpResponse { - let mut time_sequence: VecDeque = - time_sequence.into_iter().map(Duration::from_secs).collect(); - while poll_fn(|cx| body.as_mut().poll_data(cx)).await.is_some() { - let next_time = time_sequence.pop_front().unwrap_or(Duration::from_secs(1)); - tick!(time, next_time); - } - successful_response() - } - fake_server!( - FakeServerConnector, - fake_server, - Vec, - time_sequence.into_iter().collect() - ) - } - - pub fn expect_timeout(result: Result>>) { - let err = result.expect_err("should have timed out"); - assert_str_contains!( - DisplayErrorContext(&err).to_string(), - "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" - ); - } -}