Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scope] keep track of assignment/access ordering #413

Merged
merged 3 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@
)


_ASSIGNMENT_LIKE_NODES = (
cst.AnnAssign,
cst.AsName,
cst.Assign,
cst.AugAssign,
cst.ClassDef,
cst.CompFor,
cst.For,
cst.FunctionDef,
cst.Global,
cst.Import,
cst.ImportFrom,
cst.NamedExpr,
cst.Nonlocal,
cst.Parameters,
cst.WithItem,
)


@add_slots
@dataclass(frozen=False)
class Access:
Expand Down Expand Up @@ -68,6 +87,7 @@ def __new__(cls) -> "Tree":
is_type_hint: bool

__assignments: Set["BaseAssignment"]
__index: int

def __init__(
self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool
Expand All @@ -77,6 +97,7 @@ def __init__(
self.is_annotation = is_annotation
self.is_type_hint = is_type_hint
self.__assignments = set()
self.__index = scope._assignment_count

def __hash__(self) -> int:
return id(self)
Expand All @@ -86,11 +107,25 @@ def referents(self) -> Collection["BaseAssignment"]:
"""Return all assignments of the access."""
return self.__assignments

def record_assignment(self, assignment: "BaseAssignment") -> None:
self.__assignments.add(assignment)
@property
def _index(self) -> int:
return self.__index

def record_assignments(self, assignments: Set["BaseAssignment"]) -> None:
self.__assignments |= assignments
def record_assignment(self, assignment: "BaseAssignment") -> None:
if assignment.scope != self.scope or assignment._index < self.__index:
self.__assignments.add(assignment)

def record_assignments(self, name: str) -> None:
assignments = self.scope[name]
# filter out assignments that happened later than this access
previous_assignments = {
assignment
for assignment in assignments
if assignment.scope != self.scope or assignment._index < self.__index
}
if not previous_assignments and assignments:
previous_assignments = self.scope.parent[name]
self.__assignments |= previous_assignments


class BaseAssignment(abc.ABC):
Expand All @@ -109,10 +144,22 @@ def __init__(self, name: str, scope: "Scope") -> None:
self.__accesses = set()

def record_access(self, access: Access) -> None:
self.__accesses.add(access)
if access.scope != self.scope or self._index < access._index:
self.__accesses.add(access)

def record_accesses(self, accesses: Set[Access]) -> None:
self.__accesses |= accesses
later_accesses = {
access
for access in accesses
if access.scope != self.scope or self._index < access._index
}
self.__accesses |= later_accesses
earlier_accesses = accesses - later_accesses
if earlier_accesses and self.scope.parent != self.scope:
# Accesses "earlier" than the relevant assignment should be attached
# to assignments of the same name in the parent
for shadowed_assignment in self.scope.parent[self.name]:
shadowed_assignment.record_accesses(earlier_accesses)

@property
def references(self) -> Collection[Access]:
Expand All @@ -123,18 +170,31 @@ def references(self) -> Collection[Access]:
def __hash__(self) -> int:
return id(self)

@property
def _index(self) -> int:
"""Return an integer that represents the order of assignments in `scope`"""
return -1


class Assignment(BaseAssignment):
"""An assignment records the name, CSTNode and its accesses."""

#: The node of assignment, it could be a :class:`~libcst.Import`, :class:`~libcst.ImportFrom`,
#: :class:`~libcst.Name`, :class:`~libcst.FunctionDef`, or :class:`~libcst.ClassDef`.
node: cst.CSTNode
__index: int

def __init__(self, name: str, scope: "Scope", node: cst.CSTNode) -> None:
def __init__(
self, name: str, scope: "Scope", node: cst.CSTNode, index: int
) -> None:
self.node = node
self.__index = index
super().__init__(name, scope)

@property
def _index(self) -> int:
return self.__index


# even though we don't override the constructor.
class BuiltinAssignment(BaseAssignment):
Expand Down Expand Up @@ -318,16 +378,20 @@ class Scope(abc.ABC):
globals: "GlobalScope"
_assignments: MutableMapping[str, Set[BaseAssignment]]
_accesses: MutableMapping[str, Set[Access]]
_assignment_count: int

def __init__(self, parent: "Scope") -> None:
super().__init__()
self.parent = parent
self.globals = parent.globals
self._assignments = defaultdict(set)
self._accesses = defaultdict(set)
self._assignment_count = 0

def record_assignment(self, name: str, node: cst.CSTNode) -> None:
self._assignments[name].add(Assignment(name=name, scope=self, node=node))
self._assignments[name].add(
Assignment(name=name, scope=self, node=node, index=self._assignment_count)
)

def record_access(self, name: str, access: Access) -> None:
self._accesses[name].add(access)
Expand Down Expand Up @@ -932,7 +996,7 @@ def infer_accesses(self) -> None:
break

scope_name_accesses[(access.scope, name)].add(access)
access.record_assignments(access.scope[name])
access.record_assignments(name)
access.scope.record_access(name, access)

for (scope, name), accesses in scope_name_accesses.items():
Expand All @@ -943,6 +1007,8 @@ def infer_accesses(self) -> None:

def on_leave(self, original_node: cst.CSTNode) -> None:
self.provider.set_metadata(original_node, self.scope)
if isinstance(original_node, _ASSIGNMENT_LIKE_NODES):
self.scope._assignment_count += 1
super().on_leave(original_node)


Expand Down
102 changes: 102 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,3 +1310,105 @@ def test_gen_dotted_names(self) -> None:
)
}
self.assertEqual(names, {"a.b.c", "a.b", "a"})

def test_ordering(self) -> None:
zsol marked this conversation as resolved.
Show resolved Hide resolved
m, scopes = get_scope_metadata_provider(
"""
from a import b
class X:
x = b
b = b
y = b
"""
)
global_scope = scopes[m]
import_stmt = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
first_assignment = list(global_scope.assignments)[0]
assert isinstance(first_assignment, cst.metadata.Assignment)
self.assertEqual(first_assignment.node, import_stmt)
global_refs = list(first_assignment.references)
self.assertEqual(len(global_refs), 2)
class_def = ensure_type(m.body[1], cst.ClassDef)
x = ensure_type(
ensure_type(class_def.body.body[0], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertEqual(x.value, global_refs[0].node)
class_b = ensure_type(
ensure_type(class_def.body.body[1], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertEqual(class_b.value, global_refs[1].node)

class_accesses = list(scopes[x].accesses)
self.assertEqual(len(class_accesses), 3)
self.assertIn(
class_b.targets[0].target,
[
ref.node
for acc in class_accesses
for ref in acc.referents
if isinstance(ref, Assignment)
],
)
y = ensure_type(
ensure_type(class_def.body.body[2], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertIn(y.value, [access.node for access in class_accesses])

def test_ordering_between_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
print(a)
print(b)
a = 1
b = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
first_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
second_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
self.assertEqual(
first_print.args[0].value,
a_param_refs[0].node,
)
a_global = (
cst.ensure_type(
cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
a_global_assignment = list(scopes[a_global]["a"])[0]
a_global_refs = list(a_global_assignment.references)
self.assertEqual(a_global_refs, [])
b_global = (
cst.ensure_type(
cst.ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
b_global_assignment = list(scopes[b_global]["b"])[0]
b_global_refs = list(b_global_assignment.references)
self.assertEqual(len(b_global_refs), 1)
self.assertEqual(b_global_refs[0].node, second_print.args[0].value)