/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/nth_value.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines NTH_VALUE aggregate expression which may specify ordering requirement |
19 | | //! that can evaluated at runtime during query execution |
20 | | |
21 | | use std::any::Any; |
22 | | use std::collections::VecDeque; |
23 | | use std::sync::Arc; |
24 | | |
25 | | use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; |
26 | | use arrow_schema::{DataType, Field, Fields}; |
27 | | |
28 | | use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; |
29 | | use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; |
30 | | use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; |
31 | | use datafusion_expr::utils::format_state_name; |
32 | | use datafusion_expr::{ |
33 | | lit, Accumulator, AggregateUDFImpl, ExprFunctionExt, ReversedUDAF, Signature, |
34 | | SortExpr, Volatility, |
35 | | }; |
36 | | use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; |
37 | | use datafusion_functions_aggregate_common::utils::ordering_fields; |
38 | | use datafusion_physical_expr::expressions::Literal; |
39 | | use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; |
40 | | |
41 | | create_func!(NthValueAgg, nth_value_udaf); |
42 | | |
43 | | /// Returns the nth value in a group of values. |
44 | 0 | pub fn nth_value( |
45 | 0 | expr: datafusion_expr::Expr, |
46 | 0 | n: i64, |
47 | 0 | order_by: Vec<SortExpr>, |
48 | 0 | ) -> datafusion_expr::Expr { |
49 | 0 | let args = vec![expr, lit(n)]; |
50 | 0 | if !order_by.is_empty() { |
51 | 0 | nth_value_udaf() |
52 | 0 | .call(args) |
53 | 0 | .order_by(order_by) |
54 | 0 | .build() |
55 | 0 | .unwrap() |
56 | | } else { |
57 | 0 | nth_value_udaf().call(args) |
58 | | } |
59 | 0 | } |
60 | | |
61 | | /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi |
62 | | /// partition setting, partial aggregations are computed for every partition, |
63 | | /// and then their results are merged. |
64 | | #[derive(Debug)] |
65 | | pub struct NthValueAgg { |
66 | | signature: Signature, |
67 | | } |
68 | | |
69 | | impl NthValueAgg { |
70 | | /// Create a new `NthValueAgg` aggregate function |
71 | 0 | pub fn new() -> Self { |
72 | 0 | Self { |
73 | 0 | signature: Signature::any(2, Volatility::Immutable), |
74 | 0 | } |
75 | 0 | } |
76 | | } |
77 | | |
78 | | impl Default for NthValueAgg { |
79 | 0 | fn default() -> Self { |
80 | 0 | Self::new() |
81 | 0 | } |
82 | | } |
83 | | |
84 | | impl AggregateUDFImpl for NthValueAgg { |
85 | 0 | fn as_any(&self) -> &dyn Any { |
86 | 0 | self |
87 | 0 | } |
88 | | |
89 | 0 | fn name(&self) -> &str { |
90 | 0 | "nth_value" |
91 | 0 | } |
92 | | |
93 | 0 | fn signature(&self) -> &Signature { |
94 | 0 | &self.signature |
95 | 0 | } |
96 | | |
97 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
98 | 0 | Ok(arg_types[0].clone()) |
99 | 0 | } |
100 | | |
101 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
102 | 0 | let n = match acc_args.exprs[1] |
103 | 0 | .as_any() |
104 | 0 | .downcast_ref::<Literal>() |
105 | 0 | .map(|lit| lit.value()) |
106 | | { |
107 | 0 | Some(ScalarValue::Int64(Some(value))) => { |
108 | 0 | if acc_args.is_reversed { |
109 | 0 | -*value |
110 | | } else { |
111 | 0 | *value |
112 | | } |
113 | | } |
114 | | _ => { |
115 | 0 | return not_impl_err!( |
116 | 0 | "{} not supported for n: {}", |
117 | 0 | self.name(), |
118 | 0 | &acc_args.exprs[1] |
119 | 0 | ) |
120 | | } |
121 | | }; |
122 | | |
123 | 0 | let ordering_dtypes = acc_args |
124 | 0 | .ordering_req |
125 | 0 | .iter() |
126 | 0 | .map(|e| e.expr.data_type(acc_args.schema)) |
127 | 0 | .collect::<Result<Vec<_>>>()?; |
128 | | |
129 | 0 | let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; |
130 | 0 | NthValueAccumulator::try_new( |
131 | 0 | n, |
132 | 0 | &data_type, |
133 | 0 | &ordering_dtypes, |
134 | 0 | acc_args.ordering_req.to_vec(), |
135 | 0 | ) |
136 | 0 | .map(|acc| Box::new(acc) as _) |
137 | 0 | } |
138 | | |
139 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
140 | 0 | let mut fields = vec![Field::new_list( |
141 | 0 | format_state_name(self.name(), "nth_value"), |
142 | 0 | // See COMMENTS.md to understand why nullable is set to true |
143 | 0 | Field::new("item", args.input_types[0].clone(), true), |
144 | 0 | false, |
145 | 0 | )]; |
146 | 0 | let orderings = args.ordering_fields.to_vec(); |
147 | 0 | if !orderings.is_empty() { |
148 | 0 | fields.push(Field::new_list( |
149 | 0 | format_state_name(self.name(), "nth_value_orderings"), |
150 | 0 | Field::new("item", DataType::Struct(Fields::from(orderings)), true), |
151 | 0 | false, |
152 | 0 | )); |
153 | 0 | } |
154 | 0 | Ok(fields) |
155 | 0 | } |
156 | | |
157 | 0 | fn aliases(&self) -> &[String] { |
158 | 0 | &[] |
159 | 0 | } |
160 | | |
161 | 0 | fn reverse_expr(&self) -> ReversedUDAF { |
162 | 0 | ReversedUDAF::Reversed(nth_value_udaf()) |
163 | 0 | } |
164 | | } |
165 | | |
166 | | #[derive(Debug)] |
167 | | pub struct NthValueAccumulator { |
168 | | /// The `N` value. |
169 | | n: i64, |
170 | | /// Stores entries in the `NTH_VALUE` result. |
171 | | values: VecDeque<ScalarValue>, |
172 | | /// Stores values of ordering requirement expressions corresponding to each |
173 | | /// entry in `values`. This information is used when merging results from |
174 | | /// different partitions. For detailed information how merging is done, see |
175 | | /// [`merge_ordered_arrays`]. |
176 | | ordering_values: VecDeque<Vec<ScalarValue>>, |
177 | | /// Stores datatypes of expressions inside values and ordering requirement |
178 | | /// expressions. |
179 | | datatypes: Vec<DataType>, |
180 | | /// Stores the ordering requirement of the `Accumulator`. |
181 | | ordering_req: LexOrdering, |
182 | | } |
183 | | |
184 | | impl NthValueAccumulator { |
185 | | /// Create a new order-sensitive NTH_VALUE accumulator based on the given |
186 | | /// item data type. |
187 | 0 | pub fn try_new( |
188 | 0 | n: i64, |
189 | 0 | datatype: &DataType, |
190 | 0 | ordering_dtypes: &[DataType], |
191 | 0 | ordering_req: LexOrdering, |
192 | 0 | ) -> Result<Self> { |
193 | 0 | if n == 0 { |
194 | | // n cannot be 0 |
195 | 0 | return internal_err!("Nth value indices are 1 based. 0 is invalid index"); |
196 | 0 | } |
197 | 0 | let mut datatypes = vec![datatype.clone()]; |
198 | 0 | datatypes.extend(ordering_dtypes.iter().cloned()); |
199 | 0 | Ok(Self { |
200 | 0 | n, |
201 | 0 | values: VecDeque::new(), |
202 | 0 | ordering_values: VecDeque::new(), |
203 | 0 | datatypes, |
204 | 0 | ordering_req, |
205 | 0 | }) |
206 | 0 | } |
207 | | } |
208 | | |
209 | | impl Accumulator for NthValueAccumulator { |
210 | | /// Updates its state with the `values`. Assumes data in the `values` satisfies the required |
211 | | /// ordering for the accumulator (across consecutive batches, not just batch-wise). |
212 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
213 | 0 | if values.is_empty() { |
214 | 0 | return Ok(()); |
215 | 0 | } |
216 | 0 |
|
217 | 0 | let n_required = self.n.unsigned_abs() as usize; |
218 | 0 | let from_start = self.n > 0; |
219 | 0 | if from_start { |
220 | | // direction is from start |
221 | 0 | let n_remaining = n_required.saturating_sub(self.values.len()); |
222 | 0 | self.append_new_data(values, Some(n_remaining))?; |
223 | | } else { |
224 | | // direction is from end |
225 | 0 | self.append_new_data(values, None)?; |
226 | 0 | let start_offset = self.values.len().saturating_sub(n_required); |
227 | 0 | if start_offset > 0 { |
228 | 0 | self.values.drain(0..start_offset); |
229 | 0 | self.ordering_values.drain(0..start_offset); |
230 | 0 | } |
231 | | } |
232 | | |
233 | 0 | Ok(()) |
234 | 0 | } |
235 | | |
236 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
237 | 0 | if states.is_empty() { |
238 | 0 | return Ok(()); |
239 | 0 | } |
240 | 0 | // First entry in the state is the aggregation result. |
241 | 0 | let array_agg_values = &states[0]; |
242 | 0 | let n_required = self.n.unsigned_abs() as usize; |
243 | 0 | if self.ordering_req.is_empty() { |
244 | 0 | let array_agg_res = |
245 | 0 | ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; |
246 | 0 | for v in array_agg_res.into_iter() { |
247 | 0 | self.values.extend(v); |
248 | 0 | if self.values.len() > n_required { |
249 | | // There is enough data collected can stop merging |
250 | 0 | break; |
251 | 0 | } |
252 | | } |
253 | 0 | } else if let Some(agg_orderings) = states[1].as_list_opt::<i32>() { |
254 | | // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list. |
255 | | // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores |
256 | | // values received from its ordering requirement expression. (This information is necessary for during merging). |
257 | | |
258 | | // Stores NTH_VALUE results coming from each partition |
259 | 0 | let mut partition_values: Vec<VecDeque<ScalarValue>> = vec![]; |
260 | 0 | // Stores ordering requirement expression results coming from each partition |
261 | 0 | let mut partition_ordering_values: Vec<VecDeque<Vec<ScalarValue>>> = vec![]; |
262 | 0 |
|
263 | 0 | // Existing values should be merged also. |
264 | 0 | partition_values.push(self.values.clone()); |
265 | 0 |
|
266 | 0 | partition_ordering_values.push(self.ordering_values.clone()); |
267 | | |
268 | 0 | let array_agg_res = |
269 | 0 | ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; |
270 | | |
271 | 0 | for v in array_agg_res.into_iter() { |
272 | 0 | partition_values.push(v.into()); |
273 | 0 | } |
274 | | |
275 | 0 | let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; |
276 | | |
277 | 0 | let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { |
278 | 0 | // Extract value from struct to ordering_rows for each group/partition |
279 | 0 | partition_ordering_rows.into_iter().map(|ordering_row| { |
280 | 0 | if let ScalarValue::Struct(s) = ordering_row { |
281 | 0 | let mut ordering_columns_per_row = vec![]; |
282 | | |
283 | 0 | for column in s.columns() { |
284 | 0 | let sv = ScalarValue::try_from_array(column, 0)?; |
285 | 0 | ordering_columns_per_row.push(sv); |
286 | | } |
287 | | |
288 | 0 | Ok(ordering_columns_per_row) |
289 | | } else { |
290 | 0 | exec_err!( |
291 | 0 | "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", |
292 | 0 | ordering_row.data_type() |
293 | 0 | ) |
294 | | } |
295 | 0 | }).collect::<Result<Vec<_>>>() |
296 | 0 | }).collect::<Result<Vec<_>>>()?; |
297 | 0 | for ordering_values in ordering_values.into_iter() { |
298 | 0 | partition_ordering_values.push(ordering_values.into()); |
299 | 0 | } |
300 | | |
301 | 0 | let sort_options = self |
302 | 0 | .ordering_req |
303 | 0 | .iter() |
304 | 0 | .map(|sort_expr| sort_expr.options) |
305 | 0 | .collect::<Vec<_>>(); |
306 | 0 | let (new_values, new_orderings) = merge_ordered_arrays( |
307 | 0 | &mut partition_values, |
308 | 0 | &mut partition_ordering_values, |
309 | 0 | &sort_options, |
310 | 0 | )?; |
311 | 0 | self.values = new_values.into(); |
312 | 0 | self.ordering_values = new_orderings.into(); |
313 | | } else { |
314 | 0 | return exec_err!("Expects to receive a list array"); |
315 | | } |
316 | 0 | Ok(()) |
317 | 0 | } |
318 | | |
319 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
320 | 0 | let mut result = vec![self.evaluate_values()]; |
321 | 0 | if !self.ordering_req.is_empty() { |
322 | 0 | result.push(self.evaluate_orderings()?); |
323 | 0 | } |
324 | 0 | Ok(result) |
325 | 0 | } |
326 | | |
327 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
328 | 0 | let n_required = self.n.unsigned_abs() as usize; |
329 | 0 | let from_start = self.n > 0; |
330 | 0 | let nth_value_idx = if from_start { |
331 | | // index is from start |
332 | 0 | let forward_idx = n_required - 1; |
333 | 0 | (forward_idx < self.values.len()).then_some(forward_idx) |
334 | | } else { |
335 | | // index is from end |
336 | 0 | self.values.len().checked_sub(n_required) |
337 | | }; |
338 | 0 | if let Some(idx) = nth_value_idx { |
339 | 0 | Ok(self.values[idx].clone()) |
340 | | } else { |
341 | 0 | ScalarValue::try_from(self.datatypes[0].clone()) |
342 | | } |
343 | 0 | } |
344 | | |
345 | 0 | fn size(&self) -> usize { |
346 | 0 | let mut total = std::mem::size_of_val(self) |
347 | 0 | + ScalarValue::size_of_vec_deque(&self.values) |
348 | 0 | - std::mem::size_of_val(&self.values); |
349 | 0 |
|
350 | 0 | // Add size of the `self.ordering_values` |
351 | 0 | total += |
352 | 0 | std::mem::size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity(); |
353 | 0 | for row in &self.ordering_values { |
354 | 0 | total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); |
355 | 0 | } |
356 | | |
357 | | // Add size of the `self.datatypes` |
358 | 0 | total += std::mem::size_of::<DataType>() * self.datatypes.capacity(); |
359 | 0 | for dtype in &self.datatypes { |
360 | 0 | total += dtype.size() - std::mem::size_of_val(dtype); |
361 | 0 | } |
362 | | |
363 | | // Add size of the `self.ordering_req` |
364 | 0 | total += std::mem::size_of::<PhysicalSortExpr>() * self.ordering_req.capacity(); |
365 | 0 | // TODO: Calculate size of each `PhysicalSortExpr` more accurately. |
366 | 0 | total |
367 | 0 | } |
368 | | } |
369 | | |
370 | | impl NthValueAccumulator { |
371 | 0 | fn evaluate_orderings(&self) -> Result<ScalarValue> { |
372 | 0 | let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); |
373 | 0 | let struct_field = Fields::from(fields.clone()); |
374 | 0 |
|
375 | 0 | let mut column_wise_ordering_values = vec![]; |
376 | 0 | let num_columns = fields.len(); |
377 | 0 | for i in 0..num_columns { |
378 | 0 | let column_values = self |
379 | 0 | .ordering_values |
380 | 0 | .iter() |
381 | 0 | .map(|x| x[i].clone()) |
382 | 0 | .collect::<Vec<_>>(); |
383 | 0 | let array = if column_values.is_empty() { |
384 | 0 | new_empty_array(fields[i].data_type()) |
385 | | } else { |
386 | 0 | ScalarValue::iter_to_array(column_values.into_iter())? |
387 | | }; |
388 | 0 | column_wise_ordering_values.push(array); |
389 | | } |
390 | | |
391 | 0 | let ordering_array = |
392 | 0 | StructArray::try_new(struct_field, column_wise_ordering_values, None)?; |
393 | | |
394 | 0 | Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( |
395 | 0 | Arc::new(ordering_array), |
396 | 0 | )))) |
397 | 0 | } |
398 | | |
399 | 0 | fn evaluate_values(&self) -> ScalarValue { |
400 | 0 | let mut values_cloned = self.values.clone(); |
401 | 0 | let values_slice = values_cloned.make_contiguous(); |
402 | 0 | ScalarValue::List(ScalarValue::new_list_nullable( |
403 | 0 | values_slice, |
404 | 0 | &self.datatypes[0], |
405 | 0 | )) |
406 | 0 | } |
407 | | |
408 | | /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete |
409 | | /// None represents all of the new `values` need to be added to the state. |
410 | 0 | fn append_new_data( |
411 | 0 | &mut self, |
412 | 0 | values: &[ArrayRef], |
413 | 0 | fetch: Option<usize>, |
414 | 0 | ) -> Result<()> { |
415 | 0 | let n_row = values[0].len(); |
416 | 0 | let n_to_add = if let Some(fetch) = fetch { |
417 | 0 | std::cmp::min(fetch, n_row) |
418 | | } else { |
419 | 0 | n_row |
420 | | }; |
421 | 0 | for index in 0..n_to_add { |
422 | 0 | let row = get_row_at_idx(values, index)?; |
423 | 0 | self.values.push_back(row[0].clone()); |
424 | 0 | // At index 1, we have n index argument. |
425 | 0 | // Ordering values cover starting from 2nd index to end |
426 | 0 | self.ordering_values.push_back(row[2..].to_vec()); |
427 | | } |
428 | 0 | Ok(()) |
429 | 0 | } |
430 | | } |