/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/window_expr.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::any::Any; |
19 | | use std::fmt::Debug; |
20 | | use std::ops::Range; |
21 | | use std::sync::Arc; |
22 | | |
23 | | use crate::{LexOrderingRef, PhysicalExpr, PhysicalSortExpr}; |
24 | | |
25 | | use arrow::array::{new_empty_array, Array, ArrayRef}; |
26 | | use arrow::compute::kernels::sort::SortColumn; |
27 | | use arrow::compute::SortOptions; |
28 | | use arrow::datatypes::Field; |
29 | | use arrow::record_batch::RecordBatch; |
30 | | use datafusion_common::utils::compare_rows; |
31 | | use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; |
32 | | use datafusion_expr::window_state::{ |
33 | | PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups, |
34 | | }; |
35 | | use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound}; |
36 | | |
37 | | use indexmap::IndexMap; |
38 | | |
39 | | /// Common trait for [window function] implementations |
40 | | /// |
41 | | /// # Aggregate Window Expressions |
42 | | /// |
43 | | /// These expressions take the form |
44 | | /// |
45 | | /// ```text |
46 | | /// OVER({ROWS | RANGE| GROUPS} BETWEEN UNBOUNDED PRECEDING AND ...) |
47 | | /// ``` |
48 | | /// |
49 | | /// For example, cumulative window frames uses `PlainAggregateWindowExpr`. |
50 | | /// |
51 | | /// # Non Aggregate Window Expressions |
52 | | /// |
53 | | /// The expressions have the form |
54 | | /// |
55 | | /// ```text |
56 | | /// OVER({ROWS | RANGE| GROUPS} BETWEEN M {PRECEDING| FOLLOWING} AND ...) |
57 | | /// ``` |
58 | | /// |
59 | | /// For example, sliding window frames use [`SlidingAggregateWindowExpr`]. |
60 | | /// |
61 | | /// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) |
62 | | /// [`PlainAggregateWindowExpr`]: crate::window::PlainAggregateWindowExpr |
63 | | /// [`SlidingAggregateWindowExpr`]: crate::window::SlidingAggregateWindowExpr |
64 | | pub trait WindowExpr: Send + Sync + Debug { |
65 | | /// Returns the window expression as [`Any`] so that it can be |
66 | | /// downcast to a specific implementation. |
67 | | fn as_any(&self) -> &dyn Any; |
68 | | |
69 | | /// The field of the final result of this window function. |
70 | | fn field(&self) -> Result<Field>; |
71 | | |
72 | | /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default |
73 | | /// implementation returns placeholder text. |
74 | 0 | fn name(&self) -> &str { |
75 | 0 | "WindowExpr: default name" |
76 | 0 | } |
77 | | |
78 | | /// Expressions that are passed to the WindowAccumulator. |
79 | | /// Functions which take a single input argument, such as `sum`, return a single [`datafusion_expr::expr::Expr`], |
80 | | /// others (e.g. `cov`) return many. |
81 | | fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>; |
82 | | |
83 | | /// Evaluate the window function arguments against the batch and return |
84 | | /// array ref, normally the resulting `Vec` is a single element one. |
85 | 21 | fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> { |
86 | 21 | self.expressions() |
87 | 21 | .iter() |
88 | 21 | .map(|e| { |
89 | 21 | e.evaluate(batch) |
90 | 21 | .and_then(|v| v.into_array(batch.num_rows())) |
91 | 21 | }) |
92 | 21 | .collect() |
93 | 21 | } |
94 | | |
95 | | /// Evaluate the window function values against the batch |
96 | | fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>; |
97 | | |
98 | | /// Evaluate the window function against the batch. This function facilitates |
99 | | /// stateful, bounded-memory implementations. |
100 | 0 | fn evaluate_stateful( |
101 | 0 | &self, |
102 | 0 | _partition_batches: &PartitionBatches, |
103 | 0 | _window_agg_state: &mut PartitionWindowAggStates, |
104 | 0 | ) -> Result<()> { |
105 | 0 | internal_err!("evaluate_stateful is not implemented for {}", self.name()) |
106 | 0 | } |
107 | | |
108 | | /// Expressions that's from the window function's partition by clause, empty if absent |
109 | | fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>]; |
110 | | |
111 | | /// Expressions that's from the window function's order by clause, empty if absent |
112 | | fn order_by(&self) -> &[PhysicalSortExpr]; |
113 | | |
114 | | /// Get order by columns, empty if absent |
115 | 30 | fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> { |
116 | 30 | self.order_by() |
117 | 30 | .iter() |
118 | 30 | .map(|e| e.evaluate_to_sort_column(batch)18 ) |
119 | 30 | .collect::<Result<Vec<SortColumn>>>() |
120 | 30 | } |
121 | | |
122 | | /// Get the window frame of this [WindowExpr]. |
123 | | fn get_window_frame(&self) -> &Arc<WindowFrame>; |
124 | | |
125 | | /// Return a flag indicating whether this [WindowExpr] can run with |
126 | | /// bounded memory. |
127 | | fn uses_bounded_memory(&self) -> bool; |
128 | | |
129 | | /// Get the reverse expression of this [WindowExpr]. |
130 | | fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>; |
131 | | |
132 | | /// Returns all expressions used in the [`WindowExpr`]. |
133 | | /// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions. |
134 | 0 | fn all_expressions(&self) -> WindowPhysicalExpressions { |
135 | 0 | let args = self.expressions(); |
136 | 0 | let partition_by_exprs = self.partition_by().to_vec(); |
137 | 0 | let order_by_exprs = self |
138 | 0 | .order_by() |
139 | 0 | .iter() |
140 | 0 | .map(|sort_expr| Arc::clone(&sort_expr.expr)) |
141 | 0 | .collect::<Vec<_>>(); |
142 | 0 | WindowPhysicalExpressions { |
143 | 0 | args, |
144 | 0 | partition_by_exprs, |
145 | 0 | order_by_exprs, |
146 | 0 | } |
147 | 0 | } |
148 | | |
149 | | /// Rewrites [`WindowExpr`], with new expressions given. The argument should be consistent |
150 | | /// with the return value of the [`WindowExpr::all_expressions`] method. |
151 | | /// Returns `Some(Arc<dyn WindowExpr>)` if re-write is supported, otherwise returns `None`. |
152 | 0 | fn with_new_expressions( |
153 | 0 | &self, |
154 | 0 | _args: Vec<Arc<dyn PhysicalExpr>>, |
155 | 0 | _partition_bys: Vec<Arc<dyn PhysicalExpr>>, |
156 | 0 | _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>, |
157 | 0 | ) -> Option<Arc<dyn WindowExpr>> { |
158 | 0 | None |
159 | 0 | } |
160 | | } |
161 | | |
162 | | /// Stores the physical expressions used inside the `WindowExpr`. |
163 | | pub struct WindowPhysicalExpressions { |
164 | | /// Window function arguments |
165 | | pub args: Vec<Arc<dyn PhysicalExpr>>, |
166 | | /// PARTITION BY expressions |
167 | | pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>, |
168 | | /// ORDER BY expressions |
169 | | pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>, |
170 | | } |
171 | | |
172 | | /// Extension trait that adds common functionality to [`AggregateWindowExpr`]s |
173 | | pub trait AggregateWindowExpr: WindowExpr { |
174 | | /// Get the accumulator for the window expression. Note that distinct |
175 | | /// window expressions may return distinct accumulators; e.g. sliding |
176 | | /// (non-sliding) expressions will return sliding (normal) accumulators. |
177 | | fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>; |
178 | | |
179 | | /// Given current range and the last range, calculates the accumulator |
180 | | /// result for the range of interest. |
181 | | fn get_aggregate_result_inside_range( |
182 | | &self, |
183 | | last_range: &Range<usize>, |
184 | | cur_range: &Range<usize>, |
185 | | value_slice: &[ArrayRef], |
186 | | accumulator: &mut Box<dyn Accumulator>, |
187 | | ) -> Result<ScalarValue>; |
188 | | |
189 | | /// Evaluates the window function against the batch. |
190 | 0 | fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> { |
191 | 0 | let mut accumulator = self.get_accumulator()?; |
192 | 0 | let mut last_range = Range { start: 0, end: 0 }; |
193 | 0 | let sort_options: Vec<SortOptions> = |
194 | 0 | self.order_by().iter().map(|o| o.options).collect(); |
195 | 0 | let mut window_frame_ctx = |
196 | 0 | WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options); |
197 | 0 | self.get_result_column( |
198 | 0 | &mut accumulator, |
199 | 0 | batch, |
200 | 0 | None, |
201 | 0 | &mut last_range, |
202 | 0 | &mut window_frame_ctx, |
203 | 0 | 0, |
204 | 0 | false, |
205 | 0 | ) |
206 | 0 | } |
207 | | |
208 | | /// Statefully evaluates the window function against the batch. Maintains |
209 | | /// state so that it can work incrementally over multiple chunks. |
210 | 5 | fn aggregate_evaluate_stateful( |
211 | 5 | &self, |
212 | 5 | partition_batches: &PartitionBatches, |
213 | 5 | window_agg_state: &mut PartitionWindowAggStates, |
214 | 5 | ) -> Result<()> { |
215 | 5 | let field = self.field()?0 ; |
216 | 5 | let out_type = field.data_type(); |
217 | 9 | for (partition_row, partition_batch_state) in partition_batches.iter()5 { |
218 | 9 | if !window_agg_state.contains_key(partition_row) { |
219 | 3 | let accumulator = self.get_accumulator()?0 ; |
220 | 3 | window_agg_state.insert( |
221 | 3 | partition_row.clone(), |
222 | 3 | WindowState { |
223 | 3 | state: WindowAggState::new(out_type)?0 , |
224 | 3 | window_fn: WindowFn::Aggregate(accumulator), |
225 | | }, |
226 | | ); |
227 | 6 | }; |
228 | 9 | let window_state = |
229 | 9 | window_agg_state.get_mut(partition_row).ok_or_else(|| { |
230 | 0 | DataFusionError::Execution("Cannot find state".to_string()) |
231 | 9 | })?0 ; |
232 | 9 | let accumulator = match &mut window_state.window_fn { |
233 | 9 | WindowFn::Aggregate(accumulator) => accumulator, |
234 | 0 | _ => unreachable!(), |
235 | | }; |
236 | 9 | let state = &mut window_state.state; |
237 | 9 | let record_batch = &partition_batch_state.record_batch; |
238 | 9 | let most_recent_row = partition_batch_state.most_recent_row.as_ref(); |
239 | 9 | |
240 | 9 | // If there is no window state context, initialize it. |
241 | 9 | let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { |
242 | 3 | let sort_options: Vec<SortOptions> = |
243 | 3 | self.order_by().iter().map(|o| o.options).collect(); |
244 | 3 | WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options) |
245 | 9 | }); |
246 | 9 | let out_col = self.get_result_column( |
247 | 9 | accumulator, |
248 | 9 | record_batch, |
249 | 9 | most_recent_row, |
250 | 9 | // Start search from the last range |
251 | 9 | &mut state.window_frame_range, |
252 | 9 | window_frame_ctx, |
253 | 9 | state.last_calculated_index, |
254 | 9 | !partition_batch_state.is_end, |
255 | 9 | )?0 ; |
256 | 9 | state.update(&out_col, partition_batch_state)?0 ; |
257 | | } |
258 | 5 | Ok(()) |
259 | 5 | } |
260 | | |
261 | | /// Calculates the window expression result for the given record batch. |
262 | | /// Assumes that `record_batch` belongs to a single partition. |
263 | | #[allow(clippy::too_many_arguments)] |
264 | 9 | fn get_result_column( |
265 | 9 | &self, |
266 | 9 | accumulator: &mut Box<dyn Accumulator>, |
267 | 9 | record_batch: &RecordBatch, |
268 | 9 | most_recent_row: Option<&RecordBatch>, |
269 | 9 | last_range: &mut Range<usize>, |
270 | 9 | window_frame_ctx: &mut WindowFrameContext, |
271 | 9 | mut idx: usize, |
272 | 9 | not_end: bool, |
273 | 9 | ) -> Result<ArrayRef> { |
274 | 9 | let values = self.evaluate_args(record_batch)?0 ; |
275 | 9 | let order_bys = get_orderby_values(self.order_by_columns(record_batch)?0 ); |
276 | | |
277 | 9 | let most_recent_row_order_bys = most_recent_row |
278 | 9 | .map(|batch| self.order_by_columns(batch)) |
279 | 9 | .transpose()?0 |
280 | 9 | .map(get_orderby_values); |
281 | 9 | |
282 | 9 | // We iterate on each row to perform a running calculation. |
283 | 9 | let length = values[0].len(); |
284 | 9 | let mut row_wise_results: Vec<ScalarValue> = vec![]; |
285 | 9 | let is_causal = self.get_window_frame().is_causal(); |
286 | 17 | while idx < length { |
287 | | // Start search from the last_range. This squeezes searched range. |
288 | 13 | let cur_range = |
289 | 13 | window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?0 ; |
290 | | // Exit if the range is non-causal and extends all the way: |
291 | 13 | if cur_range.end == length |
292 | 9 | && !is_causal |
293 | 9 | && not_end |
294 | 9 | && !is_end_bound_safe( |
295 | 9 | window_frame_ctx, |
296 | 9 | &order_bys, |
297 | 9 | most_recent_row_order_bys.as_deref(), |
298 | 9 | self.order_by(), |
299 | 9 | idx, |
300 | 9 | )?0 |
301 | | { |
302 | 5 | break; |
303 | 8 | } |
304 | 8 | let value = self.get_aggregate_result_inside_range( |
305 | 8 | last_range, |
306 | 8 | &cur_range, |
307 | 8 | &values, |
308 | 8 | accumulator, |
309 | 8 | )?0 ; |
310 | | // Update last range |
311 | 8 | *last_range = cur_range; |
312 | 8 | row_wise_results.push(value); |
313 | 8 | idx += 1; |
314 | | } |
315 | | |
316 | 9 | if row_wise_results.is_empty() { |
317 | 5 | let field = self.field()?0 ; |
318 | 5 | let out_type = field.data_type(); |
319 | 5 | Ok(new_empty_array(out_type)) |
320 | | } else { |
321 | 4 | ScalarValue::iter_to_array(row_wise_results) |
322 | | } |
323 | 9 | } |
324 | | } |
325 | | |
326 | | /// Determines whether the end bound calculation for a window frame context is |
327 | | /// safe, meaning that the end bound stays the same, regardless of future data, |
328 | | /// based on the current sort expressions and ORDER BY columns. This function |
329 | | /// delegates work to specific functions for each frame type. |
330 | | /// |
331 | | /// # Parameters |
332 | | /// |
333 | | /// * `window_frame_ctx`: The context of the window frame being evaluated. |
334 | | /// * `order_bys`: A slice of `ArrayRef` representing the ORDER BY columns. |
335 | | /// * `most_recent_order_bys`: An optional reference to the most recent ORDER BY |
336 | | /// columns. |
337 | | /// * `sort_exprs`: Defines the lexicographical ordering in question. |
338 | | /// * `idx`: The current index in the window frame. |
339 | | /// |
340 | | /// # Returns |
341 | | /// |
342 | | /// A `Result` which is `Ok(true)` if the end bound is safe, `Ok(false)` otherwise. |
343 | 9 | pub(crate) fn is_end_bound_safe( |
344 | 9 | window_frame_ctx: &WindowFrameContext, |
345 | 9 | order_bys: &[ArrayRef], |
346 | 9 | most_recent_order_bys: Option<&[ArrayRef]>, |
347 | 9 | sort_exprs: LexOrderingRef, |
348 | 9 | idx: usize, |
349 | 9 | ) -> Result<bool> { |
350 | 9 | if sort_exprs.is_empty() { |
351 | | // Early return if no sort expressions are present: |
352 | 0 | return Ok(false); |
353 | 9 | } |
354 | 9 | |
355 | 9 | match window_frame_ctx { |
356 | 0 | WindowFrameContext::Rows(window_frame) => { |
357 | 0 | is_end_bound_safe_for_rows(&window_frame.end_bound) |
358 | | } |
359 | 9 | WindowFrameContext::Range { window_frame, .. } => is_end_bound_safe_for_range( |
360 | 9 | &window_frame.end_bound, |
361 | 9 | &order_bys[0], |
362 | 9 | most_recent_order_bys.map(|items| &items[0]), |
363 | 9 | &sort_exprs[0].options, |
364 | 9 | idx, |
365 | 9 | ), |
366 | | WindowFrameContext::Groups { |
367 | 0 | window_frame, |
368 | 0 | state, |
369 | 0 | } => is_end_bound_safe_for_groups( |
370 | 0 | &window_frame.end_bound, |
371 | 0 | state, |
372 | 0 | &order_bys[0], |
373 | 0 | most_recent_order_bys.map(|items| &items[0]), |
374 | 0 | &sort_exprs[0].options, |
375 | 0 | ), |
376 | | } |
377 | 9 | } |
378 | | |
379 | | /// For row-based window frames, determines whether the end bound calculation |
380 | | /// is safe, which is trivially the case for `Preceding` and `CurrentRow` bounds. |
381 | | /// For 'Following' bounds, it compares the bound value to zero to ensure that |
382 | | /// it doesn't extend beyond the current row. |
383 | | /// |
384 | | /// # Parameters |
385 | | /// |
386 | | /// * `end_bound`: Reference to the window frame bound in question. |
387 | | /// |
388 | | /// # Returns |
389 | | /// |
390 | | /// A `Result` indicating whether the end bound is safe for row-based window frames. |
391 | 0 | fn is_end_bound_safe_for_rows(end_bound: &WindowFrameBound) -> Result<bool> { |
392 | 0 | if let WindowFrameBound::Following(value) = end_bound { |
393 | 0 | let zero = ScalarValue::new_zero(&value.data_type()); |
394 | 0 | Ok(zero.map(|zero| value.eq(&zero)).unwrap_or(false)) |
395 | | } else { |
396 | 0 | Ok(true) |
397 | | } |
398 | 0 | } |
399 | | |
400 | | /// For row-based window frames, determines whether the end bound calculation |
401 | | /// is safe by comparing it against specific values (zero, current row). It uses |
402 | | /// the `is_row_ahead` helper function to determine if the current row is ahead |
403 | | /// of the most recent row based on the ORDER BY column and sorting options. |
404 | | /// |
405 | | /// # Parameters |
406 | | /// |
407 | | /// * `end_bound`: Reference to the window frame bound in question. |
408 | | /// * `orderby_col`: Reference to the column used for ordering. |
409 | | /// * `most_recent_ob_col`: Optional reference to the most recent order-by column. |
410 | | /// * `sort_options`: The sorting options used in the window frame. |
411 | | /// * `idx`: The current index in the window frame. |
412 | | /// |
413 | | /// # Returns |
414 | | /// |
415 | | /// A `Result` indicating whether the end bound is safe for range-based window frames. |
416 | 9 | fn is_end_bound_safe_for_range( |
417 | 9 | end_bound: &WindowFrameBound, |
418 | 9 | orderby_col: &ArrayRef, |
419 | 9 | most_recent_ob_col: Option<&ArrayRef>, |
420 | 9 | sort_options: &SortOptions, |
421 | 9 | idx: usize, |
422 | 9 | ) -> Result<bool> { |
423 | 9 | match end_bound { |
424 | 0 | WindowFrameBound::Preceding(value) => { |
425 | 0 | let zero = ScalarValue::new_zero(&value.data_type())?; |
426 | 0 | if value.eq(&zero) { |
427 | 0 | is_row_ahead(orderby_col, most_recent_ob_col, sort_options) |
428 | | } else { |
429 | 0 | Ok(true) |
430 | | } |
431 | | } |
432 | | WindowFrameBound::CurrentRow => { |
433 | 0 | is_row_ahead(orderby_col, most_recent_ob_col, sort_options) |
434 | | } |
435 | 9 | WindowFrameBound::Following(delta) => { |
436 | 9 | let Some(most_recent_ob_col) = most_recent_ob_col else { |
437 | 0 | return Ok(false); |
438 | | }; |
439 | 9 | let most_recent_row_value = |
440 | 9 | ScalarValue::try_from_array(most_recent_ob_col, 0)?0 ; |
441 | 9 | let current_row_value = ScalarValue::try_from_array(orderby_col, idx)?0 ; |
442 | | |
443 | 9 | if sort_options.descending { |
444 | 0 | current_row_value |
445 | 0 | .sub(delta) |
446 | 0 | .map(|value| value > most_recent_row_value) |
447 | | } else { |
448 | 9 | current_row_value |
449 | 9 | .add(delta) |
450 | 9 | .map(|value| most_recent_row_value > value) |
451 | | } |
452 | | } |
453 | | } |
454 | 9 | } |
455 | | |
456 | | /// For group-based window frames, determines whether the end bound calculation |
457 | | /// is safe by considering the group offset and whether the current row is ahead |
458 | | /// of the most recent row in terms of sorting. It checks if the end bound is |
459 | | /// within the bounds of the current group based on group end indices. |
460 | | /// |
461 | | /// # Parameters |
462 | | /// |
463 | | /// * `end_bound`: Reference to the window frame bound in question. |
464 | | /// * `state`: The state of the window frame for group calculations. |
465 | | /// * `orderby_col`: Reference to the column used for ordering. |
466 | | /// * `most_recent_ob_col`: Optional reference to the most recent order-by column. |
467 | | /// * `sort_options`: The sorting options used in the window frame. |
468 | | /// |
469 | | /// # Returns |
470 | | /// |
471 | | /// A `Result` indicating whether the end bound is safe for group-based window frames. |
472 | 0 | fn is_end_bound_safe_for_groups( |
473 | 0 | end_bound: &WindowFrameBound, |
474 | 0 | state: &WindowFrameStateGroups, |
475 | 0 | orderby_col: &ArrayRef, |
476 | 0 | most_recent_ob_col: Option<&ArrayRef>, |
477 | 0 | sort_options: &SortOptions, |
478 | 0 | ) -> Result<bool> { |
479 | 0 | match end_bound { |
480 | 0 | WindowFrameBound::Preceding(value) => { |
481 | 0 | let zero = ScalarValue::new_zero(&value.data_type())?; |
482 | 0 | if value.eq(&zero) { |
483 | 0 | is_row_ahead(orderby_col, most_recent_ob_col, sort_options) |
484 | | } else { |
485 | 0 | Ok(true) |
486 | | } |
487 | | } |
488 | | WindowFrameBound::CurrentRow => { |
489 | 0 | is_row_ahead(orderby_col, most_recent_ob_col, sort_options) |
490 | | } |
491 | 0 | WindowFrameBound::Following(ScalarValue::UInt64(Some(offset))) => { |
492 | 0 | let delta = state.group_end_indices.len() - state.current_group_idx; |
493 | 0 | if delta == (*offset as usize) + 1 { |
494 | 0 | is_row_ahead(orderby_col, most_recent_ob_col, sort_options) |
495 | | } else { |
496 | 0 | Ok(false) |
497 | | } |
498 | | } |
499 | 0 | _ => Ok(false), |
500 | | } |
501 | 0 | } |
502 | | |
503 | | /// This utility function checks whether `current_cols` is ahead of the `old_cols` |
504 | | /// in terms of `sort_options`. |
505 | 0 | fn is_row_ahead( |
506 | 0 | old_col: &ArrayRef, |
507 | 0 | current_col: Option<&ArrayRef>, |
508 | 0 | sort_options: &SortOptions, |
509 | 0 | ) -> Result<bool> { |
510 | 0 | let Some(current_col) = current_col else { |
511 | 0 | return Ok(false); |
512 | | }; |
513 | 0 | if old_col.is_empty() || current_col.is_empty() { |
514 | 0 | return Ok(false); |
515 | 0 | } |
516 | 0 | let last_value = ScalarValue::try_from_array(old_col, old_col.len() - 1)?; |
517 | 0 | let current_value = ScalarValue::try_from_array(current_col, 0)?; |
518 | 0 | let cmp = compare_rows(&[current_value], &[last_value], &[*sort_options])?; |
519 | 0 | Ok(cmp.is_gt()) |
520 | 0 | } |
521 | | |
522 | | /// Get order by expression results inside `order_by_columns`. |
523 | 30 | pub(crate) fn get_orderby_values(order_by_columns: Vec<SortColumn>) -> Vec<ArrayRef> { |
524 | 30 | order_by_columns.into_iter().map(|s| s.values18 ).collect() |
525 | 30 | } |
526 | | |
527 | | #[derive(Debug)] |
528 | | pub enum WindowFn { |
529 | | Builtin(Box<dyn PartitionEvaluator>), |
530 | | Aggregate(Box<dyn Accumulator>), |
531 | | } |
532 | | |
533 | | /// State for the RANK(percent_rank, rank, dense_rank) built-in window function. |
534 | | #[derive(Debug, Clone, Default)] |
535 | | pub struct RankState { |
536 | | /// The last values for rank as these values change, we increase n_rank |
537 | | pub last_rank_data: Option<Vec<ScalarValue>>, |
538 | | /// The index where last_rank_boundary is started |
539 | | pub last_rank_boundary: usize, |
540 | | /// Keep the number of entries in current rank |
541 | | pub current_group_count: usize, |
542 | | /// Rank number kept from the start |
543 | | pub n_rank: usize, |
544 | | } |
545 | | |
546 | | /// Tag to differentiate special use cases of the NTH_VALUE built-in window function. |
547 | | #[derive(Debug, Copy, Clone)] |
548 | | pub enum NthValueKind { |
549 | | First, |
550 | | Last, |
551 | | Nth(i64), |
552 | | } |
553 | | |
554 | | #[derive(Debug, Clone)] |
555 | | pub struct NthValueState { |
556 | | // In certain cases, we can finalize the result early. Consider this usage: |
557 | | // ``` |
558 | | // FIRST_VALUE(increasing_col) OVER window AS my_first_value |
559 | | // WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window |
560 | | // ``` |
561 | | // The result will always be the first entry in the table. We can store such |
562 | | // early-finalizing results and then just reuse them as necessary. This opens |
563 | | // opportunities to prune our datasets. |
564 | | pub finalized_result: Option<ScalarValue>, |
565 | | pub kind: NthValueKind, |
566 | | } |
567 | | |
568 | | /// Key for IndexMap for each unique partition |
569 | | /// |
570 | | /// For instance, if window frame is `OVER(PARTITION BY a,b)`, |
571 | | /// PartitionKey would consist of unique `[a,b]` pairs |
572 | | pub type PartitionKey = Vec<ScalarValue>; |
573 | | |
574 | | #[derive(Debug)] |
575 | | pub struct WindowState { |
576 | | pub state: WindowAggState, |
577 | | pub window_fn: WindowFn, |
578 | | } |
579 | | pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>; |
580 | | |
581 | | /// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. |
582 | | pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>; |
583 | | |
584 | | #[cfg(test)] |
585 | | mod tests { |
586 | | use std::sync::Arc; |
587 | | |
588 | | use crate::window::window_expr::is_row_ahead; |
589 | | |
590 | | use arrow_array::{ArrayRef, Float64Array}; |
591 | | use arrow_schema::SortOptions; |
592 | | use datafusion_common::Result; |
593 | | |
594 | | #[test] |
595 | | fn test_is_row_ahead() -> Result<()> { |
596 | | let old_values: ArrayRef = |
597 | | Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])); |
598 | | |
599 | | let new_values1: ArrayRef = Arc::new(Float64Array::from(vec![11.0])); |
600 | | let new_values2: ArrayRef = Arc::new(Float64Array::from(vec![10.0])); |
601 | | |
602 | | assert!(is_row_ahead( |
603 | | &old_values, |
604 | | Some(&new_values1), |
605 | | &SortOptions { |
606 | | descending: false, |
607 | | nulls_first: false |
608 | | } |
609 | | )?); |
610 | | assert!(!is_row_ahead( |
611 | | &old_values, |
612 | | Some(&new_values2), |
613 | | &SortOptions { |
614 | | descending: false, |
615 | | nulls_first: false |
616 | | } |
617 | | )?); |
618 | | |
619 | | Ok(()) |
620 | | } |
621 | | } |