Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/expr/src/window_state.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Structures used to hold window function state (for implementing WindowUDFs)
19
20
use std::{collections::VecDeque, ops::Range, sync::Arc};
21
22
use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
23
24
use arrow::{
25
    array::ArrayRef,
26
    compute::{concat, concat_batches, SortOptions},
27
    datatypes::{DataType, SchemaRef},
28
    record_batch::RecordBatch,
29
};
30
use datafusion_common::{
31
    internal_err,
32
    utils::{compare_rows, get_row_at_idx, search_in_slice},
33
    DataFusionError, Result, ScalarValue,
34
};
35
36
/// Holds the state of evaluating a window function
37
#[derive(Debug)]
38
pub struct WindowAggState {
39
    /// The range that we calculate the window function
40
    pub window_frame_range: Range<usize>,
41
    pub window_frame_ctx: Option<WindowFrameContext>,
42
    /// The index of the last row that its result is calculated inside the partition record batch buffer.
43
    pub last_calculated_index: usize,
44
    /// The offset of the deleted row number
45
    pub offset_pruned_rows: usize,
46
    /// Stores the results calculated by window frame
47
    pub out_col: ArrayRef,
48
    /// Keeps track of how many rows should be generated to be in sync with input record_batch.
49
    // (For each row in the input record batch we need to generate a window result).
50
    pub n_row_result_missing: usize,
51
    /// flag indicating whether we have received all data for this partition
52
    pub is_end: bool,
53
}
54
55
impl WindowAggState {
56
17
    pub fn prune_state(&mut self, n_prune: usize) {
57
17
        self.window_frame_range = Range {
58
17
            start: self.window_frame_range.start - n_prune,
59
17
            end: self.window_frame_range.end - n_prune,
60
17
        };
61
17
        self.last_calculated_index -= n_prune;
62
17
        self.offset_pruned_rows += n_prune;
63
17
64
17
        match self.window_frame_ctx.as_mut() {
65
            // Rows have no state do nothing
66
9
            Some(WindowFrameContext::Rows(_)) => {}
67
8
            Some(WindowFrameContext::Range { .. }) => {}
68
0
            Some(WindowFrameContext::Groups { state, .. }) => {
69
0
                let mut n_group_to_del = 0;
70
0
                for (_, end_idx) in &state.group_end_indices {
71
0
                    if n_prune < *end_idx {
72
0
                        break;
73
0
                    }
74
0
                    n_group_to_del += 1;
75
                }
76
0
                state.group_end_indices.drain(0..n_group_to_del);
77
0
                state
78
0
                    .group_end_indices
79
0
                    .iter_mut()
80
0
                    .for_each(|(_, start_idx)| *start_idx -= n_prune);
81
0
                state.current_group_idx -= n_group_to_del;
82
0
            }
83
0
            None => {}
84
        };
85
17
    }
86
87
21
    pub fn update(
88
21
        &mut self,
89
21
        out_col: &ArrayRef,
90
21
        partition_batch_state: &PartitionBatchState,
91
21
    ) -> Result<()> {
92
21
        self.last_calculated_index += out_col.len();
93
21
        self.out_col = concat(&[&self.out_col, &out_col])
?0
;
94
21
        self.n_row_result_missing =
95
21
            partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
96
21
        self.is_end = partition_batch_state.is_end;
97
21
        Ok(())
98
21
    }
99
100
6
    pub fn new(out_type: &DataType) -> Result<Self> {
101
6
        let empty_out_col = ScalarValue::try_from(out_type)
?0
.to_array_of_size(0)
?0
;
102
6
        Ok(Self {
103
6
            window_frame_range: Range { start: 0, end: 0 },
104
6
            window_frame_ctx: None,
105
6
            last_calculated_index: 0,
106
6
            offset_pruned_rows: 0,
107
6
            out_col: empty_out_col,
108
6
            n_row_result_missing: 0,
109
6
            is_end: false,
110
6
        })
111
6
    }
112
}
113
114
/// This object stores the window frame state for use in incremental calculations.
115
#[derive(Debug)]
116
pub enum WindowFrameContext {
117
    /// ROWS frames are inherently stateless.
118
    Rows(Arc<WindowFrame>),
119
    /// RANGE frames are stateful, they store indices specifying where the
120
    /// previous search left off. This amortizes the overall cost to O(n)
121
    /// where n denotes the row count.
122
    Range {
123
        window_frame: Arc<WindowFrame>,
124
        state: WindowFrameStateRange,
125
    },
126
    /// GROUPS frames are stateful, they store group boundaries and indices
127
    /// specifying where the previous search left off. This amortizes the
128
    /// overall cost to O(n) where n denotes the row count.
129
    Groups {
130
        window_frame: Arc<WindowFrame>,
131
        state: WindowFrameStateGroups,
132
    },
133
}
134
135
impl WindowFrameContext {
136
    /// Create a new state object for the given window frame.
137
6
    pub fn new(window_frame: Arc<WindowFrame>, sort_options: Vec<SortOptions>) -> Self {
138
6
        match window_frame.units {
139
3
            WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame),
140
3
            WindowFrameUnits::Range => WindowFrameContext::Range {
141
3
                window_frame,
142
3
                state: WindowFrameStateRange::new(sort_options),
143
3
            },
144
0
            WindowFrameUnits::Groups => WindowFrameContext::Groups {
145
0
                window_frame,
146
0
                state: WindowFrameStateGroups::default(),
147
0
            },
148
        }
149
6
    }
