/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/window/nth_value.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE` |
19 | | //! functions that can be evaluated at run time during query execution. |
20 | | |
21 | | use std::any::Any; |
22 | | use std::cmp::Ordering; |
23 | | use std::ops::Range; |
24 | | use std::sync::Arc; |
25 | | |
26 | | use crate::window::window_expr::{NthValueKind, NthValueState}; |
27 | | use crate::window::BuiltInWindowFunctionExpr; |
28 | | use crate::PhysicalExpr; |
29 | | |
30 | | use arrow::array::{Array, ArrayRef}; |
31 | | use arrow::datatypes::{DataType, Field}; |
32 | | use datafusion_common::Result; |
33 | | use datafusion_common::ScalarValue; |
34 | | use datafusion_expr::window_state::WindowAggState; |
35 | | use datafusion_expr::PartitionEvaluator; |
36 | | |
37 | | /// nth_value expression |
38 | | #[derive(Debug)] |
39 | | pub struct NthValue { |
40 | | name: String, |
41 | | expr: Arc<dyn PhysicalExpr>, |
42 | | /// Output data type |
43 | | data_type: DataType, |
44 | | kind: NthValueKind, |
45 | | ignore_nulls: bool, |
46 | | } |
47 | | |
48 | | impl NthValue { |
49 | | /// Create a new FIRST_VALUE window aggregate function |
50 | 0 | pub fn first( |
51 | 0 | name: impl Into<String>, |
52 | 0 | expr: Arc<dyn PhysicalExpr>, |
53 | 0 | data_type: DataType, |
54 | 0 | ignore_nulls: bool, |
55 | 0 | ) -> Self { |
56 | 0 | Self { |
57 | 0 | name: name.into(), |
58 | 0 | expr, |
59 | 0 | data_type, |
60 | 0 | kind: NthValueKind::First, |
61 | 0 | ignore_nulls, |
62 | 0 | } |
63 | 0 | } |
64 | | |
65 | | /// Create a new LAST_VALUE window aggregate function |
66 | 1 | pub fn last( |
67 | 1 | name: impl Into<String>, |
68 | 1 | expr: Arc<dyn PhysicalExpr>, |
69 | 1 | data_type: DataType, |
70 | 1 | ignore_nulls: bool, |
71 | 1 | ) -> Self { |
72 | 1 | Self { |
73 | 1 | name: name.into(), |
74 | 1 | expr, |
75 | 1 | data_type, |
76 | 1 | kind: NthValueKind::Last, |
77 | 1 | ignore_nulls, |
78 | 1 | } |
79 | 1 | } |
80 | | |
81 | | /// Create a new NTH_VALUE window aggregate function |
82 | 2 | pub fn nth( |
83 | 2 | name: impl Into<String>, |
84 | 2 | expr: Arc<dyn PhysicalExpr>, |
85 | 2 | data_type: DataType, |
86 | 2 | n: i64, |
87 | 2 | ignore_nulls: bool, |
88 | 2 | ) -> Result<Self> { |
89 | 2 | Ok(Self { |
90 | 2 | name: name.into(), |
91 | 2 | expr, |
92 | 2 | data_type, |
93 | 2 | kind: NthValueKind::Nth(n), |
94 | 2 | ignore_nulls, |
95 | 2 | }) |
96 | 2 | } |
97 | | |
98 | | /// Get the NTH_VALUE kind |
99 | 0 | pub fn get_kind(&self) -> NthValueKind { |
100 | 0 | self.kind |
101 | 0 | } |
102 | | } |
103 | | |
104 | | impl BuiltInWindowFunctionExpr for NthValue { |
105 | | /// Return a reference to Any that can be used for downcasting |
106 | 0 | fn as_any(&self) -> &dyn Any { |
107 | 0 | self |
108 | 0 | } |
109 | | |
110 | 18 | fn field(&self) -> Result<Field> { |
111 | 18 | let nullable = true; |
112 | 18 | Ok(Field::new(&self.name, self.data_type.clone(), nullable)) |
113 | 18 | } |
114 | | |
115 | 12 | fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { |
116 | 12 | vec![Arc::clone(&self.expr)] |
117 | 12 | } |
118 | | |
119 | 3 | fn name(&self) -> &str { |
120 | 3 | &self.name |
121 | 3 | } |
122 | | |
123 | 3 | fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> { |
124 | 3 | let state = NthValueState { |
125 | 3 | finalized_result: None, |
126 | 3 | kind: self.kind, |
127 | 3 | }; |
128 | 3 | Ok(Box::new(NthValueEvaluator { |
129 | 3 | state, |
130 | 3 | ignore_nulls: self.ignore_nulls, |
131 | 3 | })) |
132 | 3 | } |
133 | | |
134 | 2 | fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> { |
135 | 2 | let reversed_kind = match self.kind { |
136 | 0 | NthValueKind::First => NthValueKind::Last, |
137 | 0 | NthValueKind::Last => NthValueKind::First, |
138 | 2 | NthValueKind::Nth(idx) => NthValueKind::Nth(-idx), |
139 | | }; |
140 | 2 | Some(Arc::new(Self { |
141 | 2 | name: self.name.clone(), |
142 | 2 | expr: Arc::clone(&self.expr), |
143 | 2 | data_type: self.data_type.clone(), |
144 | 2 | kind: reversed_kind, |
145 | 2 | ignore_nulls: self.ignore_nulls, |
146 | 2 | })) |
147 | 2 | } |
148 | | } |
149 | | |
150 | | /// Value evaluator for nth_value functions |
151 | | #[derive(Debug)] |
152 | | pub(crate) struct NthValueEvaluator { |
153 | | state: NthValueState, |
154 | | ignore_nulls: bool, |
155 | | } |
156 | | |
157 | | impl PartitionEvaluator for NthValueEvaluator { |
158 | | /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), |
159 | | /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we |
160 | | /// can memoize the result. Once result is calculated, it will always stay |
161 | | /// same. Hence, we do not need to keep past data as we process the entire |
162 | | /// dataset. |
163 | 12 | fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { |
164 | 12 | let out = &state.out_col; |
165 | 12 | let size = out.len(); |
166 | 12 | let mut buffer_size = 1; |
167 | | // Decide if we arrived at a final result yet: |
168 | 12 | let (is_prunable, is_reverse_direction) = match self.state.kind { |
169 | | NthValueKind::First => { |
170 | 0 | let n_range = |
171 | 0 | state.window_frame_range.end - state.window_frame_range.start; |
172 | 0 | (n_range > 0 && size > 0, false) |
173 | | } |
174 | 4 | NthValueKind::Last => (true, true), |
175 | 8 | NthValueKind::Nth(n) => { |
176 | 8 | let n_range = |
177 | 8 | state.window_frame_range.end - state.window_frame_range.start; |
178 | 8 | match n.cmp(&0) { |
179 | | Ordering::Greater => { |
180 | 0 | (n_range >= (n as usize) && size > (n as usize), false) |
181 | | } |
182 | | Ordering::Less => { |
183 | 8 | let reverse_index = (-n) as usize; |
184 | 8 | buffer_size = reverse_index; |
185 | 8 | // Negative index represents reverse direction. |
186 | 8 | (n_range >= reverse_index, true) |
187 | | } |
188 | 0 | Ordering::Equal => (true, false), |
189 | | } |
190 | | } |
191 | | }; |
192 | | // Do not memoize results when nulls are ignored. |
193 | 12 | if is_prunable && !self.ignore_nulls { |
194 | 12 | if self.state.finalized_result.is_none() && !is_reverse_direction { |
195 | 0 | let result = ScalarValue::try_from_array(out, size - 1)?; |
196 | 0 | self.state.finalized_result = Some(result); |
197 | 12 | } |
198 | 12 | state.window_frame_range.start = |
199 | 12 | state.window_frame_range.end.saturating_sub(buffer_size); |
200 | 0 | } |
201 | 12 | Ok(()) |
202 | 12 | } |
203 | | |
204 | 27 | fn evaluate( |
205 | 27 | &mut self, |
206 | 27 | values: &[ArrayRef], |
207 | 27 | range: &Range<usize>, |
208 | 27 | ) -> Result<ScalarValue> { |
209 | 27 | if let Some(ref result0 ) = self.state.finalized_result { |
210 | 0 | Ok(result.clone()) |
211 | | } else { |
212 | | // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1. |
213 | 27 | let arr = &values[0]; |
214 | 27 | let n_range = range.end - range.start; |
215 | 27 | if n_range == 0 { |
216 | | // We produce None if the window is empty. |
217 | 0 | return ScalarValue::try_from(arr.data_type()); |
218 | 27 | } |
219 | | |
220 | | // Extract valid indices if ignoring nulls. |
221 | 27 | let valid_indices = if self.ignore_nulls { |
222 | | // Calculate valid indices, inside the window frame boundaries |
223 | 0 | let slice = arr.slice(range.start, n_range); |
224 | 0 | let valid_indices = slice |
225 | 0 | .nulls() |
226 | 0 | .map(|nulls| { |
227 | 0 | nulls |
228 | 0 | .valid_indices() |
229 | 0 | // Add offset `range.start` to valid indices, to point correct index in the original arr. |
230 | 0 | .map(|idx| idx + range.start) |
231 | 0 | .collect::<Vec<_>>() |
232 | 0 | }) |
233 | 0 | .unwrap_or_default(); |
234 | 0 | if valid_indices.is_empty() { |
235 | 0 | return ScalarValue::try_from(arr.data_type()); |
236 | 0 | } |
237 | 0 | Some(valid_indices) |
238 | | } else { |
239 | 27 | None |
240 | | }; |
241 | 27 | match self.state.kind { |
242 | | NthValueKind::First => { |
243 | 0 | if let Some(valid_indices) = &valid_indices { |
244 | 0 | ScalarValue::try_from_array(arr, valid_indices[0]) |
245 | | } else { |
246 | 0 | ScalarValue::try_from_array(arr, range.start) |
247 | | } |
248 | | } |
249 | | NthValueKind::Last => { |
250 | 9 | if let Some(valid_indices0 ) = &valid_indices { |
251 | 0 | ScalarValue::try_from_array( |
252 | 0 | arr, |
253 | 0 | valid_indices[valid_indices.len() - 1], |
254 | 0 | ) |
255 | | } else { |
256 | 9 | ScalarValue::try_from_array(arr, range.end - 1) |
257 | | } |
258 | | } |
259 | 18 | NthValueKind::Nth(n) => { |
260 | 18 | match n.cmp(&0) { |
261 | | Ordering::Greater => { |
262 | | // SQL indices are not 0-based. |
263 | 0 | let index = (n as usize) - 1; |
264 | 0 | if index >= n_range { |
265 | | // Outside the range, return NULL: |
266 | 0 | ScalarValue::try_from(arr.data_type()) |
267 | 0 | } else if let Some(valid_indices) = valid_indices { |
268 | 0 | if index >= valid_indices.len() { |
269 | 0 | return ScalarValue::try_from(arr.data_type()); |
270 | 0 | } |
271 | 0 | ScalarValue::try_from_array(&arr, valid_indices[index]) |
272 | | } else { |
273 | 0 | ScalarValue::try_from_array(arr, range.start + index) |
274 | | } |
275 | | } |
276 | | Ordering::Less => { |
277 | 18 | let reverse_index = (-n) as usize; |
278 | 18 | if n_range < reverse_index { |
279 | | // Outside the range, return NULL: |
280 | 1 | ScalarValue::try_from(arr.data_type()) |
281 | 17 | } else if let Some(valid_indices0 ) = valid_indices { |
282 | 0 | if reverse_index > valid_indices.len() { |
283 | 0 | return ScalarValue::try_from(arr.data_type()); |
284 | 0 | } |
285 | 0 | let new_index = |
286 | 0 | valid_indices[valid_indices.len() - reverse_index]; |
287 | 0 | ScalarValue::try_from_array(&arr, new_index) |
288 | | } else { |
289 | 17 | ScalarValue::try_from_array( |
290 | 17 | arr, |
291 | 17 | range.start + n_range - reverse_index, |
292 | 17 | ) |
293 | | } |
294 | | } |
295 | 0 | Ordering::Equal => ScalarValue::try_from(arr.data_type()), |
296 | | } |
297 | | } |
298 | | } |
299 | | } |
300 | 27 | } |
301 | | |
302 | 0 | fn supports_bounded_execution(&self) -> bool { |
303 | 0 | true |
304 | 0 | } |
305 | | |
306 | 51 | fn uses_window_frame(&self) -> bool { |
307 | 51 | true |
308 | 51 | } |
309 | | } |
310 | | |
311 | | #[cfg(test)] |
312 | | mod tests { |
313 | | use super::*; |
314 | | use crate::expressions::Column; |
315 | | use arrow::{array::*, datatypes::*}; |
316 | | use datafusion_common::cast::as_int32_array; |
317 | | |
318 | | fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> { |
319 | | let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); |
320 | | let values = vec![arr]; |
321 | | let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); |
322 | | let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; |
323 | | let mut ranges: Vec<Range<usize>> = vec![]; |
324 | | for i in 0..8 { |
325 | | ranges.push(Range { |
326 | | start: 0, |
327 | | end: i + 1, |
328 | | }) |
329 | | } |
330 | | let mut evaluator = expr.create_evaluator()?; |
331 | | let values = expr.evaluate_args(&batch)?; |
332 | | let result = ranges |
333 | | .iter() |
334 | | .map(|range| evaluator.evaluate(&values, range)) |
335 | | .collect::<Result<Vec<ScalarValue>>>()?; |
336 | | let result = ScalarValue::iter_to_array(result.into_iter())?; |
337 | | let result = as_int32_array(&result)?; |
338 | | assert_eq!(expected, *result); |
339 | | Ok(()) |
340 | | } |
341 | | |
342 | | #[test] |
343 | | fn first_value() -> Result<()> { |
344 | | let first_value = NthValue::first( |
345 | | "first_value".to_owned(), |
346 | | Arc::new(Column::new("arr", 0)), |
347 | | DataType::Int32, |
348 | | false, |
349 | | ); |
350 | | test_i32_result(first_value, Int32Array::from(vec![1; 8]))?; |
351 | | Ok(()) |
352 | | } |
353 | | |
354 | | #[test] |
355 | | fn last_value() -> Result<()> { |
356 | | let last_value = NthValue::last( |
357 | | "last_value".to_owned(), |
358 | | Arc::new(Column::new("arr", 0)), |
359 | | DataType::Int32, |
360 | | false, |
361 | | ); |
362 | | test_i32_result( |
363 | | last_value, |
364 | | Int32Array::from(vec![ |
365 | | Some(1), |
366 | | Some(-2), |
367 | | Some(3), |
368 | | Some(-4), |
369 | | Some(5), |
370 | | Some(-6), |
371 | | Some(7), |
372 | | Some(8), |
373 | | ]), |
374 | | )?; |
375 | | Ok(()) |
376 | | } |
377 | | |
378 | | #[test] |
379 | | fn nth_value_1() -> Result<()> { |
380 | | let nth_value = NthValue::nth( |
381 | | "nth_value".to_owned(), |
382 | | Arc::new(Column::new("arr", 0)), |
383 | | DataType::Int32, |
384 | | 1, |
385 | | false, |
386 | | )?; |
387 | | test_i32_result(nth_value, Int32Array::from(vec![1; 8]))?; |
388 | | Ok(()) |
389 | | } |
390 | | |
391 | | #[test] |
392 | | fn nth_value_2() -> Result<()> { |
393 | | let nth_value = NthValue::nth( |
394 | | "nth_value".to_owned(), |
395 | | Arc::new(Column::new("arr", 0)), |
396 | | DataType::Int32, |
397 | | 2, |
398 | | false, |
399 | | )?; |
400 | | test_i32_result( |
401 | | nth_value, |
402 | | Int32Array::from(vec![ |
403 | | None, |
404 | | Some(-2), |
405 | | Some(-2), |
406 | | Some(-2), |
407 | | Some(-2), |
408 | | Some(-2), |
409 | | Some(-2), |
410 | | Some(-2), |
411 | | ]), |
412 | | )?; |
413 | | Ok(()) |
414 | | } |
415 | | } |