Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/approx_percentile_cont.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::any::Any;
19
use std::fmt::{Debug, Formatter};
20
use std::sync::Arc;
21
22
use arrow::array::{Array, RecordBatch};
23
use arrow::compute::{filter, is_not_null};
24
use arrow::{
25
    array::{
26
        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
27
        Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
28
    },
29
    datatypes::DataType,
30
};
31
use arrow_schema::{Field, Schema};
32
33
use datafusion_common::{
34
    downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
35
    DataFusionError, Result, ScalarValue,
36
};
37
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
38
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
39
use datafusion_expr::utils::format_state_name;
40
use datafusion_expr::{
41
    Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature,
42
    Volatility,
43
};
44
use datafusion_functions_aggregate_common::tdigest::{
45
    TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
46
};
47
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
48
49
create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);
50
51
/// Computes the approximate percentile continuous of a set of numbers
52
0
pub fn approx_percentile_cont(
53
0
    expression: Expr,
54
0
    percentile: Expr,
55
0
    centroids: Option<Expr>,
56
0
) -> Expr {
57
0
    let args = if let Some(centroids) = centroids {
58
0
        vec![expression, percentile, centroids]
59
    } else {
60
0
        vec![expression, percentile]
61
    };
62
0
    approx_percentile_cont_udaf().call(args)
63
0
}
64
65
pub struct ApproxPercentileCont {
66
    signature: Signature,
67
}
68
69
impl Debug for ApproxPercentileCont {
70
0
    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
71
0
        f.debug_struct("ApproxPercentileCont")
72
0
            .field("name", &self.name())
73
0
            .field("signature", &self.signature)
74
0
            .finish()
75
0
    }
76
}
77
78
impl Default for ApproxPercentileCont {
79
0
    fn default() -> Self {
80
0
        Self::new()
81
0
    }
82
}
83
84
impl ApproxPercentileCont {
85
    /// Create a new [`ApproxPercentileCont`] aggregate function.
86
0
    pub fn new() -> Self {
87
0
        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
88
        // Accept any numeric value paired with a float64 percentile
89
0
        for num in NUMERICS {
90
0
            variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
91
            // Additionally accept an integer number of centroids for T-Digest
92
0
            for int in INTEGERS {
93
0
                variants.push(TypeSignature::Exact(vec![
94
0
                    num.clone(),
95
0
                    DataType::Float64,
96
0
                    int.clone(),
97
0
                ]))
98
            }
99
        }
100
0
        Self {
101
0
            signature: Signature::one_of(variants, Volatility::Immutable),
102
0
        }
103
0
    }
104
105
0
    pub(crate) fn create_accumulator(
106
0
        &self,
107
0
        args: AccumulatorArgs,
108
0
    ) -> Result<ApproxPercentileAccumulator> {
109
0
        let percentile = validate_input_percentile_expr(&args.exprs[1])?;
110
0
        let tdigest_max_size = if args.exprs.len() == 3 {
111
0
            Some(validate_input_max_size_expr(&args.exprs[2])?)
112
        } else {
113
0
            None
114
        };
115
116
0
        let data_type = args.exprs[0].data_type(args.schema)?;
117
0
        let accumulator: ApproxPercentileAccumulator = match data_type {
118
0
            t @ (DataType::UInt8
119
            | DataType::UInt16
120
            | DataType::UInt32
121
            | DataType::UInt64
122
            | DataType::Int8
123
            | DataType::Int16
124
            | DataType::Int32
125
            | DataType::Int64
126
            | DataType::Float32
127
            | DataType::Float64) => {
128
0
                if let Some(max_size) = tdigest_max_size {
129
0
                    ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
130
                }else{
131
0
                    ApproxPercentileAccumulator::new(percentile, t)
132
133
                }
134
            }
135
0
            other => {
136
0
                return not_impl_err!(
137
0
                    "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented"
138
0
                )
139
            }
140
        };
141
142
0
        Ok(accumulator)
143
0
    }
144
}
145
146
0
fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
147
0
    let empty_schema = Arc::new(Schema::empty());