150
151
    /// This function calculates beginning/ending indices for the frame of the current row.
152
40
    pub fn calculate_range(
153
40
        &mut self,
154
40
        range_columns: &[ArrayRef],
155
40
        last_range: &Range<usize>,
156
40
        length: usize,
157
40
        idx: usize,
158
40
    ) -> Result<Range<usize>> {
159
40
        match self {
160
27
            WindowFrameContext::Rows(window_frame) => {
161
27
                Self::calculate_range_rows(window_frame, length, idx)
162
            }
163
            // Sort options is used in RANGE mode calculations because the
164
            // ordering or position of NULLs impact range calculations and
165
            // comparison of rows.
166
            WindowFrameContext::Range {
167
13
                window_frame,
168
13
                ref mut state,
169
13
            } => state.calculate_range(
170
13
                window_frame,
171
13
                last_range,
172
13
                range_columns,
173
13
                length,
174
13
                idx,
175
13
            ),
176
            // Sort options is not used in GROUPS mode calculations as the
177
            // inequality of two rows indicates a group change, and ordering
178
            // or position of NULLs do not impact inequality.
179
            WindowFrameContext::Groups {
180
0
                window_frame,
181
0
                ref mut state,
182
0
            } => state.calculate_range(window_frame, range_columns, length, idx),
183
        }
184
40
    }
185
186
    /// This function calculates beginning/ending indices for the frame of the current row.
187
27
    fn calculate_range_rows(
188
27
        window_frame: &Arc<WindowFrame>,
189
27
        length: usize,
190
27
        idx: usize,
191
27
    ) -> Result<Range<usize>> {
192
27
        let start = match window_frame.start_bound {
193
            // UNBOUNDED PRECEDING
194
27
            WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
195
0
            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
196
0
                if idx >= n as usize {
197
0
                    idx - n as usize
198
                } else {
199
0
                    0
200
                }
201
            }
202
0
            WindowFrameBound::CurrentRow => idx,
203
            // UNBOUNDED FOLLOWING
204
            WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
205
0
                return internal_err!(
206
0
                    "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'"
207
0
                )
208
            }
209
0
            WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
210
0
                std::cmp::min(idx + n as usize, length)
211
            }
212
            // ERRONEOUS FRAMES
213
            WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
214
0
                return internal_err!("Rows should be Uint")
215
            }
216
        };
217
27
        let end = match window_frame.end_bound {
218
            // UNBOUNDED PRECEDING
219
            WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
220
0
                return internal_err!(
221
0
                    "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'"
222
0
                )
223
            }
224
0
            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
225
0
                if idx >= n as usize {
226
0
                    idx - n as usize + 1
227
                } else {
228
0
                    0
229
                }
230
            }
231
27
            WindowFrameBound::CurrentRow => idx + 1,
