-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsort.py
120 lines (93 loc) · 3.22 KB
/
sort.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Test Runner for Python.
"""
from ast import (
NodeVisitor,
ClassDef,
FunctionDef,
AsyncFunctionDef,
parse,
For,
While,
With,
If
)
from pathlib import Path
from typing import Dict, overload
from .data import Hierarchy, TestInfo
# pylint: disable=invalid-name, no-self-use
class TestOrder(NodeVisitor):
"""
Visits test_* methods in a file and caches their definition order.
"""
_cache: Dict[Hierarchy, TestInfo] = {}
def __init__(self, root: Hierarchy) -> None:
super().__init__()
self._hierarchy = [root]
def visit_ClassDef(self, node: ClassDef) -> None:
"""
Handles class definitions.
"""
bases = {f"{base.value.id}.{base.attr}" for base in node.bases}
if "unittest.TestCase" in bases:
self._hierarchy.append(Hierarchy(node.name))
self.generic_visit(node)
self._hierarchy.pop()
@overload
def _visit_definition(self, node: FunctionDef) -> None:
...
@overload
def _visit_definition(self, node: AsyncFunctionDef) -> None:
...
def _visit_definition(self, node):
if node.name.startswith("test_"):
last_body = node.body[-1]
# We need to account for subtests here by including "With" nodes
while isinstance(last_body, (For, While, If, With)):
last_body = last_body.body[-1]
testinfo = TestInfo(node.lineno, last_body.lineno, 1)
self._cache[self.get_hierarchy(Hierarchy(node.name))] = testinfo
self.generic_visit(node)
def visit_FunctionDef(self, node: FunctionDef) -> None:
"""
Handles test definitions
"""
self._visit_definition(node)
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
"""
Handles async test definitions
"""
self._visit_definition(node)
def get_hierarchy(self, name: Hierarchy) -> Hierarchy:
"""
Returns the hierarchy :: joined.
"""
return Hierarchy("::".join(self._hierarchy + [name]))
@classmethod
def lineno(cls, test_id: Hierarchy, source: Path) -> int:
"""
Returns the line that the given test was defined on.
"""
if test_id not in cls._cache:
tree = parse(source.read_text(), source.name)
cls(Hierarchy(test_id.split("::")[0])).visit(tree)
return cls._cache[test_id].lineno
@classmethod
def function_source(cls, test_id: Hierarchy, source: Path) -> str:
"""
:param test_id: Hierarchy position of test in AST
:param source: Path of source code file
:return: str of the source code of the given test.
"""
text = source.read_text()
testinfo = cls._cache[test_id]
lines = text.splitlines()[testinfo.lineno: testinfo.end_lineno + 1]
if test_id not in cls._cache:
tree = parse(text, source.name)
cls(Hierarchy(test_id.split("::")[0])).visit(tree)
if not lines[-1]:
lines.pop()
# Dedents source.
while all(line.startswith(' ') for line in lines if line):
lines = [line[1:] if line else line for line in lines]
return '\n'.join(lines)