Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/lead_lag.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines physical expression for `lead` and `lag` that can evaluated
19
//! at runtime during query execution
20
use crate::window::BuiltInWindowFunctionExpr;
21
use crate::PhysicalExpr;
22
use arrow::array::ArrayRef;
23
use arrow::datatypes::{DataType, Field};
24
use arrow_array::Array;
25
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26
use datafusion_expr::PartitionEvaluator;
27
use std::any::Any;
28
use std::cmp::min;
29
use std::collections::VecDeque;
30
use std::ops::{Neg, Range};
31
use std::sync::Arc;
32
33
/// window shift expression
34
#[derive(Debug)]
35
pub struct WindowShift {
36
    name: String,
37
    /// Output data type
38
    data_type: DataType,
39
    shift_offset: i64,
40
    expr: Arc<dyn PhysicalExpr>,
41
    default_value: ScalarValue,
42
    ignore_nulls: bool,
43
}
44
45
impl WindowShift {
46
    /// Get shift_offset of window shift expression
47
0
    pub fn get_shift_offset(&self) -> i64 {
48
0
        self.shift_offset
49
0
    }
50
51
    /// Get the default_value for window shift expression.
52
0
    pub fn get_default_value(&self) -> ScalarValue {
53
0
        self.default_value.clone()
54
0
    }
55
}
56
57
/// lead() window function
58
0
pub fn lead(
59
0
    name: String,
60
0
    data_type: DataType,
61
0
    expr: Arc<dyn PhysicalExpr>,
62
0
    shift_offset: Option<i64>,
63
0
    default_value: ScalarValue,
64
0
    ignore_nulls: bool,
65
0
) -> WindowShift {
66
0
    WindowShift {
67
0
        name,
68
0
        data_type,
69
0
        shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1),
70
0
        expr,
71
0
        default_value,
72
0
        ignore_nulls,
73
0
    }
74
0
}
75
76
/// lag() window function
77
0
pub fn lag(
78
0
    name: String,
79
0
    data_type: DataType,
80
0
    expr: Arc<dyn PhysicalExpr>,
81
0
    shift_offset: Option<i64>,
82
0
    default_value: ScalarValue,
83
0
    ignore_nulls: bool,
84
0
) -> WindowShift {
85
0
    WindowShift {
86
0
        name,
87
0
        data_type,
88
0
        shift_offset: shift_offset.unwrap_or(1),
89
0
        expr,
90
0
        default_value,
91
0
        ignore_nulls,
92
0
    }
93
0
}
94
95
impl BuiltInWindowFunctionExpr for WindowShift {
96
    /// Return a reference to Any that can be used for downcasting
97
0
    fn as_any(&self) -> &dyn Any {
98
0
        self
99
0
    }
100
101
0
    fn field(&self) -> Result<Field> {
102
0
        let nullable = true;
103
0
        Ok(Field::new(&self.name, self.data_type.clone(), nullable))
104
0
    }
105
106
0
    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
107
0
        vec![Arc::clone(&self.expr)]
108
0
    }
109
110
0
    fn name(&self) -> &str {
111
0
        &self.name
112
0
    }
113
114
0
    fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
115
0
        Ok(Box::new(WindowShiftEvaluator {
116
0
            shift_offset: self.shift_offset,
117
0
            default_value: self.default_value.clone(),
118
0
            ignore_nulls: self.ignore_nulls,
119
0
            non_null_offsets: VecDeque::new(),
120
0
        }))
121
0
    }
122
123
0
    fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
124
0
        Some(Arc::new(Self {
125
0
            name: self.name.clone(),
126
0
            data_type: self.data_type.clone(),
127
0
            shift_offset: -self.shift_offset,
128
0
            expr: Arc::clone(&self.expr),
129
0
            default_value: self.default_value.clone(),
130
0
            ignore_nulls: self.ignore_nulls,
131
0
        }))
132
0
    }
