/Users/andrewlamb/Software/datafusion/datafusion/expr/src/window_state.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Structures used to hold window function state (for implementing WindowUDFs) |
19 | | |
20 | | use std::{collections::VecDeque, ops::Range, sync::Arc}; |
21 | | |
22 | | use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; |
23 | | |
24 | | use arrow::{ |
25 | | array::ArrayRef, |
26 | | compute::{concat, concat_batches, SortOptions}, |
27 | | datatypes::{DataType, SchemaRef}, |
28 | | record_batch::RecordBatch, |
29 | | }; |
30 | | use datafusion_common::{ |
31 | | internal_err, |
32 | | utils::{compare_rows, get_row_at_idx, search_in_slice}, |
33 | | DataFusionError, Result, ScalarValue, |
34 | | }; |
35 | | |
36 | | /// Holds the state of evaluating a window function |
37 | | #[derive(Debug)] |
38 | | pub struct WindowAggState { |
39 | | /// The range that we calculate the window function |
40 | | pub window_frame_range: Range<usize>, |
41 | | pub window_frame_ctx: Option<WindowFrameContext>, |
42 | | /// The index of the last row that its result is calculated inside the partition record batch buffer. |
43 | | pub last_calculated_index: usize, |
44 | | /// The offset of the deleted row number |
45 | | pub offset_pruned_rows: usize, |
46 | | /// Stores the results calculated by window frame |
47 | | pub out_col: ArrayRef, |
48 | | /// Keeps track of how many rows should be generated to be in sync with input record_batch. |
49 | | // (For each row in the input record batch we need to generate a window result). |
50 | | pub n_row_result_missing: usize, |
51 | | /// flag indicating whether we have received all data for this partition |
52 | | pub is_end: bool, |
53 | | } |
54 | | |
55 | | impl WindowAggState { |
56 | 17 | pub fn prune_state(&mut self, n_prune: usize) { |
57 | 17 | self.window_frame_range = Range { |
58 | 17 | start: self.window_frame_range.start - n_prune, |
59 | 17 | end: self.window_frame_range.end - n_prune, |
60 | 17 | }; |
61 | 17 | self.last_calculated_index -= n_prune; |
62 | 17 | self.offset_pruned_rows += n_prune; |
63 | 17 | |
64 | 17 | match self.window_frame_ctx.as_mut() { |
65 | | // Rows have no state do nothing |
66 | 9 | Some(WindowFrameContext::Rows(_)) => {} |
67 | 8 | Some(WindowFrameContext::Range { .. }) => {} |
68 | 0 | Some(WindowFrameContext::Groups { state, .. }) => { |
69 | 0 | let mut n_group_to_del = 0; |
70 | 0 | for (_, end_idx) in &state.group_end_indices { |
71 | 0 | if n_prune < *end_idx { |
72 | 0 | break; |
73 | 0 | } |
74 | 0 | n_group_to_del += 1; |
75 | | } |
76 | 0 | state.group_end_indices.drain(0..n_group_to_del); |
77 | 0 | state |
78 | 0 | .group_end_indices |
79 | 0 | .iter_mut() |
80 | 0 | .for_each(|(_, start_idx)| *start_idx -= n_prune); |
81 | 0 | state.current_group_idx -= n_group_to_del; |
82 | 0 | } |
83 | 0 | None => {} |
84 | | }; |
85 | 17 | } |
86 | | |
87 | 21 | pub fn update( |
88 | 21 | &mut self, |
89 | 21 | out_col: &ArrayRef, |
90 | 21 | partition_batch_state: &PartitionBatchState, |
91 | 21 | ) -> Result<()> { |
92 | 21 | self.last_calculated_index += out_col.len(); |
93 | 21 | self.out_col = concat(&[&self.out_col, &out_col])?0 ; |
94 | 21 | self.n_row_result_missing = |
95 | 21 | partition_batch_state.record_batch.num_rows() - self.last_calculated_index; |
96 | 21 | self.is_end = partition_batch_state.is_end; |
97 | 21 | Ok(()) |
98 | 21 | } |
99 | | |
100 | 6 | pub fn new(out_type: &DataType) -> Result<Self> { |
101 | 6 | let empty_out_col = ScalarValue::try_from(out_type)?0 .to_array_of_size(0)?0 ; |
102 | 6 | Ok(Self { |
103 | 6 | window_frame_range: Range { start: 0, end: 0 }, |
104 | 6 | window_frame_ctx: None, |
105 | 6 | last_calculated_index: 0, |
106 | 6 | offset_pruned_rows: 0, |
107 | 6 | out_col: empty_out_col, |
108 | 6 | n_row_result_missing: 0, |
109 | 6 | is_end: false, |
110 | 6 | }) |
111 | 6 | } |
112 | | } |
113 | | |
114 | | /// This object stores the window frame state for use in incremental calculations. |
115 | | #[derive(Debug)] |
116 | | pub enum WindowFrameContext { |
117 | | /// ROWS frames are inherently stateless. |
118 | | Rows(Arc<WindowFrame>), |
119 | | /// RANGE frames are stateful, they store indices specifying where the |
120 | | /// previous search left off. This amortizes the overall cost to O(n) |
121 | | /// where n denotes the row count. |
122 | | Range { |
123 | | window_frame: Arc<WindowFrame>, |
124 | | state: WindowFrameStateRange, |
125 | | }, |
126 | | /// GROUPS frames are stateful, they store group boundaries and indices |
127 | | /// specifying where the previous search left off. This amortizes the |
128 | | /// overall cost to O(n) where n denotes the row count. |
129 | | Groups { |
130 | | window_frame: Arc<WindowFrame>, |
131 | | state: WindowFrameStateGroups, |
132 | | }, |
133 | | } |
134 | | |
135 | | impl WindowFrameContext { |
136 | | /// Create a new state object for the given window frame. |
137 | 6 | pub fn new(window_frame: Arc<WindowFrame>, sort_options: Vec<SortOptions>) -> Self { |
138 | 6 | match window_frame.units { |
139 | 3 | WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame), |
140 | 3 | WindowFrameUnits::Range => WindowFrameContext::Range { |
141 | 3 | window_frame, |
142 | 3 | state: WindowFrameStateRange::new(sort_options), |
143 | 3 | }, |
144 | 0 | WindowFrameUnits::Groups => WindowFrameContext::Groups { |
145 | 0 | window_frame, |
146 | 0 | state: WindowFrameStateGroups::default(), |
147 | 0 | }, |
148 | | } |
149 | 6 | } |
150 | | |
151 | | /// This function calculates beginning/ending indices for the frame of the current row. |
152 | 40 | pub fn calculate_range( |
153 | 40 | &mut self, |
154 | 40 | range_columns: &[ArrayRef], |
155 | 40 | last_range: &Range<usize>, |
156 | 40 | length: usize, |
157 | 40 | idx: usize, |
158 | 40 | ) -> Result<Range<usize>> { |
159 | 40 | match self { |
160 | 27 | WindowFrameContext::Rows(window_frame) => { |
161 | 27 | Self::calculate_range_rows(window_frame, length, idx) |
162 | | } |
163 | | // Sort options is used in RANGE mode calculations because the |
164 | | // ordering or position of NULLs impact range calculations and |
165 | | // comparison of rows. |
166 | | WindowFrameContext::Range { |
167 | 13 | window_frame, |
168 | 13 | ref mut state, |
169 | 13 | } => state.calculate_range( |
170 | 13 | window_frame, |
171 | 13 | last_range, |
172 | 13 | range_columns, |
173 | 13 | length, |
174 | 13 | idx, |
175 | 13 | ), |
176 | | // Sort options is not used in GROUPS mode calculations as the |
177 | | // inequality of two rows indicates a group change, and ordering |
178 | | // or position of NULLs do not impact inequality. |
179 | | WindowFrameContext::Groups { |
180 | 0 | window_frame, |
181 | 0 | ref mut state, |
182 | 0 | } => state.calculate_range(window_frame, range_columns, length, idx), |
183 | | } |
184 | 40 | } |
185 | | |
186 | | /// This function calculates beginning/ending indices for the frame of the current row. |
187 | 27 | fn calculate_range_rows( |
188 | 27 | window_frame: &Arc<WindowFrame>, |
189 | 27 | length: usize, |
190 | 27 | idx: usize, |
191 | 27 | ) -> Result<Range<usize>> { |
192 | 27 | let start = match window_frame.start_bound { |
193 | | // UNBOUNDED PRECEDING |
194 | 27 | WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, |
195 | 0 | WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { |
196 | 0 | if idx >= n as usize { |
197 | 0 | idx - n as usize |
198 | | } else { |
199 | 0 | 0 |
200 | | } |
201 | | } |
202 | 0 | WindowFrameBound::CurrentRow => idx, |
203 | | // UNBOUNDED FOLLOWING |
204 | | WindowFrameBound::Following(ScalarValue::UInt64(None)) => { |
205 | 0 | return internal_err!( |
206 | 0 | "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'" |
207 | 0 | ) |
208 | | } |
209 | 0 | WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { |
210 | 0 | std::cmp::min(idx + n as usize, length) |
211 | | } |
212 | | // ERRONEOUS FRAMES |
213 | | WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { |
214 | 0 | return internal_err!("Rows should be Uint") |
215 | | } |
216 | | }; |
217 | 27 | let end = match window_frame.end_bound { |
218 | | // UNBOUNDED PRECEDING |
219 | | WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { |
220 | 0 | return internal_err!( |
221 | 0 | "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'" |
222 | 0 | ) |
223 | | } |
224 | 0 | WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { |
225 | 0 | if idx >= n as usize { |
226 | 0 | idx - n as usize + 1 |
227 | | } else { |
228 | 0 | 0 |
229 | | } |
230 | | } |
231 | 27 | WindowFrameBound::CurrentRow => idx + 1, |
232 | | // UNBOUNDED FOLLOWING |
233 | 0 | WindowFrameBound::Following(ScalarValue::UInt64(None)) => length, |
234 | 0 | WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { |
235 | 0 | std::cmp::min(idx + n as usize + 1, length) |
236 | | } |
237 | | // ERRONEOUS FRAMES |
238 | | WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { |
239 | 0 | return internal_err!("Rows should be Uint") |
240 | | } |
241 | | }; |
242 | 27 | Ok(Range { start, end }) |
243 | 27 | } |
244 | | } |
245 | | |
246 | | /// State for each unique partition determined according to PARTITION BY column(s) |
247 | | #[derive(Debug)] |
248 | | pub struct PartitionBatchState { |
249 | | /// The record batch belonging to current partition |
250 | | pub record_batch: RecordBatch, |
251 | | /// The record batch that contains the most recent row at the input. |
252 | | /// Please note that this batch doesn't necessarily have the same partitioning |
253 | | /// with `record_batch`. Keeping track of this batch enables us to prune |
254 | | /// `record_batch` when cardinality of the partition is sparse. |
255 | | pub most_recent_row: Option<RecordBatch>, |
256 | | /// Flag indicating whether we have received all data for this partition |
257 | | pub is_end: bool, |
258 | | /// Number of rows emitted for each partition |
259 | | pub n_out_row: usize, |
260 | | } |
261 | | |
262 | | impl PartitionBatchState { |
263 | 4 | pub fn new(schema: SchemaRef) -> Self { |
264 | 4 | Self { |
265 | 4 | record_batch: RecordBatch::new_empty(schema), |
266 | 4 | most_recent_row: None, |
267 | 4 | is_end: false, |
268 | 4 | n_out_row: 0, |
269 | 4 | } |
270 | 4 | } |
271 | | |
272 | 8 | pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> { |
273 | 8 | self.record_batch = |
274 | 8 | concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?0 ; |
275 | 8 | Ok(()) |
276 | 8 | } |
277 | | |
278 | 9 | pub fn set_most_recent_row(&mut self, batch: RecordBatch) { |
279 | 9 | // It is enough for the batch to contain only a single row (the rest |
280 | 9 | // are not necessary). |
281 | 9 | self.most_recent_row = Some(batch); |
282 | 9 | } |
283 | | } |
284 | | |
285 | | /// This structure encapsulates all the state information we require as we scan |
286 | | /// ranges of data while processing RANGE frames. |
287 | | /// Attribute `sort_options` stores the column ordering specified by the ORDER |
288 | | /// BY clause. This information is used to calculate the range. |
289 | | #[derive(Debug, Default)] |
290 | | pub struct WindowFrameStateRange { |
291 | | sort_options: Vec<SortOptions>, |
292 | | } |
293 | | |
294 | | impl WindowFrameStateRange { |
295 | | /// Create a new object to store the search state. |
296 | 3 | fn new(sort_options: Vec<SortOptions>) -> Self { |
297 | 3 | Self { sort_options } |
298 | 3 | } |
299 | | |
300 | | /// This function calculates beginning/ending indices for the frame of the current row. |
301 | | // Argument `last_range` stores the resulting indices from the previous search. Since the indices only |
302 | | // advance forward, we start from `last_range` subsequently. Thus, the overall |
303 | | // time complexity of linear search amortizes to O(n) where n denotes the total |
304 | | // row count. |
305 | 13 | fn calculate_range( |
306 | 13 | &mut self, |
307 | 13 | window_frame: &Arc<WindowFrame>, |
308 | 13 | last_range: &Range<usize>, |
309 | 13 | range_columns: &[ArrayRef], |
310 | 13 | length: usize, |
311 | 13 | idx: usize, |
312 | 13 | ) -> Result<Range<usize>> { |
313 | 13 | let start = match window_frame.start_bound { |
314 | 0 | WindowFrameBound::Preceding(ref n) => { |
315 | 0 | if n.is_null() { |
316 | | // UNBOUNDED PRECEDING |
317 | 0 | 0 |
318 | | } else { |
319 | 0 | self.calculate_index_of_row::<true, true>( |
320 | 0 | range_columns, |
321 | 0 | last_range, |
322 | 0 | idx, |
323 | 0 | Some(n), |
324 | 0 | length, |
325 | 0 | )? |
326 | | } |
327 | | } |
328 | 13 | WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>( |
329 | 13 | range_columns, |
330 | 13 | last_range, |
331 | 13 | idx, |
332 | 13 | None, |
333 | 13 | length, |
334 | 13 | )?0 , |
335 | 0 | WindowFrameBound::Following(ref n) => self |
336 | 0 | .calculate_index_of_row::<true, false>( |
337 | 0 | range_columns, |
338 | 0 | last_range, |
339 | 0 | idx, |
340 | 0 | Some(n), |
341 | 0 | length, |
342 | 0 | )?, |
343 | | }; |
344 | 13 | let end = match window_frame.end_bound { |
345 | 0 | WindowFrameBound::Preceding(ref n) => self |
346 | 0 | .calculate_index_of_row::<false, true>( |
347 | 0 | range_columns, |
348 | 0 | last_range, |
349 | 0 | idx, |
350 | 0 | Some(n), |
351 | 0 | length, |
352 | 0 | )?, |
353 | 0 | WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>( |
354 | 0 | range_columns, |
355 | 0 | last_range, |
356 | 0 | idx, |
357 | 0 | None, |
358 | 0 | length, |
359 | 0 | )?, |
360 | 13 | WindowFrameBound::Following(ref n) => { |
361 | 13 | if n.is_null() { |
362 | | // UNBOUNDED FOLLOWING |
363 | 0 | length |
364 | | } else { |
365 | 13 | self.calculate_index_of_row::<false, false>( |
366 | 13 | range_columns, |
367 | 13 | last_range, |
368 | 13 | idx, |
369 | 13 | Some(n), |
370 | 13 | length, |
371 | 13 | )?0 |
372 | | } |
373 | | } |
374 | | }; |
375 | 13 | Ok(Range { start, end }) |
376 | 13 | } |
377 | | |
378 | | /// This function does the heavy lifting when finding range boundaries. It is meant to be |
379 | | /// called twice, in succession, to get window frame start and end indices (with `SIDE` |
380 | | /// supplied as true and false, respectively). |
381 | 26 | fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>( |
382 | 26 | &mut self, |
383 | 26 | range_columns: &[ArrayRef], |
384 | 26 | last_range: &Range<usize>, |
385 | 26 | idx: usize, |
386 | 26 | delta: Option<&ScalarValue>, |
387 | 26 | length: usize, |
388 | 26 | ) -> Result<usize> { |
389 | 26 | let current_row_values = get_row_at_idx(range_columns, idx)?0 ; |
390 | 26 | let end_range = if let Some(delta13 ) = delta { |
391 | 13 | let is_descending: bool = self |
392 | 13 | .sort_options |
393 | 13 | .first() |
394 | 13 | .ok_or_else(|| { |
395 | 0 | DataFusionError::Internal( |
396 | 0 | "Sort options unexpectedly absent in a window frame".to_string(), |
397 | 0 | ) |
398 | 13 | })?0 |
399 | | .descending; |
400 | | |
401 | 13 | current_row_values |
402 | 13 | .iter() |
403 | 13 | .map(|value| { |
404 | 13 | if value.is_null() { |
405 | 0 | return Ok(value.clone()); |
406 | 13 | } |
407 | 13 | if SEARCH_SIDE == is_descending { |
408 | | // TODO: Handle positive overflows. |
409 | 13 | value.add(delta) |
410 | 0 | } else if value.is_unsigned() && value < delta { |
411 | | // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. |
412 | | // If we decide to implement a "default" construction mechanism for ScalarValue, |
413 | | // change the following statement to use that. |
414 | 0 | value.sub(value) |
415 | | } else { |
416 | | // TODO: Handle negative overflows. |
417 | 0 | value.sub(delta) |
418 | | } |
419 | 13 | }) |
420 | 13 | .collect::<Result<Vec<ScalarValue>>>()?0 |
421 | | } else { |
422 | 13 | current_row_values |
423 | | }; |
424 | 26 | let search_start = if SIDE { |
425 | 13 | last_range.start |
426 | | } else { |
427 | 13 | last_range.end |
428 | | }; |
429 | 41 | let compare_fn26 = |current: &[ScalarValue], target: &[ScalarValue]| { |
430 | 41 | let cmp = compare_rows(current, target, &self.sort_options)?0 ; |
431 | 41 | Ok(if SIDE { cmp.is_lt()21 } else { cmp.is_le()20 }) |
432 | 41 | }; |
433 | 26 | search_in_slice(range_columns, &end_range, compare_fn, search_start, length) |
434 | 26 | } |
435 | | } |
436 | | |
437 | | // In GROUPS mode, rows with duplicate sorting values are grouped together. |
438 | | // Therefore, there must be an ORDER BY clause in the window definition to use GROUPS mode. |
439 | | // The syntax is as follows: |
440 | | // GROUPS frame_start [ frame_exclusion ] |
441 | | // GROUPS BETWEEN frame_start AND frame_end [ frame_exclusion ] |
442 | | // The optional frame_exclusion specifier is not yet supported. |
443 | | // The frame_start and frame_end parameters allow us to specify which rows the window |
444 | | // frame starts and ends with. They accept the following values: |
445 | | // - UNBOUNDED PRECEDING: Start with the first row of the partition. Possible only in frame_start. |
446 | | // - offset PRECEDING: When used in frame_start, it refers to the first row of the group |
447 | | // that comes "offset" groups before the current group (i.e. the group |
448 | | // containing the current row). When used in frame_end, it refers to the |
449 | | // last row of the group that comes "offset" groups before the current group. |
450 | | // - CURRENT ROW: When used in frame_start, it refers to the first row of the group containing |
451 | | // the current row. When used in frame_end, it refers to the last row of the group |
452 | | // containing the current row. |
453 | | // - offset FOLLOWING: When used in frame_start, it refers to the first row of the group |
454 | | // that comes "offset" groups after the current group (i.e. the group |
455 | | // containing the current row). When used in frame_end, it refers to the |
456 | | // last row of the group that comes "offset" groups after the current group. |
457 | | // - UNBOUNDED FOLLOWING: End with the last row of the partition. Possible only in frame_end. |
458 | | |
459 | | /// This structure encapsulates all the state information we require as we |
460 | | /// scan groups of data while processing window frames. |
461 | | #[derive(Debug, Default)] |
462 | | pub struct WindowFrameStateGroups { |
463 | | /// A tuple containing group values and the row index where the group ends. |
464 | | /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to |
465 | | /// [([1, 1], 2), ([2, 1], 4), ...]. |
466 | | pub group_end_indices: VecDeque<(Vec<ScalarValue>, usize)>, |
467 | | /// The group index to which the row index belongs. |
468 | | pub current_group_idx: usize, |
469 | | } |
470 | | |
471 | | impl WindowFrameStateGroups { |
472 | 0 | fn calculate_range( |
473 | 0 | &mut self, |
474 | 0 | window_frame: &Arc<WindowFrame>, |
475 | 0 | range_columns: &[ArrayRef], |
476 | 0 | length: usize, |
477 | 0 | idx: usize, |
478 | 0 | ) -> Result<Range<usize>> { |
479 | 0 | let start = match window_frame.start_bound { |
480 | 0 | WindowFrameBound::Preceding(ref n) => { |
481 | 0 | if n.is_null() { |
482 | | // UNBOUNDED PRECEDING |
483 | 0 | 0 |
484 | | } else { |
485 | 0 | self.calculate_index_of_row::<true, true>( |
486 | 0 | range_columns, |
487 | 0 | idx, |
488 | 0 | Some(n), |
489 | 0 | length, |
490 | 0 | )? |
491 | | } |
492 | | } |
493 | 0 | WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>( |
494 | 0 | range_columns, |
495 | 0 | idx, |
496 | 0 | None, |
497 | 0 | length, |
498 | 0 | )?, |
499 | 0 | WindowFrameBound::Following(ref n) => self |
500 | 0 | .calculate_index_of_row::<true, false>( |
501 | 0 | range_columns, |
502 | 0 | idx, |
503 | 0 | Some(n), |
504 | 0 | length, |
505 | 0 | )?, |
506 | | }; |
507 | 0 | let end = match window_frame.end_bound { |
508 | 0 | WindowFrameBound::Preceding(ref n) => self |
509 | 0 | .calculate_index_of_row::<false, true>( |
510 | 0 | range_columns, |
511 | 0 | idx, |
512 | 0 | Some(n), |
513 | 0 | length, |
514 | 0 | )?, |
515 | 0 | WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>( |
516 | 0 | range_columns, |
517 | 0 | idx, |
518 | 0 | None, |
519 | 0 | length, |
520 | 0 | )?, |
521 | 0 | WindowFrameBound::Following(ref n) => { |
522 | 0 | if n.is_null() { |
523 | | // UNBOUNDED FOLLOWING |
524 | 0 | length |
525 | | } else { |
526 | 0 | self.calculate_index_of_row::<false, false>( |
527 | 0 | range_columns, |
528 | 0 | idx, |
529 | 0 | Some(n), |
530 | 0 | length, |
531 | 0 | )? |
532 | | } |
533 | | } |
534 | | }; |
535 | 0 | Ok(Range { start, end }) |
536 | 0 | } |
537 | | |
538 | | /// This function does the heavy lifting when finding range boundaries. It is meant to be |
539 | | /// called twice, in succession, to get window frame start and end indices (with `SIDE` |
540 | | /// supplied as true and false, respectively). Generic argument `SEARCH_SIDE` determines |
541 | | /// the sign of `delta` (where true/false represents negative/positive respectively). |
542 | 0 | fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>( |
543 | 0 | &mut self, |
544 | 0 | range_columns: &[ArrayRef], |
545 | 0 | idx: usize, |
546 | 0 | delta: Option<&ScalarValue>, |
547 | 0 | length: usize, |
548 | 0 | ) -> Result<usize> { |
549 | 0 | let delta = if let Some(delta) = delta { |
550 | 0 | if let ScalarValue::UInt64(Some(value)) = delta { |
551 | 0 | *value as usize |
552 | | } else { |
553 | 0 | return internal_err!( |
554 | 0 | "Unexpectedly got a non-UInt64 value in a GROUPS mode window frame" |
555 | 0 | ); |
556 | | } |
557 | | } else { |
558 | 0 | 0 |
559 | | }; |
560 | 0 | let mut group_start = 0; |
561 | 0 | let last_group = self.group_end_indices.back_mut(); |
562 | 0 | if let Some((group_row, group_end)) = last_group { |
563 | 0 | if *group_end < length { |
564 | 0 | let new_group_row = get_row_at_idx(range_columns, *group_end)?; |
565 | | // If last/current group keys are the same, we extend the last group: |
566 | 0 | if new_group_row.eq(group_row) { |
567 | | // Update the end boundary of the group (search right boundary): |
568 | 0 | *group_end = search_in_slice( |
569 | 0 | range_columns, |
570 | 0 | group_row, |
571 | 0 | check_equality, |
572 | 0 | *group_end, |
573 | 0 | length, |
574 | 0 | )?; |
575 | 0 | } |
576 | 0 | } |
577 | | // Start searching from the last group boundary: |
578 | 0 | group_start = *group_end; |
579 | 0 | } |
580 | | |
581 | | // Advance groups until `idx` is inside a group: |
582 | 0 | while idx >= group_start { |
583 | 0 | let group_row = get_row_at_idx(range_columns, group_start)?; |
584 | | // Find end boundary of the group (search right boundary): |
585 | 0 | let group_end = search_in_slice( |
586 | 0 | range_columns, |
587 | 0 | &group_row, |
588 | 0 | check_equality, |
589 | 0 | group_start, |
590 | 0 | length, |
591 | 0 | )?; |
592 | 0 | self.group_end_indices.push_back((group_row, group_end)); |
593 | 0 | group_start = group_end; |
594 | | } |
595 | | |
596 | | // Update the group index `idx` belongs to: |
597 | 0 | while self.current_group_idx < self.group_end_indices.len() |
598 | 0 | && idx >= self.group_end_indices[self.current_group_idx].1 |
599 | 0 | { |
600 | 0 | self.current_group_idx += 1; |
601 | 0 | } |
602 | | |
603 | | // Find the group index of the frame boundary: |
604 | 0 | let group_idx = if SEARCH_SIDE { |
605 | 0 | if self.current_group_idx > delta { |
606 | 0 | self.current_group_idx - delta |
607 | | } else { |
608 | 0 | 0 |
609 | | } |
610 | | } else { |
611 | 0 | self.current_group_idx + delta |
612 | | }; |
613 | | |
614 | | // Extend `group_start_indices` until it includes at least `group_idx`: |
615 | 0 | while self.group_end_indices.len() <= group_idx && group_start < length { |
616 | 0 | let group_row = get_row_at_idx(range_columns, group_start)?; |
617 | | // Find end boundary of the group (search right boundary): |
618 | 0 | let group_end = search_in_slice( |
619 | 0 | range_columns, |
620 | 0 | &group_row, |
621 | 0 | check_equality, |
622 | 0 | group_start, |
623 | 0 | length, |
624 | 0 | )?; |
625 | 0 | self.group_end_indices.push_back((group_row, group_end)); |
626 | 0 | group_start = group_end; |
627 | | } |
628 | | |
629 | | // Calculate index of the group boundary: |
630 | 0 | Ok(match (SIDE, SEARCH_SIDE) { |
631 | | // Window frame start: |
632 | | (true, _) => { |
633 | 0 | let group_idx = std::cmp::min(group_idx, self.group_end_indices.len()); |
634 | 0 | if group_idx > 0 { |
635 | | // Normally, start at the boundary of the previous group. |
636 | 0 | self.group_end_indices[group_idx - 1].1 |
637 | | } else { |
638 | | // If previous group is out of the table, start at zero. |
639 | 0 | 0 |
640 | | } |
641 | | } |
642 | | // Window frame end, PRECEDING n |
643 | | (false, true) => { |
644 | 0 | if self.current_group_idx >= delta { |
645 | 0 | let group_idx = self.current_group_idx - delta; |
646 | 0 | self.group_end_indices[group_idx].1 |
647 | | } else { |
648 | | // Group is out of the table, therefore end at zero. |
649 | 0 | 0 |
650 | | } |
651 | | } |
652 | | // Window frame end, FOLLOWING n |
653 | | (false, false) => { |
654 | 0 | let group_idx = std::cmp::min( |
655 | 0 | self.current_group_idx + delta, |
656 | 0 | self.group_end_indices.len() - 1, |
657 | 0 | ); |
658 | 0 | self.group_end_indices[group_idx].1 |
659 | | } |
660 | | }) |
661 | 0 | } |
662 | | } |
663 | | |
664 | 0 | fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<bool> { |
665 | 0 | Ok(current == target) |
666 | 0 | } |
667 | | |
668 | | #[cfg(test)] |
669 | | mod tests { |
670 | | use super::*; |
671 | | |
672 | | use arrow::array::Float64Array; |
673 | | |
674 | | fn get_test_data() -> (Vec<ArrayRef>, Vec<SortOptions>) { |
675 | | let range_columns: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![ |
676 | | 5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11., |
677 | | ]))]; |
678 | | let sort_options = vec![SortOptions { |
679 | | descending: false, |
680 | | nulls_first: false, |
681 | | }]; |
682 | | |
683 | | (range_columns, sort_options) |
684 | | } |
685 | | |
686 | | fn assert_expected( |
687 | | expected_results: Vec<(Range<usize>, usize)>, |
688 | | window_frame: &Arc<WindowFrame>, |
689 | | ) -> Result<()> { |
690 | | let mut window_frame_groups = WindowFrameStateGroups::default(); |
691 | | let (range_columns, _) = get_test_data(); |
692 | | let n_row = range_columns[0].len(); |
693 | | for (idx, (expected_range, expected_group_idx)) in |
694 | | expected_results.into_iter().enumerate() |
695 | | { |
696 | | let range = window_frame_groups.calculate_range( |
697 | | window_frame, |
698 | | &range_columns, |
699 | | n_row, |
700 | | idx, |
701 | | )?; |
702 | | assert_eq!(range, expected_range); |
703 | | assert_eq!(window_frame_groups.current_group_idx, expected_group_idx); |
704 | | } |
705 | | Ok(()) |
706 | | } |
707 | | |
708 | | #[test] |
709 | | fn test_window_frame_group_boundaries() -> Result<()> { |
710 | | let window_frame = Arc::new(WindowFrame::new_bounds( |
711 | | WindowFrameUnits::Groups, |
712 | | WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), |
713 | | WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), |
714 | | )); |
715 | | let expected_results = vec![ |
716 | | (Range { start: 0, end: 2 }, 0), |
717 | | (Range { start: 0, end: 4 }, 1), |
718 | | (Range { start: 1, end: 5 }, 2), |
719 | | (Range { start: 1, end: 5 }, 2), |
720 | | (Range { start: 2, end: 8 }, 3), |
721 | | (Range { start: 4, end: 9 }, 4), |
722 | | (Range { start: 4, end: 9 }, 4), |
723 | | (Range { start: 4, end: 9 }, 4), |
724 | | (Range { start: 5, end: 9 }, 5), |
725 | | ]; |
726 | | assert_expected(expected_results, &window_frame) |
727 | | } |
728 | | |
729 | | #[test] |
730 | | fn test_window_frame_group_boundaries_both_following() -> Result<()> { |
731 | | let window_frame = Arc::new(WindowFrame::new_bounds( |
732 | | WindowFrameUnits::Groups, |
733 | | WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), |
734 | | WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), |
735 | | )); |
736 | | let expected_results = vec![ |
737 | | (Range::<usize> { start: 1, end: 4 }, 0), |
738 | | (Range::<usize> { start: 2, end: 5 }, 1), |
739 | | (Range::<usize> { start: 4, end: 8 }, 2), |
740 | | (Range::<usize> { start: 4, end: 8 }, 2), |
741 | | (Range::<usize> { start: 5, end: 9 }, 3), |
742 | | (Range::<usize> { start: 8, end: 9 }, 4), |
743 | | (Range::<usize> { start: 8, end: 9 }, 4), |
744 | | (Range::<usize> { start: 8, end: 9 }, 4), |
745 | | (Range::<usize> { start: 9, end: 9 }, 5), |
746 | | ]; |
747 | | assert_expected(expected_results, &window_frame) |
748 | | } |
749 | | |
750 | | #[test] |
751 | | fn test_window_frame_group_boundaries_both_preceding() -> Result<()> { |
752 | | let window_frame = Arc::new(WindowFrame::new_bounds( |
753 | | WindowFrameUnits::Groups, |
754 | | WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), |
755 | | WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), |
756 | | )); |
757 | | let expected_results = vec![ |
758 | | (Range::<usize> { start: 0, end: 0 }, 0), |
759 | | (Range::<usize> { start: 0, end: 1 }, 1), |
760 | | (Range::<usize> { start: 0, end: 2 }, 2), |
761 | | (Range::<usize> { start: 0, end: 2 }, 2), |
762 | | (Range::<usize> { start: 1, end: 4 }, 3), |
763 | | (Range::<usize> { start: 2, end: 5 }, 4), |
764 | | (Range::<usize> { start: 2, end: 5 }, 4), |
765 | | (Range::<usize> { start: 2, end: 5 }, 4), |
766 | | (Range::<usize> { start: 4, end: 8 }, 5), |
767 | | ]; |
768 | | assert_expected(expected_results, &window_frame) |
769 | | } |
770 | | } |