Skip to content

Commit

Permalink
Minor: support complex expr as the arg in the ApproxPercentileCont fu…
Browse files Browse the repository at this point in the history
…nction (#8580)

* support complex lit expr for the arg

* enchancement the percentile
  • Loading branch information
liukun4515 authored Dec 20, 2023
1 parent b456cf7 commit 1bcaac4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
20 changes: 20 additions & 0 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use datafusion::prelude::*;
use datafusion::execution::context::SessionContext;

use datafusion::assert_batches_eq;
use datafusion_expr::expr::Alias;
use datafusion_expr::{approx_median, cast};

async fn create_test_table() -> Result<DataFrame> {
Expand Down Expand Up @@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {

assert_batches_eq!(expected, &batches);

// the arg2 parameter is a complex expr, but it can be evaluated to the literal value
let alias_expr = Expr::Alias(Alias::new(
cast(lit(0.5), DataType::Float32),
None::<&str>,
"arg_2".to_string(),
));
let expr = approx_percentile_cont(col("b"), alias_expr);
let df = create_test_table().await?;
let expected = [
"+--------------------------------------+",
"| APPROX_PERCENTILE_CONT(test.b,arg_2) |",
"+--------------------------------------+",
"| 10 |",
"+--------------------------------------+",
];
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;

assert_batches_eq!(expected, &batches);

Ok(())
}

Expand Down
45 changes: 21 additions & 24 deletions datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::aggregate::tdigest::TryIntoF64;
use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::{format_state_name, Literal};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::{
array::{
Expand All @@ -27,11 +27,13 @@ use arrow::{
},
datatypes::{DataType, Field},
};
use arrow_array::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::{
downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError,
Result, ScalarValue,
};
use datafusion_expr::Accumulator;
use datafusion_expr::{Accumulator, ColumnarValue};
use std::{any::Any, iter, sync::Arc};

/// APPROX_PERCENTILE_CONT aggregate expression
Expand Down Expand Up @@ -131,18 +133,22 @@ impl PartialEq for ApproxPercentileCont {
}
}

fn get_lit_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
let empty_schema = Schema::empty();
let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema));
let result = expr.evaluate(&empty_batch)?;
match result {
ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!(
"The expr {:?} can't be evaluated to scalar value",
expr
))),
ColumnarValue::Scalar(scalar_value) => Ok(scalar_value),
}
}

fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
// Extract the desired percentile literal
let lit = expr
.as_any()
.downcast_ref::<Literal>()
.ok_or_else(|| {
DataFusionError::Internal(
"desired percentile argument must be float literal".to_string(),
)
})?
.value();
let percentile = match lit {
let lit = get_lit_value(expr)?;
let percentile = match &lit {
ScalarValue::Float32(Some(q)) => *q as f64,
ScalarValue::Float64(Some(q)) => *q,
got => return not_impl_err!(
Expand All @@ -161,17 +167,8 @@ fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
}

fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
// Extract the desired percentile literal
let lit = expr
.as_any()
.downcast_ref::<Literal>()
.ok_or_else(|| {
DataFusionError::Internal(
"desired percentile argument must be float literal".to_string(),
)
})?
.value();
let max_size = match lit {
let lit = get_lit_value(expr)?;
let max_size = match &lit {
ScalarValue::UInt8(Some(q)) => *q as usize,
ScalarValue::UInt16(Some(q)) => *q as usize,
ScalarValue::UInt32(Some(q)) => *q as usize,
Expand Down

0 comments on commit 1bcaac4

Please sign in to comment.