Skip to content

Commit

Permalink
Support chained comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Oct 18, 2024
1 parent 0ce2316 commit ef239e5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 30 deletions.
61 changes: 31 additions & 30 deletions rewrite/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,41 +1198,42 @@ def __sort_call_arguments(self, call: ast.Call) -> List[Union[ast.expr, ast.keyw


def visit_Compare(self, node):
if len(node.ops) != 1:
raise NotImplementedError("Multiple comparisons are not yet supported")

prefix = self.__whitespace()
left = self.__convert(node.left)
op = self.__convert_binary_operator(node.ops[0])

if isinstance(op.element, j.Binary.Type):
return j.Binary(
random_id(),
prefix,
Markers.EMPTY,
left,
op,
self.__convert(node.comparators[0]),
self.__map_type(node)
)
else:
if op.element == py.Binary.Type.IsNot:
negation = self.__source_before('not')
elif op.element == py.Binary.Type.NotIn:
negation = self.__source_before('in')
for i in range(len(node.ops)):
op = self.__convert_binary_operator(node.ops[i])

if isinstance(op.element, j.Binary.Type):
left = j.Binary(
random_id(),
Space.EMPTY,
Markers.EMPTY,
left,
op,
self.__convert(node.comparators[i]),
self.__map_type(node)
)
else:
negation = None
if op.element == py.Binary.Type.IsNot:
negation = self.__source_before('not')
elif op.element == py.Binary.Type.NotIn:
negation = self.__source_before('in')
else:
negation = None

return py.Binary(
random_id(),
prefix,
Markers.EMPTY,
left,
op,
negation,
self.__convert(node.comparators[0]),
self.__map_type(node)
)
left = py.Binary(
random_id(),
prefix,
Markers.EMPTY,
left,
op,
negation,
self.__convert(node.comparators[0]),
self.__map_type(node)
)

return left.with_prefix(prefix)


def __convert_binary_operator(self, op) -> Union[JLeftPadded[j.Binary.Type], JLeftPadded[py.Binary.Type]]:
Expand Down
5 changes: 5 additions & 0 deletions rewrite/tests/python/all/binary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ def test_comparison_ops():
rewrite_run(python("assert 2 > 1"))
# language=python
rewrite_run(python("assert 2 >= 1"))


def test_chained_comparison():
# language=python
rewrite_run(python("assert 1 < 2 <= 3 >=0"))

0 comments on commit ef239e5

Please sign in to comment.