diff --git a/Cargo.lock b/Cargo.lock index b2f798a5e7..328978b6d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,6 +49,20 @@ dependencies = [ "generic-array", ] +[[package]] +name = "aligned-buffer" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3b0ccaa876f7e00c6a4818bf878d5d4b387eff28c79bea024824469ec25509d" +dependencies = [ + "const_panic", + "crossbeam-utils", + "rkyv", + "stable_deref_trait", + "static_assertions", + "thiserror 1.0.69", +] + [[package]] name = "alloc-no-stdlib" version = "2.0.4" @@ -551,6 +565,29 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytecheck" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50c8f430744b23b54ad15161fcbc22d82a29b73eacbe425fea23ec822600bc6f" +dependencies = [ + "bytecheck_derive", + "ptr_meta", + "rancor", + "simdutf8", +] + +[[package]] +name = "bytecheck_derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "523363cbe1df49b68215efdf500b103ac3b0fb4836aed6d15689a076eadb8fff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "bytecount" version = "0.6.8" @@ -1003,6 +1040,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c50fcfdf972929aff202c16b80086aa3cfc6a3a820af714096c58c7c1d0582" +[[package]] +name = "const_panic" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53857514f72ee4a2b583de67401e3ff63a5472ca4acf289d09a9ea7636dfec17" + [[package]] name = "core-foundation" version = "0.9.4" @@ -2739,6 +2782,26 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "munge" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64142d38c84badf60abf06ff9bd80ad2174306a5b11bd4706535090a30a419df" +dependencies = [ + "munge_macro", +] + +[[package]] +name = "munge_macro" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bb5c1d8184f13f7d0ccbeeca0def2f9a181bce2624302793005f5ca8aa62e5e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "nanorand" version = "0.7.0" @@ -3396,6 +3459,26 @@ dependencies = [ "prost", ] +[[package]] +name = "ptr_meta" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe9e76f66d3f9606f44e45598d155cb13ecf09f4a28199e48daf8c8fc937ea90" +dependencies = [ + "ptr_meta_derive", +] + +[[package]] +name = "ptr_meta_derive" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca414edb151b4c8d125c12566ab0d74dc9cdba36fb80eb7b848c15f495fd32d1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "pyo3" version = "0.22.6" @@ -3573,6 +3656,15 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rancor" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf5f7161924b9d1cea0e4cabc97c372cea92b5f927fc13c6bca67157a0ad947" +dependencies = [ + "ptr_meta", +] + [[package]] name = "rand" version = "0.8.5" @@ -3693,6 +3785,15 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" +[[package]] +name = "rend" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a35e8a6bf28cd121053a66aa2e6a2e3eaffad4a60012179f0e864aa5ffeff215" +dependencies = [ + "bytecheck", +] + [[package]] name = "reqwest" version = "0.12.9" @@ -3759,6 +3860,36 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rkyv" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b11a153aec4a6ab60795f8ebe2923c597b16b05bb1504377451e705ef1a45323" +dependencies = [ + "bytecheck", + "bytes", + "hashbrown 0.15.2", + "indexmap", + "munge", + "ptr_meta", + "rancor", + "rend", + "rkyv_derive", + "tinyvec", + "uuid", +] + +[[package]] +name = "rkyv_derive" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "beb382a4d9f53bd5c0be86b10d8179c3f8a14c30bf774ff77096ed6581e35981" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "rstest" version = "0.23.0" @@ -4058,6 +4189,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "simplelog" version = "0.12.2" @@ -5055,24 +5192,18 @@ dependencies = [ name = "vortex-ipc" version = "0.21.1" dependencies = [ - "arrow-array", - "arrow-ipc", - "arrow-schema", - "arrow-select", + "aligned-buffer", "bytes", - "criterion", "flatbuffers", - "futures-executor", "futures-util", "itertools 0.13.0", + "pin-project-lite", "tokio", "vortex-array", "vortex-buffer", "vortex-dtype", "vortex-error", "vortex-flatbuffers", - "vortex-io", - "vortex-sampling-compressor", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index de2bf26012..ec719bc738 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ categories = ["database-implementations", "data-structures", "compression"] [workspace.dependencies] anyhow = "1.0" +aligned-buffer = "0.2.0" arbitrary = "1.3.2" arrayref = "0.3.7" arrow = { version = "53.0.0" } @@ -53,7 +54,6 @@ arrow-arith = "53.0.0" arrow-array = "53.0.0" arrow-buffer = "53.0.0" arrow-cast = "53.0.0" -arrow-ipc = "53.0.0" arrow-ord = "53.0.0" arrow-schema = "53.0.0" arrow-select = "53.0.0" @@ -103,6 +103,7 @@ once_cell = "1.20.2" parquet = "53.0.0" paste = "1.0.14" pin-project = "1.1.5" +pin-project-lite = "0.2.15" prettytable-rs = "0.10.0" tabled = { version = "0.17.0", default-features = false } prost = "0.13.0" diff --git a/bench-vortex/benches/bytes_at.rs b/bench-vortex/benches/bytes_at.rs index 0a71036657..911bc826ee 100644 --- a/bench-vortex/benches/bytes_at.rs +++ b/bench-vortex/benches/bytes_at.rs @@ -4,16 +4,12 @@ use std::sync::Arc; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use futures::executor::block_on; -use futures::StreamExt; use vortex::array::{PrimitiveArray, VarBinArray, VarBinViewArray}; -use vortex::buffer::Buffer; use vortex::dtype::{DType, Nullability}; -use vortex::io::VortexBufReader; -use vortex::ipc::stream_reader::StreamArrayReader; -use vortex::ipc::stream_writer::StreamArrayWriter; +use vortex::ipc::iterator::{ArrayIteratorIPC, SyncIPCReader}; +use vortex::iter::ArrayIteratorExt; use vortex::validity::Validity; -use vortex::{Context, IntoArrayData, IntoCanonical}; +use vortex::{Context, IntoArrayData, IntoArrayVariant}; fn array_data_fixture() -> VarBinArray { VarBinArray::try_new( @@ -27,27 +23,16 @@ fn array_data_fixture() -> VarBinArray { fn array_view_fixture() -> VarBinViewArray { let array_data = array_data_fixture(); - let mut buffer = Vec::new(); - let writer = StreamArrayWriter::new(&mut buffer); - block_on(writer.write_array(array_data.into_array())).unwrap(); + let buffer = array_data + .into_array() + .into_array_iterator() + .write_ipc(vec![]) + .unwrap(); - let buffer = Buffer::from(buffer); - - let ctx = Arc::new(Context::default()); - let reader = block_on(StreamArrayReader::try_new( - VortexBufReader::new(buffer), - ctx.clone(), - )) - .unwrap(); - let reader = block_on(reader.load_dtype()).unwrap(); - - let mut stream = Box::pin(reader.into_array_stream()); - - block_on(stream.next()) - .unwrap() + SyncIPCReader::try_new(buffer.as_slice(), Arc::new(Context::default())) .unwrap() - .into_canonical() + .into_array_data() .unwrap() .into_varbinview() .unwrap() diff --git a/bench-vortex/src/data_downloads.rs b/bench-vortex/src/data_downloads.rs index 1be52e809e..a040652b1d 100644 --- a/bench-vortex/src/data_downloads.rs +++ b/bench-vortex/src/data_downloads.rs @@ -6,6 +6,7 @@ use std::path::PathBuf; use arrow_array::RecordBatchReader; use bzip2::read::BzDecoder; +use futures::StreamExt; use log::info; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use tokio::runtime::Runtime; @@ -13,8 +14,8 @@ use vortex::array::ChunkedArray; use vortex::arrow::FromArrowType; use vortex::dtype::DType; use vortex::error::{VortexError, VortexResult}; -use vortex::io::TokioAdapter; -use vortex::ipc::stream_writer::StreamArrayWriter; +use vortex::io::{TokioAdapter, VortexWrite}; +use vortex::ipc::stream::ArrayStreamIPC; use vortex::{ArrayData, IntoArrayData}; use crate::idempotent; @@ -56,11 +57,11 @@ pub fn data_vortex_uncompressed(fname_out: &str, downloaded_data: PathBuf) -> Pa Runtime::new() .unwrap() .block_on(async move { - let write = TokioAdapter(tokio::fs::File::create(path).await.unwrap()); - StreamArrayWriter::new(write) - .write_array(array) - .await - .unwrap(); + let mut write = TokioAdapter(tokio::fs::File::create(path).await.unwrap()); + let mut bytes = array.into_array_stream().into_ipc(); + while let Some(buffer) = bytes.next().await { + write.write_all(buffer.unwrap()).await.unwrap(); + } Ok::<(), VortexError>(()) }) .unwrap(); diff --git a/vortex-array/src/iter/ext.rs b/vortex-array/src/iter/ext.rs index 919b1b30af..fe9ae94ff1 100644 --- a/vortex-array/src/iter/ext.rs +++ b/vortex-array/src/iter/ext.rs @@ -4,6 +4,7 @@ use vortex_error::VortexResult; use crate::array::ChunkedArray; use crate::iter::ArrayIterator; use crate::stream::{ArrayStream, ArrayStreamAdapter}; +use crate::{ArrayData, IntoArrayData}; pub trait ArrayIteratorExt: ArrayIterator { fn into_stream(self) -> impl ArrayStream @@ -13,11 +14,21 @@ pub trait ArrayIteratorExt: ArrayIterator { ArrayStreamAdapter::new(self.dtype().clone(), futures_util::stream::iter(self)) } - fn try_into_chunked(self) -> VortexResult + /// Collect the iterator into a single `ArrayData`. + /// + /// If the iterator yields multiple chunks, they will be returned as a [`ChunkedArray`]. + fn into_array_data(self) -> VortexResult where Self: Sized, { let dtype = self.dtype().clone(); - ChunkedArray::try_new(self.try_collect()?, dtype) + let mut chunks: Vec = self.try_collect()?; + if chunks.len() == 1 { + Ok(chunks.remove(0)) + } else { + Ok(ChunkedArray::try_new(chunks, dtype)?.into_array()) + } } } + +impl ArrayIteratorExt for I {} diff --git a/vortex-array/src/stream/ext.rs b/vortex-array/src/stream/ext.rs index 3282af2547..e6ac71dcd6 100644 --- a/vortex-array/src/stream/ext.rs +++ b/vortex-array/src/stream/ext.rs @@ -6,21 +6,28 @@ use vortex_error::VortexResult; use crate::array::ChunkedArray; use crate::stream::take_rows::TakeRows; use crate::stream::{ArrayStream, ArrayStreamAdapter}; -use crate::ArrayData; +use crate::{ArrayData, IntoArrayData}; pub trait ArrayStreamExt: ArrayStream { - fn collect_chunked(self) -> impl Future> + /// Collect the stream into a single `ArrayData`. + /// + /// If the stream yields multiple chunks, they will be returned as a [`ChunkedArray`]. + fn into_array_data(self) -> impl Future> where Self: Sized, { - async { + async move { let dtype = self.dtype().clone(); - self.try_collect() - .await - .and_then(|chunks| ChunkedArray::try_new(chunks, dtype)) + let mut chunks: Vec = self.try_collect().await?; + if chunks.len() == 1 { + Ok(chunks.remove(0)) + } else { + Ok(ChunkedArray::try_new(chunks, dtype)?.into_array()) + } } } + /// Perform a row-wise selection on the stream from an array of sorted indicessss. fn take_rows(self, indices: ArrayData) -> VortexResult where Self: Sized, @@ -32,4 +39,4 @@ pub trait ArrayStreamExt: ArrayStream { } } -impl ArrayStreamExt for R {} +impl ArrayStreamExt for S {} diff --git a/vortex-buffer/src/lib.rs b/vortex-buffer/src/lib.rs index a73a1ee21a..ba42277a24 100644 --- a/vortex-buffer/src/lib.rs +++ b/vortex-buffer/src/lib.rs @@ -63,6 +63,7 @@ impl Buffer { #[allow(clippy::same_name_method)] /// Return a new view on the buffer, but limited to the given index range. + /// TODO(ngates): implement std::ops::Index pub fn slice(&self, range: Range) -> Self { match &self.0 { Inner::Arrow(b) => Buffer(Inner::Arrow( diff --git a/vortex-file/src/byte_range.rs b/vortex-file/src/byte_range.rs new file mode 100644 index 0000000000..473fe14344 --- /dev/null +++ b/vortex-file/src/byte_range.rs @@ -0,0 +1,39 @@ +use std::fmt::{Display, Formatter}; +use std::ops::Range; + +use vortex_error::VortexUnwrap; + +#[derive(Copy, Clone, Debug)] +pub struct ByteRange { + pub begin: u64, + pub end: u64, +} + +impl Display for ByteRange { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}, {})", self.begin, self.end) + } +} + +impl ByteRange { + pub fn new(begin: u64, end: u64) -> Self { + assert!(begin < end, "Buffer begin must be before its end"); + Self { begin, end } + } + + pub fn len(&self) -> u64 { + self.end - self.begin + } + + pub fn is_empty(&self) -> bool { + self.begin == self.end + } + + pub fn as_range(&self) -> Range { + Range { + // TODO(ngates): this cast is unsafe and can panic + start: self.begin.try_into().vortex_unwrap(), + end: self.end.try_into().vortex_unwrap(), + } + } +} diff --git a/vortex-file/src/dtype_reader.rs b/vortex-file/src/dtype_reader.rs deleted file mode 100644 index b7113d2819..0000000000 --- a/vortex-file/src/dtype_reader.rs +++ /dev/null @@ -1,28 +0,0 @@ -use vortex_dtype::DType; -use vortex_error::VortexResult; -use vortex_io::{VortexBufReader, VortexReadAt}; -use vortex_ipc::messages::reader::MessageReader; - -/// Reader for serialized dtype messages -pub struct DTypeReader { - msgs: MessageReader, -} - -impl DTypeReader { - /// Create new [DTypeReader] given readable contents - pub async fn new(read: VortexBufReader) -> VortexResult { - Ok(Self { - msgs: MessageReader::try_new(read).await?, - }) - } - - /// Deserialize dtype out of ipc serialized format - pub async fn read_dtype(&mut self) -> VortexResult { - self.msgs.read_dtype().await - } - - /// Deconstruct this reader into its underlying contents for further reuse - pub fn into_inner(self) -> VortexBufReader { - self.msgs.into_inner() - } -} diff --git a/vortex-file/src/lib.rs b/vortex-file/src/lib.rs index 58fa8740bb..611bfbc806 100644 --- a/vortex-file/src/lib.rs +++ b/vortex-file/src/lib.rs @@ -66,13 +66,10 @@ //! If you ultimately seek Arrow arrays, [`VortexRecordBatchReader`] converts a //! [`VortexFileArrayStream`] into a [`RecordBatchReader`](arrow_array::RecordBatchReader). -mod dtype_reader; - -pub use dtype_reader::*; - mod read; mod write; +mod byte_range; mod pruning; #[cfg(test)] mod tests; diff --git a/vortex-file/src/read/buffered.rs b/vortex-file/src/read/buffered.rs index e207498dd4..f63ec6c385 100644 --- a/vortex-file/src/read/buffered.rs +++ b/vortex-file/src/read/buffered.rs @@ -165,7 +165,7 @@ where self.dispatcher .dispatch(move || async move { let read_messages = reader - .read_byte_ranges(messages.iter().map(|msg| msg.1.to_range()).collect()) + .read_byte_ranges(messages.iter().map(|msg| msg.1.as_range()).collect()) .map(move |read_res| { Ok(messages .into_iter() diff --git a/vortex-file/src/read/layouts/chunked.rs b/vortex-file/src/read/layouts/chunked.rs index bd1c7cf903..fe49fee7ea 100644 --- a/vortex-file/src/read/layouts/chunked.rs +++ b/vortex-file/src/read/layouts/chunked.rs @@ -400,6 +400,7 @@ mod tests { use arrow_buffer::BooleanBufferBuilder; use bytes::Bytes; use flatbuffers::{root, FlatBufferBuilder}; + use futures_util::io::Cursor; use futures_util::TryStreamExt; use vortex_array::array::{ChunkedArray, PrimitiveArray}; use vortex_array::compute::FilterMask; @@ -407,9 +408,9 @@ mod tests { use vortex_dtype::PType; use vortex_expr::{BinaryExpr, Identity, Literal, Operator}; use vortex_flatbuffers::{footer, WriteFlatBuffer}; - use vortex_ipc::messages::writer::MessageWriter; - use vortex_ipc::stream_writer::ByteRange; + use vortex_ipc::messages::{AsyncMessageWriter, EncoderMessage}; + use crate::byte_range::ByteRange; use crate::layouts::chunked::{ChunkedLayoutBuilder, ChunkedLayoutReader}; use crate::read::cache::{LazyDType, RelativeLayoutCache}; use crate::read::layouts::test_read::{filter_read_layout, read_layout, read_layout_data}; @@ -420,22 +421,25 @@ mod tests { cache: Arc>, scan: Scan, ) -> (ChunkedLayoutReader, ChunkedLayoutReader, Bytes, usize) { - let mut writer = MessageWriter::new(Vec::new()); + let mut writer = Cursor::new(Vec::new()); let array = PrimitiveArray::from((0..100).collect::>()).into_array(); let array_dtype = array.dtype().clone(); let chunked = ChunkedArray::try_new(iter::repeat(array).take(5).collect(), array_dtype).unwrap(); let len = chunked.len(); - let mut byte_offsets = vec![writer.tell()]; + let mut byte_offsets = vec![writer.position()]; let mut row_offsets = vec![0]; let mut row_offset = 0; let mut chunk_stream = chunked.array_stream(); + let mut msgs = AsyncMessageWriter::new(&mut writer); while let Some(chunk) = chunk_stream.try_next().await.unwrap() { row_offset += chunk.len() as u64; row_offsets.push(row_offset); - writer.write_array(chunk).await.unwrap(); - byte_offsets.push(writer.tell()); + msgs.write_message(EncoderMessage::Array(&chunk)) + .await + .unwrap(); + byte_offsets.push(msgs.inner().position()); } let flat_layouts = byte_offsets .iter() diff --git a/vortex-file/src/read/layouts/flat.rs b/vortex-file/src/read/layouts/flat.rs index 9a525641a3..32d818befc 100644 --- a/vortex-file/src/read/layouts/flat.rs +++ b/vortex-file/src/read/layouts/flat.rs @@ -1,13 +1,14 @@ use std::collections::BTreeSet; +use std::io::Cursor; use std::sync::Arc; use bytes::Bytes; use vortex_array::{ArrayData, Context}; use vortex_error::{vortex_bail, VortexResult}; use vortex_flatbuffers::footer; -use vortex_ipc::messages::reader::ArrayMessageReader; -use vortex_ipc::stream_writer::ByteRange; +use vortex_ipc::messages::{DecoderMessage, SyncMessageReader}; +use crate::byte_range::ByteRange; use crate::read::cache::RelativeLayoutCache; use crate::read::mask::RowMask; use crate::{ @@ -72,16 +73,16 @@ impl FlatLayoutReader { MessageLocator(self.message_cache.absolute_id(&[]), self.range) } - fn array_from_bytes(&self, mut buf: Bytes) -> VortexResult { - let mut array_reader = ArrayMessageReader::new(); - let mut read_buf = Bytes::new(); - while let Some(u) = array_reader.read(read_buf)? { - read_buf = buf.split_to(u); + fn array_from_bytes(&self, buf: Bytes) -> VortexResult { + let mut reader = SyncMessageReader::new(Cursor::new(buf)); + match reader.next().transpose()? { + Some(DecoderMessage::Array(array_parts)) => array_parts.into_array_data( + self.ctx.clone(), + self.message_cache.dtype().value()?.clone(), + ), + Some(msg) => vortex_bail!("Expected Array message, got {:?}", msg), + None => vortex_bail!("Expected Array message, got EOF"), } - array_reader.into_array( - self.ctx.clone(), - self.message_cache.dtype().value()?.clone(), - ) } } @@ -119,12 +120,12 @@ mod tests { use bytes::Bytes; use vortex_array::array::PrimitiveArray; - use vortex_array::{Context, IntoArrayData, IntoArrayVariant}; + use vortex_array::{Context, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::PType; use vortex_expr::{BinaryExpr, Identity, Literal, Operator}; - use vortex_ipc::messages::writer::MessageWriter; - use vortex_ipc::stream_writer::ByteRange; + use vortex_ipc::messages::{EncoderMessage, SyncMessageWriter}; + use crate::byte_range::ByteRange; use crate::layouts::flat::FlatLayoutReader; use crate::read::cache::{LazyDType, RelativeLayoutCache}; use crate::read::layouts::test_read::{filter_read_layout, read_layout}; @@ -133,11 +134,12 @@ mod tests { async fn read_only_layout( cache: Arc>, ) -> (FlatLayoutReader, Bytes, usize, Arc) { - let mut writer = MessageWriter::new(Vec::new()); let array = PrimitiveArray::from((0..100).collect::>()).into_array(); - let len = array.len(); - writer.write_array(array).await.unwrap(); - let written = writer.into_inner(); + + let mut written = vec![]; + SyncMessageWriter::new(&mut written) + .write_message(EncoderMessage::Array(&array.to_array())) + .unwrap(); let projection_scan = Scan::empty(); let dtype = Arc::new(LazyDType::from_dtype(PType::I32.into())); @@ -150,7 +152,7 @@ mod tests { RelativeLayoutCache::new(cache, dtype.clone()), ), Bytes::from(written), - len, + array.len(), dtype, ) } diff --git a/vortex-file/src/read/layouts/test_read.rs b/vortex-file/src/read/layouts/test_read.rs index 841f34cc39..416da248ef 100644 --- a/vortex-file/src/read/layouts/test_read.rs +++ b/vortex-file/src/read/layouts/test_read.rs @@ -33,7 +33,7 @@ pub fn read_layout_data( PollRead::ReadMore(m) => { let mut write_cache_guard = cache.write().unwrap(); for MessageLocator(id, range) in m { - write_cache_guard.set(id, buf.slice(range.to_range())); + write_cache_guard.set(id, buf.slice(range.as_range())); } } PollRead::Value(a) => return Some(a), @@ -53,7 +53,7 @@ pub fn read_filters( PollRead::ReadMore(m) => { let mut write_cache_guard = cache.write().unwrap(); for MessageLocator(id, range) in m { - write_cache_guard.set(id, buf.slice(range.to_range())); + write_cache_guard.set(id, buf.slice(range.as_range())); } } PollRead::Value(a) => { diff --git a/vortex-file/src/read/mod.rs b/vortex-file/src/read/mod.rs index 2bc5a4b72c..d89aaff814 100644 --- a/vortex-file/src/read/mod.rs +++ b/vortex-file/src/read/mod.rs @@ -29,8 +29,8 @@ pub use projection::Projection; pub use recordbatchreader::{AsyncRuntime, VortexRecordBatchReader}; pub use stream::VortexFileArrayStream; use vortex_expr::ExprRef; -use vortex_ipc::stream_writer::ByteRange; +use crate::byte_range::ByteRange; pub use crate::read::mask::RowMask; // Recommended read-size according to the AWS performance guide diff --git a/vortex-file/src/tests.rs b/vortex-file/src/tests.rs index f6651efed0..f70284f6d1 100644 --- a/vortex-file/src/tests.rs +++ b/vortex-file/src/tests.rs @@ -928,7 +928,7 @@ async fn test_pruning_with_or() { .into_array(); let buffer = Vec::new(); - let written_bytes = VortexFileWriter::new(buffer) + let written_bytes: Vec = VortexFileWriter::new(buffer) .write_array_columns(array) .await .unwrap() diff --git a/vortex-file/src/write/layout.rs b/vortex-file/src/write/layout.rs index b79bec6454..b86066a1d8 100644 --- a/vortex-file/src/write/layout.rs +++ b/vortex-file/src/write/layout.rs @@ -1,8 +1,8 @@ use bytes::Bytes; use flatbuffers::{FlatBufferBuilder, WIPOffset}; use vortex_flatbuffers::{footer as fb, FlatBufferRoot, WriteFlatBuffer}; -use vortex_ipc::stream_writer::ByteRange; +use crate::byte_range::ByteRange; use crate::{LayoutId, CHUNKED_LAYOUT_ID, COLUMNAR_LAYOUT_ID, FLAT_LAYOUT_ID}; #[derive(Debug, Clone)] diff --git a/vortex-file/src/write/writer.rs b/vortex-file/src/write/writer.rs index 164f1b2dbc..53a31570d0 100644 --- a/vortex-file/src/write/writer.rs +++ b/vortex-file/src/write/writer.rs @@ -4,6 +4,7 @@ use std::{io, iter, mem}; use bytes::Bytes; use futures::TryStreamExt; +use futures_util::io::Cursor; use itertools::Itertools; use vortex_array::array::{ChunkedArray, StructArray}; use vortex_array::stats::{as_stat_bitset_bytes, ArrayStatistics, Stat}; @@ -13,9 +14,9 @@ use vortex_dtype::DType; use vortex_error::{vortex_bail, vortex_err, VortexExpect as _, VortexResult}; use vortex_flatbuffers::{FlatBufferRoot, WriteFlatBuffer, WriteFlatBufferExt}; use vortex_io::VortexWrite; -use vortex_ipc::messages::writer::MessageWriter; -use vortex_ipc::stream_writer::ByteRange; +use vortex_ipc::messages::{EncoderMessage, MessageEncoder}; +use crate::byte_range::ByteRange; use crate::write::postscript::Postscript; use crate::write::stats_accumulator::{StatArray, StatsAccumulator}; use crate::{LayoutSpec, EOF_SIZE, MAGIC_BYTES, MAX_FOOTER_SIZE, VERSION}; @@ -33,8 +34,7 @@ const STATS_TO_WRITE: &[Stat] = &[ ]; pub struct VortexFileWriter { - msgs: MessageWriter, - + write: Cursor, row_count: u64, dtype: Option, column_writers: Vec, @@ -43,7 +43,7 @@ pub struct VortexFileWriter { impl VortexFileWriter { pub fn new(write: W) -> Self { VortexFileWriter { - msgs: MessageWriter::new(write), + write: Cursor::new(write), dtype: None, column_writers: Vec::new(), row_count: 0, @@ -117,7 +117,7 @@ impl VortexFileWriter { Some(x) => x, }; - column_writer.write_chunks(stream, &mut self.msgs).await + column_writer.write_chunks(stream, &mut self.write).await } async fn write_metadata_arrays(&mut self) -> VortexResult { @@ -125,7 +125,7 @@ impl VortexFileWriter { for column_writer in mem::take(&mut self.column_writers) { column_layouts.push( column_writer - .write_metadata(self.row_count, &mut self.msgs) + .write_metadata(self.row_count, &mut self.write) .await?, ); } @@ -135,10 +135,7 @@ impl VortexFileWriter { pub async fn finalize(mut self) -> VortexResult { let top_level_layout = self.write_metadata_arrays().await?; - let dtype_offset = self.msgs.tell(); - - // we want to write raw flatbuffers from here on out, not messages - let mut writer = self.msgs.into_inner(); + let dtype_offset = self.write.position(); // write the schema, and get the start offset of the next section (layout) let layout_offset = { @@ -149,15 +146,15 @@ impl VortexFileWriter { // we write an IPCSchema instead of a DType, which allows us to evolve / add to the schema later // these bytes get deserialized as message::Schema // NB: we don't wrap the IPCSchema in an IPCMessage, because we record the lengths/offsets in the footer - let dtype_len = write_fb_raw(&mut writer, dtype).await?; + let dtype_len = write_fb_raw(&mut self.write, dtype).await?; dtype_offset + dtype_len }; // write the layout - write_fb_raw(&mut writer, top_level_layout).await?; + write_fb_raw(&mut self.write, top_level_layout).await?; let footer = Postscript::try_new(dtype_offset, layout_offset)?; - let footer_len = write_fb_raw(&mut writer, footer).await?; + let footer_len = write_fb_raw(&mut self.write, footer).await?; if footer_len > MAX_FOOTER_SIZE as u64 { vortex_bail!( "Footer is too large ({} bytes); max footer size is {}", @@ -172,8 +169,8 @@ impl VortexFileWriter { eof[2..4].copy_from_slice(&footer_len.to_le_bytes()); eof[4..8].copy_from_slice(&MAGIC_BYTES); - writer.write_all(eof).await?; - Ok(writer) + self.write.write_all(eof).await?; + Ok(self.write.into_inner()) } } @@ -206,10 +203,10 @@ impl ColumnWriter { async fn write_chunks( &mut self, mut stream: S, - msgs: &mut MessageWriter, + write: &mut Cursor, ) -> VortexResult<()> { let mut offsets = Vec::with_capacity(stream.size_hint().0 + 1); - offsets.push(msgs.tell()); + offsets.push(write.position()); let mut row_offsets = Vec::with_capacity(stream.size_hint().0 + 1); row_offsets.push( self.batch_row_offsets @@ -230,8 +227,12 @@ impl ColumnWriter { // clear the stats that we don't want to serialize into the file retain_only_stats(&chunk, STATS_TO_WRITE); - msgs.write_array(chunk).await?; - offsets.push(msgs.tell()); + let mut encoder = MessageEncoder::default(); + for buffer in encoder.encode(EncoderMessage::Array(&chunk)) { + write.write_all(buffer).await?; + } + + offsets.push(write.position()); row_offsets.push(rows_written); } @@ -244,7 +245,7 @@ impl ColumnWriter { async fn write_metadata( self, row_count: u64, - msgs: &mut MessageWriter, + write: &mut Cursor, ) -> VortexResult { let data_chunks = self .batch_byte_offsets @@ -269,9 +270,12 @@ impl ColumnWriter { let stat_bitset = as_stat_bitset_bytes(&present_stats); - let metadata_array_begin = msgs.tell(); - msgs.write_array(metadata_array).await?; - let metadata_array_end = msgs.tell(); + let metadata_array_begin = write.position(); + let mut encoder = MessageEncoder::default(); + for buffer in encoder.encode(EncoderMessage::Array(&metadata_array)) { + write.write_all(buffer).await?; + } + let metadata_array_end = write.position(); let layouts = iter::once(LayoutSpec::flat( ByteRange::new(metadata_array_begin, metadata_array_end), diff --git a/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs b/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs index 5d0148d7fc..425b1aabdf 100644 --- a/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs +++ b/vortex-flatbuffers/flatbuffers/vortex-array/array.fbs @@ -40,4 +40,5 @@ table Buffer { padding: uint16; } +root_type Array; root_type ArrayData; \ No newline at end of file diff --git a/vortex-flatbuffers/src/generated/message.rs b/vortex-flatbuffers/src/generated/message.rs index a4a7433175..e89f676ef3 100644 --- a/vortex-flatbuffers/src/generated/message.rs +++ b/vortex-flatbuffers/src/generated/message.rs @@ -4,8 +4,8 @@ // @generated use crate::dtype::*; -use crate::array::*; use crate::scalar::*; +use crate::array::*; use core::mem; use core::cmp::Ordering; diff --git a/vortex-io/src/write.rs b/vortex-io/src/write.rs index 5500d53870..12220c9ff1 100644 --- a/vortex-io/src/write.rs +++ b/vortex-io/src/write.rs @@ -41,6 +41,21 @@ where } } +impl VortexWrite for futures::io::Cursor { + fn write_all(&mut self, buffer: B) -> impl Future> { + self.set_position(self.position() + buffer.as_slice().len() as u64); + VortexWrite::write_all(self.get_mut(), buffer) + } + + fn flush(&mut self) -> impl Future> { + VortexWrite::flush(self.get_mut()) + } + + fn shutdown(&mut self) -> impl Future> { + VortexWrite::shutdown(self.get_mut()) + } +} + impl VortexWrite for &mut W { fn write_all(&mut self, buffer: B) -> impl Future> { (*self).write_all(buffer) diff --git a/vortex-ipc/Cargo.toml b/vortex-ipc/Cargo.toml index 6df69ddc92..e8d1df9719 100644 --- a/vortex-ipc/Cargo.toml +++ b/vortex-ipc/Cargo.toml @@ -14,35 +14,20 @@ readme.workspace = true categories.workspace = true [dependencies] +aligned-buffer = { workspace = true } bytes = { workspace = true } flatbuffers = { workspace = true } -futures-util = { workspace = true } +futures-util = { workspace = true, features = ["io"] } itertools = { workspace = true } +pin-project-lite = { workspace = true } vortex-array = { workspace = true } vortex-buffer = { workspace = true } vortex-dtype = { workspace = true } vortex-error = { workspace = true } vortex-flatbuffers = { workspace = true, features = ["ipc"] } -vortex-io = { workspace = true } [dev-dependencies] -arrow-array = { workspace = true } -arrow-ipc = { workspace = true } -arrow-schema = { workspace = true } -arrow-select = { workspace = true } -criterion = { workspace = true, features = ["async_futures"] } -futures-executor = { workspace = true } tokio = { workspace = true, features = ["full"] } -vortex-sampling-compressor = { path = "../vortex-sampling-compressor" } -vortex-io = { path = "../vortex-io", features = ["futures"] } [lints] workspace = true - -[[bench]] -name = "ipc_take" -harness = false - -[[bench]] -name = "ipc_array_reader_take" -harness = false diff --git a/vortex-ipc/benches/ipc_array_reader_take.rs b/vortex-ipc/benches/ipc_array_reader_take.rs deleted file mode 100644 index 88747bf257..0000000000 --- a/vortex-ipc/benches/ipc_array_reader_take.rs +++ /dev/null @@ -1,73 +0,0 @@ -#![allow(clippy::unwrap_used)] -use std::sync::Arc; -use std::time::Duration; - -use bytes::Bytes; -use criterion::async_executor::FuturesExecutor; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use futures_executor::block_on; -use futures_util::{pin_mut, TryStreamExt}; -use itertools::Itertools; -use vortex_array::array::{ChunkedArray, PrimitiveArray}; -use vortex_array::stream::ArrayStreamExt; -use vortex_array::validity::Validity; -use vortex_array::{Context, IntoArrayData}; -use vortex_io::VortexBufReader; -use vortex_ipc::stream_reader::StreamArrayReader; -use vortex_ipc::stream_writer::StreamArrayWriter; - -// 100 record batches, 100k rows each -// take from the first 20 batches and last batch -// compare with arrow -fn ipc_array_reader_take(c: &mut Criterion) { - let ctx = Arc::new(Context::default()); - - let indices = (0..20) - .map(|i| i * 100_000 + 1) - .chain([98 * 100_000 + 1]) - .collect_vec(); - let mut group = c.benchmark_group("ipc_array_reader_take"); - - group.bench_function("vortex", |b| { - let array = ChunkedArray::from_iter( - (0..100i32) - .map(|i| vec![i; 100_000]) - .map(|vec| PrimitiveArray::from_vec(vec, Validity::AllValid).into_array()), - ) - .into_array(); - - let buffer = block_on(async { StreamArrayWriter::new(vec![]).write_array(array).await }) - .unwrap() - .into_inner(); - - let buffer = Bytes::from(buffer); - - let indices = indices.clone().into_array(); - - b.to_async(FuturesExecutor).iter(|| async { - let stream_reader = - StreamArrayReader::try_new(VortexBufReader::new(buffer.clone()), ctx.clone()) - .await - .unwrap() - .load_dtype() - .await - .unwrap(); - let stream = stream_reader - .into_array_stream() - .take_rows(indices.clone()) - .unwrap(); - pin_mut!(stream); - - while let Some(arr) = stream.try_next().await.unwrap() { - black_box(arr); - } - }); - }); -} - -criterion_group!( - name = benches; - config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = ipc_array_reader_take -); -criterion_main!(benches); diff --git a/vortex-ipc/benches/ipc_take.rs b/vortex-ipc/benches/ipc_take.rs deleted file mode 100644 index 89c7a3cc48..0000000000 --- a/vortex-ipc/benches/ipc_take.rs +++ /dev/null @@ -1,95 +0,0 @@ -#![allow(clippy::unwrap_used)] -use std::sync::Arc; -use std::time::Duration; - -use arrow_array::{Array, Int32Array, RecordBatch}; -use arrow_ipc::reader::StreamReader; -use arrow_ipc::writer::{IpcWriteOptions, StreamWriter as ArrowStreamWriter}; -use arrow_ipc::{CompressionType, MetadataVersion}; -use arrow_schema::{DataType, Field, Schema}; -use bytes::Bytes; -use criterion::async_executor::FuturesExecutor; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use futures_executor::block_on; -use futures_util::{pin_mut, TryStreamExt}; -use itertools::Itertools; -use vortex_array::array::PrimitiveArray; -use vortex_array::compress::CompressionStrategy; -use vortex_array::compute::take; -use vortex_array::{Context, IntoArrayData}; -use vortex_io::VortexBufReader; -use vortex_ipc::stream_reader::StreamArrayReader; -use vortex_ipc::stream_writer::StreamArrayWriter; -use vortex_sampling_compressor::SamplingCompressor; - -fn ipc_take(c: &mut Criterion) { - let mut group = c.benchmark_group("ipc_take"); - let indices = Int32Array::from(vec![10, 11, 12, 13, 100_000, 2_999_999]); - group.bench_function("arrow", |b| { - let mut buffer = vec![]; - { - let field = Field::new("uid", DataType::Int32, true); - let schema = Schema::new(vec![field]); - let options = IpcWriteOptions::try_new(32, false, MetadataVersion::V5) - .unwrap() - .try_with_compression(Some(CompressionType::LZ4_FRAME)) - .unwrap(); - let mut writer = - ArrowStreamWriter::try_new_with_options(&mut buffer, &schema, options).unwrap(); - let array = Int32Array::from((0i32..3_000_000).rev().collect_vec()); - - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap(); - writer.write(&batch).unwrap(); - } - - b.iter(|| { - let mut cursor = std::io::Cursor::new(&buffer); - let mut reader = StreamReader::try_new(&mut cursor, None).unwrap(); - let batch = reader.next().unwrap().unwrap(); - let array_from_batch = batch.column(0); - let array = array_from_batch - .as_any() - .downcast_ref::() - .unwrap(); - black_box(arrow_select::take::take(array, &indices, None).unwrap()); - }); - }); - - group.bench_function("vortex", |b| { - let indices = PrimitiveArray::from(vec![10, 11, 12, 13, 100_000, 2_999_999]).into_array(); - let uncompressed = PrimitiveArray::from((0i32..3_000_000).rev().collect_vec()).into_array(); - let ctx = Context::default(); - let compressor: &dyn CompressionStrategy = &SamplingCompressor::default(); - let compressed = compressor.compress(&uncompressed).unwrap(); - - // Try running take over an ArrayView. - let buffer = - block_on(async { StreamArrayWriter::new(vec![]).write_array(compressed).await }) - .unwrap() - .into_inner(); - - let ctx_ref = &Arc::new(ctx); - let ro_buffer = buffer.as_slice(); - let indices_ref = &indices; - - b.to_async(FuturesExecutor).iter(|| async move { - let stream_reader = StreamArrayReader::try_new( - VortexBufReader::new(Bytes::from(ro_buffer.to_vec())), - ctx_ref.clone(), - ) - .await? - .load_dtype() - .await?; - let reader = stream_reader.into_array_stream(); - pin_mut!(reader); - let array_view = reader.try_next().await?.unwrap(); - black_box(take(&array_view, indices_ref)) - }); - }); -} - -criterion_group!( - name = benches; - config = Criterion::default().measurement_time(Duration::from_secs(10)); - targets = ipc_take); -criterion_main!(benches); diff --git a/vortex-ipc/src/iterator.rs b/vortex-ipc/src/iterator.rs new file mode 100644 index 0000000000..8ae2a5e24e --- /dev/null +++ b/vortex-ipc/src/iterator.rs @@ -0,0 +1,189 @@ +use std::io::{Read, Write}; +use std::sync::Arc; + +use aligned_buffer::UniqueAlignedBuffer; +use bytes::Bytes; +use itertools::Itertools; +use vortex_array::iter::ArrayIterator; +use vortex_array::{ArrayDType, ArrayData, Context}; +use vortex_buffer::Buffer; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, vortex_err, VortexResult}; + +use crate::messages::{DecoderMessage, EncoderMessage, MessageEncoder, SyncMessageReader}; +use crate::ALIGNMENT; + +/// An [`ArrayIterator`] for reading messages off an IPC stream. +pub struct SyncIPCReader { + reader: SyncMessageReader, + ctx: Arc, + dtype: DType, +} + +impl SyncIPCReader { + pub fn try_new(read: R, ctx: Arc) -> VortexResult { + let mut reader = SyncMessageReader::new(read); + match reader.next().transpose()? { + Some(msg) => match msg { + DecoderMessage::DType(dtype) => Ok(SyncIPCReader { reader, ctx, dtype }), + msg => { + vortex_bail!("Expected DType message, got {:?}", msg); + } + }, + None => vortex_bail!("Expected DType message, got EOF"), + } + } +} + +impl ArrayIterator for SyncIPCReader { + fn dtype(&self) -> &DType { + &self.dtype + } +} + +impl Iterator for SyncIPCReader { + type Item = VortexResult; + + fn next(&mut self) -> Option { + match self.reader.next()? { + Ok(msg) => match msg { + DecoderMessage::Array(array_parts) => Some( + array_parts + .into_array_data(self.ctx.clone(), self.dtype.clone()) + .and_then(|array| { + if array.dtype() != self.dtype() { + Err(vortex_err!( + "Array data type mismatch: expected {:?}, got {:?}", + self.dtype(), + array.dtype() + )) + } else { + Ok(array) + } + }), + ), + msg => Some(Err(vortex_err!("Expected Array message, got {:?}", msg))), + }, + Err(e) => Some(Err(e)), + } + } +} + +/// A trait for converting an [`ArrayIterator`] into an IPC stream. +pub trait ArrayIteratorIPC { + fn into_ipc(self) -> ArrayIteratorIPCBytes + where + Self: Sized; + + fn write_ipc(self, write: W) -> VortexResult + where + Self: Sized; +} + +impl ArrayIteratorIPC for I { + fn into_ipc(self) -> ArrayIteratorIPCBytes + where + Self: Sized, + { + let mut encoder = MessageEncoder::default(); + let buffers = encoder.encode(EncoderMessage::DType(self.dtype())); + ArrayIteratorIPCBytes { + inner: Box::new(self), + encoder, + buffers, + } + } + + fn write_ipc(self, mut write: W) -> VortexResult + where + Self: Sized, + { + let mut stream = self.into_ipc(); + for buffer in &mut stream { + write.write_all(buffer?.as_slice())?; + } + Ok(write) + } +} + +pub struct ArrayIteratorIPCBytes { + inner: Box, + encoder: MessageEncoder, + buffers: Vec, +} + +impl ArrayIteratorIPCBytes { + /// Collects the IPC bytes into a single `Buffer`. + pub fn collect_to_buffer(self) -> VortexResult { + // We allocate a single aligned buffer to hold the combined IPC bytes + let buffers: Vec = self.try_collect()?; + let mut buffer = + UniqueAlignedBuffer::::with_capacity(buffers.iter().map(|b| b.len()).sum()); + for buf in buffers { + buffer.extend_from_slice(buf.as_slice()); + } + Ok(Buffer::from(Bytes::from_owner(buffer))) + } +} + +impl Iterator for ArrayIteratorIPCBytes { + type Item = VortexResult; + + fn next(&mut self) -> Option { + // Try to flush any buffers we have + if !self.buffers.is_empty() { + return Some(Ok(self.buffers.remove(0))); + } + + // Or else try to serialize the next array + match self.inner.next()? { + Ok(chunk) => { + self.buffers + .extend(self.encoder.encode(EncoderMessage::Array(&chunk))); + } + Err(e) => return Some(Err(e)), + } + + // Try to flush any buffers we have again + if !self.buffers.is_empty() { + return Some(Ok(self.buffers.remove(0))); + } + + // Otherwise, we're done + None + } +} + +#[cfg(test)] +mod test { + use std::io::Cursor; + use std::sync::Arc; + + use vortex_array::array::PrimitiveArray; + use vortex_array::iter::{ArrayIterator, ArrayIteratorExt}; + use vortex_array::validity::Validity; + use vortex_array::{ArrayDType, Context, IntoArrayVariant, ToArrayData}; + + use super::*; + + #[test] + fn test_sync_stream() { + let array = PrimitiveArray::from_vec::(vec![1, 2, 3], Validity::NonNullable); + let ipc_buffer = array + .to_array() + .into_array_iterator() + .into_ipc() + .collect_to_buffer() + .unwrap(); + + let reader = + SyncIPCReader::try_new(Cursor::new(ipc_buffer), Arc::new(Context::default())).unwrap(); + + assert_eq!(reader.dtype(), array.dtype()); + let result = reader.into_array_data().unwrap().into_primitive().unwrap(); + assert_eq!( + array.maybe_null_slice::(), + result.maybe_null_slice::() + ); + } +} diff --git a/vortex-ipc/src/lib.rs b/vortex-ipc/src/lib.rs index 5e4dcee994..da4fdd16c1 100644 --- a/vortex-ipc/src/lib.rs +++ b/vortex-ipc/src/lib.rs @@ -5,12 +5,11 @@ //! data buffers. //! //! This crate provides both in-memory message representations for holding IPC messages -//! before/after serialization, as well as streaming readers and writers that sit on top +//! before/after serialization, and streaming readers and writers that sit on top //! of any type implementing `VortexRead` or `VortexWrite` respectively. - +pub mod iterator; pub mod messages; -pub mod stream_reader; -pub mod stream_writer; +pub mod stream; /// All messages in Vortex are aligned to start at a multiple of 64 bytes. /// @@ -18,105 +17,3 @@ pub mod stream_writer; /// thus all buffers allocated with this alignment are naturally aligned /// for any data we may put inside of it. pub const ALIGNMENT: usize = 64; - -#[cfg(test)] -#[allow(clippy::panic_in_result_fn)] -mod test { - use std::sync::Arc; - - use bytes::Bytes; - use futures_executor::block_on; - use futures_util::{pin_mut, StreamExt, TryStreamExt}; - use itertools::Itertools; - use vortex_array::array::{ChunkedArray, PrimitiveArray, PrimitiveEncoding}; - use vortex_array::encoding::EncodingVTable; - use vortex_array::stream::ArrayStreamExt; - use vortex_array::{ArrayDType, Context, IntoArrayData, IntoArrayVariant}; - use vortex_buffer::Buffer; - use vortex_error::VortexResult; - use vortex_io::VortexBufReader; - - use crate::stream_reader::StreamArrayReader; - use crate::stream_writer::StreamArrayWriter; - - fn write_ipc(array: A) -> Vec { - block_on(async { - StreamArrayWriter::new(vec![]) - .write_array(array.into_array()) - .await - .unwrap() - .into_inner() - }) - } - - #[tokio::test] - #[cfg_attr(miri, ignore)] - async fn test_empty_index() -> VortexResult<()> { - let data = PrimitiveArray::from((0i32..3_000_000).collect_vec()); - let buffer = write_ipc(data); - - let indices = PrimitiveArray::from(vec![1, 2, 10]).into_array(); - - let ctx = Arc::new(Context::default()); - let stream_reader = - StreamArrayReader::try_new(VortexBufReader::new(Bytes::from(buffer)), ctx) - .await - .unwrap() - .load_dtype() - .await - .unwrap(); - let reader = stream_reader.into_array_stream(); - - let result_iter = reader.take_rows(indices)?; - pin_mut!(result_iter); - - let _result = block_on(async { result_iter.next().await.unwrap().unwrap() }); - Ok(()) - } - - #[tokio::test] - #[cfg_attr(miri, ignore)] - async fn test_write_read_chunked() -> VortexResult<()> { - let indices = PrimitiveArray::from(vec![ - 10u32, 11, 12, 13, 100_000, 2_999_999, 2_999_999, 3_000_000, - ]) - .into_array(); - - // NB: the order is reversed here to ensure we aren't grabbing indexes instead of values - let data = PrimitiveArray::from((0i32..3_000_000).rev().collect_vec()).into_array(); - let data2 = - PrimitiveArray::from((3_000_000i32..6_000_000).rev().collect_vec()).into_array(); - let chunked = ChunkedArray::try_new(vec![data.clone(), data2], data.dtype().clone())?; - let buffer = write_ipc(chunked); - let buffer = Buffer::from(buffer); - - let ctx = Arc::new(Context::default()); - let stream_reader = StreamArrayReader::try_new(VortexBufReader::new(buffer), ctx) - .await - .unwrap() - .load_dtype() - .await - .unwrap(); - - let take_iter = stream_reader.into_array_stream().take_rows(indices)?; - pin_mut!(take_iter); - - let next = block_on(async { take_iter.try_next().await })?.expect("Expected a chunk"); - assert_eq!(next.encoding().id(), PrimitiveEncoding.id()); - - assert_eq!( - next.into_primitive().unwrap().maybe_null_slice::(), - vec![2999989, 2999988, 2999987, 2999986, 2899999, 0, 0] - ); - assert_eq!( - block_on(async { take_iter.try_next().await })? - .expect("Expected a chunk") - .into_primitive() - .unwrap() - .maybe_null_slice::(), - vec![5999999] - ); - - Ok(()) - } -} diff --git a/vortex-ipc/src/messages/decoder.rs b/vortex-ipc/src/messages/decoder.rs new file mode 100644 index 0000000000..821bcdc040 --- /dev/null +++ b/vortex-ipc/src/messages/decoder.rs @@ -0,0 +1,356 @@ +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use bytes::{Buf, BytesMut}; +use flatbuffers::{root, root_unchecked, Follow}; +use itertools::Itertools; +use vortex_array::{flatbuffers as fba, ArrayData, Context}; +use vortex_buffer::Buffer; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, vortex_err, VortexExpect, VortexResult}; +use vortex_flatbuffers::message as fb; +use vortex_flatbuffers::message::{MessageHeader, MessageVersion}; + +use crate::ALIGNMENT; + +/// A message decoded from an IPC stream. +/// +/// Note that the `Array` variant cannot fully decode into an [`ArrayData`] without a [`Context`] +/// and a [`DType`]. As such, we partially decode into an [`ArrayParts`] and allow the caller to +/// finish the decoding. +#[derive(Debug)] +pub enum DecoderMessage { + Array(ArrayParts), + Buffer(Buffer), + DType(DType), +} + +/// ArrayParts represents a partially decoded Vortex array. +/// It can be completely decoded calling `into_array_data` with a context and dtype. +pub struct ArrayParts { + row_count: usize, + // Typed as fb::Array + array_flatbuffer: Buffer, + array_flatbuffer_loc: usize, + buffers: Vec, +} + +impl Debug for ArrayParts { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ArrayComponents") + .field("row_count", &self.row_count) + .field("array_flatbuffer", &self.array_flatbuffer.len()) + .field("buffers", &self.buffers.len()) + .finish() + } +} + +impl ArrayParts { + pub fn into_array_data(self, ctx: Arc, dtype: DType) -> VortexResult { + ArrayData::try_new_viewed( + ctx, + dtype, + self.row_count, + self.array_flatbuffer, + // SAFETY: ArrayComponents guarantees the buffers are valid. + |buf| unsafe { Ok(fba::Array::follow(buf, self.array_flatbuffer_loc)) }, + self.buffers, + ) + } +} + +#[derive(Default)] +enum State { + #[default] + Length, + Header(usize), + Array(ReadingArray), + Buffer(ReadingBuffer), +} + +struct ReadingArray { + header: Buffer, + buffers_length: usize, +} + +struct ReadingBuffer { + length: usize, + length_with_padding: usize, +} + +#[derive(Debug)] +pub enum PollRead { + Some(DecoderMessage), + /// Returns the _total_ number of bytes needed to make progress. + /// Note this is _not_ the incremental number of bytes needed to make progress. + NeedMore(usize), +} + +/// A stateful reader for decoding IPC messages from an arbitrary stream of bytes. +/// +/// NOTE(ngates): we should design some trait that the Decoder can take that doesn't require unique +/// ownership of the underlying bytes. The decoder needs to split out bytes, and advance a cursor, +/// but it doesn't need to mutate any bytes. So in theory, we should be able to do this zero-copy +/// over a shared buffer of bytes, instead of requiring a `BytesMut`. +pub struct MessageDecoder { + /// The minimum alignment to use when reading a data `Buffer`. + alignment: usize, + /// The current state of the decoder. + state: State, +} + +impl Default for MessageDecoder { + fn default() -> Self { + Self { + alignment: ALIGNMENT, + state: Default::default(), + } + } +} + +/// The alignment required for a flatbuffer message. +/// This is based on the assumption that the maximum primitive type is 8 bytes. +/// See: https://groups.google.com/g/flatbuffers/c/PSgQeWeTx_g +const FB_ALIGNMENT: usize = 8; + +impl MessageDecoder { + /// Attempt to read the next message from the bytes object. + /// + /// If the message is incomplete, the function will return `NeedMore` with the _total_ number + /// of bytes needed to make progress. The next call to read_next _should_ provide at least + /// this number of bytes otherwise it will be given the same `NeedMore` response. + pub fn read_next(&mut self, bytes: &mut BytesMut) -> VortexResult { + loop { + match &self.state { + State::Length => { + if bytes.len() < 4 { + return Ok(PollRead::NeedMore(4)); + } + + let msg_length = bytes.get_u32_le(); + self.state = State::Header(msg_length as usize); + } + State::Header(msg_length) => { + if bytes.len() < *msg_length { + bytes.try_reserve_aligned(*msg_length, FB_ALIGNMENT); + return Ok(PollRead::NeedMore(*msg_length)); + } + + let mut msg_bytes = bytes.split_to_aligned(*msg_length, FB_ALIGNMENT); + let msg = root::(msg_bytes.as_ref())?; + if msg.version() != MessageVersion::V0 { + vortex_bail!("Unsupported message version {:?}", msg.version()); + } + + match msg.header_type() { + MessageHeader::ArrayData => { + let array_data = msg + .header_as_array_data() + .vortex_expect("array data header"); + let buffers_length: u64 = array_data + .buffers() + .unwrap_or_default() + .iter() + .map(|buffer| buffer.length() + (buffer.padding() as u64)) + .sum(); + + let buffers_length = usize::try_from(buffers_length).map_err(|_| { + vortex_err!("buffers length is too large for usize") + })?; + + self.state = State::Array(ReadingArray { + header: Buffer::from(msg_bytes.split().freeze()), + buffers_length, + }); + } + MessageHeader::Buffer => { + let buffer = msg.header_as_buffer().vortex_expect("buffer header"); + let length = usize::try_from(buffer.length()) + .vortex_expect("Buffer length is too large for usize"); + let length_with_padding = length + buffer.padding() as usize; + + self.state = State::Buffer(ReadingBuffer { + length, + length_with_padding, + }); + } + MessageHeader::DType => { + let dtype = msg.header_as_dtype().vortex_expect("dtype header"); + let dtype = DType::try_from(dtype)?; + + // Nothing else to read, so we reset the state to Length + self.state = Default::default(); + return Ok(PollRead::Some(DecoderMessage::DType(dtype))); + } + _ => { + vortex_bail!("Unsupported message header type {:?}", msg.header_type()); + } + } + } + State::Buffer(ReadingBuffer { + length, + length_with_padding, + }) => { + if bytes.len() < *length_with_padding { + bytes.try_reserve_aligned(*length_with_padding, self.alignment); + return Ok(PollRead::NeedMore(*length_with_padding)); + } + let buffer = bytes.split_to_aligned(*length, self.alignment); + + let msg = DecoderMessage::Buffer(Buffer::from(buffer.freeze())); + let _padding = bytes.split_to(length_with_padding - length); + + // Nothing else to read, so we reset the state to Length + self.state = Default::default(); + return Ok(PollRead::Some(msg)); + } + State::Array(ReadingArray { + header, + buffers_length, + }) => { + if bytes.len() < *buffers_length { + bytes.try_reserve_aligned(*buffers_length, self.alignment); + return Ok(PollRead::NeedMore(*buffers_length)); + } + + // SAFETY: we've already validated the header + let msg = unsafe { root_unchecked::(header.as_ref()) }; + let array_data_msg = msg + .header_as_array_data() + .vortex_expect("array data header"); + let array_msg = array_data_msg + .array() + .ok_or_else(|| vortex_err!("array data message missing array"))?; + + let buffers = array_data_msg + .buffers() + .unwrap_or_default() + .iter() + .map(|buffer_msg| { + let buffer_len = usize::try_from(buffer_msg.length()) + .vortex_expect("buffer length is too large for usize"); + let buffer = bytes.split_to_aligned(buffer_len, self.alignment); + let _padding = bytes.split_to(buffer_msg.padding() as usize); + Buffer::from(buffer.freeze()) + }) + .collect_vec(); + + let row_count = usize::try_from(array_data_msg.row_count()) + .map_err(|_| vortex_err!("row count is too large for usize"))?; + + let msg = DecoderMessage::Array(ArrayParts { + row_count, + array_flatbuffer: header.clone(), + array_flatbuffer_loc: array_msg._tab.loc(), + buffers, + }); + + self.state = Default::default(); + return Ok(PollRead::Some(msg)); + } + } + } + } +} + +trait BytesMutAlignedSplit { + /// If the buffer is empty, advances the cursor to the next aligned position and ensures there + /// is sufficient capacity for the requested length. + /// + /// If the buffer is not empty, this function does nothing. + /// + /// This allows us to optimistically align buffers that might be read into from an I/O source. + /// However, if the source of the decoder's BytesMut is a fully formed in-memory IPC buffer, + /// then it would be wasteful to copy the whole thing, and we'd rather only copy the individual + /// buffers that require alignment. + fn try_reserve_aligned(&mut self, capacity: usize, align: usize); + + /// Splits the buffer at the given index, ensuring the returned BytesMut is aligned + /// as requested. + /// + /// If the buffer isn't already aligned, the split data will be copied into a new + /// buffer that is aligned. + fn split_to_aligned(&mut self, at: usize, align: usize) -> BytesMut; +} + +impl BytesMutAlignedSplit for BytesMut { + fn try_reserve_aligned(&mut self, capacity: usize, align: usize) { + if !self.is_empty() { + return; + } + + // Reserve up to the worst-cast alignment + self.reserve(capacity + align); + let padding = self.as_ptr().align_offset(align); + unsafe { self.set_len(padding) }; + self.advance(padding); + } + + fn split_to_aligned(&mut self, at: usize, align: usize) -> BytesMut { + let buffer = self.split_to(at); + + // If the buffer is already aligned, we can return it directly. + if buffer.as_ptr().align_offset(align) == 0 { + return buffer; + } + + // Otherwise, we allocate a new buffer, align the start, and copy the data. + // NOTE(ngates): this case will rarely be hit. Only if the caller mutates the bytes after + // they have been aligned by the decoder using `reserve_aligned`. + let mut aligned = BytesMut::with_capacity(buffer.len() + align); + let padding = aligned.as_ptr().align_offset(align); + unsafe { aligned.set_len(padding) }; + aligned.advance(padding); + aligned.extend_from_slice(&buffer); + + aligned + } +} + +#[cfg(test)] +mod test { + use vortex_array::array::{ConstantArray, PrimitiveArray}; + use vortex_array::{ArrayDType, IntoArrayData}; + use vortex_error::vortex_panic; + + use super::*; + use crate::messages::{EncoderMessage, MessageEncoder}; + + fn write_and_read(expected: ArrayData) { + let mut ipc_bytes = BytesMut::new(); + let mut encoder = MessageEncoder::default(); + for buf in encoder.encode(EncoderMessage::Array(&expected)) { + ipc_bytes.extend_from_slice(buf.as_ref()); + } + + let mut decoder = MessageDecoder::default(); + + // Since we provide all bytes up-front, we should never hit a NeedMore. + let mut buffer = BytesMut::from(ipc_bytes.as_ref()); + let array_parts = match decoder.read_next(&mut buffer).unwrap() { + PollRead::Some(DecoderMessage::Array(array_parts)) => array_parts, + otherwise => vortex_panic!("Expected an array, got {:?}", otherwise), + }; + + // Decode the array parts with the context + let actual = array_parts + .into_array_data(Arc::new(Context::default()), expected.dtype().clone()) + .unwrap(); + + assert_eq!(expected.len(), actual.len()); + assert_eq!(expected.encoding(), actual.encoding()); + } + + #[test] + fn array_ipc() { + write_and_read(PrimitiveArray::from(vec![0i32, 1, 2, 3]).into_array()); + } + + #[test] + fn array_no_buffers() { + // Constant arrays have no buffers + let array = ConstantArray::new(10i32, 20).into_array(); + assert!(array.buffer().is_none(), "Array should have no buffers"); + write_and_read(array); + } +} diff --git a/vortex-ipc/src/messages/encoder.rs b/vortex-ipc/src/messages/encoder.rs new file mode 100644 index 0000000000..319b9810d3 --- /dev/null +++ b/vortex-ipc/src/messages/encoder.rs @@ -0,0 +1,239 @@ +use flatbuffers::{FlatBufferBuilder, WIPOffset}; +use itertools::Itertools; +use vortex_array::stats::ArrayStatistics; +use vortex_array::{flatbuffers as fba, ArrayData}; +use vortex_buffer::Buffer; +use vortex_dtype::DType; +use vortex_error::{vortex_panic, VortexExpect}; +use vortex_flatbuffers::{message as fb, FlatBufferRoot, WriteFlatBuffer}; + +use crate::ALIGNMENT; + +/// An IPC message ready to be passed to the encoder. +pub enum EncoderMessage<'a> { + Array(&'a ArrayData), + Buffer(&'a Buffer), + DType(&'a DType), +} + +pub struct MessageEncoder { + /// The alignment used for each message and buffer. + /// TODO(ngates): I'm not sure we need to include this much padding in the stream itself. + alignment: usize, + /// The current position in the stream. Used to calculate leading padding. + pos: usize, + /// A reusable buffer of zeros used for padding. + zeros: Buffer, +} + +impl Default for MessageEncoder { + fn default() -> Self { + Self::new(ALIGNMENT) + } +} + +impl MessageEncoder { + /// Create a new message encoder that pads each message and buffer with the given alignment. + /// + /// ## Panics + /// + /// Panics if `alignment` is greater than `u16::MAX` or is not a power of 2. + pub fn new(alignment: usize) -> Self { + // We guarantee that alignment fits inside u16. + u16::try_from(alignment).vortex_expect("Alignment must fit into u16"); + if !alignment.is_power_of_two() { + vortex_panic!("Alignment must be a power of 2"); + } + + Self { + alignment, + pos: 0, + zeros: Buffer::from(vec![0; alignment]), + } + } + + /// Encode an IPC message for writing to a byte stream. + /// + /// The returned buffers should be written contiguously to the stream. + pub fn encode(&mut self, message: EncoderMessage) -> Vec { + let mut buffers = vec![]; + assert_eq!( + self.pos.next_multiple_of(self.alignment), + self.pos, + "pos must be aligned at start of a message" + ); + + // We'll push one buffer as a placeholder for the flatbuffer message length, and one + // for the flatbuffer itself. + buffers.push(self.zeros.clone()); + buffers.push(self.zeros.clone()); + + // We initialize the flatbuffer builder with a 4-byte vector that we will use to store + // the flatbuffer length into. By passing this vector into the FlatBufferBuilder, the + // flatbuffers internal alignment mechanisms will handle everything else for us. + // TODO(ngates): again, this a ton of padding... + let mut fbb = FlatBufferBuilder::from_vec(vec![0u8; 4]); + + let header = match message { + EncoderMessage::Array(array) => { + let row_count = array.len(); + let array_def = ArrayWriter { + array, + buffer_idx: 0, + } + .write_flatbuffer(&mut fbb); + + let mut fb_buffers = vec![]; + for child in array.depth_first_traversal() { + if let Some(buffer) = child.buffer() { + let end_excl_padding = self.pos + buffer.len(); + let end_incl_padding = end_excl_padding.next_multiple_of(self.alignment); + let padding = u16::try_from(end_incl_padding - end_excl_padding) + .vortex_expect("We know padding fits into u16"); + fb_buffers.push(fba::Buffer::create( + &mut fbb, + &fba::BufferArgs { + length: buffer.len() as u64, + padding, + }, + )); + buffers.push(buffer.clone()); + if padding > 0 { + buffers.push(self.zeros.slice(0..usize::from(padding))); + } + } + } + let fb_buffers = fbb.create_vector(&fb_buffers); + + fba::ArrayData::create( + &mut fbb, + &fba::ArrayDataArgs { + array: Some(array_def), + row_count: row_count as u64, + buffers: Some(fb_buffers), + }, + ) + .as_union_value() + } + EncoderMessage::Buffer(buffer) => { + let end_incl_padding = buffer.len().next_multiple_of(self.alignment); + let padding = u16::try_from(end_incl_padding - buffer.len()) + .vortex_expect("We know padding fits into u16"); + buffers.push(buffer.clone()); + if padding > 0 { + buffers.push(self.zeros.slice(0..usize::from(padding))); + } + fba::Buffer::create( + &mut fbb, + &fba::BufferArgs { + length: buffer.len() as u64, + padding, + }, + ) + .as_union_value() + } + EncoderMessage::DType(dtype) => dtype.write_flatbuffer(&mut fbb).as_union_value(), + }; + + let mut msg = fb::MessageBuilder::new(&mut fbb); + msg.add_version(Default::default()); + msg.add_header_type(match message { + EncoderMessage::Array(_) => fb::MessageHeader::ArrayData, + EncoderMessage::Buffer(_) => fb::MessageHeader::Buffer, + EncoderMessage::DType(_) => fb::MessageHeader::DType, + }); + msg.add_header(header); + let msg = msg.finish(); + + // Finish the flatbuffer and swap it out for the placeholder buffer. + fbb.finish_minimal(msg); + let (mut fbv, pos) = fbb.collapse(); + + // Add some padding to the flatbuffer vector to ensure it is aligned. + // Note that we have to include the 4-byte length prefix in the alignment calculation. + let unaligned_len = fbv.len() - pos + 4; + let padding = unaligned_len.next_multiple_of(self.alignment) - unaligned_len; + fbv.extend_from_slice(&self.zeros.as_slice()[0..padding]); + let fbv_len = fbv.len(); + let fb_buffer = Buffer::from(fbv).slice(pos..fbv_len); + + let fb_buffer_len = u32::try_from(fb_buffer.len()) + .vortex_expect("IPC flatbuffer headers must fit into u32 bytes"); + buffers[0] = Buffer::from(fb_buffer_len.to_le_bytes().to_vec()); + buffers[1] = fb_buffer; + + // Update the write cursor. + self.pos += buffers.iter().map(|b| b.len()).sum::(); + + buffers + } +} + +struct ArrayWriter<'a> { + array: &'a ArrayData, + buffer_idx: u16, +} + +impl FlatBufferRoot for ArrayWriter<'_> {} + +impl WriteFlatBuffer for ArrayWriter<'_> { + type Target<'t> = fba::Array<'t>; + + fn write_flatbuffer<'fb>( + &self, + fbb: &mut FlatBufferBuilder<'fb>, + ) -> WIPOffset> { + let encoding = self.array.encoding().id().code(); + let metadata = self + .array + .metadata_bytes() + .vortex_expect("IPCArray is missing metadata during serialization"); + let metadata = Some(fbb.create_vector(metadata.as_ref())); + + // Assign buffer indices for all child arrays. + // The second tuple element holds the buffer_index for this Array subtree. If this array + // has a buffer, that is its buffer index. If it does not, that buffer index belongs + // to one of the children. + let child_buffer_idx = self.buffer_idx + if self.array.buffer().is_some() { 1 } else { 0 }; + + let children = self + .array + .children() + .iter() + .scan(child_buffer_idx, |buffer_idx, child| { + // Update the number of buffers required. + let msg = ArrayWriter { + array: child, + buffer_idx: *buffer_idx, + } + .write_flatbuffer(fbb); + *buffer_idx = u16::try_from(child.cumulative_nbuffers()) + .ok() + .and_then(|nbuffers| nbuffers.checked_add(*buffer_idx)) + .vortex_expect("Too many buffers (u16) for ArrayData"); + Some(msg) + }) + .collect_vec(); + let children = Some(fbb.create_vector(&children)); + + let buffers = self + .array + .buffer() + .is_some() + .then_some(self.buffer_idx) + .map(|buffer_idx| fbb.create_vector_from_iter(std::iter::once(buffer_idx))); + + let stats = Some(self.array.statistics().write_flatbuffer(fbb)); + + fba::Array::create( + fbb, + &fba::ArrayArgs { + encoding, + metadata, + children, + buffers, + stats, + }, + ) + } +} diff --git a/vortex-ipc/src/messages/mod.rs b/vortex-ipc/src/messages/mod.rs index dd00afa7a5..98a0530cb2 100644 --- a/vortex-ipc/src/messages/mod.rs +++ b/vortex-ipc/src/messages/mod.rs @@ -1,173 +1,13 @@ -use flatbuffers::{FlatBufferBuilder, WIPOffset}; -use itertools::Itertools; -use vortex_array::stats::ArrayStatistics; -use vortex_array::{flatbuffers as fba, ArrayData}; -use vortex_buffer::Buffer; -use vortex_dtype::DType; -use vortex_error::VortexExpect; -use vortex_flatbuffers::{message as fb, FlatBufferRoot, WriteFlatBuffer}; - -use crate::ALIGNMENT; - -pub mod reader; -pub mod writer; - -pub enum IPCMessage { - Array(ArrayData), - Buffer(Buffer), - DType(DType), -} - -impl FlatBufferRoot for IPCMessage {} - -impl WriteFlatBuffer for IPCMessage { - type Target<'a> = fb::Message<'a>; - - fn write_flatbuffer<'fb>( - &self, - fbb: &mut FlatBufferBuilder<'fb>, - ) -> WIPOffset> { - let header = match self { - Self::Array(array) => ArrayDataWriter { array } - .write_flatbuffer(fbb) - .as_union_value(), - Self::Buffer(buffer) => { - let aligned_len = buffer.len().next_multiple_of(ALIGNMENT); - let padding = aligned_len - buffer.len(); - fba::Buffer::create( - fbb, - &fba::BufferArgs { - length: buffer.len() as u64, - padding: padding.try_into().vortex_expect("padding must fit in u16"), - }, - ) - .as_union_value() - } - Self::DType(dtype) => dtype.write_flatbuffer(fbb).as_union_value(), - }; - - let mut msg = fb::MessageBuilder::new(fbb); - msg.add_version(Default::default()); - msg.add_header_type(match self { - Self::Array(_) => fb::MessageHeader::ArrayData, - Self::Buffer(_) => fb::MessageHeader::Buffer, - Self::DType(_) => fb::MessageHeader::DType, - }); - msg.add_header(header); - msg.finish() - } -} - -struct ArrayDataWriter<'a> { - array: &'a ArrayData, -} - -impl WriteFlatBuffer for ArrayDataWriter<'_> { - type Target<'t> = fba::ArrayData<'t>; - - fn write_flatbuffer<'fb>( - &self, - fbb: &mut FlatBufferBuilder<'fb>, - ) -> WIPOffset> { - let array = Some( - ArrayWriter { - array: self.array, - buffer_idx: 0, - } - .write_flatbuffer(fbb), - ); - - // Walk the ColumnData depth-first to compute the buffer lengths. - let mut buffers = vec![]; - for array_data in self.array.depth_first_traversal() { - if let Some(buffer) = array_data.buffer() { - let aligned_size = buffer.len().next_multiple_of(ALIGNMENT); - let padding = aligned_size - buffer.len(); - buffers.push(fba::Buffer::create( - fbb, - &fba::BufferArgs { - length: buffer.len() as u64, - padding: padding.try_into().vortex_expect("padding must fit in u16"), - }, - )); - } - } - let buffers = Some(fbb.create_vector(&buffers)); - - fba::ArrayData::create( - fbb, - &fba::ArrayDataArgs { - array, - row_count: self.array.len() as u64, - buffers, - }, - ) - } -} - -struct ArrayWriter<'a> { - array: &'a ArrayData, - buffer_idx: u16, -} - -impl WriteFlatBuffer for ArrayWriter<'_> { - type Target<'t> = fba::Array<'t>; - - fn write_flatbuffer<'fb>( - &self, - fbb: &mut FlatBufferBuilder<'fb>, - ) -> WIPOffset> { - let encoding = self.array.encoding().id().code(); - let metadata = self - .array - .metadata_bytes() - .vortex_expect("IPCArray is missing metadata during serialization"); - let metadata = Some(fbb.create_vector(metadata.as_ref())); - - // Assign buffer indices for all child arrays. - // The second tuple element holds the buffer_index for this Array subtree. If this array - // has a buffer, that is its buffer index. If it does not, that buffer index belongs - // to one of the children. - let child_buffer_idx = self.buffer_idx + if self.array.buffer().is_some() { 1 } else { 0 }; - - let children = self - .array - .children() - .iter() - .scan(child_buffer_idx, |buffer_idx, child| { - // Update the number of buffers required. - let msg = ArrayWriter { - array: child, - buffer_idx: *buffer_idx, - } - .write_flatbuffer(fbb); - *buffer_idx = u16::try_from(child.cumulative_nbuffers()) - .ok() - .and_then(|nbuffers| nbuffers.checked_add(*buffer_idx)) - .vortex_expect("Too many buffers (u16) for ArrayData"); - Some(msg) - }) - .collect_vec(); - let children = Some(fbb.create_vector(&children)); - - let buffers = self - .array - .buffer() - .is_some() - .then_some(self.buffer_idx) - .map(|buffer_idx| fbb.create_vector_from_iter(std::iter::once(buffer_idx))); - - let stats = Some(self.array.statistics().write_flatbuffer(fbb)); - - fba::Array::create( - fbb, - &fba::ArrayArgs { - encoding, - metadata, - children, - buffers, - stats, - }, - ) - } -} +mod decoder; +mod encoder; +mod reader_async; +mod reader_sync; +mod writer_async; +mod writer_sync; + +pub use decoder::*; +pub use encoder::*; +pub use reader_async::*; +pub use reader_sync::*; +pub use writer_async::*; +pub use writer_sync::*; diff --git a/vortex-ipc/src/messages/reader.rs b/vortex-ipc/src/messages/reader.rs deleted file mode 100644 index 2a4d72abb5..0000000000 --- a/vortex-ipc/src/messages/reader.rs +++ /dev/null @@ -1,389 +0,0 @@ -use std::io; -use std::sync::Arc; - -use bytes::{Buf, Bytes}; -use flatbuffers::{root, root_unchecked}; -use futures_util::stream::try_unfold; -use vortex_array::stream::{ArrayStream, ArrayStreamAdapter}; -use vortex_array::{flatbuffers as fba, ArrayData, Context}; -use vortex_buffer::Buffer; -use vortex_dtype::DType; -use vortex_error::{vortex_bail, vortex_err, VortexExpect, VortexResult}; -use vortex_flatbuffers::message as fb; -use vortex_io::{VortexBufReader, VortexReadAt}; - -pub const MESSAGE_PREFIX_LENGTH: usize = 4; - -/// A stateful reader of [`Message`s][fb::Message] from a stream. -pub struct MessageReader { - read: VortexBufReader, - message: Option, - prev_message: Option, - finished: bool, -} - -impl MessageReader { - pub async fn try_new(read: VortexBufReader) -> VortexResult { - let mut reader = Self { - read, - message: None, - prev_message: None, - finished: false, - }; - reader.load_next_message().await?; - Ok(reader) - } - - async fn load_next_message(&mut self) -> VortexResult { - let mut buffer = match self.read.read_bytes(MESSAGE_PREFIX_LENGTH as u64).await { - Ok(b) => b, - Err(e) => { - return match e.kind() { - io::ErrorKind::UnexpectedEof => Ok(false), - _ => Err(e.into()), - }; - } - }; - - let len = buffer.get_u32_le(); - if len == u32::MAX { - // Marker for no more messages. - return Ok(false); - } else if len == 0 { - vortex_bail!(InvalidSerde: "Invalid IPC stream") - } - - let next_msg = self.read.read_bytes(len as u64).await?; - - // Validate that the message is valid a flatbuffer. - root::(&next_msg).map_err( - |e| vortex_err!(InvalidSerde: "Failed to parse flatbuffer message: {:?}", e), - )?; - - self.message = Some(next_msg); - - Ok(true) - } - - fn peek(&self) -> Option { - if self.finished { - return None; - } - // The message has been validated by the next() call. - Some(unsafe { - root_unchecked::( - self.message - .as_ref() - .vortex_expect("MessageReader: message"), - ) - }) - } - - async fn next(&mut self) -> VortexResult { - if self.finished { - vortex_bail!("Reader is finished, should've peeked!") - } - self.prev_message = self.message.take(); - if !self.load_next_message().await? { - self.finished = true; - } - Ok(Buffer::from( - self.prev_message - .clone() - .vortex_expect("MessageReader prev_message"), - )) - } - - pub async fn read_dtype(&mut self) -> VortexResult { - if self.peek().and_then(|m| m.header_as_dtype()).is_none() { - vortex_bail!("Expected DType message") - } - - let buf = self.next().await?; - let msg = unsafe { root_unchecked::(&buf) } - .header_as_dtype() - .ok_or_else(|| { - vortex_err!("Expected schema message; this was checked earlier in the function") - })?; - - DType::try_from(msg) - } - - pub async fn maybe_read_chunk( - &mut self, - ctx: Arc, - dtype: DType, - ) -> VortexResult> { - let all_buffers_size = match self.peek().and_then(|m| m.header_as_array_data()) { - None => return Ok(None), - Some(array_data) => array_data - .buffers() - .unwrap_or_default() - .iter() - .map(|b| b.length() + (b.padding() as u64)) - .sum(), - }; - - let mut array_reader = ArrayMessageReader::from_fb_bytes(Buffer::from( - self.message.clone().vortex_expect("MessageReader: message"), - )); - - // Issue a single read to grab all buffers - let all_buffers = self.read.read_bytes(all_buffers_size).await?; - - if array_reader.read(all_buffers)?.is_some() { - unreachable!("This is an implementation bug") - }; - - let _ = self.next().await?; - array_reader.into_array(ctx, dtype).map(Some) - } - - pub fn array_stream(&mut self, ctx: Arc, dtype: DType) -> impl ArrayStream + '_ { - struct State<'a, R: VortexReadAt> { - msgs: &'a mut MessageReader, - ctx: Arc, - dtype: DType, - } - - let init = State { - msgs: self, - ctx, - dtype: dtype.clone(), - }; - - ArrayStreamAdapter::new( - dtype, - try_unfold(init, |state| async move { - match state - .msgs - .maybe_read_chunk(state.ctx.clone(), state.dtype.clone()) - .await? - { - None => Ok(None), - Some(array) => Ok(Some((array, state))), - } - }), - ) - } - - pub fn into_array_stream(self, ctx: Arc, dtype: DType) -> impl ArrayStream { - struct State { - msgs: MessageReader, - ctx: Arc, - dtype: DType, - } - - let init = State { - msgs: self, - ctx, - dtype: dtype.clone(), - }; - - ArrayStreamAdapter::new( - dtype, - try_unfold(init, |mut state| async move { - match state - .msgs - .maybe_read_chunk(state.ctx.clone(), state.dtype.clone()) - .await? - { - None => Ok(None), - Some(array) => Ok(Some((array, state))), - } - }), - ) - } - - pub async fn maybe_read_buffer(&mut self) -> VortexResult> { - let Some(buffer_msg) = self.peek().and_then(|m| m.header_as_buffer()) else { - return Ok(None); - }; - - let buffer_len = buffer_msg.length(); - let total_len = buffer_len + (buffer_msg.padding() as u64); - - let buffer = self.read.read_bytes(total_len).await?; - let page_buffer = Ok(Some(Buffer::from( - buffer.slice(..usize::try_from(buffer_len)?), - ))); - let _ = self.next().await?; - page_buffer - } - - pub fn into_inner(self) -> VortexBufReader { - self.read - } -} - -pub enum ReadState { - Init, - ReadingLength, - ReadingFb, - ReadingBuffers, - Finished, -} - -pub struct ArrayMessageReader { - state: ReadState, - fb_msg: Option, - buffers: Vec, -} - -impl Default for ArrayMessageReader { - fn default() -> Self { - Self::new() - } -} - -impl ArrayMessageReader { - pub fn new() -> Self { - Self { - state: ReadState::Init, - fb_msg: None, - buffers: Vec::new(), - } - } - - pub fn from_fb_bytes(fb_bytes: Buffer) -> Self { - Self { - state: ReadState::ReadingBuffers, - fb_msg: Some(fb_bytes), - buffers: Vec::new(), - } - } - - /// Parse the given bytes and optionally request more from the input. - pub fn read(&mut self, mut bytes: Bytes) -> VortexResult> { - match self.state { - ReadState::Init => { - self.state = ReadState::ReadingLength; - Ok(Some(MESSAGE_PREFIX_LENGTH)) - } - ReadState::ReadingLength => { - self.state = ReadState::ReadingFb; - Ok(Some(bytes.get_u32_le() as usize)) - } - ReadState::ReadingFb => { - // SAFETY: Assumes that any flatbuffer bytes passed have been validated. - // This is currently the case in stream and file implementations. - let array_data = unsafe { - root_unchecked::(&bytes) - .header_as_array_data() - .ok_or_else(|| vortex_err!("Message was not a batch"))? - }; - - let buffers_size = array_data - .buffers() - .map(|buffers| { - buffers - .iter() - .map(|buffer| buffer.length() + buffer.padding() as u64) - .sum::() - }) - .unwrap_or_default(); - - self.fb_msg = Some(Buffer::from(bytes)); - self.state = ReadState::ReadingBuffers; - Ok(Some( - buffers_size - .try_into() - .vortex_expect("Cannot cast to usize"), - )) - } - ReadState::ReadingBuffers => { - // Split out into individual buffers - // Initialize the column's buffers for a vectored read. - // To start with, we include the padding and then truncate the buffers after. - let batch_msg = self.fb_bytes_as_array_data()?; - - let ipc_buffers = batch_msg.buffers().unwrap_or_default(); - let buffers = ipc_buffers - .iter() - .map(|buffer| { - // Grab the buffer - let data_buffer = bytes.split_to( - buffer - .length() - .try_into() - .vortex_expect("Buffer size does not fit into usize"), - ); - // Strip off any padding from the previous buffer - bytes.advance(buffer.padding() as usize); - Buffer::from(data_buffer) - }) - .collect::>(); - - self.buffers = buffers; - self.state = ReadState::Finished; - Ok(None) - } - ReadState::Finished => vortex_bail!("Reader is already finished"), - } - } - - fn fb_bytes_as_array_data(&self) -> VortexResult { - unsafe { - root_unchecked::( - self.fb_msg - .as_ref() - .ok_or_else(|| vortex_err!("Populated in previous step"))?, - ) - } - .header_as_array_data() - .ok_or_else(|| vortex_err!("Checked in previous step")) - } - - /// Produce the array buffered in the reader - pub fn into_array(self, ctx: Arc, dtype: DType) -> VortexResult { - let row_count: usize = self.fb_bytes_as_array_data()?.row_count().try_into()?; - let fb_msg = self - .fb_msg - .ok_or_else(|| vortex_err!("Populated in previous step"))?; - ArrayData::try_new_viewed( - ctx, - dtype, - row_count, - fb_msg, - |flatbuffer| { - unsafe { root_unchecked::(flatbuffer) } - .header_as_array_data() - .ok_or_else(|| vortex_err!("Failed to get root header as batch"))? - .array() - .ok_or_else(|| vortex_err!("Chunk missing Array")) - }, - self.buffers, - ) - } -} - -#[cfg(test)] -mod test { - use bytes::Bytes; - use futures_executor::block_on; - use vortex_buffer::Buffer; - use vortex_io::VortexBufReader; - - use crate::messages::reader::MessageReader; - use crate::messages::writer::MessageWriter; - - #[test] - fn read_write_page() { - let write = Vec::new(); - let mut writer = MessageWriter::new(write); - block_on(async { - writer - .write_page(Buffer::from(Bytes::from("somevalue"))) - .await - }) - .unwrap(); - let written = Buffer::from(writer.into_inner()); - let mut reader = - block_on(async { MessageReader::try_new(VortexBufReader::new(written)).await }) - .unwrap(); - let read_page = block_on(async { reader.maybe_read_buffer().await }) - .unwrap() - .unwrap(); - assert_eq!(read_page, Buffer::from(Bytes::from("somevalue"))); - } -} diff --git a/vortex-ipc/src/messages/reader_async.rs b/vortex-ipc/src/messages/reader_async.rs new file mode 100644 index 0000000000..1cdf7137a1 --- /dev/null +++ b/vortex-ipc/src/messages/reader_async.rs @@ -0,0 +1,67 @@ +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::BytesMut; +use futures_util::{AsyncRead, Stream}; +use pin_project_lite::pin_project; +use vortex_error::VortexResult; + +use crate::messages::{DecoderMessage, MessageDecoder, PollRead}; + +pin_project! { + /// An IPC message reader backed by an `AsyncRead` stream. + pub struct AsyncMessageReader { + #[pin] + read: R, + buffer: BytesMut, + decoder: MessageDecoder, + bytes_read: usize, + } +} + +impl AsyncMessageReader { + pub fn new(read: R) -> Self { + AsyncMessageReader { + read, + buffer: BytesMut::new(), + decoder: MessageDecoder::default(), + bytes_read: 0, + } + } +} + +impl Stream for AsyncMessageReader { + type Item = VortexResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + match this.decoder.read_next(this.buffer)? { + PollRead::Some(msg) => return Poll::Ready(Some(Ok(msg))), + PollRead::NeedMore(nbytes) => { + this.buffer.resize(nbytes, 0x00); + + match ready!(this + .read + .as_mut() + .poll_read(cx, &mut this.buffer.as_mut()[*this.bytes_read..])) + { + Ok(0) => { + // End of file + return Poll::Ready(None); + } + Ok(nbytes) => { + *this.bytes_read += nbytes; + // If we've finished the read operation, then we continue the loop + // and the decoder should present us with a new response. + if *this.bytes_read == nbytes { + *this.bytes_read = 0; + } + } + Err(e) => return Poll::Ready(Some(Err(e.into()))), + } + } + } + } + } +} diff --git a/vortex-ipc/src/messages/reader_sync.rs b/vortex-ipc/src/messages/reader_sync.rs new file mode 100644 index 0000000000..055655da90 --- /dev/null +++ b/vortex-ipc/src/messages/reader_sync.rs @@ -0,0 +1,51 @@ +use std::io::Read; + +use bytes::BytesMut; +use vortex_error::VortexResult; + +use crate::messages::{DecoderMessage, MessageDecoder, PollRead}; + +/// An IPC message reader backed by a `Read` stream. +pub struct SyncMessageReader { + read: R, + buffer: BytesMut, + decoder: MessageDecoder, +} + +impl SyncMessageReader { + pub fn new(read: R) -> Self { + SyncMessageReader { + read, + buffer: BytesMut::new(), + decoder: MessageDecoder::default(), + } + } +} + +impl Iterator for SyncMessageReader { + type Item = VortexResult; + + fn next(&mut self) -> Option { + loop { + match self.decoder.read_next(&mut self.buffer) { + Ok(PollRead::Some(msg)) => { + return Some(Ok(msg)); + } + Ok(PollRead::NeedMore(nbytes)) => { + self.buffer.resize(nbytes, 0x00); + match self.read.read(&mut self.buffer) { + Ok(0) => { + // EOF + return None; + } + Ok(_nbytes) => { + // Continue in the loop + } + Err(e) => return Some(Err(e.into())), + } + } + Err(e) => return Some(Err(e)), + } + } + } +} diff --git a/vortex-ipc/src/messages/writer.rs b/vortex-ipc/src/messages/writer.rs deleted file mode 100644 index c7941fe68f..0000000000 --- a/vortex-ipc/src/messages/writer.rs +++ /dev/null @@ -1,140 +0,0 @@ -#![allow(clippy::assertions_on_constants)] -use std::io; - -use bytes::Bytes; -use flatbuffers::FlatBufferBuilder; -use vortex_array::ArrayData; -use vortex_buffer::io_buf::IoBuf; -use vortex_buffer::Buffer; -use vortex_dtype::DType; -use vortex_error::VortexUnwrap; -use vortex_flatbuffers::{WriteFlatBuffer, WriteFlatBufferExt}; -use vortex_io::VortexWrite; - -use crate::messages::IPCMessage; -use crate::ALIGNMENT; - -static ZEROS: [u8; 512] = [0; 512]; - -#[derive(Debug)] -pub struct MessageWriter { - write: W, - pos: u64, - alignment: usize, - - scratch: Option>, -} - -impl MessageWriter { - pub fn new(write: W) -> Self { - assert!(ALIGNMENT <= ZEROS.len(), "ALIGNMENT must be <= 512"); - Self { - write, - pos: 0, - alignment: ALIGNMENT, - scratch: Some(Vec::new()), - } - } - - pub fn into_inner(self) -> W { - self.write - } - - /// Returns the current position in the stream. - pub fn tell(&self) -> u64 { - self.pos - } - - pub async fn write_dtype_raw(&mut self, dtype: &DType) -> io::Result<()> { - let buffer = dtype.write_flatbuffer_bytes(); - let written_len = buffer.len(); - self.write_all(buffer).await?; - - let aligned_size = written_len.next_multiple_of(self.alignment); - let padding = aligned_size - written_len; - - self.write_all(Bytes::from(&ZEROS[..padding])).await?; - - Ok(()) - } - - pub async fn write_dtype(&mut self, dtype: DType) -> io::Result<()> { - self.write_message(IPCMessage::DType(dtype)).await - } - - pub async fn write_array(&mut self, array: ArrayData) -> io::Result<()> { - self.write_message(IPCMessage::Array(array.clone())).await?; - - for array in array.depth_first_traversal() { - if let Some(buffer) = array.buffer() { - let buffer_len = buffer.len(); - let padding = buffer_len.next_multiple_of(self.alignment) - buffer_len; - self.write_all(buffer.clone()).await?; - self.write_all(Bytes::from(&ZEROS[..padding])).await?; - } - } - - Ok(()) - } - - pub async fn write_page(&mut self, buffer: Buffer) -> io::Result<()> { - let buffer_len = buffer.len(); - let padding = buffer_len.next_multiple_of(self.alignment) - buffer_len; - self.write_message(IPCMessage::Buffer(buffer.clone())) - .await?; - self.write_all(buffer).await?; - self.write_all(Bytes::from(&ZEROS[..padding])).await?; - - Ok(()) - } - - pub async fn write_message(&mut self, flatbuffer: F) -> io::Result<()> { - // We reuse the scratch buffer each time and then replace it at the end. - // The scratch buffer may be missing if a previous write failed. We could use scopeguard - // or similar here if it becomes a problem in practice. - let mut scratch = self.scratch.take().unwrap_or_default(); - scratch.clear(); - - // In order for FlatBuffers to use the correct alignment, we insert 4 bytes at the start - // of the flatbuffer vector since we will be writing this to the stream later. - scratch.extend_from_slice(&[0_u8; 4]); - - let mut fbb = FlatBufferBuilder::from_vec(scratch); - let root = flatbuffer.write_flatbuffer(&mut fbb); - fbb.finish_minimal(root); - - let (buffer, buffer_begin) = fbb.collapse(); - let buffer_end = buffer.len(); - let buffer_len = buffer_end - buffer_begin; - - let unaligned_size = 4 + buffer_len; - let aligned_size = (unaligned_size + (self.alignment - 1)) & !(self.alignment - 1); - let padding = aligned_size - unaligned_size; - - // Write the size as u32, followed by the buffer, followed by padding. - self.write_all( - u32::try_from(aligned_size - 4) - .vortex_unwrap() - .to_le_bytes(), - ) - .await?; - let buffer = self - .write_all(buffer.slice_owned(buffer_begin..buffer_end)) - .await? - .into_inner(); - self.write_all(Bytes::from(&ZEROS[..padding])).await?; - - assert_eq!(self.pos % self.alignment as u64, 0); - - // Replace the scratch buffer - self.scratch = Some(buffer); - - Ok(()) - } - - async fn write_all(&mut self, buf: B) -> io::Result { - let buf = self.write.write_all(buf).await?; - self.pos += buf.bytes_init() as u64; - Ok(buf) - } -} diff --git a/vortex-ipc/src/messages/writer_async.rs b/vortex-ipc/src/messages/writer_async.rs new file mode 100644 index 0000000000..a9a55e03f9 --- /dev/null +++ b/vortex-ipc/src/messages/writer_async.rs @@ -0,0 +1,33 @@ +use futures_util::{AsyncWrite, AsyncWriteExt}; +use vortex_error::VortexResult; + +use crate::messages::{EncoderMessage, MessageEncoder}; + +pub struct AsyncMessageWriter { + write: W, + encoder: MessageEncoder, +} + +impl AsyncMessageWriter { + pub fn new(write: W) -> Self { + Self { + write, + encoder: MessageEncoder::default(), + } + } + + pub async fn write_message(&mut self, message: EncoderMessage<'_>) -> VortexResult<()> { + for buffer in self.encoder.encode(message) { + self.write.write_all(&buffer).await?; + } + Ok(()) + } + + pub fn inner(&self) -> &W { + &self.write + } + + pub fn into_inner(self) -> W { + self.write + } +} diff --git a/vortex-ipc/src/messages/writer_sync.rs b/vortex-ipc/src/messages/writer_sync.rs new file mode 100644 index 0000000000..d64ac3c4b5 --- /dev/null +++ b/vortex-ipc/src/messages/writer_sync.rs @@ -0,0 +1,26 @@ +use std::io::Write; + +use vortex_error::VortexResult; + +use crate::messages::{EncoderMessage, MessageEncoder}; + +pub struct SyncMessageWriter { + write: W, + encoder: MessageEncoder, +} + +impl SyncMessageWriter { + pub fn new(write: W) -> Self { + Self { + write, + encoder: MessageEncoder::default(), + } + } + + pub fn write_message(&mut self, message: EncoderMessage) -> VortexResult<()> { + for buffer in self.encoder.encode(message) { + self.write.write_all(&buffer)?; + } + Ok(()) + } +} diff --git a/vortex-ipc/src/stream.rs b/vortex-ipc/src/stream.rs new file mode 100644 index 0000000000..dacc725349 --- /dev/null +++ b/vortex-ipc/src/stream.rs @@ -0,0 +1,231 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Poll}; + +use aligned_buffer::UniqueAlignedBuffer; +use bytes::Bytes; +use futures_util::{AsyncRead, AsyncWrite, AsyncWriteExt, Stream, StreamExt, TryStreamExt}; +use pin_project_lite::pin_project; +use vortex_array::stream::ArrayStream; +use vortex_array::{ArrayDType, ArrayData, Context}; +use vortex_buffer::Buffer; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, vortex_err, VortexResult}; + +use crate::messages::{AsyncMessageReader, DecoderMessage, EncoderMessage, MessageEncoder}; +use crate::ALIGNMENT; + +pin_project! { + /// An [`ArrayStream`] for reading messages off an async IPC stream. + pub struct AsyncIPCReader { + #[pin] + reader: AsyncMessageReader, + ctx: Arc, + dtype: DType, + } +} + +impl AsyncIPCReader { + pub async fn try_new(read: R, ctx: Arc) -> VortexResult { + let mut reader = AsyncMessageReader::new(read); + + let dtype = match reader.next().await.transpose()? { + Some(msg) => match msg { + DecoderMessage::DType(dtype) => dtype, + msg => { + vortex_bail!("Expected DType message, got {:?}", msg); + } + }, + None => vortex_bail!("Expected DType message, got EOF"), + }; + + Ok(AsyncIPCReader { reader, ctx, dtype }) + } +} + +impl ArrayStream for AsyncIPCReader { + fn dtype(&self) -> &DType { + &self.dtype + } +} + +impl Stream for AsyncIPCReader { + type Item = VortexResult; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + + match ready!(this.reader.poll_next(cx)) { + None => Poll::Ready(None), + Some(msg) => match msg { + Ok(DecoderMessage::Array(array_parts)) => Poll::Ready(Some( + array_parts + .into_array_data(this.ctx.clone(), this.dtype.clone()) + .and_then(|array| { + if array.dtype() != this.dtype { + Err(vortex_err!( + "Array data type mismatch: expected {:?}, got {:?}", + this.dtype, + array.dtype() + )) + } else { + Ok(array) + } + }), + )), + Ok(msg) => Poll::Ready(Some(Err(vortex_err!( + "Expected Array message, got {:?}", + msg + )))), + Err(e) => Poll::Ready(Some(Err(e))), + }, + } + } +} + +/// A trait for convering an [`ArrayStream`] into IPC streams. +pub trait ArrayStreamIPC { + fn into_ipc(self) -> ArrayStreamIPCBytes + where + Self: Sized; + + fn write_ipc(self, write: W) -> impl Future> + where + Self: Sized; +} + +impl ArrayStreamIPC for S { + fn into_ipc(self) -> ArrayStreamIPCBytes + where + Self: Sized, + { + ArrayStreamIPCBytes { + stream: Box::pin(self), + encoder: MessageEncoder::default(), + buffers: vec![], + written_dtype: false, + } + } + + async fn write_ipc(self, mut write: W) -> VortexResult + where + Self: Sized, + { + let mut stream = self.into_ipc(); + while let Some(chunk) = stream.next().await { + write.write_all(chunk?.as_slice()).await?; + } + Ok(write) + } +} + +pub struct ArrayStreamIPCBytes { + stream: Pin>, + encoder: MessageEncoder, + buffers: Vec, + written_dtype: bool, +} + +impl ArrayStreamIPCBytes { + /// Collects the IPC bytes into a single `Buffer`. + pub async fn collect_to_buffer(self) -> VortexResult { + // We allocate a single aligned buffer to hold the combined IPC bytes + let buffers: Vec = self.try_collect().await?; + let mut buffer = + UniqueAlignedBuffer::::with_capacity(buffers.iter().map(|b| b.len()).sum()); + for buf in buffers { + buffer.extend_from_slice(buf.as_slice()); + } + Ok(Buffer::from(Bytes::from_owner(buffer))) + } +} + +impl Stream for ArrayStreamIPCBytes { + type Item = VortexResult; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + + // If we haven't written the dtype yet, we write it + if !this.written_dtype { + this.buffers.extend( + this.encoder + .encode(EncoderMessage::DType(this.stream.dtype())), + ); + this.written_dtype = true; + } + + // Try to flush any buffers we have + if !this.buffers.is_empty() { + return Poll::Ready(Some(Ok(this.buffers.remove(0)))); + } + + // Or else try to serialize the next array + match ready!(this.stream.poll_next_unpin(cx)) { + None => return Poll::Ready(None), + Some(chunk) => match chunk { + Ok(chunk) => { + this.buffers + .extend(this.encoder.encode(EncoderMessage::Array(&chunk))); + } + Err(e) => return Poll::Ready(Some(Err(e))), + }, + } + + // Try to flush any buffers we have again + if !this.buffers.is_empty() { + return Poll::Ready(Some(Ok(this.buffers.remove(0)))); + } + + // Otherwise, we're done + Poll::Ready(None) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use futures_util::io::Cursor; + use vortex_array::array::PrimitiveArray; + use vortex_array::stream::{ArrayStream, ArrayStreamExt}; + use vortex_array::validity::Validity; + use vortex_array::{ArrayDType, Context, IntoArrayVariant, ToArrayData}; + + use super::*; + + #[tokio::test] + async fn test_async_stream() { + let array = PrimitiveArray::from_vec::(vec![1, 2, 3], Validity::NonNullable); + let ipc_buffer = array + .to_array() + .into_array_stream() + .into_ipc() + .collect_to_buffer() + .await + .unwrap(); + + let reader = AsyncIPCReader::try_new(Cursor::new(ipc_buffer), Arc::new(Context::default())) + .await + .unwrap(); + + assert_eq!(reader.dtype(), array.dtype()); + let result = reader + .into_array_data() + .await + .unwrap() + .into_primitive() + .unwrap(); + assert_eq!( + array.maybe_null_slice::(), + result.maybe_null_slice::() + ); + } +} diff --git a/vortex-ipc/src/stream_reader/mod.rs b/vortex-ipc/src/stream_reader/mod.rs deleted file mode 100644 index c7414b5f2a..0000000000 --- a/vortex-ipc/src/stream_reader/mod.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::ops::Deref; -use std::sync::Arc; - -use futures_util::stream::try_unfold; -use futures_util::Stream; -use vortex_array::stream::ArrayStream; -use vortex_array::Context; -use vortex_buffer::Buffer; -use vortex_dtype::DType; -use vortex_error::{VortexExpect as _, VortexResult}; -use vortex_io::{VortexBufReader, VortexReadAt}; - -use crate::messages::reader::MessageReader; - -pub struct StreamArrayReader { - msgs: MessageReader, - ctx: Arc, - dtype: Option>, -} - -impl StreamArrayReader { - pub async fn try_new(read: VortexBufReader, ctx: Arc) -> VortexResult { - Ok(Self { - msgs: MessageReader::try_new(read).await?, - ctx, - dtype: None, - }) - } - - pub fn with_dtype(mut self, dtype: Arc) -> Self { - assert!(self.dtype.is_none(), "DType already set"); - self.dtype = Some(dtype); - self - } - - pub async fn load_dtype(mut self) -> VortexResult { - assert!(self.dtype.is_none(), "DType already set"); - self.dtype = Some(Arc::new(self.msgs.read_dtype().await?)); - Ok(self) - } - - /// Reads a single array from the stream. - pub fn array_stream(&mut self) -> impl ArrayStream + '_ { - let dtype = self - .dtype - .as_ref() - .vortex_expect("Cannot read array from stream: DType not set") - .deref() - .clone(); - self.msgs.array_stream(self.ctx.clone(), dtype) - } - - pub fn into_array_stream(self) -> impl ArrayStream { - let dtype = self - .dtype - .as_ref() - .vortex_expect("Cannot read array from stream: DType not set") - .deref() - .clone(); - self.msgs.into_array_stream(self.ctx, dtype) - } - - /// Reads a single page from the stream. - pub async fn next_page(&mut self) -> VortexResult> { - self.msgs.maybe_read_buffer().await - } - - /// Reads consecutive pages from the stream until the message type changes. - pub async fn page_stream(&mut self) -> impl Stream> + '_ { - try_unfold(self, |reader| async { - match reader.next_page().await? { - Some(page) => Ok(Some((page, reader))), - None => Ok(None), - } - }) - } -} diff --git a/vortex-ipc/src/stream_writer/mod.rs b/vortex-ipc/src/stream_writer/mod.rs deleted file mode 100644 index ef6ac38a19..0000000000 --- a/vortex-ipc/src/stream_writer/mod.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::fmt::{Display, Formatter}; -use std::ops::Range; - -use futures_util::{Stream, TryStreamExt}; -use vortex_array::array::ChunkedArray; -use vortex_array::stream::ArrayStream; -use vortex_array::ArrayData; -use vortex_buffer::Buffer; -use vortex_dtype::DType; -use vortex_error::{VortexResult, VortexUnwrap}; -use vortex_io::VortexWrite; - -use crate::messages::writer::MessageWriter; - -#[cfg(test)] -mod tests; - -pub struct StreamArrayWriter { - msgs: MessageWriter, - - array_layouts: Vec, - page_ranges: Vec, -} - -impl StreamArrayWriter { - pub fn new(write: W) -> Self { - Self { - msgs: MessageWriter::new(write), - array_layouts: vec![], - page_ranges: vec![], - } - } - - pub fn array_layouts(&self) -> &[ArrayLayout] { - &self.array_layouts - } - - pub fn page_ranges(&self) -> &[ByteRange] { - &self.page_ranges - } - - pub fn into_inner(self) -> W { - self.msgs.into_inner() - } - - async fn write_dtype(&mut self, dtype: DType) -> VortexResult { - let begin = self.msgs.tell(); - self.msgs.write_dtype(dtype).await?; - let end = self.msgs.tell(); - Ok(ByteRange { begin, end }) - } - - async fn write_array_chunks(&mut self, mut stream: S) -> VortexResult - where - S: Stream> + Unpin, - { - let mut byte_offsets = vec![self.msgs.tell()]; - let mut row_offsets = vec![0]; - let mut row_offset = 0; - - while let Some(chunk) = stream.try_next().await? { - row_offset += chunk.len() as u64; - row_offsets.push(row_offset); - self.msgs.write_array(chunk).await?; - byte_offsets.push(self.msgs.tell()); - } - - Ok(ChunkOffsets::new(byte_offsets, row_offsets)) - } - - pub async fn write_array_stream( - mut self, - mut array_stream: S, - ) -> VortexResult { - let dtype_pos = self.write_dtype(array_stream.dtype().clone()).await?; - let chunk_pos = self.write_array_chunks(&mut array_stream).await?; - self.array_layouts.push(ArrayLayout { - dtype: dtype_pos, - chunks: chunk_pos, - }); - Ok(self) - } - - pub async fn write_array(self, array: ArrayData) -> VortexResult { - if let Some(chunked_array) = ChunkedArray::maybe_from(&array) { - self.write_array_stream(chunked_array.array_stream()).await - } else { - self.write_array_stream(array.into_array_stream()).await - } - } - - pub async fn write_page(mut self, buffer: Buffer) -> VortexResult { - let begin = self.msgs.tell(); - self.msgs.write_page(buffer).await?; - let end = self.msgs.tell(); - self.page_ranges.push(ByteRange { begin, end }); - Ok(self) - } -} - -#[derive(Copy, Clone, Debug)] -pub struct ByteRange { - pub begin: u64, - pub end: u64, -} - -impl Display for ByteRange { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "[{}, {})", self.begin, self.end) - } -} - -impl ByteRange { - pub fn new(begin: u64, end: u64) -> Self { - assert!(end > begin, "Buffer end must be after its beginning"); - Self { begin, end } - } - - pub fn len(&self) -> u64 { - self.end - self.begin - } - - pub fn is_empty(&self) -> bool { - self.begin == self.end - } - - pub fn to_range(&self) -> Range { - Range { - start: self.begin.try_into().vortex_unwrap(), - end: self.end.try_into().vortex_unwrap(), - } - } -} - -#[derive(Clone, Debug)] -pub struct ArrayLayout { - pub dtype: ByteRange, - pub chunks: ChunkOffsets, -} - -#[derive(Clone, Debug)] -pub struct ChunkOffsets { - pub byte_offsets: Vec, - pub row_offsets: Vec, -} - -impl ChunkOffsets { - pub fn new(byte_offsets: Vec, row_offsets: Vec) -> Self { - Self { - byte_offsets, - row_offsets, - } - } -} diff --git a/vortex-ipc/src/stream_writer/tests.rs b/vortex-ipc/src/stream_writer/tests.rs deleted file mode 100644 index 15d32bad5d..0000000000 --- a/vortex-ipc/src/stream_writer/tests.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::sync::Arc; - -use arrow_array::cast::AsArray as _; -use arrow_array::types::Int32Type; -use arrow_array::PrimitiveArray; -use vortex_array::arrow::FromArrowArray; -use vortex_array::stream::ArrayStreamExt; -use vortex_array::{ArrayData, Context, IntoCanonical}; -use vortex_buffer::Buffer; -use vortex_io::VortexBufReader; - -use crate::stream_reader::StreamArrayReader; -use crate::stream_writer::StreamArrayWriter; - -#[tokio::test] -async fn broken_data() { - let arrow_arr: PrimitiveArray = [Some(1), Some(2), Some(3), None].iter().collect(); - let vortex_arr = ArrayData::from_arrow(&arrow_arr, true); - let written = StreamArrayWriter::new(Vec::new()) - .write_array(vortex_arr) - .await - .unwrap() - .into_inner(); - let written = Buffer::from(written); - let reader = - StreamArrayReader::try_new(VortexBufReader::new(written), Arc::new(Context::default())) - .await - .unwrap(); - let arr = reader - .load_dtype() - .await - .unwrap() - .into_array_stream() - .collect_chunked() - .await - .unwrap(); - let round_tripped = arr.into_arrow().unwrap(); - assert_eq!(&arrow_arr, round_tripped.as_primitive::()); -}