232
            // UNBOUNDED FOLLOWING
233
0
            WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
234
0
            WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
235
0
                std::cmp::min(idx + n as usize + 1, length)
236
            }
237
            // ERRONEOUS FRAMES
238
            WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
239
0
                return internal_err!("Rows should be Uint")
240
            }
241
        };
242
27
        Ok(Range { start, end })
243
27
    }
244
}
245
246
/// State for each unique partition determined according to PARTITION BY column(s)
247
#[derive(Debug)]
248
pub struct PartitionBatchState {
249
    /// The record batch belonging to current partition
250
    pub record_batch: RecordBatch,
251
    /// The record batch that contains the most recent row at the input.
252
    /// Please note that this batch doesn't necessarily have the same partitioning
253
    /// with `record_batch`. Keeping track of this batch enables us to prune
254
    /// `record_batch` when cardinality of the partition is sparse.
255
    pub most_recent_row: Option<RecordBatch>,
256
    /// Flag indicating whether we have received all data for this partition
257
    pub is_end: bool,
258
    /// Number of rows emitted for each partition
259
    pub n_out_row: usize,
260
}
261
262
impl PartitionBatchState {
263
4
    pub fn new(schema: SchemaRef) -> Self {
264
4
        Self {
265
4
            record_batch: RecordBatch::new_empty(schema),
266
4
            most_recent_row: None,
267
4
            is_end: false,
268
4
            n_out_row: 0,
269
4
        }
270
4
    }
271
272
8
    pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> {
273
8
        self.record_batch =
274
8
            concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])
?0
;
275
8
        Ok(())
276
8
    }
277
278
9
    pub fn set_most_recent_row(&mut self, batch: RecordBatch) {
279
9
        // It is enough for the batch to contain only a single row (the rest
280
9
        // are not necessary).
281
9
        self.most_recent_row = Some(batch);
282
9
    }
283
}
284
285
/// This structure encapsulates all the state information we require as we scan
286
/// ranges of data while processing RANGE frames.
287
/// Attribute `sort_options` stores the column ordering specified by the ORDER
288
/// BY clause. This information is used to calculate the range.
289
#[derive(Debug, Default)]
290
pub struct WindowFrameStateRange {
291
    sort_options: Vec<SortOptions>,
292
}
293
294
impl WindowFrameStateRange {
295
    /// Create a new object to store the search state.
296
3
    fn new(sort_options: Vec<SortOptions>) -> Self {
297
3
        Self { sort_options }
298
3
    }
299
300
    /// This function calculates beginning/ending indices for the frame of the current row.
301
    // Argument `last_range` stores the resulting indices from the previous search. Since the indices only
302
    // advance forward, we start from `last_range` subsequently. Thus, the overall
303
    // time complexity of linear search amortizes to O(n) where n denotes the total
304
    // row count.
305
13
    fn calculate_range(
306
13
        &mut self,
307
13
        window_frame: &Arc<WindowFrame>,
308
13
        last_range: &Range<usize>,
309
13
        range_columns: &[ArrayRef],
310
13
        length: usize,
311
13
        idx: usize,
312
13
    ) -> Result<Range<usize>> {
313
13
        let start = match window_frame.start_bound {
314
0
            WindowFrameBound::Preceding(ref n) => {
315
0
                if n.is_null() {
316
                    // UNBOUNDED PRECEDING
317
0
                    0
318
                } else {
319
0
                    self.calculate_index_of_row::<true, true>(
320
0
                        range_columns,
321
0
                        last_range,
322
0
                        idx,
323
0
                        Some(n),
324
0
                        length,
325
0
                    )?
326
                }
327
            }
328
13
            WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
329
13
                range_columns,
330
13
                last_range,
331
13
                idx,
332
13
                None,
333
13
                length,
334
13
            )
?0
,
335
0
            WindowFrameBound::Following(ref n) => self
336
0
                .calculate_index_of_row::<true, false>(
337
0
                    range_columns,
338
0
                    last_range,
339
0
                    idx,
340
0
                    Some(n),
341
0
                    length,
342
0
                )?,
343
        };
344
13
        let end = match window_frame.end_bound {
345
0
            WindowFrameBound::Preceding(ref n) => self
346
0
                .calculate_index_of_row::<false, true>(
347
0
                    range_columns,
348
0
                    last_range,
349
0
                    idx,
350
0
                    Some(n),
351
0
                    length,
352
0
                )?,
353
0
            WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
354
0
                range_columns,
355
0
                last_range,
356
0
                idx,
357
0
                None,
358
0
                length,
359
0
            )?,
360
13
            WindowFrameBound::Following(ref n) => {
361
13
                if n.is_null() {
362
                    // UNBOUNDED FOLLOWING
363
0
                    length
364
                } else {
365
13
                    self.calculate_index_of_row::<false, false>(
366
13
                        range_columns,
367
13
                        last_range,
368
13
                        idx,
369
13
                        Some(n),
370
13
                        length,
371
13
                    )
?0
372
                }
373
            }
374
        };
