Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/repartition/mod.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! This file implements the [`RepartitionExec`]  operator, which maps N input
19
//! partitions to M output partitions based on a partitioning scheme, optionally
20
//! maintaining the order of the input rows in the output.
21
22
use std::pin::Pin;
23
use std::sync::Arc;
24
use std::task::{Context, Poll};
25
use std::{any::Any, vec};
26
27
use super::common::SharedMemoryReservation;
28
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
29
use super::{
30
    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
31
};
32
use crate::hash_utils::create_hashes;
33
use crate::metrics::BaselineMetrics;
34
use crate::repartition::distributor_channels::{
35
    channels, partition_aware_channels, DistributionReceiver, DistributionSender,
36
};
37
use crate::sorts::streaming_merge::StreamingMergeBuilder;
38
use crate::stream::RecordBatchStreamAdapter;
39
use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
40
41
use arrow::datatypes::{SchemaRef, UInt32Type};
42
use arrow::record_batch::RecordBatch;
43
use arrow_array::{PrimitiveArray, RecordBatchOptions};
44
use datafusion_common::utils::{take_arrays, transpose};
45
use datafusion_common::{not_impl_err, DataFusionError, Result};
46
use datafusion_common_runtime::SpawnedTask;
47
use datafusion_execution::memory_pool::MemoryConsumer;
48
use datafusion_execution::TaskContext;
49
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr};
50
51
use futures::stream::Stream;
52
use futures::{FutureExt, StreamExt, TryStreamExt};
53
use hashbrown::HashMap;
54
use log::trace;
55
use parking_lot::Mutex;
56
57
mod distributor_channels;
58
59
type MaybeBatch = Option<Result<RecordBatch>>;
60
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
61
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
62
63
/// Inner state of [`RepartitionExec`].
64
#[derive(Debug)]
65
struct RepartitionExecState {
66
    /// Channels for sending batches from input partitions to output partitions.
67
    /// Key is the partition number.
68
    channels: HashMap<
69
        usize,
70
        (
71
            InputPartitionsToCurrentPartitionSender,
72
            InputPartitionsToCurrentPartitionReceiver,
73
            SharedMemoryReservation,
74
        ),
75
    >,
76
77
    /// Helper that ensures that that background job is killed once it is no longer needed.
78
    abort_helper: Arc<Vec<SpawnedTask<()>>>,
79
}
80
81
impl RepartitionExecState {
82
1.39k
    fn new(
83
1.39k
        input: Arc<dyn ExecutionPlan>,
84
1.39k
        partitioning: Partitioning,
85
1.39k
        metrics: ExecutionPlanMetricsSet,
86
1.39k
        preserve_order: bool,
87
1.39k
        name: String,
88
1.39k
        context: Arc<TaskContext>,
89
1.39k
    ) -> Self {
90
1.39k
        let num_input_partitions = input.output_partitioning().partition_count();
91
1.39k
        let num_output_partitions = partitioning.partition_count();
92
93
1.39k
        let (txs, rxs) = if preserve_order {
94
0
            let (txs, rxs) =
95
0
                partition_aware_channels(num_input_partitions, num_output_partitions);
96
0
            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
97
0
            let txs = transpose(txs);
98
0
            let rxs = transpose(rxs);
99
0
            (txs, rxs)
100
        } else {
101
            // create one channel per *output* partition
102
            // note we use a custom channel that ensures there is always data for each receiver
103
            // but limits the amount of buffering if required.
104
1.39k
            let (txs, rxs) = channels(num_output_partitions);
105
1.39k
            // Clone sender for each input partitions
106
1.39k
            let txs = txs
107
1.39k
                .into_iter()
108
5.55k
                .map(|item| vec![item; num_input_partitions])
109
1.39k
                .collect::<Vec<_>>();
110
5.55k
            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
111
1.39k
            (txs, rxs)
112
        };
113
114
1.39k
        let mut channels = HashMap::with_capacity(txs.len());
115
5.55k
        for (partition, (tx, rx)) in 
txs.into_iter().zip(rxs).enumerate()1.39k
{
116
5.55k
            let reservation = Arc::new(Mutex::new(
117
5.55k
                MemoryConsumer::new(format!("{}[{partition}]", name))
118
5.55k
                    .register(context.memory_pool()),
119
5.55k
            ));
120
5.55k
            channels.insert(partition, (tx, rx, reservation));
121
5.55k
        }
122
123
        // launch one async task per *input* partition
124
1.39k
        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
125
1.40k
        for i in 0..
num_input_partitions1.39k
{
126
1.40k
            let txs: HashMap<_, _> = channels
127
1.40k
                .iter()
128
5.59k
                .map(|(partition, (tx, _rx, reservation))| {
129
5.59k
                    (*partition, (tx[i].clone(), Arc::clone(reservation)))
130
5.59k
                })
131
1.40k
                .collect();
132
1.40k
133
1.40k
            let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics);
134
1.40k
135
1.40k
            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
136
1.40k
                Arc::clone(&input),
137
1.40k
                i,
138
1.40k
                txs.clone(),
139
1.40k
                partitioning.clone(),
140
1.40k
                r_metrics,
141
1.40k
                Arc::clone(&context),
142
1.40k
            ));
143
1.40k
144
1.40k
            // In a separate task, wait for each input to be done
145
1.40k
            // (and pass along any errors, including panic!s)
146
1.40k
            let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task(
147
1.40k
                input_task,
148
1.40k
                txs.into_iter()
149
5.59k
                    .map(|(partition, (tx, _reservation))| (partition, tx))
150
1.40k
                    .collect(),
151
1.40k
            ));
152
1.40k
            spawned_tasks.push(wait_for_task);
153
1.40k
        }
154
155
1.39k
        Self {
156
1.39k
            channels,
157
1.39k
            abort_helper: Arc::new(spawned_tasks),
158
1.39k
        }
159
1.39k
    }
