diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 77190c5abcbb..7a87406afebc 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -34,16 +34,13 @@ use crate::serde::protobuf::ShuffleWritePartition; use crate::serde::scheduler::PartitionStats; use async_trait::async_trait; use datafusion::arrow::array::{ - Array, ArrayBuilder, ArrayRef, StringBuilder, StructBuilder, UInt32Builder, - UInt64Builder, + ArrayBuilder, ArrayRef, StringBuilder, StructBuilder, UInt32Builder, UInt64Builder, }; -use datafusion::arrow::compute::take; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; use datafusion::physical_plan::common::IPCWriter; -use datafusion::physical_plan::hash_utils::create_hashes; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::metrics::{ self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -55,6 +52,7 @@ use datafusion::physical_plan::{ use futures::StreamExt; use datafusion::execution::context::TaskContext; +use datafusion::physical_plan::repartition::BatchPartitioner; use log::{debug, info}; /// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and @@ -81,6 +79,7 @@ pub struct ShuffleWriterExec { struct ShuffleWriteMetrics { /// Time spend writing batches to shuffle files write_time: metrics::Time, + repart_time: metrics::Time, input_rows: metrics::Count, output_rows: metrics::Count, } @@ -88,6 +87,8 @@ struct ShuffleWriteMetrics { impl ShuffleWriteMetrics { fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { let write_time = MetricBuilder::new(metrics).subset_time("write_time", partition); + let repart_time = + MetricBuilder::new(metrics).subset_time("repart_time", partition); let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); @@ -95,6 +96,7 @@ impl ShuffleWriteMetrics { Self { write_time, + repart_time, input_rows, output_rows, } @@ -202,77 +204,48 @@ impl ShuffleWriterExec { writers.push(None); } - let hashes_buf = &mut vec![]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let mut partitioner = BatchPartitioner::try_new( + Partitioning::Hash(exprs.clone(), *n), + write_metrics.repart_time.clone(), + )?; while let Some(result) = stream.next().await { let input_batch = result?; write_metrics.input_rows.add(input_batch.num_rows()); - let arrays = exprs - .iter() - .map(|expr| { - Ok(expr - .evaluate(&input_batch)? - .into_array(input_batch.num_rows())) - }) - .collect::>>()?; - hashes_buf.clear(); - hashes_buf.resize(arrays[0].len(), 0); - // Hash arrays and compute buckets based on number of partitions - let hashes = create_hashes(&arrays, &random_state, hashes_buf)?; - let mut indices = vec![vec![]; num_output_partitions]; - for (index, hash) in hashes.iter().enumerate() { - indices[(*hash % num_output_partitions as u64) as usize] - .push(index as u64) - } - for (output_partition, partition_indices) in - indices.into_iter().enumerate() - { - let indices = partition_indices.into(); - - // Produce batches based on indices - let columns = input_batch - .columns() - .iter() - .map(|c| { - take(c.as_ref(), &indices, None).map_err(|e| { - DataFusionError::Execution(e.to_string()) - }) - }) - .collect::>>>()?; - - let output_batch = - RecordBatch::try_new(input_batch.schema(), columns)?; - - // write non-empty batch out - - // TODO optimize so we don't write or fetch empty partitions - // if output_batch.num_rows() > 0 { - let timer = write_metrics.write_time.timer(); - match &mut writers[output_partition] { - Some(w) => { - w.write(&output_batch)?; + partitioner.partition( + input_batch, + |output_partition, output_batch| { + // write non-empty batch out + + // TODO optimize so we don't write or fetch empty partitions + // if output_batch.num_rows() > 0 { + let timer = write_metrics.write_time.timer(); + match &mut writers[output_partition] { + Some(w) => { + w.write(&output_batch)?; + } + None => { + let mut path = path.clone(); + path.push(&format!("{}", output_partition)); + std::fs::create_dir_all(&path)?; + + path.push(format!("data-{}.arrow", input_partition)); + info!("Writing results to {:?}", path); + + let mut writer = + IPCWriter::new(&path, stream.schema().as_ref())?; + + writer.write(&output_batch)?; + writers[output_partition] = Some(writer); + } } - None => { - let mut path = path.clone(); - path.push(&format!("{}", output_partition)); - std::fs::create_dir_all(&path)?; - - path.push(format!("data-{}.arrow", input_partition)); - info!("Writing results to {:?}", path); - - let mut writer = - IPCWriter::new(&path, stream.schema().as_ref())?; - - writer.write(&output_batch)?; - writers[output_partition] = Some(writer); - } - } - write_metrics.output_rows.add(output_batch.num_rows()); - timer.done(); - } + write_metrics.output_rows.add(output_batch.num_rows()); + timer.done(); + Ok(()) + }, + )?; } let mut part_locs = vec![]; diff --git a/datafusion/core/src/physical_plan/metrics/value.rs b/datafusion/core/src/physical_plan/metrics/value.rs index ffb4ebb3f655..4bf92221fe66 100644 --- a/datafusion/core/src/physical_plan/metrics/value.rs +++ b/datafusion/core/src/physical_plan/metrics/value.rs @@ -300,6 +300,11 @@ impl<'a> ScopedTimerGuard<'a> { } } + /// Restarts the timer recording from the current time + pub fn restart(&mut self) { + self.start = Some(Instant::now()) + } + /// Stop the timer, record the time taken and consume self pub fn done(mut self) { self.stop() diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 036421637af4..37955539639f 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -26,9 +26,10 @@ use std::{any::Any, vec}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::hash_utils::create_hashes; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; +use arrow::array::{ArrayRef, UInt64Builder}; +use arrow::datatypes::SchemaRef; +use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use arrow::{array::Array, error::Result as ArrowResult}; -use arrow::{compute::take, datatypes::SchemaRef}; use log::debug; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -39,6 +40,7 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; use crate::execution::context::TaskContext; +use datafusion_physical_expr::PhysicalExpr; use futures::stream::Stream; use futures::StreamExt; use hashbrown::HashMap; @@ -62,6 +64,133 @@ struct RepartitionExecState { abort_helper: Arc>, } +/// A utility that can be used to partition batches based on [`Partitioning`] +pub struct BatchPartitioner { + state: BatchPartitionerState, + timer: metrics::Time, +} + +enum BatchPartitionerState { + Hash { + random_state: ahash::RandomState, + exprs: Vec>, + num_partitions: usize, + hash_buffer: Vec, + }, + RoundRobin { + num_partitions: usize, + next_idx: usize, + }, +} + +impl BatchPartitioner { + /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`] + /// + /// The time spent repartitioning will be recorded to `timer` + pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result { + let state = match partitioning { + Partitioning::RoundRobinBatch(num_partitions) => { + BatchPartitionerState::RoundRobin { + num_partitions, + next_idx: 0, + } + } + Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash { + exprs, + num_partitions, + // Use fixed random hash + random_state: ahash::RandomState::with_seeds(0, 0, 0, 0), + hash_buffer: vec![], + }, + other => { + return Err(DataFusionError::NotImplemented(format!( + "Unsupported repartitioning scheme {:?}", + other + ))) + } + }; + + Ok(Self { state, timer }) + } + + /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`] + /// based on the [`Partitioning`] specified on construction + /// + /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding + /// partition index. Any error returned by `f` will be immediately returned by this + /// function without attempting to publish further [`RecordBatch`] + /// + /// The time spent repartitioning, not including time spent in `f` will be recorded + /// to the [`metrics::Time`] provided on construction + pub fn partition(&mut self, batch: RecordBatch, mut f: F) -> Result<()> + where + F: FnMut(usize, RecordBatch) -> Result<()>, + { + match &mut self.state { + BatchPartitionerState::RoundRobin { + num_partitions, + next_idx, + } => { + let idx = *next_idx; + *next_idx = (*next_idx + 1) % *num_partitions; + f(idx, batch)?; + } + BatchPartitionerState::Hash { + random_state, + exprs, + num_partitions: partitions, + hash_buffer, + } => { + let mut timer = self.timer.timer(); + + let arrays = exprs + .iter() + .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) + .collect::>>()?; + + hash_buffer.clear(); + hash_buffer.resize(batch.num_rows(), 0); + + create_hashes(&arrays, random_state, hash_buffer)?; + + let mut indices: Vec<_> = (0..*partitions) + .map(|_| UInt64Builder::new(batch.num_rows())) + .collect(); + + for (index, hash) in hash_buffer.iter().enumerate() { + indices[(*hash % *partitions as u64) as usize] + .append_value(index as u64) + .unwrap(); + } + + for (partition, mut indices) in indices.into_iter().enumerate() { + let indices = indices.finish(); + if indices.is_empty() { + continue; + } + + // Produce batches based on indices + let columns = batch + .columns() + .iter() + .map(|c| { + arrow::compute::take(c.as_ref(), &indices, None) + .map_err(DataFusionError::ArrowError) + }) + .collect::>>()?; + + let batch = RecordBatch::try_new(batch.schema(), columns).unwrap(); + + timer.stop(); + f(partition, batch)?; + timer.restart(); + } + } + } + Ok(()) + } +} + /// The repartition operator maps N input partitions to M output partitions based on a /// partitioning scheme. No guarantees are made about the order of the resulting partitions. #[derive(Debug)] @@ -199,8 +328,6 @@ impl ExecutionPlan for RepartitionExec { mpsc::unbounded_channel::>>(); state.channels.insert(partition, (sender, receiver)); } - // Use fixed random state - let random = ahash::RandomState::with_seeds(0, 0, 0, 0); // launch one async task per *input* partition let mut join_handles = Vec::with_capacity(num_input_partitions); @@ -215,7 +342,6 @@ impl ExecutionPlan for RepartitionExec { let input_task: JoinHandle> = tokio::spawn(Self::pull_from_input( - random.clone(), self.input.clone(), i, txs.clone(), @@ -299,7 +425,6 @@ impl RepartitionExec { /// /// txs hold the output sending channels for each output partition async fn pull_from_input( - random_state: ahash::RandomState, input: Arc, i: usize, mut txs: HashMap>>>, @@ -307,16 +432,14 @@ impl RepartitionExec { r_metrics: RepartitionMetrics, context: Arc, ) -> Result<()> { - let num_output_partitions = txs.len(); + let mut partitioner = + BatchPartitioner::try_new(partitioning, r_metrics.repart_time.clone())?; // execute the child operator let timer = r_metrics.fetch_time.timer(); let mut stream = input.execute(i, context).await?; timer.done(); - let mut counter = 0; - let hashes_buf = &mut vec![]; - // While there are still outputs to send to, keep // pulling inputs while !txs.is_empty() { @@ -326,89 +449,23 @@ impl RepartitionExec { timer.done(); // Input is done - if result.is_none() { - break; - } - let result: ArrowResult = result.unwrap(); - - match &partitioning { - Partitioning::RoundRobinBatch(_) => { - let timer = r_metrics.send_time.timer(); - let output_partition = counter % num_output_partitions; - // if there is still a receiver, send to it - if let Some(tx) = txs.get_mut(&output_partition) { - if tx.send(Some(result)).is_err() { - // If the other end has hung up, it was an early shutdown (e.g. LIMIT) - txs.remove(&output_partition); - } + let batch = match result { + Some(result) => result?, + None => break, + }; + + partitioner.partition(batch, |partition, partitioned| { + let timer = r_metrics.send_time.timer(); + // if there is still a receiver, send to it + if let Some(tx) = txs.get_mut(&partition) { + if tx.send(Some(Ok(partitioned))).is_err() { + // If the other end has hung up, it was an early shutdown (e.g. LIMIT) + txs.remove(&partition); } - timer.done(); } - Partitioning::Hash(exprs, _) => { - let timer = r_metrics.repart_time.timer(); - let input_batch = result?; - let arrays = exprs - .iter() - .map(|expr| { - Ok(expr - .evaluate(&input_batch)? - .into_array(input_batch.num_rows())) - }) - .collect::>>()?; - hashes_buf.clear(); - hashes_buf.resize(arrays[0].len(), 0); - // Hash arrays and compute buckets based on number of partitions - let hashes = create_hashes(&arrays, &random_state, hashes_buf)?; - let mut indices = vec![vec![]; num_output_partitions]; - for (index, hash) in hashes.iter().enumerate() { - indices[(*hash % num_output_partitions as u64) as usize] - .push(index as u64) - } - timer.done(); - - for (num_output_partition, partition_indices) in - indices.into_iter().enumerate() - { - if partition_indices.is_empty() { - continue; - } - let timer = r_metrics.repart_time.timer(); - let indices = partition_indices.into(); - // Produce batches based on indices - let columns = input_batch - .columns() - .iter() - .map(|c| { - take(c.as_ref(), &indices, None).map_err(|e| { - DataFusionError::Execution(e.to_string()) - }) - }) - .collect::>>>()?; - let output_batch = - RecordBatch::try_new(input_batch.schema(), columns); - timer.done(); - - let timer = r_metrics.send_time.timer(); - // if there is still a receiver, send to it - if let Some(tx) = txs.get_mut(&num_output_partition) { - if tx.send(Some(output_batch)).is_err() { - // If the other end has hung up, it was an early shutdown (e.g. LIMIT) - txs.remove(&num_output_partition); - } - } - timer.done(); - } - } - other => { - // this should be unreachable as long as the validation logic - // in the constructor is kept up-to-date - return Err(DataFusionError::NotImplemented(format!( - "Unsupported repartitioning scheme {:?}", - other - ))); - } - } - counter += 1; + timer.done(); + Ok(()) + })?; } Ok(())