375
13
        Ok(Range { start, end })
376
13
    }
377
378
    /// This function does the heavy lifting when finding range boundaries. It is meant to be
379
    /// called twice, in succession, to get window frame start and end indices (with `SIDE`
380
    /// supplied as true and false, respectively).
381
26
    fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
382
26
        &mut self,
383
26
        range_columns: &[ArrayRef],
384
26
        last_range: &Range<usize>,
385
26
        idx: usize,
386
26
        delta: Option<&ScalarValue>,
387
26
        length: usize,
388
26
    ) -> Result<usize> {
389
26
        let current_row_values = get_row_at_idx(range_columns, idx)
?0
;
390
26
        let end_range = if let Some(
delta13
) = delta {
391
13
            let is_descending: bool = self
392
13
                .sort_options
393
13
                .first()
394
13
                .ok_or_else(|| {
395
0
                    DataFusionError::Internal(
396
0
                        "Sort options unexpectedly absent in a window frame".to_string(),
397
0
                    )
398
13
                })
?0
399
                .descending;
400
401
13
            current_row_values
402
13
                .iter()
403
13
                .map(|value| {
404
13
                    if value.is_null() {
405
0
                        return Ok(value.clone());
406
13
                    }
407
13
                    if SEARCH_SIDE == is_descending {
408
                        // TODO: Handle positive overflows.
409
13
                        value.add(delta)
410
0
                    } else if value.is_unsigned() && value < delta {
411
                        // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue.
412
                        //       If we decide to implement a "default" construction mechanism for ScalarValue,
413
                        //       change the following statement to use that.
414
0
                        value.sub(value)
415
                    } else {
416
                        // TODO: Handle negative overflows.
417
0
                        value.sub(delta)
418
                    }
419
13
                })
420
13
                .collect::<Result<Vec<ScalarValue>>>()
?0
421
        } else {
422
13
            current_row_values
423
        };
424
26
        let search_start = if SIDE {
425
13
            last_range.start
426
        } else {
427
13
            last_range.end
428
        };
429
41
        let 
compare_fn26
= |current: &[ScalarValue], target: &[ScalarValue]| {
430
41
            let cmp = compare_rows(current, target, &self.sort_options)
?0
;
431
41
            Ok(if SIDE { 
cmp.is_lt()21
} else {
cmp.is_le()20
})
432
41
        };
433
26
        search_in_slice(range_columns, &end_range, compare_fn, search_start, length)
434
26
    }
