From a851ecf1cc24a6b867d40087d8e890b9307137c1 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 22 Feb 2024 19:16:30 -0800 Subject: [PATCH] Support IGNORE NULLS for LAG window function (#9221) * WIP lag/lead ignore nulls * Support IGNORE NULLS for LAG function * fmt * comments * remove comments * Add new tests, minor changes, trigger evalaute_all * Make algorithm pruning friendly --------- Co-authored-by: Mustafa Akur --- datafusion/core/src/dataframe/mod.rs | 1 + .../core/src/physical_optimizer/test_utils.rs | 1 + datafusion/core/src/physical_planner.rs | 6 + datafusion/core/tests/dataframe/mod.rs | 1 + .../core/tests/fuzz_cases/window_fuzz.rs | 3 + datafusion/expr/src/expr.rs | 18 +++ datafusion/expr/src/tree_node/expr.rs | 2 + datafusion/expr/src/udwf.rs | 1 + datafusion/expr/src/utils.rs | 10 ++ .../src/analyzer/count_wildcard_rule.rs | 3 + .../optimizer/src/analyzer/type_coercion.rs | 2 + .../optimizer/src/push_down_projection.rs | 2 + .../physical-expr/src/window/lead_lag.rs | 88 ++++++++++++-- datafusion/physical-plan/src/windows/mod.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 6 + datafusion/proto/src/logical_plan/to_proto.rs | 2 + .../proto/src/physical_plan/from_proto.rs | 1 + .../tests/cases/roundtrip_logical_plan.rs | 6 + datafusion/sql/src/expr/function.rs | 16 ++- datafusion/sqllogictest/test_files/window.slt | 107 ++++++++++++++++++ .../substrait/src/logical_plan/consumer.rs | 1 + .../substrait/src/logical_plan/producer.rs | 1 + 22 files changed, 272 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 4ec16ac942b2..e407c477ae4c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1685,6 +1685,7 @@ mod tests { vec![col("aggregate_test_100.c2")], vec![], WindowFrame::new(None), + None, )); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index ca7fb78d21b1..3898fb6345f0 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -245,6 +245,7 @@ pub fn bounded_window_exec( &sort_exprs, Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), + false, ) .unwrap()], input.clone(), diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index dabf0a91b2d3..23ac7e08cad8 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -100,6 +100,7 @@ use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; +use sqlparser::ast::NullTreatment; fn create_function_physical_name( fun: &str, @@ -1581,6 +1582,7 @@ pub fn create_window_expr_with_name( partition_by, order_by, window_frame, + null_treatment, }) => { let args = args .iter() @@ -1605,6 +1607,9 @@ pub fn create_window_expr_with_name( } let window_frame = Arc::new(window_frame.clone()); + let ignore_nulls = null_treatment + .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + == NullTreatment::IgnoreNulls; windows::create_window_expr( fun, name, @@ -1613,6 +1618,7 @@ pub fn create_window_expr_with_name( &order_by, window_frame, physical_input_schema, + ignore_nulls, ) } other => plan_err!("Invalid window expression '{other:?}'"), diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index f650e9e39d88..b08b2b8fc7a2 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -182,6 +182,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), ), + None, ))])? .explain(false, false)? .collect() diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index d22d0c0f2ee0..609d26c9c253 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -281,6 +281,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &orderby_exprs, Arc::new(window_frame), schema.as_ref(), + false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![window_expr], @@ -642,6 +643,7 @@ async fn run_window_test( &orderby_exprs, Arc::new(window_frame.clone()), schema.as_ref(), + false, ) .unwrap()], exec1, @@ -664,6 +666,7 @@ async fn run_window_test( &orderby_exprs, Arc::new(window_frame.clone()), schema.as_ref(), + false, ) .unwrap()], exec2, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 09de4b708de9..f40ccb6cdb58 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -30,6 +30,7 @@ use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; +use sqlparser::ast::NullTreatment; use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; @@ -646,6 +647,7 @@ pub struct WindowFunction { pub order_by: Vec, /// Window frame pub window_frame: window_frame::WindowFrame, + pub null_treatment: Option, } impl WindowFunction { @@ -656,6 +658,7 @@ impl WindowFunction { partition_by: Vec, order_by: Vec, window_frame: window_frame::WindowFrame, + null_treatment: Option, ) -> Self { Self { fun, @@ -663,6 +666,7 @@ impl WindowFunction { partition_by, order_by, window_frame, + null_treatment, } } } @@ -1440,8 +1444,14 @@ impl fmt::Display for Expr { partition_by, order_by, window_frame, + null_treatment, }) => { fmt_function(f, &fun.to_string(), false, args, true)?; + + if let Some(nt) = null_treatment { + write!(f, "{}", nt)?; + } + if !partition_by.is_empty() { write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; } @@ -1768,15 +1778,23 @@ fn create_name(e: &Expr) -> Result { window_frame, partition_by, order_by, + null_treatment, }) => { let mut parts: Vec = vec![create_function_name(&fun.to_string(), false, args)?]; + + if let Some(nt) = null_treatment { + parts.push(format!("{}", nt)); + } + if !partition_by.is_empty() { parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by))); } + if !order_by.is_empty() { parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by))); } + parts.push(format!("{window_frame}")); Ok(parts.join(" ")) } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index add15b3d7ad7..def25ed9242f 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -283,12 +283,14 @@ impl TreeNode for Expr { partition_by, order_by, window_frame, + null_treatment, }) => Expr::WindowFunction(WindowFunction::new( fun, transform_vec(args, &mut transform)?, transform_vec(partition_by, &mut transform)?, transform_vec(order_by, &mut transform)?, window_frame, + null_treatment, )), Expr::AggregateFunction(AggregateFunction { args, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 953483408865..7e3eb6c001a1 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -130,6 +130,7 @@ impl WindowUDF { partition_by, order_by, window_frame, + null_treatment: None, }) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index e855554f3687..2fda81d8896f 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1255,6 +1255,7 @@ mod tests { vec![], vec![], WindowFrame::new(None), + None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), @@ -1262,6 +1263,7 @@ mod tests { vec![], vec![], WindowFrame::new(None), + None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), @@ -1269,6 +1271,7 @@ mod tests { vec![], vec![], WindowFrame::new(None), + None, )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), @@ -1276,6 +1279,7 @@ mod tests { vec![], vec![], WindowFrame::new(None), + None, )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1298,6 +1302,7 @@ mod tests { vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(Some(false)), + None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), @@ -1305,6 +1310,7 @@ mod tests { vec![], vec![], WindowFrame::new(None), + None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), @@ -1312,6 +1318,7 @@ mod tests { vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(Some(false)), + None, )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), @@ -1319,6 +1326,7 @@ mod tests { vec![], vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], WindowFrame::new(Some(false)), + None, )); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; @@ -1353,6 +1361,7 @@ mod tests { Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), ], WindowFrame::new(Some(false)), + None, )), Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), @@ -1364,6 +1373,7 @@ mod tests { Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), ], WindowFrame::new(Some(false)), + None, )), ]; let expected = vec![ diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 35a859783239..9242e68562c6 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -128,6 +128,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { partition_by, order_by, window_frame, + null_treatment, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { Expr::WindowFunction(expr::WindowFunction { @@ -138,6 +139,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { partition_by, order_by, window_frame, + null_treatment, }) } @@ -351,6 +353,7 @@ mod tests { WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), ), + None, ))])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index fba77047dd74..8cdb4d7dbdf6 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -392,6 +392,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { partition_by, order_by, window_frame, + null_treatment, }) => { let window_frame = coerce_window_frame(window_frame, &self.schema, &order_by)?; @@ -414,6 +415,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { partition_by, order_by, window_frame, + null_treatment, )); Ok(expr) } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 6a003ecb5fa8..8b7a9148b590 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -587,6 +587,7 @@ mod tests { vec![col("test.b")], vec![], WindowFrame::new(None), + None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( @@ -595,6 +596,7 @@ mod tests { vec![], vec![], WindowFrame::new(None), + None, )); let col1 = col(max1.display_name()?); let col2 = col(max2.display_name()?); diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 6a33f26ca126..6e1aad575f6a 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -23,10 +23,14 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; +use arrow_array::Array; +use datafusion_common::{ + arrow_datafusion_err, exec_datafusion_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::PartitionEvaluator; use std::any::Any; use std::cmp::min; +use std::collections::VecDeque; use std::ops::{Neg, Range}; use std::sync::Arc; @@ -39,6 +43,7 @@ pub struct WindowShift { shift_offset: i64, expr: Arc, default_value: Option, + ignore_nulls: bool, } impl WindowShift { @@ -60,6 +65,7 @@ pub fn lead( expr: Arc, shift_offset: Option, default_value: Option, + ignore_nulls: bool, ) -> WindowShift { WindowShift { name, @@ -67,6 +73,7 @@ pub fn lead( shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), expr, default_value, + ignore_nulls, } } @@ -77,6 +84,7 @@ pub fn lag( expr: Arc, shift_offset: Option, default_value: Option, + ignore_nulls: bool, ) -> WindowShift { WindowShift { name, @@ -84,6 +92,7 @@ pub fn lag( shift_offset: shift_offset.unwrap_or(1), expr, default_value, + ignore_nulls, } } @@ -110,6 +119,8 @@ impl BuiltInWindowFunctionExpr for WindowShift { Ok(Box::new(WindowShiftEvaluator { shift_offset: self.shift_offset, default_value: self.default_value.clone(), + ignore_nulls: self.ignore_nulls, + non_null_offsets: VecDeque::new(), })) } @@ -120,6 +131,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { shift_offset: -self.shift_offset, expr: self.expr.clone(), default_value: self.default_value.clone(), + ignore_nulls: self.ignore_nulls, })) } } @@ -128,6 +140,16 @@ impl BuiltInWindowFunctionExpr for WindowShift { pub(crate) struct WindowShiftEvaluator { shift_offset: i64, default_value: Option, + ignore_nulls: bool, + // VecDeque contains offset values that between non-null entries + non_null_offsets: VecDeque, +} + +impl WindowShiftEvaluator { + fn is_lag(&self) -> bool { + // Mode is LAG, when shift_offset is positive + self.shift_offset > 0 + } } fn create_empty_array( @@ -182,9 +204,13 @@ fn shift_with_default_value( impl PartitionEvaluator for WindowShiftEvaluator { fn get_range(&self, idx: usize, n_rows: usize) -> Result> { - if self.shift_offset > 0 { - let offset = self.shift_offset as usize; - let start = idx.saturating_sub(offset); + if self.is_lag() { + let start = if self.non_null_offsets.len() == self.shift_offset as usize { + let offset: usize = self.non_null_offsets.iter().sum(); + idx.saturating_sub(offset + 1) + } else { + 0 + }; let end = idx + 1; Ok(Range { start, end }) } else { @@ -196,7 +222,7 @@ impl PartitionEvaluator for WindowShiftEvaluator { fn is_causal(&self) -> bool { // Lagging windows are causal by definition: - self.shift_offset > 0 + self.is_lag() } fn evaluate( @@ -204,17 +230,57 @@ impl PartitionEvaluator for WindowShiftEvaluator { values: &[ArrayRef], range: &Range, ) -> Result { + // TODO: try to get rid of i64 usize conversion + // TODO: do not recalculate default value every call + // TODO: support LEAD mode for IGNORE NULLS let array = &values[0]; let dtype = array.data_type(); + let len = array.len() as i64; // LAG mode - let idx = if self.shift_offset > 0 { + let mut idx = if self.is_lag() { range.end as i64 - self.shift_offset - 1 } else { // LEAD mode range.start as i64 - self.shift_offset }; - if idx < 0 || idx as usize >= array.len() { + // Support LAG only for now, as LEAD requires some brainstorm first + // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows + // If current row index points to NULL value the row is NOT counted + if self.ignore_nulls && self.is_lag() { + // Find the nonNULL row index that shifted by offset comparing to current row index + idx = if self.non_null_offsets.len() == self.shift_offset as usize { + let total_offset: usize = self.non_null_offsets.iter().sum(); + (range.end - 1 - total_offset) as i64 + } else { + -1 + }; + + // Keep track of offset values between non-null entries + if array.is_valid(range.end - 1) { + // Non-null add new offset + self.non_null_offsets.push_back(1); + if self.non_null_offsets.len() > self.shift_offset as usize { + // WE do not need to keep track of more than `lag number of offset` values. + self.non_null_offsets.pop_front(); + } + } else if !self.non_null_offsets.is_empty() { + // Entry is null, increment offset value of the last entry. + let end_idx = self.non_null_offsets.len() - 1; + self.non_null_offsets[end_idx] += 1; + } + } else if self.ignore_nulls && !self.is_lag() { + // IGNORE NULLS and LEAD mode. + return Err(exec_datafusion_err!( + "IGNORE NULLS mode for LEAD is not supported for BoundedWindowAggExec" + )); + } + + // Set the default value if + // - index is out of window bounds + // OR + // - ignore nulls mode and current value is null and is within window bounds + if idx < 0 || idx >= len || (self.ignore_nulls && array.is_null(idx as usize)) { get_default_value(self.default_value.as_ref(), dtype) } else { ScalarValue::try_from_array(array, idx as usize) @@ -226,6 +292,11 @@ impl PartitionEvaluator for WindowShiftEvaluator { values: &[ArrayRef], _num_rows: usize, ) -> Result { + if self.ignore_nulls { + return Err(exec_datafusion_err!( + "IGNORE NULLS mode for LAG and LEAD is not supported for WindowAggExec" + )); + } // LEAD, LAG window functions take single column, values will have size 1 let value = &values[0]; shift_with_default_value(value, self.shift_offset, self.default_value.as_ref()) @@ -279,6 +350,7 @@ mod tests { Arc::new(Column::new("c3", 0)), None, None, + false, ), [ Some(-2), @@ -301,6 +373,7 @@ mod tests { Arc::new(Column::new("c3", 0)), None, None, + false, ), [ None, @@ -323,6 +396,7 @@ mod tests { Arc::new(Column::new("c3", 0)), None, Some(ScalarValue::Int32(Some(100))), + false, ), [ Some(100), diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 693d20e90a66..bf6ed925356c 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -55,6 +55,7 @@ pub use datafusion_physical_expr::window::{ }; /// Create a physical expression for window function +#[allow(clippy::too_many_arguments)] pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, @@ -63,6 +64,7 @@ pub fn create_window_expr( order_by: &[PhysicalSortExpr], window_frame: Arc, input_schema: &Schema, + ignore_nulls: bool, ) -> Result> { Ok(match fun { WindowFunctionDefinition::AggregateFunction(fun) => { @@ -83,7 +85,7 @@ pub fn create_window_expr( } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, + create_built_in_window_expr(fun, args, input_schema, name, ignore_nulls)?, partition_by, order_by, window_frame, @@ -159,6 +161,7 @@ fn create_built_in_window_expr( args: &[Arc], input_schema: &Schema, name: String, + ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason let input_types: Vec<_> = args @@ -208,6 +211,7 @@ fn create_built_in_window_expr( arg, shift_offset, default_value, + ignore_nulls, )) } BuiltInWindowFunction::Lead => { @@ -222,6 +226,7 @@ fn create_built_in_window_expr( arg, shift_offset, default_value, + ignore_nulls, )) } BuiltInWindowFunction::NthValue => { @@ -671,6 +676,7 @@ mod tests { &[], Arc::new(WindowFrame::new(None)), schema.as_ref(), + false, )?], blocking_exec, vec![], diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f1ee84a8221d..2554018a9273 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1100,6 +1100,8 @@ pub fn parse_expr( "missing window frame during deserialization".to_string(), ) })?; + // TODO: support proto for null treatment + let null_treatment = None; regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { @@ -1114,6 +1116,7 @@ pub fn parse_expr( partition_by, order_by, window_frame, + None ))) } window_expr_node::WindowFunction::BuiltInFunction(i) => { @@ -1133,6 +1136,7 @@ pub fn parse_expr( partition_by, order_by, window_frame, + null_treatment ))) } window_expr_node::WindowFunction::Udaf(udaf_name) => { @@ -1148,6 +1152,7 @@ pub fn parse_expr( partition_by, order_by, window_frame, + None, ))) } window_expr_node::WindowFunction::Udwf(udwf_name) => { @@ -1163,6 +1168,7 @@ pub fn parse_expr( partition_by, order_by, window_frame, + None, ))) } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a6348e909cb0..ccadbb217a58 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -606,6 +606,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref partition_by, ref order_by, ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, }) => { let window_function = match fun { WindowFunctionDefinition::AggregateFunction(fun) => { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 628ee5ad9b7a..af0aa485c348 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -176,6 +176,7 @@ pub fn parse_physical_window_expr( &order_by, Arc::new(window_frame), input_schema, + false, ) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 81f59975476f..6ca757908159 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1718,6 +1718,7 @@ fn roundtrip_window() { vec![col("col1")], vec![col("col2")], WindowFrame::new(Some(false)), + None, )); // 2. with default window_frame @@ -1729,6 +1730,7 @@ fn roundtrip_window() { vec![col("col1")], vec![col("col2")], WindowFrame::new(Some(false)), + None, )); // 3. with window_frame with row numbers @@ -1746,6 +1748,7 @@ fn roundtrip_window() { vec![col("col1")], vec![col("col2")], range_number_frame, + None, )); // 4. test with AggregateFunction @@ -1761,6 +1764,7 @@ fn roundtrip_window() { vec![col("col1")], vec![col("col2")], row_number_frame.clone(), + None, )); // 5. test with AggregateUDF @@ -1812,6 +1816,7 @@ fn roundtrip_window() { vec![col("col1")], vec![col("col2")], row_number_frame.clone(), + None, )); ctx.register_udaf(dummy_agg); @@ -1887,6 +1892,7 @@ fn roundtrip_window() { vec![col("col1")], vec![col("col2")], row_number_frame, + None, )); ctx.register_udwf(dummy_window_udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 64b8d6957d2b..f56138066cb6 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -52,8 +52,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_by, } = function; - if let Some(null_treatment) = null_treatment { - return not_impl_err!("Null treatment in aggregate functions is not supported: {null_treatment}"); + // If function is a window function (it has an OVER clause), + // it shouldn't have ordering requirement as function argument + // required ordering should be defined in OVER clause. + let is_function_window = over.is_some(); + + match null_treatment { + Some(null_treatment) if !is_function_window => return not_impl_err!("Null treatment in aggregate functions is not supported: {null_treatment}"), + _ => {} } let name = if name.0.len() > 1 { @@ -120,10 +126,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))); }; - // If function is a window function (it has an OVER clause), - // it shouldn't have ordering requirement as function argument - // required ordering should be defined in OVER clause. - let is_function_window = over.is_some(); if !order_by.is_empty() && is_function_window { return plan_err!( "Aggregate ORDER BY is not implemented for window functions" @@ -198,6 +200,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { partition_by, order_by, window_frame, + null_treatment, )) } _ => Expr::WindowFunction(expr::WindowFunction::new( @@ -206,6 +209,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { partition_by, order_by, window_frame, + null_treatment, )), }; return Ok(expr); diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 9276f6e1e325..8d6b314747bb 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4102,3 +4102,110 @@ ProjectionExec: expr=[ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRE ----------CoalesceBatchesExec: target_batch_size=4096 ------------FilterExec: a@0 = 1 --------------MemoryExec: partitions=1, partition_sizes=[1] + +# LAG window function IGNORE/RESPECT NULLS support with ascending order and default offset 1 +query TTTTTT +select lag(a) ignore nulls over (order by id) as x, + lag(a, 1, null) ignore nulls over (order by id) as x1, + lag(a, 1, 'def') ignore nulls over (order by id) as x2, + lag(a) respect nulls over (order by id) as x3, + lag(a, 1, null) respect nulls over (order by id) as x4, + lag(a, 1, 'def') respect nulls over (order by id) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +NULL NULL def NULL NULL def +NULL NULL def NULL NULL NULL +b b b b b b +b b b NULL NULL NULL + +# LAG window function IGNORE/RESPECT NULLS support with descending order and default offset 1 +query TTTTTT +select lag(a) ignore nulls over (order by id desc) as x, + lag(a, 1, null) ignore nulls over (order by id desc) as x1, + lag(a, 1, 'def') ignore nulls over (order by id desc) as x2, + lag(a) respect nulls over (order by id desc) as x3, + lag(a, 1, null) respect nulls over (order by id desc) as x4, + lag(a, 1, 'def') respect nulls over (order by id desc) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +NULL NULL def NULL NULL def +x x x x x x +x x x NULL NULL NULL +b b b b b b + +# LAG window function IGNORE/RESPECT NULLS support with ascending order and nondefault offset +query TTTT +select lag(a, 2, null) ignore nulls over (order by id) as x1, + lag(a, 2, 'def') ignore nulls over (order by id) as x2, + lag(a, 2, null) respect nulls over (order by id) as x4, + lag(a, 2, 'def') respect nulls over (order by id) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +NULL def NULL def +NULL def NULL def +NULL def NULL NULL +NULL def b b + +# LAG window function IGNORE/RESPECT NULLS support with descending order and nondefault offset +query TTTT +select lag(a, 2, null) ignore nulls over (order by id desc) as x1, + lag(a, 2, 'def') ignore nulls over (order by id desc) as x2, + lag(a, 2, null) respect nulls over (order by id desc) as x4, + lag(a, 2, 'def') respect nulls over (order by id desc) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') +---- +NULL def NULL def +NULL def NULL def +NULL def x x +x x NULL NULL + +# LAG window function IGNORE/RESPECT NULLS support with descending order and nondefault offset. +# To trigger WindowAggExec, we added a sum window function with all of the ranges. +statement error Execution error: IGNORE NULLS mode for LAG and LEAD is not supported for WindowAggExec +select lag(a, 2, null) ignore nulls over (order by id desc) as x1, + lag(a, 2, 'def') ignore nulls over (order by id desc) as x2, + lag(a, 2, null) respect nulls over (order by id desc) as x4, + lag(a, 2, 'def') respect nulls over (order by id desc) as x5, + sum(id) over (order by id desc ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as sum_id +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') + +# LEAD window function IGNORE/RESPECT NULLS support with descending order and nondefault offset +statement error Execution error: IGNORE NULLS mode for LEAD is not supported for BoundedWindowAggExec +select lead(a, 2, null) ignore nulls over (order by id desc) as x1, + lead(a, 2, 'def') ignore nulls over (order by id desc) as x2, + lead(a, 2, null) respect nulls over (order by id desc) as x4, + lead(a, 2, 'def') respect nulls over (order by id desc) as x5 +from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x') + +statement ok +set datafusion.execution.batch_size = 1000; + +query I +SELECT LAG(c1, 2) IGNORE NULLS OVER() +FROM null_cases +ORDER BY c2 +LIMIT 5; +---- +78 +63 +3 +24 +14 + +# result should be same with above, when lag algorithm work with pruned data. +# decreasing batch size, causes data to be produced in smaller chunks at the source. +# Hence sliding window algorithm is used during calculations. +statement ok +set datafusion.execution.batch_size = 1; + +query I +SELECT LAG(c1, 2) IGNORE NULLS OVER() +FROM null_cases +ORDER BY c2 +LIMIT 5; +---- +78 +63 +3 +24 +14 diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 58a741c63401..23a7ee05d73e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -978,6 +978,7 @@ pub async fn from_substrait_rex( from_substrait_bound(&window.lower_bound, true)?, from_substrait_bound(&window.upper_bound, false)?, ), + null_treatment: None, }))) } Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index fc9517c90a45..9b29c0c67765 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1115,6 +1115,7 @@ pub fn to_substrait_rex( partition_by, order_by, window_frame, + null_treatment: _, }) => { // function reference let function_anchor = _register_function(fun.to_string(), extension_info);