Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple columns in order by clause in for ARRAYAGG #1228

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,26 @@ class TranslateWithinGroup extends ir.Rule[ir.LogicalPlan] {

override def apply(plan: ir.LogicalPlan): ir.LogicalPlan = {
plan transformAllExpressions {
case ir.WithinGroup(ir.CallFunction("ARRAY_AGG", args), sorts) => sortArray(args.head, sorts.head)
case ir.WithinGroup(ir.CallFunction("ARRAY_AGG", args), sorts) => sortArray(args.head, sorts)
case ir.WithinGroup(ir.CallFunction("LISTAGG", args), sorts) =>
ir.ArrayJoin(sortArray(args.head, sorts.head), args(1))
ir.ArrayJoin(sortArray(args.head, sorts), args(1))
}
}

private def sortArray(arg: ir.Expression, sort: ir.SortOrder): ir.Expression = {
if (sameReference(arg, sort.expr)) {
val sortOrder = if (sort.direction == ir.Descending) { Some(ir.Literal(false)) }
private def sortArray(arg: ir.Expression, sort: Seq[ir.SortOrder]): ir.Expression = {
if (sort.size == 1 && sameReference(arg, sort.head.expr)) {
val sortOrder = if (sort.head.direction == ir.Descending) { Some(ir.Literal(false)) }
else { None }
ir.SortArray(ir.CollectList(arg), sortOrder)
} else {

val namedStructFunc = ir.CreateNamedStruct(Seq(ir.Literal("value"), arg) ++ sort.zipWithIndex.flatMap {
case (s, index) =>
Seq(ir.Literal(s"sort_by_$index"), s.expr)
})

ir.ArrayTransform(
ir.ArraySort(
ir.CollectList(ir.CreateNamedStruct(Seq(ir.Literal("value"), arg, ir.Literal("sort_by"), sort.expr))),
sortingLambda(sort.direction)),
ir.ArraySort(ir.CollectList(namedStructFunc), sortingLambda(sort)),
ir.LambdaFunction(ir.Dot(ir.Id("s"), ir.Id("value")), Seq(ir.UnresolvedNamedLambdaVariable(Seq("s")))))
}
}
Expand All @@ -36,17 +39,23 @@ class TranslateWithinGroup extends ir.Rule[ir.LogicalPlan] {
case _ => false
}

private def sortingLambda(dir: ir.SortDirection): ir.Expression = {
private def sortingLambda(sort: Seq[ir.SortOrder]): ir.Expression = {
ir.LambdaFunction(
ir.Case(
None,
Seq(
ir.WhenBranch(
ir.LessThan(ir.Dot(ir.Id("left"), ir.Id("sort_by")), ir.Dot(ir.Id("right"), ir.Id("sort_by"))),
if (dir == ir.Ascending) ir.Literal(-1) else ir.Literal(1)),
ir.WhenBranch(
ir.GreaterThan(ir.Dot(ir.Id("left"), ir.Id("sort_by")), ir.Dot(ir.Id("right"), ir.Id("sort_by"))),
if (dir == ir.Ascending) ir.Literal(1) else ir.Literal(-1))),
sort.zipWithIndex.flatMap { case (s, index) =>
Seq(
ir.WhenBranch(
ir.LessThan(
ir.Dot(ir.Id("left"), ir.Id(s"sort_by_$index")),
ir.Dot(ir.Id("right"), ir.Id(s"sort_by_$index"))),
if (s.direction == ir.Ascending) ir.Literal(-1) else ir.Literal(1)),
ir.WhenBranch(
ir.GreaterThan(
ir.Dot(ir.Id("left"), ir.Id(s"sort_by_$index")),
ir.Dot(ir.Id("right"), ir.Id(s"sort_by_$index"))),
if (s.direction == ir.Ascending) ir.Literal(1) else ir.Literal(-1)))
},
Some(ir.Literal(0))),
Seq(UnresolvedNamedLambdaVariable(Seq("left")), UnresolvedNamedLambdaVariable(Seq("right"))))
}
Expand Down
52 changes: 30 additions & 22 deletions src/databricks/labs/remorph/snow/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import re

from sqlglot import expressions as exp
from sqlglot.dialects import hive
from sqlglot.dialects import databricks as org_databricks
from sqlglot.dialects import hive
from sqlglot.dialects.dialect import if_sql
from sqlglot.dialects.dialect import rename_func
from sqlglot.errors import ParseError, UnsupportedError
from sqlglot.errors import UnsupportedError
from sqlglot.helper import apply_index_offset, csv
from sqlglot.dialects.dialect import if_sql

from databricks.labs.remorph.snow import lca_utils, local_expression

Expand Down Expand Up @@ -350,30 +350,30 @@ def _get_within_group_params(
) -> local_expression.WithinGroupParams:
has_distinct = isinstance(expr.this, exp.Distinct)
agg_col = expr.this.expressions[0] if has_distinct else expr.this
order_expr = within_group.expression
order_col = order_expr.expressions[0].this
desc = order_expr.expressions[0].args.get("desc")
is_order_asc = not desc or exp.false() == desc
# In Snow, if both DISTINCT and WITHIN GROUP are specified, both must refer to the same column.
# Ref: https://docs.snowflake.com/en/sql-reference/functions/array_agg#usage-notes
# TODO: Check the same restriction applies for other source dialects to be added in the future
if has_distinct and agg_col != order_col:
raise ParseError("If both DISTINCT and WITHIN GROUP are specified, both must refer to the same column.")
order_clause = within_group.expression
order_cols = []
for e in order_clause.expressions:
desc = e.args.get("desc")
is_order_a = not desc or exp.false() == desc
order_cols.append((e.this, is_order_a))
return local_expression.WithinGroupParams(
agg_col=agg_col,
order_col=order_col,
is_order_asc=is_order_asc,
order_cols=order_cols,
)


def _create_named_struct_for_cmp(agg_col, order_col) -> exp.Expression:
def _create_named_struct_for_cmp(wg_params: local_expression.WithinGroupParams) -> exp.Expression:
agg_col = wg_params.agg_col
order_kv = []
for i, (col, _) in enumerate(wg_params.order_cols):
order_kv.extend([exp.Literal(this=f"sort_by_{i}", is_string=True), col])

named_struct_func = exp.Anonymous(
this="named_struct",
expressions=[
exp.Literal(this="value", is_string=True),
agg_col,
exp.Literal(this="sort_by", is_string=True),
order_col,
*order_kv,
],
)
return named_struct_func
Expand Down Expand Up @@ -515,16 +515,24 @@ def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return sql

wg_params = _get_within_group_params(expression, within_group)
if wg_params.agg_col == wg_params.order_col:
return f"SORT_ARRAY({sql}{'' if wg_params.is_order_asc else ', FALSE'})"
if len(wg_params.order_cols) == 1:
order_col, is_order_asc = wg_params.order_cols[0]
if wg_params.agg_col == order_col:
return f"SORT_ARRAY({sql}{'' if is_order_asc else ', FALSE'})"

named_struct_func = _create_named_struct_for_cmp(wg_params)
comparisons = []
for i, (_, is_order_asc) in enumerate(wg_params.order_cols):
comparisons.append(
f"WHEN left.sort_by_{i} < right.sort_by_{i} THEN {'-1' if is_order_asc else '1'} "
f"WHEN left.sort_by_{i} > right.sort_by_{i} THEN {'1' if is_order_asc else '-1'}"
)

named_struct_func = _create_named_struct_for_cmp(wg_params.agg_col, wg_params.order_col)
array_sort = self.func(
"ARRAY_SORT",
self.func("ARRAY_AGG", named_struct_func),
f"""(left, right) -> CASE
WHEN left.sort_by < right.sort_by THEN {'-1' if wg_params.is_order_asc else '1'}
WHEN left.sort_by > right.sort_by THEN {'1' if wg_params.is_order_asc else '-1'}
{' '.join(comparisons)}
ELSE 0
END""",
)
Expand Down
3 changes: 1 addition & 2 deletions src/databricks/labs/remorph/snow/local_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ class ToArray(Func):
@dataclass
class WithinGroupParams:
agg_col: exp.Column
order_col: exp.Column
is_order_asc: bool
order_cols: list[tuple[exp.Column, bool]] # List of (column, is ascending)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ SELECT
ARRAY_JOIN(
TRANSFORM(
ARRAY_SORT(
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by', col2)),
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col2)),
(left, right) -> CASE
WHEN left.sort_by < right.sort_by THEN 1
WHEN left.sort_by > right.sort_by THEN -1
WHEN left.sort_by_0 < right.sort_by_0 THEN 1
WHEN left.sort_by_0 > right.sort_by_0 THEN -1
ELSE 0
END
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ SELECT
col2,
TRANSFORM(
ARRAY_SORT(
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by', col3)),
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3)),
(left, right) -> CASE
WHEN left.sort_by < right.sort_by THEN 1
WHEN left.sort_by > right.sort_by THEN -1
WHEN left.sort_by_0 < right.sort_by_0 THEN 1
WHEN left.sort_by_0 > right.sort_by_0 THEN -1
ELSE 0
END
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ ORDER BY col2 DESC;
col2,
TRANSFORM(
ARRAY_SORT(
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by', col3)),
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3)),
(left, right) -> CASE
WHEN left.sort_by < right.sort_by THEN -1
WHEN left.sort_by > right.sort_by THEN 1
WHEN left.sort_by_0 < right.sort_by_0 THEN -1
WHEN left.sort_by_0 > right.sort_by_0 THEN 1
ELSE 0
END
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ WITH cte AS (
id,
TRANSFORM(
ARRAY_SORT(
ARRAY_AGG(NAMED_STRUCT('value', tag, 'sort_by', item_count)),
ARRAY_AGG(NAMED_STRUCT('value', tag, 'sort_by_0', item_count)),
(left, right) -> CASE
WHEN left.sort_by < right.sort_by THEN 1
WHEN left.sort_by > right.sort_by THEN -1
WHEN left.sort_by_0 < right.sort_by_0 THEN 1
WHEN left.sort_by_0 > right.sort_by_0 THEN -1
ELSE 0
END
),
Expand Down
32 changes: 32 additions & 0 deletions tests/resources/functional/snowflake/arrays/test_arrayagg_8.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
-- snowflake sql:
SELECT
col2,
ARRAYAGG(col4) WITHIN GROUP (ORDER BY col3, col5)
FROM test_table
WHERE col3 > 450000
GROUP BY col2
ORDER BY col2 DESC;

-- databricks sql:
SELECT
col2,
TRANSFORM(
ARRAY_SORT(
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3, 'sort_by_1', col5)),
(left, right) -> CASE
WHEN left.sort_by_0 < right.sort_by_0 THEN -1
WHEN left.sort_by_0 > right.sort_by_0 THEN 1
WHEN left.sort_by_1 < right.sort_by_1 THEN -1
WHEN left.sort_by_1 > right.sort_by_1 THEN 1
ELSE 0
END
),
s -> s.value
)
FROM test_table
WHERE
col3 > 450000
GROUP BY
col2
ORDER BY
col2 DESC NULLS FIRST;
32 changes: 32 additions & 0 deletions tests/resources/functional/snowflake/arrays/test_arrayagg_9.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
-- snowflake sql:
SELECT
col2,
ARRAYAGG(col4) WITHIN GROUP (ORDER BY col3, col5 DESC)
FROM test_table
WHERE col3 > 450000
GROUP BY col2
ORDER BY col2 DESC;

-- databricks sql:
SELECT
col2,
TRANSFORM(
ARRAY_SORT(
ARRAY_AGG(NAMED_STRUCT('value', col4, 'sort_by_0', col3, 'sort_by_1', col5)),
(left, right) -> CASE
WHEN left.sort_by_0 < right.sort_by_0 THEN -1
WHEN left.sort_by_0 > right.sort_by_0 THEN 1
WHEN left.sort_by_1 < right.sort_by_1 THEN 1
WHEN left.sort_by_1 > right.sort_by_1 THEN -1
ELSE 0
END
),
s -> s.value
)
FROM test_table
WHERE
col3 > 450000
GROUP BY
col2
ORDER BY
col2 DESC NULLS FIRST;

This file was deleted.

Loading