From 69f498e3f5237f3546c92e5735b8c8d4d8410a96 Mon Sep 17 00:00:00 2001 From: Dom Date: Thu, 6 Jan 2022 16:45:11 +0000 Subject: [PATCH 01/10] feat: implement TDigest for approx quantile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a [TDigest] implementation providing approximate quantile estimations of large inputs using a small amount of (bounded) memory. A TDigest is most accurate near either "end" of the quantile range (that is, 0.1, 0.9, 0.95, etc) due to the use of a scalaing function that increases resolution at the tails. The paper claims single digit part per million errors for q ≤ 0.001 or q ≥ 0.999 using 100 centroids, and in practice I have found accuracy to be more than acceptable for an apprixmate function across the entire quantile range. The implementation is a modified copy of https://github.com/MnO2/t-digest, itself a Rust port of [Facebook's C++ implementation]. Both Facebook's implementation, and Mn02's Rust port are Apache 2.0 licensed. [TDigest]: https://arxiv.org/abs/1902.04023 [Facebook's C++ implementation]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h --- datafusion/src/physical_plan/mod.rs | 1 + datafusion/src/physical_plan/tdigest/mod.rs | 818 ++++++++++++++++++++ 2 files changed, 819 insertions(+) create mode 100644 datafusion/src/physical_plan/tdigest/mod.rs diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 216d4a65e639..66d913d8b24a 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -661,6 +661,7 @@ pub mod repartition; pub mod sorts; pub mod stream; pub mod string_expressions; +pub(crate) mod tdigest; pub mod type_coercion; pub mod udaf; pub mod udf; diff --git a/datafusion/src/physical_plan/tdigest/mod.rs b/datafusion/src/physical_plan/tdigest/mod.rs new file mode 100644 index 000000000000..86d84f9f96fa --- /dev/null +++ b/datafusion/src/physical_plan/tdigest/mod.rs @@ -0,0 +1,818 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with this +// work for additional information regarding copyright ownership. The ASF +// licenses this file to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//! An implementation of the [TDigest sketch algorithm] providing approximate +//! quantile calculations. +//! +//! The TDigest code in this module is modified from +//! https://github.com/MnO2/t-digest, itself a rust reimplementation of +//! [Facebook's Folly TDigest] implementation. +//! +//! Alterations include reduction of runtime heap allocations, broader type +//! support, (de-)serialisation support, reduced type conversions and null value +//! tolerance. +//! +//! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 +//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h + +use arrow::datatypes::DataType; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; + +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +// Cast a non-null [`ScalarValue::Float64`] to an [`OrderedFloat`], or +// panic. +macro_rules! cast_scalar_f64 { + ($value:expr ) => { + match &$value { + ScalarValue::Float64(Some(v)) => OrderedFloat::from(*v), + v => panic!("invalid type {:?}", v), + } + }; +} + +/// This trait is implemented for each type a [`TDigest`] can operate on, +/// allowing it to support both numerical rust types (obtained from +/// `PrimitiveArray` instances), and [`ScalarValue`] instances. +pub(crate) trait TryIntoOrderedF64 { + /// A fallible conversion of a possibly null `self` into a [`OrderedFloat`]. + /// + /// If `self` is null, this method must return `Ok(None)`. + /// + /// If `self` cannot be coerced to the desired type, this method must return + /// an `Err` variant. + fn try_as_f64(&self) -> Result>>; +} + +/// Generate an infallible conversion from `type` to an [`OrderedFloat`]. +macro_rules! impl_try_ordered_f64 { + ($type:ty) => { + impl TryIntoOrderedF64 for $type { + fn try_as_f64(&self) -> Result>> { + Ok(Some(OrderedFloat::from(*self as f64))) + } + } + }; +} + +impl_try_ordered_f64!(f64); +impl_try_ordered_f64!(f32); +impl_try_ordered_f64!(i64); +impl_try_ordered_f64!(i32); +impl_try_ordered_f64!(i16); +impl_try_ordered_f64!(i8); +impl_try_ordered_f64!(u64); +impl_try_ordered_f64!(u32); +impl_try_ordered_f64!(u16); +impl_try_ordered_f64!(u8); + +impl TryIntoOrderedF64 for ScalarValue { + fn try_as_f64(&self) -> Result>> { + match self { + ScalarValue::Float32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Float64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + + got => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_QUANTILE' for data type {} is not implemented", + got + ))) + } + } + } +} + +/// Centroid implementation to the cluster mentioned in the paper. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct Centroid { + mean: OrderedFloat, + weight: OrderedFloat, +} + +impl PartialOrd for Centroid { + fn partial_cmp(&self, other: &Centroid) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Centroid { + fn cmp(&self, other: &Centroid) -> Ordering { + self.mean.cmp(&other.mean) + } +} + +impl Centroid { + pub(crate) fn new( + mean: impl Into>, + weight: impl Into>, + ) -> Self { + Centroid { + mean: mean.into(), + weight: weight.into(), + } + } + + #[inline] + pub(crate) fn mean(&self) -> OrderedFloat { + self.mean + } + + #[inline] + pub(crate) fn weight(&self) -> OrderedFloat { + self.weight + } + + pub(crate) fn add( + &mut self, + sum: impl Into>, + weight: impl Into>, + ) -> f64 { + let new_sum = sum.into() + self.weight * self.mean; + let new_weight = self.weight + weight.into(); + self.weight = new_weight; + self.mean = new_sum / new_weight; + new_sum.into_inner() + } +} + +impl Default for Centroid { + fn default() -> Self { + Centroid { + mean: OrderedFloat::from(0.0), + weight: OrderedFloat::from(1.0), + } + } +} + +/// T-Digest to be operated on. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct TDigest { + centroids: Vec, + max_size: usize, + sum: OrderedFloat, + count: OrderedFloat, + max: OrderedFloat, + min: OrderedFloat, +} + +impl TDigest { + pub(crate) fn new(max_size: usize) -> Self { + TDigest { + centroids: Vec::new(), + max_size, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } + + #[inline] + pub(crate) fn count(&self) -> f64 { + self.count.into_inner() + } + + #[inline] + pub(crate) fn max(&self) -> f64 { + self.max.into_inner() + } + + #[inline] + pub(crate) fn min(&self) -> f64 { + self.min.into_inner() + } + + #[inline] + pub(crate) fn max_size(&self) -> usize { + self.max_size + } +} + +impl Default for TDigest { + fn default() -> Self { + TDigest { + centroids: Vec::new(), + max_size: 100, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } +} + +impl TDigest { + fn k_to_q(k: f64, d: f64) -> OrderedFloat { + let k_div_d = k / d; + if k_div_d >= 0.5 { + let base = 1.0 - k_div_d; + 1.0 - 2.0 * base * base + } else { + 2.0 * k_div_d * k_div_d + } + .into() + } + + fn clamp( + v: OrderedFloat, + lo: OrderedFloat, + hi: OrderedFloat, + ) -> OrderedFloat { + if v > hi { + hi + } else if v < lo { + lo + } else { + v + } + } + + pub(crate) fn merge_unsorted( + &self, + unsorted_values: impl IntoIterator, + ) -> Result { + let mut values = unsorted_values + .into_iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?; + + values.sort(); + + Ok(self.merge_sorted_f64(&values)) + } + + fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat]) -> TDigest { + debug_assert!(is_sorted(&sorted_values), "unsorted input to TDigest"); + + if sorted_values.is_empty() { + return self.clone(); + } + + let mut result = TDigest::new(self.max_size()); + result.count = OrderedFloat::from(self.count() + (sorted_values.len() as f64)); + + let maybe_min = OrderedFloat::from(*sorted_values.first().unwrap()); + let maybe_max = OrderedFloat::from(*sorted_values.last().unwrap()); + + if self.count() > 0.0 { + result.min = std::cmp::min(self.min, maybe_min); + result.max = std::cmp::max(self.max, maybe_max); + } else { + result.min = maybe_min; + result.max = maybe_max; + } + + let mut compressed: Vec = Vec::with_capacity(self.max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + + let mut iter_centroids = self.centroids.iter().peekable(); + let mut iter_sorted_values = sorted_values.iter().peekable(); + + let mut curr: Centroid = if let Some(c) = iter_centroids.peek() { + let curr = **iter_sorted_values.peek().unwrap(); + if c.mean() < curr { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let mut weight_so_far = curr.weight(); + + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() { + let next: Centroid = if let Some(c) = iter_centroids.peek() { + if iter_sorted_values.peek().is_none() + || c.mean() < **iter_sorted_values.peek().unwrap() + { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let next_sum = next.mean() * next.weight(); + weight_so_far += next.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += next_sum; + weights_to_merge += next.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = 0.0.into(); + weights_to_merge = 0.0.into(); + + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + curr = next; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr); + compressed.shrink_to_fit(); + compressed.sort(); + + result.centroids = compressed; + result + } + + fn external_merge( + centroids: &mut Vec, + first: usize, + middle: usize, + last: usize, + ) { + let mut result: Vec = Vec::with_capacity(centroids.len()); + + let mut i = first; + let mut j = middle; + + while i < middle && j < last { + match centroids[i].cmp(¢roids[j]) { + Ordering::Less => { + result.push(centroids[i].clone()); + i += 1; + } + Ordering::Greater => { + result.push(centroids[j].clone()); + j += 1; + } + Ordering::Equal => { + result.push(centroids[i].clone()); + i += 1; + } + } + } + + while i < middle { + result.push(centroids[i].clone()); + i += 1; + } + + while j < last { + result.push(centroids[j].clone()); + j += 1; + } + + i = first; + for centroid in result.into_iter() { + centroids[i] = centroid; + i += 1; + } + } + + // Merge multiple T-Digests + pub(crate) fn merge_digests(digests: &[TDigest]) -> TDigest { + let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); + if n_centroids == 0 { + return TDigest::default(); + } + + let max_size = digests.first().unwrap().max_size; + let mut centroids: Vec = Vec::with_capacity(n_centroids); + let mut starts: Vec = Vec::with_capacity(digests.len()); + + let mut count: f64 = 0.0; + let mut min = OrderedFloat::from(std::f64::INFINITY); + let mut max = OrderedFloat::from(std::f64::NEG_INFINITY); + + let mut start: usize = 0; + for digest in digests.into_iter() { + starts.push(start); + + let curr_count: f64 = digest.count(); + if curr_count > 0.0 { + min = std::cmp::min(min, digest.min); + max = std::cmp::max(max, digest.max); + count += curr_count; + for centroid in &digest.centroids { + centroids.push(centroid.clone()); + start += 1; + } + } + } + + let mut digests_per_block: usize = 1; + while digests_per_block < starts.len() { + for i in (0..starts.len()).step_by(digests_per_block * 2) { + if i + digests_per_block < starts.len() { + let first = starts[i]; + let middle = starts[i + digests_per_block]; + let last = if i + 2 * digests_per_block < starts.len() { + starts[i + 2 * digests_per_block] + } else { + centroids.len() + }; + + debug_assert!(first <= middle && middle <= last); + Self::external_merge(&mut centroids, first, middle, last); + } + } + + digests_per_block *= 2; + } + + let mut result = TDigest::new(max_size); + let mut compressed: Vec = Vec::with_capacity(max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + + let mut iter_centroids = centroids.iter_mut(); + let mut curr = iter_centroids.next().unwrap(); + let mut weight_so_far = curr.weight(); + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + for centroid in iter_centroids { + weight_so_far += centroid.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += centroid.mean() * centroid.weight(); + weights_to_merge += centroid.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = OrderedFloat::from(0.0); + weights_to_merge = OrderedFloat::from(0.0); + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + k_limit += 1.0; + curr = centroid; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr.clone()); + compressed.shrink_to_fit(); + compressed.sort(); + + result.count = OrderedFloat::from(count as f64); + result.min = min; + result.max = max; + result.centroids = compressed; + result + } + + /// To estimate the value located at `q` quantile + pub(crate) fn estimate_quantile(&self, q: f64) -> f64 { + if self.centroids.is_empty() { + return 0.0; + } + + let count_ = self.count; + let rank = OrderedFloat::from(q) * count_; + + let mut pos: usize; + let mut t; + if q > 0.5 { + if q >= 1.0 { + return self.max(); + } + + pos = 0; + t = count_; + + for (k, centroid) in self.centroids.iter().enumerate().rev() { + t -= centroid.weight(); + + if rank >= t { + pos = k; + break; + } + } + } else { + if q <= 0.0 { + return self.min(); + } + + pos = self.centroids.len() - 1; + t = OrderedFloat::from(0.0); + + for (k, centroid) in self.centroids.iter().enumerate() { + if rank < t + centroid.weight() { + pos = k; + break; + } + + t += centroid.weight(); + } + } + + let mut delta = OrderedFloat::from(0.0); + let mut min = self.min; + let mut max = self.max; + + if self.centroids.len() > 1 { + if pos == 0 { + delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean(); + max = self.centroids[pos + 1].mean(); + } else if pos == (self.centroids.len() - 1) { + delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean(); + min = self.centroids[pos - 1].mean(); + } else { + delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean()) + / 2.0; + min = self.centroids[pos - 1].mean(); + max = self.centroids[pos + 1].mean(); + } + } + + let value = self.centroids[pos].mean() + + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta; + Self::clamp(value, min, max).into_inner() + } + + /// This method decomposes the [`TDigest`] and its [`Centroid`] instances + /// into a series of primitive scalar values. + /// + /// First the values of the TDigest are packed, followed by the variable + /// number of centroids packed into a [`ScalarValue::List`] of + /// [`ScalarValue::Float64`]: + /// + /// ```text + /// + /// ┌────────┬────────┬────────┬───────┬────────┬────────┐ + /// │max_size│ sum │ count │ max │ min │centroid│ + /// └────────┴────────┴────────┴───────┴────────┴────────┘ + /// │ + /// ┌─────────────────────┘ + /// ▼ + /// ┌ List ───┐ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 1 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 2 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// ... + /// + /// ``` + /// + /// The [`TDigest::from_scalar_state()`] method reverses this processes, + /// consuming the output of this method and returning an unpacked + /// [`TDigest`]. + pub(crate) fn to_scalar_state(&self) -> Vec { + // Gather up all the centroids + let centroids: Vec<_> = self + .centroids + .iter() + .flat_map(|c| [c.mean().into_inner(), c.weight().into_inner()]) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + vec![ + ScalarValue::UInt64(Some(self.max_size as u64)), + ScalarValue::Float64(Some(self.sum.into_inner())), + ScalarValue::Float64(Some(self.count.into_inner())), + ScalarValue::Float64(Some(self.max.into_inner())), + ScalarValue::Float64(Some(self.min.into_inner())), + ScalarValue::List(Some(Box::new(centroids)), Box::new(DataType::Float64)), + ] + } + + /// Unpack the serialised state of a [`TDigest`] produced by + /// [`Self::to_scalar_state()`]. + /// + /// # Correctness + /// + /// Providing input to this method that was not obtained from + /// [`Self::to_scalar_state()`] results in undefined behaviour and may + /// panic. + pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self { + assert_eq!(state.len(), 6, "invalid TDigest state"); + + let max_size = match &state[0] { + ScalarValue::UInt64(Some(v)) => *v as usize, + v => panic!("invalid max_size type {:?}", v), + }; + + let centroids: Vec<_> = match &state[5] { + ScalarValue::List(Some(c), d) if **d == DataType::Float64 => c + .chunks(2) + .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) + .collect(), + v => panic!("invalid centroids type {:?}", v), + }; + + let max = cast_scalar_f64!(&state[3]); + let min = cast_scalar_f64!(&state[4]); + assert!(max >= min); + + Self { + max_size, + sum: cast_scalar_f64!(state[1]), + count: cast_scalar_f64!(&state[2]), + max, + min, + centroids, + } + } +} + +#[cfg(debug_assertions)] +fn is_sorted(values: &[OrderedFloat]) -> bool { + values.windows(2).all(|w| w[0] <= w[1]) +} + +#[cfg(test)] +mod tests { + use std::iter; + + use super::*; + + // A macro to assert the specified `quantile` estimated by `t` is within the + // allowable relative error bound. + macro_rules! assert_error_bounds { + ($t:ident, quantile = $quantile:literal, want = $want:literal) => { + assert_error_bounds!( + $t, + quantile = $quantile, + want = $want, + allowable_error = 0.01 + ) + }; + ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => { + let ans = $t.estimate_quantile($quantile); + let expected: f64 = $want; + let percentage: f64 = (expected - ans).abs() / expected; + assert!( + percentage < $re, + "relative error {} is more than {}% (got quantile {}, want {})", + percentage, + $re, + ans, + expected + ); + }; + } + + macro_rules! assert_state_roundtrip { + ($t:ident) => { + let state = $t.to_scalar_state(); + let other = TDigest::from_scalar_state(&state); + assert_eq!($t, other); + }; + } + + #[test] + fn test_int64_uniform() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_int64_uniform_with_nulls() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + // Prepend some NULLs + let values = iter::repeat(ScalarValue::Int64(None)) + .take(10) + .chain(values); + // Append some more NULLs + let values = values.chain(iter::repeat(ScalarValue::Int64(None)).take(10)); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_centroid_addition_regression() { + //https://github.com/MnO2/t-digest/pull/1 + + let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0]; + let mut t = TDigest::new(10); + + for v in vals { + t = t.merge_unsorted([ScalarValue::Float64(Some(v))]).unwrap(); + } + + assert_error_bounds!(t, quantile = 0.5, want = 1.0); + assert_error_bounds!(t, quantile = 0.95, want = 2.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_uniform_distro() { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_skewed_distro() { + let t = TDigest::new(100); + let mut values: Vec<_> = (1..=600_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + for _ in 0..400_000 { + values.push(ScalarValue::Float64(Some(1_000_000.0))); + } + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_digests() { + let mut digests: Vec = Vec::new(); + + for _ in 1..=100 { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + let t = t.merge_unsorted(values).unwrap(); + digests.push(t) + } + + let t = TDigest::merge_digests(&digests); + + assert_error_bounds!(t, quantile = 1.0, want = 1000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990.0); + assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_state_roundtrip!(t); + } +} From d9a7be238ba8eb60c93c10ddd8c31bc668f16f7e Mon Sep 17 00:00:00 2001 From: Dom Date: Fri, 7 Jan 2022 18:22:37 +0000 Subject: [PATCH 02/10] feat: approx_quantile aggregation Adds the ApproxQuantile physical expression, plumbing & test cases. The function signature is: approx_quantile(column, quantile) Where column can be any numeric type (that can be cast to a float64) and quantile is a float64 literal between 0 and 1. --- datafusion/src/physical_plan/aggregates.rs | 62 +++- .../coercion_rule/aggregate_rule.rs | 114 +++++-- .../expressions/approx_quantile.rs | 294 ++++++++++++++++++ .../src/physical_plan/expressions/mod.rs | 2 + datafusion/tests/sql/aggregates.rs | 89 ++++++ 5 files changed, 538 insertions(+), 23 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/approx_quantile.rs diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index f7beb76df3bc..4bcc413fe239 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -27,7 +27,7 @@ //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. use super::{ - functions::{Signature, Volatility}, + functions::{Signature, TypeSignature, Volatility}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; @@ -80,6 +80,8 @@ pub enum AggregateFunction { CovariancePop, /// Correlation Correlation, + /// Approximate quantile function + ApproxQuantile, } impl fmt::Display for AggregateFunction { @@ -110,6 +112,7 @@ impl FromStr for AggregateFunction { "covar_samp" => AggregateFunction::Covariance, "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, + "approx_quantile" => AggregateFunction::ApproxQuantile, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -157,6 +160,7 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), + AggregateFunction::ApproxQuantile => Ok(DataType::Float64), } } @@ -331,6 +335,19 @@ pub fn create_aggregate_expr( "CORR(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::ApproxQuantile, false) => { + Arc::new(expressions::ApproxQuantile::new( + // Pass in the desired quantile expr + coerced_phy_exprs, + name, + return_type, + )?) + } + (AggregateFunction::ApproxQuantile, true) => { + return Err(DataFusionError::NotImplemented( + "approx_quantile(DISTINCT) aggregations are not available".to_string(), + )); + } }) } @@ -389,17 +406,25 @@ pub fn signature(fun: &AggregateFunction) -> Signature { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::ApproxQuantile => Signature::one_of( + // Accept any numeric value paired with a float64 quantile + NUMERICS + .iter() + .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) + .collect(), + Volatility::Immutable, + ), } } #[cfg(test)] mod tests { use super::*; - use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg, - DistinctCount, Max, Min, Stddev, Sum, Variance, + ApproxDistinct, ApproxQuantile, ArrayAgg, Avg, Correlation, Count, Covariance, + DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; + use crate::{error::Result, scalar::ScalarValue}; #[test] fn test_count_arragg_approx_expr() -> Result<()> { @@ -513,6 +538,35 @@ mod tests { Ok(()) } + #[test] + fn test_agg_approx_quantile_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), + ]; + let result_agg_phy_exprs = create_aggregate_expr( + &AggregateFunction::ApproxQuantile, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect("failed to create aggregate expr"); + + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, false), + result_agg_phy_exprs.field().unwrap() + ); + } + } + #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index c151fb70a084..9e5b3957c7a7 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -17,7 +17,6 @@ //! Support the coercion rule for aggregate function. -use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ @@ -27,6 +26,10 @@ use crate::physical_plan::expressions::{ }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; +use crate::{ + arrow::datatypes::Schema, + physical_plan::expressions::is_approx_quantile_supported_arg_type, +}; use arrow::datatypes::DataType; use std::ops::Deref; use std::sync::Arc; @@ -38,24 +41,9 @@ pub(crate) fn coerce_types( input_types: &[DataType], signature: &Signature, ) -> Result> { - match signature.type_signature { - TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != agg_count { - return Err(DataFusionError::Plan(format!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, - agg_count, - input_types.len() - ))); - } - } - _ => { - return Err(DataFusionError::Internal(format!( - "Aggregate functions do not support this {:?}", - signature - ))); - } - }; + // Validate input_types matches (at least one of) the func signature. + check_arg_count(agg_fun, input_types, &signature.type_signature)?; + match agg_fun { AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(input_types.to_vec()) @@ -151,7 +139,75 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::ApproxQuantile => { + if !is_approx_quantile_supported_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + if !matches!(input_types[1], DataType::Float64) { + return Err(DataFusionError::Plan(format!( + "The quantile argument for {:?} must be Float64, not {:?}.", + agg_fun, input_types[1] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// +/// This method DOES NOT validate the argument types - only that (at least one, +/// in the case of [`TypeSignature::OneOf`]) signature matches the desired +/// number of input types. +fn check_arg_count( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &TypeSignature, +) -> Result<()> { + match signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != *agg_count { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + agg_count, + input_types.len() + ))); + } + } + TypeSignature::Exact(types) => { + if types.len() != input_types.len() { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + types.len(), + input_types.len() + ))); + } + } + TypeSignature::OneOf(variants) => { + let ok = variants + .iter() + .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); + if !ok { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not accept {:?} function arguments.", + agg_fun, + input_types.len() + ))); + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Aggregate functions do not support this {:?}", + signature + ))); + } } + Ok(()) } fn get_min_max_result_type(input_types: &[DataType]) -> Result> { @@ -267,5 +323,25 @@ mod tests { assert_eq!(*input_type, result.unwrap()); } } + + // ApproxQuantile input types + let input_types = vec![ + vec![DataType::Int8, DataType::Float64], + vec![DataType::Int16, DataType::Float64], + vec![DataType::Int32, DataType::Float64], + vec![DataType::Int64, DataType::Float64], + vec![DataType::UInt8, DataType::Float64], + vec![DataType::UInt16, DataType::Float64], + vec![DataType::UInt32, DataType::Float64], + vec![DataType::UInt64, DataType::Float64], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Float64, DataType::Float64], + ]; + for input_type in &input_types { + let signature = aggregates::signature(&AggregateFunction::ApproxQuantile); + let result = + coerce_types(&AggregateFunction::ApproxQuantile, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } } } diff --git a/datafusion/src/physical_plan/expressions/approx_quantile.rs b/datafusion/src/physical_plan/expressions/approx_quantile.rs new file mode 100644 index 000000000000..211fae592ad5 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/approx_quantile.rs @@ -0,0 +1,294 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ + ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::{DataType, Field}, +}; + +use crate::{ + error::DataFusionError, + physical_plan::{tdigest::TDigest, Accumulator, AggregateExpr, PhysicalExpr}, + scalar::ScalarValue, +}; + +use crate::error::Result; + +use super::{format_state_name, Literal}; + +/// Return `true` if `arg_type` is of a [`DataType`] that the [`ApproxQuantile`] +/// aggregation can operate on. +pub fn is_approx_quantile_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +/// APPROX_QUANTILE aggregate expression +#[derive(Debug)] +pub struct ApproxQuantile { + name: String, + input_data_type: DataType, + expr: Arc, + quantile: f64, +} + +impl ApproxQuantile { + /// Create a new ApproxQuantile aggregate function. + pub fn new( + expr: Vec>, + name: impl Into, + input_data_type: DataType, + ) -> Result { + // Arguments should be [ColumnExpr, DesiredQuantileLiteral] + debug_assert_eq!(expr.len(), 2); + + // Extract the desired quantile literal + let lit = expr[1] + .as_any() + .downcast_ref::() + .ok_or(DataFusionError::Internal( + "desired quantile argument must be float literal".to_string(), + ))? + .value(); + let quantile = match lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q as f64, + got => return Err(DataFusionError::NotImplemented(format!( + "Quantile value for 'APPROX_QUANTILE' must be Float32 or Float64 literal (got data type {})", + got + ))) + }; + + Ok(Self { + name: name.into(), + input_data_type, + // The physical expr to evaluate during accumulation + expr: expr[0].clone(), + quantile, + }) + } +} + +impl AggregateExpr for ApproxQuantile { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, false)) + } + + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + let accumulator: Box = match &self.input_data_type { + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 => { + Box::new(ApproxQuantileAccumulator::new(self.quantile)) + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_QUANTILE' for data type {} is not implemented", + other + ))) + } + }; + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +pub struct ApproxQuantileAccumulator { + digest: TDigest, + quantile: f64, +} + +impl ApproxQuantileAccumulator { + pub fn new(quantile: f64) -> Self { + Self { + digest: TDigest::new(100), + quantile, + } + } +} + +impl Accumulator for ApproxQuantileAccumulator { + fn state(&self) -> Result> { + Ok(self.digest.to_scalar_state()) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + debug_assert_eq!( + values.len(), + 1, + "invalid number of values in quantile update" + ); + + self.digest = self.digest.merge_unsorted([values[0].clone()])?; + Ok(()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + debug_assert_eq!( + values.len(), + 1, + "invalid number of values in batch quantile update" + ); + let values = &values[0]; + + self.digest = match values.data_type() { + DataType::Float64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Float32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + e => { + return Err(DataFusionError::Internal(format!( + "APPROX_QUANTILE is not expected to receive the type {:?}", + e + ))); + } + }; + + Ok(()) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + debug_assert_eq!( + states.len(), + 6, + "invalid number of state fields for quantile accumulator" + ); + + let other = TDigest::from_scalar_state(states); + self.digest = TDigest::merge_digests(&[self.digest.clone(), other]); + + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Float64(Some( + self.digest.estimate_quantile(self.quantile), + ))) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + todo!() + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index ca14d7fa1a8d..ce4c0e7cdaf2 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; mod approx_distinct; +mod approx_quantile; mod array_agg; mod average; #[macro_use] @@ -64,6 +65,7 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; +pub use approx_quantile::{is_approx_quantile_supported_arg_type, ApproxQuantile}; pub use array_agg::ArrayAgg; pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 9d72752b091d..94dc5c4ad45a 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -354,6 +354,95 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } +// This test executes the APPROX_QUANTILE aggregation against the test data, +// asserting the estimated quantiles are ±5% their actual values. +// +// Actual quantiles calculated with: +// +// ```r +// read_csv("./testing/data/csv/aggregate_test_100.csv") |> +// select_if(is.numeric) |> +// summarise_all(~ quantile(., c(0.1, 0.5, 0.9))) +// ``` +// +// Giving: +// +// ```text +// c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 +// +// 1 1 -95.3 -22925. -1882606710 -7.25e18 18.9 2671. 472608672. 1.83e18 0.109 0.0714 +// 2 3 15.5 4599 377164262 1.13e18 134. 30634 2365817608. 9.30e18 0.491 0.551 +// 3 5 102. 25334. 1991374996. 7.37e18 231 57518. 3776538487. 1.61e19 0.834 0.946 +// ``` +// +// Column `c12` is omitted due to a large relative error (~10%) due to the small +// float values. +#[tokio::test] +async fn csv_query_approx_quantile() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + + // Generate an assertion that the estimated $quantile value for $column is + // within 5% of the $actual quantile value. + macro_rules! quantile_test { + ($ctx:ident, column=$column:literal, quantile=$quantile:literal, actual=$actual:literal) => { + let sql = format!("SELECT (ABS(1 - approx_quantile({}, {}) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual); + let actual = execute_to_batches(&mut ctx, &sql).await; + // + // "+------+", + // "| q |", + // "+------+", + // "| true |", + // "+------+", + // + let want = ["+------+", "| q |", "+------+", "| true |", "+------+"]; + assert_batches_eq!(want, &actual); + }; + } + + quantile_test!(ctx, column = "c2", quantile = 0.1, actual = 1.0); + quantile_test!(ctx, column = "c2", quantile = 0.5, actual = 3.0); + quantile_test!(ctx, column = "c2", quantile = 0.9, actual = 5.0); + //////////////////////////////////// + quantile_test!(ctx, column = "c3", quantile = 0.1, actual = -95.3); + quantile_test!(ctx, column = "c3", quantile = 0.5, actual = 15.5); + quantile_test!(ctx, column = "c3", quantile = 0.9, actual = 102.0); + //////////////////////////////////// + quantile_test!(ctx, column = "c4", quantile = 0.1, actual = -22925.0); + quantile_test!(ctx, column = "c4", quantile = 0.5, actual = 4599.0); + quantile_test!(ctx, column = "c4", quantile = 0.9, actual = 25334.0); + //////////////////////////////////// + quantile_test!(ctx, column = "c5", quantile = 0.1, actual = -1882606710.0); + quantile_test!(ctx, column = "c5", quantile = 0.5, actual = 377164262.0); + quantile_test!(ctx, column = "c5", quantile = 0.9, actual = 1991374996.0); + //////////////////////////////////// + quantile_test!(ctx, column = "c6", quantile = 0.1, actual = -7.25e18); + quantile_test!(ctx, column = "c6", quantile = 0.5, actual = 1.13e18); + quantile_test!(ctx, column = "c6", quantile = 0.9, actual = 7.37e18); + //////////////////////////////////// + quantile_test!(ctx, column = "c7", quantile = 0.1, actual = 18.9); + quantile_test!(ctx, column = "c7", quantile = 0.5, actual = 134.0); + quantile_test!(ctx, column = "c7", quantile = 0.9, actual = 231.0); + //////////////////////////////////// + quantile_test!(ctx, column = "c8", quantile = 0.1, actual = 2671.0); + quantile_test!(ctx, column = "c8", quantile = 0.5, actual = 30634.0); + quantile_test!(ctx, column = "c8", quantile = 0.9, actual = 57518.0); + //////////////////////////////////// + quantile_test!(ctx, column = "c9", quantile = 0.1, actual = 472608672.0); + quantile_test!(ctx, column = "c9", quantile = 0.5, actual = 2365817608.0); + quantile_test!(ctx, column = "c9", quantile = 0.9, actual = 3776538487.0); + //////////////////////////////////// + quantile_test!(ctx, column = "c10", quantile = 0.1, actual = 1.83e18); + quantile_test!(ctx, column = "c10", quantile = 0.5, actual = 9.30e18); + quantile_test!(ctx, column = "c10", quantile = 0.9, actual = 1.61e19); + //////////////////////////////////// + quantile_test!(ctx, column = "c11", quantile = 0.1, actual = 0.109); + quantile_test!(ctx, column = "c11", quantile = 0.5, actual = 0.491); + quantile_test!(ctx, column = "c11", quantile = 0.9, actual = 0.834); + + Ok(()) +} + #[tokio::test] async fn query_count_without_from() -> Result<()> { let mut ctx = ExecutionContext::new(); From 0cbacd12d50316e59e69a14e154e9920970da6a2 Mon Sep 17 00:00:00 2001 From: Dom Date: Sat, 8 Jan 2022 11:03:55 +0000 Subject: [PATCH 03/10] feat: approx_quantile dataframe function Adds the approx_quantile() dataframe function, and exports it in the prelude. --- datafusion/src/logical_plan/expr.rs | 9 +++++++++ datafusion/src/logical_plan/mod.rs | 14 +++++++------- datafusion/src/prelude.rs | 12 ++++++------ datafusion/tests/dataframe_functions.rs | 20 ++++++++++++++++++++ 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 98c296939bc5..5c47cdb3e951 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1647,6 +1647,15 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } +/// Calculate an approximation of the specified `quantile` for `expr`. +pub fn approx_quantile(expr: Expr, quantile: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxQuantile, + distinct: false, + args: vec![expr, quantile], + } +} + // TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many // varying arity functions /// Create an convenience function representing a unary scalar function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 56fec3cf1a0c..058f714d3116 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -36,13 +36,13 @@ pub use builder::{ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr, - bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, - combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, - create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, - initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, - max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, - regexp_match, regexp_replace, repeat, replace, replace_col, reverse, + abs, acos, and, approx_distinct, approx_quantile, array, ascii, asin, atan, avg, + binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, + create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, + floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, + lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, + or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index abc75829ea17..ab503761b2f4 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -30,10 +30,10 @@ pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ - array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, - count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, - lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, - sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, - Column, JoinType, Partitioning, + approx_quantile, array, ascii, avg, bit_length, btrim, character_length, chr, col, + concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, + initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length, + random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, + sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, + translate, trim, upper, Column, JoinType, Partitioning, }; diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index b8efc9815636..02a3d3aa3547 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -153,6 +153,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_approx_quantile() -> Result<()> { + let expr = approx_quantile(col("b"), lit(0.5)); + + let expected = vec![ + "+-------------------------------------+", + "| APPROXQUANTILE(test.b,Float64(0.5)) |", + "+-------------------------------------+", + "| 10 |", + "+-------------------------------------+", + ]; + + let df = create_test_table()?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + + Ok(()) +} + #[tokio::test] async fn test_fn_character_length() -> Result<()> { let expr = character_length(col("a")); From c41517829ab0badb3185e74f46775c5442e709f9 Mon Sep 17 00:00:00 2001 From: Dom Date: Mon, 10 Jan 2022 14:15:01 +0000 Subject: [PATCH 04/10] refactor: bastilla approx_quantile support Adds bastilla wire encoding for approx_quantile. Adding support for this required modifying the AggregateExprNode proto message to support propigating multiple LogicalExprNode aggregate arguments - all the existing aggregations take a single argument, so this wasn't needed before. This commit adds "repeated" to the expr field, which I believe is backwards compatible as described here: https://developers.google.com/protocol-buffers/docs/proto3#updating Specifically, adding "repeated" to an existing message field: "For ... message fields, optional is compatible with repeated" No existing tests needed fixing, and a new roundtrip test is included that covers the change to allow multiple expr. --- ballista/rust/core/proto/ballista.proto | 3 ++- .../rust/core/src/serde/logical_plan/from_proto.rs | 6 +++++- ballista/rust/core/src/serde/logical_plan/mod.rs | 14 ++++++++++++++ .../rust/core/src/serde/logical_plan/to_proto.rs | 14 ++++++++++---- ballista/rust/core/src/serde/mod.rs | 3 +++ 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 15a7342d7b14..82bae8b2bd49 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -176,11 +176,12 @@ enum AggregateFunction { STDDEV=11; STDDEV_POP=12; CORRELATION=13; + APPROX_QUANTILE = 14; } message AggregateExprNode { AggregateFunction aggr_function = 1; - LogicalExprNode expr = 2; + repeated LogicalExprNode expr = 2; } enum BuiltInWindowFunction { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 568485591425..044f823251a8 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1065,7 +1065,11 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateFunction { fun, - args: vec![parse_required_expr(&expr.expr)?], + args: expr + .expr + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?, distinct: false, //TODO }) } diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index c09b8a57d4aa..8851d51b04b9 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -38,6 +38,7 @@ mod roundtrip_tests { scalar::ScalarValue, sql::parser::FileType, }; + use datafusion::{logical_plan::Repartition, physical_plan::aggregates}; use protobuf::arrow_type; use std::{convert::TryInto, sync::Arc}; @@ -1001,4 +1002,17 @@ mod roundtrip_tests { Ok(()) } + + #[test] + fn roundtrip_approx_quantile() -> Result<()> { + let test_expr = Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxQuantile, + args: vec![col("bananas"), lit(0.42)], + distinct: false, + }; + + roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); + + Ok(()) + } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index eb5d8102de42..a473c4bcd947 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1074,6 +1074,9 @@ impl TryInto for &Expr { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct } + AggregateFunction::ApproxQuantile => { + protobuf::AggregateFunction::ApproxQuantile + } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, @@ -1099,11 +1102,13 @@ impl TryInto for &Expr { } }; - let arg = &args[0]; - let aggregate_expr = Box::new(protobuf::AggregateExprNode { + let aggregate_expr = protobuf::AggregateExprNode { aggr_function: aggr_function.into(), - expr: Some(Box::new(arg.try_into()?)), - }); + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + }; Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::AggregateExpr(aggregate_expr)), }) @@ -1334,6 +1339,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, + AggregateFunction::ApproxQuantile => Self::ApproxQuantile, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 4026273a9eb7..ac9b73d1debf 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -129,6 +129,9 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation, + protobuf::AggregateFunction::ApproxQuantile => { + AggregateFunction::ApproxQuantile + } } } } From 85af3433c460cb76eae1fb15f37dae48886cbb7c Mon Sep 17 00:00:00 2001 From: Dom Date: Tue, 11 Jan 2022 19:49:15 +0000 Subject: [PATCH 05/10] refactor: use input type as return type Casts the calculated quantile value to the same type as the input data. --- datafusion/src/physical_plan/aggregates.rs | 4 +-- .../expressions/approx_quantile.rs | 34 ++++++++++++++----- datafusion/tests/sql/aggregates.rs | 2 +- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 4bcc413fe239..28c43fc27231 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -160,7 +160,7 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), - AggregateFunction::ApproxQuantile => Ok(DataType::Float64), + AggregateFunction::ApproxQuantile => Ok(coerced_data_types[0].clone()), } } @@ -561,7 +561,7 @@ mod tests { assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( - Field::new("c1", DataType::Float64, false), + Field::new("c1", data_type.clone(), false), result_agg_phy_exprs.field().unwrap() ); } diff --git a/datafusion/src/physical_plan/expressions/approx_quantile.rs b/datafusion/src/physical_plan/expressions/approx_quantile.rs index 211fae592ad5..dc57fa08e27d 100644 --- a/datafusion/src/physical_plan/expressions/approx_quantile.rs +++ b/datafusion/src/physical_plan/expressions/approx_quantile.rs @@ -105,7 +105,7 @@ impl AggregateExpr for ApproxQuantile { } fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, false)) + Ok(Field::new(&self.name, self.input_data_type.clone(), false)) } /// See [`TDigest::to_scalar_state()`] for a description of the serialised @@ -151,7 +151,9 @@ impl AggregateExpr for ApproxQuantile { fn create_accumulator(&self) -> Result> { let accumulator: Box = match &self.input_data_type { - DataType::UInt8 + t + @ + (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 @@ -160,8 +162,8 @@ impl AggregateExpr for ApproxQuantile { | DataType::Int32 | DataType::Int64 | DataType::Float32 - | DataType::Float64 => { - Box::new(ApproxQuantileAccumulator::new(self.quantile)) + | DataType::Float64) => { + Box::new(ApproxQuantileAccumulator::new(self.quantile, t.clone())) } other => { return Err(DataFusionError::NotImplemented(format!( @@ -182,13 +184,15 @@ impl AggregateExpr for ApproxQuantile { pub struct ApproxQuantileAccumulator { digest: TDigest, quantile: f64, + return_type: DataType, } impl ApproxQuantileAccumulator { - pub fn new(quantile: f64) -> Self { + pub fn new(quantile: f64, return_type: DataType) -> Self { Self { digest: TDigest::new(100), quantile, + return_type, } } } @@ -283,9 +287,23 @@ impl Accumulator for ApproxQuantileAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::Float64(Some( - self.digest.estimate_quantile(self.quantile), - ))) + let q = self.digest.estimate_quantile(self.quantile); + + // These acceptable return types MUST match the validation in + // ApproxQuantile::create_accumulator. + Ok(match &self.return_type { + DataType::Int8 => ScalarValue::Int8(Some(q as i8)), + DataType::Int16 => ScalarValue::Int16(Some(q as i16)), + DataType::Int32 => ScalarValue::Int32(Some(q as i32)), + DataType::Int64 => ScalarValue::Int64(Some(q as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), + DataType::Float32 => ScalarValue::Float32(Some(q as f32)), + DataType::Float64 => ScalarValue::Float64(Some(q as f64)), + v => unreachable!("unexpected return type {:?}", v), + }) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 94dc5c4ad45a..58adc9164c1b 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -386,7 +386,7 @@ async fn csv_query_approx_quantile() -> Result<()> { // within 5% of the $actual quantile value. macro_rules! quantile_test { ($ctx:ident, column=$column:literal, quantile=$quantile:literal, actual=$actual:literal) => { - let sql = format!("SELECT (ABS(1 - approx_quantile({}, {}) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual); + let sql = format!("SELECT (ABS(1 - CAST(approx_quantile({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual); let actual = execute_to_batches(&mut ctx, &sql).await; // // "+------+", From e8f8e3f86315d6b9b9dbc200625cd7d5a3945f59 Mon Sep 17 00:00:00 2001 From: Dom Dwyer Date: Wed, 26 Jan 2022 10:53:10 +0000 Subject: [PATCH 06/10] fixup! refactor: bastilla approx_quantile support --- ballista/rust/core/src/serde/logical_plan/mod.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 8851d51b04b9..1bbafbe78b3a 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -24,21 +24,18 @@ mod roundtrip_tests { use super::super::{super::error::Result, protobuf}; use crate::error::BallistaError; use core::panic; - use datafusion::arrow::datatypes::UnionMode; - use datafusion::logical_plan::Repartition; use datafusion::{ - arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}, + arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}, datasource::object_store::local::LocalFileSystem, logical_plan::{ col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder, - Partitioning, ToDFSchema, + Partitioning, Repartition, ToDFSchema, }, - physical_plan::functions::BuiltinScalarFunction::Sqrt, + physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt}, prelude::*, scalar::ScalarValue, sql::parser::FileType, }; - use datafusion::{logical_plan::Repartition, physical_plan::aggregates}; use protobuf::arrow_type; use std::{convert::TryInto, sync::Arc}; From faa8094900e3d5107221af1d32e8741ed13389bc Mon Sep 17 00:00:00 2001 From: Dom Dwyer Date: Thu, 27 Jan 2022 20:27:13 +0000 Subject: [PATCH 07/10] refactor: rebase onto main --- .../expressions/approx_quantile.rs | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/approx_quantile.rs b/datafusion/src/physical_plan/expressions/approx_quantile.rs index dc57fa08e27d..ba497c93e993 100644 --- a/datafusion/src/physical_plan/expressions/approx_quantile.rs +++ b/datafusion/src/physical_plan/expressions/approx_quantile.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, sync::Arc}; +use std::{any::Any, iter, sync::Arc}; use arrow::{ array::{ @@ -151,9 +151,7 @@ impl AggregateExpr for ApproxQuantile { fn create_accumulator(&self) -> Result> { let accumulator: Box = match &self.input_data_type { - t - @ - (DataType::UInt8 + t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 @@ -202,17 +200,6 @@ impl Accumulator for ApproxQuantileAccumulator { Ok(self.digest.to_scalar_state()) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - debug_assert_eq!( - values.len(), - 1, - "invalid number of values in quantile update" - ); - - self.digest = self.digest.merge_unsorted([values[0].clone()])?; - Ok(()) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { debug_assert_eq!( values.len(), @@ -273,19 +260,6 @@ impl Accumulator for ApproxQuantileAccumulator { Ok(()) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - debug_assert_eq!( - states.len(), - 6, - "invalid number of state fields for quantile accumulator" - ); - - let other = TDigest::from_scalar_state(states); - self.digest = TDigest::merge_digests(&[self.digest.clone(), other]); - - Ok(()) - } - fn evaluate(&self) -> Result { let q = self.digest.estimate_quantile(self.quantile); @@ -307,6 +281,23 @@ impl Accumulator for ApproxQuantileAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - todo!() + if states.is_empty() { + return Ok(()); + }; + + let states = (0..states[0].len()) + .map(|index| { + states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>() + .map(|state| TDigest::from_scalar_state(&state)) + }) + .chain(iter::once(Ok(self.digest.clone()))) + .collect::>>()?; + + self.digest = TDigest::merge_digests(&states); + + Ok(()) } } From 03a5eff338565b66007349d303f5d0c29ba662ef Mon Sep 17 00:00:00 2001 From: Dom Dwyer Date: Thu, 27 Jan 2022 20:27:22 +0000 Subject: [PATCH 08/10] refactor: validate quantile value Ensures the quantile values is between 0 and 1, emitting a plan error if not. --- datafusion/src/physical_plan/aggregates.rs | 26 ++++++++++++++++++- .../expressions/approx_quantile.rs | 8 ++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 28c43fc27231..620e344c2744 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -547,7 +547,7 @@ mod tests { Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), ]; let result_agg_phy_exprs = create_aggregate_expr( &AggregateFunction::ApproxQuantile, @@ -567,6 +567,30 @@ mod tests { } } + #[test] + fn test_agg_approx_quantile_invalid_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), + ]; + let err = create_aggregate_expr( + &AggregateFunction::ApproxQuantile, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect_err("should fail due to invalid quantile"); + + assert!(matches!(err, DataFusionError::Plan(_))); + } + } + #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; diff --git a/datafusion/src/physical_plan/expressions/approx_quantile.rs b/datafusion/src/physical_plan/expressions/approx_quantile.rs index ba497c93e993..95dbd81ac42d 100644 --- a/datafusion/src/physical_plan/expressions/approx_quantile.rs +++ b/datafusion/src/physical_plan/expressions/approx_quantile.rs @@ -89,6 +89,14 @@ impl ApproxQuantile { ))) }; + // Ensure the quantile is between 0 and 1. + if !(0.0..=1.0).contains(&quantile) { + return Err(DataFusionError::Plan(format!( + "Quantile value must be between 0.0 and 1.0, {} is invalid", + quantile + ))); + } + Ok(Self { name: name.into(), input_data_type, From c216f48158ad48a7813e60d2675b9a01297a3e7c Mon Sep 17 00:00:00 2001 From: Dom Dwyer Date: Sat, 29 Jan 2022 11:54:39 +0000 Subject: [PATCH 09/10] refactor: rename to approx_percentile_cont --- ballista/rust/core/proto/ballista.proto | 2 +- .../rust/core/src/serde/logical_plan/mod.rs | 4 +- .../core/src/serde/logical_plan/to_proto.rs | 6 +- ballista/rust/core/src/serde/mod.rs | 4 +- datafusion/src/logical_plan/expr.rs | 8 +- datafusion/src/logical_plan/mod.rs | 4 +- datafusion/src/physical_plan/aggregates.rs | 39 +++++----- .../coercion_rule/aggregate_rule.rs | 20 +++-- ..._quantile.rs => approx_percentile_cont.rs} | 62 +++++++-------- .../src/physical_plan/expressions/mod.rs | 6 +- datafusion/src/physical_plan/tdigest/mod.rs | 2 +- datafusion/src/prelude.rs | 4 +- datafusion/tests/dataframe_functions.rs | 14 ++-- datafusion/tests/sql/aggregates.rs | 76 +++++++++---------- 14 files changed, 129 insertions(+), 122 deletions(-) rename datafusion/src/physical_plan/expressions/{approx_quantile.rs => approx_percentile_cont.rs} (83%) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 82bae8b2bd49..fb006e532ff3 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -176,7 +176,7 @@ enum AggregateFunction { STDDEV=11; STDDEV_POP=12; CORRELATION=13; - APPROX_QUANTILE = 14; + APPROX_PERCENTILE_CONT = 14; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 1bbafbe78b3a..c00e3e42912a 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -1001,9 +1001,9 @@ mod roundtrip_tests { } #[test] - fn roundtrip_approx_quantile() -> Result<()> { + fn roundtrip_approx_percentile_cont() -> Result<()> { let test_expr = Expr::AggregateFunction { - fun: aggregates::AggregateFunction::ApproxQuantile, + fun: aggregates::AggregateFunction::ApproxPercentileCont, args: vec![col("bananas"), lit(0.42)], distinct: false, }; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index a473c4bcd947..4b13ce577cfb 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1074,8 +1074,8 @@ impl TryInto for &Expr { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct } - AggregateFunction::ApproxQuantile => { - protobuf::AggregateFunction::ApproxQuantile + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, @@ -1339,7 +1339,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::ApproxQuantile => Self::ApproxQuantile, + AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index ac9b73d1debf..64a60dc4da5d 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -129,8 +129,8 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation, - protobuf::AggregateFunction::ApproxQuantile => { - AggregateFunction::ApproxQuantile + protobuf::AggregateFunction::ApproxPercentileCont => { + AggregateFunction::ApproxPercentileCont } } } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 5c47cdb3e951..a1e51e07422e 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1647,12 +1647,12 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } -/// Calculate an approximation of the specified `quantile` for `expr`. -pub fn approx_quantile(expr: Expr, quantile: Expr) -> Expr { +/// Calculate an approximation of the specified `percentile` for `expr`. +pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { Expr::AggregateFunction { - fun: aggregates::AggregateFunction::ApproxQuantile, + fun: aggregates::AggregateFunction::ApproxPercentileCont, distinct: false, - args: vec![expr, quantile], + args: vec![expr, percentile], } } diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 058f714d3116..06c6bf90c790 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -36,8 +36,8 @@ pub use builder::{ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, approx_distinct, approx_quantile, array, ascii, asin, atan, avg, - binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, + avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 620e344c2744..8b6a5e21caac 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -80,8 +80,8 @@ pub enum AggregateFunction { CovariancePop, /// Correlation Correlation, - /// Approximate quantile function - ApproxQuantile, + /// Approximate continuous percentile function + ApproxPercentileCont, } impl fmt::Display for AggregateFunction { @@ -112,7 +112,7 @@ impl FromStr for AggregateFunction { "covar_samp" => AggregateFunction::Covariance, "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, - "approx_quantile" => AggregateFunction::ApproxQuantile, + "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -160,7 +160,7 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), - AggregateFunction::ApproxQuantile => Ok(coerced_data_types[0].clone()), + AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), } } @@ -335,17 +335,18 @@ pub fn create_aggregate_expr( "CORR(DISTINCT) aggregations are not available".to_string(), )); } - (AggregateFunction::ApproxQuantile, false) => { - Arc::new(expressions::ApproxQuantile::new( - // Pass in the desired quantile expr + (AggregateFunction::ApproxPercentileCont, false) => { + Arc::new(expressions::ApproxPercentileCont::new( + // Pass in the desired percentile expr coerced_phy_exprs, name, return_type, )?) } - (AggregateFunction::ApproxQuantile, true) => { + (AggregateFunction::ApproxPercentileCont, true) => { return Err(DataFusionError::NotImplemented( - "approx_quantile(DISTINCT) aggregations are not available".to_string(), + "approx_percentile_cont(DISTINCT) aggregations are not available" + .to_string(), )); } }) @@ -406,8 +407,8 @@ pub fn signature(fun: &AggregateFunction) -> Signature { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::ApproxQuantile => Signature::one_of( - // Accept any numeric value paired with a float64 quantile + AggregateFunction::ApproxPercentileCont => Signature::one_of( + // Accept any numeric value paired with a float64 percentile NUMERICS .iter() .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) @@ -421,8 +422,8 @@ pub fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::physical_plan::expressions::{ - ApproxDistinct, ApproxQuantile, ArrayAgg, Avg, Correlation, Count, Covariance, - DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, + ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, + Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; use crate::{error::Result, scalar::ScalarValue}; @@ -539,7 +540,7 @@ mod tests { } #[test] - fn test_agg_approx_quantile_phy_expr() { + fn test_agg_approx_percentile_phy_expr() { for data_type in NUMERICS { let input_schema = Schema::new(vec![Field::new("c1", data_type.clone(), true)]); @@ -550,7 +551,7 @@ mod tests { Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), ]; let result_agg_phy_exprs = create_aggregate_expr( - &AggregateFunction::ApproxQuantile, + &AggregateFunction::ApproxPercentileCont, false, &input_phy_exprs[..], &input_schema, @@ -558,7 +559,7 @@ mod tests { ) .expect("failed to create aggregate expr"); - assert!(result_agg_phy_exprs.as_any().is::()); + assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( Field::new("c1", data_type.clone(), false), @@ -568,7 +569,7 @@ mod tests { } #[test] - fn test_agg_approx_quantile_invalid_phy_expr() { + fn test_agg_approx_percentile_invalid_phy_expr() { for data_type in NUMERICS { let input_schema = Schema::new(vec![Field::new("c1", data_type.clone(), true)]); @@ -579,13 +580,13 @@ mod tests { Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), ]; let err = create_aggregate_expr( - &AggregateFunction::ApproxQuantile, + &AggregateFunction::ApproxPercentileCont, false, &input_phy_exprs[..], &input_schema, "c1", ) - .expect_err("should fail due to invalid quantile"); + .expect_err("should fail due to invalid percentile"); assert!(matches!(err, DataFusionError::Plan(_))); } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index 9e5b3957c7a7..bae2de74c7b7 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -28,7 +28,7 @@ use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; use crate::{ arrow::datatypes::Schema, - physical_plan::expressions::is_approx_quantile_supported_arg_type, + physical_plan::expressions::is_approx_percentile_cont_supported_arg_type, }; use arrow::datatypes::DataType; use std::ops::Deref; @@ -139,8 +139,8 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::ApproxQuantile => { - if !is_approx_quantile_supported_arg_type(&input_types[0]) { + AggregateFunction::ApproxPercentileCont => { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( "The function {:?} does not support inputs of type {:?}.", agg_fun, input_types[0] @@ -148,7 +148,7 @@ pub(crate) fn coerce_types( } if !matches!(input_types[1], DataType::Float64) { return Err(DataFusionError::Plan(format!( - "The quantile argument for {:?} must be Float64, not {:?}.", + "The percentile argument for {:?} must be Float64, not {:?}.", agg_fun, input_types[1] ))); } @@ -324,7 +324,7 @@ mod tests { } } - // ApproxQuantile input types + // ApproxPercentileCont input types let input_types = vec![ vec![DataType::Int8, DataType::Float64], vec![DataType::Int16, DataType::Float64], @@ -338,9 +338,13 @@ mod tests { vec![DataType::Float64, DataType::Float64], ]; for input_type in &input_types { - let signature = aggregates::signature(&AggregateFunction::ApproxQuantile); - let result = - coerce_types(&AggregateFunction::ApproxQuantile, input_type, &signature); + let signature = + aggregates::signature(&AggregateFunction::ApproxPercentileCont); + let result = coerce_types( + &AggregateFunction::ApproxPercentileCont, + input_type, + &signature, + ); assert_eq!(*input_type, result.unwrap()); } } diff --git a/datafusion/src/physical_plan/expressions/approx_quantile.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs similarity index 83% rename from datafusion/src/physical_plan/expressions/approx_quantile.rs rename to datafusion/src/physical_plan/expressions/approx_percentile_cont.rs index 95dbd81ac42d..f1632ac45d04 100644 --- a/datafusion/src/physical_plan/expressions/approx_quantile.rs +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -35,9 +35,9 @@ use crate::error::Result; use super::{format_state_name, Literal}; -/// Return `true` if `arg_type` is of a [`DataType`] that the [`ApproxQuantile`] -/// aggregation can operate on. -pub fn is_approx_quantile_supported_arg_type(arg_type: &DataType) -> bool { +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`ApproxPercentileCont`] aggregation can operate on. +pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, DataType::UInt8 @@ -53,47 +53,47 @@ pub fn is_approx_quantile_supported_arg_type(arg_type: &DataType) -> bool { ) } -/// APPROX_QUANTILE aggregate expression +/// APPROX_PERCENTILE_CONT aggregate expression #[derive(Debug)] -pub struct ApproxQuantile { +pub struct ApproxPercentileCont { name: String, input_data_type: DataType, expr: Arc, - quantile: f64, + percentile: f64, } -impl ApproxQuantile { - /// Create a new ApproxQuantile aggregate function. +impl ApproxPercentileCont { + /// Create a new [`ApproxPercentileCont`] aggregate function. pub fn new( expr: Vec>, name: impl Into, input_data_type: DataType, ) -> Result { - // Arguments should be [ColumnExpr, DesiredQuantileLiteral] + // Arguments should be [ColumnExpr, DesiredPercentileLiteral] debug_assert_eq!(expr.len(), 2); - // Extract the desired quantile literal + // Extract the desired percentile literal let lit = expr[1] .as_any() .downcast_ref::() .ok_or(DataFusionError::Internal( - "desired quantile argument must be float literal".to_string(), + "desired percentile argument must be float literal".to_string(), ))? .value(); - let quantile = match lit { + let percentile = match lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q as f64, got => return Err(DataFusionError::NotImplemented(format!( - "Quantile value for 'APPROX_QUANTILE' must be Float32 or Float64 literal (got data type {})", + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", got ))) }; - // Ensure the quantile is between 0 and 1. - if !(0.0..=1.0).contains(&quantile) { + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { return Err(DataFusionError::Plan(format!( - "Quantile value must be between 0.0 and 1.0, {} is invalid", - quantile + "Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid", + percentile ))); } @@ -102,12 +102,12 @@ impl ApproxQuantile { input_data_type, // The physical expr to evaluate during accumulation expr: expr[0].clone(), - quantile, + percentile, }) } } -impl AggregateExpr for ApproxQuantile { +impl AggregateExpr for ApproxPercentileCont { fn as_any(&self) -> &dyn Any { self } @@ -169,11 +169,11 @@ impl AggregateExpr for ApproxQuantile { | DataType::Int64 | DataType::Float32 | DataType::Float64) => { - Box::new(ApproxQuantileAccumulator::new(self.quantile, t.clone())) + Box::new(ApproxPercentileAccumulator::new(self.percentile, t.clone())) } other => { return Err(DataFusionError::NotImplemented(format!( - "Support for 'APPROX_QUANTILE' for data type {} is not implemented", + "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented", other ))) } @@ -187,23 +187,23 @@ impl AggregateExpr for ApproxQuantile { } #[derive(Debug)] -pub struct ApproxQuantileAccumulator { +pub struct ApproxPercentileAccumulator { digest: TDigest, - quantile: f64, + percentile: f64, return_type: DataType, } -impl ApproxQuantileAccumulator { - pub fn new(quantile: f64, return_type: DataType) -> Self { +impl ApproxPercentileAccumulator { + pub fn new(percentile: f64, return_type: DataType) -> Self { Self { digest: TDigest::new(100), - quantile, + percentile, return_type, } } } -impl Accumulator for ApproxQuantileAccumulator { +impl Accumulator for ApproxPercentileAccumulator { fn state(&self) -> Result> { Ok(self.digest.to_scalar_state()) } @@ -212,7 +212,7 @@ impl Accumulator for ApproxQuantileAccumulator { debug_assert_eq!( values.len(), 1, - "invalid number of values in batch quantile update" + "invalid number of values in batch percentile update" ); let values = &values[0]; @@ -259,7 +259,7 @@ impl Accumulator for ApproxQuantileAccumulator { } e => { return Err(DataFusionError::Internal(format!( - "APPROX_QUANTILE is not expected to receive the type {:?}", + "APPROX_PERCENTILE_CONT is not expected to receive the type {:?}", e ))); } @@ -269,10 +269,10 @@ impl Accumulator for ApproxQuantileAccumulator { } fn evaluate(&self) -> Result { - let q = self.digest.estimate_quantile(self.quantile); + let q = self.digest.estimate_quantile(self.percentile); // These acceptable return types MUST match the validation in - // ApproxQuantile::create_accumulator. + // ApproxPercentile::create_accumulator. Ok(match &self.return_type { DataType::Int8 => ScalarValue::Int8(Some(q as i8)), DataType::Int16 => ScalarValue::Int16(Some(q as i16)), diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index ce4c0e7cdaf2..9344fbd6b1bc 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -26,7 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; mod approx_distinct; -mod approx_quantile; +mod approx_percentile_cont; mod array_agg; mod average; #[macro_use] @@ -65,7 +65,9 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; -pub use approx_quantile::{is_approx_quantile_supported_arg_type, ApproxQuantile}; +pub use approx_percentile_cont::{ + is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont, +}; pub use array_agg::ArrayAgg; pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; diff --git a/datafusion/src/physical_plan/tdigest/mod.rs b/datafusion/src/physical_plan/tdigest/mod.rs index 86d84f9f96fa..cd7cdf0499fe 100644 --- a/datafusion/src/physical_plan/tdigest/mod.rs +++ b/datafusion/src/physical_plan/tdigest/mod.rs @@ -98,7 +98,7 @@ impl TryIntoOrderedF64 for ScalarValue { got => { return Err(DataFusionError::NotImplemented(format!( - "Support for 'APPROX_QUANTILE' for data type {} is not implemented", + "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented", got ))) } diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index ab503761b2f4..0aff006c7896 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -30,8 +30,8 @@ pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ - approx_quantile, array, ascii, avg, bit_length, btrim, character_length, chr, col, - concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, + approx_percentile_cont, array, ascii, avg, bit_length, btrim, character_length, chr, + col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index 02a3d3aa3547..d5118b30d2af 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -154,15 +154,15 @@ async fn test_fn_btrim_with_chars() -> Result<()> { } #[tokio::test] -async fn test_fn_approx_quantile() -> Result<()> { - let expr = approx_quantile(col("b"), lit(0.5)); +async fn test_fn_approx_percentile_cont() -> Result<()> { + let expr = approx_percentile_cont(col("b"), lit(0.5)); let expected = vec![ - "+-------------------------------------+", - "| APPROXQUANTILE(test.b,Float64(0.5)) |", - "+-------------------------------------+", - "| 10 |", - "+-------------------------------------+", + "+-------------------------------------------+", + "| APPROXPERCENTILECONT(test.b,Float64(0.5)) |", + "+-------------------------------------------+", + "| 10 |", + "+-------------------------------------------+", ]; let df = create_test_table()?; diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 58adc9164c1b..736a00318ac7 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -354,8 +354,8 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } -// This test executes the APPROX_QUANTILE aggregation against the test data, -// asserting the estimated quantiles are ±5% their actual values. +// This test executes the APPROX_PERCENTILE_CONT aggregation against the test +// data, asserting the estimated quantiles are ±5% their actual values. // // Actual quantiles calculated with: // @@ -378,15 +378,15 @@ async fn csv_query_approx_count() -> Result<()> { // Column `c12` is omitted due to a large relative error (~10%) due to the small // float values. #[tokio::test] -async fn csv_query_approx_quantile() -> Result<()> { +async fn csv_query_approx_percentile_cont() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - // Generate an assertion that the estimated $quantile value for $column is - // within 5% of the $actual quantile value. - macro_rules! quantile_test { - ($ctx:ident, column=$column:literal, quantile=$quantile:literal, actual=$actual:literal) => { - let sql = format!("SELECT (ABS(1 - CAST(approx_quantile({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual); + // Generate an assertion that the estimated $percentile value for $column is + // within 5% of the $actual percentile value. + macro_rules! percentile_test { + ($ctx:ident, column=$column:literal, percentile=$percentile:literal, actual=$actual:literal) => { + let sql = format!("SELECT (ABS(1 - CAST(approx_percentile_cont({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $percentile, $actual); let actual = execute_to_batches(&mut ctx, &sql).await; // // "+------+", @@ -400,45 +400,45 @@ async fn csv_query_approx_quantile() -> Result<()> { }; } - quantile_test!(ctx, column = "c2", quantile = 0.1, actual = 1.0); - quantile_test!(ctx, column = "c2", quantile = 0.5, actual = 3.0); - quantile_test!(ctx, column = "c2", quantile = 0.9, actual = 5.0); + percentile_test!(ctx, column = "c2", percentile = 0.1, actual = 1.0); + percentile_test!(ctx, column = "c2", percentile = 0.5, actual = 3.0); + percentile_test!(ctx, column = "c2", percentile = 0.9, actual = 5.0); //////////////////////////////////// - quantile_test!(ctx, column = "c3", quantile = 0.1, actual = -95.3); - quantile_test!(ctx, column = "c3", quantile = 0.5, actual = 15.5); - quantile_test!(ctx, column = "c3", quantile = 0.9, actual = 102.0); + percentile_test!(ctx, column = "c3", percentile = 0.1, actual = -95.3); + percentile_test!(ctx, column = "c3", percentile = 0.5, actual = 15.5); + percentile_test!(ctx, column = "c3", percentile = 0.9, actual = 102.0); //////////////////////////////////// - quantile_test!(ctx, column = "c4", quantile = 0.1, actual = -22925.0); - quantile_test!(ctx, column = "c4", quantile = 0.5, actual = 4599.0); - quantile_test!(ctx, column = "c4", quantile = 0.9, actual = 25334.0); + percentile_test!(ctx, column = "c4", percentile = 0.1, actual = -22925.0); + percentile_test!(ctx, column = "c4", percentile = 0.5, actual = 4599.0); + percentile_test!(ctx, column = "c4", percentile = 0.9, actual = 25334.0); //////////////////////////////////// - quantile_test!(ctx, column = "c5", quantile = 0.1, actual = -1882606710.0); - quantile_test!(ctx, column = "c5", quantile = 0.5, actual = 377164262.0); - quantile_test!(ctx, column = "c5", quantile = 0.9, actual = 1991374996.0); + percentile_test!(ctx, column = "c5", percentile = 0.1, actual = -1882606710.0); + percentile_test!(ctx, column = "c5", percentile = 0.5, actual = 377164262.0); + percentile_test!(ctx, column = "c5", percentile = 0.9, actual = 1991374996.0); //////////////////////////////////// - quantile_test!(ctx, column = "c6", quantile = 0.1, actual = -7.25e18); - quantile_test!(ctx, column = "c6", quantile = 0.5, actual = 1.13e18); - quantile_test!(ctx, column = "c6", quantile = 0.9, actual = 7.37e18); + percentile_test!(ctx, column = "c6", percentile = 0.1, actual = -7.25e18); + percentile_test!(ctx, column = "c6", percentile = 0.5, actual = 1.13e18); + percentile_test!(ctx, column = "c6", percentile = 0.9, actual = 7.37e18); //////////////////////////////////// - quantile_test!(ctx, column = "c7", quantile = 0.1, actual = 18.9); - quantile_test!(ctx, column = "c7", quantile = 0.5, actual = 134.0); - quantile_test!(ctx, column = "c7", quantile = 0.9, actual = 231.0); + percentile_test!(ctx, column = "c7", percentile = 0.1, actual = 18.9); + percentile_test!(ctx, column = "c7", percentile = 0.5, actual = 134.0); + percentile_test!(ctx, column = "c7", percentile = 0.9, actual = 231.0); //////////////////////////////////// - quantile_test!(ctx, column = "c8", quantile = 0.1, actual = 2671.0); - quantile_test!(ctx, column = "c8", quantile = 0.5, actual = 30634.0); - quantile_test!(ctx, column = "c8", quantile = 0.9, actual = 57518.0); + percentile_test!(ctx, column = "c8", percentile = 0.1, actual = 2671.0); + percentile_test!(ctx, column = "c8", percentile = 0.5, actual = 30634.0); + percentile_test!(ctx, column = "c8", percentile = 0.9, actual = 57518.0); //////////////////////////////////// - quantile_test!(ctx, column = "c9", quantile = 0.1, actual = 472608672.0); - quantile_test!(ctx, column = "c9", quantile = 0.5, actual = 2365817608.0); - quantile_test!(ctx, column = "c9", quantile = 0.9, actual = 3776538487.0); + percentile_test!(ctx, column = "c9", percentile = 0.1, actual = 472608672.0); + percentile_test!(ctx, column = "c9", percentile = 0.5, actual = 2365817608.0); + percentile_test!(ctx, column = "c9", percentile = 0.9, actual = 3776538487.0); //////////////////////////////////// - quantile_test!(ctx, column = "c10", quantile = 0.1, actual = 1.83e18); - quantile_test!(ctx, column = "c10", quantile = 0.5, actual = 9.30e18); - quantile_test!(ctx, column = "c10", quantile = 0.9, actual = 1.61e19); + percentile_test!(ctx, column = "c10", percentile = 0.1, actual = 1.83e18); + percentile_test!(ctx, column = "c10", percentile = 0.5, actual = 9.30e18); + percentile_test!(ctx, column = "c10", percentile = 0.9, actual = 1.61e19); //////////////////////////////////// - quantile_test!(ctx, column = "c11", quantile = 0.1, actual = 0.109); - quantile_test!(ctx, column = "c11", quantile = 0.5, actual = 0.491); - quantile_test!(ctx, column = "c11", quantile = 0.9, actual = 0.834); + percentile_test!(ctx, column = "c11", percentile = 0.1, actual = 0.109); + percentile_test!(ctx, column = "c11", percentile = 0.5, actual = 0.491); + percentile_test!(ctx, column = "c11", percentile = 0.9, actual = 0.834); Ok(()) } From 3612493c5521031e2bac93f15ae9ad7dc780abc2 Mon Sep 17 00:00:00 2001 From: Dom Dwyer Date: Mon, 31 Jan 2022 11:10:16 +0000 Subject: [PATCH 10/10] refactor: clippy lints --- .../physical_plan/expressions/approx_percentile_cont.rs | 8 +++++--- datafusion/src/physical_plan/tdigest/mod.rs | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs index f1632ac45d04..cba30ee481ab 100644 --- a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -76,9 +76,11 @@ impl ApproxPercentileCont { let lit = expr[1] .as_any() .downcast_ref::() - .ok_or(DataFusionError::Internal( - "desired percentile argument must be float literal".to_string(), - ))? + .ok_or_else(|| { + DataFusionError::Internal( + "desired percentile argument must be float literal".to_string(), + ) + })? .value(); let percentile = match lit { ScalarValue::Float32(Some(q)) => *q as f64, diff --git a/datafusion/src/physical_plan/tdigest/mod.rs b/datafusion/src/physical_plan/tdigest/mod.rs index cd7cdf0499fe..6780adc84cd1 100644 --- a/datafusion/src/physical_plan/tdigest/mod.rs +++ b/datafusion/src/physical_plan/tdigest/mod.rs @@ -266,7 +266,7 @@ impl TDigest { } fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat]) -> TDigest { - debug_assert!(is_sorted(&sorted_values), "unsorted input to TDigest"); + debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); if sorted_values.is_empty() { return self.clone(); @@ -275,8 +275,8 @@ impl TDigest { let mut result = TDigest::new(self.max_size()); result.count = OrderedFloat::from(self.count() + (sorted_values.len() as f64)); - let maybe_min = OrderedFloat::from(*sorted_values.first().unwrap()); - let maybe_max = OrderedFloat::from(*sorted_values.last().unwrap()); + let maybe_min = *sorted_values.first().unwrap(); + let maybe_max = *sorted_values.last().unwrap(); if self.count() > 0.0 { result.min = std::cmp::min(self.min, maybe_min); @@ -418,7 +418,7 @@ impl TDigest { let mut max = OrderedFloat::from(std::f64::NEG_INFINITY); let mut start: usize = 0; - for digest in digests.into_iter() { + for digest in digests.iter() { starts.push(start); let curr_count: f64 = digest.count();