Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/nth_value.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines NTH_VALUE aggregate expression which may specify ordering requirement
19
//! that can evaluated at runtime during query execution
20
21
use std::any::Any;
22
use std::collections::VecDeque;
23
use std::sync::Arc;
24
25
use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
26
use arrow_schema::{DataType, Field, Fields};
27
28
use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
29
use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
30
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31
use datafusion_expr::utils::format_state_name;
32
use datafusion_expr::{
33
    lit, Accumulator, AggregateUDFImpl, ExprFunctionExt, ReversedUDAF, Signature,
34
    SortExpr, Volatility,
35
};
36
use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
37
use datafusion_functions_aggregate_common::utils::ordering_fields;
38
use datafusion_physical_expr::expressions::Literal;
39
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
40
41
create_func!(NthValueAgg, nth_value_udaf);
42
43
/// Returns the nth value in a group of values.
44
0
pub fn nth_value(
45
0
    expr: datafusion_expr::Expr,
46
0
    n: i64,
47
0
    order_by: Vec<SortExpr>,
48
0
) -> datafusion_expr::Expr {
49
0
    let args = vec![expr, lit(n)];
50
0
    if !order_by.is_empty() {
51
0
        nth_value_udaf()
52
0
            .call(args)
53
0
            .order_by(order_by)
54
0
            .build()
55
0
            .unwrap()
56
    } else {
57
0
        nth_value_udaf().call(args)
58
    }
59
0
}
60
61
/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi
62
/// partition setting, partial aggregations are computed for every partition,
63
/// and then their results are merged.
64
#[derive(Debug)]
65
pub struct NthValueAgg {
66
    signature: Signature,
67
}
68
69
impl NthValueAgg {
70
    /// Create a new `NthValueAgg` aggregate function
71
0
    pub fn new() -> Self {
72
0
        Self {
73
0
            signature: Signature::any(2, Volatility::Immutable),
74
0
        }
75
0
    }
76
}
77
78
impl Default for NthValueAgg {
79
0
    fn default() -> Self {
80
0
        Self::new()
81
0
    }
82
}
83
84
impl AggregateUDFImpl for NthValueAgg {
85
0
    fn as_any(&self) -> &dyn Any {
86
0
        self
87
0
    }
88
89
0
    fn name(&self) -> &str {
90
0
        "nth_value"
91
0
    }
92
93
0
    fn signature(&self) -> &Signature {
94
0
        &self.signature
95
0
    }
96
97
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
98
0
        Ok(arg_types[0].clone())
99
0
    }
100
101
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
102
0
        let n = match acc_args.exprs[1]
103
0
            .as_any()
104
0
            .downcast_ref::<Literal>()
105
0
            .map(|lit| lit.value())
106
        {
107
0
            Some(ScalarValue::Int64(Some(value))) => {
108
0
                if acc_args.is_reversed {
109
0
                    -*value
110
                } else {
111
0
                    *value
112
                }
113
            }
114
            _ => {
115
0
                return not_impl_err!(
116
0
                    "{} not supported for n: {}",
117
0
                    self.name(),
118
0
                    &acc_args.exprs[1]
119
0
                )
120
            }
121
        };
122
123
0
        let ordering_dtypes = acc_args
124
0
            .ordering_req
125
0
            .iter()
126
0
            .map(|e| e.expr.data_type(acc_args.schema))
127
0
            .collect::<Result<Vec<_>>>()?;
128
129
0
        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
130
0
        NthValueAccumulator::try_new(
131
0
            n,
132
0
            &data_type,
133
0
            &ordering_dtypes,
134
0
            acc_args.ordering_req.to_vec(),
135
0
        )
136
0
        .map(|acc| Box::new(acc) as _)
137
0
    }
138
139
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
140
0
        let mut fields = vec![Field::new_list(
141
0
            format_state_name(self.name(), "nth_value"),
142
0
            // See COMMENTS.md to understand why nullable is set to true
143
0
            Field::new("item", args.input_types[0].clone(), true),
144
0
            false,
145
0
        )];
146
0
        let orderings = args.ordering_fields.to_vec();
147
0
        if !orderings.is_empty() {
148
0
            fields.push(Field::new_list(
149
0
                format_state_name(self.name(), "nth_value_orderings"),
150
0
                Field::new("item", DataType::Struct(Fields::from(orderings)), true),
151
0
                false,
152
0
            ));
153
0
        }
154
0
        Ok(fields)
155
0
    }
156
157
0
    fn aliases(&self) -> &[String] {
158
0
        &[]
159
0
    }
160
161
0
    fn reverse_expr(&self) -> ReversedUDAF {
162
0
        ReversedUDAF::Reversed(nth_value_udaf())
163
0
    }
