Skip to content

Commit

Permalink
Support IGNORE NULLS for LAG window function (apache#9221)
Browse files Browse the repository at this point in the history
* WIP lag/lead ignore nulls

* Support IGNORE NULLS for LAG function

* fmt

* comments

* remove comments

* Add new tests, minor changes, trigger evalaute_all

* Make algorithm pruning friendly

---------

Co-authored-by: Mustafa Akur <[email protected]>
  • Loading branch information
comphead and mustafasrepo authored Feb 23, 2024
1 parent 02c948d commit a851ecf
Show file tree
Hide file tree
Showing 22 changed files with 272 additions and 14 deletions.
1 change: 1 addition & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,7 @@ mod tests {
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(None),
None,
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ pub fn bounded_window_exec(
&sort_exprs,
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
false,
)
.unwrap()],
input.clone(),
Expand Down
6 changes: 6 additions & 0 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
use itertools::{multiunzip, Itertools};
use log::{debug, trace};
use sqlparser::ast::NullTreatment;

fn create_function_physical_name(
fun: &str,
Expand Down Expand Up @@ -1581,6 +1582,7 @@ pub fn create_window_expr_with_name(
partition_by,
order_by,
window_frame,
null_treatment,
}) => {
let args = args
.iter()
Expand All @@ -1605,6 +1607,9 @@ pub fn create_window_expr_with_name(
}

let window_frame = Arc::new(window_frame.clone());
let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
windows::create_window_expr(
fun,
name,
Expand All @@ -1613,6 +1618,7 @@ pub fn create_window_expr_with_name(
&order_by,
window_frame,
physical_input_schema,
ignore_nulls,
)
}
other => plan_err!("Invalid window expression '{other:?}'"),
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
None,
))])?
.explain(false, false)?
.collect()
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
&orderby_exprs,
Arc::new(window_frame),
schema.as_ref(),
false,
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
vec![window_expr],
Expand Down Expand Up @@ -642,6 +643,7 @@ async fn run_window_test(
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
false,
)
.unwrap()],
exec1,
Expand All @@ -664,6 +666,7 @@ async fn run_window_test(
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
false,
)
.unwrap()],
exec2,
Expand Down
18 changes: 18 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use arrow::datatypes::DataType;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, DFSchema, OwnedTableReference};
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
use sqlparser::ast::NullTreatment;
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Display, Formatter, Write};
Expand Down Expand Up @@ -646,6 +647,7 @@ pub struct WindowFunction {
pub order_by: Vec<Expr>,
/// Window frame
pub window_frame: window_frame::WindowFrame,
pub null_treatment: Option<NullTreatment>,
}

impl WindowFunction {
Expand All @@ -656,13 +658,15 @@ impl WindowFunction {
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
fun,
args,
partition_by,
order_by,
window_frame,
null_treatment,
}
}
}
Expand Down Expand Up @@ -1440,8 +1444,14 @@ impl fmt::Display for Expr {
partition_by,
order_by,
window_frame,
null_treatment,
}) => {
fmt_function(f, &fun.to_string(), false, args, true)?;

if let Some(nt) = null_treatment {
write!(f, "{}", nt)?;
}

if !partition_by.is_empty() {
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
}
Expand Down Expand Up @@ -1768,15 +1778,23 @@ fn create_name(e: &Expr) -> Result<String> {
window_frame,
partition_by,
order_by,
null_treatment,
}) => {
let mut parts: Vec<String> =
vec![create_function_name(&fun.to_string(), false, args)?];

if let Some(nt) = null_treatment {
parts.push(format!("{}", nt));
}

if !partition_by.is_empty() {
parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by)));
}

if !order_by.is_empty() {
parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by)));
}

parts.push(format!("{window_frame}"));
Ok(parts.join(" "))
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,14 @@ impl TreeNode for Expr {
partition_by,
order_by,
window_frame,
null_treatment,
}) => Expr::WindowFunction(WindowFunction::new(
fun,
transform_vec(args, &mut transform)?,
transform_vec(partition_by, &mut transform)?,
transform_vec(order_by, &mut transform)?,
window_frame,
null_treatment,
)),
Expr::AggregateFunction(AggregateFunction {
args,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ impl WindowUDF {
partition_by,
order_by,
window_frame,
null_treatment: None,
})
}

Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1255,27 +1255,31 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
Expand All @@ -1298,27 +1302,31 @@ mod tests {
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(Some(false)),
None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(Some(false)),
None,
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(Some(false)),
None,
));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
Expand Down Expand Up @@ -1353,6 +1361,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(Some(false)),
None,
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
Expand All @@ -1364,6 +1373,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(Some(false)),
None,
)),
];
let expected = vec![
Expand Down
3 changes: 3 additions & 0 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Expr::WindowFunction(expr::WindowFunction {
Expand All @@ -138,6 +139,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
})
}

Expand Down Expand Up @@ -351,6 +353,7 @@ mod tests {
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
None,
))])?
.project(vec![count(wildcard())])?
.build()?;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
}) => {
let window_frame =
coerce_window_frame(window_frame, &self.schema, &order_by)?;
Expand All @@ -414,6 +415,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
));
Ok(expr)
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/push_down_projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ mod tests {
vec![col("test.b")],
vec![],
WindowFrame::new(None),
None,
));

let max2 = Expr::WindowFunction(expr::WindowFunction::new(
Expand All @@ -595,6 +596,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
None,
));
let col1 = col(max1.display_name()?);
let col2 = col(max2.display_name()?);
Expand Down
Loading

0 comments on commit a851ecf

Please sign in to comment.