Skip to content

Commit

Permalink
Refactor!: use a dictionary for query modifier search
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jul 3, 2023
1 parent fe69102 commit df4448d
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 43 deletions.
9 changes: 5 additions & 4 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ class Parser(parser.Parser):

QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"settings": lambda self: self._parse_csv(self._parse_conjunction)
if self._match(TokenType.SETTINGS)
else None,
"format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None,
TokenType.SETTINGS: (
"settings",
lambda self: self._advance() or self._parse_csv(self._parse_conjunction),
),
TokenType.FORMAT: ("format", lambda self: self._advance() or self._parse_id_var()),
}

def _parse_conjunction(self) -> t.Optional[exp.Expression]:
Expand Down
14 changes: 0 additions & 14 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,6 @@ class Parser(parser.Parser):
),
}

QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
"distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
"sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
}

def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
Expand Down Expand Up @@ -429,10 +422,3 @@ def datatype_sql(self, expression: exp.DataType) -> str:
expression = exp.DataType.build(expression.this)

return super().datatype_sql(expression)

def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return super().after_having_modifiers(expression) + [
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]
8 changes: 7 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,10 @@ def sql(
return expression

if key:
return self.sql(expression.args.get(key))
value = expression.args.get(key)
if value:
return self.sql(value)
return ""

if self._cache is not None:
expression_id = hash(expression)
Expand Down Expand Up @@ -1600,6 +1603,9 @@ def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]

def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
Expand Down
67 changes: 43 additions & 24 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,19 +737,29 @@ class Parser(metaclass=_Parser):
}

QUERY_MODIFIER_PARSERS = {
"joins": lambda self: list(iter(self._parse_join, None)),
"laterals": lambda self: list(iter(self._parse_lateral, None)),
"match": lambda self: self._parse_match_recognize(),
"where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
"windows": lambda self: self._parse_window_clause(),
"order": lambda self: self._parse_order(),
"limit": lambda self: self._parse_limit(),
"offset": lambda self: self._parse_offset(),
"locks": lambda self: self._parse_locks(),
"sample": lambda self: self._parse_table_sample(as_modifier=True),
TokenType.MATCH_RECOGNIZE: ("match", lambda self: self._parse_match_recognize()),
TokenType.WHERE: ("where", lambda self: self._parse_where()),
TokenType.GROUP_BY: ("group", lambda self: self._parse_group()),
TokenType.HAVING: ("having", lambda self: self._parse_having()),
TokenType.QUALIFY: ("qualify", lambda self: self._parse_qualify()),
TokenType.WINDOW: ("windows", lambda self: self._parse_window_clause()),
TokenType.ORDER_BY: ("order", lambda self: self._parse_order()),
TokenType.LIMIT: ("limit", lambda self: self._parse_limit()),
TokenType.FETCH: ("limit", lambda self: self._parse_limit()),
TokenType.OFFSET: ("offset", lambda self: self._parse_offset()),
TokenType.FOR: ("locks", lambda self: self._parse_locks()),
TokenType.LOCK: ("locks", lambda self: self._parse_locks()),
TokenType.TABLE_SAMPLE: ("sample", lambda self: self._parse_table_sample(as_modifier=True)),
TokenType.USING: ("sample", lambda self: self._parse_table_sample(as_modifier=True)),
TokenType.CLUSTER_BY: (
"cluster",
lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
),
TokenType.DISTRIBUTE_BY: (
"distribute",
lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
),
TokenType.SORT_BY: ("sort", lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY)),
}

SET_PARSERS = {
Expand Down Expand Up @@ -2037,15 +2047,24 @@ def _parse_query_modifiers(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
if isinstance(this, self.MODIFIABLES):
for key, parser in self.QUERY_MODIFIER_PARSERS.items():
expression = parser(self)

if expression:
if key == "limit":
offset = expression.args.pop("offset", None)
if offset:
this.set("offset", exp.Offset(expression=offset))
this.set(key, expression)
for join in iter(self._parse_join, None):
this.append("joins", join)
for lateral in iter(self._parse_lateral, None):
this.append("laterals", lateral)

while True:
if self._match_set(self.QUERY_MODIFIER_PARSERS, advance=False):
key, parser = self.QUERY_MODIFIER_PARSERS[self._curr.token_type]
expression = parser(self)

if expression:
this.set(key, expression)
if key == "limit":
offset = expression.args.pop("offset", None)
if offset:
this.set("offset", exp.Offset(expression=offset))
continue
break
return this

def _parse_hint(self) -> t.Optional[exp.Hint]:
Expand Down Expand Up @@ -2508,8 +2527,8 @@ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Table
kind=kind,
)

def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
return list(iter(self._parse_pivot, None))
def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
return list(iter(self._parse_pivot, None)) or None

# https://duckdb.org/docs/sql/statements/pivot
def _parse_simplified_pivot(self) -> exp.Pivot:
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla
SELECT partition FROM a
SELECT end FROM a
SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1
SELECT * FROM x WHERE a GROUP BY a HAVING b SORT BY s ORDER BY c LIMIT d
SELECT LEFT.FOO FROM BLA AS LEFT
SELECT RIGHT.FOO FROM BLA AS RIGHT
SELECT LEFT FROM LEFT LEFT JOIN RIGHT RIGHT JOIN LEFT
Expand Down

0 comments on commit df4448d

Please sign in to comment.