/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/sliding_aggregate.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Physical exec for aggregate window function expressions. |
19 | | |
20 | | use std::any::Any; |
21 | | use std::ops::Range; |
22 | | use std::sync::Arc; |
23 | | |
24 | | use arrow::array::{Array, ArrayRef}; |
25 | | use arrow::datatypes::Field; |
26 | | use arrow::record_batch::RecordBatch; |
27 | | |
28 | | use datafusion_common::{Result, ScalarValue}; |
29 | | use datafusion_expr::{Accumulator, WindowFrame}; |
30 | | |
31 | | use crate::aggregate::AggregateFunctionExpr; |
32 | | use crate::window::window_expr::AggregateWindowExpr; |
33 | | use crate::window::{ |
34 | | PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr, |
35 | | }; |
36 | | use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; |
37 | | |
38 | | /// A window expr that takes the form of an aggregate function that |
39 | | /// can be incrementally computed over sliding windows. |
40 | | /// |
41 | | /// See comments on [`WindowExpr`] for more details. |
42 | | #[derive(Debug)] |
43 | | pub struct SlidingAggregateWindowExpr { |
44 | | aggregate: AggregateFunctionExpr, |
45 | | partition_by: Vec<Arc<dyn PhysicalExpr>>, |
46 | | order_by: Vec<PhysicalSortExpr>, |
47 | | window_frame: Arc<WindowFrame>, |
48 | | } |
49 | | |
50 | | impl SlidingAggregateWindowExpr { |
51 | | /// Create a new (sliding) aggregate window function expression. |
52 | 1 | pub fn new( |
53 | 1 | aggregate: AggregateFunctionExpr, |
54 | 1 | partition_by: &[Arc<dyn PhysicalExpr>], |
55 | 1 | order_by: &[PhysicalSortExpr], |
56 | 1 | window_frame: Arc<WindowFrame>, |
57 | 1 | ) -> Self { |
58 | 1 | Self { |
59 | 1 | aggregate, |
60 | 1 | partition_by: partition_by.to_vec(), |
61 | 1 | order_by: order_by.to_vec(), |
62 | 1 | window_frame, |
63 | 1 | } |
64 | 1 | } |
65 | | |
66 | | /// Get the [AggregateFunctionExpr] of this object. |
67 | 0 | pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr { |
68 | 0 | &self.aggregate |
69 | 0 | } |
70 | | } |
71 | | |
72 | | /// Incrementally update window function using the fact that batch is |
73 | | /// pre-sorted given the sort columns and then per partition point. |
74 | | /// |
75 | | /// Evaluates the peer group (e.g. `SUM` or `MAX` gives the same results |
76 | | /// for peers) and concatenate the results. |
77 | | impl WindowExpr for SlidingAggregateWindowExpr { |
78 | | /// Return a reference to Any that can be used for downcasting |
79 | 1 | fn as_any(&self) -> &dyn Any { |
80 | 1 | self |
81 | 1 | } |
82 | | |
83 | 12 | fn field(&self) -> Result<Field> { |
84 | 12 | Ok(self.aggregate.field()) |
85 | 12 | } |
86 | | |
87 | 1 | fn name(&self) -> &str { |
88 | 1 | self.aggregate.name() |
89 | 1 | } |
90 | | |
91 | 9 | fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { |
92 | 9 | self.aggregate.expressions() |
93 | 9 | } |
94 | | |
95 | 0 | fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> { |
96 | 0 | self.aggregate_evaluate(batch) |
97 | 0 | } |
98 | | |
99 | 5 | fn evaluate_stateful( |
100 | 5 | &self, |
101 | 5 | partition_batches: &PartitionBatches, |
102 | 5 | window_agg_state: &mut PartitionWindowAggStates, |
103 | 5 | ) -> Result<()> { |
104 | 5 | self.aggregate_evaluate_stateful(partition_batches, window_agg_state) |
105 | 5 | } |
106 | | |
107 | 12 | fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] { |
108 | 12 | &self.partition_by |
109 | 12 | } |
110 | | |
111 | 30 | fn order_by(&self) -> &[PhysicalSortExpr] { |
112 | 30 | &self.order_by |
113 | 30 | } |
114 | | |
115 | 13 | fn get_window_frame(&self) -> &Arc<WindowFrame> { |
116 | 13 | &self.window_frame |
117 | 13 | } |
118 | | |
119 | 0 | fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> { |
120 | 0 | self.aggregate.reverse_expr().map(|reverse_expr| { |
121 | 0 | let reverse_window_frame = self.window_frame.reverse(); |
122 | 0 | if reverse_window_frame.start_bound.is_unbounded() { |
123 | 0 | Arc::new(PlainAggregateWindowExpr::new( |
124 | 0 | reverse_expr, |
125 | 0 | &self.partition_by.clone(), |
126 | 0 | &reverse_order_bys(&self.order_by), |
127 | 0 | Arc::new(self.window_frame.reverse()), |
128 | 0 | )) as _ |
129 | | } else { |
130 | 0 | Arc::new(SlidingAggregateWindowExpr::new( |
131 | 0 | reverse_expr, |
132 | 0 | &self.partition_by.clone(), |
133 | 0 | &reverse_order_bys(&self.order_by), |
134 | 0 | Arc::new(self.window_frame.reverse()), |
135 | 0 | )) as _ |
136 | | } |
137 | 0 | }) |
138 | 0 | } |
139 | | |
140 | 0 | fn uses_bounded_memory(&self) -> bool { |
141 | 0 | !self.window_frame.end_bound.is_unbounded() |
142 | 0 | } |
143 | | |
144 | 0 | fn with_new_expressions( |
145 | 0 | &self, |
146 | 0 | args: Vec<Arc<dyn PhysicalExpr>>, |
147 | 0 | partition_bys: Vec<Arc<dyn PhysicalExpr>>, |
148 | 0 | order_by_exprs: Vec<Arc<dyn PhysicalExpr>>, |
149 | 0 | ) -> Option<Arc<dyn WindowExpr>> { |
150 | 0 | debug_assert_eq!(self.order_by.len(), order_by_exprs.len()); |
151 | | |
152 | 0 | let new_order_by = self |
153 | 0 | .order_by |
154 | 0 | .iter() |
155 | 0 | .zip(order_by_exprs) |
156 | 0 | .map(|(req, new_expr)| PhysicalSortExpr { |
157 | 0 | expr: new_expr, |
158 | 0 | options: req.options, |
159 | 0 | }) |
160 | 0 | .collect::<Vec<_>>(); |
161 | 0 | Some(Arc::new(SlidingAggregateWindowExpr { |
162 | 0 | aggregate: self.aggregate.with_new_expressions(args, vec![])?, |
163 | 0 | partition_by: partition_bys, |
164 | 0 | order_by: new_order_by, |
165 | 0 | window_frame: Arc::clone(&self.window_frame), |
166 | | })) |
167 | 0 | } |
168 | | } |
169 | | |
170 | | impl AggregateWindowExpr for SlidingAggregateWindowExpr { |
171 | 3 | fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> { |
172 | 3 | self.aggregate.create_sliding_accumulator() |
173 | 3 | } |
174 | | |
175 | | /// Given current range and the last range, calculates the accumulator |
176 | | /// result for the range of interest. |
177 | 8 | fn get_aggregate_result_inside_range( |
178 | 8 | &self, |
179 | 8 | last_range: &Range<usize>, |
180 | 8 | cur_range: &Range<usize>, |
181 | 8 | value_slice: &[ArrayRef], |
182 | 8 | accumulator: &mut Box<dyn Accumulator>, |
183 | 8 | ) -> Result<ScalarValue> { |
184 | 8 | if cur_range.start == cur_range.end { |
185 | 0 | self.aggregate |
186 | 0 | .default_value(self.aggregate.field().data_type()) |
187 | | } else { |
188 | | // Accumulate any new rows that have entered the window: |
189 | 8 | let update_bound = cur_range.end - last_range.end; |
190 | 8 | if update_bound > 0 { |
191 | 6 | let update: Vec<ArrayRef> = value_slice |
192 | 6 | .iter() |
193 | 6 | .map(|v| v.slice(last_range.end, update_bound)) |
194 | 6 | .collect(); |
195 | 6 | accumulator.update_batch(&update)?0 |
196 | 2 | } |
197 | | |
198 | | // Remove rows that have now left the window: |
199 | 8 | let retract_bound = cur_range.start - last_range.start; |
200 | 8 | if retract_bound > 0 { |
201 | 6 | let retract: Vec<ArrayRef> = value_slice |
202 | 6 | .iter() |
203 | 6 | .map(|v| v.slice(last_range.start, retract_bound)) |
204 | 6 | .collect(); |
205 | 6 | accumulator.retract_batch(&retract)?0 |
206 | 2 | } |
207 | 8 | accumulator.evaluate() |
208 | | } |
209 | 8 | } |
210 | | } |