Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines the sort preserving merge plan
19
20
use std::any::Any;
21
use std::sync::Arc;
22
23
use crate::common::spawn_buffered;
24
use crate::expressions::PhysicalSortExpr;
25
use crate::limit::LimitStream;
26
use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
27
use crate::sorts::streaming_merge::StreamingMergeBuilder;
28
use crate::{
29
    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30
    Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31
};
32
33
use datafusion_common::{internal_err, Result};
34
use datafusion_execution::memory_pool::MemoryConsumer;
35
use datafusion_execution::TaskContext;
36
use datafusion_physical_expr::PhysicalSortRequirement;
37
38
use datafusion_physical_expr_common::sort_expr::LexRequirement;
39
use log::{debug, trace};
40
41
/// Sort preserving merge execution plan
42
///
43
/// This takes an input execution plan and a list of sort expressions, and
44
/// provided each partition of the input plan is sorted with respect to
45
/// these sort expressions, this operator will yield a single partition
46
/// that is also sorted with respect to them
47
///
48
/// ```text
49
/// ┌─────────────────────────┐
50
/// │ ┌───┬───┬───┬───┐       │
51
/// │ │ A │ B │ C │ D │ ...   │──┐
52
/// │ └───┴───┴───┴───┘       │  │
53
/// └─────────────────────────┘  │  ┌───────────────────┐    ┌───────────────────────────────┐
54
///   Stream 1                   │  │                   │    │ ┌───┬───╦═══╦───┬───╦═══╗     │
55
///                              ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │
56
///                              │  │                   │    │ └───┴─▲─╩═══╩───┴───╩═══╝     │
57
/// ┌─────────────────────────┐  │  └───────────────────┘    └─┬─────┴───────────────────────┘
58
/// │ ╔═══╦═══╗               │  │
59
/// │ ║ B ║ E ║     ...       │──┘                             │
60
/// │ ╚═══╩═══╝               │              Note Stable Sort: the merged stream
61
/// └─────────────────────────┘                places equal rows from stream 1
62
///   Stream 2
63
///
64
///
65
///  Input Streams                                             Output stream
66
///    (sorted)                                                  (sorted)
67
/// ```
68
///
69
/// # Error Handling
70
///
71
/// If any of the input partitions return an error, the error is propagated to
72
/// the output and inputs are not polled again.
73
#[derive(Debug)]
74
pub struct SortPreservingMergeExec {
75
    /// Input plan
76
    input: Arc<dyn ExecutionPlan>,
77
    /// Sort expressions
78
    expr: Vec<PhysicalSortExpr>,
79
    /// Execution metrics
80
    metrics: ExecutionPlanMetricsSet,
81
    /// Optional number of rows to fetch. Stops producing rows after this fetch
82
    fetch: Option<usize>,
83
    /// Cache holding plan properties like equivalences, output partitioning etc.
84
    cache: PlanProperties,
85
}
86
87
impl SortPreservingMergeExec {
88
    /// Create a new sort execution plan
89
15
    pub fn new(expr: Vec<PhysicalSortExpr>, input: Arc<dyn ExecutionPlan>) -> Self {
90
15
        let cache = Self::compute_properties(&input, expr.clone());
91
15
        Self {
92
15
            input,
93
15
            expr,
94
15
            metrics: ExecutionPlanMetricsSet::new(),
95
15
            fetch: None,
96
15
            cache,
97
15
        }
98
15
    }
99
    /// Sets the number of rows to fetch
100
1
    pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
101
1
        self.fetch = fetch;
102
1
        self
103
1
    }
104
105
    /// Input schema
106
0
    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
107
0
        &self.input
108
0
    }
109
110
    /// Sort expressions
111
0
    pub fn expr(&self) -> &[PhysicalSortExpr] {
112
0
        &self.expr
113
0
    }
114
115
    /// Fetch
116
0
    pub fn fetch(&self) -> Option<usize> {
117
0
        self.fetch
118
0
    }
119
120
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
121
15
    fn compute_properties(
122
15
        input: &Arc<dyn ExecutionPlan>,
123
15
        ordering: Vec<PhysicalSortExpr>,
124
15
    ) -> PlanProperties {
125
15
        let mut eq_properties = input.equivalence_properties().clone();
126
15
        eq_properties.clear_per_partition_constants();
127
15
        eq_properties.add_new_orderings(vec![ordering]);
128
15
        PlanProperties::new(
129
15
            eq_properties,                        // Equivalence Properties
130
15
            Partitioning::UnknownPartitioning(1), // Output Partitioning
131
15
            input.execution_mode(),               // Execution Mode
132
15
        )
133
15
    }
134
}
135
136
impl DisplayAs for SortPreservingMergeExec {
137
0
    fn fmt_as(
138
0
        &self,
139
0
        t: DisplayFormatType,
140
0
        f: &mut std::fmt::Formatter,
141
0
    ) -> std::fmt::Result {
142
0
        match t {
143
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
144
0
                write!(
145
0
                    f,
146
0
                    "SortPreservingMergeExec: [{}]",
147
0
                    PhysicalSortExpr::format_list(&self.expr)
148
0
                )?;
149
0
                if let Some(fetch) = self.fetch {
150
0
                    write!(f, ", fetch={fetch}")?;
151
0
                };
152
153
0
                Ok(())
154
            }
155
        }
156
0
    }
157
}
158
159
impl ExecutionPlan for SortPreservingMergeExec {
160
0
    fn name(&self) -> &'static str {
161
0
        "SortPreservingMergeExec"
162
0
    }
163
164
    /// Return a reference to Any that can be used for downcasting
165
0
    fn as_any(&self) -> &dyn Any {
166
0
        self
167
0
    }
168
169
30
    fn properties(&self) -> &PlanProperties {
170
30
        &self.cache
171
30
    }
172
173
0
    fn fetch(&self) -> Option<usize> {
174
0
        self.fetch
175
0
    }
176
177
    /// Sets the number of rows to fetch
178
0
    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