160
}
161
162
/// Lazily initialized state
163
///
164
/// Note that the state is initialized ONCE for all partitions by a single task(thread).
165
/// This may take a short while.  It is also like that multiple threads
166
/// call execute at the same time, because we have just started "target partitions" tasks
167
/// which is commonly set to the number of CPU cores and all call execute at the same time.
168
///
169
/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles
170
/// in a futex lock but instead allow other threads to do something useful.
171
///
172
/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration
173
///  (e.g. removing channels on completion) where the overhead of `await` is not warranted.
174
type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>;
175
176
/// A utility that can be used to partition batches based on [`Partitioning`]
177
pub struct BatchPartitioner {
178
    state: BatchPartitionerState,
179
    timer: metrics::Time,
180
}
181
182
enum BatchPartitionerState {
183
    Hash {
184
        random_state: ahash::RandomState,
185
        exprs: Vec<Arc<dyn PhysicalExpr>>,
186
        num_partitions: usize,
187
        hash_buffer: Vec<u64>,
188
    },
189
    RoundRobin {
190
        num_partitions: usize,
191
        next_idx: usize,
192
    },
193
}
194
195
impl BatchPartitioner {
196
    /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`]
197
    ///
198
    /// The time spent repartitioning will be recorded to `timer`
199
1.40k
    pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
200
1.40k
        let 
state1.40k
= match partitioning {
201
24
            Partitioning::RoundRobinBatch(num_partitions) => {
202
24
                BatchPartitionerState::RoundRobin {
203
24
                    num_partitions,
204
24
                    next_idx: 0,
205
24
                }
206
            }
207
1.37k
            Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
208
1.37k
                exprs,
209
1.37k
                num_partitions,
210
1.37k
                // Use fixed random hash
211
1.37k
                random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
212
1.37k
                hash_buffer: vec![],
213
1.37k
            },
214
3
            other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
215
        };
216
217
1.40k
        Ok(Self { state, timer })
218
1.40k
    }
219
220
    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
221
    /// based on the [`Partitioning`] specified on construction
222
    ///
223
    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
224
    /// partition index. Any error returned by `f` will be immediately returned by this
225
    /// function without attempting to publish further [`RecordBatch`]
226
    ///
227
    /// The time spent repartitioning, not including time spent in `f` will be recorded
228
    /// to the [`metrics::Time`] provided on construction
229
0
    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
230
0
    where
231
0
        F: FnMut(usize, RecordBatch) -> Result<()>,
232
0
    {
233
0
        self.partition_iter(batch)?.try_for_each(|res| match res {
234
0
            Ok((partition, batch)) => f(partition, batch),
235
0
            Err(e) => Err(e),
236
0
        })
237
0
    }
238
239
    /// Actual implementation of [`partition`](Self::partition).
240
    ///
241
    /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions,
242
    /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve
243
    /// this (so we don't need to clone the entire implementation).
244
6.04k
    fn partition_iter(
245
6.04k
        &mut self,
246
6.04k
        batch: RecordBatch,
247
6.04k
    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
248
6.04k
        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
249
6.04k
            match &mut self.state {
250
                BatchPartitionerState::RoundRobin {
251
516
                    num_partitions,
252
516
                    next_idx,
253
516
                } => {
254
516
                    let idx = *next_idx;
255
516
                    *next_idx = (*next_idx + 1) % *num_partitions;
256
516
                    Box::new(std::iter::once(Ok((idx, batch))))
257
                }
258
                BatchPartitionerState::Hash {
259
5.52k
                    random_state,
260
5.52k
                    exprs,
261
5.52k
                    num_partitions: partitions,
262
5.52k
                    hash_buffer,
263
5.52k
                } => {
264
5.52k
                    // Tracking time required for distributing indexes across output partitions
265
5.52k
                    let timer = self.timer.timer();
266
267
5.52k
                    let arrays = exprs
268
5.52k
                        .iter()
269
5.52k
                        .map(|expr| expr.evaluate(&batch)
?0
.into_array(batch.num_rows()))
270
5.52k
                        .collect::<Result<Vec<_>>>()
?0
;
271
272
5.52k
                    hash_buffer.clear();
273
5.52k
                    hash_buffer.resize(batch.num_rows(), 0);
274
5.52k
275
5.52k
                    create_hashes(&arrays, random_state, hash_buffer)
?0
;
276
277
5.52k
                    let mut indices: Vec<_> = (0..*partitions)
278
22.6k
                        .map(|_| Vec::with_capacity(batch.num_rows()))
279
5.52k
                        .collect();
280
281
41.2k
                    for (index, hash) in 
hash_buffer.iter().enumerate()5.52k
{
282
41.2k
                        indices[(*hash % *partitions as u64) as usize].push(index as u32);
283
41.2k
                    }
284
285
                    // Finished building index-arrays for output partitions
286
5.52k
                    timer.done();
287
5.52k
288
5.52k
                    // Borrowing partitioner timer to prevent moving `self` to closure
289
5.52k
                    let partitioner_timer = &self.timer;
290
5.52k
                    let it = indices
291
5.52k
                        .into_iter()
292
5.52k
                        .enumerate()
293
22.6k
                        .filter_map(|(partition, indices)| {
294
22.6k
                            let indices: PrimitiveArray<UInt32Type> = indices.into();
295
22.6k
                            (!indices.is_empty()).then_some((partition, indices))
296
22.6k
                        })
297
16.9k
                        .map(move |(partition, indices)| {
298
16.9k
                            // Tracking time required for repartitioned batches construction
299
16.9k
                            let _timer = partitioner_timer.timer();
300
301
                            // Produce batches based on indices
302
16.9k
                            let columns = take_arrays(batch.columns(), &indices)
?0
;
303
304
16.9k
                            let mut options = RecordBatchOptions::new();
305
16.9k
                            options = options.with_row_count(Some(indices.len()));
306
16.9k
                            let batch = RecordBatch::try_new_with_options(
307
16.9k
                                batch.schema(),
308
16.9k
                                columns,
309
16.9k
                                &options,
310
16.9k
                            )
311
16.9k
                            .unwrap();
312
16.9k
313
16.9k
                            Ok((partition, batch))
314
16.9k
                        });
315
5.52k
316
5.52k
                    Box::new(it)
317
                }
318
            };
319
320
6.04k
        Ok(it)
321
6.04k
    }
322
323
    // return the number of output partitions
324
1.55k
    fn num_partitions(&self) -> usize {
325
1.55k
        match self.state {
326
157
            BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
327
1.39k
            BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
328
        }
329
1.55k
    }
330
}
331
332
/// Maps `N` input partitions to `M` output partitions based on a
333
/// [`Partitioning`] scheme.
334
///
335
/// # Background
336
///
337
/// DataFusion, like most other commercial systems, with the
338
/// notable exception of DuckDB, uses the "Exchange Operator" based
339
/// approach to parallelism which works well in practice given
340
/// sufficient care in implementation.
341
///
342
/// DataFusion's planner picks the target number of partitions and
343
/// then `RepartionExec` redistributes [`RecordBatch`]es to that number
344
/// of output partitions.
345
///
346
/// For example, given `target_partitions=3` (trying to use 3 cores)
347
/// but scanning an input with 2 partitions, `RepartitionExec` can be
348
/// used to get 3 even streams of `RecordBatch`es
349
///
350
///
351
///```text
352
///        ▲                  ▲                  ▲
353
///        │                  │                  │
354
///        │                  │                  │
355
///        │                  │                  │
356
///┌───────────────┐  ┌───────────────┐  ┌───────────────┐
357
///│    GroupBy    │  │    GroupBy    │  │    GroupBy    │
358
///│   (Partial)   │  │   (Partial)   │  │   (Partial)   │
359
///└───────────────┘  └───────────────┘  └───────────────┘
360
///        ▲                  ▲                  ▲
361
///        └──────────────────┼──────────────────┘
362
///                           │
363
///              ┌─────────────────────────┐
364
///              │     RepartitionExec     │
365
///              │   (hash/round robin)    │
366
///              └─────────────────────────┘
367
///                         ▲   ▲
368
///             ┌───────────┘   └───────────┐
369
///             │                           │
370
///             │                           │
371
///        .─────────.                 .─────────.
372
///     ,─'           '─.           ,─'           '─.
373
///    ;      Input      :         ;      Input      :
374
///    :   Partition 0   ;         :   Partition 1   ;
375
///     ╲               ╱           ╲               ╱
376
///      '─.         ,─'             '─.         ,─'
377
///         `───────'                   `───────'
378
///```
379
///
380
/// # Error Handling
381
///
382
/// If any of the input partitions return an error, the error is propagated to
383
/// all output partitions and inputs are not polled again.
384
///
385
/// # Output Ordering
386
///
387
/// If more than one stream is being repartitioned, the output will be some
388
/// arbitrary interleaving (and thus unordered) unless
389
/// [`Self::with_preserve_order`] specifies otherwise.
390
///
391
/// # Footnote
392
///
393
/// The "Exchange Operator" was first described in the 1989 paper
394
/// [Encapsulation of parallelism in the Volcano query processing
395
/// system
396
/// Paper](https://w6113.github.io/files/papers/volcanoparallelism-89.pdf)
397
/// which uses the term "Exchange" for the concept of repartitioning
398
/// data across threads.
399
#[derive(Debug)]
400
pub struct RepartitionExec {
401
    /// Input execution plan
402
    input: Arc<dyn ExecutionPlan>,
403
    /// Inner state that is initialized when the first output stream is created.
404
    state: LazyState,
405
    /// Execution metrics
406
    metrics: ExecutionPlanMetricsSet,
407
    /// Boolean flag to decide whether to preserve ordering. If true means
408
    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
409
    preserve_order: bool,
410
    /// Cache holding plan properties like equivalences, output partitioning etc.
411
    cache: PlanProperties,
412
}
413
414
#[derive(Debug, Clone)]
415
struct RepartitionMetrics {
416
    /// Time in nanos to execute child operator and fetch batches
417
    fetch_time: metrics::Time,
418
    /// Repartitioning elapsed time in nanos
419
    repartition_time: metrics::Time,
420
    /// Time in nanos for sending resulting batches to channels.
421
    ///
422
    /// One metric per output partition.
423
    send_time: Vec<metrics::Time>,
424
}
425
426
impl RepartitionMetrics {
427
1.40k
    pub fn new(
428
1.40k
        input_partition: usize,
429
1.40k
        num_output_partitions: usize,
430
1.40k
        metrics: &ExecutionPlanMetricsSet,
431
1.40k
    ) -> Self {
432
1.40k
        // Time in nanos to execute child operator and fetch batches
433
1.40k
        let fetch_time =
434
1.40k
            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
435
1.40k
436
1.40k
        // Time in nanos to perform repartitioning
437
1.40k
        let repartition_time =
438
1.40k
            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
439
1.40k
440
1.40k
        // Time in nanos for sending resulting batches to channels
441
1.40k
        let send_time = (0..num_output_partitions)
442
5.59k
            .map(|output_partition| {
443
5.59k
                let label =
444
5.59k
                    metrics::Label::new("outputPartition", output_partition.to_string());
445
5.59k
                MetricBuilder::new(metrics)
446
5.59k
                    .with_label(label)
447
5.59k
                    .subset_time("send_time", input_partition)
448
5.59k
            })
449
1.40k
            .collect();
450
1.40k
451
1.40k
        Self {
452
1.40k
            fetch_time,
453
1.40k
            repartition_time,
454
1.40k
            send_time,
455
1.40k
        }
456
1.40k
    }
457
}
458
459
impl RepartitionExec {
460
    /// Input execution plan
461
0
    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
462
0
        &self.input
463
0
    }
464
465
    /// Partitioning scheme to use
466
5.56k
    pub fn partitioning(&self) -> &Partitioning {
467
5.56k
        &self.cache.partitioning
468
5.56k
    }
469
470
    /// Get preserve_order flag of the RepartitionExecutor
471
    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
472
0
    pub fn preserve_order(&self) -> bool {
473
0
        self.preserve_order
474
0
    }
475
476
    /// Get name used to display this Exec
477
5.56k
    pub fn name(&self) -> &str {
478
5.56k
        "RepartitionExec"
479
5.56k
    }
480
}
481
482
impl DisplayAs for RepartitionExec {
483
3
    fn fmt_as(
484
3
        &self,
485
3
        t: DisplayFormatType,
486
3
        f: &mut std::fmt::Formatter,
487
3
    ) -> std::fmt::Result {
488
3
        match t {
489
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
490
3
                write!(
491
3
                    f,
492
3
                    "{}: partitioning={}, input_partitions={}",
493
3
                    self.name(),
494
3
                    self.partitioning(),
495
3
                    self.input.output_partitioning().partition_count()
496
3
                )
?0
;
497
498
3
                if self.preserve_order {
499
1
                    write!(f, ", preserve_order=true")
?0
;
500
2
                }
501
502
3
                if let Some(
sort_exprs1
) = self.sort_exprs() {
503
1
                    write!(
504
1
                        f,
505
1
                        ", sort_exprs={}",
506
1
                        PhysicalSortExpr::format_list(sort_exprs)
507
1
                    )
?0
;
508
2
                }
509
3
                Ok(())
510
            }
511
        }
512
3
    }
