diff --git a/CHANGELOG.md b/CHANGELOG.md index fae9e59a..d0594ea5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Formatting Changes and Bug Fixes + +- sqlfmt will now parse unquoted reserved keywords as names if they are qualified by a period, e.g., `foo.select` or `foo.case` ([#599](https://github.com/tconbeer/sqlfmt/issues/599) - thank you [@matthieucan](https://github.com/matthieucan)!). + ## [0.22.0] - 2024-07-25 ### Formatting Changes and Bug Fixes diff --git a/src/sqlfmt/actions.py b/src/sqlfmt/actions.py index 383c48ef..216c91c9 100644 --- a/src/sqlfmt/actions.py +++ b/src/sqlfmt/actions.py @@ -308,7 +308,41 @@ def handle_number(analyzer: "Analyzer", source_string: str, match: re.Match) -> ) -def handle_nonreserved_keyword( +def handle_reserved_keyword( + analyzer: "Analyzer", + source_string: str, + match: re.Match, + action: Callable[["Analyzer", str, re.Match], None], +) -> None: + """ + Reserved keywords can be used in most dialects as table or column + names without quoting if they are qualified. + https://github.com/tconbeer/sqlfmt/issues/599 + + This checks if the previous token is a period, and if so, lexes + this token as a name; otherwise this action executes the passed + action (which likely adds the node as some kind of keyword). + """ + if analyzer.previous_node is None: + action(analyzer, source_string, match) + return + + previous_token, _ = get_previous_token(analyzer.previous_node) + if previous_token is not None and previous_token.type is TokenType.DOT: + token = Token.from_match(source_string, match, token_type=TokenType.NAME) + if not token.prefix: + node = analyzer.node_manager.create_node( + token=token, previous_node=analyzer.previous_node + ) + analyzer.node_buffer.append(node) + analyzer.pos = token.epos + return + + # in all other cases, this is a keyword. + action(analyzer, source_string, match) + + +def handle_nonreserved_top_level_keyword( analyzer: "Analyzer", source_string: str, match: re.Match, @@ -317,6 +351,10 @@ def handle_nonreserved_keyword( """ Checks to see if we're at depth 0 (assuming this is a name); if so, then take the passed action, otherwise lex it as a name. + + For example, this allows us to lex these differently: + explain select 1; + select explain, 1; """ token = Token.from_match(source_string, match, token_type=TokenType.NAME) node = analyzer.node_manager.create_node( @@ -326,7 +364,9 @@ def handle_nonreserved_keyword( analyzer.node_buffer.append(node) analyzer.pos = token.epos else: - action(analyzer, source_string, match) + handle_reserved_keyword( + analyzer=analyzer, source_string=source_string, match=match, action=action + ) def lex_ruleset( diff --git a/src/sqlfmt/rules/__init__.py b/src/sqlfmt/rules/__init__.py index e4eb943d..490808cb 100644 --- a/src/sqlfmt/rules/__init__.py +++ b/src/sqlfmt/rules/__init__.py @@ -25,7 +25,10 @@ priority=1000, pattern=group(r"case") + group(r"\W", r"$"), action=partial( - actions.add_node_to_buffer, token_type=TokenType.STATEMENT_START + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.STATEMENT_START + ), ), ), Rule( @@ -33,9 +36,12 @@ priority=1010, pattern=group(r"end") + group(r"\W", r"$"), action=partial( - actions.safe_add_node_to_buffer, - token_type=TokenType.STATEMENT_END, - fallback_token_type=TokenType.NAME, + actions.handle_reserved_keyword, + action=partial( + actions.safe_add_node_to_buffer, + token_type=TokenType.STATEMENT_END, + fallback_token_type=TokenType.NAME, + ), ), ), Rule( @@ -65,7 +71,12 @@ r"within\s+group", ) + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR + ), + ), ), Rule( name="star_replace_exclude", @@ -76,8 +87,10 @@ ) + group(r"\s+\("), action=partial( - actions.add_node_to_buffer, - token_type=TokenType.WORD_OPERATOR, + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR + ), ), ), Rule( @@ -88,13 +101,21 @@ name="join_using", priority=1110, pattern=group(r"using") + group(r"\s*\("), - action=partial(actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR + ), + ), ), Rule( name="on", priority=1120, pattern=group(r"on") + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.ON), + action=partial( + actions.handle_reserved_keyword, + action=partial(actions.add_node_to_buffer, token_type=TokenType.ON), + ), ), Rule( name="boolean_operator", @@ -106,8 +127,10 @@ ) + group(r"\W", r"$"), action=partial( - actions.add_node_to_buffer, - token_type=TokenType.BOOLEAN_OPERATOR, + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.BOOLEAN_OPERATOR + ), ), ), Rule( @@ -154,7 +177,12 @@ r"returning", ) + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD + ), + ), ), Rule( name="frame_clause", @@ -162,7 +190,12 @@ pattern=group(r"(range|rows|groups)\s+") + group(r"(between\s+)?((unbounded|\d+)\s+(preceding|following)|current\s+row)") + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD + ), + ), ), Rule( # BQ arrays use an offset(n) function for @@ -172,7 +205,12 @@ name="offset_keyword", priority=1310, pattern=group(r"offset") + group(r"\s+", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD + ), + ), ), Rule( name="set_operator", @@ -181,7 +219,10 @@ r"(union|intersect|except|minus)(\s+(all|distinct))?(\s+by\s+name)?", ) + group(r"\W", r"$"), - action=actions.handle_set_operator, + action=partial( + actions.handle_reserved_keyword, + action=actions.handle_set_operator, + ), ), Rule( name="explain", @@ -189,7 +230,7 @@ pattern=group(r"explain(\s+(analyze|verbose|using\s+(tabular|json|text)))?") + group(r"\W", r"$"), action=partial( - actions.handle_nonreserved_keyword, + actions.handle_nonreserved_top_level_keyword, action=partial( actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD ), @@ -200,7 +241,7 @@ priority=2010, pattern=group(r"grant", r"revoke") + group(r"\W", r"$"), action=partial( - actions.handle_nonreserved_keyword, + actions.handle_nonreserved_top_level_keyword, action=partial(actions.lex_ruleset, new_ruleset=GRANT), ), ), @@ -209,7 +250,7 @@ priority=2015, pattern=group(CREATE_CLONABLE + r"\s+.+?\s+clone") + group(r"\W", r"$"), action=partial( - actions.handle_nonreserved_keyword, + actions.handle_nonreserved_top_level_keyword, action=partial( actions.lex_ruleset, new_ruleset=CLONE, @@ -221,7 +262,7 @@ priority=2020, pattern=group(CREATE_FUNCTION, ALTER_DROP_FUNCTION) + group(r"\W", r"$"), action=partial( - actions.handle_nonreserved_keyword, + actions.handle_nonreserved_top_level_keyword, action=partial( actions.lex_ruleset, new_ruleset=FUNCTION, @@ -237,7 +278,7 @@ ) + group(r"\W", r"$"), action=partial( - actions.handle_nonreserved_keyword, + actions.handle_nonreserved_top_level_keyword, action=partial( actions.lex_ruleset, new_ruleset=WAREHOUSE, @@ -301,7 +342,7 @@ + r"(?!\()" + group(r"\W", r"$"), action=partial( - actions.handle_nonreserved_keyword, + actions.handle_nonreserved_top_level_keyword, action=partial(actions.add_node_to_buffer, token_type=TokenType.FMT_OFF), ), ), diff --git a/src/sqlfmt/rules/clone.py b/src/sqlfmt/rules/clone.py index 1022ffc3..bc052bcd 100644 --- a/src/sqlfmt/rules/clone.py +++ b/src/sqlfmt/rules/clone.py @@ -12,7 +12,12 @@ name="unterm_keyword", priority=1300, pattern=group(CREATE_CLONABLE, r"clone") + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD + ), + ), ), Rule( name="word_operator", @@ -22,6 +27,11 @@ r"before", ) + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR + ), + ), ), ] diff --git a/src/sqlfmt/rules/function.py b/src/sqlfmt/rules/function.py index 4402345f..d65993ce 100644 --- a/src/sqlfmt/rules/function.py +++ b/src/sqlfmt/rules/function.py @@ -15,7 +15,10 @@ r"as", ) + group(r"\W", r"$"), - action=actions.handle_ddl_as, + action=partial( + actions.handle_reserved_keyword, + action=actions.handle_ddl_as, + ), ), Rule( name="word_operator", @@ -27,7 +30,12 @@ r"runtime_version", ) + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.WORD_OPERATOR + ), + ), ), Rule( name="unterm_keyword", @@ -85,6 +93,11 @@ r"(re)?set(\s+all)?", ) + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD + ), + ), ), ] diff --git a/src/sqlfmt/rules/grant.py b/src/sqlfmt/rules/grant.py index b6516efe..cce9128e 100644 --- a/src/sqlfmt/rules/grant.py +++ b/src/sqlfmt/rules/grant.py @@ -23,6 +23,11 @@ r"restrict", ) + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD + ), + ), ), ] diff --git a/src/sqlfmt/rules/warehouse.py b/src/sqlfmt/rules/warehouse.py index 54eb0edd..e7644366 100644 --- a/src/sqlfmt/rules/warehouse.py +++ b/src/sqlfmt/rules/warehouse.py @@ -45,6 +45,11 @@ r"rename\s+to", ) + group(r"\W", r"$"), - action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + action=partial( + actions.handle_reserved_keyword, + action=partial( + actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD + ), + ), ), ] diff --git a/tests/data/preformatted/008_reserved_names.sql b/tests/data/preformatted/008_reserved_names.sql new file mode 100644 index 00000000..17ceb4e6 --- /dev/null +++ b/tests/data/preformatted/008_reserved_names.sql @@ -0,0 +1,21 @@ +with + "select" as (select * from foo.select), + "case" as (select * from foo.case), + "end" as (select * from foo.end), + "as" as (select * from foo.as), + "interval" as (select * from foo.interval), + "exclude" as (select * from foo.exclude), + "using" as (select * from foo.using), + "on" as (select * from foo.on), + "and" as (select * from foo.and), + "limit" as (select * from foo.limit), + "over" as (select * from foo.over), + "between" as (select * from foo.between), + "union" as (select * from foo.union), + "explain" as (select * from foo.explain), + "grant" as (select * from foo.grant), + "create" as (select * from foo.create), + "alter" as (select * from foo.alter), + "truncate" as (select * from foo.truncate), + "drop" as (select * from foo.drop) +select 1 diff --git a/tests/functional_tests/test_general_formatting.py b/tests/functional_tests/test_general_formatting.py index d0bfae9f..d11e0a2b 100644 --- a/tests/functional_tests/test_general_formatting.py +++ b/tests/functional_tests/test_general_formatting.py @@ -15,6 +15,7 @@ "preformatted/005_fmt_off.sql", "preformatted/006_fmt_off_447.sql", "preformatted/007_fmt_off_comments.sql", + "preformatted/008_reserved_names.sql", "preformatted/301_multiline_jinjafmt.sql", "preformatted/302_jinjafmt_multiline_str.sql", "preformatted/400_create_table.sql", diff --git a/tests/unit_tests/test_actions.py b/tests/unit_tests/test_actions.py index 05491bb4..cc9f960c 100644 --- a/tests/unit_tests/test_actions.py +++ b/tests/unit_tests/test_actions.py @@ -603,3 +603,55 @@ def test_handle_nested_dictionary_in_jinja_expression( token_type=TokenType.JINJA_EXPRESSION, ) assert jinja_analyzer.pos == 355 + + +def test_handle_reserved_keywords(default_analyzer: Analyzer) -> None: + source_string = """ + select case; + select foo.case; + select foo.select; + interval; + foo.interval; + explain; + foo.explain; + """ + query = default_analyzer.parse_query(source_string=source_string.lstrip()) + assert len(query.lines) == 7 + case_line = query.lines[0] + assert len(case_line.nodes) == 4 + assert case_line.nodes[0].token.type is TokenType.UNTERM_KEYWORD + assert case_line.nodes[1].token.type is TokenType.STATEMENT_START + + case_name_line = query.lines[1] + assert len(case_name_line.nodes) == 6 + assert case_name_line.nodes[0].token.type is TokenType.UNTERM_KEYWORD + assert case_name_line.nodes[1].token.type is TokenType.NAME + assert case_name_line.nodes[2].token.type is TokenType.DOT + assert case_name_line.nodes[3].token.type is TokenType.NAME + + select_name_line = query.lines[2] + assert len(select_name_line.nodes) == 6 + assert select_name_line.nodes[0].token.type is TokenType.UNTERM_KEYWORD + assert select_name_line.nodes[1].token.type is TokenType.NAME + assert select_name_line.nodes[2].token.type is TokenType.DOT + assert select_name_line.nodes[3].token.type is TokenType.NAME + + interval_line = query.lines[3] + assert len(interval_line.nodes) == 3 + assert interval_line.nodes[0].token.type is TokenType.WORD_OPERATOR + + interval_name_line = query.lines[4] + assert len(interval_name_line.nodes) == 5 + assert interval_name_line.nodes[0].token.type is TokenType.NAME + assert interval_name_line.nodes[1].token.type is TokenType.DOT + assert interval_name_line.nodes[2].token.type is TokenType.NAME + + explain_line = query.lines[5] + assert len(explain_line.nodes) == 3 + assert explain_line.nodes[0].token.type is TokenType.UNTERM_KEYWORD + + explain_name_line = query.lines[6] + assert len(explain_name_line.nodes) == 5 + assert explain_name_line.nodes[0].token.type is TokenType.NAME + assert explain_name_line.nodes[1].token.type is TokenType.DOT + assert explain_name_line.nodes[2].token.type is TokenType.NAME