435
}
436
437
// In GROUPS mode, rows with duplicate sorting values are grouped together.
438
// Therefore, there must be an ORDER BY clause in the window definition to use GROUPS mode.
439
// The syntax is as follows:
440
//     GROUPS frame_start [ frame_exclusion ]
441
//     GROUPS BETWEEN frame_start AND frame_end [ frame_exclusion ]
442
// The optional frame_exclusion specifier is not yet supported.
443
// The frame_start and frame_end parameters allow us to specify which rows the window
444
// frame starts and ends with. They accept the following values:
445
//    - UNBOUNDED PRECEDING: Start with the first row of the partition. Possible only in frame_start.
446
//    - offset PRECEDING: When used in frame_start, it refers to the first row of the group
447
//                        that comes "offset" groups before the current group (i.e. the group
448
//                        containing the current row). When used in frame_end, it refers to the
449
//                        last row of the group that comes "offset" groups before the current group.
450
//    - CURRENT ROW: When used in frame_start, it refers to the first row of the group containing
451
//                   the current row. When used in frame_end, it refers to the last row of the group
452
//                   containing the current row.
453
//    - offset FOLLOWING: When used in frame_start, it refers to the first row of the group
454
//                        that comes "offset" groups after the current group (i.e. the group
455
//                        containing the current row). When used in frame_end, it refers to the
456
//                        last row of the group that comes "offset" groups after the current group.
457
//    - UNBOUNDED FOLLOWING: End with the last row of the partition. Possible only in frame_end.
458
459
/// This structure encapsulates all the state information we require as we
460
/// scan groups of data while processing window frames.
461
#[derive(Debug, Default)]
462
pub struct WindowFrameStateGroups {
463
    /// A tuple containing group values and the row index where the group ends.
464
    /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to
465
    ///          [([1, 1], 2), ([2, 1], 4), ...].
466
    pub group_end_indices: VecDeque<(Vec<ScalarValue>, usize)>,
467
    /// The group index to which the row index belongs.
468
    pub current_group_idx: usize,
469
}
470
471
impl WindowFrameStateGroups {
472
0
    fn calculate_range(
473
0
        &mut self,
474
0
        window_frame: &Arc<WindowFrame>,
475
0
        range_columns: &[ArrayRef],
476
0
        length: usize,
477
0
        idx: usize,
478
0
    ) -> Result<Range<usize>> {
479
0
        let start = match window_frame.start_bound {
480
0
            WindowFrameBound::Preceding(ref n) => {
481
0
                if n.is_null() {
482
                    // UNBOUNDED PRECEDING
483
0
                    0
484
                } else {
485
0
                    self.calculate_index_of_row::<true, true>(
486
0
                        range_columns,
487
0
                        idx,
488
0
                        Some(n),
489
0
                        length,
490
0
                    )?
491
                }
492
            }
493
0
            WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
494
0
                range_columns,
495
0
                idx,
496
0
                None,
497
0
                length,
498
0
            )?,
499
0
            WindowFrameBound::Following(ref n) => self
500
0
                .calculate_index_of_row::<true, false>(
501
0
                    range_columns,
502
0
                    idx,
503
0
                    Some(n),
504
0
                    length,
505
0
                )?,
506
        };
507
0
        let end = match window_frame.end_bound {
508
0
            WindowFrameBound::Preceding(ref n) => self
509
0
                .calculate_index_of_row::<false, true>(
510
0
                    range_columns,
511
0
                    idx,
512
0
                    Some(n),
513
0
                    length,
514
0
                )?,
515
0
            WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
516
0
                range_columns,
517
0
                idx,
518
0
                None,
519
0
                length,
520
0
            )?,
521
0
            WindowFrameBound::Following(ref n) => {
522
0
                if n.is_null() {
523
                    // UNBOUNDED FOLLOWING
524
0
                    length
525
                } else {
526
0
                    self.calculate_index_of_row::<false, false>(
527
0
                        range_columns,
528
0
                        idx,
529
0
                        Some(n),
530
0
                        length,
531
0
                    )?
532
                }
533
            }
534
        };
535
0
        Ok(Range { start, end })
536
0
    }
537
538
    /// This function does the heavy lifting when finding range boundaries. It is meant to be
539
    /// called twice, in succession, to get window frame start and end indices (with `SIDE`
540
    /// supplied as true and false, respectively). Generic argument `SEARCH_SIDE` determines
541
    /// the sign of `delta` (where true/false represents negative/positive respectively).
542
0
    fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
543
0
        &mut self,
544
0
        range_columns: &[ArrayRef],
545
0
        idx: usize,
546
0
        delta: Option<&ScalarValue>,
547
0
        length: usize,