513
}
514
515
impl ExecutionPlan for RepartitionExec {
516
0
    fn name(&self) -> &'static str {
517
0
        "RepartitionExec"
518
0
    }
519
520
    /// Return a reference to Any that can be used for downcasting
521
0
    fn as_any(&self) -> &dyn Any {
522
0
        self
523
0
    }
524
525
25.7k
    fn properties(&self) -> &PlanProperties {
526
25.7k
        &self.cache
527
25.7k
    }
528
529
3
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
530
3
        vec![&self.input]
531
3
    }
532
533
0
    fn with_new_children(
534
0
        self: Arc<Self>,
535
0
        mut children: Vec<Arc<dyn ExecutionPlan>>,
536
0
    ) -> Result<Arc<dyn ExecutionPlan>> {
537
0
        let mut repartition = RepartitionExec::try_new(
538
0
            children.swap_remove(0),
539
0
            self.partitioning().clone(),
540
0
        )?;
541
0
        if self.preserve_order {
542
0
            repartition = repartition.with_preserve_order();
543
0
        }
544
0
        Ok(Arc::new(repartition))
545
0
    }
546
547
0
    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
548
0
        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
549
0
    }
550
551
0
    fn maintains_input_order(&self) -> Vec<bool> {
552
0
        Self::maintains_input_order_helper(self.input(), self.preserve_order)
553
0
    }
554
555
5.55k
    fn execute(
556
5.55k
        &self,
557
5.55k
        partition: usize,
558
5.55k
        context: Arc<TaskContext>,
559
5.55k
    ) -> Result<SendableRecordBatchStream> {
560
5.55k
        trace!(
561
0
            "Start {}::execute for partition: {}",
562
0
            self.name(),
563
            partition
564
        );
565
566
5.55k
        let lazy_state = Arc::clone(&self.state);
567
5.55k
        let input = Arc::clone(&self.input);
568
5.55k
        let partitioning = self.partitioning().clone();
569
5.55k
        let metrics = self.metrics.clone();
570
5.55k
        let preserve_order = self.preserve_order;
571
5.55k
        let name = self.name().to_owned();
572
5.55k
        let schema = self.schema();
573
5.55k
        let schema_captured = Arc::clone(&schema);
574
5.55k
575
5.55k
        // Get existing ordering to use for merging
576
5.55k
        let sort_exprs = self.sort_exprs().unwrap_or(&[]).to_owned();
577
5.55k
578
5.55k
        let stream = futures::stream::once(async move {
579
5.54k
            let num_input_partitions = input.output_partitioning().partition_count();
580
5.54k
581
5.54k
            let input_captured = Arc::clone(&input);
582
5.54k
            let metrics_captured = metrics.clone();
583
5.54k
            let name_captured = name.clone();
584
5.54k
            let context_captured = Arc::clone(&context);
585
5.54k
            let state = lazy_state
586
5.54k
                .get_or_init(|| async move {
587
1.39k
                    Mutex::new(RepartitionExecState::new(
588
1.39k
                        input_captured,
589
1.39k
                        partitioning,
590
1.39k
                        metrics_captured,
591
1.39k
                        preserve_order,
592
1.39k
                        name_captured,
593
1.39k
                        context_captured,
594
1.39k
                    ))
595
5.54k
                
}1.39k
)
596
0
                .await;
597
598
            // lock scope
599
5.54k
            let (mut rx, reservation, abort_helper) = {
600
5.54k
                // lock mutexes
601
5.54k
                let mut state = state.lock();
602
5.54k
603
5.54k
                // now return stream for the specified *output* partition which will
604
5.54k
                // read from the channel
605
5.54k
                let (_tx, rx, reservation) = state
606
5.54k
                    .channels
607
5.54k
                    .remove(&partition)
608
5.54k
                    .expect("partition not used yet");
609
5.54k
610
5.54k
                (rx, reservation, Arc::clone(&state.abort_helper))
611
5.54k
            };
612
5.54k
613
5.54k
            trace!(
614
0
                "Before returning stream in {}::execute for partition: {}",
615
                name,
616
                partition
617
            );
618
619
5.54k
            if preserve_order {
620
                // Store streams from all the input partitions:
621
0
                let input_streams = rx
622
0
                    .into_iter()
623
0
                    .map(|receiver| {
624
0
                        Box::pin(PerPartitionStream {
625
0
                            schema: Arc::clone(&schema_captured),
626
0
                            receiver,
627
0
                            drop_helper: Arc::clone(&abort_helper),
628
0
                            reservation: Arc::clone(&reservation),
629
0
                        }) as SendableRecordBatchStream
630
0
                    })
631
0
                    .collect::<Vec<_>>();
632
0
                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
633
0
634
0
                // Merge streams (while preserving ordering) coming from
635
0
                // input partitions to this partition:
636
0
                let fetch = None;
637
0
                let merge_reservation =
638
0
                    MemoryConsumer::new(format!("{}[Merge {partition}]", name))
639
0
                        .register(context.memory_pool());
640
0
                StreamingMergeBuilder::new()
641
0
                    .with_streams(input_streams)
642
0
                    .with_schema(schema_captured)
643
0
                    .with_expressions(&sort_exprs)
644
0
                    .with_metrics(BaselineMetrics::new(&metrics, partition))
645
0
                    .with_batch_size(context.session_config().batch_size())
646
0
                    .with_fetch(fetch)
647
0
                    .with_reservation(merge_reservation)
648
0
                    .build()
649
            } else {
650
5.54k
                Ok(Box::pin(RepartitionStream {
651
5.54k
                    num_input_partitions,
652
5.54k
                    num_input_partitions_processed: 0,
653
5.54k
                    schema: input.schema(),
654
5.54k
                    input: rx.swap_remove(0),
655
5.54k
                    drop_helper: abort_helper,
656
5.54k
                    reservation,
657
5.54k
                }) as SendableRecordBatchStream)
658
            }
659
5.55k
        
}5.54k
)
660
5.55k
        .try_flatten();
661
5.55k
        let stream = RecordBatchStreamAdapter::new(schema, stream);
662
5.55k
        Ok(Box::pin(stream))
663
5.55k
    }
