diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b4db299..c6ee239 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,13 +20,31 @@ jobs: permissions: contents: none name: CI - needs: lint + needs: [test, lint] runs-on: ubuntu-latest if: always() steps: - name: Failed run: exit 1 if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') + test: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v4 + with: + submodules: true + token: ${{ secrets.GH_TOKEN }} + - name: install rust + uses: dtolnay/rust-toolchain@stable + with: + toolchain: stable + components: rustfmt, clippy + - name: install protoc + uses: arduino/setup-protoc@v3 + - uses: Swatinem/rust-cache@v2 + - name: Run cargo tests + run: cargo test lint: runs-on: ubuntu-latest steps: diff --git a/Cargo.toml b/Cargo.toml index 0cdcf37..76e9b21 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ edition = "2021" license = "Apache-2.0" [dependencies] +async-stream = "0.3.6" backon = "1.2.0" bytesize = "1.3.0" futures = "0.3.31" @@ -22,6 +23,7 @@ secrecy = "0.8.0" serde = { version = "1.0.214", optional = true, features = ["derive"] } sync_docs = { path = "sync_docs" } thiserror = "1.0.67" +tokio = { version = "1.41.1", features = ["time"] } tonic = { version = "0.12.3", features = ["tls", "tls-webpki-roots"] } tower-service = "0.3.3" @@ -29,7 +31,9 @@ tower-service = "0.3.3" tonic-build = { version = "0.12.3", features = ["prost"] } [dev-dependencies] -tokio = { version = "*", features = ["full"] } +rstest = "0.23.0" +tokio = { version = "1.41.1", features = ["full", "test-util"] } +tokio-stream = "0.1.16" [features] serde = ["dep:serde"] diff --git a/examples/basic.rs b/examples/basic.rs index cf9d416..771f969 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -2,8 +2,8 @@ use std::time::Duration; use futures::StreamExt; use streamstore::{ + batching::AppendRecordsBatchingStream, client::{Client, ClientConfig, ClientError, HostEndpoints}, - streams::AppendRecordStream, types::{ AppendInput, AppendRecord, CreateBasinRequest, CreateStreamRequest, DeleteBasinRequest, DeleteStreamRequest, ListBasinsRequest, ListStreamsRequest, ReadSessionRequest, @@ -123,7 +123,7 @@ async fn main() { }; let append_session_req = - AppendRecordStream::new(futures::stream::iter(records), Default::default()).unwrap(); + AppendRecordsBatchingStream::new(futures::stream::iter(records), Default::default()); match stream_client.append_session(append_session_req).await { Ok(mut stream) => { diff --git a/src/client.rs b/src/client.rs index 43f6bc6..dddd26b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -579,7 +579,7 @@ impl StreamClient { pub async fn read_session( &self, req: types::ReadSessionRequest, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.inner .send_retryable(ReadSessionServiceRequest::new( self.inner.stream_service_client(), @@ -587,7 +587,7 @@ impl StreamClient { req, )) .await - .map(Streaming::new) + .map(|s| Box::pin(s) as _) } #[sync_docs] @@ -610,7 +610,7 @@ impl StreamClient { req: S, ) -> Result, ClientError> where - S: 'static + Send + futures::Stream + Unpin, + S: 'static + Send + Unpin + futures::Stream, { self.inner .send(AppendSessionServiceRequest::new( @@ -619,7 +619,7 @@ impl StreamClient { req, )) .await - .map(Streaming::new) + .map(|s| Box::pin(s) as _) } } diff --git a/src/lib.rs b/src/lib.rs index c5bc06f..b5670ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,11 +2,10 @@ mod api; mod service; pub mod client; -pub mod streams; pub mod types; pub use bytesize; pub use futures; pub use http::uri; pub use secrecy::SecretString; -pub use service::Streaming; +pub use service::stream::batching; diff --git a/src/service.rs b/src/service.rs index 43bb238..689f345 100644 --- a/src/service.rs +++ b/src/service.rs @@ -193,21 +193,5 @@ impl futures::Stream for ServiceStreamingResponse { } } -pub struct Streaming(Box>>); - -impl Streaming { - pub(crate) fn new(s: ServiceStreamingResponse) -> Self - where - S: StreamingResponse + Send + 'static, - { - Self(Box::new(s)) - } -} - -impl futures::Stream for Streaming { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.0.poll_next_unpin(cx) - } -} +/// Generic type for streaming response. +pub type Streaming = Pin>>>; diff --git a/src/service/stream.rs b/src/service/stream.rs index 265c72c..15b9faa 100644 --- a/src/service/stream.rs +++ b/src/service/stream.rs @@ -1,3 +1,5 @@ +pub mod batching; + use tonic::{transport::Channel, IntoRequest}; use super::{ @@ -161,7 +163,7 @@ impl IdempotentRequest for ReadSessionServiceRequest { pub struct ReadSessionStreamingResponse; impl StreamingResponse for ReadSessionStreamingResponse { - type ResponseItem = types::ReadSessionResponse; + type ResponseItem = types::ReadOutput; type ApiResponseItem = api::ReadSessionResponse; fn parse_response_item( diff --git a/src/service/stream/batching.rs b/src/service/stream/batching.rs new file mode 100644 index 0000000..7c1d661 --- /dev/null +++ b/src/service/stream/batching.rs @@ -0,0 +1,406 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use bytesize::ByteSize; +use futures::{Stream, StreamExt}; + +use crate::types::{self, MeteredSize as _}; + +/// Options to configure append records batching scheme. +#[derive(Debug, Clone)] +pub struct AppendRecordsBatchingOpts { + max_batch_records: usize, + max_batch_size: ByteSize, + match_seq_num: Option, + fencing_token: Option>, + linger_duration: Duration, +} + +impl Default for AppendRecordsBatchingOpts { + fn default() -> Self { + Self { + max_batch_records: 1000, + max_batch_size: ByteSize::mib(1), + match_seq_num: None, + fencing_token: None, + linger_duration: Duration::from_millis(5), + } + } +} + +impl AppendRecordsBatchingOpts { + /// Construct an options struct with defaults. + pub fn new() -> Self { + Self::default() + } + + /// Maximum number of records in a batch. + pub fn with_max_batch_records(self, max_batch_records: usize) -> Self { + assert!( + max_batch_records > 0 && max_batch_records <= 1000, + "max_batch_records should be between (0, 1000]" + ); + + Self { + max_batch_records, + ..self + } + } + + /// Maximum size of a batch in bytes. + #[cfg(test)] + pub fn with_max_batch_size(self, max_batch_size: impl Into) -> Self { + let max_batch_size = max_batch_size.into(); + + assert!( + max_batch_size > ByteSize(0) && max_batch_size <= ByteSize::mib(1), + "max_batch_size should be between (0, 1] MiB" + ); + + Self { + max_batch_size, + ..self + } + } + + /// Enforce that the sequence number issued to the first record matches. + /// + /// This is incremented automatically for each batch. + pub fn with_match_seq_num(self, match_seq_num: impl Into) -> Self { + Self { + match_seq_num: Some(match_seq_num.into()), + ..self + } + } + + /// Enforce a fencing token. + pub fn with_fencing_token(self, fencing_token: impl Into>) -> Self { + Self { + fencing_token: Some(fencing_token.into()), + ..self + } + } + + /// Linger duration for records before flushing. + /// + /// A linger duration of 5ms is set by default. Set to `Duration::ZERO` + /// to disable. + pub fn with_linger(self, linger_duration: impl Into) -> Self { + Self { + linger_duration: linger_duration.into(), + ..self + } + } +} + +/// Wrapper stream that takes a stream of append records and batches them +/// together to send as an `AppendOutput`. +pub struct AppendRecordsBatchingStream(Pin + Send>>); + +impl AppendRecordsBatchingStream { + /// Create a new batching stream. + pub fn new(stream: S, opts: AppendRecordsBatchingOpts) -> Self + where + R: 'static + Into, + S: 'static + Send + Stream + Unpin, + { + Self(Box::pin(append_records_batching_stream(stream, opts))) + } +} + +impl Stream for AppendRecordsBatchingStream { + type Item = types::AppendInput; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_next_unpin(cx) + } +} + +fn append_records_batching_stream( + mut stream: S, + opts: AppendRecordsBatchingOpts, +) -> impl Stream + Send +where + R: Into, + S: 'static + Send + Stream + Unpin, +{ + async_stream::stream! { + let mut terminated = false; + let mut batch_builder = BatchBuilder::new(&opts); + + let batch_deadline = tokio::time::sleep(Duration::ZERO); + tokio::pin!(batch_deadline); + + while !terminated { + while !batch_builder.is_full() { + if batch_builder.len() == 1 { + // Start the timer when the first record is added. + batch_deadline + .as_mut() + .reset(tokio::time::Instant::now() + opts.linger_duration); + } + + tokio::select! { + biased; + next = stream.next() => { + if let Some(record) = next { + batch_builder.push(record); + } else { + terminated = true; + break; + } + }, + _ = &mut batch_deadline, if !batch_builder.is_empty() => { + break; + } + }; + } + + if !batch_builder.is_empty() { + yield batch_builder.flush(); + } + + // Now that we have flushed (if required), the batch builder should + // definitely not be full. It might not be empty since the peeked + // record might have been pushed into the batch. + assert!( + !batch_builder.is_full(), + "dangling peeked record does not fit into size limits", + ); + } + } +} + +struct BatchBuilder<'a> { + opts: &'a AppendRecordsBatchingOpts, + peeked_record: Option, + next_match_seq_num: Option, + batch: Vec, + batch_size: ByteSize, +} + +impl<'a> BatchBuilder<'a> { + pub fn new<'b: 'a>(opts: &'b AppendRecordsBatchingOpts) -> Self { + Self { + peeked_record: None, + next_match_seq_num: opts.match_seq_num, + batch: Vec::with_capacity(opts.max_batch_records), + batch_size: ByteSize(0), + opts, + } + } + + pub fn push(&mut self, record: impl Into) { + assert!(!self.is_full()); + let record = record.into(); + let record_size = record.metered_size(); + if self.batch_size + record_size > self.opts.max_batch_size { + let ret = self.peeked_record.replace(record); + assert!(ret.is_none()); + } else { + self.batch_size += record_size; + self.batch.push(record); + } + } + + pub fn is_empty(&self) -> bool { + if self.batch.is_empty() { + assert_eq!(self.batch_size, ByteSize(0)); + true + } else { + false + } + } + + pub fn len(&self) -> usize { + self.batch.len() + } + + pub fn is_full(&self) -> bool { + assert!(self.batch.len() <= self.opts.max_batch_records); + self.batch.len() == self.opts.max_batch_records || self.peeked_record.is_some() + } + + pub fn flush(&mut self) -> types::AppendInput { + assert!(!self.is_empty()); + + let match_seq_num = self.next_match_seq_num; + if let Some(next_match_seq_num) = self.next_match_seq_num.as_mut() { + *next_match_seq_num += self.batch.len() as u64; + } + + // Reset the inner batch, batch_size and push back the peeked record + // into the batch. + let records = { + self.batch_size = ByteSize(0); + std::mem::replace( + &mut self.batch, + Vec::with_capacity(self.opts.max_batch_records), + ) + }; + if let Some(record) = self.peeked_record.take() { + self.push(record); + } + + types::AppendInput { + records, + match_seq_num, + fencing_token: self.opts.fencing_token.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use bytesize::ByteSize; + use futures::StreamExt as _; + use rstest::rstest; + use tokio::sync::mpsc; + use tokio_stream::wrappers::UnboundedReceiverStream; + + use super::{AppendRecordsBatchingOpts, AppendRecordsBatchingStream}; + use crate::types; + + #[rstest] + #[case(Some(2), None)] + #[case(None, Some(ByteSize::b(30)))] + #[case(Some(2), Some(ByteSize::b(100)))] + #[case(Some(10), Some(ByteSize::b(30)))] + #[tokio::test] + async fn test_append_record_stream_batching( + #[case] max_batch_records: Option, + #[case] max_batch_size: Option, + ) { + let stream_iter = (0..100).map(|i| types::AppendRecord::new(format!("r_{i}"))); + let stream = futures::stream::iter(stream_iter); + + let mut opts = AppendRecordsBatchingOpts::new().with_linger(Duration::ZERO); + if let Some(max_batch_records) = max_batch_records { + opts = opts.with_max_batch_records(max_batch_records); + } + if let Some(max_batch_size) = max_batch_size { + opts = opts.with_max_batch_size(max_batch_size); + } + + let batch_stream = AppendRecordsBatchingStream::new(stream, opts); + + let batches = batch_stream + .map(|batch| batch.records) + .collect::>() + .await; + + let mut i = 0; + for batch in batches { + assert_eq!(batch.len(), 2); + for record in batch { + assert_eq!(record.body, format!("r_{i}").into_bytes()); + i += 1; + } + } + } + + #[tokio::test(start_paused = true)] + async fn test_append_record_stream_linger() { + let (stream_tx, stream_rx) = mpsc::unbounded_channel::(); + let mut i = 0; + + let collect_batches_handle = tokio::spawn(async move { + let batch_stream = AppendRecordsBatchingStream::new( + UnboundedReceiverStream::new(stream_rx), + AppendRecordsBatchingOpts::new() + .with_linger(Duration::from_secs(2)) + .with_max_batch_records(3) + .with_max_batch_size(ByteSize::b(40)), + ); + + batch_stream + .map(|batch| { + batch + .records + .into_iter() + .map(|rec| rec.body) + .collect::>() + }) + .collect::>() + .await + }); + + let mut send_next = |padding: Option<&str>| { + let mut record = types::AppendRecord::new(format!("r_{i}")); + if let Some(padding) = padding { + // The padding exists just to increase the size of record in + // order to test the size limits. + record = record.with_headers(vec![types::Header::new("padding", padding)]); + } + stream_tx.send(record).unwrap(); + i += 1; + }; + + async fn sleep_secs(secs: u64) { + let dur = Duration::from_secs(secs) + Duration::from_millis(10); + tokio::time::sleep(dur).await; + } + + send_next(None); + send_next(None); + + sleep_secs(2).await; + + send_next(None); + + // Waiting for a short time before sending next record. + sleep_secs(1).await; + + send_next(None); + + sleep_secs(1).await; + + // Checking batch count limits here. The first 3 records should be + // flushed immediately. + send_next(None); + send_next(None); + send_next(None); + send_next(None); + + // Waiting for a long time before sending any records. + sleep_secs(200).await; + + // Checking size limits here. The first record should be flushed + // immediately. + send_next(Some("large string")); + send_next(None); + + std::mem::drop(stream_tx); // Should close the stream + + let batches = collect_batches_handle.await.unwrap(); + + let expected_batches = vec![ + vec![b"r_0".to_owned(), b"r_1".to_owned()], + vec![b"r_2".to_owned(), b"r_3".to_owned()], + vec![b"r_4".to_owned(), b"r_5".to_owned(), b"r_6".to_owned()], + vec![b"r_7".to_owned()], + vec![b"r_8".to_owned()], + vec![b"r_9".to_owned()], + ]; + + assert_eq!(batches, expected_batches); + } + + #[tokio::test] + #[should_panic] + async fn test_append_record_stream_panic_size_limits() { + let stream = + futures::stream::iter([types::AppendRecord::new("too long to fit into size limits")]); + + let mut batch_stream = AppendRecordsBatchingStream::new( + stream, + AppendRecordsBatchingOpts::new().with_max_batch_size(ByteSize::b(1)), + ); + + let _ = batch_stream.next().await; + } +} diff --git a/src/streams.rs b/src/streams.rs deleted file mode 100644 index 2656491..0000000 --- a/src/streams.rs +++ /dev/null @@ -1,192 +0,0 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use bytesize::ByteSize; -use futures::{Stream, StreamExt}; - -use crate::types::{self, MeteredSize as _}; - -/// Options to configure [`AppendRecordStream`]. -#[derive(Debug, Clone)] -pub struct AppendRecordStreamOpts { - /// Maximum number of records in a batch. - pub max_batch_records: usize, - /// Maximum size of a batch in bytes. - pub max_batch_size: ByteSize, - /// Enforce that the sequence number issued to the first record matches. - /// - /// This is incremented automatically for each batch. - pub match_seq_num: Option, - /// Enforce a fencing token. - pub fencing_token: Option>, -} - -impl Default for AppendRecordStreamOpts { - fn default() -> Self { - Self { - max_batch_records: 1000, - max_batch_size: ByteSize::mib(1), - match_seq_num: None, - fencing_token: None, - } - } -} - -impl AppendRecordStreamOpts { - /// Construct an options struct with defaults. - pub fn new() -> Self { - Self::default() - } - - /// Construct from existing options with the new maximum batch records. - pub fn with_max_batch_records(self, max_batch_records: impl Into) -> Self { - Self { - max_batch_records: max_batch_records.into(), - ..self - } - } - - /// Construct from existing options with the new maximum batch size. - pub fn with_max_batch_size(self, max_batch_size: impl Into) -> Self { - Self { - max_batch_size: max_batch_size.into(), - ..self - } - } - - /// Construct from existing options with the initial match sequence number. - pub fn with_match_seq_num(self, match_seq_num: impl Into) -> Self { - Self { - match_seq_num: Some(match_seq_num.into()), - ..self - } - } - - /// Construct from existing options with the fencing token. - pub fn with_fencing_token(self, fencing_token: impl Into>) -> Self { - Self { - fencing_token: Some(fencing_token.into()), - ..self - } - } -} - -#[derive(Debug, thiserror::Error)] -pub enum AppendRecordStreamError { - #[error("max_batch_size should not be more than 1 Mib")] - BatchSizeTooLarge, -} - -/// Wrapper over a stream of append records that can be sent over to -/// [`crate::client::StreamClient::append_session`]. -pub struct AppendRecordStream -where - R: Into, - S: Send + Stream + Unpin, -{ - stream: S, - peeked_record: Option, - terminated: bool, - opts: AppendRecordStreamOpts, -} - -impl AppendRecordStream -where - R: Into, - S: Send + Stream + Unpin, -{ - /// Try constructing a new [`AppendRecordStream`] from the given stream and options. - pub fn new(stream: S, opts: AppendRecordStreamOpts) -> Result { - if opts.max_batch_size > ByteSize::mib(1) { - return Err(AppendRecordStreamError::BatchSizeTooLarge); - } - - Ok(Self { - stream, - peeked_record: None, - terminated: false, - opts, - }) - } - - fn push_record_to_batch( - &mut self, - record: types::AppendRecord, - batch: &mut Vec, - batch_size: &mut ByteSize, - ) { - let record_size = record.metered_size(); - if *batch_size + record_size > self.opts.max_batch_size { - // Set the peeked record and move on. - self.peeked_record = Some(record); - } else { - *batch_size += record_size; - batch.push(record); - } - } -} - -impl Stream for AppendRecordStream -where - R: Into, - S: Send + Stream + Unpin, -{ - type Item = types::AppendInput; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.terminated { - return Poll::Ready(None); - } - - let mut batch = Vec::with_capacity(self.opts.max_batch_records); - let mut batch_size = ByteSize::b(0); - - if let Some(peeked) = self.peeked_record.take() { - self.push_record_to_batch(peeked, &mut batch, &mut batch_size); - } - - while batch.len() < self.opts.max_batch_records && self.peeked_record.is_none() { - match self.stream.poll_next_unpin(cx) { - Poll::Pending => break, - Poll::Ready(None) => { - self.terminated = true; - break; - } - Poll::Ready(Some(record)) => { - self.push_record_to_batch(record.into(), &mut batch, &mut batch_size); - } - } - } - - if batch.is_empty() { - assert!( - self.peeked_record.is_none(), - "dangling peeked record does not fit into size limits" - ); - - if self.terminated { - Poll::Ready(None) - } else { - Poll::Pending - } - } else { - if self.peeked_record.is_some() { - // Ensure we poll again to return the peeked stream (at least). - cx.waker().wake_by_ref(); - } - - let match_seq_num = self.opts.match_seq_num; - if let Some(m) = self.opts.match_seq_num.as_mut() { - *m += batch.len() as u64 - } - - Poll::Ready(Some(types::AppendInput { - records: batch, - match_seq_num, - fencing_token: self.opts.fencing_token.clone(), - })) - } - } -} diff --git a/src/types.rs b/src/types.rs index 2f4c063..a036e25 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1140,19 +1140,11 @@ impl ReadSessionRequest { } } -#[sync_docs] -#[derive(Debug, Clone)] -pub struct ReadSessionResponse { - pub output: ReadOutput, -} - -impl TryFrom for ReadSessionResponse { +impl TryFrom for ReadOutput { type Error = ConvertError; fn try_from(value: api::ReadSessionResponse) -> Result { let api::ReadSessionResponse { output } = value; let output = output.ok_or("missing output in read session response")?; - Ok(Self { - output: output.try_into()?, - }) + output.try_into() } }