133
}
134
135
#[derive(Debug)]
136
pub(crate) struct WindowShiftEvaluator {
137
    shift_offset: i64,
138
    default_value: ScalarValue,
139
    ignore_nulls: bool,
140
    // VecDeque contains offset values that between non-null entries
141
    non_null_offsets: VecDeque<usize>,
142
}
143
144
impl WindowShiftEvaluator {
145
0
    fn is_lag(&self) -> bool {
146
0
        // Mode is LAG, when shift_offset is positive
147
0
        self.shift_offset > 0
148
0
    }
149
}
150
151
// implement ignore null for evaluate_all
152
0
fn evaluate_all_with_ignore_null(
153
0
    array: &ArrayRef,
154
0
    offset: i64,
155
0
    default_value: &ScalarValue,
156
0
    is_lag: bool,
157
0
) -> Result<ArrayRef, DataFusionError> {
158
0
    let valid_indices: Vec<usize> =
159
0
        array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
160
0
    let direction = !is_lag;
161
0
    let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
162
0
        .map(|id| {
163
0
            let result_index = match valid_indices.binary_search(&id) {
164
0
                Ok(pos) => if direction {
165
0
                    pos.checked_add(offset as usize)
166
                } else {
167
0
                    pos.checked_sub(offset.unsigned_abs() as usize)
168
                }
169
0
                .and_then(|new_pos| {
170
0
                    if new_pos < valid_indices.len() {
171
0
                        Some(valid_indices[new_pos])
172
                    } else {
173
0
                        None
174
                    }
175
0
                }),
176
0
                Err(pos) => if direction {
177
0
                    pos.checked_add(offset as usize)
178
0
                } else if pos > 0 {
179
0
                    pos.checked_sub(offset.unsigned_abs() as usize)
180
                } else {
181
0
                    None
182
                }
183
0
                .and_then(|new_pos| {
184
0
                    if new_pos < valid_indices.len() {
185
0
                        Some(valid_indices[new_pos])
186
                    } else {
187
0
                        None
188
                    }
189
0
                }),
190
            };
191
192
0
            match result_index {
193
0
                Some(index) => ScalarValue::try_from_array(array, index),
194
0
                None => Ok(default_value.clone()),
195
            }
196
0
        })
197
0
        .collect();
198
199
0
    let new_array = new_array_results?;
200
0
    ScalarValue::iter_to_array(new_array)
201
0
}
202
// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
203
0
fn shift_with_default_value(
204
0
    array: &ArrayRef,
205
0
    offset: i64,
206
0
    default_value: &ScalarValue,
207
0
) -> Result<ArrayRef> {
208
    use arrow::compute::concat;
209
210
0
    let value_len = array.len() as i64;
211
0
    if offset == 0 {
212
0
        Ok(Arc::clone(array))
213
0
    } else if offset == i64::MIN || offset.abs() >= value_len {
214
0
        default_value.to_array_of_size(value_len as usize)
215
    } else {
216
0
        let slice_offset = (-offset).clamp(0, value_len) as usize;
217
0
        let length = array.len() - offset.unsigned_abs() as usize;
218
0
        let slice = array.slice(slice_offset, length);
219
0
220
0
        // Generate array with remaining `null` items
221
0
        let nulls = offset.unsigned_abs() as usize;
222
0
        let default_values = default_value.to_array_of_size(nulls)?;
223
224
        // Concatenate both arrays, add nulls after if shift > 0 else before
225
0
        if offset > 0 {
226
0
            concat(&[default_values.as_ref(), slice.as_ref()])
227
0
                .map_err(|e| arrow_datafusion_err!(e))
228
        } else {
229
0
            concat(&[slice.as_ref(), default_values.as_ref()])
230
0
                .map_err(|e| arrow_datafusion_err!(e))
231
        }
232
    }
233
0
}
234
235
impl PartitionEvaluator for WindowShiftEvaluator {
236
0
    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
237
0
        if self.is_lag() {
238
0
            let start = if self.non_null_offsets.len() == self.shift_offset as usize {
239
                // How many rows needed previous than the current row to get necessary lag result
240
0
                let offset: usize = self.non_null_offsets.iter().sum();
241
0
                idx.saturating_sub(offset)
242
0
            } else if !self.ignore_nulls {
243
0
                let offset = self.shift_offset as usize;
244
0
                idx.saturating_sub(offset)
245
            } else {
246
0
                0
247
            };
248
0
            let end = idx + 1;
249
0
            Ok(Range { start, end })
250
        } else {
251
0
            let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
252
                // How many rows needed further than the current row to get necessary lead result
253
0
                let offset: usize = self.non_null_offsets.iter().sum();
254
0
                min(idx + offset + 1, n_rows)
255
0
            } else if !self.ignore_nulls {
256
0
                let offset = (-self.shift_offset) as usize;
257
0
                min(idx + offset, n_rows)
258
            } else {
259
0
                n_rows
260
            };
261
0
            Ok(Range { start: idx, end })
262
        }
263
0
    }
264
265
0
    fn is_causal(&self) -> bool {
266
0
        // Lagging windows are causal by definition:
267
0
        self.is_lag()
268
0
    }