664
665
0
    fn metrics(&self) -> Option<MetricsSet> {
666
0
        Some(self.metrics.clone_inner())
667
0
    }
668
669
0
    fn statistics(&self) -> Result<Statistics> {
670
0
        self.input.statistics()
671
0
    }
672
}
673
674
impl RepartitionExec {
675
    /// Create a new RepartitionExec, that produces output `partitioning`, and
676
    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
677
    /// for more details)
678
1.40k
    pub fn try_new(
679
1.40k
        input: Arc<dyn ExecutionPlan>,
680
1.40k
        partitioning: Partitioning,
681
1.40k
    ) -> Result<Self> {
682
1.40k
        let preserve_order = false;
683
1.40k
        let cache =
684
1.40k
            Self::compute_properties(&input, partitioning.clone(), preserve_order);
685
1.40k
        Ok(RepartitionExec {
686
1.40k
            input,
687
1.40k
            state: Default::default(),
688
1.40k
            metrics: ExecutionPlanMetricsSet::new(),
689
1.40k
            preserve_order,
690
1.40k
            cache,
691
1.40k
        })
692
1.40k
    }
693
694
1.40k
    fn maintains_input_order_helper(
695
1.40k
        input: &Arc<dyn ExecutionPlan>,
696
1.40k
        preserve_order: bool,
697
1.40k
    ) -> Vec<bool> {
698
1.40k
        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
699
1.40k
        vec![preserve_order || 
input.output_partitioning().partition_count() <= 11.40k
]
700
1.40k
    }
701
702
1.40k
    fn eq_properties_helper(
703
1.40k
        input: &Arc<dyn ExecutionPlan>,
704
1.40k
        preserve_order: bool,
705
1.40k
    ) -> EquivalenceProperties {
706
1.40k
        // Equivalence Properties
707
1.40k
        let mut eq_properties = input.equivalence_properties().clone();
708
1.40k
        // If the ordering is lost, reset the ordering equivalence class:
709
1.40k
        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
710
11
            eq_properties.clear_orderings();
711
1.39k
        }
712
        // When there are more than one input partitions, they will be fused at the output.
713
        // Therefore, remove per partition constants.
714
1.40k
        if input.output_partitioning().partition_count() > 1 {
715
12
            eq_properties.clear_per_partition_constants();
716
1.39k
        }
717
1.40k
        eq_properties
718
1.40k
    }
719
720
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
721
1.40k
    fn compute_properties(
722
1.40k
        input: &Arc<dyn ExecutionPlan>,
723
1.40k
        partitioning: Partitioning,
724
1.40k
        preserve_order: bool,
725
1.40k
    ) -> PlanProperties {
726
1.40k
        // Equivalence Properties
727
1.40k
        let eq_properties = Self::eq_properties_helper(input, preserve_order);
728
1.40k
729
1.40k
        PlanProperties::new(
730
1.40k
            eq_properties,          // Equivalence Properties
731
1.40k
            partitioning,           // Output Partitioning
732
1.40k
            input.execution_mode(), // Execution Mode
733
1.40k
        )
734
1.40k
    }
735
736
    /// Specify if this reparititoning operation should preserve the order of
737
    /// rows from its input when producing output. Preserving order is more
738
    /// expensive at runtime, so should only be set if the output of this
739
    /// operator can take advantage of it.
740
    ///
741
    /// If the input is not ordered, or has only one partition, this is a no op,
742
    /// and the node remains a `RepartitionExec`.
743
3
    pub fn with_preserve_order(mut self) -> Self {
744
3
        self.preserve_order =
745
3
                // If the input isn't ordered, there is no ordering to preserve
746
3
                self.input.output_ordering().is_some() &&
747
                // if there is only one input partition, merging is not required
748
                // to maintain order
749
2
                self.input.output_partitioning().partition_count() > 1;
750
3
        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
751
3
        self.cache = self.cache.with_eq_properties(eq_properties);
752
3
        self
753
3
    }
754
755
    /// Return the sort expressions that are used to merge
756
5.56k
    fn sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
757
5.56k
        if self.preserve_order {
758
1
            self.input.output_ordering()
759
        } else {
760
5.56k
            None
761
        }
762
5.56k
    }
763
764
    /// Pulls data from the specified input plan, feeding it to the
765
    /// output partitions based on the desired partitioning
766
    ///
767
    /// txs hold the output sending channels for each output partition
768
1.40k
    async fn pull_from_input(
769
1.40k
        input: Arc<dyn ExecutionPlan>,
770
1.40k
        partition: usize,
771
1.40k
        mut output_channels: HashMap<
772
1.40k
            usize,
773
1.40k
            (DistributionSender<MaybeBatch>, SharedMemoryReservation),
774
1.40k
        >,
775
1.40k
        partitioning: Partitioning,
776
1.40k
        metrics: RepartitionMetrics,
777
1.40k
        context: Arc<TaskContext>,
778
1.40k
    ) -> Result<()> {
779
1.40k
        let mut partitioner =
780
1.40k
            BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())
?3
;
781
782
        // execute the child operator
783
1.40k
        let timer = metrics.fetch_time.timer();
784
1.40k
        let 
mut stream1.40k
= input.execute(partition, context)
?1
;
785
1.40k
        timer.done();
786
1.40k
787
1.40k
        // While there are still outputs to send to, keep pulling inputs
788
1.40k
        let mut batches_until_yield = partitioner.num_partitions();
789
7.44k
        while !output_channels.is_empty() {
790
            // fetch the next batch
791
7.44k
            let timer = metrics.fetch_time.timer();
792
7.44k
            let result = stream.next().
await9
;
793
7.44k
            timer.done();
794
795
            // Input is done
796
7.44k
            let 
batch6.04k
= match result {
797
6.04k
                Some(result) => result
?1
,
798
1.39k
                None => break,
799
            };
800
801
17.4k
            for res in 
partitioner.partition_iter(batch)6.04k
?0
{
802
17.4k
                let (partition, batch) = res
?0
;
803
17.4k
                let size = batch.get_array_memory_size();
804
17.4k
805
17.4k
                let timer = metrics.send_time[partition].timer();
806
                // if there is still a receiver, send to it
807
17.4k
                if let Some((tx, reservation)) = output_channels.get_mut(&partition) {
808
17.4k
                    reservation.lock().try_grow(size)
?1
;
809
810
17.4k
                    if tx.send(Some(Ok(batch))).
await1.15k
.is_err() {
811
0
                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
812
0
                        reservation.lock().shrink(size);
813
0
                        output_channels.remove(&partition);
814
17.4k
                    }
815
0
                }
816
17.4k
                timer.done();
817
            }
818
819
            // If the input stream is endless, we may spin forever and
820
            // never yield back to tokio.  See
821
            // https://github.com/apache/datafusion/issues/5278.
822
            //
823
            // However, yielding on every batch causes a bottleneck
824
            // when running with multiple cores. See
825
            // https://github.com/apache/datafusion/issues/6290
826
            //
827
            // Thus, heuristically yield after producing num_partition
828
            // batches
829
            //
830
            // In round robin this is ideal as each input will get a
831
            // new batch. In hash partitioning it may yield too often
832
            // on uneven distributions even if some partition can not
833
            // make progress, but parallelism is going to be limited
834
            // in that case anyways
835
6.04k
            if batches_until_yield == 0 {
836
149
                tokio::task::yield_now().await;
837
149
                batches_until_yield = partitioner.num_partitions();
838
5.89k
            } else {
839
5.89k
                batches_until_yield -= 1;
840
5.89k
            }
841
        }
