/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines the sort preserving merge plan |
19 | | |
20 | | use std::any::Any; |
21 | | use std::sync::Arc; |
22 | | |
23 | | use crate::common::spawn_buffered; |
24 | | use crate::expressions::PhysicalSortExpr; |
25 | | use crate::limit::LimitStream; |
26 | | use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; |
27 | | use crate::sorts::streaming_merge::StreamingMergeBuilder; |
28 | | use crate::{ |
29 | | DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, |
30 | | Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, |
31 | | }; |
32 | | |
33 | | use datafusion_common::{internal_err, Result}; |
34 | | use datafusion_execution::memory_pool::MemoryConsumer; |
35 | | use datafusion_execution::TaskContext; |
36 | | use datafusion_physical_expr::PhysicalSortRequirement; |
37 | | |
38 | | use datafusion_physical_expr_common::sort_expr::LexRequirement; |
39 | | use log::{debug, trace}; |
40 | | |
41 | | /// Sort preserving merge execution plan |
42 | | /// |
43 | | /// This takes an input execution plan and a list of sort expressions, and |
44 | | /// provided each partition of the input plan is sorted with respect to |
45 | | /// these sort expressions, this operator will yield a single partition |
46 | | /// that is also sorted with respect to them |
47 | | /// |
48 | | /// ```text |
49 | | /// ┌─────────────────────────┐ |
50 | | /// │ ┌───┬───┬───┬───┐ │ |
51 | | /// │ │ A │ B │ C │ D │ ... │──┐ |
52 | | /// │ └───┴───┴───┴───┘ │ │ |
53 | | /// └─────────────────────────┘ │ ┌───────────────────┐ ┌───────────────────────────────┐ |
54 | | /// Stream 1 │ │ │ │ ┌───┬───╦═══╦───┬───╦═══╗ │ |
55 | | /// ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │ |
56 | | /// │ │ │ │ └───┴─▲─╩═══╩───┴───╩═══╝ │ |
57 | | /// ┌─────────────────────────┐ │ └───────────────────┘ └─┬─────┴───────────────────────┘ |
58 | | /// │ ╔═══╦═══╗ │ │ |
59 | | /// │ ║ B ║ E ║ ... │──┘ │ |
60 | | /// │ ╚═══╩═══╝ │ Note Stable Sort: the merged stream |
61 | | /// └─────────────────────────┘ places equal rows from stream 1 |
62 | | /// Stream 2 |
63 | | /// |
64 | | /// |
65 | | /// Input Streams Output stream |
66 | | /// (sorted) (sorted) |
67 | | /// ``` |
68 | | /// |
69 | | /// # Error Handling |
70 | | /// |
71 | | /// If any of the input partitions return an error, the error is propagated to |
72 | | /// the output and inputs are not polled again. |
73 | | #[derive(Debug)] |
74 | | pub struct SortPreservingMergeExec { |
75 | | /// Input plan |
76 | | input: Arc<dyn ExecutionPlan>, |
77 | | /// Sort expressions |
78 | | expr: Vec<PhysicalSortExpr>, |
79 | | /// Execution metrics |
80 | | metrics: ExecutionPlanMetricsSet, |
81 | | /// Optional number of rows to fetch. Stops producing rows after this fetch |
82 | | fetch: Option<usize>, |
83 | | /// Cache holding plan properties like equivalences, output partitioning etc. |
84 | | cache: PlanProperties, |
85 | | } |
86 | | |
87 | | impl SortPreservingMergeExec { |
88 | | /// Create a new sort execution plan |
89 | 15 | pub fn new(expr: Vec<PhysicalSortExpr>, input: Arc<dyn ExecutionPlan>) -> Self { |
90 | 15 | let cache = Self::compute_properties(&input, expr.clone()); |
91 | 15 | Self { |
92 | 15 | input, |
93 | 15 | expr, |
94 | 15 | metrics: ExecutionPlanMetricsSet::new(), |
95 | 15 | fetch: None, |
96 | 15 | cache, |
97 | 15 | } |
98 | 15 | } |
99 | | /// Sets the number of rows to fetch |
100 | 1 | pub fn with_fetch(mut self, fetch: Option<usize>) -> Self { |
101 | 1 | self.fetch = fetch; |
102 | 1 | self |
103 | 1 | } |
104 | | |
105 | | /// Input schema |
106 | 0 | pub fn input(&self) -> &Arc<dyn ExecutionPlan> { |
107 | 0 | &self.input |
108 | 0 | } |
109 | | |
110 | | /// Sort expressions |
111 | 0 | pub fn expr(&self) -> &[PhysicalSortExpr] { |
112 | 0 | &self.expr |
113 | 0 | } |
114 | | |
115 | | /// Fetch |
116 | 0 | pub fn fetch(&self) -> Option<usize> { |
117 | 0 | self.fetch |
118 | 0 | } |
119 | | |
120 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
121 | 15 | fn compute_properties( |
122 | 15 | input: &Arc<dyn ExecutionPlan>, |
123 | 15 | ordering: Vec<PhysicalSortExpr>, |
124 | 15 | ) -> PlanProperties { |
125 | 15 | let mut eq_properties = input.equivalence_properties().clone(); |
126 | 15 | eq_properties.clear_per_partition_constants(); |
127 | 15 | eq_properties.add_new_orderings(vec![ordering]); |
128 | 15 | PlanProperties::new( |
129 | 15 | eq_properties, // Equivalence Properties |
130 | 15 | Partitioning::UnknownPartitioning(1), // Output Partitioning |
131 | 15 | input.execution_mode(), // Execution Mode |
132 | 15 | ) |
133 | 15 | } |
134 | | } |
135 | | |
136 | | impl DisplayAs for SortPreservingMergeExec { |
137 | 0 | fn fmt_as( |
138 | 0 | &self, |
139 | 0 | t: DisplayFormatType, |
140 | 0 | f: &mut std::fmt::Formatter, |
141 | 0 | ) -> std::fmt::Result { |
142 | 0 | match t { |
143 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
144 | 0 | write!( |
145 | 0 | f, |
146 | 0 | "SortPreservingMergeExec: [{}]", |
147 | 0 | PhysicalSortExpr::format_list(&self.expr) |
148 | 0 | )?; |
149 | 0 | if let Some(fetch) = self.fetch { |
150 | 0 | write!(f, ", fetch={fetch}")?; |
151 | 0 | }; |
152 | | |
153 | 0 | Ok(()) |
154 | | } |
155 | | } |
156 | 0 | } |
157 | | } |
158 | | |
159 | | impl ExecutionPlan for SortPreservingMergeExec { |
160 | 0 | fn name(&self) -> &'static str { |
161 | 0 | "SortPreservingMergeExec" |
162 | 0 | } |
163 | | |
164 | | /// Return a reference to Any that can be used for downcasting |
165 | 0 | fn as_any(&self) -> &dyn Any { |
166 | 0 | self |
167 | 0 | } |
168 | | |
169 | 30 | fn properties(&self) -> &PlanProperties { |
170 | 30 | &self.cache |
171 | 30 | } |
172 | | |
173 | 0 | fn fetch(&self) -> Option<usize> { |
174 | 0 | self.fetch |
175 | 0 | } |
176 | | |
177 | | /// Sets the number of rows to fetch |
178 | 0 | fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> { |
179 | 0 | Some(Arc::new(Self { |
180 | 0 | input: Arc::clone(&self.input), |
181 | 0 | expr: self.expr.clone(), |
182 | 0 | metrics: self.metrics.clone(), |
183 | 0 | fetch: limit, |
184 | 0 | cache: self.cache.clone(), |
185 | 0 | })) |
186 | 0 | } |
187 | | |
188 | 0 | fn required_input_distribution(&self) -> Vec<Distribution> { |
189 | 0 | vec![Distribution::UnspecifiedDistribution] |
190 | 0 | } |
191 | | |
192 | 0 | fn benefits_from_input_partitioning(&self) -> Vec<bool> { |
193 | 0 | vec![false] |
194 | 0 | } |
195 | | |
196 | 0 | fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> { |
197 | 0 | vec![Some(PhysicalSortRequirement::from_sort_exprs(&self.expr))] |
198 | 0 | } |
199 | | |
200 | 0 | fn maintains_input_order(&self) -> Vec<bool> { |
201 | 0 | vec![true] |
202 | 0 | } |
203 | | |
204 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
205 | 0 | vec![&self.input] |
206 | 0 | } |
207 | | |
208 | 0 | fn with_new_children( |
209 | 0 | self: Arc<Self>, |
210 | 0 | children: Vec<Arc<dyn ExecutionPlan>>, |
211 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
212 | 0 | Ok(Arc::new( |
213 | 0 | SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0])) |
214 | 0 | .with_fetch(self.fetch), |
215 | 0 | )) |
216 | 0 | } |
217 | | |
218 | 15 | fn execute( |
219 | 15 | &self, |
220 | 15 | partition: usize, |
221 | 15 | context: Arc<TaskContext>, |
222 | 15 | ) -> Result<SendableRecordBatchStream> { |
223 | 15 | trace!( |
224 | 0 | "Start SortPreservingMergeExec::execute for partition: {}", |
225 | | partition |
226 | | ); |
227 | 15 | if 0 != partition { |
228 | 0 | return internal_err!( |
229 | 0 | "SortPreservingMergeExec invalid partition {partition}" |
230 | 0 | ); |
231 | 15 | } |
232 | 15 | |
233 | 15 | let input_partitions = self.input.output_partitioning().partition_count(); |
234 | 15 | trace!( |
235 | 0 | "Number of input partitions of SortPreservingMergeExec::execute: {}", |
236 | | input_partitions |
237 | | ); |
238 | 15 | let schema = self.schema(); |
239 | 15 | |
240 | 15 | let reservation = |
241 | 15 | MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]")) |
242 | 15 | .register(&context.runtime_env().memory_pool); |
243 | 15 | |
244 | 15 | match input_partitions { |
245 | 0 | 0 => internal_err!( |
246 | 0 | "SortPreservingMergeExec requires at least one input partition" |
247 | 0 | ), |
248 | 2 | 1 => match self.fetch { |
249 | 1 | Some(fetch) => { |
250 | 1 | let stream = self.input.execute(0, context)?0 ; |
251 | 1 | debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"0 ); |
252 | 1 | Ok(Box::pin(LimitStream::new( |
253 | 1 | stream, |
254 | 1 | 0, |
255 | 1 | Some(fetch), |
256 | 1 | BaselineMetrics::new(&self.metrics, partition), |
257 | 1 | ))) |
258 | | } |
259 | | None => { |
260 | 1 | let stream = self.input.execute(0, context); |
261 | 1 | debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"0 ); |
262 | 1 | stream |
263 | | } |
264 | | }, |
265 | | _ => { |
266 | 13 | let receivers = (0..input_partitions) |
267 | 40 | .map(|partition| { |
268 | 40 | let stream = |
269 | 40 | self.input.execute(partition, Arc::clone(&context))?0 ; |
270 | 40 | Ok(spawn_buffered(stream, 1)) |
271 | 40 | }) |
272 | 13 | .collect::<Result<_>>()?0 ; |
273 | | |
274 | 13 | debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"0 ); |
275 | | |
276 | 13 | let result12 = StreamingMergeBuilder::new() |
277 | 13 | .with_streams(receivers) |
278 | 13 | .with_schema(schema) |
279 | 13 | .with_expressions(&self.expr) |
280 | 13 | .with_metrics(BaselineMetrics::new(&self.metrics, partition)) |
281 | 13 | .with_batch_size(context.session_config().batch_size()) |
282 | 13 | .with_fetch(self.fetch) |
283 | 13 | .with_reservation(reservation) |
284 | 13 | .build()?1 ; |
285 | | |
286 | 12 | debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"0 ); |
287 | | |
288 | 12 | Ok(result) |
289 | | } |
290 | | } |
291 | 15 | } |
292 | | |
293 | 1 | fn metrics(&self) -> Option<MetricsSet> { |
294 | 1 | Some(self.metrics.clone_inner()) |
295 | 1 | } |
296 | | |
297 | 0 | fn statistics(&self) -> Result<Statistics> { |
298 | 0 | self.input.statistics() |
299 | 0 | } |
300 | | |
301 | 0 | fn supports_limit_pushdown(&self) -> bool { |
302 | 0 | true |
303 | 0 | } |
304 | | } |
305 | | |
306 | | #[cfg(test)] |
307 | | mod tests { |
308 | | use std::fmt::Formatter; |
309 | | use std::pin::Pin; |
310 | | use std::sync::Mutex; |
311 | | use std::task::{Context, Poll}; |
312 | | use std::time::Duration; |
313 | | |
314 | | use super::*; |
315 | | use crate::coalesce_partitions::CoalescePartitionsExec; |
316 | | use crate::expressions::col; |
317 | | use crate::memory::MemoryExec; |
318 | | use crate::metrics::{MetricValue, Timestamp}; |
319 | | use crate::sorts::sort::SortExec; |
320 | | use crate::stream::RecordBatchReceiverStream; |
321 | | use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; |
322 | | use crate::test::{self, assert_is_pending, make_partition}; |
323 | | use crate::{collect, common, ExecutionMode}; |
324 | | |
325 | | use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray}; |
326 | | use arrow::compute::SortOptions; |
327 | | use arrow::datatypes::{DataType, Field, Schema}; |
328 | | use arrow::record_batch::RecordBatch; |
329 | | use arrow_schema::SchemaRef; |
330 | | use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError}; |
331 | | use datafusion_common_runtime::SpawnedTask; |
332 | | use datafusion_execution::config::SessionConfig; |
333 | | use datafusion_execution::RecordBatchStream; |
334 | | use datafusion_physical_expr::expressions::Column; |
335 | | use datafusion_physical_expr::EquivalenceProperties; |
336 | | use datafusion_physical_expr_common::physical_expr::PhysicalExpr; |
337 | | |
338 | | use futures::{FutureExt, Stream, StreamExt}; |
339 | | use tokio::time::timeout; |
340 | | |
341 | | #[tokio::test] |
342 | 1 | async fn test_merge_interleave() { |
343 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
344 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
345 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
346 | 1 | Some("a"), |
347 | 1 | Some("c"), |
348 | 1 | Some("e"), |
349 | 1 | Some("g"), |
350 | 1 | Some("j"), |
351 | 1 | ])); |
352 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); |
353 | 1 | let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
354 | 1 | |
355 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30])); |
356 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
357 | 1 | Some("b"), |
358 | 1 | Some("d"), |
359 | 1 | Some("f"), |
360 | 1 | Some("h"), |
361 | 1 | Some("j"), |
362 | 1 | ])); |
363 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); |
364 | 1 | let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
365 | 1 | |
366 | 1 | _test_merge( |
367 | 1 | &[vec![b1], vec![b2]], |
368 | 1 | &[ |
369 | 1 | "+----+---+-------------------------------+", |
370 | 1 | "| a | b | c |", |
371 | 1 | "+----+---+-------------------------------+", |
372 | 1 | "| 1 | a | 1970-01-01T00:00:00.000000008 |", |
373 | 1 | "| 10 | b | 1970-01-01T00:00:00.000000004 |", |
374 | 1 | "| 2 | c | 1970-01-01T00:00:00.000000007 |", |
375 | 1 | "| 20 | d | 1970-01-01T00:00:00.000000006 |", |
376 | 1 | "| 7 | e | 1970-01-01T00:00:00.000000006 |", |
377 | 1 | "| 70 | f | 1970-01-01T00:00:00.000000002 |", |
378 | 1 | "| 9 | g | 1970-01-01T00:00:00.000000005 |", |
379 | 1 | "| 90 | h | 1970-01-01T00:00:00.000000002 |", |
380 | 1 | "| 30 | j | 1970-01-01T00:00:00.000000006 |", // input b2 before b1 |
381 | 1 | "| 3 | j | 1970-01-01T00:00:00.000000008 |", |
382 | 1 | "+----+---+-------------------------------+", |
383 | 1 | ], |
384 | 1 | task_ctx, |
385 | 1 | ) |
386 | 1 | .await0 ; |
387 | 1 | } |
388 | | |
389 | | #[tokio::test] |
390 | 1 | async fn test_merge_no_exprs() { |
391 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
392 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
393 | 1 | let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); |
394 | 1 | |
395 | 1 | let schema = batch.schema(); |
396 | 1 | let sort = vec![]; // no sort expressions |
397 | 1 | let exec = MemoryExec::try_new(&[vec![batch.clone()], vec![batch]], schema, None) |
398 | 1 | .unwrap(); |
399 | 1 | let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); |
400 | 1 | |
401 | 1 | let res = collect(merge, task_ctx).await0 .unwrap_err(); |
402 | 1 | assert_contains!( |
403 | 1 | res.to_string(), |
404 | 1 | "Internal error: Sort expressions cannot be empty for streaming merge" |
405 | 1 | ); |
406 | 1 | } |
407 | | |
408 | | #[tokio::test] |
409 | 1 | async fn test_merge_some_overlap() { |
410 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
411 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
412 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
413 | 1 | Some("a"), |
414 | 1 | Some("b"), |
415 | 1 | Some("c"), |
416 | 1 | Some("d"), |
417 | 1 | Some("e"), |
418 | 1 | ])); |
419 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); |
420 | 1 | let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
421 | 1 | |
422 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110])); |
423 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
424 | 1 | Some("c"), |
425 | 1 | Some("d"), |
426 | 1 | Some("e"), |
427 | 1 | Some("f"), |
428 | 1 | Some("g"), |
429 | 1 | ])); |
430 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); |
431 | 1 | let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
432 | 1 | |
433 | 1 | _test_merge( |
434 | 1 | &[vec![b1], vec![b2]], |
435 | 1 | &[ |
436 | 1 | "+-----+---+-------------------------------+", |
437 | 1 | "| a | b | c |", |
438 | 1 | "+-----+---+-------------------------------+", |
439 | 1 | "| 1 | a | 1970-01-01T00:00:00.000000008 |", |
440 | 1 | "| 2 | b | 1970-01-01T00:00:00.000000007 |", |
441 | 1 | "| 70 | c | 1970-01-01T00:00:00.000000004 |", |
442 | 1 | "| 7 | c | 1970-01-01T00:00:00.000000006 |", |
443 | 1 | "| 9 | d | 1970-01-01T00:00:00.000000005 |", |
444 | 1 | "| 90 | d | 1970-01-01T00:00:00.000000006 |", |
445 | 1 | "| 30 | e | 1970-01-01T00:00:00.000000002 |", |
446 | 1 | "| 3 | e | 1970-01-01T00:00:00.000000008 |", |
447 | 1 | "| 100 | f | 1970-01-01T00:00:00.000000002 |", |
448 | 1 | "| 110 | g | 1970-01-01T00:00:00.000000006 |", |
449 | 1 | "+-----+---+-------------------------------+", |
450 | 1 | ], |
451 | 1 | task_ctx, |
452 | 1 | ) |
453 | 1 | .await0 ; |
454 | 1 | } |
455 | | |
456 | | #[tokio::test] |
457 | 1 | async fn test_merge_no_overlap() { |
458 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
459 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
460 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
461 | 1 | Some("a"), |
462 | 1 | Some("b"), |
463 | 1 | Some("c"), |
464 | 1 | Some("d"), |
465 | 1 | Some("e"), |
466 | 1 | ])); |
467 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); |
468 | 1 | let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
469 | 1 | |
470 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30])); |
471 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
472 | 1 | Some("f"), |
473 | 1 | Some("g"), |
474 | 1 | Some("h"), |
475 | 1 | Some("i"), |
476 | 1 | Some("j"), |
477 | 1 | ])); |
478 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); |
479 | 1 | let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
480 | 1 | |
481 | 1 | _test_merge( |
482 | 1 | &[vec![b1], vec![b2]], |
483 | 1 | &[ |
484 | 1 | "+----+---+-------------------------------+", |
485 | 1 | "| a | b | c |", |
486 | 1 | "+----+---+-------------------------------+", |
487 | 1 | "| 1 | a | 1970-01-01T00:00:00.000000008 |", |
488 | 1 | "| 2 | b | 1970-01-01T00:00:00.000000007 |", |
489 | 1 | "| 7 | c | 1970-01-01T00:00:00.000000006 |", |
490 | 1 | "| 9 | d | 1970-01-01T00:00:00.000000005 |", |
491 | 1 | "| 3 | e | 1970-01-01T00:00:00.000000008 |", |
492 | 1 | "| 10 | f | 1970-01-01T00:00:00.000000004 |", |
493 | 1 | "| 20 | g | 1970-01-01T00:00:00.000000006 |", |
494 | 1 | "| 70 | h | 1970-01-01T00:00:00.000000002 |", |
495 | 1 | "| 90 | i | 1970-01-01T00:00:00.000000002 |", |
496 | 1 | "| 30 | j | 1970-01-01T00:00:00.000000006 |", |
497 | 1 | "+----+---+-------------------------------+", |
498 | 1 | ], |
499 | 1 | task_ctx, |
500 | 1 | ) |
501 | 1 | .await0 ; |
502 | 1 | } |
503 | | |
504 | | #[tokio::test] |
505 | 1 | async fn test_merge_three_partitions() { |
506 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
507 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
508 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
509 | 1 | Some("a"), |
510 | 1 | Some("b"), |
511 | 1 | Some("c"), |
512 | 1 | Some("d"), |
513 | 1 | Some("f"), |
514 | 1 | ])); |
515 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); |
516 | 1 | let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
517 | 1 | |
518 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30])); |
519 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
520 | 1 | Some("e"), |
521 | 1 | Some("g"), |
522 | 1 | Some("h"), |
523 | 1 | Some("i"), |
524 | 1 | Some("j"), |
525 | 1 | ])); |
526 | 1 | let c: ArrayRef = |
527 | 1 | Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60])); |
528 | 1 | let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
529 | 1 | |
530 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300])); |
531 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
532 | 1 | Some("f"), |
533 | 1 | Some("g"), |
534 | 1 | Some("h"), |
535 | 1 | Some("i"), |
536 | 1 | Some("j"), |
537 | 1 | ])); |
538 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); |
539 | 1 | let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
540 | 1 | |
541 | 1 | _test_merge( |
542 | 1 | &[vec![b1], vec![b2], vec![b3]], |
543 | 1 | &[ |
544 | 1 | "+-----+---+-------------------------------+", |
545 | 1 | "| a | b | c |", |
546 | 1 | "+-----+---+-------------------------------+", |
547 | 1 | "| 1 | a | 1970-01-01T00:00:00.000000008 |", |
548 | 1 | "| 2 | b | 1970-01-01T00:00:00.000000007 |", |
549 | 1 | "| 7 | c | 1970-01-01T00:00:00.000000006 |", |
550 | 1 | "| 9 | d | 1970-01-01T00:00:00.000000005 |", |
551 | 1 | "| 10 | e | 1970-01-01T00:00:00.000000040 |", |
552 | 1 | "| 100 | f | 1970-01-01T00:00:00.000000004 |", |
553 | 1 | "| 3 | f | 1970-01-01T00:00:00.000000008 |", |
554 | 1 | "| 200 | g | 1970-01-01T00:00:00.000000006 |", |
555 | 1 | "| 20 | g | 1970-01-01T00:00:00.000000060 |", |
556 | 1 | "| 700 | h | 1970-01-01T00:00:00.000000002 |", |
557 | 1 | "| 70 | h | 1970-01-01T00:00:00.000000020 |", |
558 | 1 | "| 900 | i | 1970-01-01T00:00:00.000000002 |", |
559 | 1 | "| 90 | i | 1970-01-01T00:00:00.000000020 |", |
560 | 1 | "| 300 | j | 1970-01-01T00:00:00.000000006 |", |
561 | 1 | "| 30 | j | 1970-01-01T00:00:00.000000060 |", |
562 | 1 | "+-----+---+-------------------------------+", |
563 | 1 | ], |
564 | 1 | task_ctx, |
565 | 1 | ) |
566 | 1 | .await0 ; |
567 | 1 | } |
568 | | |
569 | 4 | async fn _test_merge( |
570 | 4 | partitions: &[Vec<RecordBatch>], |
571 | 4 | exp: &[&str], |
572 | 4 | context: Arc<TaskContext>, |
573 | 4 | ) { |
574 | 4 | let schema = partitions[0][0].schema(); |
575 | 4 | let sort = vec![ |
576 | 4 | PhysicalSortExpr { |
577 | 4 | expr: col("b", &schema).unwrap(), |
578 | 4 | options: Default::default(), |
579 | 4 | }, |
580 | 4 | PhysicalSortExpr { |
581 | 4 | expr: col("c", &schema).unwrap(), |
582 | 4 | options: Default::default(), |
583 | 4 | }, |
584 | 4 | ]; |
585 | 4 | let exec = MemoryExec::try_new(partitions, schema, None).unwrap(); |
586 | 4 | let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); |
587 | | |
588 | 4 | let collected = collect(merge, context).await0 .unwrap(); |
589 | 4 | assert_batches_eq!(exp, collected.as_slice()); |
590 | 4 | } |
591 | | |
592 | 2 | async fn sorted_merge( |
593 | 2 | input: Arc<dyn ExecutionPlan>, |
594 | 2 | sort: Vec<PhysicalSortExpr>, |
595 | 2 | context: Arc<TaskContext>, |
596 | 2 | ) -> RecordBatch { |
597 | 2 | let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); |
598 | 2 | let mut result = collect(merge, context).await0 .unwrap(); |
599 | 2 | assert_eq!(result.len(), 1); |
600 | 2 | result.remove(0) |
601 | 2 | } |
602 | | |
603 | 1 | async fn partition_sort( |
604 | 1 | input: Arc<dyn ExecutionPlan>, |
605 | 1 | sort: Vec<PhysicalSortExpr>, |
606 | 1 | context: Arc<TaskContext>, |
607 | 1 | ) -> RecordBatch { |
608 | 1 | let sort_exec = |
609 | 1 | Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true)); |
610 | 1 | sorted_merge(sort_exec, sort, context).await0 |
611 | 1 | } |
612 | | |
613 | 7 | async fn basic_sort( |
614 | 7 | src: Arc<dyn ExecutionPlan>, |
615 | 7 | sort: Vec<PhysicalSortExpr>, |
616 | 7 | context: Arc<TaskContext>, |
617 | 7 | ) -> RecordBatch { |
618 | 7 | let merge = Arc::new(CoalescePartitionsExec::new(src)); |
619 | 7 | let sort_exec = Arc::new(SortExec::new(sort, merge)); |
620 | 217 | let mut result7 = collect(sort_exec, context)7 .await.unwrap(); |
621 | 7 | assert_eq!(result.len(), 1); |
622 | 7 | result.remove(0) |
623 | 7 | } |
624 | | |
625 | | #[tokio::test] |
626 | 1 | async fn test_partition_sort() -> Result<()> { |
627 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
628 | 1 | let partitions = 4; |
629 | 1 | let csv = test::scan_partitioned(partitions); |
630 | 1 | let schema = csv.schema(); |
631 | 1 | |
632 | 1 | let sort = vec![PhysicalSortExpr { |
633 | 1 | expr: col("i", &schema).unwrap(), |
634 | 1 | options: SortOptions { |
635 | 1 | descending: true, |
636 | 1 | nulls_first: true, |
637 | 1 | }, |
638 | 1 | }]; |
639 | 1 | |
640 | 1 | let basic = |
641 | 1 | basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await; |
642 | 1 | let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await0 ; |
643 | 1 | |
644 | 1 | let basic = arrow::util::pretty::pretty_format_batches(&[basic]) |
645 | 1 | .unwrap() |
646 | 1 | .to_string(); |
647 | 1 | let partition = arrow::util::pretty::pretty_format_batches(&[partition]) |
648 | 1 | .unwrap() |
649 | 1 | .to_string(); |
650 | 1 | |
651 | 1 | assert_eq!( |
652 | 1 | basic, partition, |
653 | 1 | "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"0 |
654 | 1 | ); |
655 | 1 | |
656 | 1 | Ok(()) |
657 | 1 | } |
658 | | |
659 | | // Split the provided record batch into multiple batch_size record batches |
660 | 9 | fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> { |
661 | 9 | let batches = (sorted.num_rows() + batch_size - 1) / batch_size; |
662 | 9 | |
663 | 9 | // Split the sorted RecordBatch into multiple |
664 | 9 | (0..batches) |
665 | 634 | .map(|batch_idx| { |
666 | 634 | let columns = (0..sorted.num_columns()) |
667 | 634 | .map(|column_idx| { |
668 | 634 | let length = |
669 | 634 | batch_size.min(sorted.num_rows() - batch_idx * batch_size); |
670 | 634 | |
671 | 634 | sorted |
672 | 634 | .column(column_idx) |
673 | 634 | .slice(batch_idx * batch_size, length) |
674 | 634 | }) |
675 | 634 | .collect(); |
676 | 634 | |
677 | 634 | RecordBatch::try_new(sorted.schema(), columns).unwrap() |
678 | 634 | }) |
679 | 9 | .collect() |
680 | 9 | } |
681 | | |
682 | 3 | async fn sorted_partitioned_input( |
683 | 3 | sort: Vec<PhysicalSortExpr>, |
684 | 3 | sizes: &[usize], |
685 | 3 | context: Arc<TaskContext>, |
686 | 3 | ) -> Result<Arc<dyn ExecutionPlan>> { |
687 | 3 | let partitions = 4; |
688 | 3 | let csv = test::scan_partitioned(partitions); |
689 | | |
690 | 3 | let sorted = basic_sort(csv, sort, context).await; |
691 | 9 | let split: Vec<_> = sizes.iter().map(3 |x| split_batch(&sorted, *x)).collect(); |
692 | 3 | |
693 | 3 | Ok(Arc::new( |
694 | 3 | MemoryExec::try_new(&split, sorted.schema(), None).unwrap(), |
695 | 3 | )) |
696 | 3 | } |
697 | | |
698 | | #[tokio::test] |
699 | 1 | async fn test_partition_sort_streaming_input() -> Result<()> { |
700 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
701 | 1 | let schema = make_partition(11).schema(); |
702 | 1 | let sort = vec![PhysicalSortExpr { |
703 | 1 | expr: col("i", &schema).unwrap(), |
704 | 1 | options: Default::default(), |
705 | 1 | }]; |
706 | 1 | |
707 | 1 | let input = |
708 | 1 | sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx)) |
709 | 1 | .await?0 ; |
710 | 1 | let basic = |
711 | 71 | basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx))1 .await; |
712 | 1 | let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await0 ; |
713 | 1 | |
714 | 1 | assert_eq!(basic.num_rows(), 1200); |
715 | 1 | assert_eq!(partition.num_rows(), 1200); |
716 | 1 | |
717 | 1 | let basic = arrow::util::pretty::pretty_format_batches(&[basic]) |
718 | 1 | .unwrap() |
719 | 1 | .to_string(); |
720 | 1 | let partition = arrow::util::pretty::pretty_format_batches(&[partition]) |
721 | 1 | .unwrap() |
722 | 1 | .to_string(); |
723 | 1 | |
724 | 1 | assert_eq!(basic, partition); |
725 | 1 | |
726 | 1 | Ok(()) |
727 | 1 | } |
728 | | |
729 | | #[tokio::test] |
730 | 1 | async fn test_partition_sort_streaming_input_output() -> Result<()> { |
731 | 1 | let schema = make_partition(11).schema(); |
732 | 1 | let sort = vec![PhysicalSortExpr { |
733 | 1 | expr: col("i", &schema).unwrap(), |
734 | 1 | options: Default::default(), |
735 | 1 | }]; |
736 | 1 | |
737 | 1 | // Test streaming with default batch size |
738 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
739 | 1 | let input = |
740 | 1 | sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx)) |
741 | 1 | .await?0 ; |
742 | 51 | let basic1 = basic_sort(Arc::clone(&input), sort.clone(), task_ctx)1 .await; |
743 | 1 | |
744 | 1 | // batch size of 23 |
745 | 1 | let task_ctx = TaskContext::default() |
746 | 1 | .with_session_config(SessionConfig::new().with_batch_size(23)); |
747 | 1 | let task_ctx = Arc::new(task_ctx); |
748 | 1 | |
749 | 1 | let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); |
750 | 1 | let merged = collect(merge, task_ctx).await0 .unwrap(); |
751 | 1 | |
752 | 1 | assert_eq!(merged.len(), 53); |
753 | 1 | |
754 | 1 | assert_eq!(basic.num_rows(), 1200); |
755 | 53 | assert_eq!(merged.iter().map(1 |x| x.num_rows()).sum::<usize>(), 1200)1 ; |
756 | 1 | |
757 | 1 | let basic = arrow::util::pretty::pretty_format_batches(&[basic]) |
758 | 1 | .unwrap() |
759 | 1 | .to_string(); |
760 | 1 | let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice()) |
761 | 1 | .unwrap() |
762 | 1 | .to_string(); |
763 | 1 | |
764 | 1 | assert_eq!(basic, partition); |
765 | 1 | |
766 | 1 | Ok(()) |
767 | 1 | } |
768 | | |
769 | | #[tokio::test] |
770 | 1 | async fn test_nulls() { |
771 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
772 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
773 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
774 | 1 | None, |
775 | 1 | Some("a"), |
776 | 1 | Some("b"), |
777 | 1 | Some("d"), |
778 | 1 | Some("e"), |
779 | 1 | ])); |
780 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![ |
781 | 1 | Some(8), |
782 | 1 | None, |
783 | 1 | Some(6), |
784 | 1 | None, |
785 | 1 | Some(4), |
786 | 1 | ])); |
787 | 1 | let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
788 | 1 | |
789 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); |
790 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ |
791 | 1 | None, |
792 | 1 | Some("b"), |
793 | 1 | Some("g"), |
794 | 1 | Some("h"), |
795 | 1 | Some("i"), |
796 | 1 | ])); |
797 | 1 | let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![ |
798 | 1 | Some(8), |
799 | 1 | None, |
800 | 1 | Some(5), |
801 | 1 | None, |
802 | 1 | Some(4), |
803 | 1 | ])); |
804 | 1 | let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); |
805 | 1 | let schema = b1.schema(); |
806 | 1 | |
807 | 1 | let sort = vec![ |
808 | 1 | PhysicalSortExpr { |
809 | 1 | expr: col("b", &schema).unwrap(), |
810 | 1 | options: SortOptions { |
811 | 1 | descending: false, |
812 | 1 | nulls_first: true, |
813 | 1 | }, |
814 | 1 | }, |
815 | 1 | PhysicalSortExpr { |
816 | 1 | expr: col("c", &schema).unwrap(), |
817 | 1 | options: SortOptions { |
818 | 1 | descending: false, |
819 | 1 | nulls_first: false, |
820 | 1 | }, |
821 | 1 | }, |
822 | 1 | ]; |
823 | 1 | let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); |
824 | 1 | let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); |
825 | 1 | |
826 | 1 | let collected = collect(merge, task_ctx).await0 .unwrap(); |
827 | 1 | assert_eq!(collected.len(), 1); |
828 | 1 | |
829 | 1 | assert_batches_eq!( |
830 | 1 | &[ |
831 | 1 | "+---+---+-------------------------------+", |
832 | 1 | "| a | b | c |", |
833 | 1 | "+---+---+-------------------------------+", |
834 | 1 | "| 1 | | 1970-01-01T00:00:00.000000008 |", |
835 | 1 | "| 1 | | 1970-01-01T00:00:00.000000008 |", |
836 | 1 | "| 2 | a | |", |
837 | 1 | "| 7 | b | 1970-01-01T00:00:00.000000006 |", |
838 | 1 | "| 2 | b | |", |
839 | 1 | "| 9 | d | |", |
840 | 1 | "| 3 | e | 1970-01-01T00:00:00.000000004 |", |
841 | 1 | "| 3 | g | 1970-01-01T00:00:00.000000005 |", |
842 | 1 | "| 4 | h | |", |
843 | 1 | "| 5 | i | 1970-01-01T00:00:00.000000004 |", |
844 | 1 | "+---+---+-------------------------------+", |
845 | 1 | ], |
846 | 1 | collected.as_slice() |
847 | 1 | ); |
848 | 1 | } |
849 | | |
850 | | #[tokio::test] |
851 | 1 | async fn test_sort_merge_single_partition_with_fetch() { |
852 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
853 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
854 | 1 | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); |
855 | 1 | let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); |
856 | 1 | let schema = batch.schema(); |
857 | 1 | |
858 | 1 | let sort = vec![PhysicalSortExpr { |
859 | 1 | expr: col("b", &schema).unwrap(), |
860 | 1 | options: SortOptions { |
861 | 1 | descending: false, |
862 | 1 | nulls_first: true, |
863 | 1 | }, |
864 | 1 | }]; |
865 | 1 | let exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap(); |
866 | 1 | let merge = Arc::new( |
867 | 1 | SortPreservingMergeExec::new(sort, Arc::new(exec)).with_fetch(Some(2)), |
868 | 1 | ); |
869 | 1 | |
870 | 1 | let collected = collect(merge, task_ctx).await0 .unwrap(); |
871 | 1 | assert_eq!(collected.len(), 1); |
872 | 1 | |
873 | 1 | assert_batches_eq!( |
874 | 1 | &[ |
875 | 1 | "+---+---+", |
876 | 1 | "| a | b |", |
877 | 1 | "+---+---+", |
878 | 1 | "| 1 | a |", |
879 | 1 | "| 2 | b |", |
880 | 1 | "+---+---+", |
881 | 1 | ], |
882 | 1 | collected.as_slice() |
883 | 1 | ); |
884 | 1 | } |
885 | | |
886 | | #[tokio::test] |
887 | 1 | async fn test_sort_merge_single_partition_without_fetch() { |
888 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
889 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); |
890 | 1 | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); |
891 | 1 | let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); |
892 | 1 | let schema = batch.schema(); |
893 | 1 | |
894 | 1 | let sort = vec![PhysicalSortExpr { |
895 | 1 | expr: col("b", &schema).unwrap(), |
896 | 1 | options: SortOptions { |
897 | 1 | descending: false, |
898 | 1 | nulls_first: true, |
899 | 1 | }, |
900 | 1 | }]; |
901 | 1 | let exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap(); |
902 | 1 | let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); |
903 | 1 | |
904 | 1 | let collected = collect(merge, task_ctx).await0 .unwrap(); |
905 | 1 | assert_eq!(collected.len(), 1); |
906 | 1 | |
907 | 1 | assert_batches_eq!( |
908 | 1 | &[ |
909 | 1 | "+---+---+", |
910 | 1 | "| a | b |", |
911 | 1 | "+---+---+", |
912 | 1 | "| 1 | a |", |
913 | 1 | "| 2 | b |", |
914 | 1 | "| 7 | c |", |
915 | 1 | "| 9 | d |", |
916 | 1 | "| 3 | e |", |
917 | 1 | "+---+---+", |
918 | 1 | ], |
919 | 1 | collected.as_slice() |
920 | 1 | ); |
921 | 1 | } |
922 | | |
923 | | #[tokio::test] |
924 | 1 | async fn test_async() -> Result<()> { |
925 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
926 | 1 | let schema = make_partition(11).schema(); |
927 | 1 | let sort = vec![PhysicalSortExpr { |
928 | 1 | expr: col("i", &schema).unwrap(), |
929 | 1 | options: SortOptions::default(), |
930 | 1 | }]; |
931 | 1 | |
932 | 1 | let batches = |
933 | 1 | sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx)) |
934 | 1 | .await?0 ; |
935 | 1 | |
936 | 1 | let partition_count = batches.output_partitioning().partition_count(); |
937 | 1 | let mut streams = Vec::with_capacity(partition_count); |
938 | 1 | |
939 | 3 | for partition in 0..partition_count1 { |
940 | 3 | let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1); |
941 | 3 | |
942 | 3 | let sender = builder.tx(); |
943 | 3 | |
944 | 3 | let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap(); |
945 | 3 | builder.spawn(async move { |
946 | 275 | while let Some(batch272 ) = stream.next().await0 { |
947 | 272 | sender.send(batch).await130 .unwrap(); |
948 | 272 | // This causes the MergeStream to wait for more input |
949 | 272 | tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; |
950 | 1 | } |
951 | 1 | |
952 | 3 | Ok(()) |
953 | 3 | }); |
954 | 3 | |
955 | 3 | streams.push(builder.build()); |
956 | 3 | } |
957 | 1 | |
958 | 1 | let metrics = ExecutionPlanMetricsSet::new(); |
959 | 1 | let reservation = |
960 | 1 | MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool); |
961 | 1 | |
962 | 1 | let fetch = None; |
963 | 1 | let merge_stream = StreamingMergeBuilder::new() |
964 | 1 | .with_streams(streams) |
965 | 1 | .with_schema(batches.schema()) |
966 | 1 | .with_expressions(sort.as_slice()) |
967 | 1 | .with_metrics(BaselineMetrics::new(&metrics, 0)) |
968 | 1 | .with_batch_size(task_ctx.session_config().batch_size()) |
969 | 1 | .with_fetch(fetch) |
970 | 1 | .with_reservation(reservation) |
971 | 1 | .build()?0 ; |
972 | 1 | |
973 | 135 | let mut merged1 = common::collect(merge_stream)1 .await.unwrap(); |
974 | 1 | |
975 | 1 | assert_eq!(merged.len(), 1); |
976 | 1 | let merged = merged.remove(0); |
977 | 91 | let basic1 = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx))1 .await; |
978 | 1 | |
979 | 1 | let basic = arrow::util::pretty::pretty_format_batches(&[basic]) |
980 | 1 | .unwrap() |
981 | 1 | .to_string(); |
982 | 1 | let partition = arrow::util::pretty::pretty_format_batches(&[merged]) |
983 | 1 | .unwrap() |
984 | 1 | .to_string(); |
985 | 1 | |
986 | 1 | assert_eq!( |
987 | 1 | basic, partition, |
988 | 1 | "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"0 |
989 | 1 | ); |
990 | 1 | |
991 | 1 | Ok(()) |
992 | 1 | } |
993 | | |
994 | | #[tokio::test] |
995 | 1 | async fn test_merge_metrics() { |
996 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
997 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
998 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); |
999 | 1 | let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); |
1000 | 1 | |
1001 | 1 | let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20])); |
1002 | 1 | let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")])); |
1003 | 1 | let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); |
1004 | 1 | |
1005 | 1 | let schema = b1.schema(); |
1006 | 1 | let sort = vec![PhysicalSortExpr { |
1007 | 1 | expr: col("b", &schema).unwrap(), |
1008 | 1 | options: Default::default(), |
1009 | 1 | }]; |
1010 | 1 | let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); |
1011 | 1 | let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); |
1012 | 1 | |
1013 | 1 | let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx) |
1014 | 1 | .await0 |
1015 | 1 | .unwrap(); |
1016 | 1 | let expected = [ |
1017 | 1 | "+----+---+", |
1018 | 1 | "| a | b |", |
1019 | 1 | "+----+---+", |
1020 | 1 | "| 1 | a |", |
1021 | 1 | "| 10 | b |", |
1022 | 1 | "| 2 | c |", |
1023 | 1 | "| 20 | d |", |
1024 | 1 | "+----+---+", |
1025 | 1 | ]; |
1026 | 1 | assert_batches_eq!(expected, collected.as_slice()); |
1027 | 1 | |
1028 | 1 | // Now, validate metrics |
1029 | 1 | let metrics = merge.metrics().unwrap(); |
1030 | 1 | |
1031 | 1 | assert_eq!(metrics.output_rows().unwrap(), 4); |
1032 | 1 | assert!(metrics.elapsed_compute().unwrap() > 0); |
1033 | 1 | |
1034 | 1 | let mut saw_start = false; |
1035 | 1 | let mut saw_end = false; |
1036 | 4 | metrics.iter().for_each(|m| match m.value() { |
1037 | 1 | MetricValue::StartTimestamp(ts) => { |
1038 | 1 | saw_start = true; |
1039 | 1 | assert!(nanos_from_timestamp(ts) > 0); |
1040 | 1 | } |
1041 | 1 | MetricValue::EndTimestamp(ts) => { |
1042 | 1 | saw_end = true; |
1043 | 1 | assert!(nanos_from_timestamp(ts) > 0); |
1044 | 1 | } |
1045 | 2 | _ => {} |
1046 | 4 | }); |
1047 | 1 | |
1048 | 1 | assert!(saw_start); |
1049 | 1 | assert!(saw_end); |
1050 | 1 | } |
1051 | | |
1052 | 2 | fn nanos_from_timestamp(ts: &Timestamp) -> i64 { |
1053 | 2 | ts.value().unwrap().timestamp_nanos_opt().unwrap() |
1054 | 2 | } |
1055 | | |
1056 | | #[tokio::test] |
1057 | 1 | async fn test_drop_cancel() -> Result<()> { |
1058 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1059 | 1 | let schema = |
1060 | 1 | Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); |
1061 | 1 | |
1062 | 1 | let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); |
1063 | 1 | let refs = blocking_exec.refs(); |
1064 | 1 | let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( |
1065 | 1 | vec![PhysicalSortExpr { |
1066 | 1 | expr: col("a", &schema)?0 , |
1067 | 1 | options: SortOptions::default(), |
1068 | 1 | }], |
1069 | 1 | blocking_exec, |
1070 | 1 | )); |
1071 | 1 | |
1072 | 1 | let fut = collect(sort_preserving_merge_exec, task_ctx); |
1073 | 1 | let mut fut = fut.boxed(); |
1074 | 1 | |
1075 | 1 | assert_is_pending(&mut fut); |
1076 | 1 | drop(fut); |
1077 | 1 | assert_strong_count_converges_to_zero(refs).await0 ; |
1078 | 1 | |
1079 | 1 | Ok(()) |
1080 | 1 | } |
1081 | | |
1082 | | #[tokio::test] |
1083 | 1 | async fn test_stable_sort() { |
1084 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1085 | 1 | |
1086 | 1 | // Create record batches like: |
1087 | 1 | // batch_number |value |
1088 | 1 | // -------------+------ |
1089 | 1 | // 1 | A |
1090 | 1 | // 1 | B |
1091 | 1 | // |
1092 | 1 | // Ensure that the output is in the same order the batches were fed |
1093 | 1 | let partitions: Vec<Vec<RecordBatch>> = (0..10) |
1094 | 10 | .map(|batch_number| { |
1095 | 10 | let batch_number: Int32Array = |
1096 | 10 | vec![Some(batch_number), Some(batch_number)] |
1097 | 10 | .into_iter() |
1098 | 10 | .collect(); |
1099 | 10 | let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect(); |
1100 | 10 | |
1101 | 10 | let batch = RecordBatch::try_from_iter(vec![ |
1102 | 10 | ("batch_number", Arc::new(batch_number) as ArrayRef), |
1103 | 10 | ("value", Arc::new(value) as ArrayRef), |
1104 | 10 | ]) |
1105 | 10 | .unwrap(); |
1106 | 10 | |
1107 | 10 | vec![batch] |
1108 | 10 | }) |
1109 | 1 | .collect(); |
1110 | 1 | |
1111 | 1 | let schema = partitions[0][0].schema(); |
1112 | 1 | |
1113 | 1 | let sort = vec![PhysicalSortExpr { |
1114 | 1 | expr: col("value", &schema).unwrap(), |
1115 | 1 | options: SortOptions { |
1116 | 1 | descending: false, |
1117 | 1 | nulls_first: true, |
1118 | 1 | }, |
1119 | 1 | }]; |
1120 | 1 | |
1121 | 1 | let exec = MemoryExec::try_new(&partitions, schema, None).unwrap(); |
1122 | 1 | let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); |
1123 | 1 | |
1124 | 1 | let collected = collect(merge, task_ctx).await0 .unwrap(); |
1125 | 1 | assert_eq!(collected.len(), 1); |
1126 | 1 | |
1127 | 1 | // Expect the data to be sorted first by "batch_number" (because |
1128 | 1 | // that was the order it was fed in, even though only "value" |
1129 | 1 | // is in the sort key) |
1130 | 1 | assert_batches_eq!( |
1131 | 1 | &[ |
1132 | 1 | "+--------------+-------+", |
1133 | 1 | "| batch_number | value |", |
1134 | 1 | "+--------------+-------+", |
1135 | 1 | "| 0 | A |", |
1136 | 1 | "| 1 | A |", |
1137 | 1 | "| 2 | A |", |
1138 | 1 | "| 3 | A |", |
1139 | 1 | "| 4 | A |", |
1140 | 1 | "| 5 | A |", |
1141 | 1 | "| 6 | A |", |
1142 | 1 | "| 7 | A |", |
1143 | 1 | "| 8 | A |", |
1144 | 1 | "| 9 | A |", |
1145 | 1 | "| 0 | B |", |
1146 | 1 | "| 1 | B |", |
1147 | 1 | "| 2 | B |", |
1148 | 1 | "| 3 | B |", |
1149 | 1 | "| 4 | B |", |
1150 | 1 | "| 5 | B |", |
1151 | 1 | "| 6 | B |", |
1152 | 1 | "| 7 | B |", |
1153 | 1 | "| 8 | B |", |
1154 | 1 | "| 9 | B |", |
1155 | 1 | "+--------------+-------+", |
1156 | 1 | ], |
1157 | 1 | collected.as_slice() |
1158 | 1 | ); |
1159 | 1 | } |
1160 | | |
1161 | | /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st |
1162 | | /// partition is exhausted from the start, and if it is polled more than one, it panics. |
1163 | | #[derive(Debug, Clone)] |
1164 | | struct CongestedExec { |
1165 | | schema: Schema, |
1166 | | cache: PlanProperties, |
1167 | | congestion_cleared: Arc<Mutex<bool>>, |
1168 | | } |
1169 | | |
1170 | | impl CongestedExec { |
1171 | 1 | fn compute_properties(schema: SchemaRef) -> PlanProperties { |
1172 | 1 | let columns = schema |
1173 | 1 | .fields |
1174 | 1 | .iter() |
1175 | 1 | .enumerate() |
1176 | 1 | .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>) |
1177 | 1 | .collect::<Vec<_>>(); |
1178 | 1 | let mut eq_properties = EquivalenceProperties::new(schema); |
1179 | 1 | eq_properties.add_new_orderings(vec![columns |
1180 | 1 | .iter() |
1181 | 1 | .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))) |
1182 | 1 | .collect::<Vec<_>>()]); |
1183 | 1 | let mode = ExecutionMode::Unbounded; |
1184 | 1 | PlanProperties::new(eq_properties, Partitioning::Hash(columns, 3), mode) |
1185 | 1 | } |
1186 | | } |
1187 | | |
1188 | | impl ExecutionPlan for CongestedExec { |
1189 | 0 | fn name(&self) -> &'static str { |
1190 | 0 | Self::static_name() |
1191 | 0 | } |
1192 | 0 | fn as_any(&self) -> &dyn Any { |
1193 | 0 | self |
1194 | 0 | } |
1195 | 3 | fn properties(&self) -> &PlanProperties { |
1196 | 3 | &self.cache |
1197 | 3 | } |
1198 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
1199 | 0 | vec![] |
1200 | 0 | } |
1201 | 0 | fn with_new_children( |
1202 | 0 | self: Arc<Self>, |
1203 | 0 | _: Vec<Arc<dyn ExecutionPlan>>, |
1204 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
1205 | 0 | Ok(self) |
1206 | 0 | } |
1207 | 3 | fn execute( |
1208 | 3 | &self, |
1209 | 3 | partition: usize, |
1210 | 3 | _context: Arc<TaskContext>, |
1211 | 3 | ) -> Result<SendableRecordBatchStream> { |
1212 | 3 | Ok(Box::pin(CongestedStream { |
1213 | 3 | schema: Arc::new(self.schema.clone()), |
1214 | 3 | none_polled_once: false, |
1215 | 3 | congestion_cleared: Arc::clone(&self.congestion_cleared), |
1216 | 3 | partition, |
1217 | 3 | })) |
1218 | 3 | } |
1219 | | } |
1220 | | |
1221 | | impl DisplayAs for CongestedExec { |
1222 | 0 | fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { |
1223 | 0 | match t { |
1224 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
1225 | 0 | write!(f, "CongestedExec",).unwrap() |
1226 | 0 | } |
1227 | 0 | } |
1228 | 0 | Ok(()) |
1229 | 0 | } |
1230 | | } |
1231 | | |
1232 | | /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st |
1233 | | /// partition is exhausted from the start, and if it is polled more than once, it panics. |
1234 | | #[derive(Debug)] |
1235 | | pub struct CongestedStream { |
1236 | | schema: SchemaRef, |
1237 | | none_polled_once: bool, |
1238 | | congestion_cleared: Arc<Mutex<bool>>, |
1239 | | partition: usize, |
1240 | | } |
1241 | | |
1242 | | impl Stream for CongestedStream { |
1243 | | type Item = Result<RecordBatch>; |
1244 | 4 | fn poll_next( |
1245 | 4 | mut self: Pin<&mut Self>, |
1246 | 4 | _cx: &mut Context<'_>, |
1247 | 4 | ) -> Poll<Option<Self::Item>> { |
1248 | 4 | match self.partition { |
1249 | | 0 => { |
1250 | 1 | if self.none_polled_once { |
1251 | 0 | panic!("Exhausted stream is polled more than one") |
1252 | | } else { |
1253 | 1 | self.none_polled_once = true; |
1254 | 1 | Poll::Ready(None) |
1255 | | } |
1256 | | } |
1257 | | 1 => { |
1258 | 2 | let cleared = self.congestion_cleared.lock().unwrap(); |
1259 | 2 | if *cleared { |
1260 | 1 | Poll::Ready(None) |
1261 | | } else { |
1262 | 1 | Poll::Pending |
1263 | | } |
1264 | | } |
1265 | | 2 => { |
1266 | 1 | let mut cleared = self.congestion_cleared.lock().unwrap(); |
1267 | 1 | *cleared = true; |
1268 | 1 | Poll::Ready(None) |
1269 | | } |
1270 | 0 | _ => unreachable!(), |
1271 | | } |
1272 | 4 | } |
1273 | | } |
1274 | | |
1275 | | impl RecordBatchStream for CongestedStream { |
1276 | 0 | fn schema(&self) -> SchemaRef { |
1277 | 0 | Arc::clone(&self.schema) |
1278 | 0 | } |
1279 | | } |
1280 | | |
1281 | | #[tokio::test] |
1282 | 1 | async fn test_spm_congestion() -> Result<()> { |
1283 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1284 | 1 | let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); |
1285 | 1 | let source = CongestedExec { |
1286 | 1 | schema: schema.clone(), |
1287 | 1 | cache: CongestedExec::compute_properties(Arc::new(schema.clone())), |
1288 | 1 | congestion_cleared: Arc::new(Mutex::new(false)), |
1289 | 1 | }; |
1290 | 1 | let spm = SortPreservingMergeExec::new( |
1291 | 1 | vec![PhysicalSortExpr::new_default(Arc::new(Column::new( |
1292 | 1 | "c1", 0, |
1293 | 1 | )))], |
1294 | 1 | Arc::new(source), |
1295 | 1 | ); |
1296 | 1 | let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx)); |
1297 | 1 | |
1298 | 1 | let result = timeout(Duration::from_secs(3), spm_task.join()).await; |
1299 | 1 | match result { |
1300 | 1 | Ok(Ok(Ok(_batches))) => Ok(()), |
1301 | 1 | Ok(Ok(Err(e))) => Err(e)0 , |
1302 | 1 | Ok(Err(_)) => Err(DataFusionError::Execution( |
1303 | 0 | "SortPreservingMerge task panicked or was cancelled".to_string(), |
1304 | 0 | )), |
1305 | 1 | Err(_) => Err(DataFusionError::Execution( |
1306 | 0 | "SortPreservingMerge caused a deadlock".to_string(), |
1307 | 0 | )), |
1308 | 1 | } |
1309 | 1 | } |
1310 | | } |