269
270
0
    fn evaluate(
271
0
        &mut self,
272
0
        values: &[ArrayRef],
273
0
        range: &Range<usize>,
274
0
    ) -> Result<ScalarValue> {
275
0
        let array = &values[0];
276
0
        let len = array.len();
277
278
        // LAG mode
279
0
        let i = if self.is_lag() {
280
0
            (range.end as i64 - self.shift_offset - 1) as usize
281
        } else {
282
            // LEAD mode
283
0
            (range.start as i64 - self.shift_offset) as usize
284
        };
285
286
0
        let mut idx: Option<usize> = if i < len { Some(i) } else { None };
287
288
        // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
289
        // If current row index points to NULL value the row is NOT counted
290
0
        if self.ignore_nulls && self.is_lag() {
291
            // LAG when NULLS are ignored.
292
            // Find the nonNULL row index that shifted by offset comparing to current row index
293
0
            idx = if self.non_null_offsets.len() == self.shift_offset as usize {
294
0
                let total_offset: usize = self.non_null_offsets.iter().sum();
295
0
                Some(range.end - 1 - total_offset)
296
            } else {
297
0
                None
298
            };
299
300
            // Keep track of offset values between non-null entries
301
0
            if array.is_valid(range.end - 1) {
302
                // Non-null add new offset
303
0
                self.non_null_offsets.push_back(1);
304
0
                if self.non_null_offsets.len() > self.shift_offset as usize {
305
0
                    // WE do not need to keep track of more than `lag number of offset` values.
306
0
                    self.non_null_offsets.pop_front();
307
0
                }
308
0
            } else if !self.non_null_offsets.is_empty() {
309
0
                // Entry is null, increment offset value of the last entry.
310
0
                let end_idx = self.non_null_offsets.len() - 1;
311
0
                self.non_null_offsets[end_idx] += 1;
312
0
            }
313
0
        } else if self.ignore_nulls && !self.is_lag() {
314
            // LEAD when NULLS are ignored.
315
            // Stores the necessary non-null entry number further than the current row.
316
0
            let non_null_row_count = (-self.shift_offset) as usize;
317
0
318
0
            if self.non_null_offsets.is_empty() {
319
                // When empty, fill non_null offsets with the data further than the current row.
320
0
                let mut offset_val = 1;
321
0
                for idx in range.start + 1..range.end {
322
0
                    if array.is_valid(idx) {
323
0
                        self.non_null_offsets.push_back(offset_val);
324
0
                        offset_val = 1;
325
0
                    } else {
326
0
                        offset_val += 1;
327
0
                    }
328
                    // It is enough to keep track of `non_null_row_count + 1` non-null offset.
329
                    // further data is unnecessary for the result.
330
0
                    if self.non_null_offsets.len() == non_null_row_count + 1 {
331
0
                        break;
332
0
                    }
333
                }
334
0
            } else if range.end < len && array.is_valid(range.end) {
335
                // Update `non_null_offsets` with the new end data.
336
0
                if array.is_valid(range.end) {
337
0
                    // When non-null, append a new offset.
338
0
                    self.non_null_offsets.push_back(1);
339
0
                } else {
340
0
                    // When null, increment offset count of the last entry
341
0
                    let last_idx = self.non_null_offsets.len() - 1;
342
0
                    self.non_null_offsets[last_idx] += 1;
343
0
                }
344
0
            }
345
346
            // Find the nonNULL row index that shifted by offset comparing to current row index
347
0
            idx = if self.non_null_offsets.len() >= non_null_row_count {
348
0
                let total_offset: usize =
349
0
                    self.non_null_offsets.iter().take(non_null_row_count).sum();
350
0
                Some(range.start + total_offset)
351
            } else {
352
0
                None
353
            };
354
            // Prune `self.non_null_offsets` from the start. so that at next iteration
355
            // start of the `self.non_null_offsets` matches with current row.
356
0
            if !self.non_null_offsets.is_empty() {
357
0
                self.non_null_offsets[0] -= 1;
358
0
                if self.non_null_offsets[0] == 0 {
359
0
                    // When offset is 0. Remove it.
360
0
                    self.non_null_offsets.pop_front();
361
0
                }
362
0
            }
363
0
        }
364
365
        // Set the default value if
366
        // - index is out of window bounds
367
        // OR
368
        // - ignore nulls mode and current value is null and is within window bounds
369
        // .unwrap() is safe here as there is a none check in front
370
        #[allow(clippy::unnecessary_unwrap)]
371
0
        if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
372
0
            ScalarValue::try_from_array(array, idx.unwrap())
373
        } else {
374
0
            Ok(self.default_value.clone())
375
        }
376
0
    }