179
0
        Some(Arc::new(Self {
180
0
            input: Arc::clone(&self.input),
181
0
            expr: self.expr.clone(),
182
0
            metrics: self.metrics.clone(),
183
0
            fetch: limit,
184
0
            cache: self.cache.clone(),
185
0
        }))
186
0
    }
187
188
0
    fn required_input_distribution(&self) -> Vec<Distribution> {
189
0
        vec![Distribution::UnspecifiedDistribution]
190
0
    }
191
192
0
    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
193
0
        vec![false]
194
0
    }
195
196
0
    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
197
0
        vec![Some(PhysicalSortRequirement::from_sort_exprs(&self.expr))]
198
0
    }
199
200
0
    fn maintains_input_order(&self) -> Vec<bool> {
201
0
        vec![true]
202
0
    }
203
204
0
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
205
0
        vec![&self.input]
206
0
    }
207
208
0
    fn with_new_children(
209
0
        self: Arc<Self>,
210
0
        children: Vec<Arc<dyn ExecutionPlan>>,
211
0
    ) -> Result<Arc<dyn ExecutionPlan>> {
212
0
        Ok(Arc::new(
213
0
            SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0]))
214
0
                .with_fetch(self.fetch),
215
0
        ))
216
0
    }
217
218
15
    fn execute(
219
15
        &self,
220
15
        partition: usize,
221
15
        context: Arc<TaskContext>,
222
15
    ) -> Result<SendableRecordBatchStream> {
223
15
        trace!(
224
0
            "Start SortPreservingMergeExec::execute for partition: {}",
225
            partition
226
        );
227
15
        if 0 != partition {
228
0
            return internal_err!(
229
0
                "SortPreservingMergeExec invalid partition {partition}"
230
0
            );
231
15
        }
232
15
233
15
        let input_partitions = self.input.output_partitioning().partition_count();
234
15
        trace!(
235
0
            "Number of input partitions of  SortPreservingMergeExec::execute: {}",
236
            input_partitions
237
        );
238
15
        let schema = self.schema();
239
15
240
15
        let reservation =
241
15
            MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
242
15
                .register(&context.runtime_env().memory_pool);
243
15
244
15
        match input_partitions {
245
0
            0 => internal_err!(
246
0
                "SortPreservingMergeExec requires at least one input partition"
247
0
            ),
248
2
            1 => match self.fetch {
249
1
                Some(fetch) => {
250
1
                    let stream = self.input.execute(0, context)
?0
;
251
1
                    debug!(
"Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"0
);
252
1
                    Ok(Box::pin(LimitStream::new(
253
1
                        stream,
254
1
                        0,
255
1
                        Some(fetch),
256
1
                        BaselineMetrics::new(&self.metrics, partition),
257
1
                    )))
258
                }
259
                None => {
260
1
                    let stream = self.input.execute(0, context);
261
1
                    debug!(
"Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"0
);
262
1
                    stream
263
                }
264
            },
265
            _ => {
266
13
                let receivers = (0..input_partitions)
267
40
                    .map(|partition| {
268
40
                        let stream =
269
40
                            self.input.execute(partition, Arc::clone(&context))
?0
;
270
40
                        Ok(spawn_buffered(stream, 1))
271
40
                    })
272
13
                    .collect::<Result<_>>()
?0
;
273
274
13
                debug!(
"Done setting up sender-receiver for SortPreservingMergeExec::execute"0
);
275
276
13
                let 
result12
= StreamingMergeBuilder::new()
277
13
                    .with_streams(receivers)
278
13
                    .with_schema(schema)
279
13
                    .with_expressions(&self.expr)
280
13
                    .with_metrics(BaselineMetrics::new(&self.metrics, partition))
281
13
                    .with_batch_size(context.session_config().batch_size())
282
13
                    .with_fetch(self.fetch)
283
13
                    .with_reservation(reservation)
284
13
                    .build()
?1
;
285
286
12
                debug!(
"Got stream result from SortPreservingMergeStream::new_from_receivers"0
);
287
288
12
                Ok(result)
289
            }
290
        }
291
15
    }
292
293
1
    fn metrics(&self) -> Option<MetricsSet> {
294
1
        Some(self.metrics.clone_inner())
295
1
    }
296
297
0
    fn statistics(&self) -> Result<Statistics> {
298
0
        self.input.statistics()
299
0
    }
300
301
0
    fn supports_limit_pushdown(&self) -> bool {
302
0
        true
303
0
    }
304
}
305
306
#[cfg(test)]
307
mod tests {
308
    use std::fmt::Formatter;
309
    use std::pin::Pin;
310
    use std::sync::Mutex;
311
    use std::task::{Context, Poll};
312
    use std::time::Duration;
313
314
    use super::*;
315
    use crate::coalesce_partitions::CoalescePartitionsExec;
316
    use crate::expressions::col;
317
    use crate::memory::MemoryExec;
318
    use crate::metrics::{MetricValue, Timestamp};
319
    use crate::sorts::sort::SortExec;
320
    use crate::stream::RecordBatchReceiverStream;
321
    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
322
    use crate::test::{self, assert_is_pending, make_partition};
323
    use crate::{collect, common, ExecutionMode};
324
325
    use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray};
326
    use arrow::compute::SortOptions;
327
    use arrow::datatypes::{DataType, Field, Schema};
328
    use arrow::record_batch::RecordBatch;
329
    use arrow_schema::SchemaRef;
330
    use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError};
331
    use datafusion_common_runtime::SpawnedTask;
332
    use datafusion_execution::config::SessionConfig;
333
    use datafusion_execution::RecordBatchStream;
334
    use datafusion_physical_expr::expressions::Column;
335
    use datafusion_physical_expr::EquivalenceProperties;
336
    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
337
338
    use futures::{FutureExt, Stream, StreamExt};
339
    use tokio::time::timeout;
340
341
    #[tokio::test]
