/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/mod.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Aggregates functionalities |
19 | | |
20 | | use std::any::Any; |
21 | | use std::sync::Arc; |
22 | | |
23 | | use super::{DisplayAs, ExecutionMode, ExecutionPlanProperties, PlanProperties}; |
24 | | use crate::aggregates::{ |
25 | | no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, |
26 | | topk_stream::GroupedTopKAggregateStream, |
27 | | }; |
28 | | use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; |
29 | | use crate::projection::get_field_metadata; |
30 | | use crate::windows::get_ordered_partition_by_indices; |
31 | | use crate::{ |
32 | | DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, |
33 | | SendableRecordBatchStream, Statistics, |
34 | | }; |
35 | | |
36 | | use arrow::array::ArrayRef; |
37 | | use arrow::datatypes::{Field, Schema, SchemaRef}; |
38 | | use arrow::record_batch::RecordBatch; |
39 | | use datafusion_common::stats::Precision; |
40 | | use datafusion_common::{internal_err, not_impl_err, Result}; |
41 | | use datafusion_execution::TaskContext; |
42 | | use datafusion_expr::Accumulator; |
43 | | use datafusion_physical_expr::{ |
44 | | equivalence::{collapse_lex_req, ProjectionMapping}, |
45 | | expressions::Column, |
46 | | physical_exprs_contains, EquivalenceProperties, LexOrdering, LexRequirement, |
47 | | PhysicalExpr, PhysicalSortRequirement, |
48 | | }; |
49 | | |
50 | | use datafusion_physical_expr::aggregate::AggregateFunctionExpr; |
51 | | use itertools::Itertools; |
52 | | |
53 | | pub mod group_values; |
54 | | mod no_grouping; |
55 | | pub mod order; |
56 | | mod row_hash; |
57 | | mod topk; |
58 | | mod topk_stream; |
59 | | |
60 | | /// Hash aggregate modes |
61 | | /// |
62 | | /// See [`Accumulator::state`] for background information on multi-phase |
63 | | /// aggregation and how these modes are used. |
64 | | #[derive(Debug, Copy, Clone, PartialEq, Eq)] |
65 | | pub enum AggregateMode { |
66 | | /// Partial aggregate that can be applied in parallel across input |
67 | | /// partitions. |
68 | | /// |
69 | | /// This is the first phase of a multi-phase aggregation. |
70 | | Partial, |
71 | | /// Final aggregate that produces a single partition of output by combining |
72 | | /// the output of multiple partial aggregates. |
73 | | /// |
74 | | /// This is the second phase of a multi-phase aggregation. |
75 | | Final, |
76 | | /// Final aggregate that works on pre-partitioned data. |
77 | | /// |
78 | | /// This requires the invariant that all rows with a particular |
79 | | /// grouping key are in the same partitions, such as is the case |
80 | | /// with Hash repartitioning on the group keys. If a group key is |
81 | | /// duplicated, duplicate groups would be produced |
82 | | FinalPartitioned, |
83 | | /// Applies the entire logical aggregation operation in a single operator, |
84 | | /// as opposed to Partial / Final modes which apply the logical aggregation using |
85 | | /// two operators. |
86 | | /// |
87 | | /// This mode requires that the input is a single partition (like Final) |
88 | | Single, |
89 | | /// Applies the entire logical aggregation operation in a single operator, |
90 | | /// as opposed to Partial / Final modes which apply the logical aggregation using |
91 | | /// two operators. |
92 | | /// |
93 | | /// This mode requires that the input is partitioned by group key (like |
94 | | /// FinalPartitioned) |
95 | | SinglePartitioned, |
96 | | } |
97 | | |
98 | | impl AggregateMode { |
99 | | /// Checks whether this aggregation step describes a "first stage" calculation. |
100 | | /// In other words, its input is not another aggregation result and the |
101 | | /// `merge_batch` method will not be called for these modes. |
102 | 59 | pub fn is_first_stage(&self) -> bool { |
103 | 59 | match self { |
104 | | AggregateMode::Partial |
105 | | | AggregateMode::Single |
106 | 34 | | AggregateMode::SinglePartitioned => true, |
107 | 25 | AggregateMode::Final | AggregateMode::FinalPartitioned => false, |
108 | | } |
109 | 59 | } |
110 | | } |
111 | | |
112 | | /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) |
113 | | /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] |
114 | | /// and a single group [false, false]. |
115 | | /// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression |
116 | | /// into multiple groups, using null expressions to align each group. |
117 | | /// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should |
118 | | /// create a `PhysicalGroupBy` like |
119 | | /// ```text |
120 | | /// PhysicalGroupBy { |
121 | | /// expr: [(col(a), a), (col(b), b)], |
122 | | /// null_expr: [(NULL, a), (NULL, b)], |
123 | | /// groups: [ |
124 | | /// [false, false], // (a,b) |
125 | | /// [false, true], // (a) <=> (a, NULL) |
126 | | /// [true, false] // (b) <=> (NULL, b) |
127 | | /// ] |
128 | | /// } |
129 | | /// ``` |
130 | | #[derive(Clone, Debug, Default)] |
131 | | pub struct PhysicalGroupBy { |
132 | | /// Distinct (Physical Expr, Alias) in the grouping set |
133 | | expr: Vec<(Arc<dyn PhysicalExpr>, String)>, |
134 | | /// Corresponding NULL expressions for expr |
135 | | null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>, |
136 | | /// Null mask for each group in this grouping set. Each group is |
137 | | /// composed of either one of the group expressions in expr or a null |
138 | | /// expression in null_expr. If `groups[i][j]` is true, then the |
139 | | /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`. |
140 | | groups: Vec<Vec<bool>>, |
141 | | } |
142 | | |
143 | | impl PhysicalGroupBy { |
144 | | /// Create a new `PhysicalGroupBy` |
145 | 1 | pub fn new( |
146 | 1 | expr: Vec<(Arc<dyn PhysicalExpr>, String)>, |
147 | 1 | null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>, |
148 | 1 | groups: Vec<Vec<bool>>, |
149 | 1 | ) -> Self { |
150 | 1 | Self { |
151 | 1 | expr, |
152 | 1 | null_expr, |
153 | 1 | groups, |
154 | 1 | } |
155 | 1 | } |
156 | | |
157 | | /// Create a GROUPING SET with only a single group. This is the "standard" |
158 | | /// case when building a plan from an expression such as `GROUP BY a,b,c` |
159 | 100 | pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self { |
160 | 100 | let num_exprs = expr.len(); |
161 | 100 | Self { |
162 | 100 | expr, |
163 | 100 | null_expr: vec![], |
164 | 100 | groups: vec![vec![false; num_exprs]], |
165 | 100 | } |
166 | 100 | } |
167 | | |
168 | | /// Calculate GROUP BY expressions nullable |
169 | 50 | pub fn exprs_nullable(&self) -> Vec<bool> { |
170 | 50 | let mut exprs_nullable = vec![false; self.expr.len()]; |
171 | 59 | for group in self.groups.iter()50 { |
172 | 83 | group.iter().enumerate().for_each(59 |(index, is_null)| { |
173 | 83 | if *is_null { |
174 | 15 | exprs_nullable[index] = true; |
175 | 68 | } |
176 | 83 | })59 |
177 | | } |
178 | 50 | exprs_nullable |
179 | 50 | } |
180 | | |
181 | | /// Returns the group expressions |
182 | 4 | pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] { |
183 | 4 | &self.expr |
184 | 4 | } |
185 | | |
186 | | /// Returns the null expressions |
187 | 0 | pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] { |
188 | 0 | &self.null_expr |
189 | 0 | } |
190 | | |
191 | | /// Returns the group null masks |
192 | 0 | pub fn groups(&self) -> &[Vec<bool>] { |
193 | 0 | &self.groups |
194 | 0 | } |
195 | | |
196 | | /// Returns true if this `PhysicalGroupBy` has no group expressions |
197 | 0 | pub fn is_empty(&self) -> bool { |
198 | 0 | self.expr.is_empty() |
199 | 0 | } |
200 | | |
201 | | /// Check whether grouping set is single group |
202 | 62 | pub fn is_single(&self) -> bool { |
203 | 62 | self.null_expr.is_empty() |
204 | 62 | } |
205 | | |
206 | | /// Calculate GROUP BY expressions according to input schema. |
207 | 59 | pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> { |
208 | 59 | self.expr |
209 | 59 | .iter() |
210 | 59 | .map(|(expr, _alias)| Arc::clone(expr)58 ) |
211 | 59 | .collect() |
212 | 59 | } |
213 | | |
214 | | /// Return grouping expressions as they occur in the output schema. |
215 | 70 | pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> { |
216 | 70 | self.expr |
217 | 70 | .iter() |
218 | 70 | .enumerate() |
219 | 84 | .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) |
220 | 70 | .collect() |
221 | 70 | } |
222 | | } |
223 | | |
224 | | impl PartialEq for PhysicalGroupBy { |
225 | 0 | fn eq(&self, other: &PhysicalGroupBy) -> bool { |
226 | 0 | self.expr.len() == other.expr.len() |
227 | 0 | && self |
228 | 0 | .expr |
229 | 0 | .iter() |
230 | 0 | .zip(other.expr.iter()) |
231 | 0 | .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2) |
232 | 0 | && self.null_expr.len() == other.null_expr.len() |
233 | 0 | && self |
234 | 0 | .null_expr |
235 | 0 | .iter() |
236 | 0 | .zip(other.null_expr.iter()) |
237 | 0 | .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2) |
238 | 0 | && self.groups == other.groups |
239 | 0 | } |
240 | | } |
241 | | |
242 | | enum StreamType { |
243 | | AggregateStream(AggregateStream), |
244 | | GroupedHash(GroupedHashAggregateStream), |
245 | | GroupedPriorityQueue(GroupedTopKAggregateStream), |
246 | | } |
247 | | |
248 | | impl From<StreamType> for SendableRecordBatchStream { |
249 | 72 | fn from(stream: StreamType) -> Self { |
250 | 72 | match stream { |
251 | 2 | StreamType::AggregateStream(stream) => Box::pin(stream), |
252 | 70 | StreamType::GroupedHash(stream) => Box::pin(stream), |
253 | 0 | StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), |
254 | | } |
255 | 72 | } |
256 | | } |
257 | | |
258 | | /// Hash aggregate execution plan |
259 | | #[derive(Debug)] |
260 | | pub struct AggregateExec { |
261 | | /// Aggregation mode (full, partial) |
262 | | mode: AggregateMode, |
263 | | /// Group by expressions |
264 | | group_by: PhysicalGroupBy, |
265 | | /// Aggregate expressions |
266 | | aggr_expr: Vec<AggregateFunctionExpr>, |
267 | | /// FILTER (WHERE clause) expression for each aggregate expression |
268 | | filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>, |
269 | | /// Set if the output of this aggregation is truncated by a upstream sort/limit clause |
270 | | limit: Option<usize>, |
271 | | /// Input plan, could be a partial aggregate or the input to the aggregate |
272 | | pub input: Arc<dyn ExecutionPlan>, |
273 | | /// Schema after the aggregate is applied |
274 | | schema: SchemaRef, |
275 | | /// Input schema before any aggregation is applied. For partial aggregate this will be the |
276 | | /// same as input.schema() but for the final aggregate it will be the same as the input |
277 | | /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`. |
278 | | /// We need the input schema of partial aggregate to be able to deserialize aggregate |
279 | | /// expressions from protobuf for final aggregate. |
280 | | pub input_schema: SchemaRef, |
281 | | /// Execution metrics |
282 | | metrics: ExecutionPlanMetricsSet, |
283 | | required_input_ordering: Option<LexRequirement>, |
284 | | /// Describes how the input is ordered relative to the group by columns |
285 | | input_order_mode: InputOrderMode, |
286 | | cache: PlanProperties, |
287 | | } |
288 | | |
289 | | impl AggregateExec { |
290 | | /// Function used in `OptimizeAggregateOrder` optimizer rule, |
291 | | /// where we need parts of the new value, others cloned from the old one |
292 | | /// Rewrites aggregate exec with new aggregate expressions. |
293 | 0 | pub fn with_new_aggr_exprs(&self, aggr_expr: Vec<AggregateFunctionExpr>) -> Self { |
294 | 0 | Self { |
295 | 0 | aggr_expr, |
296 | 0 | // clone the rest of the fields |
297 | 0 | required_input_ordering: self.required_input_ordering.clone(), |
298 | 0 | metrics: ExecutionPlanMetricsSet::new(), |
299 | 0 | input_order_mode: self.input_order_mode.clone(), |
300 | 0 | cache: self.cache.clone(), |
301 | 0 | mode: self.mode, |
302 | 0 | group_by: self.group_by.clone(), |
303 | 0 | filter_expr: self.filter_expr.clone(), |
304 | 0 | limit: self.limit, |
305 | 0 | input: Arc::clone(&self.input), |
306 | 0 | schema: Arc::clone(&self.schema), |
307 | 0 | input_schema: Arc::clone(&self.input_schema), |
308 | 0 | } |
309 | 0 | } |
310 | | |
311 | 0 | pub fn cache(&self) -> &PlanProperties { |
312 | 0 | &self.cache |
313 | 0 | } |
314 | | |
315 | | /// Create a new hash aggregate execution plan |
316 | 49 | pub fn try_new( |
317 | 49 | mode: AggregateMode, |
318 | 49 | group_by: PhysicalGroupBy, |
319 | 49 | aggr_expr: Vec<AggregateFunctionExpr>, |
320 | 49 | filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>, |
321 | 49 | input: Arc<dyn ExecutionPlan>, |
322 | 49 | input_schema: SchemaRef, |
323 | 49 | ) -> Result<Self> { |
324 | 49 | let schema = create_schema( |
325 | 49 | &input.schema(), |
326 | 49 | &group_by.expr, |
327 | 49 | &aggr_expr, |
328 | 49 | group_by.exprs_nullable(), |
329 | 49 | mode, |
330 | 49 | )?0 ; |
331 | | |
332 | 49 | let schema = Arc::new(schema); |
333 | 49 | AggregateExec::try_new_with_schema( |
334 | 49 | mode, |
335 | 49 | group_by, |
336 | 49 | aggr_expr, |
337 | 49 | filter_expr, |
338 | 49 | input, |
339 | 49 | input_schema, |
340 | 49 | schema, |
341 | 49 | ) |
342 | 49 | } |
343 | | |
344 | | /// Create a new hash aggregate execution plan with the given schema. |
345 | | /// This constructor isn't part of the public API, it is used internally |
346 | | /// by DataFusion to enforce schema consistency during when re-creating |
347 | | /// `AggregateExec`s inside optimization rules. Schema field names of an |
348 | | /// `AggregateExec` depends on the names of aggregate expressions. Since |
349 | | /// a rule may re-write aggregate expressions (e.g. reverse them) during |
350 | | /// initialization, field names may change inadvertently if one re-creates |
351 | | /// the schema in such cases. |
352 | | #[allow(clippy::too_many_arguments)] |
353 | 50 | fn try_new_with_schema( |
354 | 50 | mode: AggregateMode, |
355 | 50 | group_by: PhysicalGroupBy, |
356 | 50 | mut aggr_expr: Vec<AggregateFunctionExpr>, |
357 | 50 | filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>, |
358 | 50 | input: Arc<dyn ExecutionPlan>, |
359 | 50 | input_schema: SchemaRef, |
360 | 50 | schema: SchemaRef, |
361 | 50 | ) -> Result<Self> { |
362 | 50 | // Make sure arguments are consistent in size |
363 | 50 | if aggr_expr.len() != filter_expr.len() { |
364 | 0 | return internal_err!("Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr); |
365 | 50 | } |
366 | 50 | |
367 | 50 | let input_eq_properties = input.equivalence_properties(); |
368 | 50 | // Get GROUP BY expressions: |
369 | 50 | let groupby_exprs = group_by.input_exprs(); |
370 | 50 | // If existing ordering satisfies a prefix of the GROUP BY expressions, |
371 | 50 | // prefix requirements with this section. In this case, aggregation will |
372 | 50 | // work more efficiently. |
373 | 50 | let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); |
374 | 50 | let mut new_requirement = LexRequirement::new( |
375 | 50 | indices |
376 | 50 | .iter() |
377 | 50 | .map(|&idx| PhysicalSortRequirement { |
378 | 1 | expr: Arc::clone(&groupby_exprs[idx]), |
379 | 1 | options: None, |
380 | 50 | }) |
381 | 50 | .collect::<Vec<_>>(), |
382 | 50 | ); |
383 | | |
384 | 50 | let req = get_finer_aggregate_exprs_requirement( |
385 | 50 | &mut aggr_expr, |
386 | 50 | &group_by, |
387 | 50 | input_eq_properties, |
388 | 50 | &mode, |
389 | 50 | )?0 ; |
390 | 50 | new_requirement.inner.extend(req); |
391 | 50 | new_requirement = collapse_lex_req(new_requirement); |
392 | 50 | |
393 | 50 | // If our aggregation has grouping sets then our base grouping exprs will |
394 | 50 | // be expanded based on the flags in `group_by.groups` where for each |
395 | 50 | // group we swap the grouping expr for `null` if the flag is `true` |
396 | 50 | // That means that each index in `indices` is valid if and only if |
397 | 50 | // it is not null in every group |
398 | 50 | let indices: Vec<usize> = indices |
399 | 50 | .into_iter() |
400 | 50 | .filter(|idx| group_by.groups.iter().all(1 |group| !group[*idx]1 )1 ) |
401 | 50 | .collect(); |
402 | | |
403 | 50 | let input_order_mode = if indices.len() == groupby_exprs.len() |
404 | 2 | && !indices.is_empty() |
405 | 0 | && group_by.groups.len() == 1 |
406 | | { |
407 | 0 | InputOrderMode::Sorted |
408 | 50 | } else if !indices.is_empty() { |
409 | 0 | InputOrderMode::PartiallySorted(indices) |
410 | | } else { |
411 | 50 | InputOrderMode::Linear |
412 | | }; |
413 | | |
414 | | // construct a map from the input expression to the output expression of the Aggregation group by |
415 | 50 | let projection_mapping = |
416 | 50 | ProjectionMapping::try_new(&group_by.expr, &input.schema())?0 ; |
417 | | |
418 | 50 | let required_input_ordering = |
419 | 50 | (!new_requirement.is_empty()).then_some(new_requirement); |
420 | 50 | |
421 | 50 | let cache = Self::compute_properties( |
422 | 50 | &input, |
423 | 50 | Arc::clone(&schema), |
424 | 50 | &projection_mapping, |
425 | 50 | &mode, |
426 | 50 | &input_order_mode, |
427 | 50 | ); |
428 | 50 | |
429 | 50 | Ok(AggregateExec { |
430 | 50 | mode, |
431 | 50 | group_by, |
432 | 50 | aggr_expr, |
433 | 50 | filter_expr, |
434 | 50 | input, |
435 | 50 | schema, |
436 | 50 | input_schema, |
437 | 50 | metrics: ExecutionPlanMetricsSet::new(), |
438 | 50 | required_input_ordering, |
439 | 50 | limit: None, |
440 | 50 | input_order_mode, |
441 | 50 | cache, |
442 | 50 | }) |
443 | 50 | } |
444 | | |
445 | | /// Aggregation mode (full, partial) |
446 | 0 | pub fn mode(&self) -> &AggregateMode { |
447 | 0 | &self.mode |
448 | 0 | } |
449 | | |
450 | | /// Set the `limit` of this AggExec |
451 | 0 | pub fn with_limit(mut self, limit: Option<usize>) -> Self { |
452 | 0 | self.limit = limit; |
453 | 0 | self |
454 | 0 | } |
455 | | /// Grouping expressions |
456 | 4 | pub fn group_expr(&self) -> &PhysicalGroupBy { |
457 | 4 | &self.group_by |
458 | 4 | } |
459 | | |
460 | | /// Grouping expressions as they occur in the output schema |
461 | 0 | pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> { |
462 | 0 | self.group_by.output_exprs() |
463 | 0 | } |
464 | | |
465 | | /// Aggregate expressions |
466 | 0 | pub fn aggr_expr(&self) -> &[AggregateFunctionExpr] { |
467 | 0 | &self.aggr_expr |
468 | 0 | } |
469 | | |
470 | | /// FILTER (WHERE clause) expression for each aggregate expression |
471 | 0 | pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] { |
472 | 0 | &self.filter_expr |
473 | 0 | } |
474 | | |
475 | | /// Input plan |
476 | 16 | pub fn input(&self) -> &Arc<dyn ExecutionPlan> { |
477 | 16 | &self.input |
478 | 16 | } |
479 | | |
480 | | /// Get the input schema before any aggregates are applied |
481 | 0 | pub fn input_schema(&self) -> SchemaRef { |
482 | 0 | Arc::clone(&self.input_schema) |
483 | 0 | } |
484 | | |
485 | | /// number of rows soft limit of the AggregateExec |
486 | 0 | pub fn limit(&self) -> Option<usize> { |
487 | 0 | self.limit |
488 | 0 | } |
489 | | |
490 | 72 | fn execute_typed( |
491 | 72 | &self, |
492 | 72 | partition: usize, |
493 | 72 | context: Arc<TaskContext>, |
494 | 72 | ) -> Result<StreamType> { |
495 | 72 | // no group by at all |
496 | 72 | if self.group_by.expr.is_empty() { |
497 | 2 | return Ok(StreamType::AggregateStream(AggregateStream::new( |
498 | 2 | self, context, partition, |
499 | 2 | )?0 )); |
500 | 70 | } |
501 | | |
502 | | // grouping by an expression that has a sort/limit upstream |
503 | 70 | if let Some(limit0 ) = self.limit { |
504 | 0 | if !self.is_unordered_unfiltered_group_by_distinct() { |
505 | | return Ok(StreamType::GroupedPriorityQueue( |
506 | 0 | GroupedTopKAggregateStream::new(self, context, partition, limit)?, |
507 | | )); |
508 | 0 | } |
509 | 70 | } |
510 | | |
511 | | // grouping by something else and we need to just materialize all results |
512 | 70 | Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( |
513 | 70 | self, context, partition, |
514 | 70 | )?0 )) |
515 | 72 | } |
516 | | |
517 | | /// Finds the DataType and SortDirection for this Aggregate, if there is one |
518 | 0 | pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { |
519 | 0 | let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; |
520 | 0 | agg_expr.get_minmax_desc() |
521 | 0 | } |
522 | | |
523 | | /// true, if this Aggregate has a group-by with no required or explicit ordering, |
524 | | /// no filtering and no aggregate expressions |
525 | | /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule |
526 | | /// on an AggregateExec. |
527 | 0 | pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { |
528 | 0 | // ensure there is a group by |
529 | 0 | if self.group_expr().is_empty() { |
530 | 0 | return false; |
531 | 0 | } |
532 | 0 | // ensure there are no aggregate expressions |
533 | 0 | if !self.aggr_expr().is_empty() { |
534 | 0 | return false; |
535 | 0 | } |
536 | 0 | // ensure there are no filters on aggregate expressions; the above check |
537 | 0 | // may preclude this case |
538 | 0 | if self.filter_expr().iter().any(|e| e.is_some()) { |
539 | 0 | return false; |
540 | 0 | } |
541 | 0 | // ensure there are no order by expressions |
542 | 0 | if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) { |
543 | 0 | return false; |
544 | 0 | } |
545 | 0 | // ensure there is no output ordering; can this rule be relaxed? |
546 | 0 | if self.properties().output_ordering().is_some() { |
547 | 0 | return false; |
548 | 0 | } |
549 | 0 | // ensure no ordering is required on the input |
550 | 0 | if self.required_input_ordering()[0].is_some() { |
551 | 0 | return false; |
552 | 0 | } |
553 | 0 | true |
554 | 0 | } |
555 | | |
556 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
557 | 50 | pub fn compute_properties( |
558 | 50 | input: &Arc<dyn ExecutionPlan>, |
559 | 50 | schema: SchemaRef, |
560 | 50 | projection_mapping: &ProjectionMapping, |
561 | 50 | mode: &AggregateMode, |
562 | 50 | input_order_mode: &InputOrderMode, |
563 | 50 | ) -> PlanProperties { |
564 | 50 | // Construct equivalence properties: |
565 | 50 | let eq_properties = input |
566 | 50 | .equivalence_properties() |
567 | 50 | .project(projection_mapping, schema); |
568 | 50 | |
569 | 50 | // Get output partitioning: |
570 | 50 | let input_partitioning = input.output_partitioning().clone(); |
571 | 50 | let output_partitioning = if mode.is_first_stage() { |
572 | | // First stage aggregation will not change the output partitioning, |
573 | | // but needs to respect aliases (e.g. mapping in the GROUP BY |
574 | | // expression). |
575 | 25 | let input_eq_properties = input.equivalence_properties(); |
576 | 25 | input_partitioning.project(projection_mapping, input_eq_properties) |
577 | | } else { |
578 | 25 | input_partitioning.clone() |
579 | | }; |
580 | | |
581 | | // Determine execution mode: |
582 | 50 | let mut exec_mode = input.execution_mode(); |
583 | 50 | if exec_mode == ExecutionMode::Unbounded |
584 | 0 | && *input_order_mode == InputOrderMode::Linear |
585 | 0 | { |
586 | 0 | // Cannot run without breaking the pipeline |
587 | 0 | exec_mode = ExecutionMode::PipelineBreaking; |
588 | 50 | } |
589 | | |
590 | 50 | PlanProperties::new(eq_properties, output_partitioning, exec_mode) |
591 | 50 | } |
592 | | |
593 | 0 | pub fn input_order_mode(&self) -> &InputOrderMode { |
594 | 0 | &self.input_order_mode |
595 | 0 | } |
596 | | } |
597 | | |
598 | | impl DisplayAs for AggregateExec { |
599 | 0 | fn fmt_as( |
600 | 0 | &self, |
601 | 0 | t: DisplayFormatType, |
602 | 0 | f: &mut std::fmt::Formatter, |
603 | 0 | ) -> std::fmt::Result { |
604 | 0 | match t { |
605 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
606 | 0 | write!(f, "AggregateExec: mode={:?}", self.mode)?; |
607 | 0 | let g: Vec<String> = if self.group_by.is_single() { |
608 | 0 | self.group_by |
609 | 0 | .expr |
610 | 0 | .iter() |
611 | 0 | .map(|(e, alias)| { |
612 | 0 | let e = e.to_string(); |
613 | 0 | if &e != alias { |
614 | 0 | format!("{e} as {alias}") |
615 | | } else { |
616 | 0 | e |
617 | | } |
618 | 0 | }) |
619 | 0 | .collect() |
620 | | } else { |
621 | 0 | self.group_by |
622 | 0 | .groups |
623 | 0 | .iter() |
624 | 0 | .map(|group| { |
625 | 0 | let terms = group |
626 | 0 | .iter() |
627 | 0 | .enumerate() |
628 | 0 | .map(|(idx, is_null)| { |
629 | 0 | if *is_null { |
630 | 0 | let (e, alias) = &self.group_by.null_expr[idx]; |
631 | 0 | let e = e.to_string(); |
632 | 0 | if &e != alias { |
633 | 0 | format!("{e} as {alias}") |
634 | | } else { |
635 | 0 | e |
636 | | } |
637 | | } else { |
638 | 0 | let (e, alias) = &self.group_by.expr[idx]; |
639 | 0 | let e = e.to_string(); |
640 | 0 | if &e != alias { |
641 | 0 | format!("{e} as {alias}") |
642 | | } else { |
643 | 0 | e |
644 | | } |
645 | | } |
646 | 0 | }) |
647 | 0 | .collect::<Vec<String>>() |
648 | 0 | .join(", "); |
649 | 0 | format!("({terms})") |
650 | 0 | }) |
651 | 0 | .collect() |
652 | | }; |
653 | | |
654 | 0 | write!(f, ", gby=[{}]", g.join(", "))?; |
655 | | |
656 | 0 | let a: Vec<String> = self |
657 | 0 | .aggr_expr |
658 | 0 | .iter() |
659 | 0 | .map(|agg| agg.name().to_string()) |
660 | 0 | .collect(); |
661 | 0 | write!(f, ", aggr=[{}]", a.join(", "))?; |
662 | 0 | if let Some(limit) = self.limit { |
663 | 0 | write!(f, ", lim=[{limit}]")?; |
664 | 0 | } |
665 | | |
666 | 0 | if self.input_order_mode != InputOrderMode::Linear { |
667 | 0 | write!(f, ", ordering_mode={:?}", self.input_order_mode)?; |
668 | 0 | } |
669 | | } |
670 | | } |
671 | 0 | Ok(()) |
672 | 0 | } |
673 | | } |
674 | | |
675 | | impl ExecutionPlan for AggregateExec { |
676 | 0 | fn name(&self) -> &'static str { |
677 | 0 | "AggregateExec" |
678 | 0 | } |
679 | | |
680 | | /// Return a reference to Any that can be used for down-casting |
681 | 0 | fn as_any(&self) -> &dyn Any { |
682 | 0 | self |
683 | 0 | } |
684 | | |
685 | 154 | fn properties(&self) -> &PlanProperties { |
686 | 154 | &self.cache |
687 | 154 | } |
688 | | |
689 | 0 | fn required_input_distribution(&self) -> Vec<Distribution> { |
690 | 0 | match &self.mode { |
691 | | AggregateMode::Partial => { |
692 | 0 | vec![Distribution::UnspecifiedDistribution] |
693 | | } |
694 | | AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { |
695 | 0 | vec![Distribution::HashPartitioned(self.group_by.input_exprs())] |
696 | | } |
697 | | AggregateMode::Final | AggregateMode::Single => { |
698 | 0 | vec![Distribution::SinglePartition] |
699 | | } |
700 | | } |
701 | 0 | } |
702 | | |
703 | 0 | fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> { |
704 | 0 | vec![self.required_input_ordering.clone()] |
705 | 0 | } |
706 | | |
707 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
708 | 0 | vec![&self.input] |
709 | 0 | } |
710 | | |
711 | 1 | fn with_new_children( |
712 | 1 | self: Arc<Self>, |
713 | 1 | children: Vec<Arc<dyn ExecutionPlan>>, |
714 | 1 | ) -> Result<Arc<dyn ExecutionPlan>> { |
715 | 1 | let mut me = AggregateExec::try_new_with_schema( |
716 | 1 | self.mode, |
717 | 1 | self.group_by.clone(), |
718 | 1 | self.aggr_expr.clone(), |
719 | 1 | self.filter_expr.clone(), |
720 | 1 | Arc::clone(&children[0]), |
721 | 1 | Arc::clone(&self.input_schema), |
722 | 1 | Arc::clone(&self.schema), |
723 | 1 | )?0 ; |
724 | 1 | me.limit = self.limit; |
725 | 1 | |
726 | 1 | Ok(Arc::new(me)) |
727 | 1 | } |
728 | | |
729 | 70 | fn execute( |
730 | 70 | &self, |
731 | 70 | partition: usize, |
732 | 70 | context: Arc<TaskContext>, |
733 | 70 | ) -> Result<SendableRecordBatchStream> { |
734 | 70 | self.execute_typed(partition, context) |
735 | 70 | .map(|stream| stream.into()) |
736 | 70 | } |
737 | | |
738 | 8 | fn metrics(&self) -> Option<MetricsSet> { |
739 | 8 | Some(self.metrics.clone_inner()) |
740 | 8 | } |
741 | | |
742 | 8 | fn statistics(&self) -> Result<Statistics> { |
743 | 8 | // TODO stats: group expressions: |
744 | 8 | // - once expressions will be able to compute their own stats, use it here |
745 | 8 | // - case where we group by on a column for which with have the `distinct` stat |
746 | 8 | // TODO stats: aggr expression: |
747 | 8 | // - aggregations sometimes also preserve invariants such as min, max... |
748 | 8 | let column_statistics = Statistics::unknown_column(&self.schema()); |
749 | 8 | match self.mode { |
750 | 8 | AggregateMode::Final | AggregateMode::FinalPartitioned |
751 | 8 | if self.group_by.expr.is_empty()0 => |
752 | | { |
753 | 0 | Ok(Statistics { |
754 | 0 | num_rows: Precision::Exact(1), |
755 | 0 | column_statistics, |
756 | 0 | total_byte_size: Precision::Absent, |
757 | 0 | }) |
758 | | } |
759 | | _ => { |
760 | | // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability. |
761 | | // When it is larger than 1, we degrade the precision since it may decrease after aggregation. |
762 | 8 | let num_rows = if let Some(value) = |
763 | 8 | self.input().statistics()?0 .num_rows.get_value() |
764 | | { |
765 | 8 | if *value > 1 { |
766 | 8 | self.input().statistics()?0 .num_rows.to_inexact() |
767 | 0 | } else if *value == 0 { |
768 | | // Aggregation on an empty table creates a null row. |
769 | 0 | self.input() |
770 | 0 | .statistics()? |
771 | | .num_rows |
772 | 0 | .add(&Precision::Exact(1)) |
773 | | } else { |
774 | | // num_rows = 1 case |
775 | 0 | self.input().statistics()?.num_rows |
776 | | } |
777 | | } else { |
778 | 0 | Precision::Absent |
779 | | }; |
780 | 8 | Ok(Statistics { |
781 | 8 | num_rows, |
782 | 8 | column_statistics, |
783 | 8 | total_byte_size: Precision::Absent, |
784 | 8 | }) |
785 | | } |
786 | | } |
787 | 8 | } |
788 | | } |
789 | | |
790 | 50 | fn create_schema( |
791 | 50 | input_schema: &Schema, |
792 | 50 | group_expr: &[(Arc<dyn PhysicalExpr>, String)], |
793 | 50 | aggr_expr: &[AggregateFunctionExpr], |
794 | 50 | group_expr_nullable: Vec<bool>, |
795 | 50 | mode: AggregateMode, |
796 | 50 | ) -> Result<Schema> { |
797 | 50 | let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); |
798 | 59 | for (index, (expr, name)) in group_expr.iter().enumerate()50 { |
799 | 59 | fields.push( |
800 | 59 | Field::new( |
801 | 59 | name, |
802 | 59 | expr.data_type(input_schema)?0 , |
803 | | // In cases where we have multiple grouping sets, we will use NULL expressions in |
804 | | // order to align the grouping sets. So the field must be nullable even if the underlying |
805 | | // schema field is not. |
806 | 59 | group_expr_nullable[index] || expr.nullable(input_schema)47 ?0 , |
807 | | ) |
808 | 59 | .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()), |
809 | | ) |
810 | | } |
811 | | |
812 | 50 | match mode { |
813 | | AggregateMode::Partial => { |
814 | | // in partial mode, the fields of the accumulator's state |
815 | 49 | for expr25 in aggr_expr { |
816 | 25 | fields.extend(expr.state_fields()?0 .iter().cloned()) |
817 | | } |
818 | | } |
819 | | AggregateMode::Final |
820 | | | AggregateMode::FinalPartitioned |
821 | | | AggregateMode::Single |
822 | | | AggregateMode::SinglePartitioned => { |
823 | | // in final mode, the field with the final result of the accumulator |
824 | 44 | for expr18 in aggr_expr { |
825 | 18 | fields.push(expr.field()) |
826 | | } |
827 | | } |
828 | | } |
829 | | |
830 | 50 | Ok(Schema::new_with_metadata( |
831 | 50 | fields, |
832 | 50 | input_schema.metadata().clone(), |
833 | 50 | )) |
834 | 50 | } |
835 | | |
836 | 70 | fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { |
837 | 70 | let group_fields = schema.fields()[0..group_count].to_vec(); |
838 | 70 | Arc::new(Schema::new(group_fields)) |
839 | 70 | } |
840 | | |
841 | | /// Determines the lexical ordering requirement for an aggregate expression. |
842 | | /// |
843 | | /// # Parameters |
844 | | /// |
845 | | /// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the |
846 | | /// aggregate expression. |
847 | | /// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the |
848 | | /// physical GROUP BY expression. |
849 | | /// - `agg_mode`: A reference to an `AggregateMode` instance representing the |
850 | | /// mode of aggregation. |
851 | | /// |
852 | | /// # Returns |
853 | | /// |
854 | | /// A `LexOrdering` instance indicating the lexical ordering requirement for |
855 | | /// the aggregate expression. |
856 | 54 | fn get_aggregate_expr_req( |
857 | 54 | aggr_expr: &AggregateFunctionExpr, |
858 | 54 | group_by: &PhysicalGroupBy, |
859 | 54 | agg_mode: &AggregateMode, |
860 | 54 | ) -> LexOrdering { |
861 | 54 | // If the aggregation function is ordering requirement is not absolutely |
862 | 54 | // necessary, or the aggregation is performing a "second stage" calculation, |
863 | 54 | // then ignore the ordering requirement. |
864 | 54 | if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage()9 { |
865 | 45 | return vec![]; |
866 | 9 | } |
867 | 9 | |
868 | 9 | let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec(); |
869 | 9 | |
870 | 9 | // In non-first stage modes, we accumulate data (using `merge_batch`) from |
871 | 9 | // different partitions (i.e. merge partial results). During this merge, we |
872 | 9 | // consider the ordering of each partial result. Hence, we do not need to |
873 | 9 | // use the ordering requirement in such modes as long as partial results are |
874 | 9 | // generated with the correct ordering. |
875 | 9 | if group_by.is_single() { |
876 | 9 | // Remove all orderings that occur in the group by. These requirements |
877 | 9 | // will definitely be satisfied -- Each group by expression will have |
878 | 9 | // distinct values per group, hence all requirements are satisfied. |
879 | 9 | let physical_exprs = group_by.input_exprs(); |
880 | 18 | req.retain(|sort_expr| { |
881 | 18 | !physical_exprs_contains(&physical_exprs, &sort_expr.expr) |
882 | 18 | }); |
883 | 9 | }0 |
884 | 9 | req |
885 | 54 | } |
886 | | |
887 | | /// Computes the finer ordering for between given existing ordering requirement |
888 | | /// of aggregate expression. |
889 | | /// |
890 | | /// # Parameters |
891 | | /// |
892 | | /// * `existing_req` - The existing lexical ordering that needs refinement. |
893 | | /// * `aggr_expr` - A reference to an aggregate expression trait object. |
894 | | /// * `group_by` - Information about the physical grouping (e.g group by expression). |
895 | | /// * `eq_properties` - Equivalence properties relevant to the computation. |
896 | | /// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). |
897 | | /// |
898 | | /// # Returns |
899 | | /// |
900 | | /// An `Option<LexOrdering>` representing the computed finer lexical ordering, |
901 | | /// or `None` if there is no finer ordering; e.g. the existing requirement and |
902 | | /// the aggregator requirement is incompatible. |
903 | 54 | fn finer_ordering( |
904 | 54 | existing_req: &LexOrdering, |
905 | 54 | aggr_expr: &AggregateFunctionExpr, |
906 | 54 | group_by: &PhysicalGroupBy, |
907 | 54 | eq_properties: &EquivalenceProperties, |
908 | 54 | agg_mode: &AggregateMode, |
909 | 54 | ) -> Option<LexOrdering> { |
910 | 54 | let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); |
911 | 54 | eq_properties.get_finer_ordering(existing_req, &aggr_req) |
912 | 54 | } |
913 | | |
914 | | /// Concatenates the given slices. |
915 | 0 | pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> { |
916 | 0 | [lhs, rhs].concat() |
917 | 0 | } |
918 | | |
919 | | /// Get the common requirement that satisfies all the aggregate expressions. |
920 | | /// |
921 | | /// # Parameters |
922 | | /// |
923 | | /// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the |
924 | | /// aggregate expressions. |
925 | | /// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the |
926 | | /// physical GROUP BY expression. |
927 | | /// - `eq_properties`: A reference to an `EquivalenceProperties` instance |
928 | | /// representing equivalence properties for ordering. |
929 | | /// - `agg_mode`: A reference to an `AggregateMode` instance representing the |
930 | | /// mode of aggregation. |
931 | | /// |
932 | | /// # Returns |
933 | | /// |
934 | | /// A `LexRequirement` instance, which is the requirement that satisfies all the |
935 | | /// aggregate requirements. Returns an error in case of conflicting requirements. |
936 | 51 | pub fn get_finer_aggregate_exprs_requirement( |
937 | 51 | aggr_exprs: &mut [AggregateFunctionExpr], |
938 | 51 | group_by: &PhysicalGroupBy, |
939 | 51 | eq_properties: &EquivalenceProperties, |
940 | 51 | agg_mode: &AggregateMode, |
941 | 51 | ) -> Result<LexRequirement> { |
942 | 51 | let mut requirement = vec![]; |
943 | 51 | for aggr_expr48 in aggr_exprs.iter_mut() { |
944 | 48 | if let Some(finer_ordering) = |
945 | 48 | finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) |
946 | | { |
947 | 48 | if eq_properties.ordering_satisfy(&finer_ordering) { |
948 | | // Requirement is satisfied by existing ordering |
949 | 45 | requirement = finer_ordering; |
950 | 45 | continue; |
951 | 3 | } |
952 | 0 | } |
953 | 3 | if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { |
954 | 3 | if let Some(finer_ordering1 ) = finer_ordering( |
955 | 3 | &requirement, |
956 | 3 | &reverse_aggr_expr, |
957 | 3 | group_by, |
958 | 3 | eq_properties, |
959 | 3 | agg_mode, |
960 | 3 | ) { |
961 | 1 | if eq_properties.ordering_satisfy(&finer_ordering) { |
962 | | // Reverse requirement is satisfied by exiting ordering. |
963 | | // Hence reverse the aggregator |
964 | 0 | requirement = finer_ordering; |
965 | 0 | *aggr_expr = reverse_aggr_expr; |
966 | 0 | continue; |
967 | 1 | } |
968 | 2 | } |
969 | 0 | } |
970 | 3 | if let Some(finer_ordering) = |
971 | 3 | finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) |
972 | | { |
973 | | // There is a requirement that both satisfies existing requirement and current |
974 | | // aggregate requirement. Use updated requirement |
975 | 3 | requirement = finer_ordering; |
976 | 3 | continue; |
977 | 0 | } |
978 | 0 | if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { |
979 | 0 | if let Some(finer_ordering) = finer_ordering( |
980 | 0 | &requirement, |
981 | 0 | &reverse_aggr_expr, |
982 | 0 | group_by, |
983 | 0 | eq_properties, |
984 | 0 | agg_mode, |
985 | 0 | ) { |
986 | | // There is a requirement that both satisfies existing requirement and reverse |
987 | | // aggregate requirement. Use updated requirement |
988 | 0 | requirement = finer_ordering; |
989 | 0 | *aggr_expr = reverse_aggr_expr; |
990 | 0 | continue; |
991 | 0 | } |
992 | 0 | } |
993 | | |
994 | | // Neither the existing requirement and current aggregate requirement satisfy the other, this means |
995 | | // requirements are conflicting. Currently, we do not support |
996 | | // conflicting requirements. |
997 | 0 | return not_impl_err!( |
998 | 0 | "Conflicting ordering requirements in aggregate functions is not supported" |
999 | 0 | ); |
1000 | | } |
1001 | | |
1002 | 51 | Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) |
1003 | 51 | } |
1004 | | |
1005 | | /// Returns physical expressions for arguments to evaluate against a batch. |
1006 | | /// |
1007 | | /// The expressions are different depending on `mode`: |
1008 | | /// * Partial: AggregateFunctionExpr::expressions |
1009 | | /// * Final: columns of `AggregateFunctionExpr::state_fields()` |
1010 | 142 | pub fn aggregate_expressions( |
1011 | 142 | aggr_expr: &[AggregateFunctionExpr], |
1012 | 142 | mode: &AggregateMode, |
1013 | 142 | col_idx_base: usize, |
1014 | 142 | ) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> { |
1015 | 142 | match mode { |
1016 | | AggregateMode::Partial |
1017 | | | AggregateMode::Single |
1018 | 55 | | AggregateMode::SinglePartitioned => Ok(aggr_expr |
1019 | 55 | .iter() |
1020 | 55 | .map(|agg| { |
1021 | 55 | let mut result = agg.expressions(); |
1022 | | // Append ordering requirements to expressions' results. This |
1023 | | // way order sensitive aggregators can satisfy requirement |
1024 | | // themselves. |
1025 | 55 | if let Some(ordering_req32 ) = agg.order_bys() { |
1026 | 32 | result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); |
1027 | 32 | }23 |
1028 | 55 | result |
1029 | 55 | }) |
1030 | 55 | .collect()), |
1031 | | // In this mode, we build the merge expressions of the aggregation. |
1032 | | AggregateMode::Final | AggregateMode::FinalPartitioned => { |
1033 | 87 | let mut col_idx_base = col_idx_base; |
1034 | 87 | aggr_expr |
1035 | 87 | .iter() |
1036 | 87 | .map(|agg| { |
1037 | 87 | let exprs = merge_expressions(col_idx_base, agg)?0 ; |
1038 | 87 | col_idx_base += exprs.len(); |
1039 | 87 | Ok(exprs) |
1040 | 87 | }) |
1041 | 87 | .collect() |
1042 | | } |
1043 | | } |
1044 | 142 | } |
1045 | | |
1046 | | /// uses `state_fields` to build a vec of physical column expressions required to merge the |
1047 | | /// AggregateFunctionExpr' accumulator's state. |
1048 | | /// |
1049 | | /// `index_base` is the starting physical column index for the next expanded state field. |
1050 | 87 | fn merge_expressions( |
1051 | 87 | index_base: usize, |
1052 | 87 | expr: &AggregateFunctionExpr, |
1053 | 87 | ) -> Result<Vec<Arc<dyn PhysicalExpr>>> { |
1054 | 87 | expr.state_fields().map(|fields| { |
1055 | 87 | fields |
1056 | 87 | .iter() |
1057 | 87 | .enumerate() |
1058 | 201 | .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _) |
1059 | 87 | .collect() |
1060 | 87 | }) |
1061 | 87 | } |
1062 | | |
1063 | | pub type AccumulatorItem = Box<dyn Accumulator>; |
1064 | | |
1065 | 2 | pub fn create_accumulators( |
1066 | 2 | aggr_expr: &[AggregateFunctionExpr], |
1067 | 2 | ) -> Result<Vec<AccumulatorItem>> { |
1068 | 2 | aggr_expr |
1069 | 2 | .iter() |
1070 | 2 | .map(|expr| expr.create_accumulator()) |
1071 | 2 | .collect() |
1072 | 2 | } |
1073 | | |
1074 | | /// returns a vector of ArrayRefs, where each entry corresponds to either the |
1075 | | /// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial) |
1076 | 0 | pub fn finalize_aggregation( |
1077 | 0 | accumulators: &mut [AccumulatorItem], |
1078 | 0 | mode: &AggregateMode, |
1079 | 0 | ) -> Result<Vec<ArrayRef>> { |
1080 | 0 | match mode { |
1081 | | AggregateMode::Partial => { |
1082 | | // Build the vector of states |
1083 | 0 | accumulators |
1084 | 0 | .iter_mut() |
1085 | 0 | .map(|accumulator| { |
1086 | 0 | accumulator.state().and_then(|e| { |
1087 | 0 | e.iter() |
1088 | 0 | .map(|v| v.to_array()) |
1089 | 0 | .collect::<Result<Vec<ArrayRef>>>() |
1090 | 0 | }) |
1091 | 0 | }) |
1092 | 0 | .flatten_ok() |
1093 | 0 | .collect() |
1094 | | } |
1095 | | AggregateMode::Final |
1096 | | | AggregateMode::FinalPartitioned |
1097 | | | AggregateMode::Single |
1098 | | | AggregateMode::SinglePartitioned => { |
1099 | | // Merge the state to the final value |
1100 | 0 | accumulators |
1101 | 0 | .iter_mut() |
1102 | 0 | .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) |
1103 | 0 | .collect() |
1104 | | } |
1105 | | } |
1106 | 0 | } |
1107 | | |
1108 | | /// Evaluates expressions against a record batch. |
1109 | 131 | fn evaluate( |
1110 | 131 | expr: &[Arc<dyn PhysicalExpr>], |
1111 | 131 | batch: &RecordBatch, |
1112 | 131 | ) -> Result<Vec<ArrayRef>> { |
1113 | 131 | expr.iter() |
1114 | 245 | .map(|expr| { |
1115 | 245 | expr.evaluate(batch) |
1116 | 245 | .and_then(|v| v.into_array(batch.num_rows())) |
1117 | 245 | }) |
1118 | 131 | .collect() |
1119 | 131 | } |
1120 | | |
1121 | | /// Evaluates expressions against a record batch. |
1122 | 131 | pub(crate) fn evaluate_many( |
1123 | 131 | expr: &[Vec<Arc<dyn PhysicalExpr>>], |
1124 | 131 | batch: &RecordBatch, |
1125 | 131 | ) -> Result<Vec<Vec<ArrayRef>>> { |
1126 | 131 | expr.iter().map(|expr| evaluate(expr, batch)).collect() |
1127 | 131 | } |
1128 | | |
1129 | 131 | fn evaluate_optional( |
1130 | 131 | expr: &[Option<Arc<dyn PhysicalExpr>>], |
1131 | 131 | batch: &RecordBatch, |
1132 | 131 | ) -> Result<Vec<Option<ArrayRef>>> { |
1133 | 131 | expr.iter() |
1134 | 131 | .map(|expr| { |
1135 | 131 | expr.as_ref() |
1136 | 131 | .map(|expr| { |
1137 | 0 | expr.evaluate(batch) |
1138 | 0 | .and_then(|v| v.into_array(batch.num_rows())) |
1139 | 131 | }) |
1140 | 131 | .transpose() |
1141 | 131 | }) |
1142 | 131 | .collect() |
1143 | 131 | } |
1144 | | |
1145 | | /// Evaluate a group by expression against a `RecordBatch` |
1146 | | /// |
1147 | | /// Arguments: |
1148 | | /// - `group_by`: the expression to evaluate |
1149 | | /// - `batch`: the `RecordBatch` to evaluate against |
1150 | | /// |
1151 | | /// Returns: A Vec of Vecs of Array of results |
1152 | | /// The outer Vec appears to be for grouping sets |
1153 | | /// The inner Vec contains the results per expression |
1154 | | /// The inner-inner Array contains the results per row |
1155 | 131 | pub(crate) fn evaluate_group_by( |
1156 | 131 | group_by: &PhysicalGroupBy, |
1157 | 131 | batch: &RecordBatch, |
1158 | 131 | ) -> Result<Vec<Vec<ArrayRef>>> { |
1159 | 131 | let exprs: Vec<ArrayRef> = group_by |
1160 | 131 | .expr |
1161 | 131 | .iter() |
1162 | 163 | .map(|(expr, _)| { |
1163 | 163 | let value = expr.evaluate(batch)?0 ; |
1164 | 163 | value.into_array(batch.num_rows()) |
1165 | 163 | }) |
1166 | 131 | .collect::<Result<Vec<_>>>()?0 ; |
1167 | | |
1168 | 131 | let null_exprs: Vec<ArrayRef> = group_by |
1169 | 131 | .null_expr |
1170 | 131 | .iter() |
1171 | 131 | .map(|(expr, _)| {44 |
1172 | 44 | let value = expr.evaluate(batch)?0 ; |
1173 | 44 | value.into_array(batch.num_rows()) |
1174 | 131 | }44 ) |
1175 | 131 | .collect::<Result<Vec<_>>>()?0 ; |
1176 | | |
1177 | 131 | Ok(group_by |
1178 | 131 | .groups |
1179 | 131 | .iter() |
1180 | 171 | .map(|group| { |
1181 | 171 | group |
1182 | 171 | .iter() |
1183 | 171 | .enumerate() |
1184 | 251 | .map(|(idx, is_null)| { |
1185 | 251 | if *is_null { |
1186 | 56 | Arc::clone(&null_exprs[idx]) |
1187 | | } else { |
1188 | 195 | Arc::clone(&exprs[idx]) |
1189 | | } |
1190 | 251 | }) |
1191 | 171 | .collect() |
1192 | 171 | }) |
1193 | 131 | .collect()) |
1194 | 131 | } |
1195 | | |
1196 | | #[cfg(test)] |
1197 | | mod tests { |
1198 | | use std::task::{Context, Poll}; |
1199 | | |
1200 | | use super::*; |
1201 | | use crate::coalesce_batches::CoalesceBatchesExec; |
1202 | | use crate::coalesce_partitions::CoalescePartitionsExec; |
1203 | | use crate::common; |
1204 | | use crate::expressions::col; |
1205 | | use crate::memory::MemoryExec; |
1206 | | use crate::test::assert_is_pending; |
1207 | | use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; |
1208 | | use crate::RecordBatchStream; |
1209 | | |
1210 | | use arrow::array::{Float64Array, UInt32Array}; |
1211 | | use arrow::compute::{concat_batches, SortOptions}; |
1212 | | use arrow::datatypes::{DataType, Int32Type}; |
1213 | | use arrow_array::{ |
1214 | | DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array, |
1215 | | }; |
1216 | | use datafusion_common::{ |
1217 | | assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, |
1218 | | ScalarValue, |
1219 | | }; |
1220 | | use datafusion_execution::config::SessionConfig; |
1221 | | use datafusion_execution::memory_pool::FairSpillPool; |
1222 | | use datafusion_execution::runtime_env::RuntimeEnvBuilder; |
1223 | | use datafusion_functions_aggregate::array_agg::array_agg_udaf; |
1224 | | use datafusion_functions_aggregate::average::avg_udaf; |
1225 | | use datafusion_functions_aggregate::count::count_udaf; |
1226 | | use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf}; |
1227 | | use datafusion_functions_aggregate::median::median_udaf; |
1228 | | use datafusion_functions_aggregate::sum::sum_udaf; |
1229 | | use datafusion_physical_expr::expressions::lit; |
1230 | | use datafusion_physical_expr::PhysicalSortExpr; |
1231 | | |
1232 | | use crate::common::collect; |
1233 | | use datafusion_physical_expr::aggregate::AggregateExprBuilder; |
1234 | | use datafusion_physical_expr::expressions::Literal; |
1235 | | use datafusion_physical_expr::Partitioning; |
1236 | | use futures::{FutureExt, Stream}; |
1237 | | |
1238 | | // Generate a schema which consists of 5 columns (a, b, c, d, e) |
1239 | 1 | fn create_test_schema() -> Result<SchemaRef> { |
1240 | 1 | let a = Field::new("a", DataType::Int32, true); |
1241 | 1 | let b = Field::new("b", DataType::Int32, true); |
1242 | 1 | let c = Field::new("c", DataType::Int32, true); |
1243 | 1 | let d = Field::new("d", DataType::Int32, true); |
1244 | 1 | let e = Field::new("e", DataType::Int32, true); |
1245 | 1 | let schema = Arc::new(Schema::new(vec![a, b, c, d, e])); |
1246 | 1 | |
1247 | 1 | Ok(schema) |
1248 | 1 | } |
1249 | | |
1250 | | /// some mock data to aggregates |
1251 | 43 | fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) { |
1252 | 43 | // define a schema. |
1253 | 43 | let schema = Arc::new(Schema::new(vec![ |
1254 | 43 | Field::new("a", DataType::UInt32, false), |
1255 | 43 | Field::new("b", DataType::Float64, false), |
1256 | 43 | ])); |
1257 | 43 | |
1258 | 43 | // define data. |
1259 | 43 | ( |
1260 | 43 | Arc::clone(&schema), |
1261 | 43 | vec![ |
1262 | 43 | RecordBatch::try_new( |
1263 | 43 | Arc::clone(&schema), |
1264 | 43 | vec![ |
1265 | 43 | Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), |
1266 | 43 | Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), |
1267 | 43 | ], |
1268 | 43 | ) |
1269 | 43 | .unwrap(), |
1270 | 43 | RecordBatch::try_new( |
1271 | 43 | schema, |
1272 | 43 | vec![ |
1273 | 43 | Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), |
1274 | 43 | Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), |
1275 | 43 | ], |
1276 | 43 | ) |
1277 | 43 | .unwrap(), |
1278 | 43 | ], |
1279 | 43 | ) |
1280 | 43 | } |
1281 | | |
1282 | | /// Generates some mock data for aggregate tests. |
1283 | 8 | fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) { |
1284 | 8 | // Define a schema: |
1285 | 8 | let schema = Arc::new(Schema::new(vec![ |
1286 | 8 | Field::new("a", DataType::UInt32, false), |
1287 | 8 | Field::new("b", DataType::Float64, false), |
1288 | 8 | ])); |
1289 | 8 | |
1290 | 8 | // Generate data so that first and last value results are at 2nd and |
1291 | 8 | // 3rd partitions. With this construction, we guarantee we don't receive |
1292 | 8 | // the expected result by accident, but merging actually works properly; |
1293 | 8 | // i.e. it doesn't depend on the data insertion order. |
1294 | 8 | ( |
1295 | 8 | Arc::clone(&schema), |
1296 | 8 | vec![ |
1297 | 8 | RecordBatch::try_new( |
1298 | 8 | Arc::clone(&schema), |
1299 | 8 | vec![ |
1300 | 8 | Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), |
1301 | 8 | Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), |
1302 | 8 | ], |
1303 | 8 | ) |
1304 | 8 | .unwrap(), |
1305 | 8 | RecordBatch::try_new( |
1306 | 8 | Arc::clone(&schema), |
1307 | 8 | vec![ |
1308 | 8 | Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), |
1309 | 8 | Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])), |
1310 | 8 | ], |
1311 | 8 | ) |
1312 | 8 | .unwrap(), |
1313 | 8 | RecordBatch::try_new( |
1314 | 8 | Arc::clone(&schema), |
1315 | 8 | vec![ |
1316 | 8 | Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), |
1317 | 8 | Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])), |
1318 | 8 | ], |
1319 | 8 | ) |
1320 | 8 | .unwrap(), |
1321 | 8 | RecordBatch::try_new( |
1322 | 8 | schema, |
1323 | 8 | vec![ |
1324 | 8 | Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), |
1325 | 8 | Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])), |
1326 | 8 | ], |
1327 | 8 | ) |
1328 | 8 | .unwrap(), |
1329 | 8 | ], |
1330 | 8 | ) |
1331 | 8 | } |
1332 | | |
1333 | 12 | fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> { |
1334 | 12 | let session_config = SessionConfig::new().with_batch_size(batch_size); |
1335 | 12 | let runtime = RuntimeEnvBuilder::default() |
1336 | 12 | .with_memory_pool(Arc::new(FairSpillPool::new(max_memory))) |
1337 | 12 | .build_arc() |
1338 | 12 | .unwrap(); |
1339 | 12 | let task_ctx = TaskContext::default() |
1340 | 12 | .with_session_config(session_config) |
1341 | 12 | .with_runtime(runtime); |
1342 | 12 | Arc::new(task_ctx) |
1343 | 12 | } |
1344 | | |
1345 | 4 | async fn check_grouping_sets( |
1346 | 4 | input: Arc<dyn ExecutionPlan>, |
1347 | 4 | spill: bool, |
1348 | 4 | ) -> Result<()> { |
1349 | 4 | let input_schema = input.schema(); |
1350 | | |
1351 | 4 | let grouping_set = PhysicalGroupBy { |
1352 | 4 | expr: vec![ |
1353 | 4 | (col("a", &input_schema)?0 , "a".to_string()), |
1354 | 4 | (col("b", &input_schema)?0 , "b".to_string()), |
1355 | 4 | ], |
1356 | 4 | null_expr: vec![ |
1357 | 4 | (lit(ScalarValue::UInt32(None)), "a".to_string()), |
1358 | 4 | (lit(ScalarValue::Float64(None)), "b".to_string()), |
1359 | 4 | ], |
1360 | 4 | groups: vec![ |
1361 | 4 | vec![false, true], // (a, NULL) |
1362 | 4 | vec![true, false], // (NULL, b) |
1363 | 4 | vec![false, false], // (a,b) |
1364 | 4 | ], |
1365 | | }; |
1366 | | |
1367 | 4 | let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) |
1368 | 4 | .schema(Arc::clone(&input_schema)) |
1369 | 4 | .alias("COUNT(1)") |
1370 | 4 | .build()?0 ]; |
1371 | | |
1372 | 4 | let task_ctx = if spill { |
1373 | | // adjust the max memory size to have the partial aggregate result for spill mode. |
1374 | 2 | new_spill_ctx(4, 500) |
1375 | | } else { |
1376 | 2 | Arc::new(TaskContext::default()) |
1377 | | }; |
1378 | | |
1379 | 4 | let partial_aggregate = Arc::new(AggregateExec::try_new( |
1380 | 4 | AggregateMode::Partial, |
1381 | 4 | grouping_set.clone(), |
1382 | 4 | aggregates.clone(), |
1383 | 4 | vec![None], |
1384 | 4 | input, |
1385 | 4 | Arc::clone(&input_schema), |
1386 | 4 | )?0 ); |
1387 | | |
1388 | 4 | let result = |
1389 | 4 | common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?0 ).await2 ?0 ; |
1390 | | |
1391 | 4 | let expected = if spill { |
1392 | | // In spill mode, we test with the limited memory, if the mem usage exceeds, |
1393 | | // we trigger the early emit rule, which turns out the partial aggregate result. |
1394 | 2 | vec![ |
1395 | 2 | "+---+-----+-----------------+", |
1396 | 2 | "| a | b | COUNT(1)[count] |", |
1397 | 2 | "+---+-----+-----------------+", |
1398 | 2 | "| | 1.0 | 1 |", |
1399 | 2 | "| | 1.0 | 1 |", |
1400 | 2 | "| | 2.0 | 1 |", |
1401 | 2 | "| | 2.0 | 1 |", |
1402 | 2 | "| | 3.0 | 1 |", |
1403 | 2 | "| | 3.0 | 1 |", |
1404 | 2 | "| | 4.0 | 1 |", |
1405 | 2 | "| | 4.0 | 1 |", |
1406 | 2 | "| 2 | | 1 |", |
1407 | 2 | "| 2 | | 1 |", |
1408 | 2 | "| 2 | 1.0 | 1 |", |
1409 | 2 | "| 2 | 1.0 | 1 |", |
1410 | 2 | "| 3 | | 1 |", |
1411 | 2 | "| 3 | | 2 |", |
1412 | 2 | "| 3 | 2.0 | 2 |", |
1413 | 2 | "| 3 | 3.0 | 1 |", |
1414 | 2 | "| 4 | | 1 |", |
1415 | 2 | "| 4 | | 2 |", |
1416 | 2 | "| 4 | 3.0 | 1 |", |
1417 | 2 | "| 4 | 4.0 | 2 |", |
1418 | 2 | "+---+-----+-----------------+", |
1419 | 2 | ] |
1420 | | } else { |
1421 | 2 | vec![ |
1422 | 2 | "+---+-----+-----------------+", |
1423 | 2 | "| a | b | COUNT(1)[count] |", |
1424 | 2 | "+---+-----+-----------------+", |
1425 | 2 | "| | 1.0 | 2 |", |
1426 | 2 | "| | 2.0 | 2 |", |
1427 | 2 | "| | 3.0 | 2 |", |
1428 | 2 | "| | 4.0 | 2 |", |
1429 | 2 | "| 2 | | 2 |", |
1430 | 2 | "| 2 | 1.0 | 2 |", |
1431 | 2 | "| 3 | | 3 |", |
1432 | 2 | "| 3 | 2.0 | 2 |", |
1433 | 2 | "| 3 | 3.0 | 1 |", |
1434 | 2 | "| 4 | | 3 |", |
1435 | 2 | "| 4 | 3.0 | 1 |", |
1436 | 2 | "| 4 | 4.0 | 2 |", |
1437 | 2 | "+---+-----+-----------------+", |
1438 | 2 | ] |
1439 | | }; |
1440 | 4 | assert_batches_sorted_eq!(expected, &result); |
1441 | | |
1442 | 4 | let groups = partial_aggregate.group_expr().expr().to_vec(); |
1443 | 4 | |
1444 | 4 | let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); |
1445 | | |
1446 | 4 | let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = groups |
1447 | 4 | .iter() |
1448 | 8 | .map(|(_expr, name)| Ok((col(name, &input_schema)?0 , name.clone()))) |
1449 | 4 | .collect::<Result<_>>()?0 ; |
1450 | | |
1451 | 4 | let final_grouping_set = PhysicalGroupBy::new_single(final_group); |
1452 | | |
1453 | 4 | let task_ctx = if spill { |
1454 | 2 | new_spill_ctx(4, 3160) |
1455 | | } else { |
1456 | 2 | task_ctx |
1457 | | }; |
1458 | | |
1459 | 4 | let merged_aggregate = Arc::new(AggregateExec::try_new( |
1460 | 4 | AggregateMode::Final, |
1461 | 4 | final_grouping_set, |
1462 | 4 | aggregates, |
1463 | 4 | vec![None], |
1464 | 4 | merge, |
1465 | 4 | input_schema, |
1466 | 4 | )?0 ); |
1467 | | |
1468 | 4 | let result = |
1469 | 4 | common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?0 ).await2 ?0 ; |
1470 | 4 | let batch = concat_batches(&result[0].schema(), &result)?0 ; |
1471 | 4 | assert_eq!(batch.num_columns(), 3); |
1472 | 4 | assert_eq!(batch.num_rows(), 12); |
1473 | | |
1474 | 4 | let expected = vec![ |
1475 | 4 | "+---+-----+----------+", |
1476 | 4 | "| a | b | COUNT(1) |", |
1477 | 4 | "+---+-----+----------+", |
1478 | 4 | "| | 1.0 | 2 |", |
1479 | 4 | "| | 2.0 | 2 |", |
1480 | 4 | "| | 3.0 | 2 |", |
1481 | 4 | "| | 4.0 | 2 |", |
1482 | 4 | "| 2 | | 2 |", |
1483 | 4 | "| 2 | 1.0 | 2 |", |
1484 | 4 | "| 3 | | 3 |", |
1485 | 4 | "| 3 | 2.0 | 2 |", |
1486 | 4 | "| 3 | 3.0 | 1 |", |
1487 | 4 | "| 4 | | 3 |", |
1488 | 4 | "| 4 | 3.0 | 1 |", |
1489 | 4 | "| 4 | 4.0 | 2 |", |
1490 | 4 | "+---+-----+----------+", |
1491 | 4 | ]; |
1492 | 4 | |
1493 | 4 | assert_batches_sorted_eq!(&expected, &result); |
1494 | | |
1495 | 4 | let metrics = merged_aggregate.metrics().unwrap(); |
1496 | 4 | let output_rows = metrics.output_rows().unwrap(); |
1497 | 4 | assert_eq!(12, output_rows); |
1498 | | |
1499 | 4 | Ok(()) |
1500 | 4 | } |
1501 | | |
1502 | | /// build the aggregates on the data from some_data() and check the results |
1503 | 4 | async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> { |
1504 | 4 | let input_schema = input.schema(); |
1505 | | |
1506 | 4 | let grouping_set = PhysicalGroupBy { |
1507 | 4 | expr: vec![(col("a", &input_schema)?0 , "a".to_string())], |
1508 | 4 | null_expr: vec![], |
1509 | 4 | groups: vec![vec![false]], |
1510 | | }; |
1511 | | |
1512 | 4 | let aggregates: Vec<AggregateFunctionExpr> = |
1513 | 4 | vec![ |
1514 | 4 | AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?0 ]) |
1515 | 4 | .schema(Arc::clone(&input_schema)) |
1516 | 4 | .alias("AVG(b)") |
1517 | 4 | .build()?0 , |
1518 | | ]; |
1519 | | |
1520 | 4 | let task_ctx = if spill { |
1521 | | // set to an appropriate value to trigger spill |
1522 | 2 | new_spill_ctx(2, 1600) |
1523 | | } else { |
1524 | 2 | Arc::new(TaskContext::default()) |
1525 | | }; |
1526 | | |
1527 | 4 | let partial_aggregate = Arc::new(AggregateExec::try_new( |
1528 | 4 | AggregateMode::Partial, |
1529 | 4 | grouping_set.clone(), |
1530 | 4 | aggregates.clone(), |
1531 | 4 | vec![None], |
1532 | 4 | input, |
1533 | 4 | Arc::clone(&input_schema), |
1534 | 4 | )?0 ); |
1535 | | |
1536 | 4 | let result = |
1537 | 4 | common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?0 ).await2 ?0 ; |
1538 | | |
1539 | 4 | let expected = if spill { |
1540 | 2 | vec![ |
1541 | 2 | "+---+---------------+-------------+", |
1542 | 2 | "| a | AVG(b)[count] | AVG(b)[sum] |", |
1543 | 2 | "+---+---------------+-------------+", |
1544 | 2 | "| 2 | 1 | 1.0 |", |
1545 | 2 | "| 2 | 1 | 1.0 |", |
1546 | 2 | "| 3 | 1 | 2.0 |", |
1547 | 2 | "| 3 | 2 | 5.0 |", |
1548 | 2 | "| 4 | 3 | 11.0 |", |
1549 | 2 | "+---+---------------+-------------+", |
1550 | 2 | ] |
1551 | | } else { |
1552 | 2 | vec![ |
1553 | 2 | "+---+---------------+-------------+", |
1554 | 2 | "| a | AVG(b)[count] | AVG(b)[sum] |", |
1555 | 2 | "+---+---------------+-------------+", |
1556 | 2 | "| 2 | 2 | 2.0 |", |
1557 | 2 | "| 3 | 3 | 7.0 |", |
1558 | 2 | "| 4 | 3 | 11.0 |", |
1559 | 2 | "+---+---------------+-------------+", |
1560 | 2 | ] |
1561 | | }; |
1562 | 4 | assert_batches_sorted_eq!(expected, &result); |
1563 | | |
1564 | 4 | let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); |
1565 | | |
1566 | 4 | let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = grouping_set |
1567 | 4 | .expr |
1568 | 4 | .iter() |
1569 | 4 | .map(|(_expr, name)| Ok((col(name, &input_schema)?0 , name.clone()))) |
1570 | 4 | .collect::<Result<_>>()?0 ; |
1571 | | |
1572 | 4 | let final_grouping_set = PhysicalGroupBy::new_single(final_group); |
1573 | | |
1574 | 4 | let merged_aggregate = Arc::new(AggregateExec::try_new( |
1575 | 4 | AggregateMode::Final, |
1576 | 4 | final_grouping_set, |
1577 | 4 | aggregates, |
1578 | 4 | vec![None], |
1579 | 4 | merge, |
1580 | 4 | input_schema, |
1581 | 4 | )?0 ); |
1582 | | |
1583 | 4 | let task_ctx = if spill { |
1584 | | // enlarge memory limit to let the final aggregation finish |
1585 | 2 | new_spill_ctx(2, 2600) |
1586 | | } else { |
1587 | 2 | Arc::clone(&task_ctx) |
1588 | | }; |
1589 | 4 | let result = common::collect(merged_aggregate.execute(0, task_ctx)?0 ).await2 ?0 ; |
1590 | 4 | let batch = concat_batches(&result[0].schema(), &result)?0 ; |
1591 | 4 | assert_eq!(batch.num_columns(), 2); |
1592 | 4 | assert_eq!(batch.num_rows(), 3); |
1593 | | |
1594 | 4 | let expected = vec![ |
1595 | 4 | "+---+--------------------+", |
1596 | 4 | "| a | AVG(b) |", |
1597 | 4 | "+---+--------------------+", |
1598 | 4 | "| 2 | 1.0 |", |
1599 | 4 | "| 3 | 2.3333333333333335 |", // 3, (2 + 3 + 2) / 3 |
1600 | 4 | "| 4 | 3.6666666666666665 |", // 4, (3 + 4 + 4) / 3 |
1601 | 4 | "+---+--------------------+", |
1602 | 4 | ]; |
1603 | 4 | |
1604 | 4 | assert_batches_sorted_eq!(&expected, &result); |
1605 | | |
1606 | 4 | let metrics = merged_aggregate.metrics().unwrap(); |
1607 | 4 | let output_rows = metrics.output_rows().unwrap(); |
1608 | 4 | if spill { |
1609 | | // When spilling, the output rows metrics become partial output size + final output size |
1610 | | // This is because final aggregation starts while partial aggregation is still emitting |
1611 | 2 | assert_eq!(8, output_rows); |
1612 | | } else { |
1613 | 2 | assert_eq!(3, output_rows); |
1614 | | } |
1615 | | |
1616 | 4 | Ok(()) |
1617 | 4 | } |
1618 | | |
1619 | | /// Define a test source that can yield back to runtime before returning its first item /// |
1620 | | |
1621 | | #[derive(Debug)] |
1622 | | struct TestYieldingExec { |
1623 | | /// True if this exec should yield back to runtime the first time it is polled |
1624 | | pub yield_first: bool, |
1625 | | cache: PlanProperties, |
1626 | | } |
1627 | | |
1628 | | impl TestYieldingExec { |
1629 | 9 | fn new(yield_first: bool) -> Self { |
1630 | 9 | let schema = some_data().0; |
1631 | 9 | let cache = Self::compute_properties(schema); |
1632 | 9 | Self { yield_first, cache } |
1633 | 9 | } |
1634 | | |
1635 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
1636 | 9 | fn compute_properties(schema: SchemaRef) -> PlanProperties { |
1637 | 9 | let eq_properties = EquivalenceProperties::new(schema); |
1638 | 9 | PlanProperties::new( |
1639 | 9 | eq_properties, |
1640 | 9 | // Output Partitioning |
1641 | 9 | Partitioning::UnknownPartitioning(1), |
1642 | 9 | // Execution Mode |
1643 | 9 | ExecutionMode::Bounded, |
1644 | 9 | ) |
1645 | 9 | } |
1646 | | } |
1647 | | |
1648 | | impl DisplayAs for TestYieldingExec { |
1649 | 0 | fn fmt_as( |
1650 | 0 | &self, |
1651 | 0 | t: DisplayFormatType, |
1652 | 0 | f: &mut std::fmt::Formatter, |
1653 | 0 | ) -> std::fmt::Result { |
1654 | 0 | match t { |
1655 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
1656 | 0 | write!(f, "TestYieldingExec") |
1657 | 0 | } |
1658 | 0 | } |
1659 | 0 | } |
1660 | | } |
1661 | | |
1662 | | impl ExecutionPlan for TestYieldingExec { |
1663 | 0 | fn name(&self) -> &'static str { |
1664 | 0 | "TestYieldingExec" |
1665 | 0 | } |
1666 | | |
1667 | 0 | fn as_any(&self) -> &dyn Any { |
1668 | 0 | self |
1669 | 0 | } |
1670 | | |
1671 | 89 | fn properties(&self) -> &PlanProperties { |
1672 | 89 | &self.cache |
1673 | 89 | } |
1674 | | |
1675 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
1676 | 0 | vec![] |
1677 | 0 | } |
1678 | | |
1679 | 0 | fn with_new_children( |
1680 | 0 | self: Arc<Self>, |
1681 | 0 | _: Vec<Arc<dyn ExecutionPlan>>, |
1682 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
1683 | 0 | internal_err!("Children cannot be replaced in {self:?}") |
1684 | 0 | } |
1685 | | |
1686 | 18 | fn execute( |
1687 | 18 | &self, |
1688 | 18 | _partition: usize, |
1689 | 18 | _context: Arc<TaskContext>, |
1690 | 18 | ) -> Result<SendableRecordBatchStream> { |
1691 | 18 | let stream = if self.yield_first { |
1692 | 10 | TestYieldingStream::New |
1693 | | } else { |
1694 | 8 | TestYieldingStream::Yielded |
1695 | | }; |
1696 | | |
1697 | 18 | Ok(Box::pin(stream)) |
1698 | 18 | } |
1699 | | |
1700 | 0 | fn statistics(&self) -> Result<Statistics> { |
1701 | 0 | let (_, batches) = some_data(); |
1702 | 0 | Ok(common::compute_record_batch_statistics( |
1703 | 0 | &[batches], |
1704 | 0 | &self.schema(), |
1705 | 0 | None, |
1706 | 0 | )) |
1707 | 0 | } |
1708 | | } |
1709 | | |
1710 | | /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records |
1711 | | enum TestYieldingStream { |
1712 | | New, |
1713 | | Yielded, |
1714 | | ReturnedBatch1, |
1715 | | ReturnedBatch2, |
1716 | | } |
1717 | | |
1718 | | impl Stream for TestYieldingStream { |
1719 | | type Item = Result<RecordBatch>; |
1720 | | |
1721 | 60 | fn poll_next( |
1722 | 60 | mut self: std::pin::Pin<&mut Self>, |
1723 | 60 | cx: &mut Context<'_>, |
1724 | 60 | ) -> Poll<Option<Self::Item>> { |
1725 | 60 | match &*self { |
1726 | | TestYieldingStream::New => { |
1727 | 10 | *(self.as_mut()) = TestYieldingStream::Yielded; |
1728 | 10 | cx.waker().wake_by_ref(); |
1729 | 10 | Poll::Pending |
1730 | | } |
1731 | | TestYieldingStream::Yielded => { |
1732 | 18 | *(self.as_mut()) = TestYieldingStream::ReturnedBatch1; |
1733 | 18 | Poll::Ready(Some(Ok(some_data().1[0].clone()))) |
1734 | | } |
1735 | | TestYieldingStream::ReturnedBatch1 => { |
1736 | 16 | *(self.as_mut()) = TestYieldingStream::ReturnedBatch2; |
1737 | 16 | Poll::Ready(Some(Ok(some_data().1[1].clone()))) |
1738 | | } |
1739 | 16 | TestYieldingStream::ReturnedBatch2 => Poll::Ready(None), |
1740 | | } |
1741 | 60 | } |
1742 | | } |
1743 | | |
1744 | | impl RecordBatchStream for TestYieldingStream { |
1745 | 0 | fn schema(&self) -> SchemaRef { |
1746 | 0 | some_data().0 |
1747 | 0 | } |
1748 | | } |
1749 | | |
1750 | | //--- Tests ---// |
1751 | | |
1752 | | #[tokio::test] |
1753 | 1 | async fn aggregate_source_not_yielding() -> Result<()> { |
1754 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false)); |
1755 | 1 | |
1756 | 1 | check_aggregates(input, false).await0 |
1757 | 1 | } |
1758 | | |
1759 | | #[tokio::test] |
1760 | 1 | async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> { |
1761 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false)); |
1762 | 1 | |
1763 | 1 | check_grouping_sets(input, false).await0 |
1764 | 1 | } |
1765 | | |
1766 | | #[tokio::test] |
1767 | 1 | async fn aggregate_source_with_yielding() -> Result<()> { |
1768 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true)); |
1769 | 1 | |
1770 | 2 | check_aggregates(input, false).await |
1771 | 1 | } |
1772 | | |
1773 | | #[tokio::test] |
1774 | 1 | async fn aggregate_grouping_sets_with_yielding() -> Result<()> { |
1775 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true)); |
1776 | 1 | |
1777 | 2 | check_grouping_sets(input, false).await |
1778 | 1 | } |
1779 | | |
1780 | | #[tokio::test] |
1781 | 1 | async fn aggregate_source_not_yielding_with_spill() -> Result<()> { |
1782 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false)); |
1783 | 1 | |
1784 | 1 | check_aggregates(input, true).await0 |
1785 | 1 | } |
1786 | | |
1787 | | #[tokio::test] |
1788 | 1 | async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> { |
1789 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false)); |
1790 | 1 | |
1791 | 1 | check_grouping_sets(input, true).await0 |
1792 | 1 | } |
1793 | | |
1794 | | #[tokio::test] |
1795 | 1 | async fn aggregate_source_with_yielding_with_spill() -> Result<()> { |
1796 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true)); |
1797 | 1 | |
1798 | 2 | check_aggregates(input, true).await |
1799 | 1 | } |
1800 | | |
1801 | | #[tokio::test] |
1802 | 1 | async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> { |
1803 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true)); |
1804 | 1 | |
1805 | 2 | check_grouping_sets(input, true).await |
1806 | 1 | } |
1807 | | |
1808 | | // Median(a) |
1809 | 1 | fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> { |
1810 | 1 | AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?0 ]) |
1811 | 1 | .schema(schema) |
1812 | 1 | .alias("MEDIAN(a)") |
1813 | 1 | .build() |
1814 | 1 | } |
1815 | | |
1816 | | #[tokio::test] |
1817 | 1 | async fn test_oom() -> Result<()> { |
1818 | 1 | let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true)); |
1819 | 1 | let input_schema = input.schema(); |
1820 | 1 | |
1821 | 1 | let runtime = RuntimeEnvBuilder::default() |
1822 | 1 | .with_memory_limit(1, 1.0) |
1823 | 1 | .build_arc()?0 ; |
1824 | 1 | let task_ctx = TaskContext::default().with_runtime(runtime); |
1825 | 1 | let task_ctx = Arc::new(task_ctx); |
1826 | 1 | |
1827 | 1 | let groups_none = PhysicalGroupBy::default(); |
1828 | 1 | let groups_some = PhysicalGroupBy { |
1829 | 1 | expr: vec![(col("a", &input_schema)?0 , "a".to_string())], |
1830 | 1 | null_expr: vec![], |
1831 | 1 | groups: vec![vec![false]], |
1832 | 1 | }; |
1833 | 1 | |
1834 | 1 | // something that allocates within the aggregator |
1835 | 1 | let aggregates_v0: Vec<AggregateFunctionExpr> = |
1836 | 1 | vec![test_median_agg_expr(Arc::clone(&input_schema))?0 ]; |
1837 | 1 | |
1838 | 1 | // use fast-path in `row_hash.rs`. |
1839 | 1 | let aggregates_v2: Vec<AggregateFunctionExpr> = |
1840 | 1 | vec![ |
1841 | 1 | AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?0 ]) |
1842 | 1 | .schema(Arc::clone(&input_schema)) |
1843 | 1 | .alias("AVG(b)") |
1844 | 1 | .build()?0 , |
1845 | 1 | ]; |
1846 | 1 | |
1847 | 2 | for (version, groups, aggregates) in [ |
1848 | 1 | (0, groups_none, aggregates_v0), |
1849 | 1 | (2, groups_some, aggregates_v2), |
1850 | 1 | ] { |
1851 | 2 | let n_aggr = aggregates.len(); |
1852 | 2 | let partial_aggregate = Arc::new(AggregateExec::try_new( |
1853 | 2 | AggregateMode::Partial, |
1854 | 2 | groups, |
1855 | 2 | aggregates, |
1856 | 2 | vec![None; n_aggr], |
1857 | 2 | Arc::clone(&input), |
1858 | 2 | Arc::clone(&input_schema), |
1859 | 2 | )?0 ); |
1860 | 1 | |
1861 | 2 | let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?0 ; |
1862 | 1 | |
1863 | 1 | // ensure that we really got the version we wanted |
1864 | 2 | match version { |
1865 | 1 | 0 => { |
1866 | 1 | assert!(matches!0 (stream, StreamType::AggregateStream(_))); |
1867 | 1 | } |
1868 | 1 | 1 => { |
1869 | 1 | assert!0 (matches!0 (stream0 , StreamType::GroupedHash(_))); |
1870 | 1 | } |
1871 | 1 | 2 => { |
1872 | 1 | assert!(matches!0 (stream, StreamType::GroupedHash(_))); |
1873 | 1 | } |
1874 | 1 | _ => panic!("Unknown version: {version}")0 , |
1875 | 1 | } |
1876 | 1 | |
1877 | 2 | let stream: SendableRecordBatchStream = stream.into(); |
1878 | 2 | let err = common::collect(stream).await.unwrap_err(); |
1879 | 2 | |
1880 | 2 | // error root cause traversal is a bit complicated, see #4172. |
1881 | 2 | let err = err.find_root(); |
1882 | 2 | assert!( |
1883 | 2 | matches!0 (err, DataFusionError::ResourcesExhausted(_)), |
1884 | 1 | "Wrong error type: {err}"0 , |
1885 | 1 | ); |
1886 | 1 | } |
1887 | 1 | |
1888 | 1 | Ok(()) |
1889 | 1 | } |
1890 | | |
1891 | | #[tokio::test] |
1892 | 1 | async fn test_drop_cancel_without_groups() -> Result<()> { |
1893 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1894 | 1 | let schema = |
1895 | 1 | Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)])); |
1896 | 1 | |
1897 | 1 | let groups = PhysicalGroupBy::default(); |
1898 | 1 | |
1899 | 1 | let aggregates: Vec<AggregateFunctionExpr> = |
1900 | 1 | vec![ |
1901 | 1 | AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?0 ]) |
1902 | 1 | .schema(Arc::clone(&schema)) |
1903 | 1 | .alias("AVG(a)") |
1904 | 1 | .build()?0 , |
1905 | 1 | ]; |
1906 | 1 | |
1907 | 1 | let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); |
1908 | 1 | let refs = blocking_exec.refs(); |
1909 | 1 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
1910 | 1 | AggregateMode::Partial, |
1911 | 1 | groups.clone(), |
1912 | 1 | aggregates.clone(), |
1913 | 1 | vec![None], |
1914 | 1 | blocking_exec, |
1915 | 1 | schema, |
1916 | 1 | )?0 ); |
1917 | 1 | |
1918 | 1 | let fut = crate::collect(aggregate_exec, task_ctx); |
1919 | 1 | let mut fut = fut.boxed(); |
1920 | 1 | |
1921 | 1 | assert_is_pending(&mut fut); |
1922 | 1 | drop(fut); |
1923 | 1 | assert_strong_count_converges_to_zero(refs).await0 ; |
1924 | 1 | |
1925 | 1 | Ok(()) |
1926 | 1 | } |
1927 | | |
1928 | | #[tokio::test] |
1929 | 1 | async fn test_drop_cancel_with_groups() -> Result<()> { |
1930 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1931 | 1 | let schema = Arc::new(Schema::new(vec![ |
1932 | 1 | Field::new("a", DataType::Float64, true), |
1933 | 1 | Field::new("b", DataType::Float64, true), |
1934 | 1 | ])); |
1935 | 1 | |
1936 | 1 | let groups = |
1937 | 1 | PhysicalGroupBy::new_single(vec![(col("a", &schema)?0 , "a".to_string())]); |
1938 | 1 | |
1939 | 1 | let aggregates: Vec<AggregateFunctionExpr> = |
1940 | 1 | vec![ |
1941 | 1 | AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?0 ]) |
1942 | 1 | .schema(Arc::clone(&schema)) |
1943 | 1 | .alias("AVG(b)") |
1944 | 1 | .build()?0 , |
1945 | 1 | ]; |
1946 | 1 | |
1947 | 1 | let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); |
1948 | 1 | let refs = blocking_exec.refs(); |
1949 | 1 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
1950 | 1 | AggregateMode::Partial, |
1951 | 1 | groups, |
1952 | 1 | aggregates.clone(), |
1953 | 1 | vec![None], |
1954 | 1 | blocking_exec, |
1955 | 1 | schema, |
1956 | 1 | )?0 ); |
1957 | 1 | |
1958 | 1 | let fut = crate::collect(aggregate_exec, task_ctx); |
1959 | 1 | let mut fut = fut.boxed(); |
1960 | 1 | |
1961 | 1 | assert_is_pending(&mut fut); |
1962 | 1 | drop(fut); |
1963 | 1 | assert_strong_count_converges_to_zero(refs).await0 ; |
1964 | 1 | |
1965 | 1 | Ok(()) |
1966 | 1 | } |
1967 | | |
1968 | | #[tokio::test] |
1969 | 1 | async fn run_first_last_multi_partitions() -> Result<()> { |
1970 | 3 | for use_coalesce_batches2 in [false, true] { |
1971 | 6 | for is_first_acc4 in [false, true] { |
1972 | 12 | for spill8 in [false, true] { |
1973 | 8 | first_last_multi_partitions( |
1974 | 8 | use_coalesce_batches, |
1975 | 8 | is_first_acc, |
1976 | 8 | spill, |
1977 | 8 | 4200, |
1978 | 8 | ) |
1979 | 72 | .await?0 |
1980 | 1 | } |
1981 | 1 | } |
1982 | 1 | } |
1983 | 1 | Ok(()) |
1984 | 1 | } |
1985 | | |
1986 | | // FIRST_VALUE(b ORDER BY b <SortOptions>) |
1987 | 5 | fn test_first_value_agg_expr( |
1988 | 5 | schema: &Schema, |
1989 | 5 | sort_options: SortOptions, |
1990 | 5 | ) -> Result<AggregateFunctionExpr> { |
1991 | 5 | let ordering_req = [PhysicalSortExpr { |
1992 | 5 | expr: col("b", schema)?0 , |
1993 | 5 | options: sort_options, |
1994 | | }]; |
1995 | 5 | let args = [col("b", schema)?0 ]; |
1996 | | |
1997 | 5 | AggregateExprBuilder::new(first_value_udaf(), args.to_vec()) |
1998 | 5 | .order_by(ordering_req.to_vec()) |
1999 | 5 | .schema(Arc::new(schema.clone())) |
2000 | 5 | .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) |
2001 | 5 | .build() |
2002 | 5 | } |
2003 | | |
2004 | | // LAST_VALUE(b ORDER BY b <SortOptions>) |
2005 | 5 | fn test_last_value_agg_expr( |
2006 | 5 | schema: &Schema, |
2007 | 5 | sort_options: SortOptions, |
2008 | 5 | ) -> Result<AggregateFunctionExpr> { |
2009 | 5 | let ordering_req = [PhysicalSortExpr { |
2010 | 5 | expr: col("b", schema)?0 , |
2011 | 5 | options: sort_options, |
2012 | | }]; |
2013 | 5 | let args = [col("b", schema)?0 ]; |
2014 | 5 | AggregateExprBuilder::new(last_value_udaf(), args.to_vec()) |
2015 | 5 | .order_by(ordering_req.to_vec()) |
2016 | 5 | .schema(Arc::new(schema.clone())) |
2017 | 5 | .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) |
2018 | 5 | .build() |
2019 | 5 | } |
2020 | | |
2021 | | // This function either constructs the physical plan below, |
2022 | | // |
2023 | | // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", |
2024 | | // " CoalesceBatchesExec: target_batch_size=1024", |
2025 | | // " CoalescePartitionsExec", |
2026 | | // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", |
2027 | | // " MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]", |
2028 | | // |
2029 | | // or |
2030 | | // |
2031 | | // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", |
2032 | | // " CoalescePartitionsExec", |
2033 | | // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", |
2034 | | // " MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]", |
2035 | | // |
2036 | | // and checks whether the function `merge_batch` works correctly for |
2037 | | // FIRST_VALUE and LAST_VALUE functions. |
2038 | 8 | async fn first_last_multi_partitions( |
2039 | 8 | use_coalesce_batches: bool, |
2040 | 8 | is_first_acc: bool, |
2041 | 8 | spill: bool, |
2042 | 8 | max_memory: usize, |
2043 | 8 | ) -> Result<()> { |
2044 | 8 | let task_ctx = if spill { |
2045 | 4 | new_spill_ctx(2, max_memory) |
2046 | | } else { |
2047 | 4 | Arc::new(TaskContext::default()) |
2048 | | }; |
2049 | | |
2050 | 8 | let (schema, data) = some_data_v2(); |
2051 | 8 | let partition1 = data[0].clone(); |
2052 | 8 | let partition2 = data[1].clone(); |
2053 | 8 | let partition3 = data[2].clone(); |
2054 | 8 | let partition4 = data[3].clone(); |
2055 | | |
2056 | 8 | let groups = |
2057 | 8 | PhysicalGroupBy::new_single(vec![(col("a", &schema)?0 , "a".to_string())]); |
2058 | 8 | |
2059 | 8 | let sort_options = SortOptions { |
2060 | 8 | descending: false, |
2061 | 8 | nulls_first: false, |
2062 | 8 | }; |
2063 | 8 | let aggregates: Vec<AggregateFunctionExpr> = if is_first_acc { |
2064 | 4 | vec![test_first_value_agg_expr(&schema, sort_options)?0 ] |
2065 | | } else { |
2066 | 4 | vec![test_last_value_agg_expr(&schema, sort_options)?0 ] |
2067 | | }; |
2068 | | |
2069 | 8 | let memory_exec = Arc::new(MemoryExec::try_new( |
2070 | 8 | &[ |
2071 | 8 | vec![partition1], |
2072 | 8 | vec![partition2], |
2073 | 8 | vec![partition3], |
2074 | 8 | vec![partition4], |
2075 | 8 | ], |
2076 | 8 | Arc::clone(&schema), |
2077 | 8 | None, |
2078 | 8 | )?0 ); |
2079 | 8 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
2080 | 8 | AggregateMode::Partial, |
2081 | 8 | groups.clone(), |
2082 | 8 | aggregates.clone(), |
2083 | 8 | vec![None], |
2084 | 8 | memory_exec, |
2085 | 8 | Arc::clone(&schema), |
2086 | 8 | )?0 ); |
2087 | 8 | let coalesce = if use_coalesce_batches { |
2088 | 4 | let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)); |
2089 | 4 | Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan> |
2090 | | } else { |
2091 | 4 | Arc::new(CoalescePartitionsExec::new(aggregate_exec)) |
2092 | 4 | as Arc<dyn ExecutionPlan> |
2093 | | }; |
2094 | 8 | let aggregate_final = Arc::new(AggregateExec::try_new( |
2095 | 8 | AggregateMode::Final, |
2096 | 8 | groups, |
2097 | 8 | aggregates.clone(), |
2098 | 8 | vec![None], |
2099 | 8 | coalesce, |
2100 | 8 | schema, |
2101 | 8 | )?0 ) as Arc<dyn ExecutionPlan>; |
2102 | | |
2103 | 72 | let result8 = crate::collect(aggregate_final, task_ctx)8 .await?0 ; |
2104 | 8 | if is_first_acc { |
2105 | 4 | let expected = [ |
2106 | 4 | "+---+--------------------------------------------+", |
2107 | 4 | "| a | first_value(b) ORDER BY [b ASC NULLS LAST] |", |
2108 | 4 | "+---+--------------------------------------------+", |
2109 | 4 | "| 2 | 0.0 |", |
2110 | 4 | "| 3 | 1.0 |", |
2111 | 4 | "| 4 | 3.0 |", |
2112 | 4 | "+---+--------------------------------------------+", |
2113 | 4 | ]; |
2114 | 4 | assert_batches_eq!(expected, &result); |
2115 | | } else { |
2116 | 4 | let expected = [ |
2117 | 4 | "+---+-------------------------------------------+", |
2118 | 4 | "| a | last_value(b) ORDER BY [b ASC NULLS LAST] |", |
2119 | 4 | "+---+-------------------------------------------+", |
2120 | 4 | "| 2 | 3.0 |", |
2121 | 4 | "| 3 | 5.0 |", |
2122 | 4 | "| 4 | 6.0 |", |
2123 | 4 | "+---+-------------------------------------------+", |
2124 | 4 | ]; |
2125 | 4 | assert_batches_eq!(expected, &result); |
2126 | | }; |
2127 | 8 | Ok(()) |
2128 | 8 | } |
2129 | | |
2130 | | #[tokio::test] |
2131 | 1 | async fn test_get_finest_requirements() -> Result<()> { |
2132 | 1 | let test_schema = create_test_schema()?0 ; |
2133 | 1 | |
2134 | 1 | // Assume column a and b are aliases |
2135 | 1 | // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). |
2136 | 1 | let options1 = SortOptions { |
2137 | 1 | descending: false, |
2138 | 1 | nulls_first: false, |
2139 | 1 | }; |
2140 | 1 | let col_a = &col("a", &test_schema)?0 ; |
2141 | 1 | let col_b = &col("b", &test_schema)?0 ; |
2142 | 1 | let col_c = &col("c", &test_schema)?0 ; |
2143 | 1 | let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); |
2144 | 1 | // Columns a and b are equal. |
2145 | 1 | eq_properties.add_equal_conditions(col_a, col_b)?0 ; |
2146 | 1 | // Aggregate requirements are |
2147 | 1 | // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively |
2148 | 1 | let order_by_exprs = vec![ |
2149 | 1 | None, |
2150 | 1 | Some(vec![PhysicalSortExpr { |
2151 | 1 | expr: Arc::clone(col_a), |
2152 | 1 | options: options1, |
2153 | 1 | }]), |
2154 | 1 | Some(vec![ |
2155 | 1 | PhysicalSortExpr { |
2156 | 1 | expr: Arc::clone(col_a), |
2157 | 1 | options: options1, |
2158 | 1 | }, |
2159 | 1 | PhysicalSortExpr { |
2160 | 1 | expr: Arc::clone(col_b), |
2161 | 1 | options: options1, |
2162 | 1 | }, |
2163 | 1 | PhysicalSortExpr { |
2164 | 1 | expr: Arc::clone(col_c), |
2165 | 1 | options: options1, |
2166 | 1 | }, |
2167 | 1 | ]), |
2168 | 1 | Some(vec![ |
2169 | 1 | PhysicalSortExpr { |
2170 | 1 | expr: Arc::clone(col_a), |
2171 | 1 | options: options1, |
2172 | 1 | }, |
2173 | 1 | PhysicalSortExpr { |
2174 | 1 | expr: Arc::clone(col_b), |
2175 | 1 | options: options1, |
2176 | 1 | }, |
2177 | 1 | ]), |
2178 | 1 | ]; |
2179 | 1 | |
2180 | 1 | let common_requirement = vec![ |
2181 | 1 | PhysicalSortExpr { |
2182 | 1 | expr: Arc::clone(col_a), |
2183 | 1 | options: options1, |
2184 | 1 | }, |
2185 | 1 | PhysicalSortExpr { |
2186 | 1 | expr: Arc::clone(col_c), |
2187 | 1 | options: options1, |
2188 | 1 | }, |
2189 | 1 | ]; |
2190 | 1 | let mut aggr_exprs = order_by_exprs |
2191 | 1 | .into_iter() |
2192 | 4 | .map(|order_by_expr| { |
2193 | 4 | let ordering_req = order_by_expr.unwrap_or_default(); |
2194 | 4 | AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) |
2195 | 4 | .alias("a") |
2196 | 4 | .order_by(ordering_req.to_vec()) |
2197 | 4 | .schema(Arc::clone(&test_schema)) |
2198 | 4 | .build() |
2199 | 4 | .unwrap() |
2200 | 4 | }) |
2201 | 1 | .collect::<Vec<_>>(); |
2202 | 1 | let group_by = PhysicalGroupBy::new_single(vec![]); |
2203 | 1 | let res = get_finer_aggregate_exprs_requirement( |
2204 | 1 | &mut aggr_exprs, |
2205 | 1 | &group_by, |
2206 | 1 | &eq_properties, |
2207 | 1 | &AggregateMode::Partial, |
2208 | 1 | )?0 ; |
2209 | 1 | let res = PhysicalSortRequirement::to_sort_exprs(res); |
2210 | 1 | assert_eq!(res, common_requirement); |
2211 | 1 | Ok(()) |
2212 | 1 | } |
2213 | | |
2214 | | #[test] |
2215 | 1 | fn test_agg_exec_same_schema() -> Result<()> { |
2216 | 1 | let schema = Arc::new(Schema::new(vec![ |
2217 | 1 | Field::new("a", DataType::Float32, true), |
2218 | 1 | Field::new("b", DataType::Float32, true), |
2219 | 1 | ])); |
2220 | | |
2221 | 1 | let col_a = col("a", &schema)?0 ; |
2222 | 1 | let option_desc = SortOptions { |
2223 | 1 | descending: true, |
2224 | 1 | nulls_first: true, |
2225 | 1 | }; |
2226 | 1 | let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); |
2227 | | |
2228 | 1 | let aggregates: Vec<AggregateFunctionExpr> = vec![ |
2229 | 1 | test_first_value_agg_expr(&schema, option_desc)?0 , |
2230 | 1 | test_last_value_agg_expr(&schema, option_desc)?0 , |
2231 | | ]; |
2232 | 1 | let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); |
2233 | 1 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
2234 | 1 | AggregateMode::Partial, |
2235 | 1 | groups, |
2236 | 1 | aggregates, |
2237 | 1 | vec![None, None], |
2238 | 1 | Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>, |
2239 | 1 | schema, |
2240 | 1 | )?0 ); |
2241 | 1 | let new_agg = |
2242 | 1 | Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?0 ; |
2243 | 1 | assert_eq!(new_agg.schema(), aggregate_exec.schema()); |
2244 | 1 | Ok(()) |
2245 | 1 | } |
2246 | | |
2247 | | #[tokio::test] |
2248 | 1 | async fn test_agg_exec_group_by_const() -> Result<()> { |
2249 | 1 | let schema = Arc::new(Schema::new(vec![ |
2250 | 1 | Field::new("a", DataType::Float32, true), |
2251 | 1 | Field::new("b", DataType::Float32, true), |
2252 | 1 | Field::new("const", DataType::Int32, false), |
2253 | 1 | ])); |
2254 | 1 | |
2255 | 1 | let col_a = col("a", &schema)?0 ; |
2256 | 1 | let col_b = col("b", &schema)?0 ; |
2257 | 1 | let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); |
2258 | 1 | |
2259 | 1 | let groups = PhysicalGroupBy::new( |
2260 | 1 | vec![ |
2261 | 1 | (col_a, "a".to_string()), |
2262 | 1 | (col_b, "b".to_string()), |
2263 | 1 | (const_expr, "const".to_string()), |
2264 | 1 | ], |
2265 | 1 | vec![ |
2266 | 1 | ( |
2267 | 1 | Arc::new(Literal::new(ScalarValue::Float32(None))), |
2268 | 1 | "a".to_string(), |
2269 | 1 | ), |
2270 | 1 | ( |
2271 | 1 | Arc::new(Literal::new(ScalarValue::Float32(None))), |
2272 | 1 | "b".to_string(), |
2273 | 1 | ), |
2274 | 1 | ( |
2275 | 1 | Arc::new(Literal::new(ScalarValue::Int32(None))), |
2276 | 1 | "const".to_string(), |
2277 | 1 | ), |
2278 | 1 | ], |
2279 | 1 | vec![ |
2280 | 1 | vec![false, true, true], |
2281 | 1 | vec![true, false, true], |
2282 | 1 | vec![true, true, false], |
2283 | 1 | ], |
2284 | 1 | ); |
2285 | 1 | |
2286 | 1 | let aggregates: Vec<AggregateFunctionExpr> = |
2287 | 1 | vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) |
2288 | 1 | .schema(Arc::clone(&schema)) |
2289 | 1 | .alias("1") |
2290 | 1 | .build()?0 ]; |
2291 | 1 | |
2292 | 1 | let input_batches = (0..4) |
2293 | 4 | .map(|_| { |
2294 | 4 | let a = Arc::new(Float32Array::from(vec![0.; 8192])); |
2295 | 4 | let b = Arc::new(Float32Array::from(vec![0.; 8192])); |
2296 | 4 | let c = Arc::new(Int32Array::from(vec![1; 8192])); |
2297 | 4 | |
2298 | 4 | RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap() |
2299 | 4 | }) |
2300 | 1 | .collect(); |
2301 | 1 | |
2302 | 1 | let input = Arc::new(MemoryExec::try_new( |
2303 | 1 | &[input_batches], |
2304 | 1 | Arc::clone(&schema), |
2305 | 1 | None, |
2306 | 1 | )?0 ); |
2307 | 1 | |
2308 | 1 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
2309 | 1 | AggregateMode::Partial, |
2310 | 1 | groups, |
2311 | 1 | aggregates.clone(), |
2312 | 1 | vec![None], |
2313 | 1 | input, |
2314 | 1 | schema, |
2315 | 1 | )?0 ); |
2316 | 1 | |
2317 | 1 | let output = |
2318 | 1 | collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?0 ).await0 ?0 ; |
2319 | 1 | |
2320 | 1 | let expected = [ |
2321 | 1 | "+-----+-----+-------+----------+", |
2322 | 1 | "| a | b | const | 1[count] |", |
2323 | 1 | "+-----+-----+-------+----------+", |
2324 | 1 | "| | 0.0 | | 32768 |", |
2325 | 1 | "| 0.0 | | | 32768 |", |
2326 | 1 | "| | | 1 | 32768 |", |
2327 | 1 | "+-----+-----+-------+----------+", |
2328 | 1 | ]; |
2329 | 1 | assert_batches_sorted_eq!(expected, &output); |
2330 | 1 | |
2331 | 1 | Ok(()) |
2332 | 1 | } |
2333 | | |
2334 | | #[tokio::test] |
2335 | 1 | async fn test_agg_exec_struct_of_dicts() -> Result<()> { |
2336 | 1 | let batch = RecordBatch::try_new( |
2337 | 1 | Arc::new(Schema::new(vec![ |
2338 | 1 | Field::new( |
2339 | 1 | "labels".to_string(), |
2340 | 1 | DataType::Struct( |
2341 | 1 | vec![ |
2342 | 1 | Field::new_dict( |
2343 | 1 | "a".to_string(), |
2344 | 1 | DataType::Dictionary( |
2345 | 1 | Box::new(DataType::Int32), |
2346 | 1 | Box::new(DataType::Utf8), |
2347 | 1 | ), |
2348 | 1 | true, |
2349 | 1 | 0, |
2350 | 1 | false, |
2351 | 1 | ), |
2352 | 1 | Field::new_dict( |
2353 | 1 | "b".to_string(), |
2354 | 1 | DataType::Dictionary( |
2355 | 1 | Box::new(DataType::Int32), |
2356 | 1 | Box::new(DataType::Utf8), |
2357 | 1 | ), |
2358 | 1 | true, |
2359 | 1 | 0, |
2360 | 1 | false, |
2361 | 1 | ), |
2362 | 1 | ] |
2363 | 1 | .into(), |
2364 | 1 | ), |
2365 | 1 | false, |
2366 | 1 | ), |
2367 | 1 | Field::new("value", DataType::UInt64, false), |
2368 | 1 | ])), |
2369 | 1 | vec![ |
2370 | 1 | Arc::new(StructArray::from(vec![ |
2371 | 1 | ( |
2372 | 1 | Arc::new(Field::new_dict( |
2373 | 1 | "a".to_string(), |
2374 | 1 | DataType::Dictionary( |
2375 | 1 | Box::new(DataType::Int32), |
2376 | 1 | Box::new(DataType::Utf8), |
2377 | 1 | ), |
2378 | 1 | true, |
2379 | 1 | 0, |
2380 | 1 | false, |
2381 | 1 | )), |
2382 | 1 | Arc::new( |
2383 | 1 | vec![Some("a"), None, Some("a")] |
2384 | 1 | .into_iter() |
2385 | 1 | .collect::<DictionaryArray<Int32Type>>(), |
2386 | 1 | ) as ArrayRef, |
2387 | 1 | ), |
2388 | 1 | ( |
2389 | 1 | Arc::new(Field::new_dict( |
2390 | 1 | "b".to_string(), |
2391 | 1 | DataType::Dictionary( |
2392 | 1 | Box::new(DataType::Int32), |
2393 | 1 | Box::new(DataType::Utf8), |
2394 | 1 | ), |
2395 | 1 | true, |
2396 | 1 | 0, |
2397 | 1 | false, |
2398 | 1 | )), |
2399 | 1 | Arc::new( |
2400 | 1 | vec![Some("b"), Some("c"), Some("b")] |
2401 | 1 | .into_iter() |
2402 | 1 | .collect::<DictionaryArray<Int32Type>>(), |
2403 | 1 | ) as ArrayRef, |
2404 | 1 | ), |
2405 | 1 | ])), |
2406 | 1 | Arc::new(UInt64Array::from(vec![1, 1, 1])), |
2407 | 1 | ], |
2408 | 1 | ) |
2409 | 1 | .expect("Failed to create RecordBatch"); |
2410 | 1 | |
2411 | 1 | let group_by = PhysicalGroupBy::new_single(vec![( |
2412 | 1 | col("labels", &batch.schema())?0 , |
2413 | 1 | "labels".to_string(), |
2414 | 1 | )]); |
2415 | 1 | |
2416 | 1 | let aggr_expr = vec![AggregateExprBuilder::new( |
2417 | 1 | sum_udaf(), |
2418 | 1 | vec![col("value", &batch.schema())?0 ], |
2419 | 1 | ) |
2420 | 1 | .schema(Arc::clone(&batch.schema())) |
2421 | 1 | .alias(String::from("SUM(value)")) |
2422 | 1 | .build()?0 ]; |
2423 | 1 | |
2424 | 1 | let input = Arc::new(MemoryExec::try_new( |
2425 | 1 | &[vec![batch.clone()]], |
2426 | 1 | Arc::<arrow_schema::Schema>::clone(&batch.schema()), |
2427 | 1 | None, |
2428 | 1 | )?0 ); |
2429 | 1 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
2430 | 1 | AggregateMode::FinalPartitioned, |
2431 | 1 | group_by, |
2432 | 1 | aggr_expr, |
2433 | 1 | vec![None], |
2434 | 1 | Arc::clone(&input) as Arc<dyn ExecutionPlan>, |
2435 | 1 | batch.schema(), |
2436 | 1 | )?0 ); |
2437 | 1 | |
2438 | 1 | let session_config = SessionConfig::default(); |
2439 | 1 | let ctx = TaskContext::default().with_session_config(session_config); |
2440 | 1 | let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?0 ).await0 ?0 ; |
2441 | 1 | |
2442 | 1 | let expected = [ |
2443 | 1 | "+--------------+------------+", |
2444 | 1 | "| labels | SUM(value) |", |
2445 | 1 | "+--------------+------------+", |
2446 | 1 | "| {a: a, b: b} | 2 |", |
2447 | 1 | "| {a: , b: c} | 1 |", |
2448 | 1 | "+--------------+------------+", |
2449 | 1 | ]; |
2450 | 1 | assert_batches_eq!(expected, &output); |
2451 | 1 | |
2452 | 1 | Ok(()) |
2453 | 1 | } |
2454 | | |
2455 | | #[tokio::test] |
2456 | 1 | async fn test_skip_aggregation_after_first_batch() -> Result<()> { |
2457 | 1 | let schema = Arc::new(Schema::new(vec![ |
2458 | 1 | Field::new("key", DataType::Int32, true), |
2459 | 1 | Field::new("val", DataType::Int32, true), |
2460 | 1 | ])); |
2461 | 1 | |
2462 | 1 | let group_by = |
2463 | 1 | PhysicalGroupBy::new_single(vec![(col("key", &schema)?0 , "key".to_string())]); |
2464 | 1 | |
2465 | 1 | let aggr_expr = |
2466 | 1 | vec![ |
2467 | 1 | AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?0 ]) |
2468 | 1 | .schema(Arc::clone(&schema)) |
2469 | 1 | .alias(String::from("COUNT(val)")) |
2470 | 1 | .build()?0 , |
2471 | 1 | ]; |
2472 | 1 | |
2473 | 1 | let input_data = vec![ |
2474 | 1 | RecordBatch::try_new( |
2475 | 1 | Arc::clone(&schema), |
2476 | 1 | vec![ |
2477 | 1 | Arc::new(Int32Array::from(vec![1, 2, 3])), |
2478 | 1 | Arc::new(Int32Array::from(vec![0, 0, 0])), |
2479 | 1 | ], |
2480 | 1 | ) |
2481 | 1 | .unwrap(), |
2482 | 1 | RecordBatch::try_new( |
2483 | 1 | Arc::clone(&schema), |
2484 | 1 | vec![ |
2485 | 1 | Arc::new(Int32Array::from(vec![2, 3, 4])), |
2486 | 1 | Arc::new(Int32Array::from(vec![0, 0, 0])), |
2487 | 1 | ], |
2488 | 1 | ) |
2489 | 1 | .unwrap(), |
2490 | 1 | ]; |
2491 | 1 | |
2492 | 1 | let input = Arc::new(MemoryExec::try_new( |
2493 | 1 | &[input_data], |
2494 | 1 | Arc::clone(&schema), |
2495 | 1 | None, |
2496 | 1 | )?0 ); |
2497 | 1 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
2498 | 1 | AggregateMode::Partial, |
2499 | 1 | group_by, |
2500 | 1 | aggr_expr, |
2501 | 1 | vec![None], |
2502 | 1 | Arc::clone(&input) as Arc<dyn ExecutionPlan>, |
2503 | 1 | schema, |
2504 | 1 | )?0 ); |
2505 | 1 | |
2506 | 1 | let mut session_config = SessionConfig::default(); |
2507 | 1 | session_config = session_config.set( |
2508 | 1 | "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", |
2509 | 1 | &ScalarValue::Int64(Some(2)), |
2510 | 1 | ); |
2511 | 1 | session_config = session_config.set( |
2512 | 1 | "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", |
2513 | 1 | &ScalarValue::Float64(Some(0.1)), |
2514 | 1 | ); |
2515 | 1 | |
2516 | 1 | let ctx = TaskContext::default().with_session_config(session_config); |
2517 | 1 | let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?0 ).await0 ?0 ; |
2518 | 1 | |
2519 | 1 | let expected = [ |
2520 | 1 | "+-----+-------------------+", |
2521 | 1 | "| key | COUNT(val)[count] |", |
2522 | 1 | "+-----+-------------------+", |
2523 | 1 | "| 1 | 1 |", |
2524 | 1 | "| 2 | 1 |", |
2525 | 1 | "| 3 | 1 |", |
2526 | 1 | "| 2 | 1 |", |
2527 | 1 | "| 3 | 1 |", |
2528 | 1 | "| 4 | 1 |", |
2529 | 1 | "+-----+-------------------+", |
2530 | 1 | ]; |
2531 | 1 | assert_batches_eq!(expected, &output); |
2532 | 1 | |
2533 | 1 | Ok(()) |
2534 | 1 | } |
2535 | | |
2536 | | #[tokio::test] |
2537 | 1 | async fn test_skip_aggregation_after_threshold() -> Result<()> { |
2538 | 1 | let schema = Arc::new(Schema::new(vec![ |
2539 | 1 | Field::new("key", DataType::Int32, true), |
2540 | 1 | Field::new("val", DataType::Int32, true), |
2541 | 1 | ])); |
2542 | 1 | |
2543 | 1 | let group_by = |
2544 | 1 | PhysicalGroupBy::new_single(vec![(col("key", &schema)?0 , "key".to_string())]); |
2545 | 1 | |
2546 | 1 | let aggr_expr = |
2547 | 1 | vec![ |
2548 | 1 | AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?0 ]) |
2549 | 1 | .schema(Arc::clone(&schema)) |
2550 | 1 | .alias(String::from("COUNT(val)")) |
2551 | 1 | .build()?0 , |
2552 | 1 | ]; |
2553 | 1 | |
2554 | 1 | let input_data = vec![ |
2555 | 1 | RecordBatch::try_new( |
2556 | 1 | Arc::clone(&schema), |
2557 | 1 | vec![ |
2558 | 1 | Arc::new(Int32Array::from(vec![1, 2, 3])), |
2559 | 1 | Arc::new(Int32Array::from(vec![0, 0, 0])), |
2560 | 1 | ], |
2561 | 1 | ) |
2562 | 1 | .unwrap(), |
2563 | 1 | RecordBatch::try_new( |
2564 | 1 | Arc::clone(&schema), |
2565 | 1 | vec![ |
2566 | 1 | Arc::new(Int32Array::from(vec![2, 3, 4])), |
2567 | 1 | Arc::new(Int32Array::from(vec![0, 0, 0])), |
2568 | 1 | ], |
2569 | 1 | ) |
2570 | 1 | .unwrap(), |
2571 | 1 | RecordBatch::try_new( |
2572 | 1 | Arc::clone(&schema), |
2573 | 1 | vec![ |
2574 | 1 | Arc::new(Int32Array::from(vec![2, 3, 4])), |
2575 | 1 | Arc::new(Int32Array::from(vec![0, 0, 0])), |
2576 | 1 | ], |
2577 | 1 | ) |
2578 | 1 | .unwrap(), |
2579 | 1 | ]; |
2580 | 1 | |
2581 | 1 | let input = Arc::new(MemoryExec::try_new( |
2582 | 1 | &[input_data], |
2583 | 1 | Arc::clone(&schema), |
2584 | 1 | None, |
2585 | 1 | )?0 ); |
2586 | 1 | let aggregate_exec = Arc::new(AggregateExec::try_new( |
2587 | 1 | AggregateMode::Partial, |
2588 | 1 | group_by, |
2589 | 1 | aggr_expr, |
2590 | 1 | vec![None], |
2591 | 1 | Arc::clone(&input) as Arc<dyn ExecutionPlan>, |
2592 | 1 | schema, |
2593 | 1 | )?0 ); |
2594 | 1 | |
2595 | 1 | let mut session_config = SessionConfig::default(); |
2596 | 1 | session_config = session_config.set( |
2597 | 1 | "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", |
2598 | 1 | &ScalarValue::Int64(Some(5)), |
2599 | 1 | ); |
2600 | 1 | session_config = session_config.set( |
2601 | 1 | "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", |
2602 | 1 | &ScalarValue::Float64(Some(0.1)), |
2603 | 1 | ); |
2604 | 1 | |
2605 | 1 | let ctx = TaskContext::default().with_session_config(session_config); |
2606 | 1 | let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?0 ).await0 ?0 ; |
2607 | 1 | |
2608 | 1 | let expected = [ |
2609 | 1 | "+-----+-------------------+", |
2610 | 1 | "| key | COUNT(val)[count] |", |
2611 | 1 | "+-----+-------------------+", |
2612 | 1 | "| 1 | 1 |", |
2613 | 1 | "| 2 | 2 |", |
2614 | 1 | "| 3 | 2 |", |
2615 | 1 | "| 4 | 1 |", |
2616 | 1 | "| 2 | 1 |", |
2617 | 1 | "| 3 | 1 |", |
2618 | 1 | "| 4 | 1 |", |
2619 | 1 | "+-----+-------------------+", |
2620 | 1 | ]; |
2621 | 1 | assert_batches_eq!(expected, &output); |
2622 | 1 | |
2623 | 1 | Ok(()) |
2624 | 1 | } |
2625 | | |
2626 | | #[test] |
2627 | 1 | fn group_exprs_nullable() -> Result<()> { |
2628 | 1 | let input_schema = Arc::new(Schema::new(vec![ |
2629 | 1 | Field::new("a", DataType::Float32, false), |
2630 | 1 | Field::new("b", DataType::Float32, false), |
2631 | 1 | ])); |
2632 | | |
2633 | 1 | let aggr_expr = |
2634 | 1 | vec![ |
2635 | 1 | AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?0 ]) |
2636 | 1 | .schema(Arc::clone(&input_schema)) |
2637 | 1 | .alias("COUNT(a)") |
2638 | 1 | .build()?0 , |
2639 | | ]; |
2640 | | |
2641 | 1 | let grouping_set = PhysicalGroupBy { |
2642 | 1 | expr: vec![ |
2643 | 1 | (col("a", &input_schema)?0 , "a".to_string()), |
2644 | 1 | (col("b", &input_schema)?0 , "b".to_string()), |
2645 | 1 | ], |
2646 | 1 | null_expr: vec![ |
2647 | 1 | (lit(ScalarValue::Float32(None)), "a".to_string()), |
2648 | 1 | (lit(ScalarValue::Float32(None)), "b".to_string()), |
2649 | 1 | ], |
2650 | 1 | groups: vec![ |
2651 | 1 | vec![false, true], // (a, NULL) |
2652 | 1 | vec![false, false], // (a,b) |
2653 | 1 | ], |
2654 | | }; |
2655 | 1 | let aggr_schema = create_schema( |
2656 | 1 | &input_schema, |
2657 | 1 | &grouping_set.expr, |
2658 | 1 | &aggr_expr, |
2659 | 1 | grouping_set.exprs_nullable(), |
2660 | 1 | AggregateMode::Final, |
2661 | 1 | )?0 ; |
2662 | 1 | let expected_schema = Schema::new(vec![ |
2663 | 1 | Field::new("a", DataType::Float32, false), |
2664 | 1 | Field::new("b", DataType::Float32, true), |
2665 | 1 | Field::new("COUNT(a)", DataType::Int64, false), |
2666 | 1 | ]); |
2667 | 1 | assert_eq!(aggr_schema, expected_schema); |
2668 | 1 | Ok(()) |
2669 | 1 | } |
2670 | | } |