164
}
165
166
#[derive(Debug)]
167
pub struct NthValueAccumulator {
168
    /// The `N` value.
169
    n: i64,
170
    /// Stores entries in the `NTH_VALUE` result.
171
    values: VecDeque<ScalarValue>,
172
    /// Stores values of ordering requirement expressions corresponding to each
173
    /// entry in `values`. This information is used when merging results from
174
    /// different partitions. For detailed information how merging is done, see
175
    /// [`merge_ordered_arrays`].
176
    ordering_values: VecDeque<Vec<ScalarValue>>,
177
    /// Stores datatypes of expressions inside values and ordering requirement
178
    /// expressions.
179
    datatypes: Vec<DataType>,
180
    /// Stores the ordering requirement of the `Accumulator`.
181
    ordering_req: LexOrdering,
182
}
183
184
impl NthValueAccumulator {
185
    /// Create a new order-sensitive NTH_VALUE accumulator based on the given
186
    /// item data type.
187
0
    pub fn try_new(
188
0
        n: i64,
189
0
        datatype: &DataType,
190
0
        ordering_dtypes: &[DataType],
191
0
        ordering_req: LexOrdering,
192
0
    ) -> Result<Self> {
193
0
        if n == 0 {
194
            // n cannot be 0
195
0
            return internal_err!("Nth value indices are 1 based. 0 is invalid index");
196
0
        }
197
0
        let mut datatypes = vec![datatype.clone()];
198
0
        datatypes.extend(ordering_dtypes.iter().cloned());
199
0
        Ok(Self {
200
0
            n,
201
0
            values: VecDeque::new(),
202
0
            ordering_values: VecDeque::new(),
203
0
            datatypes,
204
0
            ordering_req,
205
0
        })
206
0
    }
207
}
208
209
impl Accumulator for NthValueAccumulator {
210
    /// Updates its state with the `values`. Assumes data in the `values` satisfies the required
211
    /// ordering for the accumulator (across consecutive batches, not just batch-wise).
212
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
213
0
        if values.is_empty() {
214
0
            return Ok(());
215
0
        }
216
0
217
0
        let n_required = self.n.unsigned_abs() as usize;
218
0
        let from_start = self.n > 0;
219
0
        if from_start {
220
            // direction is from start
221
0
            let n_remaining = n_required.saturating_sub(self.values.len());
222
0
            self.append_new_data(values, Some(n_remaining))?;
223
        } else {
224
            // direction is from end
225
0
            self.append_new_data(values, None)?;
226
0
            let start_offset = self.values.len().saturating_sub(n_required);
227
0
            if start_offset > 0 {
228
0
                self.values.drain(0..start_offset);
229
0
                self.ordering_values.drain(0..start_offset);
230
0
            }
231
        }
232
233
0
        Ok(())
234
0
    }
235
236
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
237
0
        if states.is_empty() {
238
0
            return Ok(());
239
0
        }
240
0
        // First entry in the state is the aggregation result.
241
0
        let array_agg_values = &states[0];
242
0
        let n_required = self.n.unsigned_abs() as usize;
243
0
        if self.ordering_req.is_empty() {
244
0
            let array_agg_res =
245
0
                ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
246
0
            for v in array_agg_res.into_iter() {
247
0
                self.values.extend(v);
248
0
                if self.values.len() > n_required {
249
                    // There is enough data collected can stop merging
250
0
                    break;
251
0
                }
252
            }
253
0
        } else if let Some(agg_orderings) = states[1].as_list_opt::<i32>() {
254
            // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list.
255
            // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores
256
            // values received from its ordering requirement expression. (This information is necessary for during merging).
257
258
            // Stores NTH_VALUE results coming from each partition
259
0
            let mut partition_values: Vec<VecDeque<ScalarValue>> = vec![];
260
0
            // Stores ordering requirement expression results coming from each partition
261
0
            let mut partition_ordering_values: Vec<VecDeque<Vec<ScalarValue>>> = vec![];
262
0
263
0
            // Existing values should be merged also.
264
0
            partition_values.push(self.values.clone());
265
0
266
0
            partition_ordering_values.push(self.ordering_values.clone());
267
268
0
            let array_agg_res =
269
0
                ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
270
271
0
            for v in array_agg_res.into_iter() {
272
0
                partition_values.push(v.into());
273
0
            }
274
275
0
            let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
276
277
0
            let ordering_values = orderings.into_iter().map(|partition_ordering_rows| {
278
0
                // Extract value from struct to ordering_rows for each group/partition
279
0
                partition_ordering_rows.into_iter().map(|ordering_row| {
280
0
                    if let ScalarValue::Struct(s) = ordering_row {
281
0
                        let mut ordering_columns_per_row = vec![];
282
283
0
                        for column in s.columns() {
284
0
                            let sv = ScalarValue::try_from_array(column, 0)?;
285
0
                            ordering_columns_per_row.push(sv);
286
                        }
287
288
0
                        Ok(ordering_columns_per_row)
289
                    } else {
290
0
                        exec_err!(
291
0
                            "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}",
292
0
                            ordering_row.data_type()
293
0
                        )
294
                    }
295
0
                }).collect::<Result<Vec<_>>>()
296
0
            }).collect::<Result<Vec<_>>>()?;
297
0
            for ordering_values in ordering_values.into_iter() {
298
0
                partition_ordering_values.push(ordering_values.into());
299
0
            }
300
301
0
            let sort_options = self
302
0
                .ordering_req
303
0
                .iter()
304
0
                .map(|sort_expr| sort_expr.options)
305
0
                .collect::<Vec<_>>();
306
0
            let (new_values, new_orderings) = merge_ordered_arrays(
307
0
                &mut partition_values,
308
0
                &mut partition_ordering_values,
309
0
                &sort_options,
310
0
            )?;
311
0
            self.values = new_values.into();
312
0
            self.ordering_values = new_orderings.into();
313
        } else {
314
0
            return exec_err!("Expects to receive a list array");
315
        }
