diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index bf5f5afc5791..d348e28ededa 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -246,6 +246,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, filter, order_by, + null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) @@ -1662,6 +1663,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( args, filter, order_by, + null_treatment, }) => { let args = args .iter() @@ -1689,6 +1691,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; + let ignore_nulls = null_treatment + .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let ordering_reqs = order_by.clone().unwrap_or(vec![]); @@ -1699,6 +1704,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( &ordering_reqs, physical_input_schema, name, + ignore_nulls, )?; (agg_expr, filter, order_by) } diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 84b791a3de05..14bc7a3d4f68 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -321,3 +321,83 @@ async fn test_accumulator_row_accumulator() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_first_value() -> Result<()> { + let session_ctx = SessionContext::new(); + session_ctx + .sql("CREATE TABLE abc AS VALUES (null,2,3), (4,5,6)") + .await? + .collect() + .await?; + + let results1 = session_ctx + .sql("SELECT FIRST_VALUE(column1) ignore nulls FROM abc") + .await? + .collect() + .await?; + let expected1 = [ + "+--------------------------+", + "| FIRST_VALUE(abc.column1) |", + "+--------------------------+", + "| 4 |", + "+--------------------------+", + ]; + assert_batches_eq!(expected1, &results1); + + let results2 = session_ctx + .sql("SELECT FIRST_VALUE(column1) respect nulls FROM abc") + .await? + .collect() + .await?; + let expected2 = [ + "+--------------------------+", + "| FIRST_VALUE(abc.column1) |", + "+--------------------------+", + "| |", + "+--------------------------+", + ]; + assert_batches_eq!(expected2, &results2); + + Ok(()) +} + +#[tokio::test] +async fn test_first_value_with_sort() -> Result<()> { + let session_ctx = SessionContext::new(); + session_ctx + .sql("CREATE TABLE abc AS VALUES (null,2,3), (null,1,6), (4, 5, 5), (1, 4, 7), (2, 3, 8)") + .await? + .collect() + .await?; + + let results1 = session_ctx + .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) ignore nulls FROM abc") + .await? + .collect() + .await?; + let expected1 = [ + "+--------------------------+", + "| FIRST_VALUE(abc.column1) |", + "+--------------------------+", + "| 2 |", + "+--------------------------+", + ]; + assert_batches_eq!(expected1, &results1); + + let results2 = session_ctx + .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) respect nulls FROM abc") + .await? + .collect() + .await?; + let expected2 = [ + "+--------------------------+", + "| FIRST_VALUE(abc.column1) |", + "+--------------------------+", + "| |", + "+--------------------------+", + ]; + assert_batches_eq!(expected2, &results2); + + Ok(()) +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 68b123ab1f28..e83d2f1a65f6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -543,6 +543,7 @@ pub struct AggregateFunction { pub filter: Option>, /// Optional ordering pub order_by: Option>, + pub null_treatment: Option, } impl AggregateFunction { @@ -552,6 +553,7 @@ impl AggregateFunction { distinct: bool, filter: Option>, order_by: Option>, + null_treatment: Option, ) -> Self { Self { func_def: AggregateFunctionDefinition::BuiltIn(fun), @@ -559,6 +561,7 @@ impl AggregateFunction { distinct, filter, order_by, + null_treatment, } } @@ -576,6 +579,7 @@ impl AggregateFunction { distinct, filter, order_by, + null_treatment: None, } } } @@ -646,6 +650,7 @@ pub struct WindowFunction { pub order_by: Vec, /// Window frame pub window_frame: window_frame::WindowFrame, + /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, } @@ -1471,9 +1476,13 @@ impl fmt::Display for Expr { ref args, filter, order_by, + null_treatment, .. }) => { fmt_function(f, func_def.name(), *distinct, args, true)?; + if let Some(nt) = null_treatment { + write!(f, " {}", nt)?; + } if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } @@ -1804,6 +1813,7 @@ fn create_name(e: &Expr) -> Result { args, filter, order_by, + null_treatment, }) => { let name = match func_def { AggregateFunctionDefinition::BuiltIn(..) @@ -1823,6 +1833,9 @@ fn create_name(e: &Expr) -> Result { if let Some(order_by) = order_by { info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); }; + if let Some(nt) = null_treatment { + info += &format!(" {}", nt); + } match func_def { AggregateFunctionDefinition::BuiltIn(..) | AggregateFunctionDefinition::Name(..) => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 78546dddd589..99f44a73c1dd 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -150,6 +150,7 @@ pub fn min(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -161,6 +162,7 @@ pub fn max(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -172,6 +174,7 @@ pub fn sum(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -183,6 +186,7 @@ pub fn array_agg(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -194,6 +198,7 @@ pub fn avg(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -205,6 +210,7 @@ pub fn count(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -261,6 +267,7 @@ pub fn count_distinct(expr: Expr) -> Expr { true, None, None, + None, )) } @@ -313,6 +320,7 @@ pub fn approx_distinct(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -324,6 +332,7 @@ pub fn median(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -335,6 +344,7 @@ pub fn approx_median(expr: Expr) -> Expr { false, None, None, + None, )) } @@ -346,6 +356,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { false, None, None, + None, )) } @@ -361,6 +372,7 @@ pub fn approx_percentile_cont_with_weight( false, None, None, + None, )) } @@ -431,6 +443,7 @@ pub fn stddev(expr: Expr) -> Expr { false, None, None, + None, )) } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 67d48f986f13..1c672851e9b5 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -350,6 +350,7 @@ impl TreeNode for Expr { distinct, filter, order_by, + null_treatment, }) => transform_vec(args, &mut f)? .update_data(|new_args| (new_args, filter, order_by)) .try_transform_node(|(new_args, filter, order_by)| { @@ -368,6 +369,7 @@ impl TreeNode for Expr { distinct, new_filter, new_order_by, + null_treatment, ))) } AggregateFunctionDefinition::UDF(fun) => { diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 93b24d71c496..c07445fa7f48 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -158,6 +158,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { distinct, filter, order_by, + null_treatment, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( @@ -166,6 +167,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { distinct, filter, order_by, + null_treatment, ))) } _ => Transformed::no(old_expr), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 08f49ed15b09..496def95e1bc 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -345,6 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { distinct, filter, order_by, + null_treatment, }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let new_expr = coerce_agg_exprs_for_signature( @@ -355,7 +356,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )?; Ok(Transformed::yes(Expr::AggregateFunction( expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, + fun, + new_expr, + distinct, + filter, + order_by, + null_treatment, ), ))) } @@ -946,6 +952,7 @@ mod test { false, None, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation"; @@ -959,6 +966,7 @@ mod test { false, None, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation"; @@ -976,6 +984,7 @@ mod test { false, None, None, + None, )); let err = Projection::try_new(vec![agg_expr], empty) .err() @@ -998,6 +1007,7 @@ mod test { false, None, None, + None, )); let err = Projection::try_new(vec![agg_expr], empty) diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 8b7a9148b590..28b3ff090fe6 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -545,6 +545,7 @@ mod tests { false, Some(Box::new(col("c").gt(lit(42)))), None, + None, )); let plan = LogicalPlanBuilder::from(table_scan) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 187e510e557d..0666c324d12c 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -97,6 +97,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { false, None, sort_expr.clone(), + None, )) }) .collect::>(); diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 7e6fb6b355ab..07a9d84f7d48 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -76,6 +76,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { args, filter, order_by, + null_treatment: _, }) = expr { if filter.is_some() || order_by.is_some() { @@ -196,6 +197,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { false, None, None, + None, )) .alias(&alias_str), ); @@ -205,6 +207,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { false, None, None, + None, ))) } else { Ok(Expr::AggregateFunction(AggregateFunction::new( @@ -213,6 +216,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { false, // intentional to remove distinct here None, None, + None, ))) } } @@ -471,6 +475,7 @@ mod tests { true, None, None, + None, )), ], )? @@ -535,6 +540,7 @@ mod tests { true, None, None, + None, )), ], )? @@ -597,6 +603,7 @@ mod tests { false, Some(Box::new(col("a").gt(lit(5)))), None, + None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? @@ -619,6 +626,7 @@ mod tests { true, Some(Box::new(col("a").gt(lit(5)))), None, + None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? @@ -641,6 +649,7 @@ mod tests { false, None, Some(vec![col("a")]), + None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? @@ -663,6 +672,7 @@ mod tests { true, None, Some(vec![col("a")]), + None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? @@ -685,6 +695,7 @@ mod tests { true, Some(Box::new(col("a").gt(lit(5)))), Some(vec![col("a")]), + None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 0aaf0dc0c8c5..846431034c96 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -46,6 +46,7 @@ pub fn create_aggregate_expr( ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, + ignore_nulls: bool, ) -> Result> { let name = name.into(); // get the result data type for this aggregate function @@ -359,13 +360,16 @@ pub fn create_aggregate_expr( (AggregateFunction::Median, true) => { return not_impl_err!("MEDIAN(DISTINCT) aggregations are not available"); } - (AggregateFunction::FirstValue, _) => Arc::new(expressions::FirstValue::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - ordering_req.to_vec(), - ordering_types, - )), + (AggregateFunction::FirstValue, _) => Arc::new( + expressions::FirstValue::new( + input_phy_exprs[0].clone(), + name, + input_phy_types[0].clone(), + ordering_req.to_vec(), + ordering_types, + ) + .with_ignore_nulls(ignore_nulls), + ), (AggregateFunction::LastValue, _) => Arc::new(expressions::LastValue::new( input_phy_exprs[0].clone(), name, @@ -1308,7 +1312,15 @@ mod tests { "Invalid or wrong number of arguments passed to aggregate: '{name}'" ); } - create_aggregate_expr(fun, distinct, &coerced_phy_exprs, &[], input_schema, name) + create_aggregate_expr( + fun, + distinct, + &coerced_phy_exprs, + &[], + input_schema, + name, + false, + ) } // Returns the coerced exprs for each `input_exprs`. diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index d2bf48551f0d..17dd3ef1206d 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -44,6 +44,7 @@ pub struct FirstValue { expr: Arc, ordering_req: LexOrdering, requirement_satisfied: bool, + ignore_nulls: bool, } impl FirstValue { @@ -63,9 +64,15 @@ impl FirstValue { expr, ordering_req, requirement_satisfied, + ignore_nulls: false, } } + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } + /// Returns the name of the aggregate expression. pub fn name(&self) -> &str { &self.name @@ -134,6 +141,7 @@ impl AggregateExpr for FirstValue { &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), + self.ignore_nulls, ) .map(|acc| { Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ @@ -179,6 +187,7 @@ impl AggregateExpr for FirstValue { &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), + self.ignore_nulls, ) .map(|acc| { Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ @@ -213,6 +222,8 @@ struct FirstValueAccumulator { ordering_req: LexOrdering, // Stores whether incoming data already satisfies the ordering requirement. requirement_satisfied: bool, + // Ignore null values. + ignore_nulls: bool, } impl FirstValueAccumulator { @@ -221,6 +232,7 @@ impl FirstValueAccumulator { data_type: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + ignore_nulls: bool, ) -> Result { let orderings = ordering_dtypes .iter() @@ -233,6 +245,7 @@ impl FirstValueAccumulator { orderings, ordering_req, requirement_satisfied, + ignore_nulls, }) } @@ -249,7 +262,18 @@ impl FirstValueAccumulator { }; if self.requirement_satisfied { // Get first entry according to the pre-existing ordering (0th index): - return Ok((!value.is_empty()).then_some(0)); + if self.ignore_nulls { + // If ignoring nulls, find the first non-null value. + for i in 0..value.len() { + if !value.is_null(i) { + return Ok(Some(i)); + } + } + return Ok(None); + } else { + // If not ignoring nulls, return the first value if it exists. + return Ok((!value.is_empty()).then_some(0)); + } } let sort_columns = ordering_values .iter() @@ -259,8 +283,20 @@ impl FirstValueAccumulator { options: Some(req.options), }) .collect::>(); - let indices = lexsort_to_indices(&sort_columns, Some(1))?; - Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + + if self.ignore_nulls { + let indices = lexsort_to_indices(&sort_columns, None)?; + // If ignoring nulls, find the first non-null value. + for index in indices.iter().flatten() { + if !value.is_null(index as usize) { + return Ok(Some(index as usize)); + } + } + Ok(None) + } else { + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + } } fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { @@ -708,7 +744,7 @@ mod tests { #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; let mut last_accumulator = LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; // first value in the tuple is start of the range (inclusive), @@ -748,13 +784,13 @@ mod tests { // FirstValueAccumulator let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; first_accumulator.update_batch(&[arrs[0].clone()])?; let state1 = first_accumulator.state()?; let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; first_accumulator.update_batch(&[arrs[1].clone()])?; let state2 = first_accumulator.state()?; @@ -770,7 +806,7 @@ mod tests { } let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; first_accumulator.merge_batch(&states)?; let merged_state = first_accumulator.state()?; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 8993c630aa49..1c620c22a164 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -213,6 +213,7 @@ mod tests { &[], &schema, "agg", + false, ) .unwrap(); diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index f9896bafca15..26d649f57201 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -193,9 +193,16 @@ pub(crate) mod tests { .unwrap(); let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); - let agg = - create_aggregate_expr(&function, distinct, &[input], &[], &schema, "agg") - .unwrap(); + let agg = create_aggregate_expr( + &function, + distinct, + &[input], + &[], + &schema, + "agg", + false, + ) + .unwrap(); let result = aggregate(&batch, agg).unwrap(); assert_eq!(expected, result); diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index c19694aef8b7..54731f0d812b 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -74,6 +74,7 @@ pub fn create_window_expr( &[], input_schema, name, + ignore_nulls, )?; window_expr_from_aggregate_expr( partition_by, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index fb056c78291a..6476afca43bd 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1185,6 +1185,7 @@ pub fn parse_expr( parse_optional_expr(expr.filter.as_deref(), registry, codec)? .map(Box::new), parse_vec_expr(&expr.order_by, registry, codec)?, + None, ))) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 96f750f3d22a..0ee43ffd2716 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -663,6 +663,7 @@ pub fn serialize_expr( ref distinct, ref filter, ref order_by, + null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a20baeb4e941..a4c08d76867d 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -468,6 +468,7 @@ impl AsExecutionPlan for PhysicalPlanNode { &ordering_req, &physical_schema, name.to_string(), + false, ) } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 702ae99babd8..fad50d3ecddc 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1664,6 +1664,7 @@ fn roundtrip_count() { false, None, None, + None, )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1677,6 +1678,7 @@ fn roundtrip_count_distinct() { true, None, None, + None, )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1690,6 +1692,7 @@ fn roundtrip_approx_percentile_cont() { false, None, None, + None, )); let ctx = SessionContext::new(); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ad1d2db70cf4..bcf641e4b5a0 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -57,11 +57,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // required ordering should be defined in OVER clause. let is_function_window = over.is_some(); - match null_treatment { - Some(null_treatment) if !is_function_window => return not_impl_err!("Null treatment in aggregate functions is not supported: {null_treatment}"), - _ => {} - } - let name = if name.0.len() > 1 { // DF doesn't handle compound identifiers // (e.g. "foo.bar") for function names yet @@ -199,7 +194,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(Box::new); return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, filter, order_by, + fun, + args, + distinct, + filter, + order_by, + null_treatment, ))); }; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d36d973cbee6..aa0b619167dc 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -216,6 +216,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { agg_func.distinct, agg_func.filter.clone(), agg_func.order_by.clone(), + agg_func.null_treatment, )), true) }, _ => (expr, false), @@ -620,6 +621,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { distinct, None, order_by, + None, ))) // see if we can rewrite it into NTH-VALUE } @@ -770,6 +772,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { args, distinct, order_by, + null_treatment, .. }) => Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, @@ -781,6 +784,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?)), order_by, + null_treatment, ))), _ => plan_err!( "AggregateExpressionWithFilter expression was not an AggregateFunction" diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index db1beb94446b..57e2e1ef06a7 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -24,8 +24,7 @@ use arrow_schema::*; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use datafusion_common::{ - assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, - TableReference, + config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ @@ -1288,16 +1287,6 @@ fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() { ); } -#[test] -fn select_simple_aggregate_respect_nulls() { - let sql = "SELECT MIN(age) RESPECT NULLS FROM person"; - let err = logical_plan(sql).expect_err("query should have failed"); - - assert_contains!( - err.strip_backtrace(), - "This feature is not implemented: Null treatment in aggregate functions is not supported: RESPECT NULLS" - ); -} #[test] fn select_from_typed_string_values() { quick_test( diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index b78c6287746c..266f04580f11 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3323,3 +3323,55 @@ SELECT CAST(a AS INT) FROM t GROUP BY t.a; statement ok DROP TABLE t; + +# Test for ignore null in FIRST_VALUE +statement ok +CREATE TABLE t AS VALUES (null::bigint), (3), (4); + +query I +SELECT FIRST_VALUE(column1) FROM t; +---- +NULL + +query I +SELECT FIRST_VALUE(column1) RESPECT NULLS FROM t; +---- +NULL + +query I +SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t; +---- +3 + +statement ok +DROP TABLE t; + +# Test for ignore null with ORDER BY in FIRST_VALUE +statement ok +CREATE TABLE t AS VALUES (3, 4), (4, 3), (null::bigint, 1), (null::bigint, 1); + +query I +SELECT column1 FROM t ORDER BY column2; +---- +NULL +NULL +4 +3 + +query I +SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; +---- +NULL + +query I +SELECT FIRST_VALUE(column1 ORDER BY column2) RESPECT NULLS FROM t; +---- +NULL + +query I +SELECT FIRST_VALUE(column1 ORDER BY column2) IGNORE NULLS FROM t; +---- +4 + +statement ok +DROP TABLE t; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 095806c538d1..ed1e48ca71a6 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -754,7 +754,7 @@ pub async fn from_substrait_agg_func( } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new(fun, args, distinct, filter, order_by), + expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), ))) } else { not_impl_err!( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 9b29c0c67765..a6a38ab6145c 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -672,7 +672,7 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { + Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { match func_def { AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by {