342
1
    async fn test_merge_interleave() {
343
1
        let task_ctx = Arc::new(TaskContext::default());
344
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
345
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
346
1
            Some("a"),
347
1
            Some("c"),
348
1
            Some("e"),
349
1
            Some("g"),
350
1
            Some("j"),
351
1
        ]));
352
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
353
1
        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
354
1
355
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
356
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
357
1
            Some("b"),
358
1
            Some("d"),
359
1
            Some("f"),
360
1
            Some("h"),
361
1
            Some("j"),
362
1
        ]));
363
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
364
1
        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
365
1
366
1
        _test_merge(
367
1
            &[vec![b1], vec![b2]],
368
1
            &[
369
1
                "+----+---+-------------------------------+",
370
1
                "| a  | b | c                             |",
371
1
                "+----+---+-------------------------------+",
372
1
                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
373
1
                "| 10 | b | 1970-01-01T00:00:00.000000004 |",
374
1
                "| 2  | c | 1970-01-01T00:00:00.000000007 |",
375
1
                "| 20 | d | 1970-01-01T00:00:00.000000006 |",
376
1
                "| 7  | e | 1970-01-01T00:00:00.000000006 |",
377
1
                "| 70 | f | 1970-01-01T00:00:00.000000002 |",
378
1
                "| 9  | g | 1970-01-01T00:00:00.000000005 |",
379
1
                "| 90 | h | 1970-01-01T00:00:00.000000002 |",
380
1
                "| 30 | j | 1970-01-01T00:00:00.000000006 |", // input b2 before b1
381
1
                "| 3  | j | 1970-01-01T00:00:00.000000008 |",
382
1
                "+----+---+-------------------------------+",
383
1
            ],
384
1
            task_ctx,
385
1
        )
386
1
        .
await0
;
387
1
    }
388
389
    #[tokio::test]
390
1
    async fn test_merge_no_exprs() {
391
1
        let task_ctx = Arc::new(TaskContext::default());
392
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
393
1
        let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap();
394
1
395
1
        let schema = batch.schema();
396
1
        let sort = vec![]; // no sort expressions
397
1
        let exec = MemoryExec::try_new(&[vec![batch.clone()], vec![batch]], schema, None)
398
1
            .unwrap();
399
1
        let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
400
1
401
1
        let res = collect(merge, task_ctx).
await0
.unwrap_err();
402
1
        assert_contains!(
403
1
            res.to_string(),
404
1
            "Internal error: Sort expressions cannot be empty for streaming merge"
405
1
        );
406
1
    }
407
408
    #[tokio::test]
409
1
    async fn test_merge_some_overlap() {
410
1
        let task_ctx = Arc::new(TaskContext::default());
411
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
412
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
413
1
            Some("a"),
414
1
            Some("b"),
415
1
            Some("c"),
416
1
            Some("d"),
417
1
            Some("e"),
418
1
        ]));
419
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
420
1
        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
421
1
422
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
423
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
424
1
            Some("c"),
425
1
            Some("d"),
426
1
            Some("e"),
427
1
            Some("f"),
428
1
            Some("g"),
429
1
        ]));
430
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
431
1
        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
432
1
433
1
        _test_merge(
434
1
            &[vec![b1], vec![b2]],
435
1
            &[
436
1
                "+-----+---+-------------------------------+",
437
1
                "| a   | b | c                             |",
438
1
                "+-----+---+-------------------------------+",
439
1
                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
440
1
                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
441
1
                "| 70  | c | 1970-01-01T00:00:00.000000004 |",
442
1
                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
443
1
                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
444
1
                "| 90  | d | 1970-01-01T00:00:00.000000006 |",
445
1
                "| 30  | e | 1970-01-01T00:00:00.000000002 |",
446
1
                "| 3   | e | 1970-01-01T00:00:00.000000008 |",
447
1
                "| 100 | f | 1970-01-01T00:00:00.000000002 |",
448
1
                "| 110 | g | 1970-01-01T00:00:00.000000006 |",
449
1
                "+-----+---+-------------------------------+",
450
1
            ],
451
1
            task_ctx,
452
1
        )
453
1
        .
await0
;
454
1
    }
455
456
    #[tokio::test]
457
1
    async fn test_merge_no_overlap() {
458
1
        let task_ctx = Arc::new(TaskContext::default());
459
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
460
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
461
1
            Some("a"),
462
1
            Some("b"),
463
1
            Some("c"),
464
1
            Some("d"),
465
1
            Some("e"),
466
1
        ]));
467
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
468
1
        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
469
1
470
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
471
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
472
1
            Some("f"),
473
1
            Some("g"),
474
1
            Some("h"),
475
1
            Some("i"),
476
1
            Some("j"),
477
1
        ]));
478
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
479
1
        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
480
1
481
1
        _test_merge(
482
1
            &[vec![b1], vec![b2]],
483
1
            &[
484
1
                "+----+---+-------------------------------+",
485
1
                "| a  | b | c                             |",
486
1
                "+----+---+-------------------------------+",
487
1
                "| 1  | a | 1970-01-01T00:00:00.000000008 |",
488
1
                "| 2  | b | 1970-01-01T00:00:00.000000007 |",
489
1
                "| 7  | c | 1970-01-01T00:00:00.000000006 |",
490
1
                "| 9  | d | 1970-01-01T00:00:00.000000005 |",
491
1
                "| 3  | e | 1970-01-01T00:00:00.000000008 |",
492
1
                "| 10 | f | 1970-01-01T00:00:00.000000004 |",
493
1
                "| 20 | g | 1970-01-01T00:00:00.000000006 |",
494
1
                "| 70 | h | 1970-01-01T00:00:00.000000002 |",
495
1
                "| 90 | i | 1970-01-01T00:00:00.000000002 |",
496
1
                "| 30 | j | 1970-01-01T00:00:00.000000006 |",
497
1
                "+----+---+-------------------------------+",
498
1
            ],
499
1
            task_ctx,
500
1
        )
501
1
        .
await0
;
502
1
    }
503
504
    #[tokio::test]
