From f6ae6c6750e982cc1d8fa361aa8b6da4d58aae85 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Dec 2023 07:04:43 -0500 Subject: [PATCH] Fix group by aliased expression in LogicalPLanBuilder::aggregate (#8629) --- datafusion/core/src/dataframe/mod.rs | 36 ++++++++++++- datafusion/expr/src/logical_plan/builder.rs | 58 ++++++++++++++------- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2ae4a7c21a9c..3c3bcd497b7f 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1769,8 +1769,8 @@ mod tests { let df_results = df.collect().await?; #[rustfmt::skip] - assert_batches_sorted_eq!( - [ "+----+", + assert_batches_sorted_eq!([ + "+----+", "| id |", "+----+", "| 1 |", @@ -1781,6 +1781,38 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_alias() -> Result<()> { + let df = test_table().await?; + + let df = df + // GROUP BY `c2 + 1` + .aggregate(vec![col("c2") + lit(1)], vec![])? + // SELECT `c2 + 1` as c2 + .select(vec![(col("c2") + lit(1)).alias("c2")])? + // GROUP BY c2 as "c2" (alias in expr is not supported by SQL) + .aggregate(vec![col("c2").alias("c2")], vec![])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+", + "| c2 |", + "+----+", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "| 6 |", + "+----+", + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 88310dab82a2..549c25f89bae 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -904,27 +904,11 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - let mut group_expr = normalize_cols(group_expr, &self.plan)?; + let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - // Rewrite groupby exprs according to functional dependencies - let group_by_expr_names = group_expr - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - let schema = self.plan.schema(); - if let Some(target_indices) = - get_target_functional_dependencies(schema, &group_by_expr_names) - { - for idx in target_indices { - let field = schema.field(idx); - let expr = - Expr::Column(Column::new(field.qualifier().cloned(), field.name())); - if !group_expr.contains(&expr) { - group_expr.push(expr); - } - } - } + let group_expr = + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) @@ -1189,6 +1173,42 @@ pub fn build_join_schema( schema.with_functional_dependencies(func_dependencies) } +/// Add additional "synthetic" group by expressions based on functional +/// dependencies. +/// +/// For example, if we are grouping on `[c1]`, and we know from +/// functional dependencies that column `c1` determines `c2`, this function +/// adds `c2` to the group by list. +/// +/// This allows MySQL style selects like +/// `SELECT col FROM t WHERE pk = 5` if col is unique +fn add_group_by_exprs_from_dependencies( + mut group_expr: Vec, + schema: &DFSchemaRef, +) -> Result> { + // Names of the fields produced by the GROUP BY exprs for example, `GROUP BY + // c1 + 1` produces an output field named `"c1 + 1"` + let mut group_by_field_names = group_expr + .iter() + .map(|e| e.display_name()) + .collect::>>()?; + + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_field_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + let expr_name = expr.display_name()?; + if !group_by_field_names.contains(&expr_name) { + group_by_field_names.push(expr_name); + group_expr.push(expr); + } + } + } + Ok(group_expr) +} /// Errors if one or more expressions have equal names. pub(crate) fn validate_unique_names<'a>( node_name: &str,