/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/limit.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines the LIMIT plan |
19 | | |
20 | | use std::any::Any; |
21 | | use std::pin::Pin; |
22 | | use std::sync::Arc; |
23 | | use std::task::{Context, Poll}; |
24 | | |
25 | | use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; |
26 | | use super::{ |
27 | | DisplayAs, ExecutionMode, ExecutionPlanProperties, PlanProperties, RecordBatchStream, |
28 | | SendableRecordBatchStream, Statistics, |
29 | | }; |
30 | | use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning}; |
31 | | |
32 | | use arrow::datatypes::SchemaRef; |
33 | | use arrow::record_batch::RecordBatch; |
34 | | use datafusion_common::{internal_err, Result}; |
35 | | use datafusion_execution::TaskContext; |
36 | | |
37 | | use futures::stream::{Stream, StreamExt}; |
38 | | use log::trace; |
39 | | |
40 | | /// Limit execution plan |
41 | | #[derive(Debug)] |
42 | | pub struct GlobalLimitExec { |
43 | | /// Input execution plan |
44 | | input: Arc<dyn ExecutionPlan>, |
45 | | /// Number of rows to skip before fetch |
46 | | skip: usize, |
47 | | /// Maximum number of rows to fetch, |
48 | | /// `None` means fetching all rows |
49 | | fetch: Option<usize>, |
50 | | /// Execution metrics |
51 | | metrics: ExecutionPlanMetricsSet, |
52 | | cache: PlanProperties, |
53 | | } |
54 | | |
55 | | impl GlobalLimitExec { |
56 | | /// Create a new GlobalLimitExec |
57 | 24 | pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self { |
58 | 24 | let cache = Self::compute_properties(&input); |
59 | 24 | GlobalLimitExec { |
60 | 24 | input, |
61 | 24 | skip, |
62 | 24 | fetch, |
63 | 24 | metrics: ExecutionPlanMetricsSet::new(), |
64 | 24 | cache, |
65 | 24 | } |
66 | 24 | } |
67 | | |
68 | | /// Input execution plan |
69 | 0 | pub fn input(&self) -> &Arc<dyn ExecutionPlan> { |
70 | 0 | &self.input |
71 | 0 | } |
72 | | |
73 | | /// Number of rows to skip before fetch |
74 | 0 | pub fn skip(&self) -> usize { |
75 | 0 | self.skip |
76 | 0 | } |
77 | | |
78 | | /// Maximum number of rows to fetch |
79 | 0 | pub fn fetch(&self) -> Option<usize> { |
80 | 0 | self.fetch |
81 | 0 | } |
82 | | |
83 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
84 | 24 | fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties { |
85 | 24 | PlanProperties::new( |
86 | 24 | input.equivalence_properties().clone(), // Equivalence Properties |
87 | 24 | Partitioning::UnknownPartitioning(1), // Output Partitioning |
88 | 24 | ExecutionMode::Bounded, // Execution Mode |
89 | 24 | ) |
90 | 24 | } |
91 | | } |
92 | | |
93 | | impl DisplayAs for GlobalLimitExec { |
94 | 0 | fn fmt_as( |
95 | 0 | &self, |
96 | 0 | t: DisplayFormatType, |
97 | 0 | f: &mut std::fmt::Formatter, |
98 | 0 | ) -> std::fmt::Result { |
99 | 0 | match t { |
100 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
101 | 0 | write!( |
102 | 0 | f, |
103 | 0 | "GlobalLimitExec: skip={}, fetch={}", |
104 | 0 | self.skip, |
105 | 0 | self.fetch.map_or("None".to_string(), |x| x.to_string()) |
106 | 0 | ) |
107 | 0 | } |
108 | 0 | } |
109 | 0 | } |
110 | | } |
111 | | |
112 | | impl ExecutionPlan for GlobalLimitExec { |
113 | 0 | fn name(&self) -> &'static str { |
114 | 0 | "GlobalLimitExec" |
115 | 0 | } |
116 | | |
117 | | /// Return a reference to Any that can be used for downcasting |
118 | 0 | fn as_any(&self) -> &dyn Any { |
119 | 0 | self |
120 | 0 | } |
121 | | |
122 | 16 | fn properties(&self) -> &PlanProperties { |
123 | 16 | &self.cache |
124 | 16 | } |
125 | | |
126 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
127 | 0 | vec![&self.input] |
128 | 0 | } |
129 | | |
130 | 0 | fn required_input_distribution(&self) -> Vec<Distribution> { |
131 | 0 | vec![Distribution::SinglePartition] |
132 | 0 | } |
133 | | |
134 | 0 | fn maintains_input_order(&self) -> Vec<bool> { |
135 | 0 | vec![true] |
136 | 0 | } |
137 | | |
138 | 0 | fn benefits_from_input_partitioning(&self) -> Vec<bool> { |
139 | 0 | vec![false] |
140 | 0 | } |
141 | | |
142 | 0 | fn with_new_children( |
143 | 0 | self: Arc<Self>, |
144 | 0 | children: Vec<Arc<dyn ExecutionPlan>>, |
145 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
146 | 0 | Ok(Arc::new(GlobalLimitExec::new( |
147 | 0 | Arc::clone(&children[0]), |
148 | 0 | self.skip, |
149 | 0 | self.fetch, |
150 | 0 | ))) |
151 | 0 | } |
152 | | |
153 | 8 | fn execute( |
154 | 8 | &self, |
155 | 8 | partition: usize, |
156 | 8 | context: Arc<TaskContext>, |
157 | 8 | ) -> Result<SendableRecordBatchStream> { |
158 | 8 | trace!( |
159 | 0 | "Start GlobalLimitExec::execute for partition: {}", |
160 | | partition |
161 | | ); |
162 | | // GlobalLimitExec has a single output partition |
163 | 8 | if 0 != partition { |
164 | 0 | return internal_err!("GlobalLimitExec invalid partition {partition}"); |
165 | 8 | } |
166 | 8 | |
167 | 8 | // GlobalLimitExec requires a single input partition |
168 | 8 | if 1 != self.input.output_partitioning().partition_count() { |
169 | 0 | return internal_err!("GlobalLimitExec requires a single input partition"); |
170 | 8 | } |
171 | 8 | |
172 | 8 | let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); |
173 | 8 | let stream = self.input.execute(0, context)?0 ; |
174 | 8 | Ok(Box::pin(LimitStream::new( |
175 | 8 | stream, |
176 | 8 | self.skip, |
177 | 8 | self.fetch, |
178 | 8 | baseline_metrics, |
179 | 8 | ))) |
180 | 8 | } |
181 | | |
182 | 0 | fn metrics(&self) -> Option<MetricsSet> { |
183 | 0 | Some(self.metrics.clone_inner()) |
184 | 0 | } |
185 | | |
186 | 16 | fn statistics(&self) -> Result<Statistics> { |
187 | 16 | Statistics::with_fetch( |
188 | 16 | self.input.statistics()?0 , |
189 | 16 | self.schema(), |
190 | 16 | self.fetch, |
191 | 16 | self.skip, |
192 | | 1, |
193 | | ) |
194 | 16 | } |
195 | | |
196 | 0 | fn fetch(&self) -> Option<usize> { |
197 | 0 | self.fetch |
198 | 0 | } |
199 | | |
200 | 0 | fn supports_limit_pushdown(&self) -> bool { |
201 | 0 | true |
202 | 0 | } |
203 | | } |
204 | | |
205 | | /// LocalLimitExec applies a limit to a single partition |
206 | | #[derive(Debug)] |
207 | | pub struct LocalLimitExec { |
208 | | /// Input execution plan |
209 | | input: Arc<dyn ExecutionPlan>, |
210 | | /// Maximum number of rows to return |
211 | | fetch: usize, |
212 | | /// Execution metrics |
213 | | metrics: ExecutionPlanMetricsSet, |
214 | | cache: PlanProperties, |
215 | | } |
216 | | |
217 | | impl LocalLimitExec { |
218 | | /// Create a new LocalLimitExec partition |
219 | 1 | pub fn new(input: Arc<dyn ExecutionPlan>, fetch: usize) -> Self { |
220 | 1 | let cache = Self::compute_properties(&input); |
221 | 1 | Self { |
222 | 1 | input, |
223 | 1 | fetch, |
224 | 1 | metrics: ExecutionPlanMetricsSet::new(), |
225 | 1 | cache, |
226 | 1 | } |
227 | 1 | } |
228 | | |
229 | | /// Input execution plan |
230 | 0 | pub fn input(&self) -> &Arc<dyn ExecutionPlan> { |
231 | 0 | &self.input |
232 | 0 | } |
233 | | |
234 | | /// Maximum number of rows to fetch |
235 | 0 | pub fn fetch(&self) -> usize { |
236 | 0 | self.fetch |
237 | 0 | } |
238 | | |
239 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
240 | 1 | fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties { |
241 | 1 | PlanProperties::new( |
242 | 1 | input.equivalence_properties().clone(), // Equivalence Properties |
243 | 1 | input.output_partitioning().clone(), // Output Partitioning |
244 | 1 | ExecutionMode::Bounded, // Execution Mode |
245 | 1 | ) |
246 | 1 | } |
247 | | } |
248 | | |
249 | | impl DisplayAs for LocalLimitExec { |
250 | 0 | fn fmt_as( |
251 | 0 | &self, |
252 | 0 | t: DisplayFormatType, |
253 | 0 | f: &mut std::fmt::Formatter, |
254 | 0 | ) -> std::fmt::Result { |
255 | 0 | match t { |
256 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
257 | 0 | write!(f, "LocalLimitExec: fetch={}", self.fetch) |
258 | 0 | } |
259 | 0 | } |
260 | 0 | } |
261 | | } |
262 | | |
263 | | impl ExecutionPlan for LocalLimitExec { |
264 | 0 | fn name(&self) -> &'static str { |
265 | 0 | "LocalLimitExec" |
266 | 0 | } |
267 | | |
268 | | /// Return a reference to Any that can be used for downcasting |
269 | 0 | fn as_any(&self) -> &dyn Any { |
270 | 0 | self |
271 | 0 | } |
272 | | |
273 | 1 | fn properties(&self) -> &PlanProperties { |
274 | 1 | &self.cache |
275 | 1 | } |
276 | | |
277 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
278 | 0 | vec![&self.input] |
279 | 0 | } |
280 | | |
281 | 0 | fn benefits_from_input_partitioning(&self) -> Vec<bool> { |
282 | 0 | vec![false] |
283 | 0 | } |
284 | | |
285 | 0 | fn maintains_input_order(&self) -> Vec<bool> { |
286 | 0 | vec![true] |
287 | 0 | } |
288 | | |
289 | 0 | fn with_new_children( |
290 | 0 | self: Arc<Self>, |
291 | 0 | children: Vec<Arc<dyn ExecutionPlan>>, |
292 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
293 | 0 | match children.len() { |
294 | 0 | 1 => Ok(Arc::new(LocalLimitExec::new( |
295 | 0 | Arc::clone(&children[0]), |
296 | 0 | self.fetch, |
297 | 0 | ))), |
298 | 0 | _ => internal_err!("LocalLimitExec wrong number of children"), |
299 | | } |
300 | 0 | } |
301 | | |
302 | 0 | fn execute( |
303 | 0 | &self, |
304 | 0 | partition: usize, |
305 | 0 | context: Arc<TaskContext>, |
306 | 0 | ) -> Result<SendableRecordBatchStream> { |
307 | 0 | trace!("Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); |
308 | 0 | let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); |
309 | 0 | let stream = self.input.execute(partition, context)?; |
310 | 0 | Ok(Box::pin(LimitStream::new( |
311 | 0 | stream, |
312 | 0 | 0, |
313 | 0 | Some(self.fetch), |
314 | 0 | baseline_metrics, |
315 | 0 | ))) |
316 | 0 | } |
317 | | |
318 | 0 | fn metrics(&self) -> Option<MetricsSet> { |
319 | 0 | Some(self.metrics.clone_inner()) |
320 | 0 | } |
321 | | |
322 | 1 | fn statistics(&self) -> Result<Statistics> { |
323 | 1 | Statistics::with_fetch( |
324 | 1 | self.input.statistics()?0 , |
325 | 1 | self.schema(), |
326 | 1 | Some(self.fetch), |
327 | | 0, |
328 | | 1, |
329 | | ) |
330 | 1 | } |
331 | | |
332 | 0 | fn fetch(&self) -> Option<usize> { |
333 | 0 | Some(self.fetch) |
334 | 0 | } |
335 | | |
336 | 0 | fn supports_limit_pushdown(&self) -> bool { |
337 | 0 | true |
338 | 0 | } |
339 | | } |
340 | | |
341 | | /// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows. |
342 | | pub struct LimitStream { |
343 | | /// The remaining number of rows to skip |
344 | | skip: usize, |
345 | | /// The remaining number of rows to produce |
346 | | fetch: usize, |
347 | | /// The input to read from. This is set to None once the limit is |
348 | | /// reached to enable early termination |
349 | | input: Option<SendableRecordBatchStream>, |
350 | | /// Copy of the input schema |
351 | | schema: SchemaRef, |
352 | | /// Execution time metrics |
353 | | baseline_metrics: BaselineMetrics, |
354 | | } |
355 | | |
356 | | impl LimitStream { |
357 | 14 | pub fn new( |
358 | 14 | input: SendableRecordBatchStream, |
359 | 14 | skip: usize, |
360 | 14 | fetch: Option<usize>, |
361 | 14 | baseline_metrics: BaselineMetrics, |
362 | 14 | ) -> Self { |
363 | 14 | let schema = input.schema(); |
364 | 14 | Self { |
365 | 14 | skip, |
366 | 14 | fetch: fetch.unwrap_or(usize::MAX), |
367 | 14 | input: Some(input), |
368 | 14 | schema, |
369 | 14 | baseline_metrics, |
370 | 14 | } |
371 | 14 | } |
372 | | |
373 | 10 | fn poll_and_skip( |
374 | 10 | &mut self, |
375 | 10 | cx: &mut Context<'_>, |
376 | 10 | ) -> Poll<Option<Result<RecordBatch>>> { |
377 | 10 | let input = self.input.as_mut().unwrap(); |
378 | 22 | loop { |
379 | 22 | let poll = input.poll_next_unpin(cx); |
380 | 22 | let poll = poll.map_ok(|batch| { |
381 | 14 | if batch.num_rows() <= self.skip { |
382 | 12 | self.skip -= batch.num_rows(); |
383 | 12 | RecordBatch::new_empty(input.schema()) |
384 | | } else { |
385 | 2 | let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip); |
386 | 2 | self.skip = 0; |
387 | 2 | new_batch |
388 | | } |
389 | 22 | }14 ); |
390 | | |
391 | 14 | match &poll { |
392 | 14 | Poll::Ready(Some(Ok(batch))) => { |
393 | 14 | if batch.num_rows() > 0 { |
394 | 2 | break poll; |
395 | 12 | } else { |
396 | 12 | // continue to poll input stream |
397 | 12 | } |
398 | | } |
399 | 0 | Poll::Ready(Some(Err(_e))) => break poll, |
400 | 3 | Poll::Ready(None) => break poll, |
401 | 5 | Poll::Pending => break poll, |
402 | | } |
403 | | } |
404 | 10 | } |
405 | | |
406 | | /// fetches from the batch |
407 | 22 | fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> { |
408 | 22 | // records time on drop |
409 | 22 | let _timer = self.baseline_metrics.elapsed_compute().timer(); |
410 | 22 | if self.fetch == 0 { |
411 | 0 | self.input = None; // clear input so it can be dropped early |
412 | 0 | None |
413 | 22 | } else if batch.num_rows() < self.fetch { |
414 | | // |
415 | 13 | self.fetch -= batch.num_rows(); |
416 | 13 | Some(batch) |
417 | 9 | } else if batch.num_rows() >= self.fetch { |
418 | 9 | let batch_rows = self.fetch; |
419 | 9 | self.fetch = 0; |
420 | 9 | self.input = None; // clear input so it can be dropped early |
421 | 9 | |
422 | 9 | // It is guaranteed that batch_rows is <= batch.num_rows |
423 | 9 | Some(batch.slice(0, batch_rows)) |
424 | | } else { |
425 | 0 | unreachable!() |
426 | | } |
427 | 22 | } |
428 | | } |
429 | | |
430 | | impl Stream for LimitStream { |
431 | | type Item = Result<RecordBatch>; |
432 | | |
433 | 44 | fn poll_next( |
434 | 44 | mut self: Pin<&mut Self>, |
435 | 44 | cx: &mut Context<'_>, |
436 | 44 | ) -> Poll<Option<Self::Item>> { |
437 | 44 | let fetch_started = self.skip == 0; |
438 | 44 | let poll = match &mut self.input { |
439 | 35 | Some(input) => { |
440 | 35 | let poll = if fetch_started { |
441 | 25 | input.poll_next_unpin(cx) |
442 | | } else { |
443 | 10 | self.poll_and_skip(cx) |
444 | | }; |
445 | | |
446 | 35 | poll.map(|x| m27 atch x22 { |
447 | 22 | Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(), |
448 | 5 | other => other, |
449 | 35 | }27 ) |
450 | | } |
451 | | // input has been cleared |
452 | 9 | None => Poll::Ready(None), |
453 | | }; |
454 | | |
455 | 44 | self.baseline_metrics.record_poll(poll) |
456 | 44 | } |
457 | | } |
458 | | |
459 | | impl RecordBatchStream for LimitStream { |
460 | | /// Get the schema |
461 | 0 | fn schema(&self) -> SchemaRef { |
462 | 0 | Arc::clone(&self.schema) |
463 | 0 | } |
464 | | } |
465 | | |
466 | | #[cfg(test)] |
467 | | mod tests { |
468 | | use super::*; |
469 | | use crate::coalesce_partitions::CoalescePartitionsExec; |
470 | | use crate::common::collect; |
471 | | use crate::{common, test}; |
472 | | |
473 | | use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; |
474 | | use arrow_array::RecordBatchOptions; |
475 | | use arrow_schema::Schema; |
476 | | use datafusion_common::stats::Precision; |
477 | | use datafusion_physical_expr::expressions::col; |
478 | | use datafusion_physical_expr::PhysicalExpr; |
479 | | |
480 | | #[tokio::test] |
481 | 1 | async fn limit() -> Result<()> { |
482 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
483 | 1 | |
484 | 1 | let num_partitions = 4; |
485 | 1 | let csv = test::scan_partitioned(num_partitions); |
486 | 1 | |
487 | 1 | // input should have 4 partitions |
488 | 1 | assert_eq!(csv.output_partitioning().partition_count(), num_partitions); |
489 | 1 | |
490 | 1 | let limit = |
491 | 1 | GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7)); |
492 | 1 | |
493 | 1 | // the result should contain 4 batches (one per input partition) |
494 | 1 | let iter = limit.execute(0, task_ctx)?0 ; |
495 | 1 | let batches = common::collect(iter).await?0 ; |
496 | 1 | |
497 | 1 | // there should be a total of 100 rows |
498 | 1 | let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); |
499 | 1 | assert_eq!(row_count, 7); |
500 | 1 | |
501 | 1 | Ok(()) |
502 | 1 | } |
503 | | |
504 | | #[tokio::test] |
505 | 1 | async fn limit_early_shutdown() -> Result<()> { |
506 | 1 | let batches = vec![ |
507 | 1 | test::make_partition(5), |
508 | 1 | test::make_partition(10), |
509 | 1 | test::make_partition(15), |
510 | 1 | test::make_partition(20), |
511 | 1 | test::make_partition(25), |
512 | 1 | ]; |
513 | 1 | let input = test::exec::TestStream::new(batches); |
514 | 1 | |
515 | 1 | let index = input.index(); |
516 | 1 | assert_eq!(index.value(), 0); |
517 | 1 | |
518 | 1 | // limit of six needs to consume the entire first record batch |
519 | 1 | // (5 rows) and 1 row from the second (1 row) |
520 | 1 | let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); |
521 | 1 | let limit_stream = |
522 | 1 | LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics); |
523 | 1 | assert_eq!(index.value(), 0); |
524 | 1 | |
525 | 1 | let results = collect(Box::pin(limit_stream)).await0 .unwrap(); |
526 | 2 | let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); |
527 | 1 | // Only 6 rows should have been produced |
528 | 1 | assert_eq!(num_rows, 6); |
529 | 1 | |
530 | 1 | // Only the first two batches should be consumed |
531 | 1 | assert_eq!(index.value(), 2); |
532 | 1 | |
533 | 1 | Ok(()) |
534 | 1 | } |
535 | | |
536 | | #[tokio::test] |
537 | 1 | async fn limit_equals_batch_size() -> Result<()> { |
538 | 1 | let batches = vec![ |
539 | 1 | test::make_partition(6), |
540 | 1 | test::make_partition(6), |
541 | 1 | test::make_partition(6), |
542 | 1 | ]; |
543 | 1 | let input = test::exec::TestStream::new(batches); |
544 | 1 | |
545 | 1 | let index = input.index(); |
546 | 1 | assert_eq!(index.value(), 0); |
547 | 1 | |
548 | 1 | // limit of six needs to consume the entire first record batch |
549 | 1 | // (6 rows) and stop immediately |
550 | 1 | let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); |
551 | 1 | let limit_stream = |
552 | 1 | LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics); |
553 | 1 | assert_eq!(index.value(), 0); |
554 | 1 | |
555 | 1 | let results = collect(Box::pin(limit_stream)).await0 .unwrap(); |
556 | 1 | let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); |
557 | 1 | // Only 6 rows should have been produced |
558 | 1 | assert_eq!(num_rows, 6); |
559 | 1 | |
560 | 1 | // Only the first batch should be consumed |
561 | 1 | assert_eq!(index.value(), 1); |
562 | 1 | |
563 | 1 | Ok(()) |
564 | 1 | } |
565 | | |
566 | | #[tokio::test] |
567 | 1 | async fn limit_no_column() -> Result<()> { |
568 | 1 | let batches = vec![ |
569 | 1 | make_batch_no_column(6), |
570 | 1 | make_batch_no_column(6), |
571 | 1 | make_batch_no_column(6), |
572 | 1 | ]; |
573 | 1 | let input = test::exec::TestStream::new(batches); |
574 | 1 | |
575 | 1 | let index = input.index(); |
576 | 1 | assert_eq!(index.value(), 0); |
577 | 1 | |
578 | 1 | // limit of six needs to consume the entire first record batch |
579 | 1 | // (6 rows) and stop immediately |
580 | 1 | let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); |
581 | 1 | let limit_stream = |
582 | 1 | LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics); |
583 | 1 | assert_eq!(index.value(), 0); |
584 | 1 | |
585 | 1 | let results = collect(Box::pin(limit_stream)).await0 .unwrap(); |
586 | 1 | let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); |
587 | 1 | // Only 6 rows should have been produced |
588 | 1 | assert_eq!(num_rows, 6); |
589 | 1 | |
590 | 1 | // Only the first batch should be consumed |
591 | 1 | assert_eq!(index.value(), 1); |
592 | 1 | |
593 | 1 | Ok(()) |
594 | 1 | } |
595 | | |
596 | | // test cases for "skip" |
597 | 7 | async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> { |
598 | 7 | let task_ctx = Arc::new(TaskContext::default()); |
599 | 7 | |
600 | 7 | // 4 partitions @ 100 rows apiece |
601 | 7 | let num_partitions = 4; |
602 | 7 | let csv = test::scan_partitioned(num_partitions); |
603 | 7 | |
604 | 7 | assert_eq!(csv.output_partitioning().partition_count(), num_partitions); |
605 | | |
606 | 7 | let offset = |
607 | 7 | GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); |
608 | | |
609 | | // the result should contain 4 batches (one per input partition) |
610 | 7 | let iter = offset.execute(0, task_ctx)?0 ; |
611 | 7 | let batches = common::collect(iter).await?0 ; |
612 | 10 | Ok(batches.iter().map(7 |batch| batch.num_rows()).sum())7 |
613 | 7 | } |
614 | | |
615 | | #[tokio::test] |
616 | 1 | async fn skip_none_fetch_none() -> Result<()> { |
617 | 1 | let row_count = skip_and_fetch(0, None).await?0 ; |
618 | 1 | assert_eq!(row_count, 400); |
619 | 1 | Ok(()) |
620 | 1 | } |
621 | | |
622 | | #[tokio::test] |
623 | 1 | async fn skip_none_fetch_50() -> Result<()> { |
624 | 1 | let row_count = skip_and_fetch(0, Some(50)).await?0 ; |
625 | 1 | assert_eq!(row_count, 50); |
626 | 1 | Ok(()) |
627 | 1 | } |
628 | | |
629 | | #[tokio::test] |
630 | 1 | async fn skip_3_fetch_none() -> Result<()> { |
631 | 1 | // there are total of 400 rows, we skipped 3 rows (offset = 3) |
632 | 1 | let row_count = skip_and_fetch(3, None).await?0 ; |
633 | 1 | assert_eq!(row_count, 397); |
634 | 1 | Ok(()) |
635 | 1 | } |
636 | | |
637 | | #[tokio::test] |
638 | 1 | async fn skip_3_fetch_10_stats() -> Result<()> { |
639 | 1 | // there are total of 100 rows, we skipped 3 rows (offset = 3) |
640 | 1 | let row_count = skip_and_fetch(3, Some(10)).await?0 ; |
641 | 1 | assert_eq!(row_count, 10); |
642 | 1 | Ok(()) |
643 | 1 | } |
644 | | |
645 | | #[tokio::test] |
646 | 1 | async fn skip_400_fetch_none() -> Result<()> { |
647 | 1 | let row_count = skip_and_fetch(400, None).await?0 ; |
648 | 1 | assert_eq!(row_count, 0); |
649 | 1 | Ok(()) |
650 | 1 | } |
651 | | |
652 | | #[tokio::test] |
653 | 1 | async fn skip_400_fetch_1() -> Result<()> { |
654 | 1 | // there are a total of 400 rows |
655 | 1 | let row_count = skip_and_fetch(400, Some(1)).await?0 ; |
656 | 1 | assert_eq!(row_count, 0); |
657 | 1 | Ok(()) |
658 | 1 | } |
659 | | |
660 | | #[tokio::test] |
661 | 1 | async fn skip_401_fetch_none() -> Result<()> { |
662 | 1 | // there are total of 400 rows, we skipped 401 rows (offset = 3) |
663 | 1 | let row_count = skip_and_fetch(401, None).await?0 ; |
664 | 1 | assert_eq!(row_count, 0); |
665 | 1 | Ok(()) |
666 | 1 | } |
667 | | |
668 | | #[tokio::test] |
669 | 1 | async fn test_row_number_statistics_for_global_limit() -> Result<()> { |
670 | 1 | let row_count = row_number_statistics_for_global_limit(0, Some(10)).await0 ?0 ; |
671 | 1 | assert_eq!(row_count, Precision::Exact(10)); |
672 | 1 | |
673 | 1 | let row_count = row_number_statistics_for_global_limit(5, Some(10)).await0 ?0 ; |
674 | 1 | assert_eq!(row_count, Precision::Exact(10)); |
675 | 1 | |
676 | 1 | let row_count = row_number_statistics_for_global_limit(400, Some(10)).await0 ?0 ; |
677 | 1 | assert_eq!(row_count, Precision::Exact(0)); |
678 | 1 | |
679 | 1 | let row_count = row_number_statistics_for_global_limit(398, Some(10)).await0 ?0 ; |
680 | 1 | assert_eq!(row_count, Precision::Exact(2)); |
681 | 1 | |
682 | 1 | let row_count = row_number_statistics_for_global_limit(398, Some(1)).await0 ?0 ; |
683 | 1 | assert_eq!(row_count, Precision::Exact(1)); |
684 | 1 | |
685 | 1 | let row_count = row_number_statistics_for_global_limit(398, None).await0 ?0 ; |
686 | 1 | assert_eq!(row_count, Precision::Exact(2)); |
687 | 1 | |
688 | 1 | let row_count = |
689 | 1 | row_number_statistics_for_global_limit(0, Some(usize::MAX)).await0 ?0 ; |
690 | 1 | assert_eq!(row_count, Precision::Exact(400)); |
691 | 1 | |
692 | 1 | let row_count = |
693 | 1 | row_number_statistics_for_global_limit(398, Some(usize::MAX)).await0 ?0 ; |
694 | 1 | assert_eq!(row_count, Precision::Exact(2)); |
695 | 1 | |
696 | 1 | let row_count = |
697 | 1 | row_number_inexact_statistics_for_global_limit(0, Some(10)).await0 ?0 ; |
698 | 1 | assert_eq!(row_count, Precision::Inexact(10)); |
699 | 1 | |
700 | 1 | let row_count = |
701 | 1 | row_number_inexact_statistics_for_global_limit(5, Some(10)).await0 ?0 ; |
702 | 1 | assert_eq!(row_count, Precision::Inexact(10)); |
703 | 1 | |
704 | 1 | let row_count = |
705 | 1 | row_number_inexact_statistics_for_global_limit(400, Some(10)).await0 ?0 ; |
706 | 1 | assert_eq!(row_count, Precision::Exact(0)); |
707 | 1 | |
708 | 1 | let row_count = |
709 | 1 | row_number_inexact_statistics_for_global_limit(398, Some(10)).await0 ?0 ; |
710 | 1 | assert_eq!(row_count, Precision::Inexact(2)); |
711 | 1 | |
712 | 1 | let row_count = |
713 | 1 | row_number_inexact_statistics_for_global_limit(398, Some(1)).await0 ?0 ; |
714 | 1 | assert_eq!(row_count, Precision::Inexact(1)); |
715 | 1 | |
716 | 1 | let row_count = row_number_inexact_statistics_for_global_limit(398, None).await0 ?0 ; |
717 | 1 | assert_eq!(row_count, Precision::Inexact(2)); |
718 | 1 | |
719 | 1 | let row_count = |
720 | 1 | row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await0 ?0 ; |
721 | 1 | assert_eq!(row_count, Precision::Inexact(400)); |
722 | 1 | |
723 | 1 | let row_count = |
724 | 1 | row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await0 ?0 ; |
725 | 1 | assert_eq!(row_count, Precision::Inexact(2)); |
726 | 1 | |
727 | 1 | Ok(()) |
728 | 1 | } |
729 | | |
730 | | #[tokio::test] |
731 | 1 | async fn test_row_number_statistics_for_local_limit() -> Result<()> { |
732 | 1 | let row_count = row_number_statistics_for_local_limit(4, 10).await0 ?0 ; |
733 | 1 | assert_eq!(row_count, Precision::Exact(10)); |
734 | 1 | |
735 | 1 | Ok(()) |
736 | 1 | } |
737 | | |
738 | 8 | async fn row_number_statistics_for_global_limit( |
739 | 8 | skip: usize, |
740 | 8 | fetch: Option<usize>, |
741 | 8 | ) -> Result<Precision<usize>> { |
742 | 8 | let num_partitions = 4; |
743 | 8 | let csv = test::scan_partitioned(num_partitions); |
744 | 8 | |
745 | 8 | assert_eq!(csv.output_partitioning().partition_count(), num_partitions); |
746 | | |
747 | 8 | let offset = |
748 | 8 | GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); |
749 | 8 | |
750 | 8 | Ok(offset.statistics()?0 .num_rows) |
751 | 8 | } |
752 | | |
753 | 8 | pub fn build_group_by( |
754 | 8 | input_schema: &SchemaRef, |
755 | 8 | columns: Vec<String>, |
756 | 8 | ) -> PhysicalGroupBy { |
757 | 8 | let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![]; |
758 | 8 | for column in columns.iter() { |
759 | 8 | group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); |
760 | 8 | } |
761 | 8 | PhysicalGroupBy::new_single(group_by_expr.clone()) |
762 | 8 | } |
763 | | |
764 | 8 | async fn row_number_inexact_statistics_for_global_limit( |
765 | 8 | skip: usize, |
766 | 8 | fetch: Option<usize>, |
767 | 8 | ) -> Result<Precision<usize>> { |
768 | 8 | let num_partitions = 4; |
769 | 8 | let csv = test::scan_partitioned(num_partitions); |
770 | 8 | |
771 | 8 | assert_eq!(csv.output_partitioning().partition_count(), num_partitions); |
772 | | |
773 | | // Adding a "GROUP BY i" changes the input stats from Exact to Inexact. |
774 | 8 | let agg = AggregateExec::try_new( |
775 | 8 | AggregateMode::Final, |
776 | 8 | build_group_by(&csv.schema(), vec!["i".to_string()]), |
777 | 8 | vec![], |
778 | 8 | vec![], |
779 | 8 | Arc::clone(&csv), |
780 | 8 | Arc::clone(&csv.schema()), |
781 | 8 | )?0 ; |
782 | 8 | let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg); |
783 | 8 | |
784 | 8 | let offset = GlobalLimitExec::new( |
785 | 8 | Arc::new(CoalescePartitionsExec::new(agg_exec)), |
786 | 8 | skip, |
787 | 8 | fetch, |
788 | 8 | ); |
789 | 8 | |
790 | 8 | Ok(offset.statistics()?0 .num_rows) |
791 | 8 | } |
792 | | |
793 | 1 | async fn row_number_statistics_for_local_limit( |
794 | 1 | num_partitions: usize, |
795 | 1 | fetch: usize, |
796 | 1 | ) -> Result<Precision<usize>> { |
797 | 1 | let csv = test::scan_partitioned(num_partitions); |
798 | 1 | |
799 | 1 | assert_eq!(csv.output_partitioning().partition_count(), num_partitions); |
800 | | |
801 | 1 | let offset = LocalLimitExec::new(csv, fetch); |
802 | 1 | |
803 | 1 | Ok(offset.statistics()?0 .num_rows) |
804 | 1 | } |
805 | | |
806 | | /// Return a RecordBatch with a single array with row_count sz |
807 | 3 | fn make_batch_no_column(sz: usize) -> RecordBatch { |
808 | 3 | let schema = Arc::new(Schema::empty()); |
809 | 3 | |
810 | 3 | let options = RecordBatchOptions::new().with_row_count(Option::from(sz)); |
811 | 3 | RecordBatch::try_new_with_options(schema, vec![], &options).unwrap() |
812 | 3 | } |
813 | | } |