Skip to content

Commit

Permalink
fix/#214/stubborn merge p1 (#246)
Browse files Browse the repository at this point in the history
* refactor: remove AS and TIGHT_WORD_OPERATOR token types

* fix #214: improve rules for stubborn merging p1 operators

* refactor: improve performance regression in stubborn merging

* fix: stubbon merge p0 operators first

* fix: prevent bad stubborn merging of already-merged segments

* fix: fix regression when jinja blocks and brackets are interlaced

* chore: bump primer refs

* fix: update http primer stats
  • Loading branch information
tconbeer authored Aug 21, 2022
1 parent c34fcf2 commit 8e4218f
Show file tree
Hide file tree
Showing 17 changed files with 257 additions and 98 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ All notable changes to this project will be documented in this file.
### Formatting Changes + Bug Fixes
- adds more granularity to operator precedence and will merge lines more aggressively that start with high-precedence operators ([#200](https://github.com/tconbeer/sqlfmt/issues/200))
- improves the formatting of `between ... and ...`, especially in situations where the source includes a line break ([#207](https://github.com/tconbeer/sqlfmt/issues/207))
- improves the consistency of formatting long chains of operators that include parentheses ([#214](https://github.com/tconbeer/sqlfmt/issues/214))
- fixes a bug that caused unnecessary copying of the cache when using multiprocessing. Large projects should see dramatically faster (near-instant) runs once the cache is warm
- fixes a bug that could cause lines with long jinja tags to be one character over the line length limit, and could result in unstable formatting ([#237](https://github.com/tconbeer/sqlfmt/issues/237) - thank you [@nfcampos](https://github.com/nfcampos)!)
- fixes a bug that formatted array literals like they were indexing operations ([#235](https://github.com/tconbeer/sqlfmt/issues/235) - thank you [@nfcampos](https://github.com/nfcampos)!)
Expand Down
4 changes: 2 additions & 2 deletions src/sqlfmt/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def handle_set_operator(
Mostly, when we encounter a set operator (like union) we just want to add
a token with a SET_OPERATOR type. However, EXCEPT is an overloaded
keyword in some dialects (BigQuery) that support `select * except (fields)`.
In this case, except should be a TIGHT_WORD_OPERATOR
In this case, except should be a WORD_OPERATOR
"""
previous_node = analyzer.previous_node
token = Token.from_match(source_string, match, TokenType.SET_OPERATOR)
Expand All @@ -153,7 +153,7 @@ def handle_set_operator(
and prev_token.type == TokenType.STAR
):
token = Token(
type=TokenType.TIGHT_WORD_OPERATOR,
type=TokenType.WORD_OPERATOR,
prefix=token.prefix,
token=token.token,
spos=token.spos,
Expand Down
22 changes: 5 additions & 17 deletions src/sqlfmt/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,21 +274,25 @@ def __init__(self) -> None:
pattern=group(
r"all",
r"any",
r"as",
r"(not\s+)?between",
r"cube",
r"(not\s+)?exists",
r"filter",
r"grouping sets",
r"(not\s+)?in",
r"is(\s+not)?",
r"isnull",
r"(not\s+)?i?like(\s+any)?",
r"over",
r"notnull",
r"(not\s+)?regexp",
r"(not\s+)?rlike",
r"rollup",
r"some",
r"(not\s+)?similar\s+to",
r"using",
r"within\s+group",
)
+ group(r"\W", r"$"),
action=partial(
Expand All @@ -305,25 +309,9 @@ def __init__(self) -> None:
+ group(r"\s+\("),
action=partial(
actions.add_node_to_buffer,
token_type=TokenType.TIGHT_WORD_OPERATOR,
token_type=TokenType.WORD_OPERATOR,
),
),
Rule(
name="agg_modifiers",
priority=922,
pattern=group(r"over", r"within\s+group", r"filter")
+ group(r"\W", r"$"),
action=partial(
actions.add_node_to_buffer,
token_type=TokenType.TIGHT_WORD_OPERATOR,
),
),
Rule(
name="as",
priority=930,
pattern=group(r"as") + group(r"\W", r"$"),
action=partial(actions.add_node_to_buffer, token_type=TokenType.AS),
),
Rule(
name="on",
priority=940,
Expand Down
69 changes: 42 additions & 27 deletions src/sqlfmt/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def create_merged_line(self, lines: List[Line]) -> List[Line]:

return leading_blank_lines + [merged_line] + trailing_blank_lines

def safe_create_merged_line(self, lines: List[Line]) -> List[Line]:
try:
return self.create_merged_line(lines)
except CannotMergeException:
return lines

@classmethod
def _extract_components(
cls, lines: Iterable[Line]
Expand Down Expand Up @@ -275,46 +281,55 @@ def _maybe_stubbornly_merge(self, segments: List[Segment]) -> List[Segment]:
if len(segments) <= 1:
return segments

stubborn_merge_tier = OperatorPrecedence.COMPARATORS

new_segments = [segments[0]]
for segment in segments[1:]:
prev_operator = self._segment_continues_operator_sequence(
new_segments[-1], max_precedence=stubborn_merge_tier
)

# first stubborn-merge all p0 operators
for i, segment in enumerate(segments[1:], start=1):
if (
# always stubbornly merge P0 operators (e.g., `over`)
self._segment_continues_operator_sequence(
segment, max_precedence=OperatorPrecedence.OTHER_TIGHT
)
# stubbornly merge p1 operators only if they do NOT
# follow another p1 operator AND they open brackets
# and cover multiple lines
or (
not prev_operator
and self._segment_continues_operator_sequence(
segment, max_precedence=stubborn_merge_tier
)
and segment.tail_closes_head
)
):
prev_segment = new_segments.pop()
merged_segments = self._stubbornly_merge(prev_segment, segment)
new_segments.extend(merged_segments)
new_segments = self._stubbornly_merge(new_segments, segment)
else:
new_segments.append(segment)

if len(new_segments) == 1:
return new_segments

# next, stubbon-merge qualifying p1 operators
segments = new_segments
new_segments = [segments[0]]

starts_with_p1_operator = [
self._segment_continues_operator_sequence(
segment, max_precedence=OperatorPrecedence.COMPARATORS
)
for segment in segments
]
for i, segment in enumerate(segments[1:], start=1):
if (
not starts_with_p1_operator[i - 1]
and starts_with_p1_operator[i]
and Segment(self.safe_create_merged_line(segment)).tail_closes_head
):
new_segments = self._stubbornly_merge(new_segments, segment)
else:
new_segments.append(segment)

return new_segments

def _stubbornly_merge(
self, prev_segment: Segment, segment: Segment
self, prev_segments: List[Segment], segment: Segment
) -> List[Segment]:
"""
Attempts several different methods of merging prev_segment and
segment. Returns a list of segments that represent the
best possible merger of those two segments
Attempts several different methods of merging the last segment in
new_segments and segment. Returns a list of segments that represent the
best possible merger of those segments
"""
new_segments: List[Segment] = []
new_segments = prev_segments.copy()
prev_segment = new_segments.pop()
head, i = segment.head

# try to merge the first line of this segment with the previous segment
Expand All @@ -339,6 +354,6 @@ def _stubbornly_merge(
new_segments.append(prev_segment)
except CannotMergeException:
# give up and just return the original segments
return [prev_segment, segment]

return new_segments
new_segments.extend([prev_segment, segment])
finally:
return new_segments
2 changes: 0 additions & 2 deletions src/sqlfmt/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ def is_operator(self) -> bool:
in (
TokenType.OPERATOR,
TokenType.WORD_OPERATOR,
TokenType.TIGHT_WORD_OPERATOR,
TokenType.AS,
TokenType.ON,
TokenType.BOOLEAN_OPERATOR,
TokenType.DOUBLE_COLON,
Expand Down
4 changes: 0 additions & 4 deletions src/sqlfmt/node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ def whitespace(
TokenType.STATEMENT_START,
TokenType.STATEMENT_END,
TokenType.WORD_OPERATOR,
TokenType.TIGHT_WORD_OPERATOR,
TokenType.BOOLEAN_OPERATOR,
TokenType.AS,
TokenType.ON,
TokenType.SEMICOLON,
):
Expand Down Expand Up @@ -225,8 +223,6 @@ def standardize_value(self, token: Token) -> str:
TokenType.STATEMENT_START,
TokenType.STATEMENT_END,
TokenType.WORD_OPERATOR,
TokenType.TIGHT_WORD_OPERATOR,
TokenType.AS,
TokenType.ON,
TokenType.BOOLEAN_OPERATOR,
TokenType.SET_OPERATOR,
Expand Down
53 changes: 33 additions & 20 deletions src/sqlfmt/operator_precedence.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def _function_lookup(
) -> Callable[[Node], "OperatorPrecedence"]:
mapping = {
TokenType.DOUBLE_COLON: lambda x: OperatorPrecedence.DOUBLE_COLON,
TokenType.AS: lambda x: OperatorPrecedence.AS,
TokenType.BRACKET_OPEN: lambda x: OperatorPrecedence.SQUARE_BRACKETS,
TokenType.TIGHT_WORD_OPERATOR: lambda x: OperatorPrecedence.OTHER_TIGHT,
TokenType.ON: lambda x: OperatorPrecedence.ON,
TokenType.STAR: lambda x: OperatorPrecedence.MULTIPLICATION,
TokenType.BOOLEAN_OPERATOR: cls._from_boolean,
Expand Down Expand Up @@ -86,26 +84,41 @@ def _from_operator(node: Node) -> "OperatorPrecedence":

@staticmethod
def _from_word_operator(node: Node) -> "OperatorPrecedence":
membership = [
r"(not\s+)?between",
r"(not\s+)?in",
r"(not\s+)?i?like(\s+any)?",
r"(not\s+)?similar\s+to",
r"(not\s+)?rlike",
r"(not\s+)?regexp",
]
membership_prog = [
re.compile(patt, re.IGNORECASE | re.DOTALL) for patt in membership
]
presence = [r"is(\s+not)?", r"isnull", r"notnull"]
presence_prog = [
re.compile(patt, re.IGNORECASE | re.DOTALL) for patt in presence
mapping = [
(OperatorPrecedence.AS, [r"as"]),
(
OperatorPrecedence.OTHER_TIGHT,
[
r"exclude",
r"replace",
r"except",
r"over",
r"within\s+group",
r"filter",
r"using",
],
),
(
OperatorPrecedence.MEMBERSHIP,
[
r"(not\s+)?between",
r"(not\s+)?in",
r"(not\s+)?i?like(\s+any)?",
r"(not\s+)?similar\s+to",
r"(not\s+)?rlike",
r"(not\s+)?regexp",
],
),
(OperatorPrecedence.PRESENCE, [r"is(\s+not)?", r"isnull", r"notnull"]),
]

if any([prog.match(node.value) for prog in membership_prog]):
return OperatorPrecedence.MEMBERSHIP
elif any([prog.match(node.value) for prog in presence_prog]):
return OperatorPrecedence.PRESENCE
for precedence, pattern_list in mapping:
programs = [
re.compile(f"{pattern}$", flags=re.IGNORECASE)
for pattern in pattern_list
]
if any([prog.match(node.value) for prog in programs]):
return precedence
else:
return OperatorPrecedence.OTHER

Expand Down
20 changes: 14 additions & 6 deletions src/sqlfmt/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,22 @@ def tail_closes_head(self) -> bool:
if len(self) <= 1:
return False

head, _ = self.head
tail, _ = self.tail
head, i = self.head
tail, j = self.tail
if head == tail:
return False
elif (
tail.closes_bracket_from_previous_line
or tail.closes_simple_jinja_block_from_previous_line
) and tail.depth == head.depth:

between_lines = self[i + 1 : -(j + 1)]
if tail.depth == head.depth and (
(
tail.closes_bracket_from_previous_line
and all([line.depth[0] > head.depth[0] for line in between_lines])
)
or (
tail.closes_simple_jinja_block_from_previous_line
and all([line.depth[1] > head.depth[1] for line in between_lines])
)
):
return True
else:
return False
Expand Down
2 changes: 0 additions & 2 deletions src/sqlfmt/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class TokenType(Enum):
COLON = auto()
OPERATOR = auto()
WORD_OPERATOR = auto()
TIGHT_WORD_OPERATOR = auto()
AS = auto()
ON = auto()
BOOLEAN_OPERATOR = auto()
COMMA = auto()
Expand Down
12 changes: 6 additions & 6 deletions src/sqlfmt_primer/primer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="gitlab",
git_url="https://github.com/tconbeer/gitlab-analytics-sqlfmt.git",
git_ref="71aee56", # sqlfmt 8b20379
git_ref="11d30d3", # sqlfmt 3e0f900
expected_changed=4,
expected_unchanged=2413,
expected_errored=0,
Expand All @@ -39,7 +39,7 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="rittman",
git_url="https://github.com/tconbeer/rittman_ra_data_warehouse.git",
git_ref="a15f185", # sqlfmt 8b20379
git_ref="5cab7e0", # sqlfmt 3e0f900
expected_changed=0,
expected_unchanged=307,
expected_errored=4, # true mismatching brackets
Expand All @@ -48,9 +48,9 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="http_archive",
git_url="https://github.com/tconbeer/http_archive_almanac.git",
git_ref="4bcd217", # sqlfmt 53d509b
expected_changed=1,
expected_unchanged=1701,
git_ref="2046b2d", # sqlfmt 3e0f900
expected_changed=0,
expected_unchanged=1702,
expected_errored=0,
sub_directory=Path("sql"),
),
Expand All @@ -75,7 +75,7 @@ def get_projects() -> List[SQLProject]:
SQLProject(
name="dbt_utils",
git_url="https://github.com/tconbeer/dbt-utils.git",
git_ref="f676241", # sqlfmt 8b20379
git_ref="55c9199", # sqlfmt 3e0f900
expected_changed=1,
expected_unchanged=130,
expected_errored=0,
Expand Down
Loading

0 comments on commit 8e4218f

Please sign in to comment.