Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/approx_median.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution
19
20
use std::any::Any;
21
use std::fmt::Debug;
22
23
use arrow::{datatypes::DataType, datatypes::Field};
24
use arrow_schema::DataType::{Float64, UInt64};
25
26
use datafusion_common::{not_impl_err, plan_err, Result};
27
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
28
use datafusion_expr::type_coercion::aggregates::NUMERICS;
29
use datafusion_expr::utils::format_state_name;
30
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
31
32
use crate::approx_percentile_cont::ApproxPercentileAccumulator;
33
34
make_udaf_expr_and_func!(
35
    ApproxMedian,
36
    approx_median,
37
    expression,
38
    "Computes the approximate median of a set of numbers",
39
    approx_median_udaf
40
);
41
42
/// APPROX_MEDIAN aggregate expression
43
pub struct ApproxMedian {
44
    signature: Signature,
45
}
46
47
impl Debug for ApproxMedian {
48
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49
0
        f.debug_struct("ApproxMedian")
50
0
            .field("name", &self.name())
51
0
            .field("signature", &self.signature)
52
0
            .finish()
53
0
    }
54
}
55
56
impl Default for ApproxMedian {
57
0
    fn default() -> Self {
58
0
        Self::new()
59
0
    }
60
}
61
62
impl ApproxMedian {
63
    /// Create a new APPROX_MEDIAN aggregate function
64
0
    pub fn new() -> Self {
65
0
        Self {
66
0
            signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
67
0
        }
68
0
    }
69
}
70
71
impl AggregateUDFImpl for ApproxMedian {
72
    /// Return a reference to Any that can be used for downcasting
73
0
    fn as_any(&self) -> &dyn Any {
74
0
        self
75
0
    }
76
77
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
78
0
        Ok(vec![
79
0
            Field::new(format_state_name(args.name, "max_size"), UInt64, false),
80
0
            Field::new(format_state_name(args.name, "sum"), Float64, false),
81
0
            Field::new(format_state_name(args.name, "count"), UInt64, false),
82
0
            Field::new(format_state_name(args.name, "max"), Float64, false),
83
0
            Field::new(format_state_name(args.name, "min"), Float64, false),
84
0
            Field::new_list(
85
0
                format_state_name(args.name, "centroids"),
86
0
                Field::new("item", Float64, true),
87
0
                false,
88
0
            ),
89
0
        ])
90
0
    }
91
92
0
    fn name(&self) -> &str {
93
0
        "approx_median"
94
0
    }
95
96
0
    fn signature(&self) -> &Signature {
97
0
        &self.signature
98
0
    }
99
100
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
101
0
        if !arg_types[0].is_numeric() {
102
0
            return plan_err!("ApproxMedian requires numeric input types");
103
0
        }
104
0
        Ok(arg_types[0].clone())
105
0
    }
106
107
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
108
0
        if acc_args.is_distinct {
109
0
            return not_impl_err!(
110
0
                "APPROX_MEDIAN(DISTINCT) aggregations are not available"
111
0
            );
112
0
        }
113
0
114
0
        Ok(Box::new(ApproxPercentileAccumulator::new(
115
0
            0.5_f64,
116
0
            acc_args.exprs[0].data_type(acc_args.schema)?,
117
        )))
118
0
    }
119
}