diff --git a/src/sqlfmt/rules/__init__.py b/src/sqlfmt/rules/__init__.py index b7011b68..32c1d28c 100644 --- a/src/sqlfmt/rules/__init__.py +++ b/src/sqlfmt/rules/__init__.py @@ -3,13 +3,20 @@ from sqlfmt import actions from sqlfmt.exception import StopRulesetLexing from sqlfmt.rule import Rule -from sqlfmt.rules.common import NEWLINE, SQL_COMMENT, SQL_QUOTED_EXP, group +from sqlfmt.rules.common import ( + ALTER_WAREHOUSE, + CREATE_FUNCTION, + CREATE_WAREHOUSE, + NEWLINE, + SQL_COMMENT, + SQL_QUOTED_EXP, + group, +) from sqlfmt.rules.core import CORE as CORE from sqlfmt.rules.function import FUNCTION as FUNCTION from sqlfmt.rules.grant import GRANT as GRANT from sqlfmt.rules.jinja import JINJA as JINJA # noqa - -# from sqlfmt.rules.warehouse import WAREHOUSE as WAREHOUSE +from sqlfmt.rules.warehouse import WAREHOUSE as WAREHOUSE from sqlfmt.token import TokenType MAIN = [ @@ -196,18 +203,29 @@ Rule( name="create_function", priority=2020, - pattern=group( - ( - r"create(\s+or\s+replace)?(\s+temp(orary)?)?(\s+secure)?(\s+table)?" - r"\s+function(\s+if\s+not\s+exists)?" + pattern=group(CREATE_FUNCTION) + group(r"\W", r"$"), + action=partial( + actions.handle_nonreserved_keyword, + action=partial( + actions.lex_ruleset, + new_ruleset=FUNCTION, + stop_exception=StopRulesetLexing, ), + ), + ), + Rule( + name="create_warehouse", + priority=2030, + pattern=group( + CREATE_WAREHOUSE, + ALTER_WAREHOUSE, ) + group(r"\W", r"$"), action=partial( actions.handle_nonreserved_keyword, action=partial( actions.lex_ruleset, - new_ruleset=FUNCTION, + new_ruleset=WAREHOUSE, stop_exception=StopRulesetLexing, ), ), diff --git a/src/sqlfmt/rules/common.py b/src/sqlfmt/rules/common.py index bdaa0d96..1b3771f9 100644 --- a/src/sqlfmt/rules/common.py +++ b/src/sqlfmt/rules/common.py @@ -32,4 +32,5 @@ def group(*choices: str) -> str: r"create(\s+or\s+replace)?(\s+temp(orary)?)?(\s+secure)?(\s+table)?" r"\s+function(\s+if\s+not\s+exists)?" ) -CREATE_WAREHOUSE = r"create(\s+or\s+replace)?warehouse(\s+if\s+not\s+exists)?" +CREATE_WAREHOUSE = r"create(\s+or\s+replace)?\s+warehouse(\s+if\s+not\s+exists)?" +ALTER_WAREHOUSE = r"alter\s+warehouse(\s+if\s+exists)?" diff --git a/src/sqlfmt/rules/function.py b/src/sqlfmt/rules/function.py index 92a580ef..2761793d 100644 --- a/src/sqlfmt/rules/function.py +++ b/src/sqlfmt/rules/function.py @@ -51,7 +51,6 @@ r"rows", r"support", r"set", - r"as", # snowflake r"comment", r"imports", diff --git a/src/sqlfmt/rules/warehouse.py b/src/sqlfmt/rules/warehouse.py new file mode 100644 index 00000000..54eb0edd --- /dev/null +++ b/src/sqlfmt/rules/warehouse.py @@ -0,0 +1,50 @@ +from functools import partial + +from sqlfmt import actions +from sqlfmt.rule import Rule +from sqlfmt.rules.common import ALTER_WAREHOUSE, CREATE_WAREHOUSE, group +from sqlfmt.rules.core import CORE +from sqlfmt.token import TokenType + +WAREHOUSE = [ + *CORE, + Rule( + name="unterm_keyword", + priority=1300, + pattern=group( + CREATE_WAREHOUSE, + ALTER_WAREHOUSE, + # objectProperties + r"(with\s+|(un)?set\s+)?" + + group( + r"warehouse_type", + r"warehouse_size", + r"max_cluster_count", + r"min_cluster_count", + r"scaling_policy", + r"auto_suspend", + r"auto_resume", + r"initially_suspended", + r"resource_monitor", + r"comment", + r"enable_query_acceleration", + r"query_acceleration_max_scale_factor", + r"tag", + ), + # objectParams + r"(set\s+)?" + + group( + r"max_concurrency_level", + r"statement_queued_timeout_in_seconds", + r"statement_timeout_in_seconds", + ), + # alter + r"suspend", + r"resume(\s+if\s+suspended)?", + r"abort\s+all\s+queries", + r"rename\s+to", + ) + + group(r"\W", r"$"), + action=partial(actions.add_node_to_buffer, token_type=TokenType.UNTERM_KEYWORD), + ), +] diff --git a/tests/data/unformatted/410_create_warehouse.sql b/tests/data/unformatted/410_create_warehouse.sql new file mode 100644 index 00000000..e0216d67 --- /dev/null +++ b/tests/data/unformatted/410_create_warehouse.sql @@ -0,0 +1,44 @@ +create or replace warehouse foo +warehouse_size='XLARGE' +warehouse_type='SNOWPARK-OPTIMIZED' +max_cluster_count=6; +create + warehouse if not exists foo + with warehouse_size = 'X5LARGE' + AUTO_SUSPEND = 100 + AUTO_RESUME = FALSE + INITIALLY_SUSPENDED = TRUE; + +alter warehouse if exists foo set warehouse_size='XSMALL'; alter warehouse if exists foo set tag 'foobar'='baz', 'another_really_long_tag_name'='really_very_long_tag_value_quxxxxxxxxxxxxxxxxxxx', 'bar'='baz'; + +alter warehouse foo rename to bar; +alter warehouse bar resume if suspended; +)))))__SQLFMT_OUTPUT__((((( +create or replace warehouse foo +warehouse_size = 'XLARGE' +warehouse_type = 'SNOWPARK-OPTIMIZED' +max_cluster_count = 6 +; +create warehouse if not exists foo +with warehouse_size = 'X5LARGE' +auto_suspend = 100 +auto_resume = false +initially_suspended = true +; + +alter warehouse if exists foo +set warehouse_size = 'XSMALL' +; +alter warehouse if exists foo +set tag + 'foobar' = 'baz', + 'another_really_long_tag_name' = 'really_very_long_tag_value_quxxxxxxxxxxxxxxxxxxx', + 'bar' = 'baz' +; + +alter warehouse foo +rename to bar +; +alter warehouse bar +resume if suspended +; diff --git a/tests/functional_tests/test_general_formatting.py b/tests/functional_tests/test_general_formatting.py index b667a23b..3ea89da4 100644 --- a/tests/functional_tests/test_general_formatting.py +++ b/tests/functional_tests/test_general_formatting.py @@ -60,6 +60,7 @@ "unformatted/404_create_function_pg_examples.sql", "unformatted/405_create_function_snowflake_examples.sql", "unformatted/406_create_function_bq_examples.sql", + "unformatted/410_create_warehouse.sql", ], ) def test_formatting(p: str) -> None: diff --git a/tests/unit_tests/test_rule.py b/tests/unit_tests/test_rule.py index 151e3c91..94a3d9c4 100644 --- a/tests/unit_tests/test_rule.py +++ b/tests/unit_tests/test_rule.py @@ -4,7 +4,7 @@ import pytest from sqlfmt.rule import Rule -from sqlfmt.rules import CORE, FUNCTION, GRANT, JINJA, MAIN +from sqlfmt.rules import CORE, FUNCTION, GRANT, JINJA, MAIN, WAREHOUSE def get_rule(ruleset: List[Rule], rule_name: str) -> Rule: @@ -275,6 +275,22 @@ def get_rule(ruleset: List[Rule], rule_name: str) -> Rule: (FUNCTION, "unterm_keyword", "not null"), (FUNCTION, "unterm_keyword", "handler"), (FUNCTION, "unterm_keyword", "packages"), + (MAIN, "create_warehouse", "create warehouse if not exists"), + (MAIN, "create_warehouse", "alter warehouse if exists"), + (WAREHOUSE, "unterm_keyword", "create warehouse if not exists"), + (WAREHOUSE, "unterm_keyword", "create or replace warehouse"), + (WAREHOUSE, "unterm_keyword", "warehouse_type"), + (WAREHOUSE, "unterm_keyword", "with warehouse_size"), + (WAREHOUSE, "unterm_keyword", "set warehouse_size"), + (WAREHOUSE, "unterm_keyword", "max_cluster_count"), + (WAREHOUSE, "unterm_keyword", "min_cluster_count"), + (WAREHOUSE, "unterm_keyword", "auto_suspend"), + (WAREHOUSE, "unterm_keyword", "auto_resume"), + (WAREHOUSE, "unterm_keyword", "alter warehouse if exists"), + (WAREHOUSE, "unterm_keyword", "rename to"), + (WAREHOUSE, "unterm_keyword", "set tag"), + (WAREHOUSE, "unterm_keyword", "resume if suspended"), + (WAREHOUSE, "unterm_keyword", "unset scaling_policy"), ], ) def test_regex_exact_match(