Skip to content

Commit

Permalink
Scalar functions in ORDER BY unparsing support
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov committed Oct 11, 2024
1 parent 75a953f commit 81fa086
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 39 deletions.
46 changes: 10 additions & 36 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ use datafusion_common::{
internal_err, not_impl_err, Column, DataFusionError, Result, TableReference,
};
use datafusion_expr::{
expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
LogicalPlanBuilder, Projection, SortExpr,
expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr
};
use sqlparser::ast::{self, Ident, SetExpr};
use std::sync::Arc;
Expand All @@ -37,8 +36,7 @@ use super::{
subquery_alias_inner_query_and_columns,
},
utils::{
find_agg_node_within_select, find_window_nodes_within_select,
unproject_window_exprs,
find_agg_node_within_select, find_window_nodes_within_select, unproject_sort_expr, unproject_window_exprs
},
Unparser,
};
Expand Down Expand Up @@ -299,6 +297,7 @@ impl Unparser<'_> {
let filter_expr = self.expr_to_sql(&unprojected)?;
select.having(Some(filter_expr));
} else {

let filter_expr = self.expr_to_sql(&filter.predicate)?;
select.selection(Some(filter_expr));
}
Expand Down Expand Up @@ -345,38 +344,13 @@ impl Unparser<'_> {
);
};

let sort_exprs: &Vec<SortExpr> =
// In case of aggregation there could be columns containing aggregation functions we need to unproject
match find_agg_node_within_select(plan, select.already_projected()) {
Some(agg) => &sort
.expr
.iter()
.map(|sort_expr| {
let mut sort_expr = sort_expr.clone();

// ORDER BY can't have aliases, this indicates that the column was not properly unparsed, update it
if let Expr::Alias(alias) = &sort_expr.expr {
sort_expr.expr = *alias.expr.clone();
}

// Unproject the sort expression if it is a column from the aggregation
if let Expr::Column(c) = &sort_expr.expr {
if c.relation.is_none() && agg.schema.is_column_from_schema(&c) {
sort_expr.expr = unproject_agg_exprs(
&sort_expr.expr,
agg,
None,
)?;
}
}

Ok::<_, DataFusionError>(sort_expr)
})
.collect::<Result<Vec<_>>>()?,
None => &sort.expr,
};
let agg = find_agg_node_within_select(plan, select.already_projected());
// unproject sort expressions
let sort_exprs: Vec<SortExpr> = sort.expr.iter().map(|sort_expr| {
unproject_sort_expr(sort_expr, agg, sort.input.as_ref())
}).collect::<Result<Vec<_>>>()?;

query_ref.order_by(self.sorts_to_sql(sort_exprs)?);
query_ref.order_by(self.sorts_to_sql(&sort_exprs)?);

self.select_to_sql_recursively(
sort.input.as_ref(),
Expand Down Expand Up @@ -680,7 +654,7 @@ impl Unparser<'_> {
}
}

fn sorts_to_sql(&self, sort_exprs: &Vec<SortExpr>) -> Result<Vec<ast::OrderByExpr>> {
fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) -> Result<Vec<ast::OrderByExpr>> {
sort_exprs
.iter()
.map(|sort_expr| self.sort_to_sql(sort_expr))
Expand Down
51 changes: 49 additions & 2 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Column, Result, ScalarValue,
Column, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr,
Window,
};
use sqlparser::ast;

Expand Down Expand Up @@ -190,6 +191,52 @@ fn find_window_expr<'a>(
.find(|expr| expr.schema_name().to_string() == column_name)
}

/// Transforms a Column expression into the actual expression from aggregation or projection if found.
/// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced
/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to
/// the actual expression, such as sum("catalog_returns"."cr_net_loss").
pub(crate) fn unproject_sort_expr(
sort_expr: &SortExpr,
agg: Option<&Aggregate>,
input: &LogicalPlan,
) -> Result<SortExpr> {
let mut sort_expr = sort_expr.clone();

// Remove alias if present, because ORDER BY cannot use aliases
if let Expr::Alias(alias) = &sort_expr.expr {
sort_expr.expr = *alias.expr.clone();
}

let Expr::Column(ref col_ref) = sort_expr.expr else {
return Ok::<_, DataFusionError>(sort_expr);
};

if col_ref.relation.is_some() {
return Ok::<_, DataFusionError>(sort_expr);
};

// In case of aggregation there could be columns containing aggregation functions we need to unproject
if let Some(agg) = agg {
if agg.schema.is_column_from_schema(col_ref) {
let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?;
sort_expr.expr = new_expr;

return Ok::<_, DataFusionError>(sort_expr);
}
}

if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input {
if let Ok(idx) = schema.index_of_column(col_ref) {
if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone());
}
}
return Ok::<_, DataFusionError>(sort_expr);
}

Ok::<_, DataFusionError>(sort_expr)
}

/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
pub(crate) fn date_part_to_sql(
unparser: &Unparser,
Expand Down
12 changes: 11 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow_schema::*;
use datafusion_common::{DFSchema, Result, TableReference};
use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf};
use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
use datafusion_functions::unicode;
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
Expand Down Expand Up @@ -643,6 +644,7 @@ fn sql_round_trip(query: &str, expect: &str) {
let context = MockContextProvider {
state: MockSessionState::default()
.with_aggregate_function(sum_udaf())
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone()))
};
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
Expand Down Expand Up @@ -814,14 +816,22 @@ fn test_interval_lhs_lt() {
}

#[test]
fn test_order_by_with_aggr_to_sql() {
fn test_order_by_to_sql() {
// order by aggregation function
sql_round_trip(
r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#,
r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#,
);

// order by aggregation function alias
sql_round_trip(
r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 10"#,
r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#,
);

// order by scalar function from projection
sql_round_trip(
r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY id, substr(first_name,0,5)"#,
r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"#,
);
}

0 comments on commit 81fa086

Please sign in to comment.