Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor unsupported DDL handling #650

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file.
- sqlfmt no longer adds a space between the function name and parens for `filter()`, `isnull()`, and `rlike('foo', 'bar')` (but it also permits `filter ()`, `isnull ()`, and `rlike ('foo')` to support dialects where those are operators, not function names) ([#641](https://github.com/tconbeer/sqlfmt/issues/641), [#478](https://github.com/tconbeer/sqlfmt/issues/478) - thank you [@williamscs](https://github.com/williamscs), [@hongtron](https://github.com/hongtron), and [@chwiese](https://github.com/chwiese)!).
- sqlfmt now supports Spark type-hinted numeric literals like `32y` and `+3.2e6bd` and will not introduce a space between the digits and their type suffix ([#640](https://github.com/tconbeer/sqlfmt/issues/640) - thank you [@ShaneMazur](https://github.com/ShaneMazur)!).
- sqlfmt now supports Databricks query hint comments like `/*+ COALESCE(3) */` ([#639](https://github.com/tconbeer/sqlfmt/issues/639) - thank you [@wr-atlas](https://github.com/wr-atlas)!).
- sqlfmt now no-ops instead of errors when encountering `create row access policy` statements with `grant` sub-statements (it also generally more robustly handles unsupported DDL) ([#633](https://github.com/tconbeer/sqlfmt/issues/633)).

## [0.23.3] - 2024-11-12

Expand Down
11 changes: 0 additions & 11 deletions src/sqlfmt/node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,6 @@ def disable_formatting(
):
formatting_disabled.pop()

# formatting can be disabled because of unsupported
# ddl. When we hit a semicolon we need to pop
# all of the formatting disabled tokens caused by ddl
# off the stack
if token.type is TokenType.SEMICOLON:
while (
formatting_disabled
and "fmt:" not in formatting_disabled[-1].token.lower()
):
formatting_disabled.pop()

return formatting_disabled

def append_newline(self, line: Line) -> None:
Expand Down
8 changes: 6 additions & 2 deletions src/sqlfmt/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sqlfmt.rules.function import FUNCTION as FUNCTION
from sqlfmt.rules.grant import GRANT as GRANT
from sqlfmt.rules.jinja import JINJA as JINJA # noqa
from sqlfmt.rules.unsupported import UNSUPPORTED as UNSUPPORTED
from sqlfmt.rules.warehouse import WAREHOUSE as WAREHOUSE
from sqlfmt.token import TokenType

Expand Down Expand Up @@ -77,7 +78,7 @@
r"interval",
r"is(\s+not)?(\s+distinct\s+from)?",
r"isnull",
r"(not\s+)?(r|i)?like(\s+(any|all))?",
r"(not\s+)?i?like(\s+(any|all))?",
r"over",
r"(un)?pivot",
r"notnull",
Expand Down Expand Up @@ -362,7 +363,10 @@
+ group(r"\W", r"$"),
action=partial(
actions.handle_nonreserved_top_level_keyword,
action=partial(actions.add_node_to_buffer, token_type=TokenType.FMT_OFF),
action=partial(
actions.lex_ruleset,
new_ruleset=UNSUPPORTED,
),
),
),
]
35 changes: 20 additions & 15 deletions src/sqlfmt/rules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlfmt.rules.jinja import JINJA
from sqlfmt.token import TokenType

CORE = [
ALWAYS = [
Rule(
name="fmt_off",
priority=0,
Expand Down Expand Up @@ -64,10 +64,27 @@
pattern=group(r"\*/"),
action=actions.raise_sqlfmt_bracket_error,
),
Rule(
name="semicolon",
priority=350,
pattern=group(r";"),
action=actions.handle_semicolon,
),
Rule(
name="newline",
priority=9000,
pattern=group(NEWLINE),
action=actions.handle_newline,
),
]


CORE = [
*ALWAYS,
Rule(
# see https://spark.apache.org/docs/latest/sql-ref-literals.html#integral-literal-syntax
name="spark_int_literals",
priority=349,
priority=400,
pattern=group(
r"(\+|-)?\d+(l|s|y)",
),
Expand All @@ -77,19 +94,13 @@
# the (bd|d|f) groups add support for Spark fractional literals
# https://spark.apache.org/docs/latest/sql-ref-literals.html#fractional-literals-syntax
name="number",
priority=350,
priority=401,
pattern=group(
r"(\+|-)?\d+(\.\d*)?(e(\+|-)?\d+)?(bd|d|f)?",
r"(\+|-)?\.\d+(e(\+|-)?\d+)?(bd|d|f)?",
),
action=actions.handle_number,
),
Rule(
name="semicolon",
priority=400,
pattern=group(r";"),
action=actions.handle_semicolon,
),
Rule(
name="star",
priority=410,
Expand Down Expand Up @@ -210,10 +221,4 @@
pattern=group(r"\w+"),
action=partial(actions.add_node_to_buffer, token_type=TokenType.NAME),
),
Rule(
name="newline",
priority=9000,
pattern=group(NEWLINE),
action=actions.handle_newline,
),
]
20 changes: 20 additions & 0 deletions src/sqlfmt/rules/unsupported.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from functools import partial

from sqlfmt import actions
from sqlfmt.rule import Rule
from sqlfmt.rules.common import NEWLINE, group
from sqlfmt.rules.core import ALWAYS
from sqlfmt.token import TokenType

UNSUPPORTED = [
*ALWAYS,
Rule(
name="unsupported_line",
priority=1000,
pattern=group(r"[^;\n]+?") + group(r";", NEWLINE, r"$"),
action=partial(
actions.handle_reserved_keyword,
action=partial(actions.add_node_to_buffer, token_type=TokenType.DATA),
),
),
]
2 changes: 1 addition & 1 deletion src/sqlfmt_primer/primer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="dbt_utils",
git_url="https://github.com/tconbeer/dbt-utils.git",
git_ref="c62b99f", # sqlfmt 6e9615c
git_ref="3e8412a", # sqlfmt 717530c
expected_changed=0,
expected_unchanged=131,
expected_errored=0,
Expand Down
4 changes: 4 additions & 0 deletions tests/data/preformatted/401_create_row_access_policy.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
create or replace row access policy foo
on foo.bar.baz
grant to ('user1', 'user2')
filter using ( foo = 'bar' )
1 change: 1 addition & 0 deletions tests/functional_tests/test_general_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"preformatted/302_jinjafmt_multiline_str.sql",
"preformatted/303_jinjafmt_more_mutliline_str.sql",
"preformatted/400_create_table.sql",
"preformatted/401_create_row_access_policy.sql",
"unformatted/100_select_case.sql",
"unformatted/101_multiline.sql",
"unformatted/102_lots_of_comments.sql",
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ def test_handle_unsupported_ddl(default_analyzer: Analyzer) -> None:
query = default_analyzer.parse_query(source_string=source_string.lstrip())
assert len(query.lines) == 3
first_create_line = query.lines[0]
assert len(first_create_line.nodes) == 9
assert first_create_line.nodes[0].token.type is TokenType.FMT_OFF
assert len(first_create_line.nodes) == 3 # data, semicolon, newline
assert first_create_line.nodes[0].token.type is TokenType.DATA
assert first_create_line.nodes[-2].token.type is TokenType.SEMICOLON

select_line = query.lines[1]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_disabled_formatting(default_mode: Mode) -> None:
assert create_publication_line.formatting_disabled
assert create_publication_line.nodes
create_token = create_publication_line.nodes[0].token
assert create_token.type is TokenType.FMT_OFF
assert create_token.type is TokenType.DATA
assert create_token in create_publication_line.nodes[0].formatting_disabled
assert len(create_publication_line.nodes[0].formatting_disabled) == 3
semicolon_node = create_publication_line.nodes[-2]
Expand Down