diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 38d87e934e5f..85af9023fb46 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -174,7 +174,7 @@ message WindowExprNode { // udaf = 3 } LogicalExprNode expr = 4; - // repeated LogicalExprNode partition_by = 5; + repeated LogicalExprNode partition_by = 5; repeated LogicalExprNode order_by = 6; // repeated LogicalExprNode filter = 7; oneof window_frame { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 36a37a1e472c..86daeb063c47 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -910,6 +910,12 @@ impl TryInto for &protobuf::LogicalExprNode { .window_function .as_ref() .ok_or_else(|| proto_error("Received empty window function"))?; + let partition_by = expr + .partition_by + .iter() + .map(|e| e.try_into()) + .into_iter() + .collect::, _>>()?; let order_by = expr .order_by .iter() @@ -940,6 +946,7 @@ impl TryInto for &protobuf::LogicalExprNode { AggregateFunction::from(aggr_function), ), args: vec![parse_required_expr(&expr.expr)?], + partition_by, order_by, window_frame, }) @@ -960,6 +967,7 @@ impl TryInto for &protobuf::LogicalExprNode { BuiltInWindowFunction::from(built_in_function), ), args: vec![parse_required_expr(&expr.expr)?], + partition_by, order_by, window_frame, }) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index fb1383daab3a..5d996843d624 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1006,6 +1006,7 @@ impl TryInto for &Expr { Expr::WindowFunction { ref fun, ref args, + ref partition_by, ref order_by, ref window_frame, .. @@ -1023,6 +1024,10 @@ impl TryInto for &Expr { } }; let arg = &args[0]; + let partition_by = partition_by + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?; let order_by = order_by .iter() .map(|e| e.try_into()) @@ -1035,6 +1040,7 @@ impl TryInto for &Expr { let window_expr = Box::new(protobuf::WindowExprNode { expr: Some(Box::new(arg.try_into()?)), window_function: Some(window_function), + partition_by, order_by, window_frame, }); diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 5fcc971527c6..b319d5b25f12 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -236,7 +236,9 @@ impl TryInto> for &protobuf::PhysicalPlanNode { Expr::WindowFunction { fun, args, + partition_by, order_by, + window_frame, .. } => { let arg = df_planner @@ -248,9 +250,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map_err(|e| { BallistaError::General(format!("{:?}", e)) })?; + if !partition_by.is_empty() { + return Err(BallistaError::NotImplemented("Window function with partition by is not yet implemented".to_owned())); + } if !order_by.is_empty() { return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned())); } + if window_frame.is_some() { + return Err(BallistaError::NotImplemented("Window function with window frame is not yet implemented".to_owned())); + } let window_expr = create_window_expr( &fun, &[arg], diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index d5c92dbd2143..58dba16f02ef 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -194,6 +194,8 @@ pub enum Expr { fun: window_functions::WindowFunction, /// List of expressions to feed to the functions as arguments args: Vec, + /// List of partition by expressions + partition_by: Vec, /// List of order by expressions order_by: Vec, /// Window frame @@ -588,10 +590,18 @@ impl Expr { Expr::ScalarUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::WindowFunction { args, order_by, .. } => { + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { let visitor = args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = partition_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; let visitor = order_by .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; @@ -733,11 +743,13 @@ impl Expr { Expr::WindowFunction { args, fun, + partition_by, order_by, window_frame, } => Expr::WindowFunction { args: rewrite_vec(args, rewriter)?, fun, + partition_by: rewrite_vec(partition_by, rewriter)?, order_by: rewrite_vec(order_by, rewriter)?, window_frame, }, diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 25cf9e33d2ca..3344dce1d81d 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -687,11 +687,7 @@ impl LogicalPlan { LogicalPlan::Window { ref window_expr, .. } => { - write!( - f, - "WindowAggr: windowExpr=[{:?}] partitionBy=[]", - window_expr - ) + write!(f, "WindowAggr: windowExpr=[{:?}]", window_expr) } LogicalPlan::Aggregate { ref group_expr, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 65c95bee20d4..e707d30bc9ac 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -36,6 +36,7 @@ use crate::{ const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; +const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__"; const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__"; /// Recursively walk a list of expression trees, collecting the unique set of column @@ -258,9 +259,16 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]), Expr::ScalarFunction { args, .. } => Ok(args.clone()), Expr::ScalarUDF { args, .. } => Ok(args.clone()), - Expr::WindowFunction { args, order_by, .. } => { + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { let mut expr_list: Vec = vec![]; expr_list.extend(args.clone()); + expr_list.push(lit(WINDOW_PARTITION_MARKER)); + expr_list.extend(partition_by.clone()); expr_list.push(lit(WINDOW_SORT_MARKER)); expr_list.extend(order_by.clone()); Ok(expr_list) @@ -340,7 +348,20 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { Expr::WindowFunction { fun, window_frame, .. } => { - let index = expressions + let partition_index = expressions + .iter() + .position(|expr| { + matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str))) + if str == WINDOW_PARTITION_MARKER) + }) + .ok_or_else(|| { + DataFusionError::Internal( + "Ill-formed window function expressions: unexpected marker" + .to_owned(), + ) + })?; + + let sort_index = expressions .iter() .position(|expr| { matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str))) @@ -351,12 +372,21 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { "Ill-formed window function expressions".to_owned(), ) })?; - Ok(Expr::WindowFunction { - fun: fun.clone(), - args: expressions[..index].to_vec(), - order_by: expressions[index + 1..].to_vec(), - window_frame: *window_frame, - }) + + if partition_index >= sort_index { + Err(DataFusionError::Internal( + "Ill-formed window function expressions: partition index too large" + .to_owned(), + )) + } else { + Ok(Expr::WindowFunction { + fun: fun.clone(), + args: expressions[..partition_index].to_vec(), + partition_by: expressions[partition_index + 1..sort_index].to_vec(), + order_by: expressions[sort_index + 1..].to_vec(), + window_frame: *window_frame, + }) + } } Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { fun: fun.clone(), diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 7df0068c5f54..53f22ecaf3f2 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1122,52 +1122,53 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // then, window function if let Some(window) = &function.over { - if window.partition_by.is_empty() { - let order_by = window - .order_by - .iter() - .map(|e| self.order_by_to_sort_expr(e)) - .into_iter() - .collect::>>()?; - let window_frame = window - .window_frame - .as_ref() - .map(|window_frame| window_frame.clone().try_into()) - .transpose()?; - let fun = window_functions::WindowFunction::from_str(&name); - if let Ok(window_functions::WindowFunction::AggregateFunction( + let partition_by = window + .partition_by + .iter() + .map(|e| self.sql_expr_to_logical_expr(e)) + .into_iter() + .collect::>>()?; + let order_by = window + .order_by + .iter() + .map(|e| self.order_by_to_sort_expr(e)) + .into_iter() + .collect::>>()?; + let window_frame = window + .window_frame + .as_ref() + .map(|window_frame| window_frame.clone().try_into()) + .transpose()?; + let fun = window_functions::WindowFunction::from_str(&name)?; + match fun { + window_functions::WindowFunction::AggregateFunction( aggregate_fun, - )) = fun - { + ) => { return Ok(Expr::WindowFunction { fun: window_functions::WindowFunction::AggregateFunction( aggregate_fun.clone(), ), args: self .aggregate_fn_to_expr(&aggregate_fun, function)?, + partition_by, order_by, window_frame, }); - } else if let Ok( - window_functions::WindowFunction::BuiltInWindowFunction( - window_fun, - ), - ) = fun - { + } + window_functions::WindowFunction::BuiltInWindowFunction( + window_fun, + ) => { return Ok(Expr::WindowFunction { fun: window_functions::WindowFunction::BuiltInWindowFunction( window_fun, ), args: self.function_args_to_expr(function)?, + partition_by, order_by, window_frame, }); } } - return Err(DataFusionError::NotImplemented(format!( - "Unsupported OVER clause ({})", - window - ))); } // next, aggregate built-ins @@ -2775,7 +2776,7 @@ mod tests { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(order_id)\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2785,7 +2786,7 @@ mod tests { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders"; let expected = "\ Projection: #order_id AS oid, #MAX(order_id) AS max_oid\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2795,7 +2796,7 @@ mod tests { let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(qty Multiply Float64(1.1))\ - \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2806,20 +2807,29 @@ mod tests { "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------- + /// WindowAgg (cost=69.83..87.33 rows=1000 width=8) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` #[test] - fn over_partition_by_not_supported() { - let sql = - "SELECT order_id, MAX(delivered) OVER (PARTITION BY order_id) from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "NotImplemented(\"Unsupported OVER clause (PARTITION BY order_id)\")", - format!("{:?}", err) - ); + fn over_partition_by() { + let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); } /// psql result @@ -2839,9 +2849,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2852,9 +2862,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2865,9 +2875,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2878,9 +2888,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ Projection: #order_id, #MAX(qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id DESC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2903,9 +2913,9 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id Plus Int64(1) ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2929,10 +2939,10 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[SUM(#qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2956,10 +2966,10 @@ mod tests { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[SUM(#qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -2987,15 +2997,108 @@ mod tests { let expected = "\ Sort: #order_id ASC NULLS FIRST\ \n Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ - \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ - \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[SUM(#qty)]]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ - \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ \n TableScan: orders projection=None"; quick_test(sql, expected); } + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------- + /// WindowAgg (cost=69.83..89.83 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_partition_by_order_by() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------- + /// WindowAgg (cost=69.83..89.83 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_partition_by_order_by_no_dup() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------------------- + /// WindowAgg (cost=142.16..162.16 rows=1000 width=16) + /// -> Sort (cost=142.16..144.66 rows=1000 width=12) + /// Sort Key: qty, order_id + /// -> WindowAgg (cost=69.83..92.33 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_partition_by_order_by_mix_up() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty), MIN(qty) OVER (PARTITION BY qty ORDER BY order_id) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ + \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ----------------------------------------------------------------------------- + /// WindowAgg (cost=69.83..109.83 rows=1000 width=24) + /// -> WindowAgg (cost=69.83..92.33 rows=1000 width=20) + /// -> Sort (cost=69.83..72.33 rows=1000 width=16) + /// Sort Key: order_id, qty, price + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=16) + /// ``` + /// FIXME: for now we are not detecting prefix of sorting keys in order to save one sort exec phase + #[test] + fn over_partition_by_order_by_mix_up_prefix() { + let sql = + "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty), MIN(qty) OVER (PARTITION BY order_id, qty ORDER BY price) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST, #price ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn only_union_all_supported() { let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders"; diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 848fb3ee31fc..5e9b9526ea83 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -239,6 +239,7 @@ where Expr::WindowFunction { fun, args, + partition_by, order_by, window_frame, } => Ok(Expr::WindowFunction { @@ -247,6 +248,10 @@ where .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, + partition_by: partition_by + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, order_by: order_by .iter() .map(|e| clone_with_replacement(e, replacement_fn)) @@ -432,19 +437,38 @@ pub(crate) fn resolve_aliases_to_exprs( }) } +type WindowSortKey = Vec; + +fn generate_sort_key(partition_by: &[Expr], order_by: &[Expr]) -> WindowSortKey { + let mut sort_key = vec![]; + partition_by.iter().for_each(|e| { + let e = e.clone().sort(true, true); + if !sort_key.contains(&e) { + sort_key.push(e); + } + }); + order_by.iter().for_each(|e| { + if !sort_key.contains(&e) { + sort_key.push(e.clone()); + } + }); + sort_key +} + /// group a slice of window expression expr by their order by expressions pub(crate) fn group_window_expr_by_sort_keys( window_expr: &[Expr], -) -> Result)>> { +) -> Result)>> { let mut result = vec![]; window_expr.iter().try_for_each(|expr| match expr { - Expr::WindowFunction { order_by, .. } => { + Expr::WindowFunction { partition_by, order_by, .. } => { + let sort_key = generate_sort_key(partition_by, order_by); if let Some((_, values)) = result.iter_mut().find( - |group: &&mut (&[Expr], Vec<&Expr>)| matches!(group, (key, _) if key == order_by), + |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key), ) { values.push(expr); } else { - result.push((order_by, vec![expr])) + result.push((sort_key, vec![expr])) } Ok(()) } @@ -466,7 +490,7 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { let result = group_window_expr_by_sort_keys(&[])?; - let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![]; + let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![]; assert_eq!(expected, result); Ok(()) } @@ -476,32 +500,35 @@ mod tests { let max1 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], + partition_by: vec![], order_by: vec![], window_frame: None, }; - // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs)?; - let key = &[]; - let expected: Vec<(&[Expr], Vec<&Expr>)> = + let key = vec![]; + let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![(key, vec![&max1, &max2, &min3, &sum4])]; assert_eq!(expected, result); Ok(()) @@ -527,24 +554,28 @@ mod tests { let max1 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![age_asc.clone(), name_desc.clone()], window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![], window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], + partition_by: vec![], order_by: vec![age_asc.clone(), name_desc.clone()], window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], + partition_by: vec![], order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], window_frame: None, }; @@ -552,11 +583,11 @@ mod tests { let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs)?; - let key1 = &[age_asc.clone(), name_desc.clone()]; - let key2 = &[]; - let key3 = &[name_desc, age_asc, created_at_desc]; + let key1 = vec![age_asc.clone(), name_desc.clone()]; + let key2 = vec![]; + let key3 = vec![name_desc, age_asc, created_at_desc]; - let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![ + let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ (key1, vec![&max1, &min3]), (key2, vec![&max2]), (key3, vec![&sum4]), @@ -571,6 +602,7 @@ mod tests { Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], + partition_by: vec![], order_by: vec![ Expr::Sort { expr: Box::new(col("age")), @@ -588,6 +620,7 @@ mod tests { Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], + partition_by: vec![], order_by: vec![ Expr::Sort { expr: Box::new(col("name")),