548
0
    ) -> Result<usize> {
549
0
        let delta = if let Some(delta) = delta {
550
0
            if let ScalarValue::UInt64(Some(value)) = delta {
551
0
                *value as usize
552
            } else {
553
0
                return internal_err!(
554
0
                    "Unexpectedly got a non-UInt64 value in a GROUPS mode window frame"
555
0
                );
556
            }
557
        } else {
558
0
            0
559
        };
560
0
        let mut group_start = 0;
561
0
        let last_group = self.group_end_indices.back_mut();
562
0
        if let Some((group_row, group_end)) = last_group {
563
0
            if *group_end < length {
564
0
                let new_group_row = get_row_at_idx(range_columns, *group_end)?;
565
                // If last/current group keys are the same, we extend the last group:
566
0
                if new_group_row.eq(group_row) {
567
                    // Update the end boundary of the group (search right boundary):
568
0
                    *group_end = search_in_slice(
569
0
                        range_columns,
570
0
                        group_row,
571
0
                        check_equality,
572
0
                        *group_end,
573
0
                        length,
574
0
                    )?;
575
0
                }
576
0
            }
577
            // Start searching from the last group boundary:
578
0
            group_start = *group_end;
579
0
        }
580
581
        // Advance groups until `idx` is inside a group:
582
0
        while idx >= group_start {
583
0
            let group_row = get_row_at_idx(range_columns, group_start)?;
584
            // Find end boundary of the group (search right boundary):
585
0
            let group_end = search_in_slice(
586
0
                range_columns,
587
0
                &group_row,
588
0
                check_equality,
589
0
                group_start,
590
0
                length,
591
0
            )?;
592
0
            self.group_end_indices.push_back((group_row, group_end));
593
0
            group_start = group_end;
594
        }
595
596
        // Update the group index `idx` belongs to:
597
0
        while self.current_group_idx < self.group_end_indices.len()
598
0
            && idx >= self.group_end_indices[self.current_group_idx].1
599
0
        {
600
0
            self.current_group_idx += 1;
601
0
        }
602
603
        // Find the group index of the frame boundary:
604
0
        let group_idx = if SEARCH_SIDE {
605
0
            if self.current_group_idx > delta {
606
0
                self.current_group_idx - delta
607
            } else {
608
0
                0
609
            }
610
        } else {
611
0
            self.current_group_idx + delta
612
        };
613
614
        // Extend `group_start_indices` until it includes at least `group_idx`:
615
0
        while self.group_end_indices.len() <= group_idx && group_start < length {
616
0
            let group_row = get_row_at_idx(range_columns, group_start)?;
617
            // Find end boundary of the group (search right boundary):
618
0
            let group_end = search_in_slice(
619
0
                range_columns,
620
0
                &group_row,
621
0
                check_equality,
622
0
                group_start,
623
0
                length,
624
0
            )?;
625
0
            self.group_end_indices.push_back((group_row, group_end));
626
0
            group_start = group_end;
627
        }
628
629
        // Calculate index of the group boundary:
630
0
        Ok(match (SIDE, SEARCH_SIDE) {
631
            // Window frame start:
632
            (true, _) => {
633
0
                let group_idx = std::cmp::min(group_idx, self.group_end_indices.len());
634
0
                if group_idx > 0 {
635
                    // Normally, start at the boundary of the previous group.
636
0
                    self.group_end_indices[group_idx - 1].1
637
                } else {
638
                    // If previous group is out of the table, start at zero.
639
0
                    0
640
                }
641
            }
642
            // Window frame end, PRECEDING n
643
            (false, true) => {
644
0
                if self.current_group_idx >= delta {
645
0
                    let group_idx = self.current_group_idx - delta;
646
0
                    self.group_end_indices[group_idx].1
647
                } else {
648
                    // Group is out of the table, therefore end at zero.
649
0
                    0
650
                }
651
            }
652
            // Window frame end, FOLLOWING n
653
            (false, false) => {
654
0
                let group_idx = std::cmp::min(
655
0
                    self.current_group_idx + delta,
656
0
                    self.group_end_indices.len() - 1,
657
0
                );
658
0
                self.group_end_indices[group_idx].1
659
            }
660
        })
661
0
    }
