Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/median.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::collections::HashSet;
19
use std::fmt::Formatter;
20
use std::{fmt::Debug, sync::Arc};
21
22
use arrow::array::{downcast_integer, ArrowNumericType};
23
use arrow::{
24
    array::{ArrayRef, AsArray},
25
    datatypes::{
26
        DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type,
27
        Float64Type,
28
    },
29
};
30
31
use arrow::array::Array;
32
use arrow::array::ArrowNativeTypeOp;
33
use arrow::datatypes::ArrowNativeType;
34
35
use datafusion_common::{DataFusionError, Result, ScalarValue};
36
use datafusion_expr::function::StateFieldsArgs;
37
use datafusion_expr::{
38
    function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
39
    Signature, Volatility,
40
};
41
use datafusion_functions_aggregate_common::utils::Hashable;
42
43
make_udaf_expr_and_func!(
44
    Median,
45
    median,
46
    expression,
47
    "Computes the median of a set of numbers",
48
    median_udaf
49
);
50
51
/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a
52
/// lot of memory because all values need to be stored in memory before a result can be
53
/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more
54
/// efficient solution.
55
///
56
/// If using the distinct variation, the memory usage will be similarly high if the
57
/// cardinality is high as it stores all distinct values in memory before computing the
58
/// result, but if cardinality is low then memory usage will also be lower.
59
pub struct Median {
60
    signature: Signature,
61
}
62
63
impl Debug for Median {
64
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
65
0
        f.debug_struct("Median")
66
0
            .field("name", &self.name())
67
0
            .field("signature", &self.signature)
68
0
            .finish()
69
0
    }
70
}
71
72
impl Default for Median {
73
1
    fn default() -> Self {
74
1
        Self::new()
75
1
    }
76
}
77
78
impl Median {
79
1
    pub fn new() -> Self {
80
1
        Self {
81
1
            signature: Signature::numeric(1, Volatility::Immutable),
82
1
        }
83
1
    }
84
}
85
86
impl AggregateUDFImpl for Median {
87
0
    fn as_any(&self) -> &dyn std::any::Any {
88
0
        self
89
0
    }
90
91
1
    fn name(&self) -> &str {
92
1
        "median"
93
1
    }
94
95
1
    fn signature(&self) -> &Signature {
96
1
        &self.signature
97
1
    }
98
99
1
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100
1
        Ok(arg_types[0].clone())
101
1
    }
102
103
1
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
104
1
        //Intermediate state is a list of the elements we have collected so far
105
1
        let field = Field::new("item", args.input_types[0].clone(), true);
106
1
        let state_name = if args.is_distinct {
107
0
            "distinct_median"
108
        } else {
109
1
            "median"
110
        };
111
112
1
        Ok(vec![Field::new(
113
1
            format_state_name(args.name, state_name),
114
1
            DataType::List(Arc::new(field)),
115
1
            true,
116
1
        )])
117
1
    }
118
119
1
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
120
        macro_rules! helper {
121
            ($t:ty, $dt:expr) => {
122
                if acc_args.is_distinct {
123
                    Ok(Box::new(DistinctMedianAccumulator::<$t> {
124
                        data_type: $dt.clone(),
125
                        distinct_values: HashSet::new(),
126
                    }))
127
                } else {
128
                    Ok(Box::new(MedianAccumulator::<$t> {
129
                        data_type: $dt.clone(),
130
                        all_values: vec![],
131
                    }))
132
                }
133
            };
134
        }
135
136
1
        let dt = acc_args.exprs[0].data_type(acc_args.schema)
?0
;
137
0
        downcast_integer! {
138
1
            dt => (helper, 
dt0
),
139
0
            DataType::Float16 => helper!(Float16Type, dt),
140
0
            DataType::Float32 => helper!(Float32Type, dt),
141
0
            DataType::Float64 => helper!(Float64Type, dt),
142
0
            DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
143
0
            DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
144
0
            _ => Err(DataFusionError::NotImplemented(format!(
145
0
                "MedianAccumulator not supported for {} with {}",
146
0
                acc_args.name,
147
0
                dt,
148
0
            ))),
149
        }
150
1
    }
151
152
0
    fn aliases(&self) -> &[String] {
153
0
        &[]
154
0
    }