842
843
1.39k
        Ok(())
844
1.40k
    }
845
846
    /// Waits for `input_task` which is consuming one of the inputs to
847
    /// complete. Upon each successful completion, sends a `None` to
848
    /// each of the output tx channels to signal one of the inputs is
849
    /// complete. Upon error, propagates the errors to all output tx
850
    /// channels.
851
1.40k
    async fn wait_for_task(
852
1.40k
        input_task: SpawnedTask<Result<()>>,
853
1.40k
        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
854
1.40k
    ) {
855
1.40k
        // wait for completion, and propagate error
856
1.40k
        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
857
1.40k
858
1.40k
        match input_task.join().
await1.29k
{
859
            // Error in joining task
860
0
            Err(e) => {
861
0
                let e = Arc::new(e);
862
863
0
                for (_, tx) in txs {
864
0
                    let err = Err(DataFusionError::Context(
865
0
                        "Join Error".to_string(),
866
0
                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
867
0
                    ));
868
0
                    tx.send(Some(err)).await.ok();
869
                }
870
            }
871
            // Error from running input task
872
4
            Ok(Err(e)) => {
873
4
                let e = Arc::new(e);
874
875
11
                for (_, 
tx7
) in txs {
876
                    // wrap it because need to send error to all output partitions
877
7
                    let err = Err(DataFusionError::External(Box::new(Arc::clone(&e))));
878
7
                    tx.send(Some(err)).
await1
.ok();
879
                }
880
            }
881
            // Input task completed successfully
882
            Ok(Ok(())) => {
883
                // notify each output partition that this input partition has no more data
884
6.98k
                for (_, 
tx5.58k
) in txs {
885
5.58k
                    tx.send(None).
await2.41k
.
ok()5.58k
;
886
                }
887
            }
888
        }
889
1.40k
    }
890
}
891
892
struct RepartitionStream {
893
    /// Number of input partitions that will be sending batches to this output channel
894
    num_input_partitions: usize,
895
896
    /// Number of input partitions that have finished sending batches to this output channel
897
    num_input_partitions_processed: usize,
898
899
    /// Schema wrapped by Arc
900
    schema: SchemaRef,
901
902
    /// channel containing the repartitioned batches
903
    input: DistributionReceiver<MaybeBatch>,
904
905
    /// Handle to ensure background tasks are killed when no longer needed.
906
    #[allow(dead_code)]
907
    drop_helper: Arc<Vec<SpawnedTask<()>>>,
908
909
    /// Memory reservation.
910
    reservation: SharedMemoryReservation,
911
}
912
913
impl Stream for RepartitionStream {
914
    type Item = Result<RecordBatch>;
915
916
26.3k
    fn poll_next(
917
26.3k
        mut self: Pin<&mut Self>,
918
26.3k
        cx: &mut Context<'_>,
919
26.3k
    ) -> Poll<Option<Self::Item>> {
920
        loop {
921
26.3k
            match self.input.recv().poll_unpin(cx) {
922
17.4k
                Poll::Ready(Some(Some(v))) => {
923
17.4k
                    if let Ok(
batch17.4k
) = &v {
924
17.4k
                        self.reservation
925
17.4k
                            .lock()
926
17.4k
                            .shrink(batch.get_array_memory_size());
927
17.4k
                    }
7
928
929
17.4k
                    return Poll::Ready(Some(v));
930
                }
931
                Poll::Ready(Some(None)) => {
932
5.58k
                    self.num_input_partitions_processed += 1;
933
5.58k
934
5.58k
                    if self.num_input_partitions == self.num_input_partitions_processed {
935
                        // all input partitions have finished sending batches
936
5.54k
                        return Poll::Ready(None);
937
                    } else {
938
                        // other partitions still have data to send
939
41
                        continue;
940
                    }
941
                }
942
                Poll::Ready(None) => {
943
0
                    return Poll::Ready(None);
944
                }
945
                Poll::Pending => {
946
3.31k
                    return Poll::Pending;
947
                }
948
            }
949
        }
950
26.3k
    }
951
}
952
953
impl RecordBatchStream for RepartitionStream {
954
    /// Get the schema
955
0
    fn schema(&self) -> SchemaRef {
956
0
        Arc::clone(&self.schema)
957
0
    }
958
}
959
960
/// This struct converts a receiver to a stream.
961
/// Receiver receives data on an SPSC channel.
962
struct PerPartitionStream {
963
    /// Schema wrapped by Arc
964
    schema: SchemaRef,
965
966
    /// channel containing the repartitioned batches
967
    receiver: DistributionReceiver<MaybeBatch>,
968
969
    /// Handle to ensure background tasks are killed when no longer needed.
970
    #[allow(dead_code)]
971
    drop_helper: Arc<Vec<SpawnedTask<()>>>,
972
973
    /// Memory reservation.
974
    reservation: SharedMemoryReservation,
975
}
976
977
impl Stream for PerPartitionStream {
978
    type Item = Result<RecordBatch>;
979
980
0
    fn poll_next(
981
0
        mut self: Pin<&mut Self>,
982
0
        cx: &mut Context<'_>,
983
0
    ) -> Poll<Option<Self::Item>> {
984
0
        match self.receiver.recv().poll_unpin(cx) {
985
0
            Poll::Ready(Some(Some(v))) => {
986
0
                if let Ok(batch) = &v {
987
0
                    self.reservation
988
0
                        .lock()
989
0
                        .shrink(batch.get_array_memory_size());
990
0
                }
991
0
                Poll::Ready(Some(v))
992
            }
993
            Poll::Ready(Some(None)) => {
994
                // Input partition has finished sending batches
995
0
                Poll::Ready(None)
996
            }
997
0
            Poll::Ready(None) => Poll::Ready(None),
998
0
            Poll::Pending => Poll::Pending,
999
        }
1000
0
    }
1001
}
1002
1003
impl RecordBatchStream for PerPartitionStream {
1004
    /// Get the schema
1005
0
    fn schema(&self) -> SchemaRef {
1006
0
        Arc::clone(&self.schema)
1007
0
    }
1008
}
1009
1010
#[cfg(test)]
1011
mod tests {
1012
    use std::collections::HashSet;
1013
1014
    use super::*;
1015
    use crate::{
1016
        test::{
1017
            assert_is_pending,
1018
            exec::{
1019
                assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1020
                ErrorExec, MockExec,
1021
            },
1022
        },
1023
        {collect, expressions::col, memory::MemoryExec},
1024
    };
1025
1026
    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1027
    use arrow::datatypes::{DataType, Field, Schema};
1028
    use datafusion_common::cast::as_string_array;
1029
    use datafusion_common::{arrow_datafusion_err, assert_batches_sorted_eq, exec_err};
1030
    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1031
1032
    use tokio::task::JoinSet;
1033
1034
    #[tokio::test]
1035
1
    async fn one_to_many_round_robin() -> Result<()> {
1036
1
        // define input partitions
1037
1
        let schema = test_schema();
1038
1
        let partition = create_vec_batches(50);
1039
1
        let partitions = vec![partition];
1040
1
1041
1
        // repartition from 1 input to 4 output
1042
1
        let output_partitions =
1043
15
            
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4))1
.await
?0
;
1044
1
1045
1
        assert_eq!(4, output_partitions.len());
1046
1
        assert_eq!(13, output_partitions[0].len());
1047
1
        assert_eq!(13, output_partitions[1].len());
1048
1
        assert_eq!(12, output_partitions[2].len());
1049
1
        assert_eq!(12, output_partitions[3].len());
1050
1
1051
1
        Ok(())
1052
1
    }
1053
1054
    #[tokio::test]
