Skip to content

Commit

Permalink
Order Preserving RepartitionExec Implementation (#6742)
Browse files Browse the repository at this point in the history
* Write tests for functionality

* Implement sort preserving repartition exec

* Minor changes

* Implement second design (per partition merge)

* Simplifications

* Address reviews

* Move the fuzz test to appropriate folder, improve comments

* Decrease code duplication

* simplifications

* Update comment

---------

Co-authored-by: Mehmet Ozan Kabak <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2023
1 parent 1522e7a commit f24a724
Show file tree
Hide file tree
Showing 4 changed files with 405 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ pub fn channels<T>(
(senders, receivers)
}

type PartitionAwareSenders<T> = Vec<Vec<DistributionSender<T>>>;
type PartitionAwareReceivers<T> = Vec<Vec<DistributionReceiver<T>>>;

/// Create `n_out` empty channels for each of the `n_in` inputs.
/// This way, each distinct partition will communicate via a dedicated channel.
/// This SPSC structure enables us to track which partition input data comes from.
pub fn partition_aware_channels<T>(
n_in: usize,
n_out: usize,
) -> (PartitionAwareSenders<T>, PartitionAwareReceivers<T>) {
(0..n_in).map(|_| channels(n_out)).unzip()
}

/// Erroring during [send](DistributionSender::send).
///
/// This occurs when the [receiver](DistributionReceiver) is gone.
Expand Down
179 changes: 154 additions & 25 deletions datafusion/core/src/physical_plan/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
// under the License.

//! The repartition operator maps N input partitions to M output partitions based on a
//! partitioning scheme.
//! partitioning scheme (according to flag `preserve_order` ordering can be preserved during
//! repartitioning if its input is ordered).
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, vec};

use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::repartition::distributor_channels::channels;
use crate::physical_plan::repartition::distributor_channels::{
channels, partition_aware_channels,
};
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
};
Expand All @@ -42,6 +45,9 @@ use super::expressions::PhysicalSortExpr;
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{RecordBatchStream, SendableRecordBatchStream};

use crate::physical_plan::common::transpose;
use crate::physical_plan::metrics::BaselineMetrics;
use crate::physical_plan::sorts::streaming_merge;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::Stream;
Expand All @@ -53,6 +59,8 @@ use tokio::task::JoinHandle;
mod distributor_channels;

type MaybeBatch = Option<Result<RecordBatch>>;
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;

/// Inner state of [`RepartitionExec`].
#[derive(Debug)]
Expand All @@ -62,8 +70,8 @@ struct RepartitionExecState {
channels: HashMap<
usize,
(
DistributionSender<MaybeBatch>,
DistributionReceiver<MaybeBatch>,
InputPartitionsToCurrentPartitionSender,
InputPartitionsToCurrentPartitionReceiver,
SharedMemoryReservation,
),
>,
Expand Down Expand Up @@ -245,6 +253,9 @@ pub struct RepartitionExec {

/// Execution metrics
metrics: ExecutionPlanMetricsSet,

/// Boolean flag to decide whether to preserve ordering
preserve_order: bool,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -298,6 +309,15 @@ impl RepartitionExec {
pub fn partitioning(&self) -> &Partitioning {
&self.partitioning
}

/// Get name of the Executor
pub fn name(&self) -> &str {
if self.preserve_order {
"SortPreservingRepartitionExec"
} else {
"RepartitionExec"
}
}
}

impl ExecutionPlan for RepartitionExec {
Expand Down Expand Up @@ -345,8 +365,12 @@ impl ExecutionPlan for RepartitionExec {
}

fn maintains_input_order(&self) -> Vec<bool> {
// We preserve ordering when input partitioning is 1
vec![self.input().output_partitioning().partition_count() <= 1]
if self.preserve_order {
vec![true]
} else {
// We preserve ordering when input partitioning is 1
vec![self.input().output_partitioning().partition_count() <= 1]
}
}

fn equivalence_properties(&self) -> EquivalenceProperties {
Expand All @@ -359,7 +383,8 @@ impl ExecutionPlan for RepartitionExec {
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
"Start RepartitionExec::execute for partition: {}",
"Start {}::execute for partition: {}",
self.name(),
partition
);
// lock mutexes
Expand All @@ -370,13 +395,29 @@ impl ExecutionPlan for RepartitionExec {

// if this is the first partition to be invoked then we need to set up initial state
if state.channels.is_empty() {
// create one channel per *output* partition
// note we use a custom channel that ensures there is always data for each receiver
// but limits the amount of buffering if required.
let (txs, rxs) = channels(num_output_partitions);
let (txs, rxs) = if self.preserve_order {
let (txs, rxs) =
partition_aware_channels(num_input_partitions, num_output_partitions);
// Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
let txs = transpose(txs);
let rxs = transpose(rxs);
(txs, rxs)
} else {
// create one channel per *output* partition
// note we use a custom channel that ensures there is always data for each receiver
// but limits the amount of buffering if required.
let (txs, rxs) = channels(num_output_partitions);
// Clone sender for ech input partitions
let txs = txs
.into_iter()
.map(|item| vec![item; num_input_partitions])
.collect::<Vec<_>>();
let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
(txs, rxs)
};
for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
MemoryConsumer::new(format!("{}[{partition}]", self.name()))
.register(context.memory_pool()),
));
state.channels.insert(partition, (tx, rx, reservation));
Expand All @@ -389,7 +430,7 @@ impl ExecutionPlan for RepartitionExec {
.channels
.iter()
.map(|(partition, (tx, _rx, reservation))| {
(*partition, (tx.clone(), Arc::clone(reservation)))
(*partition, (tx[i].clone(), Arc::clone(reservation)))
})
.collect();

Expand Down Expand Up @@ -420,24 +461,53 @@ impl ExecutionPlan for RepartitionExec {
}

trace!(
"Before returning stream in RepartitionExec::execute for partition: {}",
"Before returning stream in {}::execute for partition: {}",
self.name(),
partition
);

// now return stream for the specified *output* partition which will
// read from the channel
let (_tx, rx, reservation) = state
let (_tx, mut rx, reservation) = state
.channels
.remove(&partition)
.expect("partition not used yet");
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
input: rx,
drop_helper: Arc::clone(&state.abort_helper),
reservation,
}))

if self.preserve_order {
// Store streams from all the input partitions:
let input_streams = rx
.into_iter()
.map(|receiver| {
Box::pin(PerPartitionStream {
schema: self.schema(),
receiver,
drop_helper: Arc::clone(&state.abort_helper),
reservation: reservation.clone(),
}) as SendableRecordBatchStream
})
.collect::<Vec<_>>();
// Note that receiver size (`rx.len()`) and `num_input_partitions` are same.

// Get existing ordering:
let sort_exprs = self.input.output_ordering().unwrap_or(&[]);
// Merge streams (while preserving ordering) coming from input partitions to this partition:
streaming_merge(
input_streams,
self.schema(),
sort_exprs,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
)
} else {
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
input: rx.swap_remove(0),
drop_helper: Arc::clone(&state.abort_helper),
reservation,
}))
}
}