155
}
156
157
/// The median accumulator accumulates the raw input values
158
/// as `ScalarValue`s
159
///
160
/// The intermediate state is represented as a List of scalar values updated by
161
/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
162
/// in the final evaluation step so that we avoid expensive conversions and
163
/// allocations during `update_batch`.
164
struct MedianAccumulator<T: ArrowNumericType> {
165
    data_type: DataType,
166
    all_values: Vec<T::Native>,
167
}
168
169
impl<T: ArrowNumericType> std::fmt::Debug for MedianAccumulator<T> {
170
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
171
0
        write!(f, "MedianAccumulator({})", self.data_type)
172
0
    }
173
}
174
175
impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
176
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
177
0
        let all_values = self
178
0
            .all_values
179
0
            .iter()
180
0
            .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &self.data_type))
181
0
            .collect::<Result<Vec<_>>>()?;
182
183
0
        let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
184
0
        Ok(vec![ScalarValue::List(arr)])
185
0
    }
186
187
1
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
188
1
        let values = values[0].as_primitive::<T>();
189
1
        self.all_values.reserve(values.len() - values.null_count());
190
1
        self.all_values.extend(values.iter().flatten());
191
1
        Ok(())
192
1
    }
193
194
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
195
0
        let array = states[0].as_list::<i32>();
196
0
        for v in array.iter().flatten() {
197
0
            self.update_batch(&[v])?
198
        }
199
0
        Ok(())
200
0
    }
201
202
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
203
0
        let d = std::mem::take(&mut self.all_values);
204
0
        let median = calculate_median::<T>(d);
205
0
        ScalarValue::new_primitive::<T>(median, &self.data_type)
206
0
    }
207
208
2
    fn size(&self) -> usize {
209
2
        std::mem::size_of_val(self)
210
2
            + self.all_values.capacity() * std::mem::size_of::<T::Native>()
211
2
    }
212
}
213
214
/// The distinct median accumulator accumulates the raw input values
215
/// as `ScalarValue`s
216
///
217
/// The intermediate state is represented as a List of scalar values updated by
218
/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values
219
/// in the final evaluation step so that we avoid expensive conversions and
220
/// allocations during `update_batch`.
221
struct DistinctMedianAccumulator<T: ArrowNumericType> {
222
    data_type: DataType,
223
    distinct_values: HashSet<Hashable<T::Native>>,
224
}
225
226
impl<T: ArrowNumericType> std::fmt::Debug for DistinctMedianAccumulator<T> {
227
0
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
228
0
        write!(f, "DistinctMedianAccumulator({})", self.data_type)
229
0
    }
230
}
231
232
impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> {
233
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
234
0
        let all_values = self
235
0
            .distinct_values
236
0
            .iter()
237
0
            .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type))
238
0
            .collect::<Result<Vec<_>>>()?;
239
240
0
        let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type);
241
0
        Ok(vec![ScalarValue::List(arr)])
242
0
    }
243
244
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
245
0
        if values.is_empty() {
246
0
            return Ok(());
247
0
        }
248
0
249
0
        let array = values[0].as_primitive::<T>();
250
0
        match array.nulls().filter(|x| x.null_count() > 0) {
251
0
            Some(n) => {
252
0
                for idx in n.valid_indices() {
253
0
                    self.distinct_values.insert(Hashable(array.value(idx)));
254
0
                }
255
            }
256
0
            None => array.values().iter().for_each(|x| {
257
0
                self.distinct_values.insert(Hashable(*x));
258
0
            }),
259
        }
260
0
        Ok(())
261
0
    }
262
263
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
264
0
        let array = states[0].as_list::<i32>();
265
0
        for v in array.iter().flatten() {
266
0
            self.update_batch(&[v])?
267
        }
268
0
        Ok(())
269
0
    }
270
271
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
272
0
        let d = std::mem::take(&mut self.distinct_values)
273
0
            .into_iter()
274
0
            .map(|v| v.0)
275
0
            .collect::<Vec<_>>();
276
0
        let median = calculate_median::<T>(d);
277
0
        ScalarValue::new_primitive::<T>(median, &self.data_type)
278
0
    }
279
280
0
    fn size(&self) -> usize {
281
0
        std::mem::size_of_val(self)
282
0
            + self.distinct_values.capacity() * std::mem::size_of::<T::Native>()
283
0
    }
284
}
285
286
0
fn calculate_median<T: ArrowNumericType>(
287
0
    mut values: Vec<T::Native>,
288
0
) -> Option<T::Native> {
289
0
    let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);
290
291
0
    let len = values.len();
292
0
    if len == 0 {
293
0
        None
294
0
    } else if len % 2 == 0 {
295
0
        let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
296
0
        let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp);
297
0
        let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2));
298
0
        Some(median)
299
    } else {
300
0
        let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
301
0
        Some(*median)
302
    }
303
0
}