377
378
0
    fn evaluate_all(
379
0
        &mut self,
380
0
        values: &[ArrayRef],
381
0
        _num_rows: usize,
382
0
    ) -> Result<ArrayRef> {
383
0
        // LEAD, LAG window functions take single column, values will have size 1
384
0
        let value = &values[0];
385
0
        if !self.ignore_nulls {
386
0
            shift_with_default_value(value, self.shift_offset, &self.default_value)
387
        } else {
388
0
            evaluate_all_with_ignore_null(
389
0
                value,
390
0
                self.shift_offset,
391
0
                &self.default_value,
392
0
                self.is_lag(),
393
0
            )
394
        }
395
0
    }
396
397
0
    fn supports_bounded_execution(&self) -> bool {
398
0
        true
399
0
    }
400
}
401
402
#[cfg(test)]
403
mod tests {
404
    use super::*;
405
    use crate::expressions::Column;
406
    use arrow::{array::*, datatypes::*};
407
    use datafusion_common::cast::as_int32_array;
408
409
    fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> {
410
        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
411
        let values = vec![arr];
412
        let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
413
        let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
414
        let values = expr.evaluate_args(&batch)?;
415
        let result = expr
416
            .create_evaluator()?
417
            .evaluate_all(&values, batch.num_rows())?;
418
        let result = as_int32_array(&result)?;
419
        assert_eq!(expected, *result);
420
        Ok(())
421
    }
422
423
    #[test]
424
    fn lead_lag_get_range() -> Result<()> {
425
        // LAG(2)
426
        let lag_fn = WindowShiftEvaluator {
427
            shift_offset: 2,
428
            default_value: ScalarValue::Null,
429
            ignore_nulls: false,
430
            non_null_offsets: Default::default(),
431
        };
432
        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
433
        assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
434
435
        // LAG(2 ignore nulls)
436
        let lag_fn = WindowShiftEvaluator {
437
            shift_offset: 2,
438
            default_value: ScalarValue::Null,
439
            ignore_nulls: true,
440
            // models data received [<Some>, <Some>, <Some>, NULL, <Some>, NULL, <current row>, ...]
441
            non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used
442
        };
443
        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
444
445
        // LEAD(2)
446
        let lead_fn = WindowShiftEvaluator {
447
            shift_offset: -2,
448
            default_value: ScalarValue::Null,
449
            ignore_nulls: false,
450
            non_null_offsets: Default::default(),
451
        };
452
        assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
453
        assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
454
455
        // LEAD(2 ignore nulls)
456
        let lead_fn = WindowShiftEvaluator {
457
            shift_offset: -2,
458
            default_value: ScalarValue::Null,
459
            ignore_nulls: true,
460
            // models data received [..., <current row>, NULL, <Some>, NULL, <Some>, ..]
461
            non_null_offsets: vec![2, 2].into(),
462
        };
463
        assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
464
465
        Ok(())
466
    }
467
468
    #[test]
469
    fn lead_lag_window_shift() -> Result<()> {
470
        test_i32_result(
471
            lead(
472
                "lead".to_owned(),
473
                DataType::Int32,
474
                Arc::new(Column::new("c3", 0)),
475
                None,
476
                ScalarValue::Null.cast_to(&DataType::Int32)?,
477
                false,
478
            ),
479
            [
480
                Some(-2),
481
                Some(3),
482
                Some(-4),
483
                Some(5),
484
                Some(-6),
485
                Some(7),
486
                Some(8),
487
                None,
488
            ]
489
            .iter()
490
            .collect::<Int32Array>(),
491
        )?;
492
493
        test_i32_result(
494
            lag(
495
                "lead".to_owned(),
496
                DataType::Int32,
497
                Arc::new(Column::new("c3", 0)),
498
                None,
499
                ScalarValue::Null.cast_to(&DataType::Int32)?,
500
                false,
501
            ),
502
            [
503
                None,
504
                Some(1),
505
                Some(-2),
506
                Some(3),
507
                Some(-4),
508
                Some(5),
509
                Some(-6),
510
                Some(7),
511
            ]
512
            .iter()
513
            .collect::<Int32Array>(),
514
        )?;
515
516
        test_i32_result(
517
            lag(
518
                "lead".to_owned(),
519
                DataType::Int32,
520
                Arc::new(Column::new("c3", 0)),
521
                None,
522
                ScalarValue::Int32(Some(100)),
523
                false,
524
            ),
525
            [
526
                Some(100),
527
                Some(1),
528
                Some(-2),
529
                Some(3),
530
                Some(-4),
531
                Some(5),
532
                Some(-6),
533
                Some(7),
534
            ]
535
            .iter()
536
            .collect::<Int32Array>(),
537
        )?;
538
        Ok(())
539
    }
540
}