Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance #149

Merged
merged 7 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/tap/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union
while len(super_classes) > 0:
super_class = super_classes.pop(0)

if super_class not in visited and issubclass(super_class, Tap):
if super_class not in visited and issubclass(super_class, Tap) and super_class is not Tap:
super_dictionary = extract_func(super_class)

# Update only unseen variables to avoid overriding subclass values
Expand All @@ -529,9 +529,7 @@ def _get_class_dict(self) -> Dict[str, Any]:
if not (
var.startswith("_")
or callable(val)
or isinstance(val, staticmethod)
or isinstance(val, classmethod)
or isinstance(val, property)
or isinstance(val, (staticmethod, classmethod, property))
)
}

Expand All @@ -546,9 +544,7 @@ def _get_class_variables(self) -> dict:
class_variable_names = {**self._get_annotations(), **self._get_class_dict()}.keys()

try:
class_variables = self._get_from_self_and_super(
extract_func=lambda super_class: get_class_variables(super_class)
)
class_variables = self._get_from_self_and_super(extract_func=get_class_variables)

# Handle edge-case of source code modification while code is running
variables_to_add = class_variable_names - class_variables.keys()
Expand Down
51 changes: 32 additions & 19 deletions src/tap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
Expand Down Expand Up @@ -184,29 +185,31 @@ def is_positional_arg(*name_or_flags) -> bool:
return not is_option_arg(*name_or_flags)


def tokenize_source(obj: object) -> Generator:
"""Returns a generator for the tokens of the object's source code."""
source = inspect.getsource(obj)
token_generator = tokenize.generate_tokens(StringIO(source).readline)
return token_generator
def tokenize_source(source: str) -> Generator[tokenize.TokenInfo, None, None]:
"""Returns a generator for the tokens of the object's source code, given the source code."""
return tokenize.generate_tokens(StringIO(source).readline)


def get_class_column(obj: type) -> int:
"""Determines the column number for class variables in a class."""
def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:
"""Determines the column number for class variables in a class, given the tokens of the class."""
first_line = 1
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
if token.strip() == "@":
first_line += 1
if start_line <= first_line or token.strip() == "":
continue

return start_column
raise ValueError("Could not find any class variables in the class.")


def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, int]]]]:
"""Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code."""
def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, List[Dict[str, Union[str, int]]]]:
"""
Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code,
given the tokens of the object's source code.
"""
line_to_tokens = {}
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
line_to_tokens.setdefault(start_line, []).append({
'token_type': token_type,
'token': token,
Expand All @@ -220,13 +223,14 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
return line_to_tokens


def get_subsequent_assign_lines(cls: type) -> Set[int]:
"""For all multiline assign statements, get the line numbers after the first line of the assignment."""
# Get source code of class
source = inspect.getsource(cls)
def get_subsequent_assign_lines(source_cls: str) -> Set[int]:
"""
For all multiline assign statements, get the line numbers after the first line of the assignment,
given the source code of the object.
"""

# Parse source code using ast (with an if statement to avoid indentation errors)
source = f"if True:\n{textwrap.indent(source, ' ')}"
source = f"if True:\n{textwrap.indent(source_cls, ' ')}"
body = ast.parse(source).body[0]

# Set up warning message
Expand Down Expand Up @@ -260,6 +264,11 @@ def get_subsequent_assign_lines(cls: type) -> Set[int]:
assign_lines = set()
for node in cls_body.body:
if isinstance(node, (ast.Assign, ast.AnnAssign)):
# Check if the end line number is found
if node.end_lineno is None:
warnings.warn(parse_warning)
continue

# Get line number of assign statement excluding the first line (and minus 1 for the if statement)
assign_lines |= set(range(node.lineno, node.end_lineno))

Expand All @@ -268,15 +277,19 @@ def get_subsequent_assign_lines(cls: type) -> Set[int]:

def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
# Get the source code and tokens of the class
source_cls = inspect.getsource(cls)
tokens = tuple(tokenize_source(source_cls))

# Get mapping from line number to tokens
line_to_tokens = source_line_to_tokens(cls)
line_to_tokens = source_line_to_tokens(tokens)

# Get class variable column number
class_variable_column = get_class_column(cls)
class_variable_column = get_class_column(tokens)

# For all multiline assign statements, get the line numbers after the first line of the assignment
# This is used to avoid identifying comments in multiline assign statements
subsequent_assign_lines = get_subsequent_assign_lines(cls)
subsequent_assign_lines = get_subsequent_assign_lines(source_cls)

# Extract class variables
class_variable = None
Expand Down
20 changes: 14 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from argparse import ArgumentTypeError
import inspect
import json
import os
import subprocess
Expand All @@ -11,6 +12,7 @@
get_class_column,
get_class_variables,
GitInfo,
tokenize_source,
type_to_str,
get_literals,
TupleTypeEnforcer,
Expand Down Expand Up @@ -145,7 +147,8 @@ def test_column_simple(self):
class SimpleColumn:
arg = 2

self.assertEqual(get_class_column(SimpleColumn), 12)
tokens = tokenize_source(inspect.getsource(SimpleColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_column_comment(self):
class CommentColumn:
Expand All @@ -158,28 +161,32 @@ class CommentColumn:

arg = 2

self.assertEqual(get_class_column(CommentColumn), 12)
tokens = tokenize_source(inspect.getsource(CommentColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_column_space(self):
class SpaceColumn:

arg = 2

self.assertEqual(get_class_column(SpaceColumn), 12)
tokens = tokenize_source(inspect.getsource(SpaceColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_column_method(self):
class FuncColumn:
def func(self):
pass

self.assertEqual(get_class_column(FuncColumn), 12)
tokens = tokenize_source(inspect.getsource(FuncColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_dataclass(self):
@class_decorator
class DataclassColumn:
arg: int = 5

self.assertEqual(get_class_column(DataclassColumn), 12)
tokens = tokenize_source(inspect.getsource(DataclassColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_dataclass_method(self):
def wrapper(f):
Expand All @@ -191,7 +198,8 @@ class DataclassColumn:
def func(self):
pass

self.assertEqual(get_class_column(DataclassColumn), 12)
tokens = tokenize_source(inspect.getsource(DataclassColumn))
self.assertEqual(get_class_column(tokens), 12)


class ClassVariableTests(TestCase):
Expand Down
Loading