148
0
    let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
149
0
    if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
150
0
        Ok(s)
151
    } else {
152
0
        internal_err!("Didn't expect ColumnarValue::Array")
153
    }
154
0
}
155
156
0
fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
157
0
    let percentile = match get_scalar_value(expr)
158
0
        .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
159
0
        ScalarValue::Float32(Some(value)) => {
160
0
            value as f64
161
        }
162
0
        ScalarValue::Float64(Some(value)) => {
163
0
            value
164
        }
165
0
        sv => {
166
0
            return not_impl_err!(
167
0
                "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
168
0
                sv.data_type()
169
0
            )
170
        }
171
    };
172
173
    // Ensure the percentile is between 0 and 1.
174
0
    if !(0.0..=1.0).contains(&percentile) {
175
0
        return plan_err!(
176
0
            "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
177
0
        );
178
0
    }
179
0
    Ok(percentile)
180
0
}
181
182
0
fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
183
0
    let max_size = match get_scalar_value(expr)
184
0
        .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
185
0
        ScalarValue::UInt8(Some(q)) => q as usize,
186
0
        ScalarValue::UInt16(Some(q)) => q as usize,
187
0
        ScalarValue::UInt32(Some(q)) => q as usize,
188
0
        ScalarValue::UInt64(Some(q)) => q as usize,
189
0
        ScalarValue::Int32(Some(q)) if q > 0 => q as usize,
190
0
        ScalarValue::Int64(Some(q)) if q > 0 => q as usize,
191
0
        ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
192
0
        ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
193
0
        sv => {
194
0
            return not_impl_err!(
195
0
                "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
196
0
                sv.data_type()
197
0
            )
198
        },
199
    };
200
201
0
    Ok(max_size)
202
0
}
203
204
impl AggregateUDFImpl for ApproxPercentileCont {
205
0
    fn as_any(&self) -> &dyn Any {
206
0
        self
207
0
    }
208
209
    #[allow(rustdoc::private_intra_doc_links)]
210
    /// See [`TDigest::to_scalar_state()`] for a description of the serialised
211
    /// state.
212
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
213
0
        Ok(vec![
214
0
            Field::new(
215
0
                format_state_name(args.name, "max_size"),
216
0
                DataType::UInt64,
217
0
                false,
218
0
            ),
219
0
            Field::new(
220
0
                format_state_name(args.name, "sum"),
221
0
                DataType::Float64,
222
0
                false,
223
0
            ),
224
0
            Field::new(
225
0
                format_state_name(args.name, "count"),
226
0
                DataType::UInt64,
227
0
                false,
228
0
            ),
229
0
            Field::new(
230
0
                format_state_name(args.name, "max"),
231
0
                DataType::Float64,
232
0
                false,
233
0
            ),
234
0
            Field::new(
235
0
                format_state_name(args.name, "min"),
236
0
                DataType::Float64,
237
0
                false,
238
0
            ),
239
0
            Field::new_list(
240
0
                format_state_name(args.name, "centroids"),
241
0
                Field::new("item", DataType::Float64, true),
242
0
                false,
243
0
            ),
244
0
        ])
245
0
    }
246
247
0
    fn name(&self) -> &str {
248
0
        "approx_percentile_cont"
249
0
    }
250
251
0
    fn signature(&self) -> &Signature {
252
0
        &self.signature
253
0
    }
254
255
    #[inline]
256
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
257
0
        Ok(Box::new(self.create_accumulator(acc_args)?))
258
0
    }
259
260
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
261
0
        if !arg_types[0].is_numeric() {
262
0
            return plan_err!("approx_percentile_cont requires numeric input types");
263
0
        }
264
0
        if arg_types.len() == 3 && !arg_types[2].is_integer() {
265
0
            return plan_err!(
266
0
                "approx_percentile_cont requires integer max_size input types"
267
0
            );
268
0
        }
269
0
        Ok(arg_types[0].clone())
270
0
    }