1055
1
    async fn many_to_one_round_robin() -> Result<()> {
1056
1
        // define input partitions
1057
1
        let schema = test_schema();
1058
1
        let partition = create_vec_batches(50);
1059
1
        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1060
1
1061
1
        // repartition from 3 input to 1 output
1062
1
        let output_partitions =
1063
153
            
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1))1
.await
?0
;
1064
1
1065
1
        assert_eq!(1, output_partitions.len());
1066
1
        assert_eq!(150, output_partitions[0].len());
1067
1
1068
1
        Ok(())
1069
1
    }
1070
1071
    #[tokio::test]
1072
1
    async fn many_to_many_round_robin() -> Result<()> {
1073
1
        // define input partitions
1074
1
        let schema = test_schema();
1075
1
        let partition = create_vec_batches(50);
1076
1
        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1077
1
1078
1
        // repartition from 3 input to 5 output
1079
1
        let output_partitions =
1080
35
            
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5))1
.await
?0
;
1081
1
1082
1
        assert_eq!(5, output_partitions.len());
1083
1
        assert_eq!(30, output_partitions[0].len());
1084
1
        assert_eq!(30, output_partitions[1].len());
1085
1
        assert_eq!(30, output_partitions[2].len());
1086
1
        assert_eq!(30, output_partitions[3].len());
1087
1
        assert_eq!(30, output_partitions[4].len());
1088
1
1089
1
        Ok(())
1090
1
    }
1091
1092
    #[tokio::test]
1093
1
    async fn many_to_many_hash_partition() -> Result<()> {
1094
1
        // define input partitions
1095
1
        let schema = test_schema();
1096
1
        let partition = create_vec_batches(50);
1097
1
        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1098
1
1099
1
        let output_partitions = repartition(
1100
1
            &schema,
1101
1
            partitions,
1102
1
            Partitioning::Hash(vec![col("c0", &schema)
?0
], 8),
1103
1
        )
1104
9
        .await
?0
;
1105
1
1106
1
        let total_rows: usize = output_partitions
1107
1
            .iter()
1108
600
            .map(|x| 
x.iter().map(8
|x| x.num_rows()
).sum::<usize>()8
)
1109
1
            .sum();
1110
1
1111
1
        assert_eq!(8, output_partitions.len());
1112
1
        assert_eq!(total_rows, 8 * 50 * 3);
1113
1
1114
1
        Ok(())
1115
1
    }
1116
1117
12
    fn test_schema() -> Arc<Schema> {
1118
12
        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1119
12
    }
1120
1121
5
    async fn repartition(
1122
5
        schema: &SchemaRef,
1123
5
        input_partitions: Vec<Vec<RecordBatch>>,
1124
5
        partitioning: Partitioning,
1125
5
    ) -> Result<Vec<Vec<RecordBatch>>> {
1126
5
        let task_ctx = Arc::new(TaskContext::default());
1127
        // create physical plan
1128
5
        let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)
?0
;
1129
5
        let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)
?0
;
1130
1131
        // execute and collect results
1132
5
        let mut output_partitions = vec![];
1133
23
        for i in 0..
exec.partitioning().partition_count()5
{
1134
            // execute this *output* partition and collect all batches
1135
23
            let mut stream = exec.execute(i, Arc::clone(&task_ctx))
?0
;
1136
23
            let mut batches = vec![];
1137
1.12k
            while let Some(
result1.10k
) = stream.next().
await247
{
1138
1.10k
                batches.push(result
?0
);
1139
            }
1140
23
            output_partitions.push(batches);
1141
        }
1142
5
        Ok(output_partitions)
1143
5
    }
1144
1145
    #[tokio::test]
1146
1
    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1147
1
        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1148
1
            SpawnedTask::spawn(async move {
1149
1
                // define input partitions
1150
1
                let schema = test_schema();
1151
1
                let partition = create_vec_batches(50);
1152
1
                let partitions =
1153
1
                    vec![partition.clone(), partition.clone(), partition.clone()];
1154
1
1155
1
                // repartition from 3 input to 5 output
1156
35
                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1157
1
            });
1158
1
1159
1
        let output_partitions = handle.join().await.unwrap().unwrap();
1160
1
1161
1
        assert_eq!(5, output_partitions.len());
1162
1
        assert_eq!(30, output_partitions[0].len());
1163
1
        assert_eq!(30, output_partitions[1].len());
1164
1
        assert_eq!(30, output_partitions[2].len());
1165
1
        assert_eq!(30, output_partitions[3].len());
1166
1
        assert_eq!(30, output_partitions[4].len());
1167
1
1168
1
        Ok(())
1169
1
    }
1170
1171
    #[tokio::test]
1172
1
    async fn unsupported_partitioning() {
1173
1
        let task_ctx = Arc::new(TaskContext::default());
1174
1
        // have to send at least one batch through to provoke error
1175
1
        let batch = RecordBatch::try_from_iter(vec![(
1176
1
            "my_awesome_field",
1177
1
            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1178
1
        )])
1179
1
        .unwrap();
1180
1
1181
1
        let schema = batch.schema();
1182
1
        let input = MockExec::new(vec![Ok(batch)], schema);
1183
1
        // This generates an error (partitioning type not supported)
1184
1
        // but only after the plan is executed. The error should be
1185
1
        // returned and no results produced
1186
1
        let partitioning = Partitioning::UnknownPartitioning(1);
1187
1
        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1188
1
        let output_stream = exec.execute(0, task_ctx).unwrap();
1189
1
1190
1
        // Expect that an error is returned
1191
1
        let result_string = crate::common::collect(output_stream)
1192
1
            .await
1193
1
            .unwrap_err()
1194
1
            .to_string();
1195
1
        assert!(
1196
1
            result_string
1197
1
                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1198
1
            
"actual: {result_string}"0
1199
1
        );
1200
1
    }
1201
1202
    #[tokio::test]
1203
1
    async fn error_for_input_exec() {
1204
1
        // This generates an error on a call to execute. The error
1205
1
        // should be returned and no results produced.
1206
1
1207
1
        let task_ctx = Arc::new(TaskContext::default());
1208
1
        let input = ErrorExec::new();
1209
1
        let partitioning = Partitioning::RoundRobinBatch(1);
1210
1
        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1211
1
1212
1
        // Note: this should pass (the stream can be created) but the
1213
1
        // error when the input is executed should get passed back
1214
1
        let output_stream = exec.execute(0, task_ctx).unwrap();
1215
1
1216
1
        // Expect that an error is returned
1217
1
        let result_string = crate::common::collect(output_stream)
1218
1
            .await
1219
1
            .unwrap_err()
1220
1
            .to_string();
1221
1
        assert!(
1222
1
            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1223
1
            
"actual: {result_string}"0
1224
1
        );
1225
1
    }
1226
1227
    #[tokio::test]
1228
1
    async fn repartition_with_error_in_stream() {
1229
1
        let task_ctx = Arc::new(TaskContext::default());
1230
1
        let batch = RecordBatch::try_from_iter(vec![(
1231
1
            "my_awesome_field",
1232
1
            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1233
1
        )])
1234
1
        .unwrap();
1235
1
1236
1
        // input stream returns one good batch and then one error. The
1237
1
        // error should be returned.
1238
1
        let err = exec_err!("bad data error");
1239
1
1240
1
        let schema = batch.schema();
1241
1
        let input = MockExec::new(vec![Ok(batch), err], schema);
1242
1
        let partitioning = Partitioning::RoundRobinBatch(1);
1243
1
        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1244
1
1245
1
        // Note: this should pass (the stream can be created) but the
1246
1
        // error when the input is executed should get passed back
1247
1
        let output_stream = exec.execute(0, task_ctx).unwrap();
1248
1
1249
1
        // Expect that an error is returned
1250
1
        let result_string = crate::common::collect(output_stream)
1251
2
            .await
1252
1
            .unwrap_err()
1253
1
            .to_string();
1254
1
        assert!(
1255
1
            result_string.contains("bad data error"),
1256
1
            
"actual: {result_string}"0
1257
1
        );
1258
1
    }
