From 4a8741b2b86cf6fb8bdca8cd2d4d5baad5cd0d5c Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 19 Dec 2023 19:10:39 +0800 Subject: [PATCH] enchancement the percentile --- .../tests/dataframe/dataframe_functions.rs | 20 ++++++++++++++++ .../src/aggregate/approx_percentile_cont.rs | 23 +++++++++---------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 9677003ec226f..fe56fc22ea8cc 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -31,6 +31,7 @@ use datafusion::prelude::*; use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; +use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast}; async fn create_test_table() -> Result { @@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); + // the arg2 parameter is a complex expr, but it can be evaluated to the literal value + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.5), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b"), alias_expr); + let df = create_test_table().await?; + let expected = [ + "+--------------------------------------+", + "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "+--------------------------------------+", + "| 10 |", + "+--------------------------------------+", + ]; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index 81b93afa8131a..d1374cc02bdc0 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use crate::aggregate::tdigest::TryIntoF64; use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; +use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{ array::{ @@ -27,11 +27,13 @@ use arrow::{ }, datatypes::{DataType, Field}, }; +use arrow_array::RecordBatch; +use arrow_schema::Schema; use datafusion_common::{ downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, ColumnarValue}; use std::{any::Any, iter, sync::Arc}; /// APPROX_PERCENTILE_CONT aggregate expression @@ -136,20 +138,17 @@ fn get_float_lit_value(expr: &Arc) -> Result { let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); let result = expr.evaluate(&empty_batch)?; match result { - ColumnarValue::Array(_) => { - DataFusionError::Internal( - format!("The expr {:?} can't be evaluated to scalar value", expr) - ) - } - ColumnarValue::Scalar(scalar_value) => { - scalar_value - } + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), } } fn validate_input_percentile_expr(expr: &Arc) -> Result { let lit = get_float_lit_value(expr)?; - let percentile = match lit { + let percentile = match &lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, got => return not_impl_err!( @@ -169,7 +168,7 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { fn validate_input_max_size_expr(expr: &Arc) -> Result { let lit = get_float_lit_value(expr)?; - let max_size = match lit { + let max_size = match &lit { ScalarValue::UInt8(Some(q)) => *q as usize, ScalarValue::UInt16(Some(q)) => *q as usize, ScalarValue::UInt32(Some(q)) => *q as usize,