/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/aggregate.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Physical exec for aggregate window function expressions. |
19 | | |
20 | | use std::any::Any; |
21 | | use std::ops::Range; |
22 | | use std::sync::Arc; |
23 | | |
24 | | use arrow::array::Array; |
25 | | use arrow::record_batch::RecordBatch; |
26 | | use arrow::{array::ArrayRef, datatypes::Field}; |
27 | | |
28 | | use datafusion_common::ScalarValue; |
29 | | use datafusion_common::{DataFusionError, Result}; |
30 | | use datafusion_expr::{Accumulator, WindowFrame}; |
31 | | |
32 | | use crate::aggregate::AggregateFunctionExpr; |
33 | | use crate::window::window_expr::AggregateWindowExpr; |
34 | | use crate::window::{ |
35 | | PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr, |
36 | | }; |
37 | | use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; |
38 | | |
39 | | /// A window expr that takes the form of an aggregate function. |
40 | | /// |
41 | | /// See comments on [`WindowExpr`] for more details. |
42 | | #[derive(Debug)] |
43 | | pub struct PlainAggregateWindowExpr { |
44 | | aggregate: AggregateFunctionExpr, |
45 | | partition_by: Vec<Arc<dyn PhysicalExpr>>, |
46 | | order_by: Vec<PhysicalSortExpr>, |
47 | | window_frame: Arc<WindowFrame>, |
48 | | } |
49 | | |
50 | | impl PlainAggregateWindowExpr { |
51 | | /// Create a new aggregate window function expression |
52 | 1 | pub fn new( |
53 | 1 | aggregate: AggregateFunctionExpr, |
54 | 1 | partition_by: &[Arc<dyn PhysicalExpr>], |
55 | 1 | order_by: &[PhysicalSortExpr], |
56 | 1 | window_frame: Arc<WindowFrame>, |
57 | 1 | ) -> Self { |
58 | 1 | Self { |
59 | 1 | aggregate, |
60 | 1 | partition_by: partition_by.to_vec(), |
61 | 1 | order_by: order_by.to_vec(), |
62 | 1 | window_frame, |
63 | 1 | } |
64 | 1 | } |
65 | | |
66 | | /// Get aggregate expr of AggregateWindowExpr |
67 | 0 | pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr { |
68 | 0 | &self.aggregate |
69 | 0 | } |
70 | | } |
71 | | |
72 | | /// peer based evaluation based on the fact that batch is pre-sorted given the sort columns |
73 | | /// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same |
74 | | /// results for peers) and concatenate the results. |
75 | | impl WindowExpr for PlainAggregateWindowExpr { |
76 | | /// Return a reference to Any that can be used for downcasting |
77 | 1 | fn as_any(&self) -> &dyn Any { |
78 | 1 | self |
79 | 1 | } |
80 | | |
81 | 1 | fn field(&self) -> Result<Field> { |
82 | 1 | Ok(self.aggregate.field()) |
83 | 1 | } |
84 | | |
85 | 0 | fn name(&self) -> &str { |
86 | 0 | self.aggregate.name() |
87 | 0 | } |
88 | | |
89 | 0 | fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { |
90 | 0 | self.aggregate.expressions() |
91 | 0 | } |
92 | | |
93 | 0 | fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> { |
94 | 0 | self.aggregate_evaluate(batch) |
95 | 0 | } |
96 | | |
97 | 0 | fn evaluate_stateful( |
98 | 0 | &self, |
99 | 0 | partition_batches: &PartitionBatches, |
100 | 0 | window_agg_state: &mut PartitionWindowAggStates, |
101 | 0 | ) -> Result<()> { |
102 | 0 | self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?; |
103 | | |
104 | | // Update window frame range for each partition. As we know that |
105 | | // non-sliding aggregations will never call `retract_batch`, this value |
106 | | // can safely increase, and we can remove "old" parts of the state. |
107 | | // This enables us to run queries involving UNBOUNDED PRECEDING frames |
108 | | // using bounded memory for suitable aggregations. |
109 | 0 | for partition_row in partition_batches.keys() { |
110 | 0 | let window_state = |
111 | 0 | window_agg_state.get_mut(partition_row).ok_or_else(|| { |
112 | 0 | DataFusionError::Execution("Cannot find state".to_string()) |
113 | 0 | })?; |
114 | 0 | let state = &mut window_state.state; |
115 | 0 | if self.window_frame.start_bound.is_unbounded() { |
116 | 0 | state.window_frame_range.start = |
117 | 0 | state.window_frame_range.end.saturating_sub(1); |
118 | 0 | } |
119 | | } |
120 | 0 | Ok(()) |
121 | 0 | } |
122 | | |
123 | 3 | fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] { |
124 | 3 | &self.partition_by |
125 | 3 | } |
126 | | |
127 | 0 | fn order_by(&self) -> &[PhysicalSortExpr] { |
128 | 0 | &self.order_by |
129 | 0 | } |
130 | | |
131 | 0 | fn get_window_frame(&self) -> &Arc<WindowFrame> { |
132 | 0 | &self.window_frame |
133 | 0 | } |
134 | | |
135 | 0 | fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> { |
136 | 0 | self.aggregate.reverse_expr().map(|reverse_expr| { |
137 | 0 | let reverse_window_frame = self.window_frame.reverse(); |
138 | 0 | if reverse_window_frame.start_bound.is_unbounded() { |
139 | 0 | Arc::new(PlainAggregateWindowExpr::new( |
140 | 0 | reverse_expr, |
141 | 0 | &self.partition_by.clone(), |
142 | 0 | &reverse_order_bys(&self.order_by), |
143 | 0 | Arc::new(self.window_frame.reverse()), |
144 | 0 | )) as _ |
145 | | } else { |
146 | 0 | Arc::new(SlidingAggregateWindowExpr::new( |
147 | 0 | reverse_expr, |
148 | 0 | &self.partition_by.clone(), |
149 | 0 | &reverse_order_bys(&self.order_by), |
150 | 0 | Arc::new(self.window_frame.reverse()), |
151 | 0 | )) as _ |
152 | | } |
153 | 0 | }) |
154 | 0 | } |
155 | | |
156 | 0 | fn uses_bounded_memory(&self) -> bool { |
157 | 0 | !self.window_frame.end_bound.is_unbounded() |
158 | 0 | } |
159 | | } |
160 | | |
161 | | impl AggregateWindowExpr for PlainAggregateWindowExpr { |
162 | 0 | fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> { |
163 | 0 | self.aggregate.create_accumulator() |
164 | 0 | } |
165 | | |
166 | | /// For a given range, calculate accumulation result inside the range on |
167 | | /// `value_slice` and update accumulator state. |
168 | | // We assume that `cur_range` contains `last_range` and their start points |
169 | | // are same. In summary if `last_range` is `Range{start: a,end: b}` and |
170 | | // `cur_range` is `Range{start: a1, end: b1}`, it is guaranteed that a1=a and b1>=b. |
171 | 0 | fn get_aggregate_result_inside_range( |
172 | 0 | &self, |
173 | 0 | last_range: &Range<usize>, |
174 | 0 | cur_range: &Range<usize>, |
175 | 0 | value_slice: &[ArrayRef], |
176 | 0 | accumulator: &mut Box<dyn Accumulator>, |
177 | 0 | ) -> Result<ScalarValue> { |
178 | 0 | if cur_range.start == cur_range.end { |
179 | 0 | self.aggregate |
180 | 0 | .default_value(self.aggregate.field().data_type()) |
181 | | } else { |
182 | | // Accumulate any new rows that have entered the window: |
183 | 0 | let update_bound = cur_range.end - last_range.end; |
184 | 0 | // A non-sliding aggregation only processes new data, it never |
185 | 0 | // deals with expiring data as its starting point is always the |
186 | 0 | // same point (i.e. the beginning of the table/frame). Hence, we |
187 | 0 | // do not call `retract_batch`. |
188 | 0 | if update_bound > 0 { |
189 | 0 | let update: Vec<ArrayRef> = value_slice |
190 | 0 | .iter() |
191 | 0 | .map(|v| v.slice(last_range.end, update_bound)) |
192 | 0 | .collect(); |
193 | 0 | accumulator.update_batch(&update)? |
194 | 0 | } |
195 | 0 | accumulator.evaluate() |
196 | | } |
197 | 0 | } |
198 | | } |