271
}
272
273
#[derive(Debug)]
274
pub struct ApproxPercentileAccumulator {
275
    digest: TDigest,
276
    percentile: f64,
277
    return_type: DataType,
278
}
279
280
impl ApproxPercentileAccumulator {
281
0
    pub fn new(percentile: f64, return_type: DataType) -> Self {
282
0
        Self {
283
0
            digest: TDigest::new(DEFAULT_MAX_SIZE),
284
0
            percentile,
285
0
            return_type,
286
0
        }
287
0
    }
288
289
0
    pub fn new_with_max_size(
290
0
        percentile: f64,
291
0
        return_type: DataType,
292
0
        max_size: usize,
293
0
    ) -> Self {
294
0
        Self {
295
0
            digest: TDigest::new(max_size),
296
0
            percentile,
297
0
            return_type,
298
0
        }
299
0
    }
300
301
    // public for approx_percentile_cont_with_weight
302
0
    pub fn merge_digests(&mut self, digests: &[TDigest]) {
303
0
        let digests = digests.iter().chain(std::iter::once(&self.digest));
304
0
        self.digest = TDigest::merge_digests(digests)
305
0
    }
306
307
    // public for approx_percentile_cont_with_weight
308
0
    pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
309
0
        match values.data_type() {
310
            DataType::Float64 => {
311
0
                let array = downcast_value!(values, Float64Array);
312
0
                Ok(array
313
0
                    .values()
314
0
                    .iter()
315
0
                    .filter_map(|v| v.try_as_f64().transpose())
316
0
                    .collect::<Result<Vec<_>>>()?)
317
            }
318
            DataType::Float32 => {
319
0
                let array = downcast_value!(values, Float32Array);
320
0
                Ok(array
321
0
                    .values()
322
0
                    .iter()
323
0
                    .filter_map(|v| v.try_as_f64().transpose())
324
0
                    .collect::<Result<Vec<_>>>()?)
325
            }
326
            DataType::Int64 => {
327
0
                let array = downcast_value!(values, Int64Array);
328
0
                Ok(array
329
0
                    .values()
330
0
                    .iter()
331
0
                    .filter_map(|v| v.try_as_f64().transpose())
332
0
                    .collect::<Result<Vec<_>>>()?)
333
            }
334
            DataType::Int32 => {
335
0
                let array = downcast_value!(values, Int32Array);
336
0
                Ok(array
337
0
                    .values()
338
0
                    .iter()
339
0
                    .filter_map(|v| v.try_as_f64().transpose())
340
0
                    .collect::<Result<Vec<_>>>()?)
341
            }
342
            DataType::Int16 => {
343
0
                let array = downcast_value!(values, Int16Array);
344
0
                Ok(array
345
0
                    .values()
346
0
                    .iter()
347
0
                    .filter_map(|v| v.try_as_f64().transpose())
348
0
                    .collect::<Result<Vec<_>>>()?)
349
            }
350
            DataType::Int8 => {
351
0
                let array = downcast_value!(values, Int8Array);
352
0
                Ok(array
353
0
                    .values()
354
0
                    .iter()
355
0
                    .filter_map(|v| v.try_as_f64().transpose())
356
0
                    .collect::<Result<Vec<_>>>()?)
357
            }
358
            DataType::UInt64 => {
359
0
                let array = downcast_value!(values, UInt64Array);
360
0
                Ok(array
361
0
                    .values()
362
0
                    .iter()
363
0
                    .filter_map(|v| v.try_as_f64().transpose())
364
0
                    .collect::<Result<Vec<_>>>()?)
365
            }
366
            DataType::UInt32 => {
367
0
                let array = downcast_value!(values, UInt32Array);
368
0
                Ok(array
369
0
                    .values()
370
0
                    .iter()
371
0
                    .filter_map(|v| v.try_as_f64().transpose())
372
0
                    .collect::<Result<Vec<_>>>()?)
373
            }
374
            DataType::UInt16 => {
375
0
                let array = downcast_value!(values, UInt16Array);
376
0
                Ok(array
377
0
                    .values()
378
0
                    .iter()
379
0
                    .filter_map(|v| v.try_as_f64().transpose())
380
0
                    .collect::<Result<Vec<_>>>()?)
381
            }
382
            DataType::UInt8 => {
383
0
                let array = downcast_value!(values, UInt8Array);
384
0
                Ok(array
385
0
                    .values()
386
0
                    .iter()
387
0
                    .filter_map(|v| v.try_as_f64().transpose())
388
0
                    .collect::<Result<Vec<_>>>()?)
389
            }
390
0
            e => internal_err!(
391
0
                "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
392
0
            ),
393
        }
394
0
    }
