/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/median.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::collections::HashSet; |
19 | | use std::fmt::Formatter; |
20 | | use std::{fmt::Debug, sync::Arc}; |
21 | | |
22 | | use arrow::array::{downcast_integer, ArrowNumericType}; |
23 | | use arrow::{ |
24 | | array::{ArrayRef, AsArray}, |
25 | | datatypes::{ |
26 | | DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type, |
27 | | Float64Type, |
28 | | }, |
29 | | }; |
30 | | |
31 | | use arrow::array::Array; |
32 | | use arrow::array::ArrowNativeTypeOp; |
33 | | use arrow::datatypes::ArrowNativeType; |
34 | | |
35 | | use datafusion_common::{DataFusionError, Result, ScalarValue}; |
36 | | use datafusion_expr::function::StateFieldsArgs; |
37 | | use datafusion_expr::{ |
38 | | function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, |
39 | | Signature, Volatility, |
40 | | }; |
41 | | use datafusion_functions_aggregate_common::utils::Hashable; |
42 | | |
43 | | make_udaf_expr_and_func!( |
44 | | Median, |
45 | | median, |
46 | | expression, |
47 | | "Computes the median of a set of numbers", |
48 | | median_udaf |
49 | | ); |
50 | | |
51 | | /// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a |
52 | | /// lot of memory because all values need to be stored in memory before a result can be |
53 | | /// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more |
54 | | /// efficient solution. |
55 | | /// |
56 | | /// If using the distinct variation, the memory usage will be similarly high if the |
57 | | /// cardinality is high as it stores all distinct values in memory before computing the |
58 | | /// result, but if cardinality is low then memory usage will also be lower. |
59 | | pub struct Median { |
60 | | signature: Signature, |
61 | | } |
62 | | |
63 | | impl Debug for Median { |
64 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
65 | 0 | f.debug_struct("Median") |
66 | 0 | .field("name", &self.name()) |
67 | 0 | .field("signature", &self.signature) |
68 | 0 | .finish() |
69 | 0 | } |
70 | | } |
71 | | |
72 | | impl Default for Median { |
73 | 1 | fn default() -> Self { |
74 | 1 | Self::new() |
75 | 1 | } |
76 | | } |
77 | | |
78 | | impl Median { |
79 | 1 | pub fn new() -> Self { |
80 | 1 | Self { |
81 | 1 | signature: Signature::numeric(1, Volatility::Immutable), |
82 | 1 | } |
83 | 1 | } |
84 | | } |
85 | | |
86 | | impl AggregateUDFImpl for Median { |
87 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
88 | 0 | self |
89 | 0 | } |
90 | | |
91 | 1 | fn name(&self) -> &str { |
92 | 1 | "median" |
93 | 1 | } |
94 | | |
95 | 1 | fn signature(&self) -> &Signature { |
96 | 1 | &self.signature |
97 | 1 | } |
98 | | |
99 | 1 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
100 | 1 | Ok(arg_types[0].clone()) |
101 | 1 | } |
102 | | |
103 | 1 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
104 | 1 | //Intermediate state is a list of the elements we have collected so far |
105 | 1 | let field = Field::new("item", args.input_types[0].clone(), true); |
106 | 1 | let state_name = if args.is_distinct { |
107 | 0 | "distinct_median" |
108 | | } else { |
109 | 1 | "median" |
110 | | }; |
111 | | |
112 | 1 | Ok(vec![Field::new( |
113 | 1 | format_state_name(args.name, state_name), |
114 | 1 | DataType::List(Arc::new(field)), |
115 | 1 | true, |
116 | 1 | )]) |
117 | 1 | } |
118 | | |
119 | 1 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
120 | | macro_rules! helper { |
121 | | ($t:ty, $dt:expr) => { |
122 | | if acc_args.is_distinct { |
123 | | Ok(Box::new(DistinctMedianAccumulator::<$t> { |
124 | | data_type: $dt.clone(), |
125 | | distinct_values: HashSet::new(), |
126 | | })) |
127 | | } else { |
128 | | Ok(Box::new(MedianAccumulator::<$t> { |
129 | | data_type: $dt.clone(), |
130 | | all_values: vec![], |
131 | | })) |
132 | | } |
133 | | }; |
134 | | } |
135 | | |
136 | 1 | let dt = acc_args.exprs[0].data_type(acc_args.schema)?0 ; |
137 | 0 | downcast_integer! { |
138 | 1 | dt => (helper, dt0 ), |
139 | 0 | DataType::Float16 => helper!(Float16Type, dt), |
140 | 0 | DataType::Float32 => helper!(Float32Type, dt), |
141 | 0 | DataType::Float64 => helper!(Float64Type, dt), |
142 | 0 | DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), |
143 | 0 | DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), |
144 | 0 | _ => Err(DataFusionError::NotImplemented(format!( |
145 | 0 | "MedianAccumulator not supported for {} with {}", |
146 | 0 | acc_args.name, |
147 | 0 | dt, |
148 | 0 | ))), |
149 | | } |
150 | 1 | } |
151 | | |
152 | 0 | fn aliases(&self) -> &[String] { |
153 | 0 | &[] |
154 | 0 | } |
155 | | } |
156 | | |
157 | | /// The median accumulator accumulates the raw input values |
158 | | /// as `ScalarValue`s |
159 | | /// |
160 | | /// The intermediate state is represented as a List of scalar values updated by |
161 | | /// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values |
162 | | /// in the final evaluation step so that we avoid expensive conversions and |
163 | | /// allocations during `update_batch`. |
164 | | struct MedianAccumulator<T: ArrowNumericType> { |
165 | | data_type: DataType, |
166 | | all_values: Vec<T::Native>, |
167 | | } |
168 | | |
169 | | impl<T: ArrowNumericType> std::fmt::Debug for MedianAccumulator<T> { |
170 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
171 | 0 | write!(f, "MedianAccumulator({})", self.data_type) |
172 | 0 | } |
173 | | } |
174 | | |
175 | | impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> { |
176 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
177 | 0 | let all_values = self |
178 | 0 | .all_values |
179 | 0 | .iter() |
180 | 0 | .map(|x| ScalarValue::new_primitive::<T>(Some(*x), &self.data_type)) |
181 | 0 | .collect::<Result<Vec<_>>>()?; |
182 | | |
183 | 0 | let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); |
184 | 0 | Ok(vec![ScalarValue::List(arr)]) |
185 | 0 | } |
186 | | |
187 | 1 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
188 | 1 | let values = values[0].as_primitive::<T>(); |
189 | 1 | self.all_values.reserve(values.len() - values.null_count()); |
190 | 1 | self.all_values.extend(values.iter().flatten()); |
191 | 1 | Ok(()) |
192 | 1 | } |
193 | | |
194 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
195 | 0 | let array = states[0].as_list::<i32>(); |
196 | 0 | for v in array.iter().flatten() { |
197 | 0 | self.update_batch(&[v])? |
198 | | } |
199 | 0 | Ok(()) |
200 | 0 | } |
201 | | |
202 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
203 | 0 | let d = std::mem::take(&mut self.all_values); |
204 | 0 | let median = calculate_median::<T>(d); |
205 | 0 | ScalarValue::new_primitive::<T>(median, &self.data_type) |
206 | 0 | } |
207 | | |
208 | 2 | fn size(&self) -> usize { |
209 | 2 | std::mem::size_of_val(self) |
210 | 2 | + self.all_values.capacity() * std::mem::size_of::<T::Native>() |
211 | 2 | } |
212 | | } |
213 | | |
214 | | /// The distinct median accumulator accumulates the raw input values |
215 | | /// as `ScalarValue`s |
216 | | /// |
217 | | /// The intermediate state is represented as a List of scalar values updated by |
218 | | /// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values |
219 | | /// in the final evaluation step so that we avoid expensive conversions and |
220 | | /// allocations during `update_batch`. |
221 | | struct DistinctMedianAccumulator<T: ArrowNumericType> { |
222 | | data_type: DataType, |
223 | | distinct_values: HashSet<Hashable<T::Native>>, |
224 | | } |
225 | | |
226 | | impl<T: ArrowNumericType> std::fmt::Debug for DistinctMedianAccumulator<T> { |
227 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
228 | 0 | write!(f, "DistinctMedianAccumulator({})", self.data_type) |
229 | 0 | } |
230 | | } |
231 | | |
232 | | impl<T: ArrowNumericType> Accumulator for DistinctMedianAccumulator<T> { |
233 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
234 | 0 | let all_values = self |
235 | 0 | .distinct_values |
236 | 0 | .iter() |
237 | 0 | .map(|x| ScalarValue::new_primitive::<T>(Some(x.0), &self.data_type)) |
238 | 0 | .collect::<Result<Vec<_>>>()?; |
239 | | |
240 | 0 | let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); |
241 | 0 | Ok(vec![ScalarValue::List(arr)]) |
242 | 0 | } |
243 | | |
244 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
245 | 0 | if values.is_empty() { |
246 | 0 | return Ok(()); |
247 | 0 | } |
248 | 0 |
|
249 | 0 | let array = values[0].as_primitive::<T>(); |
250 | 0 | match array.nulls().filter(|x| x.null_count() > 0) { |
251 | 0 | Some(n) => { |
252 | 0 | for idx in n.valid_indices() { |
253 | 0 | self.distinct_values.insert(Hashable(array.value(idx))); |
254 | 0 | } |
255 | | } |
256 | 0 | None => array.values().iter().for_each(|x| { |
257 | 0 | self.distinct_values.insert(Hashable(*x)); |
258 | 0 | }), |
259 | | } |
260 | 0 | Ok(()) |
261 | 0 | } |
262 | | |
263 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
264 | 0 | let array = states[0].as_list::<i32>(); |
265 | 0 | for v in array.iter().flatten() { |
266 | 0 | self.update_batch(&[v])? |
267 | | } |
268 | 0 | Ok(()) |
269 | 0 | } |
270 | | |
271 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
272 | 0 | let d = std::mem::take(&mut self.distinct_values) |
273 | 0 | .into_iter() |
274 | 0 | .map(|v| v.0) |
275 | 0 | .collect::<Vec<_>>(); |
276 | 0 | let median = calculate_median::<T>(d); |
277 | 0 | ScalarValue::new_primitive::<T>(median, &self.data_type) |
278 | 0 | } |
279 | | |
280 | 0 | fn size(&self) -> usize { |
281 | 0 | std::mem::size_of_val(self) |
282 | 0 | + self.distinct_values.capacity() * std::mem::size_of::<T::Native>() |
283 | 0 | } |
284 | | } |
285 | | |
286 | 0 | fn calculate_median<T: ArrowNumericType>( |
287 | 0 | mut values: Vec<T::Native>, |
288 | 0 | ) -> Option<T::Native> { |
289 | 0 | let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); |
290 | | |
291 | 0 | let len = values.len(); |
292 | 0 | if len == 0 { |
293 | 0 | None |
294 | 0 | } else if len % 2 == 0 { |
295 | 0 | let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); |
296 | 0 | let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); |
297 | 0 | let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); |
298 | 0 | Some(median) |
299 | | } else { |
300 | 0 | let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); |
301 | 0 | Some(*median) |
302 | | } |
303 | 0 | } |