diff --git a/datafusion/core/src/physical_plan/repartition/distributor_channels.rs b/datafusion/core/src/physical_plan/repartition/distributor_channels.rs new file mode 100644 index 000000000000..412926fbc61a --- /dev/null +++ b/datafusion/core/src/physical_plan/repartition/distributor_channels.rs @@ -0,0 +1,710 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Special channel construction to distribute data from various inputs into N outputs +//! minimizing buffering but preventing deadlocks when repartitoning +//! +//! # Design +//! +//! ```text +//! +----+ +------+ +//! | TX |==|| | Gate | +//! +----+ || | | +--------+ +----+ +//! ====| |==| Buffer |==| RX | +//! +----+ || | | +--------+ +----+ +//! | TX |==|| | | +//! +----+ | | +//! | | +//! +----+ | | +--------+ +----+ +//! | TX |======| |==| Buffer |==| RX | +//! +----+ +------+ +--------+ +----+ +//! ``` +//! +//! There are `N` virtual MPSC (multi-producer, single consumer) channels with unbounded capacity. However, if all +//! buffers/channels are non-empty, than a global gate will be closed preventing new data from being written (the +//! sender futures will be [pending](Poll::Pending)) until at least one channel is empty (and not closed). +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, +}; + +use parking_lot::Mutex; + +/// Create `n` empty channels. +pub fn channels( + n: usize, +) -> (Vec>, Vec>) { + let channels = (0..n) + .map(|id| { + Arc::new(Mutex::new(Channel { + data: VecDeque::default(), + n_senders: 1, + recv_alive: true, + recv_wakers: Vec::default(), + id, + })) + }) + .collect::>(); + let gate = Arc::new(Mutex::new(Gate { + empty_channels: n, + send_wakers: Vec::default(), + })); + let senders = channels + .iter() + .map(|channel| DistributionSender { + channel: Arc::clone(channel), + gate: Arc::clone(&gate), + }) + .collect(); + let receivers = channels + .into_iter() + .map(|channel| DistributionReceiver { + channel, + gate: Arc::clone(&gate), + }) + .collect(); + (senders, receivers) +} + +/// Erroring during [send](DistributionSender::send). +/// +/// This occurs when the [receiver](DistributedReceiver) is gone. +#[derive(PartialEq, Eq)] +pub struct SendError(pub T); + +impl std::fmt::Debug for SendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("SendError").finish() + } +} + +impl std::fmt::Display for SendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "cannot send data, receiver is gone") + } +} + +impl std::error::Error for SendError {} + +/// Sender side of distribution [channels]. +/// +/// This handle can be cloned. All clones will write into the same channel. Dropping the last sender will close the +/// channel. In this case, the [receiver](DistributionReceiver) will still be able to poll the remaining data, but will +/// receive `None` afterwards. +#[derive(Debug)] +pub struct DistributionSender { + /// To prevent lock inversion / deadlock, channel lock is always acquired prior to gate lock + channel: SharedChannel, + gate: SharedGate, +} + +impl DistributionSender { + /// Send data. + /// + /// This fails if the [receiver](DistributionReceiver) is gone. + pub fn send(&self, element: T) -> SendFuture<'_, T> { + SendFuture { + channel: &self.channel, + gate: &self.gate, + element: Box::new(Some(element)), + } + } +} + +impl Clone for DistributionSender { + fn clone(&self) -> Self { + let mut guard = self.channel.lock(); + guard.n_senders += 1; + + Self { + channel: Arc::clone(&self.channel), + gate: Arc::clone(&self.gate), + } + } +} + +impl Drop for DistributionSender { + fn drop(&mut self) { + let mut guard_channel = self.channel.lock(); + guard_channel.n_senders -= 1; + + if guard_channel.n_senders == 0 { + // Note: the recv_alive check is so that we don't double-clear the status + if guard_channel.data.is_empty() && guard_channel.recv_alive { + // channel is gone, so we need to clear our signal + let mut guard_gate = self.gate.lock(); + guard_gate.empty_channels -= 1; + } + + // receiver may be waiting for data, but should return `None` now since the channel is closed + guard_channel.wake_receivers(); + } + } +} + +/// Future backing [send](DistributionSender::send). +#[derive(Debug)] +pub struct SendFuture<'a, T> { + channel: &'a SharedChannel, + gate: &'a SharedGate, + // the additional Box is required for `Self: Unpin` + element: Box>, +} + +impl<'a, T> Future for SendFuture<'a, T> { + type Output = Result<(), SendError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + assert!(this.element.is_some(), "polled ready future"); + + let mut guard_channel = this.channel.lock(); + + // receiver end still alive? + if !guard_channel.recv_alive { + return Poll::Ready(Err(SendError( + this.element.take().expect("just checked"), + ))); + } + + let mut guard_gate = this.gate.lock(); + + // does ANY receiver need data? + // if so, allow sender to create another + if guard_gate.empty_channels == 0 { + guard_gate + .send_wakers + .push((cx.waker().clone(), guard_channel.id)); + return Poll::Pending; + } + + let was_empty = guard_channel.data.is_empty(); + guard_channel + .data + .push_back(this.element.take().expect("just checked")); + if was_empty { + guard_gate.empty_channels -= 1; + guard_channel.wake_receivers(); + } + + Poll::Ready(Ok(())) + } +} + +/// Receiver side of distribution [channels]. +#[derive(Debug)] +pub struct DistributionReceiver { + channel: SharedChannel, + gate: SharedGate, +} + +impl DistributionReceiver { + /// Receive data from channel. + /// + /// Returns `None` if the channel is empty and no [senders](DistributionSender) are left. + pub fn recv(&mut self) -> RecvFuture<'_, T> { + RecvFuture { + channel: &mut self.channel, + gate: &mut self.gate, + rdy: false, + } + } +} + +impl Drop for DistributionReceiver { + fn drop(&mut self) { + let mut guard_channel = self.channel.lock(); + let mut guard_gate = self.gate.lock(); + guard_channel.recv_alive = false; + + // Note: n_senders check is here so we don't double-clear the signal + if guard_channel.data.is_empty() && (guard_channel.n_senders > 0) { + // channel is gone, so we need to clear our signal + guard_gate.empty_channels -= 1; + } + + // senders may be waiting for gate to open but should error now that the channel is closed + guard_gate.wake_channel_senders(guard_channel.id); + + // clear potential remaining data from channel + guard_channel.data.clear(); + } +} + +/// Future backing [recv](DistributionReceiver::recv). +pub struct RecvFuture<'a, T> { + channel: &'a mut SharedChannel, + gate: &'a mut SharedGate, + rdy: bool, +} + +impl<'a, T> Future for RecvFuture<'a, T> { + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + assert!(!this.rdy, "polled ready future"); + + let mut guard_channel = this.channel.lock(); + + match guard_channel.data.pop_front() { + Some(element) => { + // change "empty" signal for this channel? + if guard_channel.data.is_empty() && (guard_channel.n_senders > 0) { + let mut guard_gate = this.gate.lock(); + + // update counter + let old_counter = guard_gate.empty_channels; + guard_gate.empty_channels += 1; + + // open gate? + if old_counter == 0 { + guard_gate.wake_all_senders(); + } + + drop(guard_gate); + drop(guard_channel); + } + + this.rdy = true; + Poll::Ready(Some(element)) + } + None if guard_channel.n_senders == 0 => { + this.rdy = true; + Poll::Ready(None) + } + None => { + guard_channel.recv_wakers.push(cx.waker().clone()); + Poll::Pending + } + } + } +} + +/// Links senders and receivers. +#[derive(Debug)] +struct Channel { + /// Buffered data. + data: VecDeque, + + /// Reference counter for the sender side. + n_senders: usize, + + /// Reference "counter"/flag for the single receiver. + recv_alive: bool, + + /// Wakers for the receiver side. + /// + /// The receiver will be pending if the [buffer](Self::data) is empty and + /// there are senders left (according to the [reference counter](Self::n_senders)). + recv_wakers: Vec, + + /// Channel ID. + /// + /// This is used to address [send wakers](Gate::send_wakers). + id: usize, +} + +impl Channel { + fn wake_receivers(&mut self) { + for waker in self.recv_wakers.drain(..) { + waker.wake(); + } + } +} + +/// Shared channel. +/// +/// One or multiple senders and a single receiver will share a channel. +type SharedChannel = Arc>>; + +/// The "all channels have data" gate. +#[derive(Debug)] +struct Gate { + /// Number of currently empty (and still open) channels. + empty_channels: usize, + + /// Wakers for the sender side, including their channel IDs. + send_wakers: Vec<(Waker, usize)>, +} + +impl Gate { + //// Wake all senders. + /// + /// This is helpful to signal that there are some channels empty now and hence the gate was opened. + fn wake_all_senders(&mut self) { + for (waker, _id) in self.send_wakers.drain(..) { + waker.wake(); + } + } + + /// Wake senders for a specific channel. + /// + /// This is helpful to signal that the receiver side is gone and the senders shall now error. + fn wake_channel_senders(&mut self, id: usize) { + // `drain_filter` is unstable, so implement our own + let (wake, keep) = self + .send_wakers + .drain(..) + .partition(|(_waker, id2)| id == *id2); + self.send_wakers = keep; + for (waker, _id) in wake { + waker.wake(); + } + } +} + +/// Gate shared by all senders and receivers. +type SharedGate = Arc>; + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicBool, Ordering}; + + use futures::{task::ArcWake, FutureExt}; + + use super::*; + + #[test] + fn test_single_channel_no_gate() { + // use two channels so that the first one never hits the gate + let (mut txs, mut rxs) = channels(2); + + let mut recv_fut = rxs[0].recv(); + let waker = poll_pending(&mut recv_fut); + + poll_ready(&mut txs[0].send("foo")).unwrap(); + assert!(waker.woken()); + assert_eq!(poll_ready(&mut recv_fut), Some("foo"),); + + poll_ready(&mut txs[0].send("bar")).unwrap(); + poll_ready(&mut txs[0].send("baz")).unwrap(); + poll_ready(&mut txs[0].send("end")).unwrap(); + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),); + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("baz"),); + + // close channel + txs.remove(0); + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("end"),); + assert_eq!(poll_ready(&mut rxs[0].recv()), None,); + assert_eq!(poll_ready(&mut rxs[0].recv()), None,); + } + + #[test] + fn test_multi_sender() { + // use two channels so that the first one never hits the gate + let (txs, mut rxs) = channels(2); + + let tx_clone = txs[0].clone(); + + poll_ready(&mut txs[0].send("foo")).unwrap(); + poll_ready(&mut tx_clone.send("bar")).unwrap(); + + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("foo"),); + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("bar"),); + } + + #[test] + fn test_gate() { + let (txs, mut rxs) = channels(2); + + // gate initially open + poll_ready(&mut txs[0].send("0_a")).unwrap(); + + // gate still open because channel 1 is still empty + poll_ready(&mut txs[0].send("0_b")).unwrap(); + + // gate still open because channel 1 is still empty prior to this call, so this call still goes through + poll_ready(&mut txs[1].send("1_a")).unwrap(); + + // both channels non-empty => gate closed + + let mut send_fut = txs[1].send("1_b"); + let waker = poll_pending(&mut send_fut); + + // drain channel 0 + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_a"),); + poll_pending(&mut send_fut); + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("0_b"),); + + // channel 0 empty => gate open + assert!(waker.woken()); + poll_ready(&mut send_fut).unwrap(); + } + + #[test] + fn test_close_channel_by_dropping_tx() { + let (mut txs, mut rxs) = channels(2); + + let tx0 = txs.remove(0); + let tx1 = txs.remove(0); + let tx0_clone = tx0.clone(); + + let mut recv_fut = rxs[0].recv(); + + poll_ready(&mut tx1.send("a")).unwrap(); + let recv_waker = poll_pending(&mut recv_fut); + + // drop original sender + drop(tx0); + + // not yet closed (there's a clone left) + assert!(!recv_waker.woken()); + poll_ready(&mut tx1.send("b")).unwrap(); + let recv_waker = poll_pending(&mut recv_fut); + + // create new clone + let tx0_clone2 = tx0_clone.clone(); + assert!(!recv_waker.woken()); + poll_ready(&mut tx1.send("c")).unwrap(); + let recv_waker = poll_pending(&mut recv_fut); + + // drop first clone + drop(tx0_clone); + assert!(!recv_waker.woken()); + poll_ready(&mut tx1.send("d")).unwrap(); + let recv_waker = poll_pending(&mut recv_fut); + + // drop last clone + drop(tx0_clone2); + + // channel closed => also close gate + poll_pending(&mut tx1.send("e")); + assert!(recv_waker.woken()); + assert_eq!(poll_ready(&mut recv_fut), None,); + } + + #[test] + fn test_close_channel_by_dropping_rx_on_open_gate() { + let (txs, mut rxs) = channels(2); + + let rx0 = rxs.remove(0); + let _rx1 = rxs.remove(0); + + poll_ready(&mut txs[1].send("a")).unwrap(); + + // drop receiver => also close gate + drop(rx0); + + poll_pending(&mut txs[1].send("b")); + assert_eq!(poll_ready(&mut txs[0].send("foo")), Err(SendError("foo")),); + } + + #[test] + fn test_close_channel_by_dropping_rx_on_closed_gate() { + let (txs, mut rxs) = channels(2); + + let rx0 = rxs.remove(0); + let mut rx1 = rxs.remove(0); + + // fill both channels + poll_ready(&mut txs[0].send("0_a")).unwrap(); + poll_ready(&mut txs[1].send("1_a")).unwrap(); + + let mut send_fut0 = txs[0].send("0_b"); + let mut send_fut1 = txs[1].send("1_b"); + let waker0 = poll_pending(&mut send_fut0); + let waker1 = poll_pending(&mut send_fut1); + + // drop receiver + drop(rx0); + + assert!(waker0.woken()); + assert!(!waker1.woken()); + assert_eq!(poll_ready(&mut send_fut0), Err(SendError("0_b")),); + + // gate closed, so cannot send on channel 1 + poll_pending(&mut send_fut1); + + // channel 1 can still receive data + assert_eq!(poll_ready(&mut rx1.recv()), Some("1_a"),); + } + + #[test] + fn test_drop_rx_three_channels() { + let (mut txs, mut rxs) = channels(3); + + let tx0 = txs.remove(0); + let tx1 = txs.remove(0); + let tx2 = txs.remove(0); + let mut rx0 = rxs.remove(0); + let rx1 = rxs.remove(0); + let _rx2 = rxs.remove(0); + + // fill channels + poll_ready(&mut tx0.send("0_a")).unwrap(); + poll_ready(&mut tx1.send("1_a")).unwrap(); + poll_ready(&mut tx2.send("2_a")).unwrap(); + + // drop / close one channel + drop(rx1); + + // receive data + assert_eq!(poll_ready(&mut rx0.recv()), Some("0_a"),); + + // use senders again + poll_ready(&mut tx0.send("0_b")).unwrap(); + assert_eq!(poll_ready(&mut tx1.send("1_b")), Err(SendError("1_b")),); + poll_pending(&mut tx2.send("2_b")); + } + + #[test] + fn test_close_channel_by_dropping_rx_clears_data() { + let (txs, rxs) = channels(1); + + let obj = Arc::new(()); + let counter = Arc::downgrade(&obj); + assert_eq!(counter.strong_count(), 1); + + // add object to channel + poll_ready(&mut txs[0].send(obj)).unwrap(); + assert_eq!(counter.strong_count(), 1); + + // drop receiver + drop(rxs); + + assert_eq!(counter.strong_count(), 0); + } + + #[test] + #[should_panic(expected = "polled ready future")] + fn test_panic_poll_send_future_after_ready_ok() { + let (txs, _rxs) = channels(1); + let mut fut = txs[0].send("foo"); + poll_ready(&mut fut).unwrap(); + poll_ready(&mut fut).ok(); + } + + #[test] + #[should_panic(expected = "polled ready future")] + fn test_panic_poll_send_future_after_ready_err() { + let (txs, rxs) = channels(1); + + drop(rxs); + + let mut fut = txs[0].send("foo"); + poll_ready(&mut fut).unwrap_err(); + poll_ready(&mut fut).ok(); + } + + #[test] + #[should_panic(expected = "polled ready future")] + fn test_panic_poll_recv_future_after_ready_some() { + let (txs, mut rxs) = channels(1); + + poll_ready(&mut txs[0].send("foo")).unwrap(); + + let mut fut = rxs[0].recv(); + poll_ready(&mut fut).unwrap(); + poll_ready(&mut fut); + } + + #[test] + #[should_panic(expected = "polled ready future")] + fn test_panic_poll_recv_future_after_ready_none() { + let (txs, mut rxs) = channels::(1); + + drop(txs); + + let mut fut = rxs[0].recv(); + assert!(poll_ready(&mut fut).is_none()); + poll_ready(&mut fut); + } + + #[test] + #[should_panic(expected = "future is pending")] + fn test_meta_poll_ready_wrong_state() { + let mut fut = futures::future::pending::(); + poll_ready(&mut fut); + } + + #[test] + #[should_panic(expected = "future is ready")] + fn test_meta_poll_pending_wrong_state() { + let mut fut = futures::future::ready(1); + poll_pending(&mut fut); + } + + #[test] + fn test_meta_poll_pending_waker() { + let (tx, mut rx) = futures::channel::oneshot::channel(); + let waker = poll_pending(&mut rx); + assert!(!waker.woken()); + tx.send(1).unwrap(); + assert!(waker.woken()); + } + + /// Poll a given [`Future`] and ensure it is [ready](Poll::Ready). + #[track_caller] + fn poll_ready(fut: &mut F) -> F::Output + where + F: Future + Unpin, + { + match poll(fut).0 { + Poll::Ready(x) => x, + Poll::Pending => panic!("future is pending"), + } + } + + /// Poll a given [`Future`] and ensure it is [pending](Poll::Pending). + /// + /// Returns a waker that can later be checked. + #[track_caller] + fn poll_pending(fut: &mut F) -> Arc + where + F: Future + Unpin, + { + let (res, waker) = poll(fut); + match res { + Poll::Ready(_) => panic!("future is ready"), + Poll::Pending => waker, + } + } + + fn poll(fut: &mut F) -> (Poll, Arc) + where + F: Future + Unpin, + { + let test_waker = Arc::new(TestWaker::default()); + let waker = futures::task::waker(Arc::clone(&test_waker)); + let mut cx = std::task::Context::from_waker(&waker); + let res = fut.poll_unpin(&mut cx); + (res, test_waker) + } + + /// A test [`Waker`] that signal if [`wake`](Waker::wake) was called. + #[derive(Debug, Default)] + struct TestWaker { + woken: AtomicBool, + } + + impl TestWaker { + /// Was [`wake`](Waker::wake) called? + fn woken(&self) -> bool { + self.woken.load(Ordering::SeqCst) + } + } + + impl ArcWake for TestWaker { + fn wake_by_ref(arc_self: &Arc) { + arc_self.woken.store(true, Ordering::SeqCst); + } + } +} diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition/mod.rs similarity index 88% rename from datafusion/core/src/physical_plan/repartition.rs rename to datafusion/core/src/physical_plan/repartition/mod.rs index 451b0fba4b13..1d0f1fe5cc5d 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition/mod.rs @@ -26,6 +26,7 @@ use std::{any::Any, vec}; use crate::error::{DataFusionError, Result}; use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use crate::physical_plan::hash_utils::create_hashes; +use crate::physical_plan::repartition::distributor_channels::channels; use crate::physical_plan::{ DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, }; @@ -34,7 +35,8 @@ use arrow::datatypes::SchemaRef; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use log::debug; -use tokio_stream::wrappers::UnboundedReceiverStream; + +use self::distributor_channels::{DistributionReceiver, DistributionSender}; use super::common::{AbortOnDropMany, AbortOnDropSingle}; use super::expressions::PhysicalSortExpr; @@ -44,12 +46,13 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::execution::context::TaskContext; use datafusion_physical_expr::PhysicalExpr; use futures::stream::Stream; -use futures::StreamExt; +use futures::{FutureExt, StreamExt}; use hashbrown::HashMap; use parking_lot::Mutex; -use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; +mod distributor_channels; + type MaybeBatch = Option>; type SharedMemoryReservation = Arc>; @@ -61,8 +64,8 @@ struct RepartitionExecState { channels: HashMap< usize, ( - UnboundedSender, - UnboundedReceiver, + DistributionSender, + DistributionReceiver, SharedMemoryReservation, ), >, @@ -132,67 +135,92 @@ impl BatchPartitioner { where F: FnMut(usize, RecordBatch) -> Result<()>, { - match &mut self.state { - BatchPartitionerState::RoundRobin { - num_partitions, - next_idx, - } => { - let idx = *next_idx; - *next_idx = (*next_idx + 1) % *num_partitions; - f(idx, batch)?; - } - BatchPartitionerState::Hash { - random_state, - exprs, - num_partitions: partitions, - hash_buffer, - } => { - let mut timer = self.timer.timer(); - - let arrays = exprs - .iter() - .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) - .collect::>>()?; + self.partition_iter(batch)?.try_for_each(|res| match res { + Ok((partition, batch)) => f(partition, batch), + Err(e) => Err(e), + }) + } - hash_buffer.clear(); - hash_buffer.resize(batch.num_rows(), 0); + /// Actual implementation of [`partition`](Self::partition). + /// + /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions, + /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve + /// this (so we don't need to clone the entire implementation). + fn partition_iter( + &mut self, + batch: RecordBatch, + ) -> Result> + Send + '_> { + let it: Box> + Send> = + match &mut self.state { + BatchPartitionerState::RoundRobin { + num_partitions, + next_idx, + } => { + let idx = *next_idx; + *next_idx = (*next_idx + 1) % *num_partitions; + Box::new(std::iter::once(Ok((idx, batch)))) + } + BatchPartitionerState::Hash { + random_state, + exprs, + num_partitions: partitions, + hash_buffer, + } => { + let timer = self.timer.timer(); + + let arrays = exprs + .iter() + .map(|expr| { + Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())) + }) + .collect::>>()?; - create_hashes(&arrays, random_state, hash_buffer)?; + hash_buffer.clear(); + hash_buffer.resize(batch.num_rows(), 0); - let mut indices: Vec<_> = (0..*partitions) - .map(|_| UInt64Builder::with_capacity(batch.num_rows())) - .collect(); + create_hashes(&arrays, random_state, hash_buffer)?; - for (index, hash) in hash_buffer.iter().enumerate() { - indices[(*hash % *partitions as u64) as usize] - .append_value(index as u64); - } + let mut indices: Vec<_> = (0..*partitions) + .map(|_| UInt64Builder::with_capacity(batch.num_rows())) + .collect(); - for (partition, mut indices) in indices.into_iter().enumerate() { - let indices = indices.finish(); - if indices.is_empty() { - continue; + for (index, hash) in hash_buffer.iter().enumerate() { + indices[(*hash % *partitions as u64) as usize] + .append_value(index as u64); } - // Produce batches based on indices - let columns = batch - .columns() - .iter() - .map(|c| { - arrow::compute::take(c.as_ref(), &indices, None) - .map_err(DataFusionError::ArrowError) + let it = indices + .into_iter() + .enumerate() + .filter_map(|(partition, mut indices)| { + let indices = indices.finish(); + (!indices.is_empty()).then_some((partition, indices)) }) - .collect::>>()?; - - let batch = RecordBatch::try_new(batch.schema(), columns).unwrap(); - - timer.stop(); - f(partition, batch)?; - timer.restart(); + .map(move |(partition, indices)| { + // Produce batches based on indices + let columns = batch + .columns() + .iter() + .map(|c| { + arrow::compute::take(c.as_ref(), &indices, None) + .map_err(DataFusionError::ArrowError) + }) + .collect::>>()?; + + let batch = + RecordBatch::try_new(batch.schema(), columns).unwrap(); + + // bind timer so it drops w/ this iterator + let _ = &timer; + + Ok((partition, batch)) + }); + + Box::new(it) } - } - } - Ok(()) + }; + + Ok(it) } } @@ -337,22 +365,15 @@ impl ExecutionPlan for RepartitionExec { // if this is the first partition to be invoked then we need to set up initial state if state.channels.is_empty() { // create one channel per *output* partition - for partition in 0..num_output_partitions { - // Note that this operator uses unbounded channels to avoid deadlocks because - // the output partitions can be read in any order and this could cause input - // partitions to be blocked when sending data to output UnboundedReceivers that are not - // being read yet. This may cause high memory usage if the next operator is - // reading output partitions in order rather than concurrently. One workaround - // for this would be to add spill-to-disk capabilities. - let (sender, receiver) = - mpsc::unbounded_channel::>>(); + // note we use a custom channel that ensures there is always data for each receiver + // but limits the amount of buffering if required. + let (txs, rxs) = channels(num_output_partitions); + for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("RepartitionExec[{partition}]")) .register(context.memory_pool()), )); - state - .channels - .insert(partition, (sender, receiver, reservation)); + state.channels.insert(partition, (tx, rx, reservation)); } // launch one async task per *input* partition @@ -407,7 +428,7 @@ impl ExecutionPlan for RepartitionExec { num_input_partitions, num_input_partitions_processed: 0, schema: self.input.schema(), - input: UnboundedReceiverStream::new(rx), + input: rx, drop_helper: Arc::clone(&state.abort_helper), reservation, })) @@ -460,7 +481,10 @@ impl RepartitionExec { async fn pull_from_input( input: Arc, i: usize, - mut txs: HashMap, SharedMemoryReservation)>, + mut txs: HashMap< + usize, + (DistributionSender, SharedMemoryReservation), + >, partitioning: Partitioning, r_metrics: RepartitionMetrics, context: Arc, @@ -487,23 +511,23 @@ impl RepartitionExec { None => break, }; - partitioner.partition(batch, |partition, partitioned| { - let size = partitioned.get_array_memory_size(); + for res in partitioner.partition_iter(batch)? { + let (partition, batch) = res?; + let size = batch.get_array_memory_size(); let timer = r_metrics.send_time.timer(); // if there is still a receiver, send to it if let Some((tx, reservation)) = txs.get_mut(&partition) { reservation.lock().try_grow(size)?; - if tx.send(Some(Ok(partitioned))).is_err() { + if tx.send(Some(Ok(batch))).await.is_err() { // If the other end has hung up, it was an early shutdown (e.g. LIMIT) reservation.lock().shrink(size); txs.remove(&partition); } } timer.done(); - Ok(()) - })?; + } } Ok(()) @@ -516,7 +540,7 @@ impl RepartitionExec { /// channels. async fn wait_for_task( input_task: AbortOnDropSingle>, - txs: HashMap>>>, + txs: HashMap>>>, ) { // wait for completion, and propagate error // note we ignore errors on send (.ok) as that means the receiver has already shutdown. @@ -532,7 +556,7 @@ impl RepartitionExec { Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))), ), ))); - tx.send(Some(err)).ok(); + tx.send(Some(err)).await.ok(); } } // Error from running input task @@ -542,14 +566,14 @@ impl RepartitionExec { for (_, tx) in txs { // wrap it because need to send error to all output partitions let err = Err(ArrowError::ExternalError(Box::new(e.clone()))); - tx.send(Some(err)).ok(); + tx.send(Some(err)).await.ok(); } } // Input task completed successfully Ok(Ok(())) => { // notify each output partition that this input partition has no more data for (_, tx) in txs { - tx.send(None).ok(); + tx.send(None).await.ok(); } } } @@ -567,7 +591,7 @@ struct RepartitionStream { schema: SchemaRef, /// channel containing the repartitioned batches - input: UnboundedReceiverStream>>, + input: DistributionReceiver, /// Handle to ensure background tasks are killed when no longer needed. #[allow(dead_code)] @@ -585,7 +609,7 @@ impl Stream for RepartitionStream { cx: &mut Context<'_>, ) -> Poll> { loop { - match self.input.poll_next_unpin(cx) { + match self.input.recv().poll_unpin(cx) { Poll::Ready(Some(Some(v))) => { if let Ok(batch) = &v { self.reservation