diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs index 53df698c33ac..cb3dbf2d337b 100644 --- a/datafusion/core/src/physical_plan/sorts/cursor.rs +++ b/datafusion/core/src/physical_plan/sorts/cursor.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow::row::{Row, Rows}; +use arrow::row::Row; use std::cmp::Ordering; +use super::RowBatch; + /// A `SortKeyCursor` is created from a `RecordBatch`, and a set of /// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. /// @@ -35,7 +37,7 @@ pub struct SortKeyCursor { // An id uniquely identifying the record batch scanned by this cursor. batch_id: usize, - rows: Rows, + rows: RowBatch, } impl std::fmt::Debug for SortKeyCursor { @@ -50,7 +52,7 @@ impl std::fmt::Debug for SortKeyCursor { impl SortKeyCursor { /// Create a new SortKeyCursor - pub fn new(stream_idx: usize, batch_id: usize, rows: Rows) -> Self { + pub fn new(stream_idx: usize, batch_id: usize, rows: RowBatch) -> Self { Self { stream_idx, cur_row: 0, diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs b/datafusion/core/src/physical_plan/sorts/mod.rs index db6ab5c604e2..34bd5f4be231 100644 --- a/datafusion/core/src/physical_plan/sorts/mod.rs +++ b/datafusion/core/src/physical_plan/sorts/mod.rs @@ -17,30 +17,485 @@ //! Sort functionalities -use crate::physical_plan::SendableRecordBatchStream; -use std::fmt::{Debug, Formatter}; - +use crate::{error::Result, physical_plan::SendableRecordBatchStream}; +use std::{ + fmt::{Debug, Formatter}, + pin::Pin, + sync::Arc, + task::Poll, +}; mod cursor; mod index; pub mod sort; pub mod sort_preserving_merge; +use arrow::{ + record_batch::RecordBatch, + row::{Row, RowParser, Rows}, +}; pub use cursor::SortKeyCursor; +use futures::{stream::Fuse, Stream, StreamExt}; pub use index::RowIndex; +use pin_project_lite::pin_project; +use tokio::task::JoinHandle; + +use super::{common::AbortOnDropSingle, metrics::MemTrackingMetrics}; + +pub(crate) type SendableRowStream = Pin> + Send>>; +pub(crate) type SortStreamItem = Result<(RecordBatch, Option)>; +pub(crate) type SendableSortStream = Pin + Send>>; -pub(crate) struct SortedStream { - stream: SendableRecordBatchStream, - mem_used: usize, +pin_project! { + pub(crate) struct SortedStream { + #[pin] + batches: Option>, + #[pin] + rows: Option>, + #[pin] + pairs_stream: Option, + pairs_rx: Option>, + last_batch: Option>, + last_row: Option>, + mem_used: usize, + // flag is only true if this was intialized wiith `new_no_row_encoding` + row_encoding_ignored: bool, + rx_drop_helper: Option>, + is_empty: bool + } } +impl SortedStream { + pub(crate) fn new(stream: SendableSortStream, mem_used: usize) -> Self { + Self { + batches: None, + rows: None, + pairs_rx: None, + pairs_stream: Some(stream), + rx_drop_helper: None, + mem_used, + row_encoding_ignored: false, + last_batch: None, + last_row: None, + is_empty: false, + } + } + pub(crate) fn new_from_rx( + rx: tokio::sync::mpsc::Receiver, + handle: JoinHandle<()>, + mem_used: usize, + ) -> Self { + Self { + batches: None, + rows: None, + pairs_rx: Some(rx), + pairs_stream: None, + rx_drop_helper: Some(AbortOnDropSingle::new(handle)), + mem_used, + row_encoding_ignored: false, + last_batch: None, + last_row: None, + is_empty: false, + } + } + pub(crate) fn new_from_streams( + stream: SendableRecordBatchStream, + mem_used: usize, + row_stream: SendableRowStream, + ) -> Self { + Self { + batches: Some(stream.fuse()), + rows: Some(row_stream.fuse()), + pairs_rx: None, + pairs_stream: None, + mem_used, + row_encoding_ignored: false, + last_batch: None, + last_row: None, + rx_drop_helper: None, + is_empty: false, + } + } + /// create stream where the row encoding for each batch is always None + pub(crate) fn new_no_row_encoding( + stream: SendableRecordBatchStream, + mem_used: usize, + ) -> Self { + Self { + batches: Some(stream.fuse()), + rows: None, + mem_used, + pairs_rx: None, + pairs_stream: None, + row_encoding_ignored: true, + last_batch: None, + last_row: None, + rx_drop_helper: None, + is_empty: false, + } + } + pub(crate) fn empty() -> Self { + Self { + is_empty: true, + + batches: None, + rows: None, + mem_used: 0, + pairs_rx: None, + pairs_stream: None, + row_encoding_ignored: true, + last_batch: None, + last_row: None, + rx_drop_helper: None, + } + } +} impl Debug for SortedStream { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "InMemSorterStream") } } +impl Stream for SortedStream { + type Item = SortStreamItem; -impl SortedStream { - pub(crate) fn new(stream: SendableRecordBatchStream, mem_used: usize) -> Self { - Self { stream, mem_used } + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + if *this.is_empty { + return Poll::Ready(None); + } + if this.pairs_rx.is_some() { + return this.pairs_rx.as_mut().unwrap().poll_recv(cx); + } + if this.pairs_stream.is_some() { + return this.pairs_stream.as_pin_mut().unwrap().poll_next(cx); + } + if this.rows.is_none() { + // even if no rows stream there has to be a batch stream + return match this.batches.as_pin_mut().unwrap().poll_next(cx) { + Poll::Ready(Some(Ok(batch))) => Poll::Ready(Some(Ok((batch, None)))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }; + } + // otherwise both batches and rows exist + let mut batches = this.batches.as_pin_mut().unwrap(); + let mut rows = this.rows.as_pin_mut().unwrap(); + if this.last_batch.is_none() { + match batches.as_mut().poll_next(cx) { + Poll::Ready(Some(res)) => *this.last_batch = Some(res), + Poll::Ready(None) | Poll::Pending => {} + } + } + if this.last_row.is_none() { + match rows.as_mut().poll_next(cx) { + Poll::Ready(Some(maybe_rows)) => *this.last_row = Some(maybe_rows), + Poll::Ready(None) | Poll::Pending => {} + } + } + if this.last_batch.is_some() && this.last_row.is_some() { + let result = this.last_batch.take().unwrap(); + let maybe_row = this.last_row.take().unwrap(); + Poll::Ready(Some(result.map(|batch| (batch, maybe_row)))) + } else if rows.is_done() || batches.is_done() { + Poll::Ready(None) + } else { + Poll::Pending + } + } +} +// helper logic used a few times. version of metrics.record_poll with different inner type +pub(crate) fn record_poll_sort_item( + metrics: &MemTrackingMetrics, + poll: std::task::Poll>, +) -> std::task::Poll> { + if let std::task::Poll::Ready(maybe_sort_item) = &poll { + match maybe_sort_item { + Some(Ok((batch, _rows))) => metrics.record_output(batch.num_rows()), + Some(Err(_)) | None => { + metrics.done(); + } + } + } + poll +} + +/// Cloneable batch of rows taken from multiple [RowSelection]s +#[derive(Debug, Clone)] +pub struct RowBatch { + // refs to the rows referenced by `indices` + rows: Vec>, + // first item = index of the ref in `rows`, second item=index within that `RowSelection` + indices: Arc>, + /// if wrapping a single rows object + single_rows: Option>, +} + +impl RowBatch { + /// Create new batch of rows selected from `rows`. + /// + /// `indices` defines where each row comes from: first element of the tuple is the index + /// of the ref in `rows`, second is the index within that `RowSelection`. + pub fn new(rows: Vec>, indices: Vec<(usize, usize)>) -> Self { + Self { + rows, + indices: Arc::new(indices), + single_rows: None, + } + } + + /// Returns the nth row in the batch. + pub fn row(&self, n: usize) -> Row { + match &self.single_rows { + Some(rows) => rows.row(n), + None => { + let (rows_ref_idx, row_idx) = self.indices[n]; + self.rows[rows_ref_idx].row(row_idx) + } + } + } + + /// Number of rows selected + pub fn num_rows(&self) -> usize { + match &self.single_rows { + Some(rows) => rows.num_rows(), + None => self.indices.len(), + } + } + /// Iterate over rows in their selected order + pub fn iter(&self) -> RowBatchIter { + RowBatchIter { + row_selection: self, + cur_idx: 0, + } + } + /// Amount of bytes + pub fn memory_size(&self) -> usize { + match &self.single_rows { + Some(rows) => rows.size() + std::mem::size_of::(), + None => { + let indices_size = self.indices.len() * 2 * std::mem::size_of::(); + let rows_size = self.rows.iter().map(|r| r.size()).sum::(); + rows_size + indices_size + std::mem::size_of::() + } + } + } +} +impl From for RowBatch { + fn from(value: RowSelection) -> Self { + Self { + indices: Arc::new((0..value.num_rows()).map(|i| (0, i)).collect()), + rows: vec![Arc::new(value)], + single_rows: None, + } + } +} +impl From for RowBatch { + fn from(value: Rows) -> Self { + Self { + rows: Vec::with_capacity(0), + indices: Arc::new(Vec::with_capacity(0)), + single_rows: Some(Arc::new(value)), + } + } +} + +/// Iterate over each row in a [`RowBatch`] +pub struct RowBatchIter<'a> { + row_selection: &'a RowBatch, + cur_idx: usize, +} +impl<'a> Iterator for RowBatchIter<'a> { + type Item = Row<'a>; + + fn next(&mut self) -> Option { + if self.cur_idx < self.row_selection.num_rows() { + let row = self.row_selection.row(self.cur_idx); + self.cur_idx += 1; + Some(row) + } else { + None + } + } +} + +/// A selection of rows from the same [`RowData`]. +#[derive(Debug)] +pub struct RowSelection { + rows: RowData, + // None when this `RowSelection` is equivalent to its `Rows` + indices: Option>, +} +#[derive(Debug)] +enum RowData { + /// Rows that have always been in memory + Rows(Rows), + /// Rows that were spilled to disk and then later read back into mem + Spilled { + parser: RowParser, + bytes: Vec>, + }, +} +impl RowData { + fn row(&self, n: usize) -> Row { + match self { + RowData::Rows(rows) => rows.row(n), + RowData::Spilled { parser, bytes } => parser.parse(&bytes[n]), + } + } + fn size(&self) -> usize { + match self { + RowData::Rows(rows) => rows.size(), + RowData::Spilled { bytes, .. } => bytes.len() + std::mem::size_of::(), + } + } + fn num_rows(&self) -> usize { + match self { + RowData::Rows(rows) => rows.num_rows(), + RowData::Spilled { bytes, .. } => bytes.len(), + } + } +} +impl RowSelection { + /// New + pub fn new(rows: Rows, indices: Vec) -> Self { + Self { + rows: RowData::Rows(rows), + indices: Some(indices), + } + } + fn from_spilled(parser: RowParser, bytes: Vec>) -> Self { + Self { + rows: RowData::Spilled { parser, bytes }, + indices: None, + } + } + /// Get the nth row of the selection. + pub fn row(&self, n: usize) -> Row { + if let Some(ref indices) = self.indices { + let idx = indices[n]; + self.rows.row(idx) + } else { + self.rows.row(n) + } + } + + /// Iterate over the rows in the selected order. + pub fn iter(&self) -> RowSelectionIter { + RowSelectionIter { + row_selection: self, + cur_n: 0, + } + } + /// Number of bytes held in rows and indices. + pub fn size(&self) -> usize { + let indices_size = self + .indices + .as_ref() + .map(|i| i.len() * std::mem::size_of::()) + .unwrap_or(0); + self.rows.size() + indices_size + std::mem::size_of::() + } + + fn num_rows(&self) -> usize { + if let Some(ref indices) = self.indices { + indices.len() + } else { + self.rows.num_rows() + } + } +} +impl From for RowSelection { + fn from(value: Rows) -> Self { + Self { + indices: None, + rows: RowData::Rows(value), + } + } +} +impl From for RowSelection { + fn from(value: RowData) -> Self { + Self { + indices: None, + rows: value, + } + } +} +/// Iterator for [`RowSelection`] +pub struct RowSelectionIter<'a> { + row_selection: &'a RowSelection, + cur_n: usize, +} +impl<'a> Iterator for RowSelectionIter<'a> { + type Item = Row<'a>; + + fn next(&mut self) -> Option { + if self.cur_n < self.row_selection.num_rows() { + let row = self.row_selection.row(self.cur_n); + self.cur_n += 1; + Some(row) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::Int64Array, + datatypes::DataType, + record_batch::RecordBatch, + row::{RowConverter, SortField}, + }; + + use crate::assert_batches_eq; + + use super::*; + + fn int64_rows( + conv: &mut RowConverter, + values: impl IntoIterator, + ) -> Rows { + let array: Int64Array = values.into_iter().map(Some).collect(); + let batch = + RecordBatch::try_from_iter(vec![("c1", Arc::new(array) as _)]).unwrap(); + conv.convert_columns(batch.columns()).unwrap() + } + + #[test] + fn test_row_batch_and_sorted_rows() { + let mut conv = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap(); + let s1 = RowSelection::new(int64_rows(&mut conv, 0..3), vec![2, 2, 1]); + let s2 = RowSelection::new(int64_rows(&mut conv, 5..8), vec![1, 2, 0]); + let s3: RowSelection = int64_rows(&mut conv, 2..4).into(); // null indices case + let selection = RowBatch::new( + vec![s1, s2, s3].into_iter().map(Arc::new).collect(), + vec![ + (2, 0), // 2 + (0, 2), // 1 + (0, 0), // 2 + (1, 1), // 7 + ], + ); + let rows: Vec = selection.iter().collect(); + assert_eq!(rows.len(), 4); + let parsed = conv.convert_rows(rows).unwrap(); + let batch = + RecordBatch::try_from_iter(vec![("c1", parsed.get(0).unwrap().clone())]) + .unwrap(); + let expected = vec![ + "+----+", // + "| c1 |", // + "+----+", // + "| 2 |", // + "| 1 |", // + "| 2 |", // + "| 7 |", // + "+----+", + ]; + assert_batches_eq!(expected, &[batch]); } } diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index c3fc06206ca1..05f6331187e3 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -19,6 +19,8 @@ //! It will do in-memory sorting if it has enough memory budget //! but spills to disk if needed. +use super::{record_poll_sort_item, RowBatch, RowSelection}; +use super::{SendableSortStream, SortStreamItem}; use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::execution::memory_pool::{ @@ -32,19 +34,20 @@ use crate::physical_plan::metrics::{ }; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; use crate::physical_plan::sorts::SortedStream; -use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::physical_plan::stream::RecordBatchStreamAdapter; use crate::physical_plan::{ - DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, + displayable, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, }; use crate::prelude::SessionConfig; -use arrow::array::{make_array, Array, ArrayRef, MutableArrayData}; +use arrow::array::{make_array, Array, ArrayRef, MutableArrayData, UInt32Array}; pub use arrow::compute::SortOptions; use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions}; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; +use arrow::row::{Row, RowConverter, SortField}; use datafusion_physical_expr::EquivalenceProperties; use futures::{Stream, StreamExt, TryStreamExt}; use log::{debug, error}; @@ -54,12 +57,13 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::fs::File; use std::io::BufReader; +use std::io::{Read, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::task::{Context, Poll}; use tempfile::NamedTempFile; -use tokio::sync::mpsc::{Receiver, Sender}; -use tokio::task; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::task::{self, JoinHandle}; /// Sort arbitrary size of data to get a total order (may spill several times during sorting based on free memory available). /// @@ -73,7 +77,7 @@ use tokio::task; struct ExternalSorter { schema: SchemaRef, in_mem_batches: Vec, - spills: Vec, + spills: Vec, /// Sort expressions expr: Vec, session_config: Arc, @@ -83,6 +87,15 @@ struct ExternalSorter { fetch: Option, reservation: MemoryReservation, partition_id: usize, + use_row_encoding: bool, + // if this flag is true, the output of the sort will + // have non-None `RowBatch` + preserve_output_rows: bool, +} +struct Spill { + record_batch_file: NamedTempFile, + // `None` when row encoding not preserved + rows_file: Option, } impl ExternalSorter { @@ -100,7 +113,6 @@ impl ExternalSorter { let reservation = MemoryConsumer::new(format!("ExternalSorter[{partition_id}]")) .with_can_spill(true) .register(&runtime.memory_pool); - Self { schema, in_mem_batches: vec![], @@ -113,9 +125,18 @@ impl ExternalSorter { fetch, reservation, partition_id, + preserve_output_rows: false, + use_row_encoding: false, } } + pub fn set_preserve_output_rows(&mut self, val: bool) { + self.preserve_output_rows = val; + } + fn set_use_row_encoding(&mut self, val: bool) { + self.use_row_encoding = val; + } + async fn insert_batch( &mut self, input: RecordBatch, @@ -132,12 +153,24 @@ impl ExternalSorter { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. let _timer = tracking_metrics.elapsed_compute().timer(); - let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?; + let partial = sort_batch( + input, + self.schema.clone(), + &self.expr, + self.fetch, + self.use_row_encoding, + )?; // The resulting batch might be smaller (or larger, see #3747) than the input // batch due to either a propagated limit or the re-construction of arrays. So // for being reliable, we need to reflect the memory usage of the partial batch. - let new_size = batch_byte_size(&partial.sorted_batch); + // + // In addition, if it's row encoding was preserved, that would also change the size. + let new_size = batch_byte_size(&partial.sorted_batch) + + match partial.sort_data { + SortData::Rows(ref rows) => rows.size(), + SortData::Arrays(_) => 0, + }; match new_size.cmp(&size) { Ordering::Greater => { // We don't have to call try_grow here, since we have already used the @@ -165,7 +198,7 @@ impl ExternalSorter { } /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`. - fn sort(&mut self) -> Result { + fn sort(&mut self) -> Result { let batch_size = self.session_config.batch_size(); if self.spilled_before() { @@ -174,7 +207,7 @@ impl ExternalSorter { .new_intermediate_tracking(self.partition_id, &self.runtime.memory_pool); let mut streams: Vec = vec![]; if !self.in_mem_batches.is_empty() { - let in_mem_stream = in_mem_partial_sort( + let mut stream = in_mem_partial_sort( &mut self.in_mem_batches, self.schema.clone(), &self.expr, @@ -183,40 +216,52 @@ impl ExternalSorter { self.fetch, )?; let prev_used = self.reservation.free(); - streams.push(SortedStream::new(in_mem_stream, prev_used)); + stream.mem_used = prev_used; + streams.push(stream); } - + let sort_fields = self + .expr + .iter() + .map(|e| { + Ok(SortField::new_with_options( + e.expr.data_type(&self.schema)?, + e.options, + )) + }) + .collect::>>()?; for spill in self.spills.drain(..) { - let stream = read_spill_as_stream(spill, self.schema.clone())?; - streams.push(SortedStream::new(stream, 0)); + let (rx, handle) = read_spill_as_stream(spill, sort_fields.to_owned())?; + streams.push(SortedStream::new_from_rx(rx, handle, 0)); } let tracking_metrics = self .metrics_set .new_final_tracking(self.partition_id, &self.runtime.memory_pool); - Ok(Box::pin(SortPreservingMergeStream::new_from_streams( + let sort_stream = SortPreservingMergeStream::new_from_streams( streams, self.schema.clone(), &self.expr, tracking_metrics, self.session_config.batch_size(), - )?)) + self.preserve_output_rows, + )?; + Ok(SortedStream::new(Box::pin(sort_stream), 0)) } else if !self.in_mem_batches.is_empty() { let tracking_metrics = self .metrics_set .new_final_tracking(self.partition_id, &self.runtime.memory_pool); - let result = in_mem_partial_sort( + let stream = in_mem_partial_sort( &mut self.in_mem_batches, self.schema.clone(), &self.expr, batch_size, tracking_metrics, self.fetch, - ); + )?; // Report to the memory manager we are no longer using memory self.reservation.free(); - result + Ok(stream) } else { - Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) + Ok(SortedStream::empty()) } } @@ -237,29 +282,43 @@ impl ExternalSorter { if self.in_mem_batches.is_empty() { return Ok(0); } - debug!("Spilling sort data of ExternalSorter to disk whilst inserting"); let tracking_metrics = self .metrics_set .new_intermediate_tracking(self.partition_id, &self.runtime.memory_pool); - let spillfile = self.runtime.disk_manager.create_tmp_file("Sorting")?; - let stream = in_mem_partial_sort( + let mut stream = in_mem_partial_sort( &mut self.in_mem_batches, self.schema.clone(), &self.expr, self.session_config.batch_size(), tracking_metrics, self.fetch, - ); - - spill_partial_sorted_stream(&mut stream?, spillfile.path(), self.schema.clone()) - .await?; + )?; + let rows_file = if stream.row_encoding_ignored { + None + } else { + Some( + self.runtime + .disk_manager + .create_tmp_file("Sorting row encodings")?, + ) + }; + spill_partial_sorted_stream( + &mut stream, + spillfile.path(), + rows_file.as_ref().map(|f| f.path()), + self.schema.clone(), + ) + .await?; self.reservation.free(); let used = self.metrics.mem_used().set(0); self.metrics.record_spill(used); - self.spills.push(spillfile); + self.spills.push(Spill { + record_batch_file: spillfile, + rows_file, + }); Ok(used) } } @@ -282,40 +341,73 @@ fn in_mem_partial_sort( batch_size: usize, tracking_metrics: MemTrackingMetrics, fetch: Option, -) -> Result { +) -> Result { assert_ne!(buffered_batches.len(), 0); if buffered_batches.len() == 1 { - let result = buffered_batches.pop(); - Ok(Box::pin(SizedRecordBatchStream::new( + let result = buffered_batches.pop().unwrap(); + let BatchWithSortArray { + sort_data, + sorted_batch, + } = result; + let rowbatch: Option = match sort_data { + SortData::Rows(rows) => Some(rows.into()), + SortData::Arrays(_) => None, + }; + let stream = Box::pin(SizedRecordBatchStream::new( schema, - vec![Arc::new(result.unwrap().sorted_batch)], + vec![Arc::new(sorted_batch)], tracking_metrics, - ))) + )); + if let Some(rowbatch) = rowbatch { + Ok(SortedStream::new_from_streams( + stream, + 0, + Box::pin(futures::stream::once(futures::future::ready(Some( + rowbatch, + )))), + )) + } else { + Ok(SortedStream::new_no_row_encoding(stream, 0)) + } } else { - let (sorted_arrays, batches): (Vec>, Vec) = - buffered_batches - .drain(..) - .map(|b| { - let BatchWithSortArray { - sort_arrays, - sorted_batch: batch, - } = b; - (sort_arrays, batch) - }) - .unzip(); + let (sort_data, batches): (Vec, Vec) = buffered_batches + .drain(..) + .map(|b| { + let BatchWithSortArray { + sort_data, + sorted_batch: batch, + } = b; + (sort_data, batch) + }) + .unzip(); let sorted_iter = { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. let _timer = tracking_metrics.elapsed_compute().timer(); - get_sorted_iter(&sorted_arrays, expressions, batch_size, fetch)? + get_sorted_iter(&sort_data, expressions, batch_size, fetch)? }; - Ok(Box::pin(SortedSizedRecordBatchStream::new( + let rows = sort_data + .into_iter() + .map(|d| match d { + SortData::Rows(rows) => Some(rows), + SortData::Arrays(_) => None, + }) + .collect::>>(); + let used_rows = rows.is_some(); + let batch_stream = SortedSizedStream::new( schema, batches, sorted_iter, tracking_metrics, - ))) + rows.map(|rs| rs.into_iter().map(Arc::new).collect()), + ) + .boxed(); + let mut stream = SortedStream::new(batch_stream, 0); + if !used_rows { + stream.row_encoding_ignored = true; + } + Ok(stream) } } @@ -327,16 +419,16 @@ struct CompositeIndex { /// Get sorted iterator by sort concatenated `SortColumn`s fn get_sorted_iter( - sort_arrays: &[Vec], + sort_data: &[SortData], expr: &[PhysicalSortExpr], batch_size: usize, fetch: Option, ) -> Result { - let row_indices = sort_arrays + let row_indices = sort_data .iter() .enumerate() - .flat_map(|(i, arrays)| { - (0..arrays[0].len()).map(move |r| CompositeIndex { + .flat_map(|(i, d)| { + (0..d.num_rows()).map(move |r| CompositeIndex { // since we original use UInt32Array to index the combined mono batch, // component record batches won't overflow as well, // use u32 here for space efficiency. @@ -345,22 +437,52 @@ fn get_sorted_iter( }) }) .collect::>(); - - let sort_columns = expr + let rows_per_batch: Option> = sort_data .iter() - .enumerate() - .map(|(i, expr)| { - let columns_i = sort_arrays - .iter() - .map(|cs| cs[i].as_ref()) - .collect::>(); - Ok(SortColumn { - values: concat(columns_i.as_slice())?, - options: Some(expr.options), - }) + .map(|d| match d { + SortData::Rows(ref rows) => Some(rows), + SortData::Arrays(_) => None, }) - .collect::>>()?; - let indices = lexsort_to_indices(&sort_columns, fetch)?; + .collect(); + let indices = match rows_per_batch { + Some(rows_per_batch) => { + let mut to_sort = rows_per_batch + .iter() + .flat_map(|r| r.iter()) + .enumerate() + .collect::>(); + // NB: according to the rust docs, `sort` is a mergesort (while + // `sort_unstable` is quicksort.) so right here, `sort` should be faster + // since we are sorting a bunch of concatenated sorted sequences. + to_sort.sort_by(|(_, row_a), (_, row_b)| row_a.cmp(row_b)); + let limit = match fetch { + Some(lim) => lim.min(to_sort.len()), + None => to_sort.len(), + }; + UInt32Array::from_iter(to_sort.iter().take(limit).map(|(idx, _)| *idx as u32)) + } + None => { + let sort_columns = expr + .iter() + .enumerate() + .map(|(i, expr)| { + let columns_i = sort_data + .iter() + .map(|data| match data { + // todo fix + SortData::Rows(_) => unreachable!(), + SortData::Arrays(arrays) => arrays[i].as_ref(), + }) + .collect::>(); + Ok(SortColumn { + values: concat(columns_i.as_slice())?, + options: Some(expr.options), + }) + }) + .collect::>>()?; + lexsort_to_indices(&sort_columns, fetch)? + } + }; // Calculate composite index based on sorted indices let row_indices = indices @@ -471,38 +593,45 @@ fn group_indices( } /// Stream of sorted record batches -struct SortedSizedRecordBatchStream { +struct SortedSizedStream { schema: SchemaRef, batches: Vec, sorted_iter: SortedIterator, num_cols: usize, metrics: MemTrackingMetrics, + rows: Option>>, } -impl SortedSizedRecordBatchStream { +impl SortedSizedStream { /// new pub fn new( schema: SchemaRef, batches: Vec, sorted_iter: SortedIterator, mut metrics: MemTrackingMetrics, + rows: Option>>, ) -> Self { let size = batches.iter().map(batch_byte_size).sum::() - + sorted_iter.memory_size(); + + sorted_iter.memory_size() + // include rows if non-None + + rows + .as_ref() + .map_or(0, |r| r.iter().map(|r| r.size()).sum()); metrics.init_mem_used(size); let num_cols = batches[0].num_columns(); - SortedSizedRecordBatchStream { + SortedSizedStream { schema, batches, sorted_iter, + rows, num_cols, metrics, } } } -impl Stream for SortedSizedRecordBatchStream { - type Item = Result; +impl Stream for SortedSizedStream { + type Item = SortStreamItem; fn poll_next( mut self: std::pin::Pin<&mut Self>, @@ -512,6 +641,7 @@ impl Stream for SortedSizedRecordBatchStream { None => Poll::Ready(None), Some(slices) => { let num_rows = slices.iter().map(|s| s.len).sum(); + // create columns for record batch let output = (0..self.num_cols) .map(|i| { let arrays = self @@ -532,8 +662,33 @@ impl Stream for SortedSizedRecordBatchStream { .collect::>(); let batch = RecordBatch::try_new(self.schema.clone(), output).map_err(Into::into); - let poll = Poll::Ready(Some(batch)); - self.metrics.record_poll(poll) + match batch { + Ok(batch) => { + // construct `RowBatch` batch if sorted row encodings were preserved + let row_batch = self.rows.as_ref().map(|rows| { + let row_refs = + rows.iter().map(Arc::clone).collect::>(); + let indices = slices + .iter() + .flat_map(|s| { + (0..s.len).map(|i| { + ( + s.batch_idx as usize, + s.start_row_idx as usize + i, + ) + }) + }) + .collect::>(); + RowBatch::new(row_refs, indices) + }); + let poll = Poll::Ready(Some(Ok((batch, row_batch)))); + record_poll_sort_item(&self.metrics, poll) + } + Err(err) => { + let poll = Poll::Ready(Some(Err(err))); + record_poll_sort_item(&self.metrics, poll) + } + } } } } @@ -545,20 +700,17 @@ struct CompositeSlice { len: usize, } -impl RecordBatchStream for SortedSizedRecordBatchStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - async fn spill_partial_sorted_stream( - in_mem_stream: &mut SendableRecordBatchStream, + in_mem_stream: &mut SortedStream, path: &Path, + row_path: Option<&Path>, schema: SchemaRef, ) -> Result<()> { - let (sender, receiver) = tokio::sync::mpsc::channel(2); + let (sender, receiver) = mpsc::channel(2); let path: PathBuf = path.into(); - let handle = task::spawn_blocking(move || write_sorted(receiver, path, schema)); + let row_path = row_path.map(|p| p.to_path_buf()); + let handle = + task::spawn_blocking(move || write_sorted(receiver, path, row_path, schema)); while let Some(item) = in_mem_stream.next().await { sender.send(item).await.ok(); } @@ -572,48 +724,63 @@ async fn spill_partial_sorted_stream( } fn read_spill_as_stream( - path: NamedTempFile, - schema: SchemaRef, -) -> Result { - let (sender, receiver): (Sender>, Receiver>) = - tokio::sync::mpsc::channel(2); + spill: Spill, + sort_fields: Vec, +) -> Result<(mpsc::Receiver, JoinHandle<()>)> { + let (sender, receiver) = mpsc::channel::(2); let join_handle = task::spawn_blocking(move || { - if let Err(e) = read_spill(sender, path.path()) { - error!("Failure while reading spill file: {:?}. Error: {}", path, e); + if let Err(e) = read_spill(sender, &spill, sort_fields) { + error!( + "Failure while reading spill file: ({:?}, {:?}). Error: {}", + spill.record_batch_file, spill.rows_file, e + ); } }); - Ok(RecordBatchReceiverStream::create( - &schema, - receiver, - join_handle, - )) + Ok((receiver, join_handle)) } fn write_sorted( - mut receiver: Receiver>, + mut receiver: Receiver, path: PathBuf, + row_path: Option, schema: SchemaRef, ) -> Result<()> { let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; + let mut row_writer = RowWriter::try_new(row_path.as_ref())?; while let Some(batch) = receiver.blocking_recv() { - writer.write(&batch?)?; + let (recbatch, rows) = batch?; + writer.write(&recbatch)?; + row_writer.write(rows)?; } writer.finish()?; + row_writer.finish()?; debug!( "Spilled {} batches of total {} rows to disk, memory released {}", writer.num_batches, writer.num_rows, - human_readable_size(writer.num_bytes as usize), + human_readable_size( + writer.num_bytes as usize + row_writer.num_row_bytes as usize + ), ); Ok(()) } -fn read_spill(sender: Sender>, path: &Path) -> Result<()> { - let file = BufReader::new(File::open(path)?); +fn read_spill( + sender: Sender, + spill: &Spill, + sort_fields: Vec, +) -> Result<()> { + let file = BufReader::new(File::open(&spill.record_batch_file)?); let reader = FileReader::try_new(file, None)?; - for batch in reader { + let row_reader = RowReader::try_new(spill.rows_file.as_ref(), sort_fields)?; + for zipped in reader.zip(row_reader) { + let item = match zipped { + (Ok(batch), Ok(rows)) => Ok((batch, rows)), + (Err(err), Ok(_)) | (Err(err), Err(_)) => Err(err.into()), + (Ok(_), Err(err)) => Err(err), + }; sender - .blocking_send(batch.map_err(Into::into)) + .blocking_send(item) .map_err(|e| DataFusionError::Execution(format!("{e}")))?; } Ok(()) @@ -665,7 +832,6 @@ impl SortExec { fetch, } } - /// Input schema pub fn input(&self) -> &Arc { &self.input @@ -680,6 +846,83 @@ impl SortExec { pub fn fetch(&self) -> Option { self.fetch } + /// to be used by parent nodes to run execute that incldues the row + /// encodings in the result stream + pub(crate) fn execute_save_rows( + &self, + partition: usize, + context: Arc, + ) -> Result { + debug!("Start SortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + + debug!( + "Start invoking SortExec's input.execute for partition: {}", + partition + ); + + let input = self.input.execute(partition, context.clone())?; + + debug!("End SortExec's input.execute for partition: {}", partition); + Ok(Box::pin( + futures::stream::once(do_sort( + input, + partition, + self.expr.clone(), + self.metrics_set.clone(), + context, + self.fetch(), + true, + )) + .try_flatten(), + )) + } + /// to be used by parent nodes to spawn execution into tokio threadpool + /// and write results to `tx` + pub(crate) fn execution_spawn_save_rows( + &self, + partition: usize, + context: Arc, + tx: mpsc::Sender, + ) -> tokio::task::JoinHandle<()> { + let input = self.input.clone(); + let expr = self.expr.clone(); + let metrics = self.metrics_set.clone(); + let fetch = self.fetch(); + let disp = displayable(input.as_ref()).one_line().to_string(); + tokio::spawn(async move { + debug!("Start SortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + + debug!( + "Start invoking SortExec's input.execute for partition: {}", + partition + ); + let input = match input.execute(partition, context.clone()) { + Err(e) => { + tx.send(Err(e)).await.ok(); + return; + } + Ok(stream) => stream, + }; + debug!("End SortExec's input.execute for partition: {}", partition); + let mut sort_item_stream = match do_sort( + input, partition, expr, metrics, context, fetch, true, + ) + .await + { + Ok(stream) => stream, + Err(err) => { + tx.send(Err(err)).await.ok(); + return; + } + }; + while let Some(item) = sort_item_stream.next().await { + if tx.send(item).await.is_err() { + debug!("Stopping execution: output is gone, plan cancelling: {disp}"); + return; + } + } + }) + } } impl ExecutionPlan for SortExec { @@ -766,7 +1009,6 @@ impl ExecutionPlan for SortExec { let input = self.input.execute(partition, context.clone())?; debug!("End SortExec's input.execute for partition: {}", partition); - Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), futures::stream::once(do_sort( @@ -776,8 +1018,11 @@ impl ExecutionPlan for SortExec { self.metrics_set.clone(), context, self.fetch(), + // default execute shouldnt save row encodings + false, )) - .try_flatten(), + .try_flatten() + .map_ok(|(record_batch, _rows)| record_batch), ))) } @@ -808,8 +1053,23 @@ impl ExecutionPlan for SortExec { } } +enum SortData { + Rows(RowSelection), + Arrays(Vec), +} +impl SortData { + fn num_rows(&self) -> usize { + match self { + SortData::Rows(r) => r.num_rows(), + SortData::Arrays(a) => { + let first_col = &a[0]; + first_col.len() + } + } + } +} struct BatchWithSortArray { - sort_arrays: Vec, + sort_data: SortData, sorted_batch: RecordBatch, } @@ -818,13 +1078,37 @@ fn sort_batch( schema: SchemaRef, expr: &[PhysicalSortExpr], fetch: Option, + use_row_encoding: bool, ) -> Result { let sort_columns = expr .iter() .map(|e| e.evaluate_to_sort_column(&batch)) .collect::>>()?; - - let indices = lexsort_to_indices(&sort_columns, fetch)?; + let (indices, rows) = match use_row_encoding { + // if single column or there's a limit, fallback to regular sort + false => (lexsort_to_indices(&sort_columns, fetch)?, None), + _ => { + let sort_fields = sort_columns + .iter() + .map(|c| { + let datatype = c.values.data_type().to_owned(); + SortField::new_with_options(datatype, c.options.unwrap_or_default()) + }) + .collect::>(); + let arrays: Vec = + sort_columns.iter().map(|c| c.values.clone()).collect(); + let mut row_converter = RowConverter::new(sort_fields)?; + let rows = row_converter.convert_columns(&arrays)?; + + let mut to_sort: Vec<(usize, Row)> = rows.into_iter().enumerate().collect(); + to_sort.sort_unstable_by(|(_, row_a), (_, row_b)| row_a.cmp(row_b)); + let sorted_indices = to_sort.iter().map(|(idx, _)| *idx).collect::>(); + ( + UInt32Array::from_iter(sorted_indices.iter().map(|i| *i as u32)), + Some(RowSelection::new(rows, sorted_indices)), + ) + } + }; // reorder all rows based on sorted indices let sorted_batch = RecordBatch::try_new( @@ -845,26 +1129,31 @@ fn sort_batch( }) .collect::, ArrowError>>()?, )?; - - let sort_arrays = sort_columns - .into_iter() - .map(|sc| { - Ok(take( - sc.values.as_ref(), - &indices, - Some(TakeOptions { - check_bounds: false, - }), - )?) - }) - .collect::>>()?; + let sort_data = match rows { + Some(rows) => SortData::Rows(rows), + None => { + // only need sort_arrays when we dont have rows. + let sort_arrays = sort_columns + .into_iter() + .map(|sc| { + Ok(take( + sc.values.as_ref(), + &indices, + Some(TakeOptions { + check_bounds: false, + }), + )?) + }) + .collect::>>()?; + SortData::Arrays(sort_arrays) + } + }; Ok(BatchWithSortArray { - sort_arrays, + sort_data, sorted_batch, }) } - async fn do_sort( mut input: SendableRecordBatchStream, partition_id: usize, @@ -872,13 +1161,15 @@ async fn do_sort( metrics_set: CompositeMetricsSet, context: Arc, fetch: Option, -) -> Result { + preserve_rows: bool, +) -> Result { debug!( "Start do_sort for partition {} of context session_id {} and task_id {:?}", partition_id, context.session_id(), context.task_id() ); + let n_sort_cols = expr.len(); let schema = input.schema(); let tracking_metrics = metrics_set.new_intermediate_tracking(partition_id, context.memory_pool()); @@ -891,11 +1182,58 @@ async fn do_sort( context.runtime_env(), fetch, ); - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch, &tracking_metrics).await?; + if preserve_rows { + sorter.set_preserve_output_rows(true); + } + sorter.set_use_row_encoding(match (n_sort_cols, fetch) { + // if single column or there's a limit, fallback to regular sort + (1, None) | (_, Some(_)) => false, + _ => true, + }); + if sorter.use_row_encoding { + // wait til more than 1 batch is seen before inserting the first batch + // (still maintains the order, just inserts first 2 batches together). + // then if theres only a single batch to sort, we dont use row encoding + let mut first_batch = Vec::with_capacity(1); + let mut inserted_first = false; + while let Some(batch) = input.next().await { + let batch = batch?; + match (inserted_first, first_batch.is_empty()) { + (false, true) => { + first_batch.push(batch); + } + (false, false) => { + // maintain batch insertion order + sorter + .insert_batch(first_batch.pop().unwrap(), &tracking_metrics) + .await?; + sorter.insert_batch(batch, &tracking_metrics).await?; + inserted_first = true; + } + (true, true) => { + sorter.insert_batch(batch, &tracking_metrics).await?; + } + (true, false) => { + unreachable!() + } + } + } + if !inserted_first && !first_batch.is_empty() { + // only one batch was inserted, dont use row encoding + sorter.set_use_row_encoding(false); + assert!(!sorter.spilled_before()); + assert!(sorter.in_mem_batches.is_empty()); + let batch = first_batch.pop().unwrap(); + sorter.insert_batch(batch, &tracking_metrics).await?; + } + } else { + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch, &tracking_metrics).await?; + } } let result = sorter.sort(); + debug!( "End do_sort for partition {} of context session_id {} and task_id {:?}", partition_id, @@ -905,6 +1243,146 @@ async fn do_sort( result } +/// manages writing potential rows to and from disk +struct RowWriter { + // serializing w/ arrow ipc format for maximum code simplicity... probably sub-optimal + file: Option, + num_row_bytes: u32, +} +const MAGIC_BYTES: &[u8] = b"AROW"; +impl RowWriter { + fn try_new(path: Option>) -> Result { + match path { + Some(p) => { + let mut file = File::create(p)?; + file.write_all(MAGIC_BYTES)?; + Ok(Self { + file: Some(file), + num_row_bytes: 0, + }) + } + None => Ok(Self { + file: None, + num_row_bytes: 0, + }), + } + } + fn write(&mut self, rows: Option) -> Result<()> { + match (rows, self.file.as_mut()) { + (Some(rows), Some(file)) => { + file.write_all(&(rows.num_rows() as u32).to_le_bytes())?; + for row in rows.iter() { + let bytes: &[u8] = row.as_ref(); + let num_bytes = bytes.len() as u32; + self.num_row_bytes += num_bytes; + file.write_all(&num_bytes.to_le_bytes())?; + file.write_all(bytes)?; + } + Ok(()) + } + // no-op + _ => Ok(()), + } + } + fn finish(&mut self) -> Result<()> { + if let Some(file) = self.file.as_mut() { + file.flush()?; + Ok(()) + } else { + Ok(()) + } + } +} + +/// manages reading potential rows to and from disk. +struct RowReader { + /// temporary file format solution is storing it w/ arrow IPC + file: Option, + row_conv: RowConverter, + stopped: bool, +} +impl RowReader { + fn try_new( + path: Option>, + sort_fields: Vec, + ) -> Result { + let row_conv = RowConverter::new(sort_fields)?; + match path { + Some(p) => { + let mut file = File::open(p)?; + let mut buf = [0_u8; 4]; + file.read_exact(&mut buf)?; + if buf != MAGIC_BYTES { + return Err(DataFusionError::Internal( + "unexpected magic bytes in serialized rows file".to_owned(), + )); + } + Ok(Self { + file: Some(file), + row_conv, + stopped: false, + }) + } + None => Ok(Self { + file: None, + row_conv, + stopped: false, + }), + } + } + + fn read_batch(&mut self) -> Result> { + let file = self.file.as_mut().unwrap(); + let mut buf = [0_u8; 4]; + match file.read_exact(&mut buf) { + Ok(_) => {} + Err(io_err) => { + if io_err.kind() == std::io::ErrorKind::UnexpectedEof { + return Ok(None); + } + return Err(io_err.into()); + } + } + let num_rows = u32::from_le_bytes(buf); + let mut bytes: Vec> = Vec::with_capacity(num_rows as usize); + for _ in 0..num_rows { + let mut buf = [0_u8; 4]; + file.read_exact(&mut buf)?; + let n = u32::from_le_bytes(buf); + let mut buf = vec![0_u8; n as usize]; + file.read_exact(&mut buf)?; + bytes.push(buf); + } + Ok(Some( + RowSelection::from_spilled(self.row_conv.parser(), bytes).into(), + )) + } +} +impl Iterator for RowReader { + type Item = Result>; + + fn next(&mut self) -> Option { + if self.stopped { + return None; + } + if self.file.is_some() { + let res = self.read_batch(); + match res { + Ok(Some(batch)) => Some(Ok(Some(batch))), + Ok(None) => None, + Err(err) => { + self.stopped = true; + Some(Err(err)) + } + } + } else { + // will be zipped with the main record batch reader so + // just yield None forever + Some(Ok(None)) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -925,6 +1403,62 @@ mod tests { use futures::FutureExt; use std::collections::HashMap; + #[test] + fn test_row_writer_reader() { + use crate::prelude::SessionContext; + use arrow::array::{Int64Array, StringArray}; + use arrow::datatypes::DataType; + let sort_fields = vec![ + SortField::new(DataType::Int64), + SortField::new(DataType::Utf8), + ]; + let mut conv = RowConverter::new(sort_fields.to_owned()).unwrap(); + + fn makebatch(n: i64) -> RecordBatch { + let ints: Int64Array = (0..n).map(Some).collect(); + let varlengths: StringArray = + StringArray::from_iter((0..n).map(|i| i + 100).map(|i| { + if i % 3 == 0 { + None + } else { + Some((i.pow(2)).to_string()) + } + })); + RecordBatch::try_from_iter(vec![ + ("c1", Arc::new(ints) as _), + ("c2", Arc::new(varlengths) as _), + ]) + .unwrap() + } + let row_lens = vec![10, 0, 0, 1, 50]; + let batches = row_lens.iter().map(|i| makebatch(*i)).collect::>(); + let rows = batches + .iter() + .map(|b| conv.convert_columns(b.columns()).unwrap()) + .collect::>(); + + let ctx = SessionContext::new(); + let runtime = ctx.runtime_env(); + let tempfile = runtime.disk_manager.create_tmp_file("Sorting").unwrap(); + let mut wr = RowWriter::try_new(Some(tempfile.path())).unwrap(); + for r in rows { + wr.write(Some(r.into())).unwrap(); + } + wr.finish().unwrap(); + + let rdr = RowReader::try_new(Some(tempfile.path()), sort_fields).unwrap(); + let batches = rdr.collect::>(); + assert_eq!(batches.len(), row_lens.len()); + let read_lens = batches + .iter() + .map(|b| { + let rowbatch = b.as_ref().unwrap().as_ref().unwrap(); + rowbatch.num_rows() as i64 + }) + .collect::>(); + assert_eq!(row_lens, read_lens); + } + #[tokio::test] async fn test_in_mem_sort() -> Result<()> { let session_ctx = SessionContext::new(); @@ -985,7 +1519,9 @@ mod tests { #[tokio::test] async fn test_sort_spill() -> Result<()> { // trigger spill there will be 4 batches with 5.5KB for each - let config = RuntimeConfig::new().with_memory_limit(12288, 1.0); + // plus 1289 bytes of row data for each batch + let row_size = 1289; + let config = RuntimeConfig::new().with_memory_limit(12288 + (row_size * 4), 1.0); let runtime = Arc::new(RuntimeEnv::new(config)?); let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); @@ -1051,11 +1587,46 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_spill_no_row_encoding_edge_case() -> Result<()> { + // trigger spill there will be 4 batches with 5.5KB for each + let config = RuntimeConfig::new().with_memory_limit(12288, 1.0); + let runtime = Arc::new(RuntimeEnv::new(config)?); + let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); + + let partitions = 4; + let csv = test::scan_partitioned_csv(partitions)?; + let schema = csv.schema(); + + let sort_exec = Arc::new(SortExec::try_new( + vec![ + // c2 uin32 column + PhysicalSortExpr { + expr: col("c2", &schema)?, + options: SortOptions::default(), + }, + ], + Arc::new(CoalescePartitionsExec::new(csv)), + None, + )?); + let task_ctx = session_ctx.task_ctx(); + let result = collect(sort_exec.clone(), task_ctx).await?; + assert_eq!(result.len(), 1); + assert_eq!( + session_ctx.runtime_env().memory_pool.reserved(), + 0, + "The sort should have returned all memory used back to the memory manager" + ); + + Ok(()) + } + #[tokio::test] async fn test_sort_fetch_memory_calculation() -> Result<()> { // This test mirrors down the size from the example above. let avg_batch_size = 6000; let partitions = 4; + let added_row_size = 1289 * partitions; // A tuple of (fetch, expect_spillage) let test_options = vec![ @@ -1068,8 +1639,10 @@ mod tests { ]; for (fetch, expect_spillage) in test_options { - let config = RuntimeConfig::new() - .with_memory_limit(avg_batch_size * (partitions - 1), 1.0); + let config = RuntimeConfig::new().with_memory_limit( + avg_batch_size * (partitions - 1) + added_row_size, + 1.0, + ); let runtime = Arc::new(RuntimeEnv::new(config)?); let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 7ef4d3bf8e86..5312933a3bfd 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -30,21 +30,23 @@ use arrow::{ record_batch::RecordBatch, }; use futures::stream::{Fuse, FusedStream}; -use futures::{ready, Stream, StreamExt}; +use futures::{ready, Stream, StreamExt, TryStreamExt}; use log::debug; use tokio::sync::mpsc; +use super::{record_poll_sort_item, RowBatch, RowSelection, SortStreamItem}; use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::physical_plan::metrics::{ ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet, }; +use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream}; -use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, - Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, + Distribution, ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, + Statistics, }; use datafusion_physical_expr::EquivalenceProperties; @@ -192,46 +194,65 @@ impl ExecutionPlan for SortPreservingMergeExec { let receivers = match tokio::runtime::Handle::try_current() { Ok(_) => (0..input_partitions) .map(|part_i| { - let (sender, receiver) = mpsc::channel(1); - let join_handle = spawn_execution( - self.input.clone(), - sender, - part_i, - context.clone(), - ); - - SortedStream::new( - RecordBatchReceiverStream::create( - &schema, - receiver, - join_handle, - ), - 0, - ) + if let Some(sort_plan) = + self.input.as_any().downcast_ref::() + { + let (tx, rx) = mpsc::channel(1); + let join_handle = sort_plan.execution_spawn_save_rows( + part_i, + context.clone(), + tx, + ); + SortedStream::new_from_rx(rx, join_handle, 0) + } else { + let (sender, receiver) = mpsc::channel(1); + let join_handle = spawn_execution( + self.input.clone(), + sender, + part_i, + context.clone(), + ); + SortedStream::new_no_row_encoding( + RecordBatchReceiverStream::create( + &schema, + receiver, + join_handle, + ), + 0, + ) + } }) .collect(), Err(_) => (0..input_partitions) .map(|partition| { - let stream = - self.input.execute(partition, context.clone())?; - Ok(SortedStream::new(stream, 0)) + if let Some(sort_plan) = + self.input.as_any().downcast_ref::() + { + let sortstream = sort_plan + .execute_save_rows(partition, context.clone())?; + Ok(SortedStream::new(sortstream, 0)) + } else { + let stream = + self.input.execute(partition, context.clone())?; + Ok(SortedStream::new_no_row_encoding(stream, 0)) + } }) .collect::>()?, }; - debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"); - let result = Box::pin(SortPreservingMergeStream::new_from_streams( + let result = SortPreservingMergeStream::new_from_streams( receivers, schema, &self.expr, tracking_metrics, context.session_config().batch_size(), - )?); + // dont emit row encodings for this plan + false, + )?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); - - Ok(result) + Ok(result.into()) } } } @@ -260,7 +281,7 @@ impl ExecutionPlan for SortPreservingMergeExec { struct MergingStreams { /// The sorted input streams to merge together - streams: Vec>, + streams: Vec>, /// number of streams num_streams: usize, } @@ -274,7 +295,7 @@ impl std::fmt::Debug for MergingStreams { } impl MergingStreams { - fn new(input_streams: Vec>) -> Self { + fn new(input_streams: Vec>) -> Self { Self { num_streams: input_streams.len(), streams: input_streams, @@ -298,8 +319,7 @@ pub(crate) struct SortPreservingMergeStream { /// /// Exhausted batches will be popped off the front once all /// their rows have been yielded to the output - batches: Vec>, - + batches: Vec)>>, /// The accumulated row indexes for the next record batch in_progress: Vec, @@ -343,6 +363,10 @@ pub(crate) struct SortPreservingMergeStream { /// row converter row_converter: RowConverter, + /// if this is false it will always yield None for the row encoding + /// this is true when `SortPreservingMergeStream` is used within `SortExec` + /// but not when its used in `SortPreservingMergeStream` + preserve_row_encoding: bool, } impl SortPreservingMergeStream { @@ -352,11 +376,12 @@ impl SortPreservingMergeStream { expressions: &[PhysicalSortExpr], mut tracking_metrics: MemTrackingMetrics, batch_size: usize, + // when used from within SortExec this should be true + preserve_row_encoding: bool, ) -> Result { let stream_count = streams.len(); let batches = (0..stream_count).map(|_| VecDeque::new()).collect(); tracking_metrics.init_mem_used(streams.iter().map(|s| s.mem_used).sum()); - let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect(); let sort_fields = expressions .iter() @@ -370,7 +395,7 @@ impl SortPreservingMergeStream { Ok(Self { schema, batches, - streams: MergingStreams::new(wrappers), + streams: MergingStreams::new(streams.into_iter().map(|s| s.fuse()).collect()), column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), tracking_metrics, aborted: false, @@ -381,6 +406,7 @@ impl SortPreservingMergeStream { loser_tree_adjusted: false, batch_size, row_converter, + preserve_row_encoding, }) } @@ -413,30 +439,50 @@ impl SortPreservingMergeStream { Some(Err(e)) => { return Poll::Ready(Err(e)); } - Some(Ok(batch)) => { + Some(Ok((batch, preserved_rows))) => { if batch.num_rows() > 0 { - let cols = self - .column_expressions - .iter() - .map(|expr| { - Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())) - }) - .collect::>>()?; - - let rows = match self.row_converter.convert_columns(&cols) { - Ok(rows) => rows, - Err(e) => { - return Poll::Ready(Err(DataFusionError::ArrowError(e))); + // use preserved row encoding if it existed, otherwise create now + let rows = match preserved_rows { + Some(rows) => { + // dbg!(&rows); + rows + } + None => { + let cols = self + .column_expressions + .iter() + .map(|expr| { + Ok(expr + .evaluate(&batch)? + .into_array(batch.num_rows())) + }) + .collect::>>()?; + match self.row_converter.convert_columns(&cols) { + // creates RowBatch where RowSelection refs array is empty + // as well as indices array + Ok(rows) => rows.into(), + Err(e) => { + return Poll::Ready(Err( + DataFusionError::ArrowError(e), + )); + } + } } }; - + // if this stream should emit the row encoding, save it in + // batches so that the sorted rows can be constructed + // when the sroted record batches are + if self.preserve_row_encoding { + self.batches[idx].push_back((batch, Some(rows.clone()))) + } else { + self.batches[idx].push_back((batch, None)) + } self.cursors[idx] = Some(SortKeyCursor::new( idx, self.next_batch_id, // assign this batch an ID rows, )); self.next_batch_id += 1; - self.batches[idx].push_back(batch) } else { empty_batch = true; } @@ -454,7 +500,7 @@ impl SortPreservingMergeStream { /// Drains the in_progress row indexes, and builds a new RecordBatch from them /// /// Will then drop any batches for which all rows have been yielded to the output - fn build_record_batch(&mut self) -> Result { + fn build_record_batch(&mut self) -> SortStreamItem { // Mapping from stream index to the index of the first buffer from that stream let mut buffer_idx = 0; let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len()); @@ -474,7 +520,9 @@ impl SortPreservingMergeStream { .batches .iter() .flat_map(|batch| { - batch.iter().map(|batch| batch.column(column_idx).data()) + batch + .iter() + .map(|(batch, _rows)| batch.column(column_idx).data()) }) .collect(); @@ -518,6 +566,74 @@ impl SortPreservingMergeStream { make_arrow_array(array_data.freeze()) }) .collect(); + let rows = if self.preserve_row_encoding { + if self.in_progress.is_empty() { + Some(RowBatch::new(vec![], vec![])) + } else { + let rows = self + .batches + .iter() + .flat_map(|batch| { + batch.iter().map(|(_, rows)| { + rows.as_ref().expect( + "if preserve_row_encoding was true \ + then row data should've been saved in batch", + ) + }) + }) + .collect::>(); + // let stream_idx_count = self.in_progress.iter().map(|v| (v.batch_idx, v.stream_idx)).unique().count(); + let mut new_indices: Vec<(usize, usize)> = + Vec::with_capacity(self.in_progress.len()); + let mut new_rows: Vec> = + Vec::with_capacity(rows.iter().map(|r| r.num_rows()).sum()); + // map index of `rows` to the location in `new_rows` + let mut offsets: Vec> = vec![None; rows.len()]; + let first = &self.in_progress[0]; + let mut buffer_idx = + stream_to_buffer_idx[first.stream_idx] + first.batch_idx; + let mut start_row_idx = first.row_idx; + let mut end_row_idx = start_row_idx + 1; + for row_index in self.in_progress.iter().skip(1) { + let next_buffer_idx = + stream_to_buffer_idx[row_index.stream_idx] + row_index.batch_idx; + + if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { + // subsequent row in same batch + end_row_idx += 1; + continue; + } + let row_batch = rows[buffer_idx]; + let offset = match offsets[buffer_idx] { + Some(offset) => offset, + None => { + let offset = new_rows.len(); + new_rows.extend(row_batch.rows.iter().map(Arc::clone)); + offsets[buffer_idx] = Some(offset); + offset + } + }; + let row_indices = &row_batch.indices[start_row_idx..end_row_idx]; + new_indices + .extend(row_indices.iter().map(|(x, y)| (*x + offset, *y))); + new_rows.extend(row_batch.rows.iter().map(Arc::clone)); + // start new batch of rows + buffer_idx = next_buffer_idx; + start_row_idx = row_index.row_idx; + end_row_idx = start_row_idx + 1; + } + // emit final batch of rows + let row_batch = rows[buffer_idx]; + let row_indices = &row_batch.indices[start_row_idx..end_row_idx]; + new_indices + .extend(row_indices.iter().map(|(x, y)| (*x + new_rows.len(), *y))); + new_rows.extend(row_batch.rows.iter().map(Arc::clone)); + assert_eq!(new_indices.len(), self.in_progress.len()); + Some(RowBatch::new(new_rows, new_indices)) + } + } else { + None as Option + }; self.in_progress.clear(); @@ -536,19 +652,21 @@ impl SortPreservingMergeStream { } } - RecordBatch::try_new(self.schema.clone(), columns).map_err(Into::into) + RecordBatch::try_new(self.schema.clone(), columns) + .map(|batch| (batch, rows)) + .map_err(Into::into) } } impl Stream for SortPreservingMergeStream { - type Item = Result; + type Item = SortStreamItem; fn poll_next( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { let poll = self.poll_next_inner(cx); - self.tracking_metrics.record_poll(poll) + record_poll_sort_item(&self.tracking_metrics, poll) } } @@ -557,7 +675,7 @@ impl SortPreservingMergeStream { fn poll_next_inner( self: &mut Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll> { if self.aborted { return Poll::Ready(None); } @@ -582,7 +700,6 @@ impl SortPreservingMergeStream { .as_mut() .filter(|cursor| !cursor.is_finished()) .map(|cursor| (cursor.stream_idx(), cursor.advance())); - if let Some((stream_idx, row_idx)) = next { self.loser_tree_adjusted = false; let batch_idx = self.batches[stream_idx].len() - 1; @@ -699,10 +816,12 @@ impl SortPreservingMergeStream { Poll::Ready(Ok(())) } } - -impl RecordBatchStream for SortPreservingMergeStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() +impl From for SendableRecordBatchStream { + fn from(value: SortPreservingMergeStream) -> Self { + Box::pin(RecordBatchStreamAdapter::new( + value.schema.clone(), + value.into_stream().map_ok(|(rb, _rows)| rb), + )) } } @@ -714,7 +833,6 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use futures::FutureExt; - use tokio_stream::StreamExt; use crate::arrow::array::{Int32Array, StringArray, TimestampNanosecondArray}; use crate::from_slice::FromSlice; @@ -1287,7 +1405,7 @@ mod tests { } }); - streams.push(SortedStream::new( + streams.push(SortedStream::new_no_row_encoding( RecordBatchReceiverStream::create(&schema, receiver, join_handle), 0, )); @@ -1296,17 +1414,16 @@ mod tests { let metrics = ExecutionPlanMetricsSet::new(); let tracking_metrics = MemTrackingMetrics::new(&metrics, task_ctx.memory_pool(), 0); - let merge_stream = SortPreservingMergeStream::new_from_streams( streams, batches.schema(), sort.as_slice(), tracking_metrics, task_ctx.session_config().batch_size(), + false, ) .unwrap(); - - let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); + let mut merged = common::collect(merge_stream.into()).await.unwrap(); assert_eq!(merged.len(), 1); let merged = merged.remove(0); diff --git a/datafusion/core/tests/sort_key_cursor.rs b/datafusion/core/tests/sort_key_cursor.rs index 7d03ffc87bf5..0cea5caa978d 100644 --- a/datafusion/core/tests/sort_key_cursor.rs +++ b/datafusion/core/tests/sort_key_cursor.rs @@ -186,7 +186,7 @@ impl CursorBuilder { SortKeyCursor::new( stream_idx.expect("stream idx not set"), batch_id.expect("batch id not set"), - rows, + rows.into(), ) } }