Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix group by aliased expression in LogicalPLanBuilder::aggregate #8629

Merged
merged 1 commit into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1825,8 +1825,8 @@ mod tests {
let df_results = collect(physical_plan, ctx.task_ctx()).await?;

#[rustfmt::skip]
assert_batches_sorted_eq!(
[ "+----+",
assert_batches_sorted_eq!([
"+----+",
"| id |",
"+----+",
"| 1 |",
Expand All @@ -1837,6 +1837,38 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_aggregate_alias() -> Result<()> {
let df = test_table().await?;

let df = df
// GROUP BY `c2 + 1`
.aggregate(vec![col("c2") + lit(1)], vec![])?
// SELECT `c2 + 1` as c2
.select(vec![(col("c2") + lit(1)).alias("c2")])?
// GROUP BY c2 as "c2" (alias in expr is not supported by SQL)
.aggregate(vec![col("c2").alias("c2")], vec![])?;

let df_results = df.collect().await?;

#[rustfmt::skip]
assert_batches_sorted_eq!([
"+----+",
"| c2 |",
"+----+",
"| 2 |",
"| 3 |",
"| 4 |",
"| 5 |",
"| 6 |",
"+----+",
],
&df_results
);

Ok(())
}

#[tokio::test]
async fn test_distinct() -> Result<()> {
let t = test_table().await?;
Expand Down
58 changes: 39 additions & 19 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,27 +904,11 @@ impl LogicalPlanBuilder {
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let mut group_expr = normalize_cols(group_expr, &self.plan)?;
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;

// Rewrite groupby exprs according to functional dependencies
let group_by_expr_names = group_expr
.iter()
.map(|group_by_expr| group_by_expr.display_name())
.collect::<Result<Vec<_>>>()?;
let schema = self.plan.schema();
if let Some(target_indices) =
get_target_functional_dependencies(schema, &group_by_expr_names)
{
for idx in target_indices {
let field = schema.field(idx);
let expr =
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
if !group_expr.contains(&expr) {
group_expr.push(expr);
}
}
}
let group_expr =
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::from)
Expand Down Expand Up @@ -1189,6 +1173,42 @@ pub fn build_join_schema(
schema.with_functional_dependencies(func_dependencies)
}

/// Add additional "synthetic" group by expressions based on functional
/// dependencies.
///
/// For example, if we are grouping on `[c1]`, and we know from
/// functional dependencies that column `c1` determines `c2`, this function
/// adds `c2` to the group by list.
///
/// This allows MySQL style selects like
/// `SELECT col FROM t WHERE pk = 5` if col is unique
fn add_group_by_exprs_from_dependencies(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled the logic into its own function so I could document it better

mut group_expr: Vec<Expr>,
schema: &DFSchemaRef,
) -> Result<Vec<Expr>> {
// Names of the fields produced by the GROUP BY exprs for example, `GROUP BY
// c1 + 1` produces an output field named `"c1 + 1"`
let mut group_by_field_names = group_expr
.iter()
.map(|e| e.display_name())
.collect::<Result<Vec<_>>>()?;

if let Some(target_indices) =
get_target_functional_dependencies(schema, &group_by_field_names)
{
for idx in target_indices {
let field = schema.field(idx);
let expr =
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
let expr_name = expr.display_name()?;
if !group_by_field_names.contains(&expr_name) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the change -- rather than comparing the Exprs it compares their display_name (the fields they will create)

group_by_field_names.push(expr_name);
group_expr.push(expr);
}
}
}
Ok(group_expr)
}
/// Errors if one or more expressions have equal names.
pub(crate) fn validate_unique_names<'a>(
node_name: &str,
Expand Down