diff --git a/src/sqlfmt/comment.py b/src/sqlfmt/comment.py index a995bcd..5afc3d6 100644 --- a/src/sqlfmt/comment.py +++ b/src/sqlfmt/comment.py @@ -27,7 +27,7 @@ def __str__(self) -> str: if ( self.is_multiline or self.formatting_disabled - or self.is_databricks_type_hint + or self.is_databricks_query_hint ): return self.token.token else: @@ -90,7 +90,7 @@ def is_c_style(self) -> bool: return self.token.token.startswith("/*") @property - def is_databricks_type_hint(self) -> bool: + def is_databricks_query_hint(self) -> bool: return self.token.token.startswith("/*+") @property diff --git a/src/sqlfmt/merger.py b/src/sqlfmt/merger.py index 9736ce1..b2916c9 100644 --- a/src/sqlfmt/merger.py +++ b/src/sqlfmt/merger.py @@ -78,7 +78,7 @@ def _extract_components( "Can't merge lines with inline comments and other comments" ) elif any( - [comment.is_databricks_type_hint for comment in line.comments] + [comment.is_databricks_query_hint for comment in line.comments] ): raise CannotMergeException( "Can't merge lines with a databricks type hint comment" diff --git a/tests/data/unit_tests/test_merger/test_no_merge_databricks_query_hints.sql b/tests/data/unit_tests/test_merger/test_no_merge_databricks_query_hints.sql new file mode 100644 index 0000000..144380f --- /dev/null +++ b/tests/data/unit_tests/test_merger/test_no_merge_databricks_query_hints.sql @@ -0,0 +1,5 @@ +select /*+ my query hint */ + 1 +)))))__SQLFMT_OUTPUT__((((( +select + 1 diff --git a/tests/unit_tests/test_comment.py b/tests/unit_tests/test_comment.py index b6d994c..53ba67a 100644 --- a/tests/unit_tests/test_comment.py +++ b/tests/unit_tests/test_comment.py @@ -2,6 +2,7 @@ import pytest from sqlfmt.comment import Comment +from sqlfmt.node import Node from sqlfmt.node_manager import NodeManager from sqlfmt.token import Token, TokenType @@ -60,6 +61,30 @@ def multiline_comment() -> Comment: return comment +@pytest.fixture +def datbricks_query_hint_comment() -> Comment: + n = Node( + Token( + type=TokenType.UNTERM_KEYWORD, + prefix="", + token="select", + spos=0, + epos=6, + ), + previous_node=None, + prefix="", + value="select", + open_brackets=[], + open_jinja_blocks=[], + formatting_disabled=[], + ) + t = Token( + type=TokenType.COMMENT, prefix=" ", token="/*+ hint here */", spos=6, epos=23 + ) + comment = Comment(t, is_standalone=False, previous_node=n) + return comment + + @pytest.fixture def fmt_disabled_comment() -> Comment: t = Token(type=TokenType.FMT_OFF, prefix="", token="--fmt: off", spos=0, epos=10) @@ -81,11 +106,13 @@ def test_get_marker( short_mysql_comment: Comment, nospace_comment: Comment, short_js_comment: Comment, + datbricks_query_hint_comment: Comment, ) -> None: assert short_comment._get_marker() == ("--", 3) assert short_mysql_comment._get_marker() == ("#", 2) assert short_js_comment._get_marker() == ("//", 3) assert nospace_comment._get_marker() == ("--", 2) + assert datbricks_query_hint_comment._get_marker() == ("/*", 2) def test_comment_parts( @@ -210,3 +237,7 @@ def test_no_wrap_long_jinja_comments() -> None: rendered = comment.render_standalone(88, "") assert rendered == comment_str + "\n" + + +def test_no_add_space_databricks_hint(datbricks_query_hint_comment: Comment) -> None: + assert str(datbricks_query_hint_comment) == datbricks_query_hint_comment.token.token diff --git a/tests/unit_tests/test_merger.py b/tests/unit_tests/test_merger.py index c42d2af..fb22708 100644 --- a/tests/unit_tests/test_merger.py +++ b/tests/unit_tests/test_merger.py @@ -560,3 +560,15 @@ def test_do_not_merge_operator_sequences_across_commas(merger: LineMerger) -> No merged_lines = merger.maybe_merge_lines(raw_query.lines) result_string = "".join([str(line) for line in merged_lines]) assert result_string == expected_string + + +def test_do_not_merge_databricks_query_hints(merger: LineMerger) -> None: + source_string, expected_string = read_test_data( + "unit_tests/test_merger/test_no_merge_databricks_query_hints.sql" + ) + raw_query = merger.mode.dialect.initialize_analyzer( + merger.mode.line_length + ).parse_query(source_string) + merged_lines = merger.maybe_merge_lines(raw_query.lines) + result_string = "".join([str(line) for line in merged_lines]) + assert result_string == expected_string