diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 15a7342d7b14..fb006e532ff3 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -176,11 +176,12 @@ enum AggregateFunction { STDDEV=11; STDDEV_POP=12; CORRELATION=13; + APPROX_PERCENTILE_CONT = 14; } message AggregateExprNode { AggregateFunction aggr_function = 1; - LogicalExprNode expr = 2; + repeated LogicalExprNode expr = 2; } enum BuiltInWindowFunction { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 568485591425..044f823251a8 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1065,7 +1065,11 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateFunction { fun, - args: vec![parse_required_expr(&expr.expr)?], + args: expr + .expr + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?, distinct: false, //TODO }) } diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index c09b8a57d4aa..c00e3e42912a 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -24,16 +24,14 @@ mod roundtrip_tests { use super::super::{super::error::Result, protobuf}; use crate::error::BallistaError; use core::panic; - use datafusion::arrow::datatypes::UnionMode; - use datafusion::logical_plan::Repartition; use datafusion::{ - arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}, + arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}, datasource::object_store::local::LocalFileSystem, logical_plan::{ col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder, - Partitioning, ToDFSchema, + Partitioning, Repartition, ToDFSchema, }, - physical_plan::functions::BuiltinScalarFunction::Sqrt, + physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt}, prelude::*, scalar::ScalarValue, sql::parser::FileType, @@ -1001,4 +999,17 @@ mod roundtrip_tests { Ok(()) } + + #[test] + fn roundtrip_approx_percentile_cont() -> Result<()> { + let test_expr = Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxPercentileCont, + args: vec![col("bananas"), lit(0.42)], + distinct: false, + }; + + roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); + + Ok(()) + } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index eb5d8102de42..4b13ce577cfb 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1074,6 +1074,9 @@ impl TryInto for &Expr { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, @@ -1099,11 +1102,13 @@ impl TryInto for &Expr { } }; - let arg = &args[0]; - let aggregate_expr = Box::new(protobuf::AggregateExprNode { + let aggregate_expr = protobuf::AggregateExprNode { aggr_function: aggr_function.into(), - expr: Some(Box::new(arg.try_into()?)), - }); + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + }; Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::AggregateExpr(aggregate_expr)), }) @@ -1334,6 +1339,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, + AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 4026273a9eb7..64a60dc4da5d 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -129,6 +129,9 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation, + protobuf::AggregateFunction::ApproxPercentileCont => { + AggregateFunction::ApproxPercentileCont + } } } } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 98c296939bc5..a1e51e07422e 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1647,6 +1647,15 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } +/// Calculate an approximation of the specified `percentile` for `expr`. +pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxPercentileCont, + distinct: false, + args: vec![expr, percentile], + } +} + // TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many // varying arity functions /// Create an convenience function representing a unary scalar function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 56fec3cf1a0c..06c6bf90c790 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -36,13 +36,13 @@ pub use builder::{ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr, - bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, - combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, - create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, - initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, - max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, - regexp_match, regexp_replace, repeat, replace, replace_col, reverse, + abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, + avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, + create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, + floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, + lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, + or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index f7beb76df3bc..8b6a5e21caac 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -27,7 +27,7 @@ //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. use super::{ - functions::{Signature, Volatility}, + functions::{Signature, TypeSignature, Volatility}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; @@ -80,6 +80,8 @@ pub enum AggregateFunction { CovariancePop, /// Correlation Correlation, + /// Approximate continuous percentile function + ApproxPercentileCont, } impl fmt::Display for AggregateFunction { @@ -110,6 +112,7 @@ impl FromStr for AggregateFunction { "covar_samp" => AggregateFunction::Covariance, "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, + "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -157,6 +160,7 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), + AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), } } @@ -331,6 +335,20 @@ pub fn create_aggregate_expr( "CORR(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::ApproxPercentileCont, false) => { + Arc::new(expressions::ApproxPercentileCont::new( + // Pass in the desired percentile expr + coerced_phy_exprs, + name, + return_type, + )?) + } + (AggregateFunction::ApproxPercentileCont, true) => { + return Err(DataFusionError::NotImplemented( + "approx_percentile_cont(DISTINCT) aggregations are not available" + .to_string(), + )); + } }) } @@ -389,17 +407,25 @@ pub fn signature(fun: &AggregateFunction) -> Signature { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::ApproxPercentileCont => Signature::one_of( + // Accept any numeric value paired with a float64 percentile + NUMERICS + .iter() + .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) + .collect(), + Volatility::Immutable, + ), } } #[cfg(test)] mod tests { use super::*; - use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg, - DistinctCount, Max, Min, Stddev, Sum, Variance, + ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, + Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; + use crate::{error::Result, scalar::ScalarValue}; #[test] fn test_count_arragg_approx_expr() -> Result<()> { @@ -513,6 +539,59 @@ mod tests { Ok(()) } + #[test] + fn test_agg_approx_percentile_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), + ]; + let result_agg_phy_exprs = create_aggregate_expr( + &AggregateFunction::ApproxPercentileCont, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect("failed to create aggregate expr"); + + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), false), + result_agg_phy_exprs.field().unwrap() + ); + } + } + + #[test] + fn test_agg_approx_percentile_invalid_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), + ]; + let err = create_aggregate_expr( + &AggregateFunction::ApproxPercentileCont, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect_err("should fail due to invalid percentile"); + + assert!(matches!(err, DataFusionError::Plan(_))); + } + } + #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index c151fb70a084..bae2de74c7b7 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -17,7 +17,6 @@ //! Support the coercion rule for aggregate function. -use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ @@ -27,6 +26,10 @@ use crate::physical_plan::expressions::{ }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; +use crate::{ + arrow::datatypes::Schema, + physical_plan::expressions::is_approx_percentile_cont_supported_arg_type, +}; use arrow::datatypes::DataType; use std::ops::Deref; use std::sync::Arc; @@ -38,24 +41,9 @@ pub(crate) fn coerce_types( input_types: &[DataType], signature: &Signature, ) -> Result> { - match signature.type_signature { - TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != agg_count { - return Err(DataFusionError::Plan(format!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, - agg_count, - input_types.len() - ))); - } - } - _ => { - return Err(DataFusionError::Internal(format!( - "Aggregate functions do not support this {:?}", - signature - ))); - } - }; + // Validate input_types matches (at least one of) the func signature. + check_arg_count(agg_fun, input_types, &signature.type_signature)?; + match agg_fun { AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(input_types.to_vec()) @@ -151,7 +139,75 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::ApproxPercentileCont => { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + if !matches!(input_types[1], DataType::Float64) { + return Err(DataFusionError::Plan(format!( + "The percentile argument for {:?} must be Float64, not {:?}.", + agg_fun, input_types[1] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// +/// This method DOES NOT validate the argument types - only that (at least one, +/// in the case of [`TypeSignature::OneOf`]) signature matches the desired +/// number of input types. +fn check_arg_count( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &TypeSignature, +) -> Result<()> { + match signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != *agg_count { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + agg_count, + input_types.len() + ))); + } + } + TypeSignature::Exact(types) => { + if types.len() != input_types.len() { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + types.len(), + input_types.len() + ))); + } + } + TypeSignature::OneOf(variants) => { + let ok = variants + .iter() + .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); + if !ok { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not accept {:?} function arguments.", + agg_fun, + input_types.len() + ))); + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Aggregate functions do not support this {:?}", + signature + ))); + } } + Ok(()) } fn get_min_max_result_type(input_types: &[DataType]) -> Result> { @@ -267,5 +323,29 @@ mod tests { assert_eq!(*input_type, result.unwrap()); } } + + // ApproxPercentileCont input types + let input_types = vec![ + vec![DataType::Int8, DataType::Float64], + vec![DataType::Int16, DataType::Float64], + vec![DataType::Int32, DataType::Float64], + vec![DataType::Int64, DataType::Float64], + vec![DataType::UInt8, DataType::Float64], + vec![DataType::UInt16, DataType::Float64], + vec![DataType::UInt32, DataType::Float64], + vec![DataType::UInt64, DataType::Float64], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Float64, DataType::Float64], + ]; + for input_type in &input_types { + let signature = + aggregates::signature(&AggregateFunction::ApproxPercentileCont); + let result = coerce_types( + &AggregateFunction::ApproxPercentileCont, + input_type, + &signature, + ); + assert_eq!(*input_type, result.unwrap()); + } } } diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs new file mode 100644 index 000000000000..cba30ee481ab --- /dev/null +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, iter, sync::Arc}; + +use arrow::{ + array::{ + ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::{DataType, Field}, +}; + +use crate::{ + error::DataFusionError, + physical_plan::{tdigest::TDigest, Accumulator, AggregateExpr, PhysicalExpr}, + scalar::ScalarValue, +}; + +use crate::error::Result; + +use super::{format_state_name, Literal}; + +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`ApproxPercentileCont`] aggregation can operate on. +pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +/// APPROX_PERCENTILE_CONT aggregate expression +#[derive(Debug)] +pub struct ApproxPercentileCont { + name: String, + input_data_type: DataType, + expr: Arc, + percentile: f64, +} + +impl ApproxPercentileCont { + /// Create a new [`ApproxPercentileCont`] aggregate function. + pub fn new( + expr: Vec>, + name: impl Into, + input_data_type: DataType, + ) -> Result { + // Arguments should be [ColumnExpr, DesiredPercentileLiteral] + debug_assert_eq!(expr.len(), 2); + + // Extract the desired percentile literal + let lit = expr[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired percentile argument must be float literal".to_string(), + ) + })? + .value(); + let percentile = match lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q as f64, + got => return Err(DataFusionError::NotImplemented(format!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got + ))) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return Err(DataFusionError::Plan(format!( + "Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid", + percentile + ))); + } + + Ok(Self { + name: name.into(), + input_data_type, + // The physical expr to evaluate during accumulation + expr: expr[0].clone(), + percentile, + }) + } +} + +impl AggregateExpr for ApproxPercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), false)) + } + + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + let accumulator: Box = match &self.input_data_type { + t @ (DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64) => { + Box::new(ApproxPercentileAccumulator::new(self.percentile, t.clone())) + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented", + other + ))) + } + }; + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +pub struct ApproxPercentileAccumulator { + digest: TDigest, + percentile: f64, + return_type: DataType, +} + +impl ApproxPercentileAccumulator { + pub fn new(percentile: f64, return_type: DataType) -> Self { + Self { + digest: TDigest::new(100), + percentile, + return_type, + } + } +} + +impl Accumulator for ApproxPercentileAccumulator { + fn state(&self) -> Result> { + Ok(self.digest.to_scalar_state()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + debug_assert_eq!( + values.len(), + 1, + "invalid number of values in batch percentile update" + ); + let values = &values[0]; + + self.digest = match values.data_type() { + DataType::Float64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Float32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + e => { + return Err(DataFusionError::Internal(format!( + "APPROX_PERCENTILE_CONT is not expected to receive the type {:?}", + e + ))); + } + }; + + Ok(()) + } + + fn evaluate(&self) -> Result { + let q = self.digest.estimate_quantile(self.percentile); + + // These acceptable return types MUST match the validation in + // ApproxPercentile::create_accumulator. + Ok(match &self.return_type { + DataType::Int8 => ScalarValue::Int8(Some(q as i8)), + DataType::Int16 => ScalarValue::Int16(Some(q as i16)), + DataType::Int32 => ScalarValue::Int32(Some(q as i32)), + DataType::Int64 => ScalarValue::Int64(Some(q as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), + DataType::Float32 => ScalarValue::Float32(Some(q as f32)), + DataType::Float64 => ScalarValue::Float64(Some(q as f64)), + v => unreachable!("unexpected return type {:?}", v), + }) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + + let states = (0..states[0].len()) + .map(|index| { + states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>() + .map(|state| TDigest::from_scalar_state(&state)) + }) + .chain(iter::once(Ok(self.digest.clone()))) + .collect::>>()?; + + self.digest = TDigest::merge_digests(&states); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index ca14d7fa1a8d..9344fbd6b1bc 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; mod approx_distinct; +mod approx_percentile_cont; mod array_agg; mod average; #[macro_use] @@ -64,6 +65,9 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; +pub use approx_percentile_cont::{ + is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont, +}; pub use array_agg::ArrayAgg; pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 216d4a65e639..66d913d8b24a 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -661,6 +661,7 @@ pub mod repartition; pub mod sorts; pub mod stream; pub mod string_expressions; +pub(crate) mod tdigest; pub mod type_coercion; pub mod udaf; pub mod udf; diff --git a/datafusion/src/physical_plan/tdigest/mod.rs b/datafusion/src/physical_plan/tdigest/mod.rs new file mode 100644 index 000000000000..6780adc84cd1 --- /dev/null +++ b/datafusion/src/physical_plan/tdigest/mod.rs @@ -0,0 +1,818 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with this +// work for additional information regarding copyright ownership. The ASF +// licenses this file to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//! An implementation of the [TDigest sketch algorithm] providing approximate +//! quantile calculations. +//! +//! The TDigest code in this module is modified from +//! https://github.com/MnO2/t-digest, itself a rust reimplementation of +//! [Facebook's Folly TDigest] implementation. +//! +//! Alterations include reduction of runtime heap allocations, broader type +//! support, (de-)serialisation support, reduced type conversions and null value +//! tolerance. +//! +//! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 +//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h + +use arrow::datatypes::DataType; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; + +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +// Cast a non-null [`ScalarValue::Float64`] to an [`OrderedFloat`], or +// panic. +macro_rules! cast_scalar_f64 { + ($value:expr ) => { + match &$value { + ScalarValue::Float64(Some(v)) => OrderedFloat::from(*v), + v => panic!("invalid type {:?}", v), + } + }; +} + +/// This trait is implemented for each type a [`TDigest`] can operate on, +/// allowing it to support both numerical rust types (obtained from +/// `PrimitiveArray` instances), and [`ScalarValue`] instances. +pub(crate) trait TryIntoOrderedF64 { + /// A fallible conversion of a possibly null `self` into a [`OrderedFloat`]. + /// + /// If `self` is null, this method must return `Ok(None)`. + /// + /// If `self` cannot be coerced to the desired type, this method must return + /// an `Err` variant. + fn try_as_f64(&self) -> Result>>; +} + +/// Generate an infallible conversion from `type` to an [`OrderedFloat`]. +macro_rules! impl_try_ordered_f64 { + ($type:ty) => { + impl TryIntoOrderedF64 for $type { + fn try_as_f64(&self) -> Result>> { + Ok(Some(OrderedFloat::from(*self as f64))) + } + } + }; +} + +impl_try_ordered_f64!(f64); +impl_try_ordered_f64!(f32); +impl_try_ordered_f64!(i64); +impl_try_ordered_f64!(i32); +impl_try_ordered_f64!(i16); +impl_try_ordered_f64!(i8); +impl_try_ordered_f64!(u64); +impl_try_ordered_f64!(u32); +impl_try_ordered_f64!(u16); +impl_try_ordered_f64!(u8); + +impl TryIntoOrderedF64 for ScalarValue { + fn try_as_f64(&self) -> Result>> { + match self { + ScalarValue::Float32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Float64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + + got => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented", + got + ))) + } + } + } +} + +/// Centroid implementation to the cluster mentioned in the paper. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct Centroid { + mean: OrderedFloat, + weight: OrderedFloat, +} + +impl PartialOrd for Centroid { + fn partial_cmp(&self, other: &Centroid) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Centroid { + fn cmp(&self, other: &Centroid) -> Ordering { + self.mean.cmp(&other.mean) + } +} + +impl Centroid { + pub(crate) fn new( + mean: impl Into>, + weight: impl Into>, + ) -> Self { + Centroid { + mean: mean.into(), + weight: weight.into(), + } + } + + #[inline] + pub(crate) fn mean(&self) -> OrderedFloat { + self.mean + } + + #[inline] + pub(crate) fn weight(&self) -> OrderedFloat { + self.weight + } + + pub(crate) fn add( + &mut self, + sum: impl Into>, + weight: impl Into>, + ) -> f64 { + let new_sum = sum.into() + self.weight * self.mean; + let new_weight = self.weight + weight.into(); + self.weight = new_weight; + self.mean = new_sum / new_weight; + new_sum.into_inner() + } +} + +impl Default for Centroid { + fn default() -> Self { + Centroid { + mean: OrderedFloat::from(0.0), + weight: OrderedFloat::from(1.0), + } + } +} + +/// T-Digest to be operated on. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct TDigest { + centroids: Vec, + max_size: usize, + sum: OrderedFloat, + count: OrderedFloat, + max: OrderedFloat, + min: OrderedFloat, +} + +impl TDigest { + pub(crate) fn new(max_size: usize) -> Self { + TDigest { + centroids: Vec::new(), + max_size, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } + + #[inline] + pub(crate) fn count(&self) -> f64 { + self.count.into_inner() + } + + #[inline] + pub(crate) fn max(&self) -> f64 { + self.max.into_inner() + } + + #[inline] + pub(crate) fn min(&self) -> f64 { + self.min.into_inner() + } + + #[inline] + pub(crate) fn max_size(&self) -> usize { + self.max_size + } +} + +impl Default for TDigest { + fn default() -> Self { + TDigest { + centroids: Vec::new(), + max_size: 100, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } +} + +impl TDigest { + fn k_to_q(k: f64, d: f64) -> OrderedFloat { + let k_div_d = k / d; + if k_div_d >= 0.5 { + let base = 1.0 - k_div_d; + 1.0 - 2.0 * base * base + } else { + 2.0 * k_div_d * k_div_d + } + .into() + } + + fn clamp( + v: OrderedFloat, + lo: OrderedFloat, + hi: OrderedFloat, + ) -> OrderedFloat { + if v > hi { + hi + } else if v < lo { + lo + } else { + v + } + } + + pub(crate) fn merge_unsorted( + &self, + unsorted_values: impl IntoIterator, + ) -> Result { + let mut values = unsorted_values + .into_iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?; + + values.sort(); + + Ok(self.merge_sorted_f64(&values)) + } + + fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat]) -> TDigest { + debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); + + if sorted_values.is_empty() { + return self.clone(); + } + + let mut result = TDigest::new(self.max_size()); + result.count = OrderedFloat::from(self.count() + (sorted_values.len() as f64)); + + let maybe_min = *sorted_values.first().unwrap(); + let maybe_max = *sorted_values.last().unwrap(); + + if self.count() > 0.0 { + result.min = std::cmp::min(self.min, maybe_min); + result.max = std::cmp::max(self.max, maybe_max); + } else { + result.min = maybe_min; + result.max = maybe_max; + } + + let mut compressed: Vec = Vec::with_capacity(self.max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + + let mut iter_centroids = self.centroids.iter().peekable(); + let mut iter_sorted_values = sorted_values.iter().peekable(); + + let mut curr: Centroid = if let Some(c) = iter_centroids.peek() { + let curr = **iter_sorted_values.peek().unwrap(); + if c.mean() < curr { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let mut weight_so_far = curr.weight(); + + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() { + let next: Centroid = if let Some(c) = iter_centroids.peek() { + if iter_sorted_values.peek().is_none() + || c.mean() < **iter_sorted_values.peek().unwrap() + { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let next_sum = next.mean() * next.weight(); + weight_so_far += next.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += next_sum; + weights_to_merge += next.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = 0.0.into(); + weights_to_merge = 0.0.into(); + + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + curr = next; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr); + compressed.shrink_to_fit(); + compressed.sort(); + + result.centroids = compressed; + result + } + + fn external_merge( + centroids: &mut Vec, + first: usize, + middle: usize, + last: usize, + ) { + let mut result: Vec = Vec::with_capacity(centroids.len()); + + let mut i = first; + let mut j = middle; + + while i < middle && j < last { + match centroids[i].cmp(¢roids[j]) { + Ordering::Less => { + result.push(centroids[i].clone()); + i += 1; + } + Ordering::Greater => { + result.push(centroids[j].clone()); + j += 1; + } + Ordering::Equal => { + result.push(centroids[i].clone()); + i += 1; + } + } + } + + while i < middle { + result.push(centroids[i].clone()); + i += 1; + } + + while j < last { + result.push(centroids[j].clone()); + j += 1; + } + + i = first; + for centroid in result.into_iter() { + centroids[i] = centroid; + i += 1; + } + } + + // Merge multiple T-Digests + pub(crate) fn merge_digests(digests: &[TDigest]) -> TDigest { + let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); + if n_centroids == 0 { + return TDigest::default(); + } + + let max_size = digests.first().unwrap().max_size; + let mut centroids: Vec = Vec::with_capacity(n_centroids); + let mut starts: Vec = Vec::with_capacity(digests.len()); + + let mut count: f64 = 0.0; + let mut min = OrderedFloat::from(std::f64::INFINITY); + let mut max = OrderedFloat::from(std::f64::NEG_INFINITY); + + let mut start: usize = 0; + for digest in digests.iter() { + starts.push(start); + + let curr_count: f64 = digest.count(); + if curr_count > 0.0 { + min = std::cmp::min(min, digest.min); + max = std::cmp::max(max, digest.max); + count += curr_count; + for centroid in &digest.centroids { + centroids.push(centroid.clone()); + start += 1; + } + } + } + + let mut digests_per_block: usize = 1; + while digests_per_block < starts.len() { + for i in (0..starts.len()).step_by(digests_per_block * 2) { + if i + digests_per_block < starts.len() { + let first = starts[i]; + let middle = starts[i + digests_per_block]; + let last = if i + 2 * digests_per_block < starts.len() { + starts[i + 2 * digests_per_block] + } else { + centroids.len() + }; + + debug_assert!(first <= middle && middle <= last); + Self::external_merge(&mut centroids, first, middle, last); + } + } + + digests_per_block *= 2; + } + + let mut result = TDigest::new(max_size); + let mut compressed: Vec = Vec::with_capacity(max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + + let mut iter_centroids = centroids.iter_mut(); + let mut curr = iter_centroids.next().unwrap(); + let mut weight_so_far = curr.weight(); + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + for centroid in iter_centroids { + weight_so_far += centroid.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += centroid.mean() * centroid.weight(); + weights_to_merge += centroid.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = OrderedFloat::from(0.0); + weights_to_merge = OrderedFloat::from(0.0); + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + k_limit += 1.0; + curr = centroid; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr.clone()); + compressed.shrink_to_fit(); + compressed.sort(); + + result.count = OrderedFloat::from(count as f64); + result.min = min; + result.max = max; + result.centroids = compressed; + result + } + + /// To estimate the value located at `q` quantile + pub(crate) fn estimate_quantile(&self, q: f64) -> f64 { + if self.centroids.is_empty() { + return 0.0; + } + + let count_ = self.count; + let rank = OrderedFloat::from(q) * count_; + + let mut pos: usize; + let mut t; + if q > 0.5 { + if q >= 1.0 { + return self.max(); + } + + pos = 0; + t = count_; + + for (k, centroid) in self.centroids.iter().enumerate().rev() { + t -= centroid.weight(); + + if rank >= t { + pos = k; + break; + } + } + } else { + if q <= 0.0 { + return self.min(); + } + + pos = self.centroids.len() - 1; + t = OrderedFloat::from(0.0); + + for (k, centroid) in self.centroids.iter().enumerate() { + if rank < t + centroid.weight() { + pos = k; + break; + } + + t += centroid.weight(); + } + } + + let mut delta = OrderedFloat::from(0.0); + let mut min = self.min; + let mut max = self.max; + + if self.centroids.len() > 1 { + if pos == 0 { + delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean(); + max = self.centroids[pos + 1].mean(); + } else if pos == (self.centroids.len() - 1) { + delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean(); + min = self.centroids[pos - 1].mean(); + } else { + delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean()) + / 2.0; + min = self.centroids[pos - 1].mean(); + max = self.centroids[pos + 1].mean(); + } + } + + let value = self.centroids[pos].mean() + + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta; + Self::clamp(value, min, max).into_inner() + } + + /// This method decomposes the [`TDigest`] and its [`Centroid`] instances + /// into a series of primitive scalar values. + /// + /// First the values of the TDigest are packed, followed by the variable + /// number of centroids packed into a [`ScalarValue::List`] of + /// [`ScalarValue::Float64`]: + /// + /// ```text + /// + /// ┌────────┬────────┬────────┬───────┬────────┬────────┐ + /// │max_size│ sum │ count │ max │ min │centroid│ + /// └────────┴────────┴────────┴───────┴────────┴────────┘ + /// │ + /// ┌─────────────────────┘ + /// ▼ + /// ┌ List ───┐ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 1 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 2 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// ... + /// + /// ``` + /// + /// The [`TDigest::from_scalar_state()`] method reverses this processes, + /// consuming the output of this method and returning an unpacked + /// [`TDigest`]. + pub(crate) fn to_scalar_state(&self) -> Vec { + // Gather up all the centroids + let centroids: Vec<_> = self + .centroids + .iter() + .flat_map(|c| [c.mean().into_inner(), c.weight().into_inner()]) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + vec![ + ScalarValue::UInt64(Some(self.max_size as u64)), + ScalarValue::Float64(Some(self.sum.into_inner())), + ScalarValue::Float64(Some(self.count.into_inner())), + ScalarValue::Float64(Some(self.max.into_inner())), + ScalarValue::Float64(Some(self.min.into_inner())), + ScalarValue::List(Some(Box::new(centroids)), Box::new(DataType::Float64)), + ] + } + + /// Unpack the serialised state of a [`TDigest`] produced by + /// [`Self::to_scalar_state()`]. + /// + /// # Correctness + /// + /// Providing input to this method that was not obtained from + /// [`Self::to_scalar_state()`] results in undefined behaviour and may + /// panic. + pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self { + assert_eq!(state.len(), 6, "invalid TDigest state"); + + let max_size = match &state[0] { + ScalarValue::UInt64(Some(v)) => *v as usize, + v => panic!("invalid max_size type {:?}", v), + }; + + let centroids: Vec<_> = match &state[5] { + ScalarValue::List(Some(c), d) if **d == DataType::Float64 => c + .chunks(2) + .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) + .collect(), + v => panic!("invalid centroids type {:?}", v), + }; + + let max = cast_scalar_f64!(&state[3]); + let min = cast_scalar_f64!(&state[4]); + assert!(max >= min); + + Self { + max_size, + sum: cast_scalar_f64!(state[1]), + count: cast_scalar_f64!(&state[2]), + max, + min, + centroids, + } + } +} + +#[cfg(debug_assertions)] +fn is_sorted(values: &[OrderedFloat]) -> bool { + values.windows(2).all(|w| w[0] <= w[1]) +} + +#[cfg(test)] +mod tests { + use std::iter; + + use super::*; + + // A macro to assert the specified `quantile` estimated by `t` is within the + // allowable relative error bound. + macro_rules! assert_error_bounds { + ($t:ident, quantile = $quantile:literal, want = $want:literal) => { + assert_error_bounds!( + $t, + quantile = $quantile, + want = $want, + allowable_error = 0.01 + ) + }; + ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => { + let ans = $t.estimate_quantile($quantile); + let expected: f64 = $want; + let percentage: f64 = (expected - ans).abs() / expected; + assert!( + percentage < $re, + "relative error {} is more than {}% (got quantile {}, want {})", + percentage, + $re, + ans, + expected + ); + }; + } + + macro_rules! assert_state_roundtrip { + ($t:ident) => { + let state = $t.to_scalar_state(); + let other = TDigest::from_scalar_state(&state); + assert_eq!($t, other); + }; + } + + #[test] + fn test_int64_uniform() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_int64_uniform_with_nulls() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + // Prepend some NULLs + let values = iter::repeat(ScalarValue::Int64(None)) + .take(10) + .chain(values); + // Append some more NULLs + let values = values.chain(iter::repeat(ScalarValue::Int64(None)).take(10)); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_centroid_addition_regression() { + //https://github.com/MnO2/t-digest/pull/1 + + let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0]; + let mut t = TDigest::new(10); + + for v in vals { + t = t.merge_unsorted([ScalarValue::Float64(Some(v))]).unwrap(); + } + + assert_error_bounds!(t, quantile = 0.5, want = 1.0); + assert_error_bounds!(t, quantile = 0.95, want = 2.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_uniform_distro() { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_skewed_distro() { + let t = TDigest::new(100); + let mut values: Vec<_> = (1..=600_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + for _ in 0..400_000 { + values.push(ScalarValue::Float64(Some(1_000_000.0))); + } + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_digests() { + let mut digests: Vec = Vec::new(); + + for _ in 1..=100 { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + let t = t.merge_unsorted(values).unwrap(); + digests.push(t) + } + + let t = TDigest::merge_digests(&digests); + + assert_error_bounds!(t, quantile = 1.0, want = 1000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990.0); + assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_state_roundtrip!(t); + } +} diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index abc75829ea17..0aff006c7896 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -30,10 +30,10 @@ pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ - array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, - count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, - lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, - sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, - Column, JoinType, Partitioning, + approx_percentile_cont, array, ascii, avg, bit_length, btrim, character_length, chr, + col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, + initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length, + random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, + sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, + translate, trim, upper, Column, JoinType, Partitioning, }; diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index b8efc9815636..d5118b30d2af 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -153,6 +153,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_approx_percentile_cont() -> Result<()> { + let expr = approx_percentile_cont(col("b"), lit(0.5)); + + let expected = vec![ + "+-------------------------------------------+", + "| APPROXPERCENTILECONT(test.b,Float64(0.5)) |", + "+-------------------------------------------+", + "| 10 |", + "+-------------------------------------------+", + ]; + + let df = create_test_table()?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + + Ok(()) +} + #[tokio::test] async fn test_fn_character_length() -> Result<()> { let expr = character_length(col("a")); diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 9d72752b091d..736a00318ac7 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -354,6 +354,95 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } +// This test executes the APPROX_PERCENTILE_CONT aggregation against the test +// data, asserting the estimated quantiles are ±5% their actual values. +// +// Actual quantiles calculated with: +// +// ```r +// read_csv("./testing/data/csv/aggregate_test_100.csv") |> +// select_if(is.numeric) |> +// summarise_all(~ quantile(., c(0.1, 0.5, 0.9))) +// ``` +// +// Giving: +// +// ```text +// c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 +// +// 1 1 -95.3 -22925. -1882606710 -7.25e18 18.9 2671. 472608672. 1.83e18 0.109 0.0714 +// 2 3 15.5 4599 377164262 1.13e18 134. 30634 2365817608. 9.30e18 0.491 0.551 +// 3 5 102. 25334. 1991374996. 7.37e18 231 57518. 3776538487. 1.61e19 0.834 0.946 +// ``` +// +// Column `c12` is omitted due to a large relative error (~10%) due to the small +// float values. +#[tokio::test] +async fn csv_query_approx_percentile_cont() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + + // Generate an assertion that the estimated $percentile value for $column is + // within 5% of the $actual percentile value. + macro_rules! percentile_test { + ($ctx:ident, column=$column:literal, percentile=$percentile:literal, actual=$actual:literal) => { + let sql = format!("SELECT (ABS(1 - CAST(approx_percentile_cont({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $percentile, $actual); + let actual = execute_to_batches(&mut ctx, &sql).await; + // + // "+------+", + // "| q |", + // "+------+", + // "| true |", + // "+------+", + // + let want = ["+------+", "| q |", "+------+", "| true |", "+------+"]; + assert_batches_eq!(want, &actual); + }; + } + + percentile_test!(ctx, column = "c2", percentile = 0.1, actual = 1.0); + percentile_test!(ctx, column = "c2", percentile = 0.5, actual = 3.0); + percentile_test!(ctx, column = "c2", percentile = 0.9, actual = 5.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c3", percentile = 0.1, actual = -95.3); + percentile_test!(ctx, column = "c3", percentile = 0.5, actual = 15.5); + percentile_test!(ctx, column = "c3", percentile = 0.9, actual = 102.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c4", percentile = 0.1, actual = -22925.0); + percentile_test!(ctx, column = "c4", percentile = 0.5, actual = 4599.0); + percentile_test!(ctx, column = "c4", percentile = 0.9, actual = 25334.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c5", percentile = 0.1, actual = -1882606710.0); + percentile_test!(ctx, column = "c5", percentile = 0.5, actual = 377164262.0); + percentile_test!(ctx, column = "c5", percentile = 0.9, actual = 1991374996.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c6", percentile = 0.1, actual = -7.25e18); + percentile_test!(ctx, column = "c6", percentile = 0.5, actual = 1.13e18); + percentile_test!(ctx, column = "c6", percentile = 0.9, actual = 7.37e18); + //////////////////////////////////// + percentile_test!(ctx, column = "c7", percentile = 0.1, actual = 18.9); + percentile_test!(ctx, column = "c7", percentile = 0.5, actual = 134.0); + percentile_test!(ctx, column = "c7", percentile = 0.9, actual = 231.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c8", percentile = 0.1, actual = 2671.0); + percentile_test!(ctx, column = "c8", percentile = 0.5, actual = 30634.0); + percentile_test!(ctx, column = "c8", percentile = 0.9, actual = 57518.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c9", percentile = 0.1, actual = 472608672.0); + percentile_test!(ctx, column = "c9", percentile = 0.5, actual = 2365817608.0); + percentile_test!(ctx, column = "c9", percentile = 0.9, actual = 3776538487.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c10", percentile = 0.1, actual = 1.83e18); + percentile_test!(ctx, column = "c10", percentile = 0.5, actual = 9.30e18); + percentile_test!(ctx, column = "c10", percentile = 0.9, actual = 1.61e19); + //////////////////////////////////// + percentile_test!(ctx, column = "c11", percentile = 0.1, actual = 0.109); + percentile_test!(ctx, column = "c11", percentile = 0.5, actual = 0.491); + percentile_test!(ctx, column = "c11", percentile = 0.9, actual = 0.834); + + Ok(()) +} + #[tokio::test] async fn query_count_without_from() -> Result<()> { let mut ctx = ExecutionContext::new();