505
1
    async fn test_merge_three_partitions() {
506
1
        let task_ctx = Arc::new(TaskContext::default());
507
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
508
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
509
1
            Some("a"),
510
1
            Some("b"),
511
1
            Some("c"),
512
1
            Some("d"),
513
1
            Some("f"),
514
1
        ]));
515
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
516
1
        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
517
1
518
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
519
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
520
1
            Some("e"),
521
1
            Some("g"),
522
1
            Some("h"),
523
1
            Some("i"),
524
1
            Some("j"),
525
1
        ]));
526
1
        let c: ArrayRef =
527
1
            Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
528
1
        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
529
1
530
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
531
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
532
1
            Some("f"),
533
1
            Some("g"),
534
1
            Some("h"),
535
1
            Some("i"),
536
1
            Some("j"),
537
1
        ]));
538
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
539
1
        let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
540
1
541
1
        _test_merge(
542
1
            &[vec![b1], vec![b2], vec![b3]],
543
1
            &[
544
1
                "+-----+---+-------------------------------+",
545
1
                "| a   | b | c                             |",
546
1
                "+-----+---+-------------------------------+",
547
1
                "| 1   | a | 1970-01-01T00:00:00.000000008 |",
548
1
                "| 2   | b | 1970-01-01T00:00:00.000000007 |",
549
1
                "| 7   | c | 1970-01-01T00:00:00.000000006 |",
550
1
                "| 9   | d | 1970-01-01T00:00:00.000000005 |",
551
1
                "| 10  | e | 1970-01-01T00:00:00.000000040 |",
552
1
                "| 100 | f | 1970-01-01T00:00:00.000000004 |",
553
1
                "| 3   | f | 1970-01-01T00:00:00.000000008 |",
554
1
                "| 200 | g | 1970-01-01T00:00:00.000000006 |",
555
1
                "| 20  | g | 1970-01-01T00:00:00.000000060 |",
556
1
                "| 700 | h | 1970-01-01T00:00:00.000000002 |",
557
1
                "| 70  | h | 1970-01-01T00:00:00.000000020 |",
558
1
                "| 900 | i | 1970-01-01T00:00:00.000000002 |",
559
1
                "| 90  | i | 1970-01-01T00:00:00.000000020 |",
560
1
                "| 300 | j | 1970-01-01T00:00:00.000000006 |",
561
1
                "| 30  | j | 1970-01-01T00:00:00.000000060 |",
562
1
                "+-----+---+-------------------------------+",
563
1
            ],
564
1
            task_ctx,
565
1
        )
566
1
        .
await0
;
567
1
    }
568
569
4
    async fn _test_merge(
570
4
        partitions: &[Vec<RecordBatch>],
571
4
        exp: &[&str],
572
4
        context: Arc<TaskContext>,
573
4
    ) {
574
4
        let schema = partitions[0][0].schema();
575
4
        let sort = vec![
576
4
            PhysicalSortExpr {
577
4
                expr: col("b", &schema).unwrap(),
578
4
                options: Default::default(),
579
4
            },
580
4
            PhysicalSortExpr {
581
4
                expr: col("c", &schema).unwrap(),
582
4
                options: Default::default(),
583
4
            },
584
4
        ];
585
4
        let exec = MemoryExec::try_new(partitions, schema, None).unwrap();
586
4
        let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
587
588
4
        let collected = collect(merge, context).
await0
.unwrap();
589
4
        assert_batches_eq!(exp, collected.as_slice());
590
4
    }
591
592
2
    async fn sorted_merge(
593
2
        input: Arc<dyn ExecutionPlan>,
594
2
        sort: Vec<PhysicalSortExpr>,
595
2
        context: Arc<TaskContext>,
596
2
    ) -> RecordBatch {
597
2
        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
598
2
        let mut result = collect(merge, context).
await0
.unwrap();
599
2
        assert_eq!(result.len(), 1);
600
2
        result.remove(0)
601
2
    }
602
603
1
    async fn partition_sort(
604
1
        input: Arc<dyn ExecutionPlan>,
605
1
        sort: Vec<PhysicalSortExpr>,
606
1
        context: Arc<TaskContext>,
607
1
    ) -> RecordBatch {
608
1
        let sort_exec =
609
1
            Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
610
1
        sorted_merge(sort_exec, sort, context).
await0
611
1
    }
612
613
7
    async fn basic_sort(
614
7
        src: Arc<dyn ExecutionPlan>,
615
7
        sort: Vec<PhysicalSortExpr>,
616
7
        context: Arc<TaskContext>,
617
7
    ) -> RecordBatch {
618
7
        let merge = Arc::new(CoalescePartitionsExec::new(src));
619
7
        let sort_exec = Arc::new(SortExec::new(sort, merge));
620
217
        let 
mut result7
=
collect(sort_exec, context)7
.await.unwrap();
621
7
        assert_eq!(result.len(), 1);
622
7
        result.remove(0)
623
7
    }
624
625
    #[tokio::test]
626
1
    async fn test_partition_sort() -> Result<()> {
627
1
        let task_ctx = Arc::new(TaskContext::default());
628
1
        let partitions = 4;
629
1
        let csv = test::scan_partitioned(partitions);
630
1
        let schema = csv.schema();
631
1
632
1
        let sort = vec![PhysicalSortExpr {
633
1
            expr: col("i", &schema).unwrap(),
634
1
            options: SortOptions {
635
1
                descending: true,
636
1
                nulls_first: true,
637
1
            },
638
1
        }];
639
1
640
1
        let basic =
641
1
            basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
642
1
        let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).
await0
;
643
1
644
1
        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
645
1
            .unwrap()
646
1
            .to_string();
647
1
        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
648
1
            .unwrap()
649
1
            .to_string();
650
1
651
1
        assert_eq!(
652
1
            basic, partition,
653
1
            
"basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"0
654
1
        );
655
1
656
1
        Ok(())
657
1
    }
658
659
    // Split the provided record batch into multiple batch_size record batches
