Skip to content

Commit

Permalink
refactor: create segment class to wrap a list of lines (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
tconbeer authored Aug 10, 2022
1 parent 28a186c commit f3a7b22
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 228 deletions.
143 changes: 32 additions & 111 deletions src/sqlfmt/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlfmt.line import Line
from sqlfmt.mode import Mode
from sqlfmt.node import Node
from sqlfmt.segment import Segment, create_segments_from_lines
from sqlfmt.token import TokenType


Expand All @@ -21,8 +22,6 @@ def create_merged_line(self, lines: List[Line]) -> List[Line]:
any of the lines violate the rules in _raise_unmergeable.
"""

# if the child is just one below the parent, we're trying to
# merge a single line.
if len(lines) <= 1:
return lines

Expand All @@ -46,7 +45,9 @@ def create_merged_line(self, lines: List[Line]) -> List[Line]:
return leading_blank_lines + [merged_line] + trailing_blank_lines

@classmethod
def _extract_components(cls, lines: List[Line]) -> Tuple[List[Node], List[Comment]]:
def _extract_components(
cls, lines: Iterable[Line]
) -> Tuple[List[Node], List[Comment]]:
"""
Given a list of lines, return 2 components:
1. list of all nodes in those lines, with only a single trailing newline
Expand Down Expand Up @@ -130,7 +131,7 @@ def maybe_merge_lines(self, lines: List[Line]) -> List[Line]:
merged_lines = []
# doesn't fit onto a single line, so split into
# segments at the depth of lines[0]
segments = self._split_into_segments(lines)
segments = create_segments_from_lines(lines)
# if a segment starts with a standalone operator,
# the first two lines of that segment should likely
# be merged before doing anything else
Expand All @@ -156,53 +157,26 @@ def maybe_merge_lines(self, lines: List[Line]) -> List[Line]:
# we need to strip that off so we only segment the
# indented lines
else:
_, i = self._get_first_nonblank_line(lines)
merged_lines.extend(lines[: i + 1])
for segment in self._get_remainder_of_segment(lines, i):
only_segment = segments[0]
_, i = only_segment.head
merged_lines.extend(only_segment[: i + 1])
for segment in only_segment.split_after(i):
merged_lines.extend(self.maybe_merge_lines(segment))
finally:
return merged_lines

@classmethod
def _tail_closes_head(cls, segment: List[Line]) -> bool:
"""
Returns True only if the last line in lines closes a bracket or
simple jinja block that is opened by the first line in lines.
"""
if len(segment) <= 1:
return False

head, _ = cls._get_first_nonblank_line(segment)
tail, _ = cls._get_first_nonblank_line(reversed(segment))
if head == tail:
return False
elif (
tail.closes_bracket_from_previous_line
or tail.closes_simple_jinja_block_from_previous_line
) and tail.depth == head.depth:
return True
else:
return False

@staticmethod
def _get_first_nonblank_line(segment: Iterable[Line]) -> Tuple[Line, int]:
for i, line in enumerate(segment):
if not line.is_blank_line:
return line, i
else:
raise SqlfmtSegmentError("All lines in the segment are empty")

def _fix_standalone_operators(self, segments: List[List[Line]]) -> List[List[Line]]:
def _fix_standalone_operators(self, segments: List[Segment]) -> List[Segment]:
"""
If the first line of a segment is a standalone operator,
we should try to merge the first two lines together before
doing anything else
"""
for segment in segments:
try:
head, i = self._get_first_nonblank_line(segment)
head, i = segment.head
if head.is_standalone_operator:
_, j = self._get_first_nonblank_line(segment[i + 1 :])
remainder_after_operator = Segment(segment[i + 1 :])
_, j = remainder_after_operator.head
try:
merged_lines = self.create_merged_line(segment[: i + j + 2])
segment[: i + j + 2] = merged_lines
Expand All @@ -214,9 +188,9 @@ def _fix_standalone_operators(self, segments: List[List[Line]]) -> List[List[Lin

def _maybe_merge_operators(
self,
segments: List[List[Line]],
segments: List[Segment],
priority: int = 2,
) -> List[List[Line]]:
) -> List[Segment]:
"""
Tries to merge runs of segments that start with operators into previous
segments. Operators have a priority that determines a sort of hierarchy;
Expand All @@ -226,7 +200,7 @@ def _maybe_merge_operators(
if len(segments) <= 1 or priority < 0:
return segments
head = 0
new_segments: List[List[Line]] = []
new_segments: List[Segment] = []

for i, segment in enumerate(segments[1:], start=1):
if not self._segment_continues_operator_sequence(segment, priority):
Expand All @@ -245,14 +219,14 @@ def _maybe_merge_operators(

@classmethod
def _segment_continues_operator_sequence(
cls, segment: List[Line], min_priority: int
cls, segment: Segment, min_priority: int
) -> bool:
"""
Returns true if the first line of the segment is part
of a sequence of operators
"""
try:
line, _ = cls._get_first_nonblank_line(segment)
line, _ = segment.head
except SqlfmtSegmentError:
# if a segment is blank, keep scanning
return True
Expand Down Expand Up @@ -281,8 +255,8 @@ def _operator_priority(token_type: TokenType) -> int:
return 0

def _try_merge_operator_segments(
self, segments: List[List[Line]], priority: int
) -> List[List[Line]]:
self, segments: List[Segment], priority: int
) -> List[Segment]:
"""
Attempts to merge segments into a single line; if that fails,
recurses at a lower operator priority
Expand All @@ -291,13 +265,15 @@ def _try_merge_operator_segments(
return segments

try:
new_segments = [self.create_merged_line(list(itertools.chain(*segments)))]
new_segments = [
Segment(self.create_merged_line(list(itertools.chain(*segments))))
]
except CannotMergeException:
new_segments = self._maybe_merge_operators(segments, priority - 1)
finally:
return new_segments

def _maybe_stubbornly_merge(self, segments: List[List[Line]]) -> List[List[Line]]:
def _maybe_stubbornly_merge(self, segments: List[Segment]) -> List[Segment]:
"""
We prefer some operators, like `as`, `over()`, `exclude()`, and
array or dictionary accessing with `[]` to be
Expand Down Expand Up @@ -329,7 +305,7 @@ def _maybe_stubbornly_merge(self, segments: List[List[Line]]) -> List[List[Line]
and self._segment_continues_operator_sequence(
segment, min_priority=1
)
and self._tail_closes_head(segment)
and segment.tail_closes_head
)
):
prev_segment = new_segments.pop()
Expand All @@ -341,24 +317,24 @@ def _maybe_stubbornly_merge(self, segments: List[List[Line]]) -> List[List[Line]
return new_segments

def _stubbornly_merge(
self, prev_segment: List[Line], segment: List[Line]
) -> List[List[Line]]:
self, prev_segment: Segment, segment: Segment
) -> List[Segment]:
"""
Attempts several different methods of merging prev_segment and
segment. Returns a list of segments that represent the
best possible merger of those two segments
"""
new_segments: List[List[Line]] = []
new_segments: List[Segment] = []
# try to merge the first line of this segment with the previous segment
head, i = self._get_first_nonblank_line(segment)
head, i = segment.head

try:
prev_segment = self.create_merged_line(prev_segment + [head])
prev_segment.extend(segment[i + 1 :])
prev_segment = Segment(self.create_merged_line(prev_segment + [head]))
prev_segment.extend(Segment(segment[i + 1 :]))
new_segments.append(prev_segment)
except CannotMergeException:
# try to add this segment to the last line of the previous segment
last_line, k = self._get_first_nonblank_line(reversed(prev_segment))
last_line, k = prev_segment.tail
try:
new_last_lines = self.create_merged_line([last_line] + segment)
prev_segment[-(k + 1) :] = new_last_lines
Expand All @@ -376,58 +352,3 @@ def _stubbornly_merge(
return [prev_segment, segment]

return new_segments

@classmethod
def _get_remainder_of_segment(
cls, segment: List[Line], idx: int
) -> List[List[Line]]:
"""
Takes a segment and an index, and returns a list of either one or two segments,
composed of the lines of segment[idx+1:], depending on whether the segment
ends with a closing bracket
"""
if cls._tail_closes_head(segment):
_, j = cls._get_first_nonblank_line(reversed(segment))
return [
# the lines between the head and tail
segment[idx + 1 : -(j + 1)],
# the tail line (and trailing whitespace)
segment[-(j + 1) :],
]
else:
return [segment[idx + 1 :]]

def _split_into_segments(self, lines: List[Line]) -> List[List[Line]]:
"""
A segment is a list of consecutive lines that are indented from the
first line.
This method takes a list of lines and returns a list of segments.
Is is basically an unfold/corecursion
"""
if not lines:
return []

target_depth = lines[0].depth
head_is_singleton_operator = lines[0].is_standalone_operator
start_idx = 2 if head_is_singleton_operator else 1
for i, line in enumerate(lines[start_idx:], start=start_idx):
# scan through the lines until we get back to the
# depth of the first line
if line.depth <= target_depth or line.depth[1] < target_depth[1]:
# if this line starts with a closing bracket,
# we want to include that closing bracket
# in the same segment as the first line.
if (
line.closes_bracket_from_previous_line
or line.closes_simple_jinja_block_from_previous_line
or line.is_blank_line
) and line.depth == target_depth:
continue
else:
return [lines[:i]] + self._split_into_segments(lines[i:])
else:
# we've exhausted lines without finding any segments, so return a
# single segment comprising the original list
return [lines]
104 changes: 104 additions & 0 deletions src/sqlfmt/segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import List, Sequence, Tuple

from sqlfmt.exception import SqlfmtSegmentError
from sqlfmt.line import Line


def create_segments_from_lines(lines: Sequence[Line]) -> List["Segment"]:
"""
A segment is a list of consecutive lines that are indented from the
first line.
This method takes a list of lines and returns a list of segments.
Is is basically an unfold/corecursion
"""
if not lines:
return []

target_depth = lines[0].depth
head_is_singleton_operator = lines[0].is_standalone_operator
start_idx = 2 if head_is_singleton_operator else 1
for i, line in enumerate(lines[start_idx:], start=start_idx):
# scan through the lines until we get back to the
# depth of the first line
if line.depth <= target_depth or line.depth[1] < target_depth[1]:
# if this line starts with a closing bracket,
# we want to include that closing bracket
# in the same segment as the first line.
if (
line.closes_bracket_from_previous_line
or line.closes_simple_jinja_block_from_previous_line
or line.is_blank_line
) and line.depth == target_depth:
continue
else:
return [Segment(lines[:i])] + create_segments_from_lines(lines[i:])
else:
# we've exhausted lines without finding any segments, so return a
# single segment comprising the original list
return [Segment(lines)]


class Segment(List[Line]):
@property
def head(self) -> Tuple[Line, int]:
"""
Returns the first nonblank line in the Segment, and the index
of that line
"""
for i, line in enumerate(self):
if not line.is_blank_line:
return line, i
else:
raise SqlfmtSegmentError("All lines in the segment are empty")

@property
def tail(self) -> Tuple[Line, int]:
"""
Returns the last nonblank line in the Segment, and the index
of that line (from the bottom. TODO: make the index more obvious)
"""
for i, line in enumerate(reversed(self)):
if not line.is_blank_line:
return line, i
else:
raise SqlfmtSegmentError("All lines in the segment are empty")

@property
def tail_closes_head(self) -> bool:
"""
Returns True only if the last line in the segment closes a bracket or
simple jinja block that is opened by the first line in the segment.
"""
if len(self) <= 1:
return False

head, _ = self.head
tail, _ = self.tail
if head == tail:
return False
elif (
tail.closes_bracket_from_previous_line
or tail.closes_simple_jinja_block_from_previous_line
) and tail.depth == head.depth:
return True
else:
return False

def split_after(self, idx: int) -> List["Segment"]:
"""
Takes an index, and returns a list of either one or two segments,
composed of the lines of self.lines[idx+1:], depending on whether the segment
ends with a closing bracket
"""
if self.tail_closes_head:
_, j = self.tail
return [
# the lines between the head and tail
Segment(self[idx + 1 : -(j + 1)]),
# the tail line (and trailing whitespace)
Segment(self[-(j + 1) :]),
]
else:
return [Segment(self[idx + 1 :])]
Loading

0 comments on commit f3a7b22

Please sign in to comment.