Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(optimizer): expand join constructs into SELECT * from subqueries #1560

Merged
merged 10 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions sqlglot/optimizer/qualify_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,42 @@

def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.

Example:
Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
>>> qualify_tables(expression).sql()
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'

Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
schema: A schema to populate

Returns:
sqlglot.Expression: qualified expression

(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
sequence = itertools.count()

next_name = lambda: f"_q_{next(sequence)}"

for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
# Expand join construct
if isinstance(derived_table, exp.Subquery):
unnested = derived_table.unnest()
if isinstance(unnested, exp.Table):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))

if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
Expand Down
6 changes: 6 additions & 0 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
# This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
Expand Down Expand Up @@ -587,6 +590,9 @@ def _traverse_tables(scope):
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)

if isinstance(scope.expression, exp.Table):
expressions.append(scope.expression)

expressions.extend(scope.expression.args.get("laterals") or [])

for expression in expressions:
Expand Down
23 changes: 23 additions & 0 deletions tests/fixtures/optimizer/qualify_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,26 @@ WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a;

SELECT (SELECT y.c FROM y AS y) FROM x;
SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x;

-------------------------
-- Expand join constructs
-------------------------

-- This is valid in Trino, so we treat the (tbl AS tbl) as a "join construct" per postgres' terminology.
SELECT * FROM (tbl AS tbl) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;

SELECT * FROM ((tbl AS tbl)) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;

SELECT * FROM (((tbl AS tbl))) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0;

SELECT * FROM (tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0;

SELECT * FROM ((tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3)) AS _q_0;
SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0;

SELECT * FROM (tbl1 AS tbl1 JOIN (tbl2 AS tbl2 JOIN tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;
SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN (SELECT * FROM c.db.tbl2 AS tbl2 JOIN c.db.tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1;