diff --git a/datafusion/core/src/physical_plan/repartition/distributor_channels.rs b/datafusion/core/src/physical_plan/repartition/distributor_channels.rs index 1eedf0d707be..519218f2c236 100644 --- a/datafusion/core/src/physical_plan/repartition/distributor_channels.rs +++ b/datafusion/core/src/physical_plan/repartition/distributor_channels.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Special channel construction to distribute data from varios inputs into N outputs. +//! Special channel construction to distribute data from various inputs into N outputs +//! minimizing buffering but preventing deadlocks when repartitoning //! //! # Design //! @@ -106,9 +107,10 @@ impl std::error::Error for SendError {} /// /// This handle can be cloned. All clones will write into the same channel. Dropping the last sender will close the /// channel. In this case, the [receiver](DistributionReceiver) will still be able to poll the remaining data, but will -/// receiver `None` afterwards. +/// receive `None` afterwards. #[derive(Debug)] pub struct DistributionSender { + /// To prevent lock inversion / deadlock, channel lock is always acquired prior to gate lock channel: SharedChannel, gate: SharedGate, } @@ -185,6 +187,7 @@ impl<'a, T> Future for SendFuture<'a, T> { let mut guard_gate = this.gate.lock(); // does ANY receiver need data? + // if so, allow sender to create another if guard_gate.empty_channels == 0 { guard_gate .send_wakers diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/core/src/physical_plan/repartition/mod.rs index 90e063e94816..1d0f1fe5cc5d 100644 --- a/datafusion/core/src/physical_plan/repartition/mod.rs +++ b/datafusion/core/src/physical_plan/repartition/mod.rs @@ -365,6 +365,8 @@ impl ExecutionPlan for RepartitionExec { // if this is the first partition to be invoked then we need to set up initial state if state.channels.is_empty() { // create one channel per *output* partition + // note we use a custom channel that ensures there is always data for each receiver + // but limits the amount of buffering if required. let (txs, rxs) = channels(num_output_partitions); for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { let reservation = Arc::new(Mutex::new(