Skip to content

Commit

Permalink
Merge pull request #149 from arnaud-ma/better-performance
Browse files Browse the repository at this point in the history
Improve performance by:
(1) When extracting arguments from descendants (via inheritance) don't extract argument from the top-level Tap class.
(2) When extracting docstrings, only read source code once.
  • Loading branch information
martinjm97 authored Sep 8, 2024
2 parents c0d4b75 + 09cb610 commit 7ea9791
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 32 deletions.
10 changes: 3 additions & 7 deletions src/tap/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union
while len(super_classes) > 0:
super_class = super_classes.pop(0)

if super_class not in visited and issubclass(super_class, Tap):
if super_class not in visited and issubclass(super_class, Tap) and super_class is not Tap:
super_dictionary = extract_func(super_class)

# Update only unseen variables to avoid overriding subclass values
Expand All @@ -529,9 +529,7 @@ def _get_class_dict(self) -> Dict[str, Any]:
if not (
var.startswith("_")
or callable(val)
or isinstance(val, staticmethod)
or isinstance(val, classmethod)
or isinstance(val, property)
or isinstance(val, (staticmethod, classmethod, property))
)
}

Expand All @@ -546,9 +544,7 @@ def _get_class_variables(self) -> dict:
class_variable_names = {**self._get_annotations(), **self._get_class_dict()}.keys()

try:
class_variables = self._get_from_self_and_super(
extract_func=lambda super_class: get_class_variables(super_class)
)
class_variables = self._get_from_self_and_super(extract_func=get_class_variables)

# Handle edge-case of source code modification while code is running
variables_to_add = class_variable_names - class_variables.keys()
Expand Down
51 changes: 32 additions & 19 deletions src/tap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
Expand Down Expand Up @@ -184,29 +185,31 @@ def is_positional_arg(*name_or_flags) -> bool:
return not is_option_arg(*name_or_flags)


def tokenize_source(obj: object) -> Generator:
"""Returns a generator for the tokens of the object's source code."""
source = inspect.getsource(obj)
token_generator = tokenize.generate_tokens(StringIO(source).readline)
return token_generator
def tokenize_source(source: str) -> Generator[tokenize.TokenInfo, None, None]:
"""Returns a generator for the tokens of the object's source code, given the source code."""
return tokenize.generate_tokens(StringIO(source).readline)


def get_class_column(obj: type) -> int:
"""Determines the column number for class variables in a class."""
def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:
"""Determines the column number for class variables in a class, given the tokens of the class."""
first_line = 1
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
if token.strip() == "@":
first_line += 1
if start_line <= first_line or token.strip() == "":
continue

return start_column
raise ValueError("Could not find any class variables in the class.")


def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, int]]]]:
"""Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code."""
def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, List[Dict[str, Union[str, int]]]]:
"""
Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code,
given the tokens of the object's source code.
"""
line_to_tokens = {}
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
line_to_tokens.setdefault(start_line, []).append({
'token_type': token_type,
'token': token,
Expand All @@ -220,13 +223,14 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
return line_to_tokens


def get_subsequent_assign_lines(cls: type) -> Set[int]:
"""For all multiline assign statements, get the line numbers after the first line of the assignment."""
# Get source code of class
source = inspect.getsource(cls)
def get_subsequent_assign_lines(source_cls: str) -> Set[int]:
"""
For all multiline assign statements, get the line numbers after the first line of the assignment,
given the source code of the object.
"""

# Parse source code using ast (with an if statement to avoid indentation errors)
source = f"if True:\n{textwrap.indent(source, ' ')}"
source = f"if True:\n{textwrap.indent(source_cls, ' ')}"
body = ast.parse(source).body[0]

# Set up warning message
Expand Down Expand Up @@ -260,6 +264,11 @@ def get_subsequent_assign_lines(cls: type) -> Set[int]:
assign_lines = set()
for node in cls_body.body:
if isinstance(node, (ast.Assign, ast.AnnAssign)):
# Check if the end line number is found
if node.end_lineno is None:
warnings.warn(parse_warning)
continue

# Get line number of assign statement excluding the first line (and minus 1 for the if statement)
assign_lines |= set(range(node.lineno, node.end_lineno))

Expand All @@ -268,15 +277,19 @@ def get_subsequent_assign_lines(cls: type) -> Set[int]:

def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
# Get the source code and tokens of the class
source_cls = inspect.getsource(cls)
tokens = tuple(tokenize_source(source_cls))

# Get mapping from line number to tokens
line_to_tokens = source_line_to_tokens(cls)
line_to_tokens = source_line_to_tokens(tokens)

# Get class variable column number
class_variable_column = get_class_column(cls)
class_variable_column = get_class_column(tokens)

# For all multiline assign statements, get the line numbers after the first line of the assignment
# This is used to avoid identifying comments in multiline assign statements
subsequent_assign_lines = get_subsequent_assign_lines(cls)
subsequent_assign_lines = get_subsequent_assign_lines(source_cls)

# Extract class variables
class_variable = None
Expand Down
20 changes: 14 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from argparse import ArgumentTypeError
import inspect
import json
import os
import subprocess
Expand All @@ -11,6 +12,7 @@
get_class_column,
get_class_variables,
GitInfo,
tokenize_source,
type_to_str,
get_literals,
TupleTypeEnforcer,
Expand Down Expand Up @@ -145,7 +147,8 @@ def test_column_simple(self):
class SimpleColumn:
arg = 2

self.assertEqual(get_class_column(SimpleColumn), 12)
tokens = tokenize_source(inspect.getsource(SimpleColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_column_comment(self):
class CommentColumn:
Expand All @@ -158,28 +161,32 @@ class CommentColumn:

arg = 2

self.assertEqual(get_class_column(CommentColumn), 12)
tokens = tokenize_source(inspect.getsource(CommentColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_column_space(self):
class SpaceColumn:

arg = 2

self.assertEqual(get_class_column(SpaceColumn), 12)
tokens = tokenize_source(inspect.getsource(SpaceColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_column_method(self):
class FuncColumn:
def func(self):
pass

self.assertEqual(get_class_column(FuncColumn), 12)
tokens = tokenize_source(inspect.getsource(FuncColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_dataclass(self):
@class_decorator
class DataclassColumn:
arg: int = 5

self.assertEqual(get_class_column(DataclassColumn), 12)
tokens = tokenize_source(inspect.getsource(DataclassColumn))
self.assertEqual(get_class_column(tokens), 12)

def test_dataclass_method(self):
def wrapper(f):
Expand All @@ -191,7 +198,8 @@ class DataclassColumn:
def func(self):
pass

self.assertEqual(get_class_column(DataclassColumn), 12)
tokens = tokenize_source(inspect.getsource(DataclassColumn))
self.assertEqual(get_class_column(tokens), 12)


class ClassVariableTests(TestCase):
Expand Down

0 comments on commit 7ea9791

Please sign in to comment.