Skip to content

Commit

Permalink
Properly project grouping set expressions (#6777)
Browse files Browse the repository at this point in the history
* update version to 26.0.0

* update Cargo.lock

* changelog

* prettier

* update changelog

* VTX-1613: update ignore rule

* VTX-1613: revert

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: handle grouping sets

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: debug

* VTX-1613: cleanup & fix for grouping set

* VTX-1613: cleanup

* VTX-1613: cleanup

* VTX-1613: cleanup

* VTX-1613: cleanup

* VTX-1613: cleanup

* VTX-1613: cleanup

* VTX-1613: cleanup

* VTX-1613: fix import

---------

Co-authored-by: Andy Grove <[email protected]>
  • Loading branch information
fsdvh and andygrove authored Jun 28, 2023
1 parent b9ecfc5 commit a76b09e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 9 deletions.
8 changes: 4 additions & 4 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

72 changes: 67 additions & 5 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,7 @@ impl CommonSubexprEliminate {

let mut proj_exprs = vec![];
for expr in &new_group_expr {
let out_col: Column =
expr.to_field(&new_input_schema)?.qualified_column();
proj_exprs.push(Expr::Column(out_col));
extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
}
for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) {
if expr_rewritten == expr_orig {
Expand Down Expand Up @@ -488,6 +486,22 @@ fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalP
)
}

fn extract_expressions(
expr: &Expr,
schema: &DFSchema,
result: &mut Vec<Expr>,
) -> Result<()> {
if let Expr::GroupingSet(groupings) = expr {
for e in groupings.distinct_expr() {
result.push(Expr::Column(e.to_field(schema)?.qualified_column()))
}
} else {
result.push(Expr::Column(expr.to_field(schema)?.qualified_column()));
}

Ok(())
}

/// Which type of [expressions](Expr) should be considered for rewriting?
#[derive(Debug, Clone, Copy)]
enum ExprMask {
Expand Down Expand Up @@ -773,8 +787,8 @@ mod test {
avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
};
use datafusion_expr::{
AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, Volatility,
grouping_set, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction,
Signature, StateTypeFunction, Volatility,
};

use crate::optimizer::OptimizerContext;
Expand Down Expand Up @@ -1251,4 +1265,52 @@ mod test {

Ok(())
}

#[test]
fn test_extract_expressions_from_grouping_set() -> Result<()> {
let mut result = Vec::with_capacity(3);
let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
let schema = DFSchema::new_with_metadata(
vec![
DFField::new_unqualified("a", DataType::Int32, false),
DFField::new_unqualified("b", DataType::Int32, false),
DFField::new_unqualified("c", DataType::Int32, false),
],
HashMap::default(),
)?;
extract_expressions(&grouping, &schema, &mut result)?;

assert!(result.len() == 3);
Ok(())
}

#[test]
fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
let mut result = Vec::with_capacity(2);
let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
let schema = DFSchema::new_with_metadata(
vec![
DFField::new_unqualified("a", DataType::Int32, false),
DFField::new_unqualified("b", DataType::Int32, false),
],
HashMap::default(),
)?;
extract_expressions(&grouping, &schema, &mut result)?;

assert!(result.len() == 2);
Ok(())
}

#[test]
fn test_extract_expressions_from_col() -> Result<()> {
let mut result = Vec::with_capacity(1);
let schema = DFSchema::new_with_metadata(
vec![DFField::new_unqualified("a", DataType::Int32, false)],
HashMap::default(),
)?;
extract_expressions(&col("a"), &schema, &mut result)?;

assert!(result.len() == 1);
Ok(())
}
}

0 comments on commit a76b09e

Please sign in to comment.