Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(optimizer): expand join constructs into SELECT * from subqueries #1560

Merged
merged 10 commits into from
May 6, 2023
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions sqlglot/optimizer/expand_join_constructs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sqlglot import exp


def expand_join_constructs(expression: exp.Expression) -> exp.Expression:
"""
Replace "join constructs" (*) by equivalent SELECT * subqueries.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS tbl")
>>> expand_join_constructs(expression).sql()
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS tbl'

(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""

def _expand_join_constructs(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Subquery):
unnested = expression.unnest()
if isinstance(unnested, exp.Table):
expression.this.replace(exp.select("*").from_(unnested.copy(), copy=False))

return expression

return expression.transform(_expand_join_constructs)
2 changes: 2 additions & 0 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_join_constructs import expand_join_constructs
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@@ -27,6 +28,7 @@
RULES = (
lower_identities,
qualify_tables,
expand_join_constructs,
isolate_table_selects,
qualify_columns,
pushdown_projections,
2 changes: 1 addition & 1 deletion sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
@@ -510,7 +510,7 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.UDTF):
elif isinstance(scope.expression, (exp.UDTF, exp.Table)):
pass
else:
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
18 changes: 18 additions & 0 deletions tests/fixtures/optimizer/expand_join_constructs.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- This is valid in Trino, so we treat the (tbl AS tbl) as a "join construct" per postgres' terminology.
SELECT * FROM (tbl AS tbl) AS _q_0;
SELECT * FROM (SELECT * FROM tbl AS tbl) AS _q_0;

SELECT * FROM ((tbl AS tbl)) AS _q_0;
SELECT * FROM (SELECT * FROM tbl AS tbl) AS _q_0;

SELECT * FROM (((tbl AS tbl))) AS _q_0;
SELECT * FROM (SELECT * FROM tbl AS tbl) AS _q_0;

SELECT * FROM (tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0;
SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0;

SELECT * FROM ((tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3)) AS _q_0;
SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0;

SELECT * FROM (tbl1 AS tbl1 JOIN (tbl2 AS tbl2 JOIN tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;
SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN (SELECT * FROM tbl2 AS tbl2 JOIN tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;
5 changes: 5 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -235,6 +235,11 @@ def test_expand_laterals(self):
execute=True,
)

def test_expand_join_constructs(self):
self.check_file(
"expand_join_constructs", optimizer.expand_join_constructs.expand_join_constructs
)

def test_expand_multi_table_selects(self):
self.check_file(
"expand_multi_table_selects",