660
9
    fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
661
9
        let batches = (sorted.num_rows() + batch_size - 1) / batch_size;
662
9
663
9
        // Split the sorted RecordBatch into multiple
664
9
        (0..batches)
665
634
            .map(|batch_idx| {
666
634
                let columns = (0..sorted.num_columns())
667
634
                    .map(|column_idx| {
668
634
                        let length =
669
634
                            batch_size.min(sorted.num_rows() - batch_idx * batch_size);
670
634
671
634
                        sorted
672
634
                            .column(column_idx)
673
634
                            .slice(batch_idx * batch_size, length)
674
634
                    })
675
634
                    .collect();
676
634
677
634
                RecordBatch::try_new(sorted.schema(), columns).unwrap()
678
634
            })
679
9
            .collect()
680
9
    }
681
682
3
    async fn sorted_partitioned_input(
683
3
        sort: Vec<PhysicalSortExpr>,
684
3
        sizes: &[usize],
685
3
        context: Arc<TaskContext>,
686
3
    ) -> Result<Arc<dyn ExecutionPlan>> {
687
3
        let partitions = 4;
688
3
        let csv = test::scan_partitioned(partitions);
689
690
3
        let sorted = basic_sort(csv, sort, context).await;
691
9
        let 
split: Vec<_> = sizes.iter().map(3
|x| split_batch(&sorted, *x)).collect();
692
3
693
3
        Ok(Arc::new(
694
3
            MemoryExec::try_new(&split, sorted.schema(), None).unwrap(),
695
3
        ))
696
3
    }
697
698
    #[tokio::test]
699
1
    async fn test_partition_sort_streaming_input() -> Result<()> {
700
1
        let task_ctx = Arc::new(TaskContext::default());
701
1
        let schema = make_partition(11).schema();
702
1
        let sort = vec![PhysicalSortExpr {
703
1
            expr: col("i", &schema).unwrap(),
704
1
            options: Default::default(),
705
1
        }];
706
1
707
1
        let input =
708
1
            sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
709
1
                .await
?0
;
710
1
        let basic =
711
71
            
basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx))1
.await;
712
1
        let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).
await0
;
713
1
714
1
        assert_eq!(basic.num_rows(), 1200);
715
1
        assert_eq!(partition.num_rows(), 1200);
716
1
717
1
        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
718
1
            .unwrap()
719
1
            .to_string();
720
1
        let partition = arrow::util::pretty::pretty_format_batches(&[partition])
721
1
            .unwrap()
722
1
            .to_string();
723
1
724
1
        assert_eq!(basic, partition);
725
1
726
1
        Ok(())
727
1
    }
728
729
    #[tokio::test]
730
1
    async fn test_partition_sort_streaming_input_output() -> Result<()> {
731
1
        let schema = make_partition(11).schema();
732
1
        let sort = vec![PhysicalSortExpr {
733
1
            expr: col("i", &schema).unwrap(),
734
1
            options: Default::default(),
735
1
        }];
736
1
737
1
        // Test streaming with default batch size
738
1
        let task_ctx = Arc::new(TaskContext::default());
739
1
        let input =
740
1
            sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
741
1
                .await
?0
;
742
51
        let 
basic1
=
basic_sort(Arc::clone(&input), sort.clone(), task_ctx)1
.await;
743
1
744
1
        // batch size of 23
745
1
        let task_ctx = TaskContext::default()
746
1
            .with_session_config(SessionConfig::new().with_batch_size(23));
747
1
        let task_ctx = Arc::new(task_ctx);
748
1
749
1
        let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
750
1
        let merged = collect(merge, task_ctx).
await0
.unwrap();
751
1
752
1
        assert_eq!(merged.len(), 53);
753
1
754
1
        assert_eq!(basic.num_rows(), 1200);
755
53
        
assert_eq!(merged.iter().map(1
|x| x.num_rows()
).sum::<usize>(), 1200)1
;
756
1
757
1
        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
758
1
            .unwrap()
759
1
            .to_string();
760
1
        let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice())
761
1
            .unwrap()
762
1
            .to_string();
763
1
764
1
        assert_eq!(basic, partition);
765
1
766
1
        Ok(())
767
1
    }
768
769
    #[tokio::test]
770
1
    async fn test_nulls() {
771
1
        let task_ctx = Arc::new(TaskContext::default());
772
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
773
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
774
1
            None,
775
1
            Some("a"),
776
1
            Some("b"),
777
1
            Some("d"),
778
1
            Some("e"),
779
1
        ]));
780
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
781
1
            Some(8),
782
1
            None,
783
1
            Some(6),
784
1
            None,
785
1
            Some(4),
786
1
        ]));
787
1
        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
788
1
789
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
790
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
791
1
            None,
792
1
            Some("b"),
793
1
            Some("g"),
794
1
            Some("h"),
795
1
            Some("i"),
796
1
        ]));
797
1
        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
798
1
            Some(8),
799
1
            None,
800
1
            Some(5),
801
1
            None,
802
1
            Some(4),
803
1
        ]));
804
1
        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
805
1
        let schema = b1.schema();
806
1
807
1
        let sort = vec![
808
1
            PhysicalSortExpr {
809
1
                expr: col("b", &schema).unwrap(),
810
1
                options: SortOptions {
811
1
                    descending: false,
812
1
                    nulls_first: true,
813
1
                },
814
1
            },
815
1
            PhysicalSortExpr {
816
1
                expr: col("c", &schema).unwrap(),
817
1
                options: SortOptions {
818
1
                    descending: false,
819
1
                    nulls_first: false,
820
1
                },
821
1
            },
822
1
        ];
823
1
        let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
824
1
        let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
825
1
826
1
        let collected = collect(merge, task_ctx).
await0
.unwrap();
827
1
        assert_eq!(collected.len(), 1);