395
}
396
397
impl Accumulator for ApproxPercentileAccumulator {
398
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
399
0
        Ok(self.digest.to_scalar_state().into_iter().collect())
400
0
    }
401
402
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
403
0
        // Remove any nulls before computing the percentile
404
0
        let mut values = Arc::clone(&values[0]);
405
0
        if values.nulls().is_some() {
406
0
            values = filter(&values, &is_not_null(&values)?)?;
407
0
        }
408
0
        let sorted_values = &arrow::compute::sort(&values, None)?;
409
0
        let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
410
0
        self.digest = self.digest.merge_sorted_f64(&sorted_values);
411
0
        Ok(())
412
0
    }
413
414
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
415
0
        if self.digest.count() == 0 {
416
0
            return ScalarValue::try_from(self.return_type.clone());
417
0
        }
418
0
        let q = self.digest.estimate_quantile(self.percentile);
419
0
420
0
        // These acceptable return types MUST match the validation in
421
0
        // ApproxPercentile::create_accumulator.
422
0
        Ok(match &self.return_type {
423
0
            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
424
0
            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
425
0
            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
426
0
            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
427
0
            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
428
0
            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
429
0
            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
430
0
            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
431
0
            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
432
0
            DataType::Float64 => ScalarValue::Float64(Some(q)),
433
0
            v => unreachable!("unexpected return type {:?}", v),
434
        })
435
0
    }
436
437
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
438
0
        if states.is_empty() {
439
0
            return Ok(());
440
0
        }
441
442
0
        let states = (0..states[0].len())
443
0
            .map(|index| {
444
0
                states
445
0
                    .iter()
446
0
                    .map(|array| ScalarValue::try_from_array(array, index))
447
0
                    .collect::<Result<Vec<_>>>()
448
0
                    .map(|state| TDigest::from_scalar_state(&state))
449
0
            })
450
0
            .collect::<Result<Vec<_>>>()?;
451
452
0
        self.merge_digests(&states);
453
0
454
0
        Ok(())
455
0
    }
456
457
0
    fn size(&self) -> usize {
458
0
        std::mem::size_of_val(self) + self.digest.size()
459
0
            - std::mem::size_of_val(&self.digest)
460
0
            + self.return_type.size()
461
0
            - std::mem::size_of_val(&self.return_type)
462
0
    }
463
}
464
465
#[cfg(test)]
466
mod tests {
467
    use arrow_schema::DataType;
468
469
    use datafusion_functions_aggregate_common::tdigest::TDigest;
470
471
    use crate::approx_percentile_cont::ApproxPercentileAccumulator;
472
473
    #[test]
474
    fn test_combine_approx_percentile_accumulator() {
475
        let mut digests: Vec<TDigest> = Vec::new();
476
477
        // one TDigest with 50_000 values from 1 to 1_000
478
        for _ in 1..=50 {
479
            let t = TDigest::new(100);
480
            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
481
            let t = t.merge_unsorted_f64(values);
482
            digests.push(t)
483
        }
484
485
        let t1 = TDigest::merge_digests(&digests);
486
        let t2 = TDigest::merge_digests(&digests);
487
488
        let mut accumulator =
489
            ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
490
491
        accumulator.merge_digests(&[t1]);
492
        assert_eq!(accumulator.digest.count(), 50_000);
493
        accumulator.merge_digests(&[t2]);
494
        assert_eq!(accumulator.digest.count(), 100_000);
495
    }
496
}