Skip to content

Commit

Permalink
Improve variable lookup to ignore exclusive statements
Browse files Browse the repository at this point in the history
  • Loading branch information
david-yz-liu authored and Pierre-Sassoulas committed Aug 2, 2021
1 parent ac95965 commit 8434159
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 3 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ Release date: TBA

* Added support to infer return type of ``typing.cast()``

* Fix variable lookup's handling of exclusive statements

Closes PyCQA/pylint#3711


What's New in astroid 2.6.5?
============================
Expand Down
15 changes: 12 additions & 3 deletions astroid/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,17 +1211,26 @@ def _filter_stmts(self, stmts, frame, offset):
if not (optional_assign or are_exclusive(_stmts[pindex], node)):
del _stmt_parents[pindex]
del _stmts[pindex]

# If self and node are exclusive, then we can ignore node
if are_exclusive(self, node):
continue

if isinstance(node, AssignName):
# Remove all previously stored assignments if:
# 1. node's statement always assigns
# 2. node has the same parent as self (i.e., they're in the same block)
if not optional_assign and stmt.parent is mystmt.parent:
_stmts = []
_stmt_parents = []
elif isinstance(node, DelName):
# Remove all previously stored assignments
_stmts = []
_stmt_parents = []
continue
if not are_exclusive(self, node):
_stmts.append(node)
_stmt_parents.append(stmt.parent)
# Add the new assignment
_stmts.append(node)
_stmt_parents.append(stmt.parent)
return _stmts


Expand Down
270 changes: 270 additions & 0 deletions tests/unittest_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,5 +474,275 @@ def run1():
self.assertEqual(len(stmts), 0)


class LookupControlFlowTest(unittest.TestCase):
"""Tests for lookup capabilities and control flow"""

def test_consecutive_assign(self):
"""When multiple assignment statements are in the same block, only the last one
is returned.
"""
code = """
x = 10
x = 100
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 3)

def test_assign_after_use(self):
"""An assignment statement appearing after the variable is not returned."""
code = """
print(x)
x = 10
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 0)

def test_del_removes_prior(self):
"""Delete statement removes any prior assignments"""
code = """
x = 10
del x
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 0)

def test_del_no_effect_after(self):
"""Delete statement doesn't remove future assignments"""
code = """
x = 10
del x
x = 100
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 4)

def test_if_assign(self):
"""Assignment in if statement is added to lookup results, but does not replace
prior assignments.
"""
code = """
def f(b):
x = 10
if b:
x = 100
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 2)
self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 5])

def test_if_assigns_same_branch(self):
"""When if branch has multiple assignment statements, only the last one
is added.
"""
code = """
def f(b):
x = 10
if b:
x = 100
x = 1000
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 2)
self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 6])

def test_if_assigns_different_branch(self):
"""When different branches have assignment statements, the last one
in each branch is added.
"""
code = """
def f(b):
x = 10
if b == 1:
x = 100
x = 1000
elif b == 2:
x = 3
elif b == 3:
x = 4
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 4)
self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 6, 8, 10])

def test_assign_exclusive(self):
"""When the variable appears inside a branch of an if statement,
no assignment statements from other branches are returned.
"""
code = """
def f(b):
x = 10
if b == 1:
x = 100
x = 1000
elif b == 2:
x = 3
elif b == 3:
x = 4
else:
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 3)

def test_assign_not_exclusive(self):
"""When the variable appears inside a branch of an if statement,
only the last assignment statement in the same branch is returned.
"""
code = """
def f(b):
x = 10
if b == 1:
x = 100
x = 1000
elif b == 2:
x = 3
elif b == 3:
x = 4
print(x)
else:
x = 5
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 10)

def test_if_else(self):
"""When an assignment statement appears in both an if and else branch, both
are added. This does NOT replace an assignment statement appearing before the
if statement. (See issue #213)
"""
code = """
def f(b):
x = 10
if b:
x = 100
else:
x = 1000
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 3)
self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 5, 7])

def test_if_variable_in_condition_1(self):
"""Test lookup works correctly when a variable appears in an if condition."""
code = """
x = 10
if x > 10:
print('a')
elif x > 0:
print('b')
"""
astroid = builder.parse(code)
x_name1, x_name2 = (
n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"
)

_, stmts1 = x_name1.lookup("x")
self.assertEqual(len(stmts1), 1)
self.assertEqual(stmts1[0].lineno, 2)

_, stmts2 = x_name2.lookup("x")
self.assertEqual(len(stmts2), 1)
self.assertEqual(stmts2[0].lineno, 2)

def test_if_variable_in_condition_2(self):
"""Test lookup works correctly when a variable appears in an if condition,
and the variable is reassigned in each branch.
This is based on PyCQA/pylint issue #3711.
"""
code = """
x = 10
if x > 10:
x = 100
elif x > 0:
x = 200
elif x > -10:
x = 300
else:
x = 400
"""
astroid = builder.parse(code)
x_names = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"]

# All lookups should refer only to the initial x = 10.
for x_name in x_names:
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 2)

def test_del_not_exclusive(self):
"""A delete statement in an if statement branch removes all previous
assignment statements when the delete statement is not exclusive with
the variable (e.g., when the variable is used below the if statement).
"""
code = """
def f(b):
x = 10
if b == 1:
x = 100
elif b == 2:
del x
elif b == 3:
x = 4 # Only this assignment statement is returned
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 9)

def test_del_exclusive(self):
"""A delete statement in an if statement branch that is exclusive with the
variable does not remove previous assignment statements.
"""
code = """
def f(b):
x = 10
if b == 1:
x = 100
elif b == 2:
del x
else:
print(x)
"""
astroid = builder.parse(code)
x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0]
_, stmts = x_name.lookup("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 3)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8434159

Please sign in to comment.