828
1
829
1
        assert_batches_eq!(
830
1
            &[
831
1
                "+---+---+-------------------------------+",
832
1
                "| a | b | c                             |",
833
1
                "+---+---+-------------------------------+",
834
1
                "| 1 |   | 1970-01-01T00:00:00.000000008 |",
835
1
                "| 1 |   | 1970-01-01T00:00:00.000000008 |",
836
1
                "| 2 | a |                               |",
837
1
                "| 7 | b | 1970-01-01T00:00:00.000000006 |",
838
1
                "| 2 | b |                               |",
839
1
                "| 9 | d |                               |",
840
1
                "| 3 | e | 1970-01-01T00:00:00.000000004 |",
841
1
                "| 3 | g | 1970-01-01T00:00:00.000000005 |",
842
1
                "| 4 | h |                               |",
843
1
                "| 5 | i | 1970-01-01T00:00:00.000000004 |",
844
1
                "+---+---+-------------------------------+",
845
1
            ],
846
1
            collected.as_slice()
847
1
        );
848
1
    }
849
850
    #[tokio::test]
851
1
    async fn test_sort_merge_single_partition_with_fetch() {
852
1
        let task_ctx = Arc::new(TaskContext::default());
853
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
854
1
        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
855
1
        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
856
1
        let schema = batch.schema();
857
1
858
1
        let sort = vec![PhysicalSortExpr {
859
1
            expr: col("b", &schema).unwrap(),
860
1
            options: SortOptions {
861
1
                descending: false,
862
1
                nulls_first: true,
863
1
            },
864
1
        }];
865
1
        let exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap();
866
1
        let merge = Arc::new(
867
1
            SortPreservingMergeExec::new(sort, Arc::new(exec)).with_fetch(Some(2)),
868
1
        );
869
1
870
1
        let collected = collect(merge, task_ctx).
await0
.unwrap();
871
1
        assert_eq!(collected.len(), 1);
872
1
873
1
        assert_batches_eq!(
874
1
            &[
875
1
                "+---+---+",
876
1
                "| a | b |",
877
1
                "+---+---+",
878
1
                "| 1 | a |",
879
1
                "| 2 | b |",
880
1
                "+---+---+",
881
1
            ],
882
1
            collected.as_slice()
883
1
        );
884
1
    }
885
886
    #[tokio::test]
887
1
    async fn test_sort_merge_single_partition_without_fetch() {
888
1
        let task_ctx = Arc::new(TaskContext::default());
889
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
890
1
        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
891
1
        let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
892
1
        let schema = batch.schema();
893
1
894
1
        let sort = vec![PhysicalSortExpr {
895
1
            expr: col("b", &schema).unwrap(),
896
1
            options: SortOptions {
897
1
                descending: false,
898
1
                nulls_first: true,
899
1
            },
900
1
        }];
901
1
        let exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap();
902
1
        let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
903
1
904
1
        let collected = collect(merge, task_ctx).
await0
.unwrap();
905
1
        assert_eq!(collected.len(), 1);
906
1
907
1
        assert_batches_eq!(
908
1
            &[
909
1
                "+---+---+",
910
1
                "| a | b |",
911
1
                "+---+---+",
912
1
                "| 1 | a |",
913
1
                "| 2 | b |",
914
1
                "| 7 | c |",
915
1
                "| 9 | d |",
916
1
                "| 3 | e |",
917
1
                "+---+---+",
918
1
            ],
919
1
            collected.as_slice()
920
1
        );
921
1
    }
922
923
    #[tokio::test]
924
1
    async fn test_async() -> Result<()> {
925
1
        let task_ctx = Arc::new(TaskContext::default());
926
1
        let schema = make_partition(11).schema();
927
1
        let sort = vec![PhysicalSortExpr {
928
1
            expr: col("i", &schema).unwrap(),
929
1
            options: SortOptions::default(),
930
1
        }];
931
1
932
1
        let batches =
933
1
            sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
934
1
                .await
?0
;
935
1
936
1
        let partition_count = batches.output_partitioning().partition_count();
937
1
        let mut streams = Vec::with_capacity(partition_count);
938
1
939
3
        for partition in 0..
partition_count1
{
940
3
            let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
941
3
942
3
            let sender = builder.tx();
943
3
944
3
            let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
945
3
            builder.spawn(async move {
946
275
                while let Some(
batch272
) = stream.next().
await0
{
947
272
                    sender.send(batch).
await130
.unwrap();
948
272
                    // This causes the MergeStream to wait for more input
949
272
                    tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
950
1
                }
951
1
952
3
                Ok(())
953
3
            });
954
3
955
3
            streams.push(builder.build());
956
3
        }
957
1
958
1
        let metrics = ExecutionPlanMetricsSet::new();
959
1
        let reservation =
960
1
            MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
961
1
962
1
        let fetch = None;
963
1
        let merge_stream = StreamingMergeBuilder::new()
964
1
            .with_streams(streams)
965
1
            .with_schema(batches.schema())
966
1
            .with_expressions(sort.as_slice())
967
1
            .with_metrics(BaselineMetrics::new(&metrics, 0))
968
1
            .with_batch_size(task_ctx.session_config().batch_size())
969
1
            .with_fetch(fetch)
970
1
            .with_reservation(reservation)
971
1
            .build()
?0
;
972
1
973
135
        let 
mut merged1
=
common::collect(merge_stream)1
.await.unwrap();
974
1
975
1
        assert_eq!(merged.len(), 1);
976
1
        let merged = merged.remove(0);
977
91
        let 
basic1
=
basic_sort(batches, sort.clone(), Arc::clone(&task_ctx))1
.await;
978
1
979
1
        let basic = arrow::util::pretty::pretty_format_batches(&[basic])
980
1
            .unwrap()
981
1
            .to_string();
982
1
        let partition = arrow::util::pretty::pretty_format_batches(&[merged])
983
1
            .unwrap()
984
1
            .to_string();
985
1
986
1
        assert_eq!(
987
1
            basic, partition,
988
1
            
"basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"0
989
1
        );
990
1
991
1
        Ok(())
992
1
    }
