Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/nth_value.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE`
19
//! functions that can be evaluated at run time during query execution.
20
21
use std::any::Any;
22
use std::cmp::Ordering;
23
use std::ops::Range;
24
use std::sync::Arc;
25
26
use crate::window::window_expr::{NthValueKind, NthValueState};
27
use crate::window::BuiltInWindowFunctionExpr;
28
use crate::PhysicalExpr;
29
30
use arrow::array::{Array, ArrayRef};
31
use arrow::datatypes::{DataType, Field};
32
use datafusion_common::Result;
33
use datafusion_common::ScalarValue;
34
use datafusion_expr::window_state::WindowAggState;
35
use datafusion_expr::PartitionEvaluator;
36
37
/// nth_value expression
38
#[derive(Debug)]
39
pub struct NthValue {
40
    name: String,
41
    expr: Arc<dyn PhysicalExpr>,
42
    /// Output data type
43
    data_type: DataType,
44
    kind: NthValueKind,
45
    ignore_nulls: bool,
46
}
47
48
impl NthValue {
49
    /// Create a new FIRST_VALUE window aggregate function
50
0
    pub fn first(
51
0
        name: impl Into<String>,
52
0
        expr: Arc<dyn PhysicalExpr>,
53
0
        data_type: DataType,
54
0
        ignore_nulls: bool,
55
0
    ) -> Self {
56
0
        Self {
57
0
            name: name.into(),
58
0
            expr,
59
0
            data_type,
60
0
            kind: NthValueKind::First,
61
0
            ignore_nulls,
62
0
        }
63
0
    }
64
65
    /// Create a new LAST_VALUE window aggregate function
66
1
    pub fn last(
67
1
        name: impl Into<String>,
68
1
        expr: Arc<dyn PhysicalExpr>,
69
1
        data_type: DataType,
70
1
        ignore_nulls: bool,
71
1
    ) -> Self {
72
1
        Self {
73
1
            name: name.into(),
74
1
            expr,
75
1
            data_type,
76
1
            kind: NthValueKind::Last,
77
1
            ignore_nulls,
78
1
        }
79
1
    }
80
81
    /// Create a new NTH_VALUE window aggregate function
82
2
    pub fn nth(
83
2
        name: impl Into<String>,
84
2
        expr: Arc<dyn PhysicalExpr>,
85
2
        data_type: DataType,
86
2
        n: i64,
87
2
        ignore_nulls: bool,
88
2
    ) -> Result<Self> {
89
2
        Ok(Self {
90
2
            name: name.into(),
91
2
            expr,
92
2
            data_type,
93
2
            kind: NthValueKind::Nth(n),
94
2
            ignore_nulls,
95
2
        })
96
2
    }
97
98
    /// Get the NTH_VALUE kind
99
0
    pub fn get_kind(&self) -> NthValueKind {
100
0
        self.kind
101
0
    }
102
}
103
104
impl BuiltInWindowFunctionExpr for NthValue {
105
    /// Return a reference to Any that can be used for downcasting
106
0
    fn as_any(&self) -> &dyn Any {
107
0
        self
108
0
    }
109
110
18
    fn field(&self) -> Result<Field> {
111
18
        let nullable = true;
112
18
        Ok(Field::new(&self.name, self.data_type.clone(), nullable))
113
18
    }
114
115
12
    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
116
12
        vec![Arc::clone(&self.expr)]
117
12
    }
118
119
3
    fn name(&self) -> &str {
120
3
        &self.name
121
3
    }
122
123
3
    fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
124
3
        let state = NthValueState {
125
3
            finalized_result: None,
126
3
            kind: self.kind,
127
3
        };
128
3
        Ok(Box::new(NthValueEvaluator {
129
3
            state,
130
3
            ignore_nulls: self.ignore_nulls,
131
3
        }))
132
3
    }
133
134
2
    fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
135
2
        let reversed_kind = match self.kind {
136
0
            NthValueKind::First => NthValueKind::Last,
137
0
            NthValueKind::Last => NthValueKind::First,
138
2
            NthValueKind::Nth(idx) => NthValueKind::Nth(-idx),
139
        };
140
2
        Some(Arc::new(Self {
141
2
            name: self.name.clone(),
142
2
            expr: Arc::clone(&self.expr),
143
2
            data_type: self.data_type.clone(),
144
2
            kind: reversed_kind,
145
2
            ignore_nulls: self.ignore_nulls,
146
2
        }))
147
2
    }
148
}
149
150
/// Value evaluator for nth_value functions
151
#[derive(Debug)]
152
pub(crate) struct NthValueEvaluator {
153
    state: NthValueState,
154
    ignore_nulls: bool,
155
}
156
157
impl PartitionEvaluator for NthValueEvaluator {
158
    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
159
    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
160
    /// can memoize the result.  Once result is calculated, it will always stay
161
    /// same. Hence, we do not need to keep past data as we process the entire
162
    /// dataset.
163
12
    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
164
12
        let out = &state.out_col;
165
12
        let size = out.len();
166
12
        let mut buffer_size = 1;
167
        // Decide if we arrived at a final result yet:
168
12
        let (is_prunable, is_reverse_direction) = match self.state.kind {
169
            NthValueKind::First => {
170
0
                let n_range =
171
0
                    state.window_frame_range.end - state.window_frame_range.start;
172
0
                (n_range > 0 && size > 0, false)
173
            }
174
4
            NthValueKind::Last => (true, true),
175
8
            NthValueKind::Nth(n) => {
176
8
                let n_range =
177
8
                    state.window_frame_range.end - state.window_frame_range.start;
178
8
                match n.cmp(&0) {
179
                    Ordering::Greater => {
180
0
                        (n_range >= (n as usize) && size > (n as usize), false)
181
                    }
182
                    Ordering::Less => {
183
8
                        let reverse_index = (-n) as usize;
184
8
                        buffer_size = reverse_index;
185
8
                        // Negative index represents reverse direction.
186
8
                        (n_range >= reverse_index, true)
187
                    }
188
0
                    Ordering::Equal => (true, false),
189
                }
190
            }
191
        };
192
        // Do not memoize results when nulls are ignored.
193
12
        if is_prunable && !self.ignore_nulls {
194
12
            if self.state.finalized_result.is_none() && !is_reverse_direction {
195
0
                let result = ScalarValue::try_from_array(out, size - 1)?;
196
0
                self.state.finalized_result = Some(result);
197
12
            }
198
12
            state.window_frame_range.start =
199
12
                state.window_frame_range.end.saturating_sub(buffer_size);
200
0
        }
201
12
        Ok(())
202
12
    }
203
204
27
    fn evaluate(
205
27
        &mut self,
206
27
        values: &[ArrayRef],
207
27
        range: &Range<usize>,
208
27
    ) -> Result<ScalarValue> {
209
27
        if let Some(
ref result0
) = self.state.finalized_result {
210
0
            Ok(result.clone())
211
        } else {
212
            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
213
27
            let arr = &values[0];
214
27
            let n_range = range.end - range.start;
215
27
            if n_range == 0 {
216
                // We produce None if the window is empty.
217
0
                return ScalarValue::try_from(arr.data_type());
218
27
            }
219
220
            // Extract valid indices if ignoring nulls.
221
27
            let valid_indices = if self.ignore_nulls {
222
                // Calculate valid indices, inside the window frame boundaries
223
0
                let slice = arr.slice(range.start, n_range);
224
0
                let valid_indices = slice
225
0
                    .nulls()
226
0
                    .map(|nulls| {
227
0
                        nulls
228
0
                            .valid_indices()
229
0
                            // Add offset `range.start` to valid indices, to point correct index in the original arr.
230
0
                            .map(|idx| idx + range.start)
231
0
                            .collect::<Vec<_>>()
232
0
                    })
233
0
                    .unwrap_or_default();
234
0
                if valid_indices.is_empty() {
235
0
                    return ScalarValue::try_from(arr.data_type());
236
0
                }
237
0
                Some(valid_indices)
238
            } else {
239
27
                None
240
            };
241
27
            match self.state.kind {
242
                NthValueKind::First => {
243
0
                    if let Some(valid_indices) = &valid_indices {
244
0
                        ScalarValue::try_from_array(arr, valid_indices[0])
245
                    } else {
246
0
                        ScalarValue::try_from_array(arr, range.start)
247
                    }
248
                }
249
                NthValueKind::Last => {
250
9
                    if let Some(
valid_indices0
) = &valid_indices {
251
0
                        ScalarValue::try_from_array(
252
0
                            arr,
253
0
                            valid_indices[valid_indices.len() - 1],
254
0
                        )
255
                    } else {
256
9
                        ScalarValue::try_from_array(arr, range.end - 1)
257
                    }
258
                }
259
18
                NthValueKind::Nth(n) => {
260
18
                    match n.cmp(&0) {
261
                        Ordering::Greater => {
262
                            // SQL indices are not 0-based.
263
0
                            let index = (n as usize) - 1;
264
0
                            if index >= n_range {
265
                                // Outside the range, return NULL:
266
0
                                ScalarValue::try_from(arr.data_type())
267
0
                            } else if let Some(valid_indices) = valid_indices {
268
0
                                if index >= valid_indices.len() {
269
0
                                    return ScalarValue::try_from(arr.data_type());
270
0
                                }
271
0
                                ScalarValue::try_from_array(&arr, valid_indices[index])
272
                            } else {
273
0
                                ScalarValue::try_from_array(arr, range.start + index)
274
                            }
275
                        }
276
                        Ordering::Less => {
277
18
                            let reverse_index = (-n) as usize;
278
18
                            if n_range < reverse_index {
279
                                // Outside the range, return NULL:
280
1
                                ScalarValue::try_from(arr.data_type())
281
17
                            } else if let Some(
valid_indices0
) = valid_indices {
282
0
                                if reverse_index > valid_indices.len() {
283
0
                                    return ScalarValue::try_from(arr.data_type());
284
0
                                }
285
0
                                let new_index =
286
0
                                    valid_indices[valid_indices.len() - reverse_index];
287
0
                                ScalarValue::try_from_array(&arr, new_index)
288
                            } else {
289
17
                                ScalarValue::try_from_array(
290
17
                                    arr,
291
17
                                    range.start + n_range - reverse_index,
292
17
                                )
293
                            }
294
                        }
295
0
                        Ordering::Equal => ScalarValue::try_from(arr.data_type()),
296
                    }
297
                }
298
            }
299
        }
300
27
    }
301
302
0
    fn supports_bounded_execution(&self) -> bool {
303
0
        true
304
0
    }
305
306
51
    fn uses_window_frame(&self) -> bool {
307
51
        true
308
51
    }
309
}
310
311
#[cfg(test)]
312
mod tests {
313
    use super::*;
314
    use crate::expressions::Column;
315
    use arrow::{array::*, datatypes::*};
316
    use datafusion_common::cast::as_int32_array;
317
318
    fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> {
319
        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
320
        let values = vec![arr];
321
        let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
322
        let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
323
        let mut ranges: Vec<Range<usize>> = vec![];
324
        for i in 0..8 {
325
            ranges.push(Range {
326
                start: 0,
327
                end: i + 1,
328
            })
329
        }
330
        let mut evaluator = expr.create_evaluator()?;
331
        let values = expr.evaluate_args(&batch)?;
332
        let result = ranges
333
            .iter()
334
            .map(|range| evaluator.evaluate(&values, range))
335
            .collect::<Result<Vec<ScalarValue>>>()?;
336
        let result = ScalarValue::iter_to_array(result.into_iter())?;
337
        let result = as_int32_array(&result)?;
338
        assert_eq!(expected, *result);
339
        Ok(())
340
    }
341
342
    #[test]
343
    fn first_value() -> Result<()> {
344
        let first_value = NthValue::first(
345
            "first_value".to_owned(),
346
            Arc::new(Column::new("arr", 0)),
347
            DataType::Int32,
348
            false,
349
        );
350
        test_i32_result(first_value, Int32Array::from(vec![1; 8]))?;
351
        Ok(())
352
    }
353
354
    #[test]
355
    fn last_value() -> Result<()> {
356
        let last_value = NthValue::last(
357
            "last_value".to_owned(),
358
            Arc::new(Column::new("arr", 0)),
359
            DataType::Int32,
360
            false,
361
        );
362
        test_i32_result(
363
            last_value,
364
            Int32Array::from(vec![
365
                Some(1),
366
                Some(-2),
367
                Some(3),
368
                Some(-4),
369
                Some(5),
370
                Some(-6),
371
                Some(7),
372
                Some(8),
373
            ]),
374
        )?;
375
        Ok(())
376
    }
377
378
    #[test]
379
    fn nth_value_1() -> Result<()> {
380
        let nth_value = NthValue::nth(
381
            "nth_value".to_owned(),
382
            Arc::new(Column::new("arr", 0)),
383
            DataType::Int32,
384
            1,
385
            false,
386
        )?;
387
        test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?;
388
        Ok(())
389
    }
390
391
    #[test]
392
    fn nth_value_2() -> Result<()> {
393
        let nth_value = NthValue::nth(
394
            "nth_value".to_owned(),
395
            Arc::new(Column::new("arr", 0)),
396
            DataType::Int32,
397
            2,
398
            false,
399
        )?;
400
        test_i32_result(
401
            nth_value,
402
            Int32Array::from(vec![
403
                None,
404
                Some(-2),
405
                Some(-2),
406
                Some(-2),
407
                Some(-2),
408
                Some(-2),
409
                Some(-2),
410
                Some(-2),
411
            ]),
412
        )?;
413
        Ok(())
414
    }
415
}