Skip to content

Commit

Permalink
fix error
Browse files Browse the repository at this point in the history
  • Loading branch information
Huaxin Gao committed Mar 2, 2024
1 parent d061f06 commit b3f8952
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 140 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
args,
filter,
order_by,
null_treatment,
..
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
Expand Down
9 changes: 8 additions & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
&self.schema,
&fun.signature(),
)?;
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(fun, new_expr, distinct, filter, order_by, null_treatment));
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
new_expr,
distinct,
filter,
order_by,
null_treatment,
));
Ok(expr)
}
AggregateFunctionDefinition::UDF(fun) => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
args,
filter,
order_by,
null_treatment,
..
}) = expr
{
if filter.is_some() || order_by.is_some() {
Expand Down
27 changes: 19 additions & 8 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,16 @@ pub fn create_aggregate_expr(
(AggregateFunction::Median, true) => {
return not_impl_err!("MEDIAN(DISTINCT) aggregations are not available");
}
(AggregateFunction::FirstValue, _) => Arc::new(expressions::FirstValue::new(
input_phy_exprs[0].clone(),
name,
input_phy_types[0].clone(),
ordering_req.to_vec(),
ordering_types,
).ignore_null(ignore_nulls)),
(AggregateFunction::FirstValue, _) => Arc::new(
expressions::FirstValue::new(
input_phy_exprs[0].clone(),
name,
input_phy_types[0].clone(),
ordering_req.to_vec(),
ordering_types,
)
.ignore_null(ignore_nulls),
),
(AggregateFunction::LastValue, _) => Arc::new(expressions::LastValue::new(
input_phy_exprs[0].clone(),
name,
Expand Down Expand Up @@ -1309,7 +1312,15 @@ mod tests {
"Invalid or wrong number of arguments passed to aggregate: '{name}'"
);
}
create_aggregate_expr(fun, distinct, &coerced_phy_exprs, &[], input_schema, name, false)
create_aggregate_expr(
fun,
distinct,
&coerced_phy_exprs,
&[],
input_schema,
name,
false,
)
}

// Returns the coerced exprs for each `input_exprs`.
Expand Down
115 changes: 3 additions & 112 deletions datafusion/physical-expr/src/aggregate/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl FirstValue {
}
}

pub fn ignore_null(mut self, ignore_null: bool) -> Self{
pub fn ignore_null(mut self, ignore_null: bool) -> Self {
self.ignore_null = ignore_null;
self
}
Expand Down Expand Up @@ -735,11 +735,9 @@ mod tests {

use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator};

use crate::expressions::{col, FirstValue};
use crate::{AggregateExpr, PhysicalSortExpr};
use arrow::compute::concat;
use arrow_array::{ArrayRef, Int32Array, Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema, SortOptions};
use arrow_array::{ArrayRef, Int64Array};
use arrow_schema::DataType;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

Expand Down Expand Up @@ -846,111 +844,4 @@ mod tests {

Ok(())
}

#[test]
fn first_ignore_null() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
None,
Some(2),
None,
None,
Some(3),
Some(9),
]));
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);

let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
let a_expr = col("a", &schema)?;

let agg1 = Arc::new(FirstValue::new(
a_expr.clone(),
"first1",
DataType::Int32,
vec![],
vec![],
).ignore_null(true));
let first1 = aggregate(&batch, agg1)?;
assert_eq!(first1, ScalarValue::Int32(Some(2)));

let agg2 = Arc::new(FirstValue::new(
a_expr.clone(),
"first1",
DataType::Int32,
vec![],
vec![],
));
let first2 = aggregate(&batch, agg2)?;
assert_eq!(first2, ScalarValue::Int32(None));

Ok(())
}

#[test]
fn first_ignore_null_with_sort() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(12),
None,
None,
None,
Some(10),
Some(9),
]));
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);

let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?;
let a_expr = col("a", &schema)?;

let option_desc = SortOptions {
descending: false,
nulls_first: true,
};
let sort_expr_a = vec![PhysicalSortExpr {
expr: a_expr.clone(),
options: option_desc,
}];

let agg1 = Arc::new(FirstValue::new(
a_expr.clone(),
"first1",
DataType::Int32,
sort_expr_a.clone(),
vec![DataType::Int32],
).ignore_null(true));
let first1 = aggregate(&batch, agg1)?;
assert_eq!(first1, ScalarValue::Int32(Some(9)));

let agg2 = Arc::new(FirstValue::new(
a_expr.clone(),
"first2",
DataType::Int32,
sort_expr_a.clone(),
vec![DataType::Int32],
));
let first2 = aggregate(&batch, agg2)?;
assert_eq!(first2, ScalarValue::Int32(None));

Ok(())
}

pub fn aggregate(
batch: &RecordBatch,
agg: Arc<dyn AggregateExpr>,
) -> Result<ScalarValue> {
let mut accum = agg.create_accumulator()?;
let mut expr = agg.expressions();
if let Some(ordering_req) = agg.order_bys() {
expr.extend(ordering_req.iter().map(|item| item.expr.clone()));
}

let values = expr
.iter()
.map(|e| {
e.evaluate(batch)
.and_then(|v| v.into_array(batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;

accum.update_batch(&values)?;
accum.evaluate()
}
}
13 changes: 10 additions & 3 deletions datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,16 @@ pub(crate) mod tests {
.unwrap();

let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]);
let agg =
create_aggregate_expr(&function, distinct, &[input], &[], &schema, "agg", false)
.unwrap();
let agg = create_aggregate_expr(
&function,
distinct,
&[input],
&[],
&schema,
"agg",
false,
)
.unwrap();

let result = aggregate(&batch, agg).unwrap();
assert_eq!(expected, result);
Expand Down
7 changes: 6 additions & 1 deletion datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.map(Box::new);

return Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
fun, args, distinct, filter, order_by, null_treatment,
fun,
args,
distinct,
filter,
order_by,
null_treatment,
)));
};

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
agg_func.distinct,
agg_func.filter.clone(),
agg_func.order_by.clone(),
agg_func.null_treatment.clone(),
agg_func.null_treatment,
)), true)
},
_ => (expr, false),
Expand Down
13 changes: 1 addition & 12 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ use arrow_schema::*;
use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};

use datafusion_common::{
assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue,
TableReference,
config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference,
};
use datafusion_common::{plan_err, ParamValues};
use datafusion_expr::{
Expand Down Expand Up @@ -1288,16 +1287,6 @@ fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() {
);
}

#[test]
fn select_simple_aggregate_respect_nulls() {
let sql = "SELECT MIN(age) RESPECT NULLS FROM person";
let err = logical_plan(sql).expect_err("query should have failed");

assert_contains!(
err.strip_backtrace(),
"This feature is not implemented: Null treatment in aggregate functions is not supported: RESPECT NULLS"
);
}
#[test]
fn select_from_typed_string_values() {
quick_test(
Expand Down

0 comments on commit b3f8952

Please sign in to comment.