316
0
        Ok(())
317
0
    }
318
319
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
320
0
        let mut result = vec![self.evaluate_values()];
321
0
        if !self.ordering_req.is_empty() {
322
0
            result.push(self.evaluate_orderings()?);
323
0
        }
324
0
        Ok(result)
325
0
    }
326
327
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
328
0
        let n_required = self.n.unsigned_abs() as usize;
329
0
        let from_start = self.n > 0;
330
0
        let nth_value_idx = if from_start {
331
            // index is from start
332
0
            let forward_idx = n_required - 1;
333
0
            (forward_idx < self.values.len()).then_some(forward_idx)
334
        } else {
335
            // index is from end
336
0
            self.values.len().checked_sub(n_required)
337
        };
338
0
        if let Some(idx) = nth_value_idx {
339
0
            Ok(self.values[idx].clone())
340
        } else {
341
0
            ScalarValue::try_from(self.datatypes[0].clone())
342
        }
343
0
    }
344
345
0
    fn size(&self) -> usize {
346
0
        let mut total = std::mem::size_of_val(self)
347
0
            + ScalarValue::size_of_vec_deque(&self.values)
348
0
            - std::mem::size_of_val(&self.values);
349
0
350
0
        // Add size of the `self.ordering_values`
351
0
        total +=
352
0
            std::mem::size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity();
353
0
        for row in &self.ordering_values {
354
0
            total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row);
355
0
        }
356
357
        // Add size of the `self.datatypes`
358
0
        total += std::mem::size_of::<DataType>() * self.datatypes.capacity();
359
0
        for dtype in &self.datatypes {
360
0
            total += dtype.size() - std::mem::size_of_val(dtype);
361
0
        }
362
363
        // Add size of the `self.ordering_req`
364
0
        total += std::mem::size_of::<PhysicalSortExpr>() * self.ordering_req.capacity();
365
0
        // TODO: Calculate size of each `PhysicalSortExpr` more accurately.
366
0
        total
367
0
    }
368
}
369
370
impl NthValueAccumulator {
371
0
    fn evaluate_orderings(&self) -> Result<ScalarValue> {
372
0
        let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
373
0
        let struct_field = Fields::from(fields.clone());
374
0
375
0
        let mut column_wise_ordering_values = vec![];
376
0
        let num_columns = fields.len();
377
0
        for i in 0..num_columns {
378
0
            let column_values = self
379
0
                .ordering_values
380
0
                .iter()
381
0
                .map(|x| x[i].clone())
382
0
                .collect::<Vec<_>>();
383
0
            let array = if column_values.is_empty() {
384
0
                new_empty_array(fields[i].data_type())
385
            } else {
386
0
                ScalarValue::iter_to_array(column_values.into_iter())?
387
            };
388
0
            column_wise_ordering_values.push(array);
389
        }
390
391
0
        let ordering_array =
392
0
            StructArray::try_new(struct_field, column_wise_ordering_values, None)?;
393
394
0
        Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable(
395
0
            Arc::new(ordering_array),
396
0
        ))))
397
0
    }
398
399
0
    fn evaluate_values(&self) -> ScalarValue {
400
0
        let mut values_cloned = self.values.clone();
401
0
        let values_slice = values_cloned.make_contiguous();
402
0
        ScalarValue::List(ScalarValue::new_list_nullable(
403
0
            values_slice,
404
0
            &self.datatypes[0],
405
0
        ))
406
0
    }
407
408
    /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete
409
    /// None represents all of the new `values` need to be added to the state.
410
0
    fn append_new_data(
411
0
        &mut self,
412
0
        values: &[ArrayRef],
413
0
        fetch: Option<usize>,
414
0
    ) -> Result<()> {
415
0
        let n_row = values[0].len();
416
0
        let n_to_add = if let Some(fetch) = fetch {
417
0
            std::cmp::min(fetch, n_row)
418
        } else {
419
0
            n_row
420
        };
421
0
        for index in 0..n_to_add {
422
0
            let row = get_row_at_idx(values, index)?;
423
0
            self.values.push_back(row[0].clone());
424
0
            // At index 1, we have n index argument.
425
0
            // Ordering values cover starting from 2nd index to end
426
0
            self.ordering_values.push_back(row[2..].to_vec());
427
        }
428
0
        Ok(())
429
0
    }
430
}