/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/approx_percentile_cont.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::any::Any; |
19 | | use std::fmt::{Debug, Formatter}; |
20 | | use std::sync::Arc; |
21 | | |
22 | | use arrow::array::{Array, RecordBatch}; |
23 | | use arrow::compute::{filter, is_not_null}; |
24 | | use arrow::{ |
25 | | array::{ |
26 | | ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, |
27 | | Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, |
28 | | }, |
29 | | datatypes::DataType, |
30 | | }; |
31 | | use arrow_schema::{Field, Schema}; |
32 | | |
33 | | use datafusion_common::{ |
34 | | downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, |
35 | | DataFusionError, Result, ScalarValue, |
36 | | }; |
37 | | use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; |
38 | | use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; |
39 | | use datafusion_expr::utils::format_state_name; |
40 | | use datafusion_expr::{ |
41 | | Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, |
42 | | Volatility, |
43 | | }; |
44 | | use datafusion_functions_aggregate_common::tdigest::{ |
45 | | TDigest, TryIntoF64, DEFAULT_MAX_SIZE, |
46 | | }; |
47 | | use datafusion_physical_expr_common::physical_expr::PhysicalExpr; |
48 | | |
49 | | create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); |
50 | | |
51 | | /// Computes the approximate percentile continuous of a set of numbers |
52 | 0 | pub fn approx_percentile_cont( |
53 | 0 | expression: Expr, |
54 | 0 | percentile: Expr, |
55 | 0 | centroids: Option<Expr>, |
56 | 0 | ) -> Expr { |
57 | 0 | let args = if let Some(centroids) = centroids { |
58 | 0 | vec![expression, percentile, centroids] |
59 | | } else { |
60 | 0 | vec![expression, percentile] |
61 | | }; |
62 | 0 | approx_percentile_cont_udaf().call(args) |
63 | 0 | } |
64 | | |
65 | | pub struct ApproxPercentileCont { |
66 | | signature: Signature, |
67 | | } |
68 | | |
69 | | impl Debug for ApproxPercentileCont { |
70 | 0 | fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { |
71 | 0 | f.debug_struct("ApproxPercentileCont") |
72 | 0 | .field("name", &self.name()) |
73 | 0 | .field("signature", &self.signature) |
74 | 0 | .finish() |
75 | 0 | } |
76 | | } |
77 | | |
78 | | impl Default for ApproxPercentileCont { |
79 | 0 | fn default() -> Self { |
80 | 0 | Self::new() |
81 | 0 | } |
82 | | } |
83 | | |
84 | | impl ApproxPercentileCont { |
85 | | /// Create a new [`ApproxPercentileCont`] aggregate function. |
86 | 0 | pub fn new() -> Self { |
87 | 0 | let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); |
88 | | // Accept any numeric value paired with a float64 percentile |
89 | 0 | for num in NUMERICS { |
90 | 0 | variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); |
91 | | // Additionally accept an integer number of centroids for T-Digest |
92 | 0 | for int in INTEGERS { |
93 | 0 | variants.push(TypeSignature::Exact(vec![ |
94 | 0 | num.clone(), |
95 | 0 | DataType::Float64, |
96 | 0 | int.clone(), |
97 | 0 | ])) |
98 | | } |
99 | | } |
100 | 0 | Self { |
101 | 0 | signature: Signature::one_of(variants, Volatility::Immutable), |
102 | 0 | } |
103 | 0 | } |
104 | | |
105 | 0 | pub(crate) fn create_accumulator( |
106 | 0 | &self, |
107 | 0 | args: AccumulatorArgs, |
108 | 0 | ) -> Result<ApproxPercentileAccumulator> { |
109 | 0 | let percentile = validate_input_percentile_expr(&args.exprs[1])?; |
110 | 0 | let tdigest_max_size = if args.exprs.len() == 3 { |
111 | 0 | Some(validate_input_max_size_expr(&args.exprs[2])?) |
112 | | } else { |
113 | 0 | None |
114 | | }; |
115 | | |
116 | 0 | let data_type = args.exprs[0].data_type(args.schema)?; |
117 | 0 | let accumulator: ApproxPercentileAccumulator = match data_type { |
118 | 0 | t @ (DataType::UInt8 |
119 | | | DataType::UInt16 |
120 | | | DataType::UInt32 |
121 | | | DataType::UInt64 |
122 | | | DataType::Int8 |
123 | | | DataType::Int16 |
124 | | | DataType::Int32 |
125 | | | DataType::Int64 |
126 | | | DataType::Float32 |
127 | | | DataType::Float64) => { |
128 | 0 | if let Some(max_size) = tdigest_max_size { |
129 | 0 | ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size) |
130 | | }else{ |
131 | 0 | ApproxPercentileAccumulator::new(percentile, t) |
132 | | |
133 | | } |
134 | | } |
135 | 0 | other => { |
136 | 0 | return not_impl_err!( |
137 | 0 | "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" |
138 | 0 | ) |
139 | | } |
140 | | }; |
141 | | |
142 | 0 | Ok(accumulator) |
143 | 0 | } |
144 | | } |
145 | | |
146 | 0 | fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> { |
147 | 0 | let empty_schema = Arc::new(Schema::empty()); |
148 | 0 | let batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); |
149 | 0 | if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? { |
150 | 0 | Ok(s) |
151 | | } else { |
152 | 0 | internal_err!("Didn't expect ColumnarValue::Array") |
153 | | } |
154 | 0 | } |
155 | | |
156 | 0 | fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> { |
157 | 0 | let percentile = match get_scalar_value(expr) |
158 | 0 | .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { |
159 | 0 | ScalarValue::Float32(Some(value)) => { |
160 | 0 | value as f64 |
161 | | } |
162 | 0 | ScalarValue::Float64(Some(value)) => { |
163 | 0 | value |
164 | | } |
165 | 0 | sv => { |
166 | 0 | return not_impl_err!( |
167 | 0 | "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", |
168 | 0 | sv.data_type() |
169 | 0 | ) |
170 | | } |
171 | | }; |
172 | | |
173 | | // Ensure the percentile is between 0 and 1. |
174 | 0 | if !(0.0..=1.0).contains(&percentile) { |
175 | 0 | return plan_err!( |
176 | 0 | "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" |
177 | 0 | ); |
178 | 0 | } |
179 | 0 | Ok(percentile) |
180 | 0 | } |
181 | | |
182 | 0 | fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> { |
183 | 0 | let max_size = match get_scalar_value(expr) |
184 | 0 | .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { |
185 | 0 | ScalarValue::UInt8(Some(q)) => q as usize, |
186 | 0 | ScalarValue::UInt16(Some(q)) => q as usize, |
187 | 0 | ScalarValue::UInt32(Some(q)) => q as usize, |
188 | 0 | ScalarValue::UInt64(Some(q)) => q as usize, |
189 | 0 | ScalarValue::Int32(Some(q)) if q > 0 => q as usize, |
190 | 0 | ScalarValue::Int64(Some(q)) if q > 0 => q as usize, |
191 | 0 | ScalarValue::Int16(Some(q)) if q > 0 => q as usize, |
192 | 0 | ScalarValue::Int8(Some(q)) if q > 0 => q as usize, |
193 | 0 | sv => { |
194 | 0 | return not_impl_err!( |
195 | 0 | "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", |
196 | 0 | sv.data_type() |
197 | 0 | ) |
198 | | }, |
199 | | }; |
200 | | |
201 | 0 | Ok(max_size) |
202 | 0 | } |
203 | | |
204 | | impl AggregateUDFImpl for ApproxPercentileCont { |
205 | 0 | fn as_any(&self) -> &dyn Any { |
206 | 0 | self |
207 | 0 | } |
208 | | |
209 | | #[allow(rustdoc::private_intra_doc_links)] |
210 | | /// See [`TDigest::to_scalar_state()`] for a description of the serialised |
211 | | /// state. |
212 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
213 | 0 | Ok(vec![ |
214 | 0 | Field::new( |
215 | 0 | format_state_name(args.name, "max_size"), |
216 | 0 | DataType::UInt64, |
217 | 0 | false, |
218 | 0 | ), |
219 | 0 | Field::new( |
220 | 0 | format_state_name(args.name, "sum"), |
221 | 0 | DataType::Float64, |
222 | 0 | false, |
223 | 0 | ), |
224 | 0 | Field::new( |
225 | 0 | format_state_name(args.name, "count"), |
226 | 0 | DataType::UInt64, |
227 | 0 | false, |
228 | 0 | ), |
229 | 0 | Field::new( |
230 | 0 | format_state_name(args.name, "max"), |
231 | 0 | DataType::Float64, |
232 | 0 | false, |
233 | 0 | ), |
234 | 0 | Field::new( |
235 | 0 | format_state_name(args.name, "min"), |
236 | 0 | DataType::Float64, |
237 | 0 | false, |
238 | 0 | ), |
239 | 0 | Field::new_list( |
240 | 0 | format_state_name(args.name, "centroids"), |
241 | 0 | Field::new("item", DataType::Float64, true), |
242 | 0 | false, |
243 | 0 | ), |
244 | 0 | ]) |
245 | 0 | } |
246 | | |
247 | 0 | fn name(&self) -> &str { |
248 | 0 | "approx_percentile_cont" |
249 | 0 | } |
250 | | |
251 | 0 | fn signature(&self) -> &Signature { |
252 | 0 | &self.signature |
253 | 0 | } |
254 | | |
255 | | #[inline] |
256 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
257 | 0 | Ok(Box::new(self.create_accumulator(acc_args)?)) |
258 | 0 | } |
259 | | |
260 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
261 | 0 | if !arg_types[0].is_numeric() { |
262 | 0 | return plan_err!("approx_percentile_cont requires numeric input types"); |
263 | 0 | } |
264 | 0 | if arg_types.len() == 3 && !arg_types[2].is_integer() { |
265 | 0 | return plan_err!( |
266 | 0 | "approx_percentile_cont requires integer max_size input types" |
267 | 0 | ); |
268 | 0 | } |
269 | 0 | Ok(arg_types[0].clone()) |
270 | 0 | } |
271 | | } |
272 | | |
273 | | #[derive(Debug)] |
274 | | pub struct ApproxPercentileAccumulator { |
275 | | digest: TDigest, |
276 | | percentile: f64, |
277 | | return_type: DataType, |
278 | | } |
279 | | |
280 | | impl ApproxPercentileAccumulator { |
281 | 0 | pub fn new(percentile: f64, return_type: DataType) -> Self { |
282 | 0 | Self { |
283 | 0 | digest: TDigest::new(DEFAULT_MAX_SIZE), |
284 | 0 | percentile, |
285 | 0 | return_type, |
286 | 0 | } |
287 | 0 | } |
288 | | |
289 | 0 | pub fn new_with_max_size( |
290 | 0 | percentile: f64, |
291 | 0 | return_type: DataType, |
292 | 0 | max_size: usize, |
293 | 0 | ) -> Self { |
294 | 0 | Self { |
295 | 0 | digest: TDigest::new(max_size), |
296 | 0 | percentile, |
297 | 0 | return_type, |
298 | 0 | } |
299 | 0 | } |
300 | | |
301 | | // public for approx_percentile_cont_with_weight |
302 | 0 | pub fn merge_digests(&mut self, digests: &[TDigest]) { |
303 | 0 | let digests = digests.iter().chain(std::iter::once(&self.digest)); |
304 | 0 | self.digest = TDigest::merge_digests(digests) |
305 | 0 | } |
306 | | |
307 | | // public for approx_percentile_cont_with_weight |
308 | 0 | pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> { |
309 | 0 | match values.data_type() { |
310 | | DataType::Float64 => { |
311 | 0 | let array = downcast_value!(values, Float64Array); |
312 | 0 | Ok(array |
313 | 0 | .values() |
314 | 0 | .iter() |
315 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
316 | 0 | .collect::<Result<Vec<_>>>()?) |
317 | | } |
318 | | DataType::Float32 => { |
319 | 0 | let array = downcast_value!(values, Float32Array); |
320 | 0 | Ok(array |
321 | 0 | .values() |
322 | 0 | .iter() |
323 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
324 | 0 | .collect::<Result<Vec<_>>>()?) |
325 | | } |
326 | | DataType::Int64 => { |
327 | 0 | let array = downcast_value!(values, Int64Array); |
328 | 0 | Ok(array |
329 | 0 | .values() |
330 | 0 | .iter() |
331 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
332 | 0 | .collect::<Result<Vec<_>>>()?) |
333 | | } |
334 | | DataType::Int32 => { |
335 | 0 | let array = downcast_value!(values, Int32Array); |
336 | 0 | Ok(array |
337 | 0 | .values() |
338 | 0 | .iter() |
339 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
340 | 0 | .collect::<Result<Vec<_>>>()?) |
341 | | } |
342 | | DataType::Int16 => { |
343 | 0 | let array = downcast_value!(values, Int16Array); |
344 | 0 | Ok(array |
345 | 0 | .values() |
346 | 0 | .iter() |
347 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
348 | 0 | .collect::<Result<Vec<_>>>()?) |
349 | | } |
350 | | DataType::Int8 => { |
351 | 0 | let array = downcast_value!(values, Int8Array); |
352 | 0 | Ok(array |
353 | 0 | .values() |
354 | 0 | .iter() |
355 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
356 | 0 | .collect::<Result<Vec<_>>>()?) |
357 | | } |
358 | | DataType::UInt64 => { |
359 | 0 | let array = downcast_value!(values, UInt64Array); |
360 | 0 | Ok(array |
361 | 0 | .values() |
362 | 0 | .iter() |
363 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
364 | 0 | .collect::<Result<Vec<_>>>()?) |
365 | | } |
366 | | DataType::UInt32 => { |
367 | 0 | let array = downcast_value!(values, UInt32Array); |
368 | 0 | Ok(array |
369 | 0 | .values() |
370 | 0 | .iter() |
371 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
372 | 0 | .collect::<Result<Vec<_>>>()?) |
373 | | } |
374 | | DataType::UInt16 => { |
375 | 0 | let array = downcast_value!(values, UInt16Array); |
376 | 0 | Ok(array |
377 | 0 | .values() |
378 | 0 | .iter() |
379 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
380 | 0 | .collect::<Result<Vec<_>>>()?) |
381 | | } |
382 | | DataType::UInt8 => { |
383 | 0 | let array = downcast_value!(values, UInt8Array); |
384 | 0 | Ok(array |
385 | 0 | .values() |
386 | 0 | .iter() |
387 | 0 | .filter_map(|v| v.try_as_f64().transpose()) |
388 | 0 | .collect::<Result<Vec<_>>>()?) |
389 | | } |
390 | 0 | e => internal_err!( |
391 | 0 | "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" |
392 | 0 | ), |
393 | | } |
394 | 0 | } |
395 | | } |
396 | | |
397 | | impl Accumulator for ApproxPercentileAccumulator { |
398 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
399 | 0 | Ok(self.digest.to_scalar_state().into_iter().collect()) |
400 | 0 | } |
401 | | |
402 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
403 | 0 | // Remove any nulls before computing the percentile |
404 | 0 | let mut values = Arc::clone(&values[0]); |
405 | 0 | if values.nulls().is_some() { |
406 | 0 | values = filter(&values, &is_not_null(&values)?)?; |
407 | 0 | } |
408 | 0 | let sorted_values = &arrow::compute::sort(&values, None)?; |
409 | 0 | let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?; |
410 | 0 | self.digest = self.digest.merge_sorted_f64(&sorted_values); |
411 | 0 | Ok(()) |
412 | 0 | } |
413 | | |
414 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
415 | 0 | if self.digest.count() == 0 { |
416 | 0 | return ScalarValue::try_from(self.return_type.clone()); |
417 | 0 | } |
418 | 0 | let q = self.digest.estimate_quantile(self.percentile); |
419 | 0 |
|
420 | 0 | // These acceptable return types MUST match the validation in |
421 | 0 | // ApproxPercentile::create_accumulator. |
422 | 0 | Ok(match &self.return_type { |
423 | 0 | DataType::Int8 => ScalarValue::Int8(Some(q as i8)), |
424 | 0 | DataType::Int16 => ScalarValue::Int16(Some(q as i16)), |
425 | 0 | DataType::Int32 => ScalarValue::Int32(Some(q as i32)), |
426 | 0 | DataType::Int64 => ScalarValue::Int64(Some(q as i64)), |
427 | 0 | DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), |
428 | 0 | DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), |
429 | 0 | DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), |
430 | 0 | DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), |
431 | 0 | DataType::Float32 => ScalarValue::Float32(Some(q as f32)), |
432 | 0 | DataType::Float64 => ScalarValue::Float64(Some(q)), |
433 | 0 | v => unreachable!("unexpected return type {:?}", v), |
434 | | }) |
435 | 0 | } |
436 | | |
437 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
438 | 0 | if states.is_empty() { |
439 | 0 | return Ok(()); |
440 | 0 | } |
441 | | |
442 | 0 | let states = (0..states[0].len()) |
443 | 0 | .map(|index| { |
444 | 0 | states |
445 | 0 | .iter() |
446 | 0 | .map(|array| ScalarValue::try_from_array(array, index)) |
447 | 0 | .collect::<Result<Vec<_>>>() |
448 | 0 | .map(|state| TDigest::from_scalar_state(&state)) |
449 | 0 | }) |
450 | 0 | .collect::<Result<Vec<_>>>()?; |
451 | | |
452 | 0 | self.merge_digests(&states); |
453 | 0 |
|
454 | 0 | Ok(()) |
455 | 0 | } |
456 | | |
457 | 0 | fn size(&self) -> usize { |
458 | 0 | std::mem::size_of_val(self) + self.digest.size() |
459 | 0 | - std::mem::size_of_val(&self.digest) |
460 | 0 | + self.return_type.size() |
461 | 0 | - std::mem::size_of_val(&self.return_type) |
462 | 0 | } |
463 | | } |
464 | | |
465 | | #[cfg(test)] |
466 | | mod tests { |
467 | | use arrow_schema::DataType; |
468 | | |
469 | | use datafusion_functions_aggregate_common::tdigest::TDigest; |
470 | | |
471 | | use crate::approx_percentile_cont::ApproxPercentileAccumulator; |
472 | | |
473 | | #[test] |
474 | | fn test_combine_approx_percentile_accumulator() { |
475 | | let mut digests: Vec<TDigest> = Vec::new(); |
476 | | |
477 | | // one TDigest with 50_000 values from 1 to 1_000 |
478 | | for _ in 1..=50 { |
479 | | let t = TDigest::new(100); |
480 | | let values: Vec<_> = (1..=1_000).map(f64::from).collect(); |
481 | | let t = t.merge_unsorted_f64(values); |
482 | | digests.push(t) |
483 | | } |
484 | | |
485 | | let t1 = TDigest::merge_digests(&digests); |
486 | | let t2 = TDigest::merge_digests(&digests); |
487 | | |
488 | | let mut accumulator = |
489 | | ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); |
490 | | |
491 | | accumulator.merge_digests(&[t1]); |
492 | | assert_eq!(accumulator.digest.count(), 50_000); |
493 | | accumulator.merge_digests(&[t2]); |
494 | | assert_eq!(accumulator.digest.count(), 100_000); |
495 | | } |
496 | | } |