1259
1260
    #[tokio::test]
1261
1
    async fn repartition_with_delayed_stream() {
1262
1
        let task_ctx = Arc::new(TaskContext::default());
1263
1
        let batch1 = RecordBatch::try_from_iter(vec![(
1264
1
            "my_awesome_field",
1265
1
            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1266
1
        )])
1267
1
        .unwrap();
1268
1
1269
1
        let batch2 = RecordBatch::try_from_iter(vec![(
1270
1
            "my_awesome_field",
1271
1
            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1272
1
        )])
1273
1
        .unwrap();
1274
1
1275
1
        // The mock exec doesn't return immediately (instead it
1276
1
        // requires the input to wait at least once)
1277
1
        let schema = batch1.schema();
1278
1
        let expected_batches = vec![batch1.clone(), batch2.clone()];
1279
1
        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1280
1
        let partitioning = Partitioning::RoundRobinBatch(1);
1281
1
1282
1
        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1283
1
1284
1
        let expected = vec![
1285
1
            "+------------------+",
1286
1
            "| my_awesome_field |",
1287
1
            "+------------------+",
1288
1
            "| foo              |",
1289
1
            "| bar              |",
1290
1
            "| frob             |",
1291
1
            "| baz              |",
1292
1
            "+------------------+",
1293
1
        ];
1294
1
1295
1
        assert_batches_sorted_eq!(&expected, &expected_batches);
1296
1
1297
1
        let output_stream = exec.execute(0, task_ctx).unwrap();
1298
3
        let 
batches1
=
crate::common::collect(output_stream)1
.await.unwrap();
1299
1
1300
1
        assert_batches_sorted_eq!(&expected, &batches);
1301
1
    }
1302
1303
    #[tokio::test]
1304
1
    async fn robin_repartition_with_dropping_output_stream() {
1305
1
        let task_ctx = Arc::new(TaskContext::default());
1306
1
        let partitioning = Partitioning::RoundRobinBatch(2);
1307
1
        // The barrier exec waits to be pinged
1308
1
        // requires the input to wait at least once)
1309
1
        let input = Arc::new(make_barrier_exec());
1310
1
1311
1
        // partition into two output streams
1312
1
        let exec = RepartitionExec::try_new(
1313
1
            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1314
1
            partitioning,
1315
1
        )
1316
1
        .unwrap();
1317
1
1318
1
        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1319
1
        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1320
1
1321
1
        // now, purposely drop output stream 0
1322
1
        // *before* any outputs are produced
1323
1
        std::mem::drop(output_stream0);
1324
1
1325
1
        // Now, start sending input
1326
1
        let mut background_task = JoinSet::new();
1327
1
        background_task.spawn(async move {
1328
1
            input.wait().await;
1329
1
        });
1330
1
1331
1
        // output stream 1 should *not* error and have one of the input batches
1332
4
        let 
batches1
=
crate::common::collect(output_stream1)1
.await.unwrap();
1333
1
1334
1
        let expected = vec![
1335
1
            "+------------------+",
1336
1
            "| my_awesome_field |",
1337
1
            "+------------------+",
1338
1
            "| baz              |",
1339
1
            "| frob             |",
1340
1
            "| gaz              |",
1341
1
            "| grob             |",
1342
1
            "+------------------+",
1343
1
        ];
1344
1
1345
1
        assert_batches_sorted_eq!(&expected, &batches);
1346
1
    }
1347
1348
    #[tokio::test]
1349
    // As the hash results might be different on different platforms or
1350
    // with different compilers, we will compare the same execution with
1351
    // and without dropping the output stream.
1352
1
    async fn hash_repartition_with_dropping_output_stream() {
1353
1
        let task_ctx = Arc::new(TaskContext::default());
1354
1
        let partitioning = Partitioning::Hash(
1355
1
            vec![Arc::new(crate::expressions::Column::new(
1356
1
                "my_awesome_field",
1357
1
                0,
1358
1
            ))],
1359
1
            2,
1360
1
        );
1361
1
1362
1
        // We first collect the results without dropping the output stream.
1363
1
        let input = Arc::new(make_barrier_exec());
1364
1
        let exec = RepartitionExec::try_new(
1365
1
            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1366
1
            partitioning.clone(),
1367
1
        )
1368
1
        .unwrap();
1369
1
        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1370
1
        let mut background_task = JoinSet::new();
1371
1
        background_task.spawn(async move {
1372
1
            input.wait().await;
1373
1
        });
1374
5
        let 
batches_without_drop1
=
crate::common::collect(output_stream1)1
.await.unwrap();
1375
1
1376
1
        // run some checks on the result
1377
1
        let items_vec = str_batches_to_vec(&batches_without_drop);
1378
1
        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1379
1
        assert_eq!(items_vec.len(), items_set.len());
1380
1
        let source_str_set: HashSet<&str> =
1381
1
            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1382
1
                .iter()
1383
1
                .copied()
1384
1
                .collect();
1385
1
        assert_eq!(items_set.difference(&source_str_set).count(), 0);
1386
1
1387
1
        // Now do the same but dropping the stream before waiting for the barrier
1388
1
        let input = Arc::new(make_barrier_exec());
1389
1
        let exec = RepartitionExec::try_new(
1390
1
            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1391
1
            partitioning,
1392
1
        )
1393
1
        .unwrap();
1394
1
        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1395
1
        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1396
1
        // now, purposely drop output stream 0
1397
1
        // *before* any outputs are produced
1398
1
        std::mem::drop(output_stream0);
1399
1
        let mut background_task = JoinSet::new();
1400
1
        background_task.spawn(async move {
1401
1
            input.wait().await;
1402
1
        });
1403
5
        let 
batches_with_drop1
=
crate::common::collect(output_stream1)1
.await.unwrap();
1404
1
1405
1
        assert_eq!(batches_without_drop, batches_with_drop);
1406
1
    }
1407
1408
1
    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1409
1
        batches
1410
1
            .iter()
1411
3
            .flat_map(|batch| {
1412
3
                assert_eq!(batch.columns().len(), 1);
1413
3
                let string_array = as_string_array(batch.column(0))
1414
3
                    .expect("Unexpected type for repartitoned batch");
1415
3
1416
3
                string_array
1417
3
                    .iter()
1418
4
                    .map(|v| v.expect("Unexpected null"))
1419
3
                    .collect::<Vec<_>>()
1420
3
            })
1421
1
            .collect::<Vec<_>>()
1422
1
    }
1423
1424
    /// Create a BarrierExec that returns two partitions of two batches each
1425
3
    fn make_barrier_exec() -> BarrierExec {
1426
3
        let batch1 = RecordBatch::try_from_iter(vec![(
1427
3
            "my_awesome_field",
1428
3
            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1429
3
        )])
1430
3
        .unwrap();
1431
3
1432
3
        let batch2 = RecordBatch::try_from_iter(vec![(
1433
3
            "my_awesome_field",
1434
3
            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1435
3
        )])
1436
3
        .unwrap();
1437
3
1438
3
        let batch3 = RecordBatch::try_from_iter(vec![(
1439
3
            "my_awesome_field",
1440
3
            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1441
3
        )])
1442
3
        .unwrap();
1443
3
1444
3
        let batch4 = RecordBatch::try_from_iter(vec![(
1445
3
            "my_awesome_field",
1446
3
            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1447
3
        )])
1448
3
        .unwrap();
1449
3
1450
3
        // The barrier exec waits to be pinged
1451
3
        // requires the input to wait at least once)
1452
3
        let schema = batch1.schema();
1453
3
        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1454
3
    }
1455
1456
    #[tokio::test]
1457
1
    async fn test_drop_cancel() -> Result<()> {
1458
1
        let task_ctx = Arc::new(TaskContext::default());
1459
1
        let schema =
1460
1
            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1461
1
1462
1
        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1463
1
        let refs = blocking_exec.refs();
1464
1
        let repartition_exec = Arc::new(RepartitionExec::try_new(
1465
1
            blocking_exec,
1466
1
            Partitioning::UnknownPartitioning(1),
1467
1
        )
?0
);
1468
1
1469
1
        let fut = collect(repartition_exec, task_ctx);
1470
1
        let mut fut = fut.boxed();
1471
1
1472
1
        assert_is_pending(&mut fut);
1473
1
        drop(fut);
1474
1
        assert_strong_count_converges_to_zero(refs).await;
1475
1
1476
1
        Ok(())
1477
1
    }