993
994
    #[tokio::test]
995
1
    async fn test_merge_metrics() {
996
1
        let task_ctx = Arc::new(TaskContext::default());
997
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
998
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
999
1
        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1000
1
1001
1
        let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1002
1
        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1003
1
        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1004
1
1005
1
        let schema = b1.schema();
1006
1
        let sort = vec![PhysicalSortExpr {
1007
1
            expr: col("b", &schema).unwrap(),
1008
1
            options: Default::default(),
1009
1
        }];
1010
1
        let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
1011
1
        let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
1012
1
1013
1
        let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1014
1
            .
await0
1015
1
            .unwrap();
1016
1
        let expected = [
1017
1
            "+----+---+",
1018
1
            "| a  | b |",
1019
1
            "+----+---+",
1020
1
            "| 1  | a |",
1021
1
            "| 10 | b |",
1022
1
            "| 2  | c |",
1023
1
            "| 20 | d |",
1024
1
            "+----+---+",
1025
1
        ];
1026
1
        assert_batches_eq!(expected, collected.as_slice());
1027
1
1028
1
        // Now, validate metrics
1029
1
        let metrics = merge.metrics().unwrap();
1030
1
1031
1
        assert_eq!(metrics.output_rows().unwrap(), 4);
1032
1
        assert!(metrics.elapsed_compute().unwrap() > 0);
1033
1
1034
1
        let mut saw_start = false;
1035
1
        let mut saw_end = false;
1036
4
        metrics.iter().for_each(|m| match m.value() {
1037
1
            MetricValue::StartTimestamp(ts) => {
1038
1
                saw_start = true;
1039
1
                assert!(nanos_from_timestamp(ts) > 0);
1040
1
            }
1041
1
            MetricValue::EndTimestamp(ts) => {
1042
1
                saw_end = true;
1043
1
                assert!(nanos_from_timestamp(ts) > 0);
1044
1
            }
1045
2
            _ => {}
1046
4
        });
1047
1
1048
1
        assert!(saw_start);
1049
1
        assert!(saw_end);
1050
1
    }
1051
1052
2
    fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1053
2
        ts.value().unwrap().timestamp_nanos_opt().unwrap()
1054
2
    }
1055
1056
    #[tokio::test]
1057
1
    async fn test_drop_cancel() -> Result<()> {
1058
1
        let task_ctx = Arc::new(TaskContext::default());
1059
1
        let schema =
1060
1
            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1061
1
1062
1
        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1063
1
        let refs = blocking_exec.refs();
1064
1
        let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1065
1
            vec![PhysicalSortExpr {
1066
1
                expr: col("a", &schema)
?0
,
1067
1
                options: SortOptions::default(),
1068
1
            }],
1069
1
            blocking_exec,
1070
1
        ));
1071
1
1072
1
        let fut = collect(sort_preserving_merge_exec, task_ctx);
1073
1
        let mut fut = fut.boxed();
1074
1
1075
1
        assert_is_pending(&mut fut);
1076
1
        drop(fut);
1077
1
        assert_strong_count_converges_to_zero(refs).
await0
;
1078
1
1079
1
        Ok(())
1080
1
    }
1081
1082
    #[tokio::test]
1083
1
    async fn test_stable_sort() {
1084
1
        let task_ctx = Arc::new(TaskContext::default());
1085
1
1086
1
        // Create record batches like:
1087
1
        // batch_number |value
1088
1
        // -------------+------
1089
1
        //    1         | A
1090
1
        //    1         | B
1091
1
        //
1092
1
        // Ensure that the output is in the same order the batches were fed
1093
1
        let partitions: Vec<Vec<RecordBatch>> = (0..10)
1094
10
            .map(|batch_number| {
1095
10
                let batch_number: Int32Array =
1096
10
                    vec![Some(batch_number), Some(batch_number)]
1097
10
                        .into_iter()
1098
10
                        .collect();
1099
10
                let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1100
10
1101
10
                let batch = RecordBatch::try_from_iter(vec![
1102
10
                    ("batch_number", Arc::new(batch_number) as ArrayRef),
1103
10
                    ("value", Arc::new(value) as ArrayRef),
1104
10
                ])
1105
10
                .unwrap();
1106
10
1107
10
                vec![batch]
1108
10
            })
1109
1
            .collect();
1110
1
1111
1
        let schema = partitions[0][0].schema();
1112
1
1113
1
        let sort = vec![PhysicalSortExpr {
1114
1
            expr: col("value", &schema).unwrap(),
1115
1
            options: SortOptions {
1116
1
                descending: false,
1117
1
                nulls_first: true,
1118
1
            },
1119
1
        }];
1120
1
1121
1
        let exec = MemoryExec::try_new(&partitions, schema, None).unwrap();
1122
1
        let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
1123
1
1124
1
        let collected = collect(merge, task_ctx).
await0
.unwrap();
1125
1
        assert_eq!(collected.len(), 1);
1126
1
1127
1
        // Expect the data to be sorted first by "batch_number" (because
1128
1
        // that was the order it was fed in, even though only "value"
1129
1
        // is in the sort key)
1130
1
        assert_batches_eq!(
1131
1
            &[
1132
1
                "+--------------+-------+",
1133
1
                "| batch_number | value |",
1134
1
                "+--------------+-------+",
1135
1
                "| 0            | A     |",
1136
1
                "| 1            | A     |",
1137
1
                "| 2            | A     |",
1138
1
                "| 3            | A     |",
1139
1
                "| 4            | A     |",
1140
1
                "| 5            | A     |",
1141
1
                "| 6            | A     |",
1142
1
                "| 7            | A     |",
1143
1
                "| 8            | A     |",
1144
1
                "| 9            | A     |",
1145
1
                "| 0            | B     |",
1146
1
                "| 1            | B     |",
1147
1
                "| 2            | B     |",
1148
1
                "| 3            | B     |",
1149
1
                "| 4            | B     |",
1150
1
                "| 5            | B     |",
1151
1
                "| 6            | B     |",
1152
1
                "| 7            | B     |",
1153
1
                "| 8            | B     |",
1154
1
                "| 9            | B     |",
1155
1
                "+--------------+-------+",
1156
1
            ],
1157
1
            collected.as_slice()
1158
1
        );
1159
1
    }
