Skip to content

Commit

Permalink
enchancement the percentile
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Dec 19, 2023
1 parent 0ddf443 commit dfb84fe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
20 changes: 20 additions & 0 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use datafusion::prelude::*;
use datafusion::execution::context::SessionContext;

use datafusion::assert_batches_eq;
use datafusion_expr::expr::Alias;
use datafusion_expr::{approx_median, cast};

async fn create_test_table() -> Result<DataFrame> {
Expand Down Expand Up @@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {

assert_batches_eq!(expected, &batches);

// the arg2 parameter is a complex expr, but it can't be evaluated to the literal value
let alias_expr = Expr::Alias(Alias::new(
cast(lit(0.5), DataType::Float32),
None::<&str>,
"arg_2".to_string(),
));
let expr = approx_percentile_cont(col("b"), alias_expr);
let df = create_test_table().await?;
let expected = [
"+--------------------------------------+",
"| APPROX_PERCENTILE_CONT(test.b,arg_2) |",
"+--------------------------------------+",
"| 10 |",
"+--------------------------------------+",
];
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;

assert_batches_eq!(expected, &batches);

Ok(())
}

Expand Down
23 changes: 11 additions & 12 deletions datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::aggregate::tdigest::TryIntoF64;
use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::{format_state_name, Literal};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::{
array::{
Expand All @@ -27,11 +27,13 @@ use arrow::{
},
datatypes::{DataType, Field},
};
use arrow_array::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::{
downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError,
Result, ScalarValue,
};
use datafusion_expr::Accumulator;
use datafusion_expr::{Accumulator, ColumnarValue};
use std::{any::Any, iter, sync::Arc};

/// APPROX_PERCENTILE_CONT aggregate expression
Expand Down Expand Up @@ -136,20 +138,17 @@ fn get_float_lit_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema));
let result = expr.evaluate(&empty_batch)?;
match result {
ColumnarValue::Array(_) => {
DataFusionError::Internal(
format!("The expr {:?} can't be evaluated to scalar value", expr)
)
}
ColumnarValue::Scalar(scalar_value) => {
scalar_value
}
ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!(
"The expr {:?} can't be evaluated to scalar value",
expr
))),
ColumnarValue::Scalar(scalar_value) => Ok(scalar_value),
}
}

fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
let lit = get_float_lit_value(expr)?;
let percentile = match lit {
let percentile = match &lit {
ScalarValue::Float32(Some(q)) => *q as f64,
ScalarValue::Float64(Some(q)) => *q,
got => return not_impl_err!(
Expand All @@ -169,7 +168,7 @@ fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {

fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
let lit = get_float_lit_value(expr)?;
let max_size = match lit {
let max_size = match &lit {
ScalarValue::UInt8(Some(q)) => *q as usize,
ScalarValue::UInt16(Some(q)) => *q as usize,
ScalarValue::UInt32(Some(q)) => *q as usize,
Expand Down

0 comments on commit dfb84fe

Please sign in to comment.