1478
1479
    #[tokio::test]
1480
1
    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1481
1
        let task_ctx = Arc::new(TaskContext::default());
1482
1
        let batch = RecordBatch::try_from_iter(vec![(
1483
1
            "a",
1484
1
            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
1485
1
        )])
1486
1
        .unwrap();
1487
1
        let partitioning = Partitioning::Hash(
1488
1
            vec![Arc::new(crate::expressions::Column::new("a", 0))],
1489
1
            2,
1490
1
        );
1491
1
        let schema = batch.schema();
1492
1
        let input = MockExec::new(vec![Ok(batch)], schema);
1493
1
        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1494
1
        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1495
1
        let batch0 = crate::common::collect(output_stream0).await.unwrap();
1496
1
        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1497
1
        let batch1 = crate::common::collect(output_stream1).
await0
.unwrap();
1498
1
        assert!(batch0.is_empty() || 
batch1.is_empty()0
);
1499
1
        Ok(())
1500
1
    }
1501
1502
    #[tokio::test]
1503
1
    async fn oom() -> Result<()> {
1504
1
        // define input partitions
1505
1
        let schema = test_schema();
1506
1
        let partition = create_vec_batches(50);
1507
1
        let input_partitions = vec![partition];
1508
1
        let partitioning = Partitioning::RoundRobinBatch(4);
1509
1
1510
1
        // setup up context
1511
1
        let runtime = RuntimeEnvBuilder::default()
1512
1
            .with_memory_limit(1, 1.0)
1513
1
            .build_arc()
?0
;
1514
1
1515
1
        let task_ctx = TaskContext::default().with_runtime(runtime);
1516
1
        let task_ctx = Arc::new(task_ctx);
1517
1
1518
1
        // create physical plan
1519
1
        let exec = MemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)
?0
;
1520
1
        let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)
?0
;
1521
1
1522
1
        // pull partitions
1523
4
        for i in 0..
exec.partitioning().partition_count()1
{
1524
4
            let mut stream = exec.execute(i, Arc::clone(&task_ctx))
?0
;
1525
4
            let err =
1526
4
                arrow_datafusion_err!(stream.next().
await1
.unwrap().unwrap_err().into());
1527
4
            let err = err.find_root();
1528
4
            assert!(
1529
4
                
matches!0
(err, DataFusionError::ResourcesExhausted(_)),
1530
1
                
"Wrong error type: {err}"0
,
1531
1
            );
1532
1
        }
1533
1
1534
1
        Ok(())
1535
1
    }
1536
1537
    /// Create vector batches
1538
6
    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
1539
6
        let batch = create_batch();
1540
300
        (0..n).map(|_| batch.clone()).collect()
1541
6
    }
1542
1543
    /// Create batch
1544
6
    fn create_batch() -> RecordBatch {
1545
6
        let schema = test_schema();
1546
6
        RecordBatch::try_new(
1547
6
            schema,
1548
6
            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
1549
6
        )
1550
6
        .unwrap()
1551
6
    }
1552
}
1553
1554
#[cfg(test)]
1555
mod test {
1556
    use arrow_schema::{DataType, Field, Schema, SortOptions};
1557
1558
    use datafusion_physical_expr::expressions::col;
1559
1560
    use crate::memory::MemoryExec;
1561
    use crate::union::UnionExec;
1562
1563
    use super::*;
1564
1565
    /// Asserts that the plan is as expected
1566
    ///
1567
    /// `$EXPECTED_PLAN_LINES`: input plan
1568
    /// `$PLAN`: the plan to optimized
1569
    ///
1570
    macro_rules! assert_plan {
1571
        ($EXPECTED_PLAN_LINES: expr,  $PLAN: expr) => {
1572
            let physical_plan = $PLAN;
1573
            let formatted = crate::displayable(&physical_plan).indent(true).to_string();
1574
            let actual: Vec<&str> = formatted.trim().lines().collect();
1575
1576
            let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES
1577
10
                .iter().map(|s| *s).collect();
1578
1579
            assert_eq!(
1580
                expected_plan_lines, actual,
1581
                "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n"
1582
            );
1583
        };
1584
    }
1585
1586
    #[tokio::test]
1587
1
    async fn test_preserve_order() -> Result<()> {
1588
1
        let schema = test_schema();
1589
1
        let sort_exprs = sort_exprs(&schema);
1590
1
        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
1591
1
        let source2 = sorted_memory_exec(&schema, sort_exprs);
1592
1
        // output has multiple partitions, and is sorted
1593
1
        let union = UnionExec::new(vec![source1, source2]);
1594
1
        let exec =
1595
1
            RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1596
1
                .unwrap()
1597
1
                .with_preserve_order();
1598
1
1599
1
        // Repartition should preserve order
1600
1
        let expected_plan = [
1601
1
            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC",
1602
1
            "  UnionExec",
1603
1
            "    MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1604
1
            "    MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1605
1
        ];
1606
1
        assert_plan!(expected_plan, exec);
1607
1
        Ok(())
1608
1
    }
1609
1610
    #[tokio::test]
1611
1
    async fn test_preserve_order_one_partition() -> Result<()> {
1612
1
        let schema = test_schema();
1613
1
        let sort_exprs = sort_exprs(&schema);
1614
1
        let source = sorted_memory_exec(&schema, sort_exprs);
1615
1
        // output is sorted, but has only a single partition, so no need to sort
1616
1
        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))
1617
1
            .unwrap()
1618
1
            .with_preserve_order();
1619
1
1620
1
        // Repartition should not preserve order
1621
1
        let expected_plan = [
1622
1
            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1",
1623
1
            "  MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC",
1624
1
        ];
1625
1
        assert_plan!(expected_plan, exec);
1626
1
        Ok(())
1627
1
    }
1628
1629
    #[tokio::test]
1630
1
    async fn test_preserve_order_input_not_sorted() -> Result<()> {
1631
1
        let schema = test_schema();
1632
1
        let source1 = memory_exec(&schema);
1633
1
        let source2 = memory_exec(&schema);
1634
1
        // output has multiple partitions, but is not sorted
1635
1
        let union = UnionExec::new(vec![source1, source2]);
1636
1
        let exec =
1637
1
            RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10))
1638
1
                .unwrap()
1639
1
                .with_preserve_order();
1640
1
1641
1
        // Repartition should not preserve order, as there is no order to preserve
1642
1
        let expected_plan = [
1643
1
            "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2",
1644
1
            "  UnionExec",
1645
1
            "    MemoryExec: partitions=1, partition_sizes=[0]",
1646
1
            "    MemoryExec: partitions=1, partition_sizes=[0]",
1647
1
        ];
1648
1
        assert_plan!(expected_plan, exec);
1649
1
        Ok(())
1650
1
    }
1651
1652
3
    fn test_schema() -> Arc<Schema> {
1653
3
        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1654
3
    }
1655
1656
2
    fn sort_exprs(schema: &Schema) -> Vec<PhysicalSortExpr> {
1657
2
        let options = SortOptions::default();
1658
2
        vec![PhysicalSortExpr {
1659
2
            expr: col("c0", schema).unwrap(),
1660
2
            options,
1661
2
        }]
1662
2
    }
1663
1664
2
    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
1665
2
        Arc::new(MemoryExec::try_new(&[vec![]], Arc::clone(schema), None).unwrap())
1666
2
    }
1667
1668
3
    fn sorted_memory_exec(
1669
3
        schema: &SchemaRef,
1670
3
        sort_exprs: Vec<PhysicalSortExpr>,
1671
3
    ) -> Arc<dyn ExecutionPlan> {
1672
3
        Arc::new(
1673
3
            MemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
1674
3
                .unwrap()
1675
3
                .with_sort_information(vec![sort_exprs]),
1676
3
        )
1677
3
    }
1678
}