Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::any::Any;
19
use std::fmt::{Debug, Formatter};
20
use std::sync::Arc;
21
22
use arrow::{
23
    array::ArrayRef,
24
    datatypes::{DataType, Field},
25
};
26
27
use datafusion_common::ScalarValue;
28
use datafusion_common::{not_impl_err, plan_err, Result};
29
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
30
use datafusion_expr::type_coercion::aggregates::NUMERICS;
31
use datafusion_expr::Volatility::Immutable;
32
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature};
33
use datafusion_functions_aggregate_common::tdigest::{
34
    Centroid, TDigest, DEFAULT_MAX_SIZE,
35
};
36
37
use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
38
39
make_udaf_expr_and_func!(
40
    ApproxPercentileContWithWeight,
41
    approx_percentile_cont_with_weight,
42
    expression weight percentile,
43
    "Computes the approximate percentile continuous with weight of a set of numbers",
44
    approx_percentile_cont_with_weight_udaf
45
);
46
47
/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
48
pub struct ApproxPercentileContWithWeight {
49
    signature: Signature,
50
    approx_percentile_cont: ApproxPercentileCont,
51
}
52
53
impl Debug for ApproxPercentileContWithWeight {
54
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
55
0
        f.debug_struct("ApproxPercentileContWithWeight")
56
0
            .field("signature", &self.signature)
57
0
            .finish()
58
0
    }
59
}
60
61
impl Default for ApproxPercentileContWithWeight {
62
0
    fn default() -> Self {
63
0
        Self::new()
64
0
    }
65
}
66
67
impl ApproxPercentileContWithWeight {
68
    /// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
69
0
    pub fn new() -> Self {
70
0
        Self {
71
0
            signature: Signature::one_of(
72
0
                // Accept any numeric value paired with a float64 percentile
73
0
                NUMERICS
74
0
                    .iter()
75
0
                    .map(|t| {
76
0
                        TypeSignature::Exact(vec![
77
0
                            t.clone(),
78
0
                            t.clone(),
79
0
                            DataType::Float64,
80
0
                        ])
81
0
                    })
82
0
                    .collect(),
83
0
                Immutable,
84
0
            ),
85
0
            approx_percentile_cont: ApproxPercentileCont::new(),
86
0
        }
87
0
    }
88
}
89
90
impl AggregateUDFImpl for ApproxPercentileContWithWeight {
91
0
    fn as_any(&self) -> &dyn Any {
92
0
        self
93
0
    }
94
95
0
    fn name(&self) -> &str {
96
0
        "approx_percentile_cont_with_weight"
97
0
    }
98
99
0
    fn signature(&self) -> &Signature {
100
0
        &self.signature
101
0
    }
102
103
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
104
0
        if !arg_types[0].is_numeric() {
105
0
            return plan_err!(
106
0
                "approx_percentile_cont_with_weight requires numeric input types"
107
0
            );
108
0
        }
109
0
        if !arg_types[1].is_numeric() {
110
0
            return plan_err!(
111
0
                "approx_percentile_cont_with_weight requires numeric weight input types"
112
0
            );
113
0
        }
114
0
        if arg_types[2] != DataType::Float64 {
115
0
            return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
116
0
        }
117
0
        Ok(arg_types[0].clone())
118
0
    }
119
120
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
121
0
        if acc_args.is_distinct {
122
0
            return not_impl_err!(
123
0
                "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available"
124
0
            );
125
0
        }
126
0
127
0
        if acc_args.exprs.len() != 3 {
128
0
            return plan_err!(
129
0
                "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile"
130
0
            );
131
0
        }
132
0
133
0
        let sub_args = AccumulatorArgs {
134
0
            exprs: &[
135
0
                Arc::clone(&acc_args.exprs[0]),
136
0
                Arc::clone(&acc_args.exprs[2]),
137
0
            ],
138
0
            ..acc_args
139
0
        };
140
0
        let approx_percentile_cont_accumulator =
141
0
            self.approx_percentile_cont.create_accumulator(sub_args)?;
142
0
        let accumulator = ApproxPercentileWithWeightAccumulator::new(
143
0
            approx_percentile_cont_accumulator,
144
0
        );
145
0
        Ok(Box::new(accumulator))
146
0
    }
147
148
    #[allow(rustdoc::private_intra_doc_links)]
149
    /// See [`TDigest::to_scalar_state()`] for a description of the serialised
150
    /// state.
151
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
152
0
        self.approx_percentile_cont.state_fields(args)
153
0
    }
154
}
155
156
#[derive(Debug)]
157
pub struct ApproxPercentileWithWeightAccumulator {
158
    approx_percentile_cont_accumulator: ApproxPercentileAccumulator,
159
}
160
161
impl ApproxPercentileWithWeightAccumulator {
162
0
    pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self {
163
0
        Self {
164
0
            approx_percentile_cont_accumulator,
165
0
        }
166
0
    }
167
}
168
169
impl Accumulator for ApproxPercentileWithWeightAccumulator {
170
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
171
0
        self.approx_percentile_cont_accumulator.state()
172
0
    }
173
174
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
175
0
        let means = &values[0];
176
0
        let weights = &values[1];
177
0
        debug_assert_eq!(
178
0
            means.len(),
179
0
            weights.len(),
180
0
            "invalid number of values in means and weights"
181
        );
182
0
        let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?;
183
0
        let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?;
184
0
        let mut digests: Vec<TDigest> = vec![];
185
0
        for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
186
0
            digests.push(TDigest::new_with_centroid(
187
0
                DEFAULT_MAX_SIZE,
188
0
                Centroid::new(*mean, *weight),
189
0
            ))
190
        }
191
0
        self.approx_percentile_cont_accumulator
192
0
            .merge_digests(&digests);
193
0
        Ok(())
194
0
    }
195
196
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
197
0
        self.approx_percentile_cont_accumulator.evaluate()
198
0
    }
199
200
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
201
0
        self.approx_percentile_cont_accumulator
202
0
            .merge_batch(states)?;
203
204
0
        Ok(())
205
0
    }
206
207
0
    fn size(&self) -> usize {
208
0
        std::mem::size_of_val(self)
209
0
            - std::mem::size_of_val(&self.approx_percentile_cont_accumulator)
210
0
            + self.approx_percentile_cont_accumulator.size()
211
0
    }
212
}