fn metrics(&self) -> Option<MetricsSet> {
Expand All @@ -453,7 +523,8 @@ impl ExecutionPlan for RepartitionExec {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"RepartitionExec: partitioning={}, input_partitions={}",
"{}: partitioning={}, input_partitions={}",
self.name(),
self.partitioning,
self.input.output_partitioning().partition_count()
)
Expand All @@ -480,9 +551,16 @@ impl RepartitionExec {
abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
})),
metrics: ExecutionPlanMetricsSet::new(),
preserve_order: false,
})
}

/// Set Order preserving flag
pub fn with_preserve_order(mut self) -> Self {
self.preserve_order = true;
self
}

/// Pulls data from the specified input plan, feeding it to the
/// output partitions based on the desired partitioning
///
Expand Down Expand Up @@ -575,7 +653,7 @@ impl RepartitionExec {
/// channels.
async fn wait_for_task(
input_task: AbortOnDropSingle<Result<()>>,
txs: HashMap<usize, DistributionSender<Option<Result<RecordBatch>>>>,
txs: HashMap<usize, DistributionSender<MaybeBatch>>,
) {
// wait for completion, and propagate error
// note we ignore errors on send (.ok) as that means the receiver has already shutdown.
Expand Down Expand Up @@ -681,6 +759,56 @@ impl RecordBatchStream for RepartitionStream {
}
}

/// This struct converts a receiver to a stream.
/// Receiver receives data on an SPSC channel.
struct PerPartitionStream {
/// Schema wrapped by Arc
schema: SchemaRef,

/// channel containing the repartitioned batches
receiver: DistributionReceiver<MaybeBatch>,

/// Handle to ensure background tasks are killed when no longer needed.
#[allow(dead_code)]
drop_helper: Arc<AbortOnDropMany<()>>,

/// Memory reservation.
reservation: SharedMemoryReservation,
}

impl Stream for PerPartitionStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.receiver.recv().poll_unpin(cx) {
Poll::Ready(Some(Some(v))) => {
if let Ok(batch) = &v {
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
}
Poll::Ready(Some(v))
}
Poll::Ready(Some(None)) => {
// Input partition has finished sending batches
Poll::Ready(None)
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}

impl RecordBatchStream for PerPartitionStream {
/// Get the schema
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -705,6 +833,7 @@ mod tests {
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use futures::FutureExt;
use std::collections::HashSet;
use tokio::task::JoinHandle;

#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/fuzz_cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ mod aggregate_fuzz;
mod join_fuzz;
mod merge_fuzz;
mod order_spill_fuzz;
mod sort_preserving_repartition_fuzz;
mod window_fuzz;
Loading

0 comments on commit f24a724

Please sign in to comment.