1160
1161
    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1162
    /// partition is exhausted from the start, and if it is polled more than one, it panics.
1163
    #[derive(Debug, Clone)]
1164
    struct CongestedExec {
1165
        schema: Schema,
1166
        cache: PlanProperties,
1167
        congestion_cleared: Arc<Mutex<bool>>,
1168
    }
1169
1170
    impl CongestedExec {
1171
1
        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1172
1
            let columns = schema
1173
1
                .fields
1174
1
                .iter()
1175
1
                .enumerate()
1176
1
                .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1177
1
                .collect::<Vec<_>>();
1178
1
            let mut eq_properties = EquivalenceProperties::new(schema);
1179
1
            eq_properties.add_new_orderings(vec![columns
1180
1
                .iter()
1181
1
                .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr)))
1182
1
                .collect::<Vec<_>>()]);
1183
1
            let mode = ExecutionMode::Unbounded;
1184
1
            PlanProperties::new(eq_properties, Partitioning::Hash(columns, 3), mode)
1185
1
        }
1186
    }
1187
1188
    impl ExecutionPlan for CongestedExec {
1189
0
        fn name(&self) -> &'static str {
1190
0
            Self::static_name()
1191
0
        }
1192
0
        fn as_any(&self) -> &dyn Any {
1193
0
            self
1194
0
        }
1195
3
        fn properties(&self) -> &PlanProperties {
1196
3
            &self.cache
1197
3
        }
1198
0
        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1199
0
            vec![]
1200
0
        }
1201
0
        fn with_new_children(
1202
0
            self: Arc<Self>,
1203
0
            _: Vec<Arc<dyn ExecutionPlan>>,
1204
0
        ) -> Result<Arc<dyn ExecutionPlan>> {
1205
0
            Ok(self)
1206
0
        }
1207
3
        fn execute(
1208
3
            &self,
1209
3
            partition: usize,
1210
3
            _context: Arc<TaskContext>,
1211
3
        ) -> Result<SendableRecordBatchStream> {
1212
3
            Ok(Box::pin(CongestedStream {
1213
3
                schema: Arc::new(self.schema.clone()),
1214
3
                none_polled_once: false,
1215
3
                congestion_cleared: Arc::clone(&self.congestion_cleared),
1216
3
                partition,
1217
3
            }))
1218
3
        }
1219
    }
1220
1221
    impl DisplayAs for CongestedExec {
1222
0
        fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1223
0
            match t {
1224
                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1225
0
                    write!(f, "CongestedExec",).unwrap()
1226
0
                }
1227
0
            }
1228
0
            Ok(())
1229
0
        }
1230
    }
1231
1232
    /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1233
    /// partition is exhausted from the start, and if it is polled more than once, it panics.
1234
    #[derive(Debug)]
1235
    pub struct CongestedStream {
1236
        schema: SchemaRef,
1237
        none_polled_once: bool,
1238
        congestion_cleared: Arc<Mutex<bool>>,
1239
        partition: usize,
1240
    }
1241
1242
    impl Stream for CongestedStream {
1243
        type Item = Result<RecordBatch>;
1244
4
        fn poll_next(
1245
4
            mut self: Pin<&mut Self>,
1246
4
            _cx: &mut Context<'_>,
1247
4
        ) -> Poll<Option<Self::Item>> {
1248
4
            match self.partition {
1249
                0 => {
1250
1
                    if self.none_polled_once {
1251
0
                        panic!("Exhausted stream is polled more than one")
1252
                    } else {
1253
1
                        self.none_polled_once = true;
1254
1
                        Poll::Ready(None)
1255
                    }
1256
                }
1257
                1 => {
1258
2
                    let cleared = self.congestion_cleared.lock().unwrap();
1259
2
                    if *cleared {
1260
1
                        Poll::Ready(None)
1261
                    } else {
1262
1
                        Poll::Pending
1263
                    }
1264
                }
1265
                2 => {
1266
1
                    let mut cleared = self.congestion_cleared.lock().unwrap();
1267
1
                    *cleared = true;
1268
1
                    Poll::Ready(None)
1269
                }
1270
0
                _ => unreachable!(),
1271
            }
1272
4
        }
1273
    }
1274
1275
    impl RecordBatchStream for CongestedStream {
1276
0
        fn schema(&self) -> SchemaRef {
1277
0
            Arc::clone(&self.schema)
1278
0
        }
1279
    }
1280
1281
    #[tokio::test]
1282
1
    async fn test_spm_congestion() -> Result<()> {
1283
1
        let task_ctx = Arc::new(TaskContext::default());
1284
1
        let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1285
1
        let source = CongestedExec {
1286
1
            schema: schema.clone(),
1287
1
            cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
1288
1
            congestion_cleared: Arc::new(Mutex::new(false)),
1289
1
        };
1290
1
        let spm = SortPreservingMergeExec::new(
1291
1
            vec![PhysicalSortExpr::new_default(Arc::new(Column::new(
1292
1
                "c1", 0,
1293
1
            )))],
1294
1
            Arc::new(source),
1295
1
        );
1296
1
        let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1297
1
1298
1
        let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1299
1
        match result {
1300
1
            Ok(Ok(Ok(_batches))) => Ok(()),
1301
1
            Ok(Ok(Err(
e))) => Err(e)0
,
1302
1
            Ok(Err(_)) => Err(DataFusionError::Execution(
1303
0
                "SortPreservingMerge task panicked or was cancelled".to_string(),
1304
0
            )),
1305
1
            Err(_) => Err(DataFusionError::Execution(
1306
0
                "SortPreservingMerge caused a deadlock".to_string(),
1307
0
            )),
1308
1
        }
1309
1
    }
1310
}