Skip to content

Commit

Permalink
fix: add unit test for coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer committed Nov 22, 2024
1 parent c5810c3 commit 551f31e
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/sqlfmt/comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __str__(self) -> str:
if (
self.is_multiline
or self.formatting_disabled
or self.is_databricks_type_hint
or self.is_databricks_query_hint
):
return self.token.token
else:
Expand Down Expand Up @@ -90,7 +90,7 @@ def is_c_style(self) -> bool:
return self.token.token.startswith("/*")

@property
def is_databricks_type_hint(self) -> bool:
def is_databricks_query_hint(self) -> bool:
return self.token.token.startswith("/*+")

@property
Expand Down
2 changes: 1 addition & 1 deletion src/sqlfmt/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _extract_components(
"Can't merge lines with inline comments and other comments"
)
elif any(
[comment.is_databricks_type_hint for comment in line.comments]
[comment.is_databricks_query_hint for comment in line.comments]
):
raise CannotMergeException(
"Can't merge lines with a databricks type hint comment"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
select /*+ my query hint */
1
)))))__SQLFMT_OUTPUT__(((((
select
1
31 changes: 31 additions & 0 deletions tests/unit_tests/test_comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from sqlfmt.comment import Comment
from sqlfmt.node import Node
from sqlfmt.node_manager import NodeManager
from sqlfmt.token import Token, TokenType

Expand Down Expand Up @@ -60,6 +61,30 @@ def multiline_comment() -> Comment:
return comment


@pytest.fixture
def datbricks_query_hint_comment() -> Comment:
n = Node(
Token(
type=TokenType.UNTERM_KEYWORD,
prefix="",
token="select",
spos=0,
epos=6,
),
previous_node=None,
prefix="",
value="select",
open_brackets=[],
open_jinja_blocks=[],
formatting_disabled=[],
)
t = Token(
type=TokenType.COMMENT, prefix=" ", token="/*+ hint here */", spos=6, epos=23
)
comment = Comment(t, is_standalone=False, previous_node=n)
return comment


@pytest.fixture
def fmt_disabled_comment() -> Comment:
t = Token(type=TokenType.FMT_OFF, prefix="", token="--fmt: off", spos=0, epos=10)
Expand All @@ -81,11 +106,13 @@ def test_get_marker(
short_mysql_comment: Comment,
nospace_comment: Comment,
short_js_comment: Comment,
datbricks_query_hint_comment: Comment,
) -> None:
assert short_comment._get_marker() == ("--", 3)
assert short_mysql_comment._get_marker() == ("#", 2)
assert short_js_comment._get_marker() == ("//", 3)
assert nospace_comment._get_marker() == ("--", 2)
assert datbricks_query_hint_comment._get_marker() == ("/*", 2)


def test_comment_parts(
Expand Down Expand Up @@ -210,3 +237,7 @@ def test_no_wrap_long_jinja_comments() -> None:
rendered = comment.render_standalone(88, "")

assert rendered == comment_str + "\n"


def test_no_add_space_databricks_hint(datbricks_query_hint_comment: Comment) -> None:
assert str(datbricks_query_hint_comment) == datbricks_query_hint_comment.token.token
12 changes: 12 additions & 0 deletions tests/unit_tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,15 @@ def test_do_not_merge_operator_sequences_across_commas(merger: LineMerger) -> No
merged_lines = merger.maybe_merge_lines(raw_query.lines)
result_string = "".join([str(line) for line in merged_lines])
assert result_string == expected_string


def test_do_not_merge_databricks_query_hints(merger: LineMerger) -> None:
source_string, expected_string = read_test_data(
"unit_tests/test_merger/test_no_merge_databricks_query_hints.sql"
)
raw_query = merger.mode.dialect.initialize_analyzer(
merger.mode.line_length
).parse_query(source_string)
merged_lines = merger.maybe_merge_lines(raw_query.lines)
result_string = "".join([str(line) for line in merged_lines])
assert result_string == expected_string

0 comments on commit 551f31e

Please sign in to comment.