Skip to content

Commit

Permalink
#2004 approx percentile with weight (#2031)
Browse files Browse the repository at this point in the history
* Add new aggregate function in multiple places

* implement new aggregator and test case

* rename to SessionContext (follow latest change on master branch)

* fix clippy

* fix clippy

* fix error message and add test cases for error ones
  • Loading branch information
jychen7 authored Mar 24, 2022
1 parent d3c45c2 commit e8ed603
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 90 deletions.
5 changes: 5 additions & 0 deletions datafusion-expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub enum AggregateFunction {
Correlation,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
ApproxPercentileContWithWeight,
/// ApproxMedian
ApproxMedian,
}
Expand Down Expand Up @@ -86,6 +88,9 @@ impl FromStr for AggregateFunction {
"covar_pop" => AggregateFunction::CovariancePop,
"corr" => AggregateFunction::Correlation,
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
"approx_percentile_cont_with_weight" => {
AggregateFunction::ApproxPercentileContWithWeight
}
"approx_median" => AggregateFunction::ApproxMedian,
_ => {
return Err(DataFusionError::Plan(format!(
Expand Down
13 changes: 13 additions & 0 deletions datafusion-expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
}
}

/// Calculate an approximation of the specified `percentile` for `expr` and `weight_expr`.
pub fn approx_percentile_cont_with_weight(
expr: Expr,
weight_expr: Expr,
percentile: Expr,
) -> Expr {
Expr::AggregateFunction {
fun: aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
distinct: false,
args: vec![expr, weight_expr, percentile],
}
}

// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
// varying arity functions
/// Create an convenience function representing a unary scalar function
Expand Down
21 changes: 21 additions & 0 deletions datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxPercentileContWithWeight => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
return Err(DataFusionError::Plan(format!(
"The weight argument for {:?} does not support inputs of type {:?}.",
agg_fun, input_types[1]
)));
}
if !matches!(input_types[2], DataType::Float64) {
return Err(DataFusionError::Plan(format!(
"The percentile argument for {:?} must be Float64, not {:?}.",
agg_fun, input_types[2]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxMedian => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
Expand Down
156 changes: 108 additions & 48 deletions datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
// under the License.

use super::{format_state_name, Literal};
use crate::{tdigest::TDigest, AggregateExpr, PhysicalExpr};
use crate::tdigest::TryIntoOrderedF64;
use crate::{
tdigest::{TDigest, DEFAULT_MAX_SIZE},
AggregateExpr, PhysicalExpr,
};
use arrow::{
array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Expand All @@ -28,6 +32,7 @@ use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::Accumulator;
use ordered_float::OrderedFloat;
use std::{any::Any, iter, sync::Arc};

/// Return `true` if `arg_type` is of a [`DataType`] that the
Expand Down Expand Up @@ -102,6 +107,30 @@ impl ApproxPercentileCont {
percentile,
})
}

pub(crate) fn create_plain_accumulator(&self) -> Result<ApproxPercentileAccumulator> {
let accumulator: ApproxPercentileAccumulator = match &self.input_data_type {
t @ (DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64) => {
ApproxPercentileAccumulator::new(self.percentile, t.clone())
}
other => {
return Err(DataFusionError::NotImplemented(format!(
"Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented",
other
)))
}
};
Ok(accumulator)
}
}

impl AggregateExpr for ApproxPercentileCont {
Expand Down Expand Up @@ -156,27 +185,8 @@ impl AggregateExpr for ApproxPercentileCont {
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let accumulator: Box<dyn Accumulator> = match &self.input_data_type {
t @ (DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64) => {
Box::new(ApproxPercentileAccumulator::new(self.percentile, t.clone()))
}
other => {
return Err(DataFusionError::NotImplemented(format!(
"Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented",
other
)))
}
};
Ok(accumulator)
let accumulator = self.create_plain_accumulator()?;
Ok(Box::new(accumulator))
}

fn name(&self) -> &str {
Expand All @@ -194,75 +204,125 @@ pub struct ApproxPercentileAccumulator {
impl ApproxPercentileAccumulator {
pub fn new(percentile: f64, return_type: DataType) -> Self {
Self {
digest: TDigest::new(100),
digest: TDigest::new(DEFAULT_MAX_SIZE),
percentile,
return_type,
}
}
}

impl Accumulator for ApproxPercentileAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(self.digest.to_scalar_state())
pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
self.digest = TDigest::merge_digests(digests);
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
debug_assert_eq!(
values.len(),
1,
"invalid number of values in batch percentile update"
);
let values = &values[0];

self.digest = match values.data_type() {
pub(crate) fn convert_to_ordered_float(
values: &ArrayRef,
) -> Result<Vec<OrderedFloat<f64>>> {
match values.data_type() {
DataType::Float64 => {
let array = values.as_any().downcast_ref::<Float64Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Float32 => {
let array = values.as_any().downcast_ref::<Float32Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int64 => {
let array = values.as_any().downcast_ref::<Int64Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int32 => {
let array = values.as_any().downcast_ref::<Int32Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int16 => {
let array = values.as_any().downcast_ref::<Int16Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::Int8 => {
let array = values.as_any().downcast_ref::<Int8Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt64 => {
let array = values.as_any().downcast_ref::<UInt64Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt32 => {
let array = values.as_any().downcast_ref::<UInt32Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt16 => {
let array = values.as_any().downcast_ref::<UInt16Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
DataType::UInt8 => {
let array = values.as_any().downcast_ref::<UInt8Array>().unwrap();
self.digest.merge_unsorted(array.values().iter().cloned())?
Ok(array
.values()
.iter()
.filter_map(|v| v.try_as_f64().transpose())
.collect::<Result<Vec<_>>>()?)
}
e => {
return Err(DataFusionError::Internal(format!(
"APPROX_PERCENTILE_CONT is not expected to receive the type {:?}",
e
)));
}
};
}
}
}

impl Accumulator for ApproxPercentileAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(self.digest.to_scalar_state())
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
debug_assert_eq!(
values.len(),
1,
"invalid number of values in batch percentile update"
);
let values = &values[0];
let unsorted_values =
ApproxPercentileAccumulator::convert_to_ordered_float(values)?;
self.digest = self.digest.merge_unsorted_f64(unsorted_values);
Ok(())
}

Expand Down Expand Up @@ -302,7 +362,7 @@ impl Accumulator for ApproxPercentileAccumulator {
.chain(iter::once(Ok(self.digest.clone())))
.collect::<Result<Vec<_>>>()?;

self.digest = TDigest::merge_digests(&states);
self.merge_digests(&states);

Ok(())
}
Expand Down
Loading

0 comments on commit e8ed603

Please sign in to comment.