662
}
663
664
0
fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<bool> {
665
0
    Ok(current == target)
666
0
}
667
668
#[cfg(test)]
669
mod tests {
670
    use super::*;
671
672
    use arrow::array::Float64Array;
673
674
    fn get_test_data() -> (Vec<ArrayRef>, Vec<SortOptions>) {
675
        let range_columns: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
676
            5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11.,
677
        ]))];
678
        let sort_options = vec![SortOptions {
679
            descending: false,
680
            nulls_first: false,
681
        }];
682
683
        (range_columns, sort_options)
684
    }
685
686
    fn assert_expected(
687
        expected_results: Vec<(Range<usize>, usize)>,
688
        window_frame: &Arc<WindowFrame>,
689
    ) -> Result<()> {
690
        let mut window_frame_groups = WindowFrameStateGroups::default();
691
        let (range_columns, _) = get_test_data();
692
        let n_row = range_columns[0].len();
693
        for (idx, (expected_range, expected_group_idx)) in
694
            expected_results.into_iter().enumerate()
695
        {
696
            let range = window_frame_groups.calculate_range(
697
                window_frame,
698
                &range_columns,
699
                n_row,
700
                idx,
701
            )?;
702
            assert_eq!(range, expected_range);
703
            assert_eq!(window_frame_groups.current_group_idx, expected_group_idx);
704
        }
705
        Ok(())
706
    }
707
708
    #[test]
709
    fn test_window_frame_group_boundaries() -> Result<()> {
710
        let window_frame = Arc::new(WindowFrame::new_bounds(
711
            WindowFrameUnits::Groups,
712
            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
713
            WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
714
        ));
715
        let expected_results = vec![
716
            (Range { start: 0, end: 2 }, 0),
717
            (Range { start: 0, end: 4 }, 1),
718
            (Range { start: 1, end: 5 }, 2),
719
            (Range { start: 1, end: 5 }, 2),
720
            (Range { start: 2, end: 8 }, 3),
721
            (Range { start: 4, end: 9 }, 4),
722
            (Range { start: 4, end: 9 }, 4),
723
            (Range { start: 4, end: 9 }, 4),
724
            (Range { start: 5, end: 9 }, 5),
725
        ];
726
        assert_expected(expected_results, &window_frame)
727
    }
728
729
    #[test]
730
    fn test_window_frame_group_boundaries_both_following() -> Result<()> {
731
        let window_frame = Arc::new(WindowFrame::new_bounds(
732
            WindowFrameUnits::Groups,
733
            WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
734
            WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
735
        ));
736
        let expected_results = vec![
737
            (Range::<usize> { start: 1, end: 4 }, 0),
738
            (Range::<usize> { start: 2, end: 5 }, 1),
739
            (Range::<usize> { start: 4, end: 8 }, 2),
740
            (Range::<usize> { start: 4, end: 8 }, 2),
741
            (Range::<usize> { start: 5, end: 9 }, 3),
742
            (Range::<usize> { start: 8, end: 9 }, 4),
743
            (Range::<usize> { start: 8, end: 9 }, 4),
744
            (Range::<usize> { start: 8, end: 9 }, 4),
745
            (Range::<usize> { start: 9, end: 9 }, 5),
746
        ];
747
        assert_expected(expected_results, &window_frame)
748
    }
749
750
    #[test]
751
    fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
752
        let window_frame = Arc::new(WindowFrame::new_bounds(
753
            WindowFrameUnits::Groups,
754
            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
755
            WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
756
        ));
757
        let expected_results = vec![
758
            (Range::<usize> { start: 0, end: 0 }, 0),
759
            (Range::<usize> { start: 0, end: 1 }, 1),
760
            (Range::<usize> { start: 0, end: 2 }, 2),
761
            (Range::<usize> { start: 0, end: 2 }, 2),
762
            (Range::<usize> { start: 1, end: 4 }, 3),
763
            (Range::<usize> { start: 2, end: 5 }, 4),
764
            (Range::<usize> { start: 2, end: 5 }, 4),
765
            (Range::<usize> { start: 2, end: 5 }, 4),
766
            (Range::<usize> { start: 4, end: 8 }, 5),
767
        ];
768
        assert_expected(expected_results, &window_frame)
769
    }
770
}