diff --git a/src/sqlfmt/node_manager.py b/src/sqlfmt/node_manager.py index 70ec5a33..d7679740 100644 --- a/src/sqlfmt/node_manager.py +++ b/src/sqlfmt/node_manager.py @@ -54,6 +54,7 @@ def raise_on_mismatched_bracket(self, token: Token, last_bracket: Node) -> None: "[": "]", "case": "end", "array<": ">", + "map<": ">", "table<": ">", "struct<": ">", } diff --git a/src/sqlfmt/rules/core.py b/src/sqlfmt/rules/core.py index d8cd4c90..ad0b2438 100644 --- a/src/sqlfmt/rules/core.py +++ b/src/sqlfmt/rules/core.py @@ -116,8 +116,8 @@ r"\[", r"\(", r"\{", - # bq usese angle brackets for type definitions for compound types - r"(array|table|struct)\s*<", + # bq/athena uses angle brackets for type definitions for compound types + r"(array|map|table|struct)\s*<", ), action=partial(actions.add_node_to_buffer, token_type=TokenType.BRACKET_OPEN), ), diff --git a/tests/data/unformatted/130_athena_data_types.sql b/tests/data/unformatted/130_athena_data_types.sql new file mode 100644 index 00000000..d7122679 --- /dev/null +++ b/tests/data/unformatted/130_athena_data_types.sql @@ -0,0 +1,9 @@ +-- source: https://github.com/tconbeer/sqlfmt/issues/500 +select +cast( +json_parse(foo) as array< + map>) +from dwh.table +)))))__SQLFMT_OUTPUT__((((( +-- source: https://github.com/tconbeer/sqlfmt/issues/500 +select cast(json_parse(foo) as array>) from dwh.table diff --git a/tests/functional_tests/test_general_formatting.py b/tests/functional_tests/test_general_formatting.py index 73f11a60..43de6c6e 100644 --- a/tests/functional_tests/test_general_formatting.py +++ b/tests/functional_tests/test_general_formatting.py @@ -46,6 +46,7 @@ "unformatted/127_more_comments.sql", "unformatted/128_double_slash_comments.sql", "unformatted/129_duckdb_joins.sql", + "unformatted/130_athena_data_types.sql", "unformatted/200_base_model.sql", "unformatted/201_basic_snapshot.sql", "unformatted/202_unpivot_macro.sql", diff --git a/tests/unit_tests/test_rule.py b/tests/unit_tests/test_rule.py index 839f15b2..e07ce666 100644 --- a/tests/unit_tests/test_rule.py +++ b/tests/unit_tests/test_rule.py @@ -190,6 +190,7 @@ def get_rule(ruleset: List[Rule], rule_name: str) -> Rule: (MAIN, "set_operator", "minus"), (MAIN, "set_operator", "except"), (CORE, "bracket_open", "array<"), + (CORE, "bracket_open", "map<"), (CORE, "bracket_open", "table\n<"), (CORE, "bracket_open", "struct<"), (MAIN, "explain", "explain"),