Skip to content

Commit

Permalink
Merge branch 'main' into git-no-remote
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjm97 committed Nov 28, 2024
2 parents 7ccd355 + fe0d2d8 commit 792da8d
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 96 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']

steps:
- uses: actions/checkout@main
Expand Down
2 changes: 1 addition & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright (c) 2022 Jesse Michel and Kyle Swanson
Copyright (c) 2024 Jesse Michel and Kyle Swanson

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Running `python square.py --num 2` will print `The square of your number is 4.0.

## Installation

Tap requires Python 3.8+
Tap requires Python 3.9+

To install Tap from PyPI run:

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ dependencies = [
"packaging",
"typing-inspect >= 0.7.1",
]
requires-python = ">=3.8"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Typing :: Typed",
Expand Down
18 changes: 4 additions & 14 deletions src/tap/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
TupleTypeEnforcer,
define_python_object_encoder,
as_python_object,
fix_py36_copy,
enforce_reproducibility,
PathLike,
)
Expand Down Expand Up @@ -227,7 +226,7 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
# Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them
elif get_origin(var_type) in (Tuple, tuple) and len(get_args(var_type)) > 0:
loop = False
types = get_args(var_type)
types = list(get_args(var_type))

# Handle Tuple[type, ...]
if len(types) == 2 and types[1] == Ellipsis:
Expand Down Expand Up @@ -504,7 +503,7 @@ def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> Union
while len(super_classes) > 0:
super_class = super_classes.pop(0)

if super_class not in visited and issubclass(super_class, Tap):
if super_class not in visited and issubclass(super_class, Tap) and super_class is not Tap:
super_dictionary = extract_func(super_class)

# Update only unseen variables to avoid overriding subclass values
Expand All @@ -527,13 +526,7 @@ def _get_class_dict(self) -> Dict[str, Any]:
class_dict = {
var: val
for var, val in class_dict.items()
if not (
var.startswith("_")
or callable(val)
or isinstance(val, staticmethod)
or isinstance(val, classmethod)
or isinstance(val, property)
)
if not (var.startswith("_") or callable(val) or isinstance(val, (staticmethod, classmethod, property)))
}

return class_dict
Expand All @@ -547,9 +540,7 @@ def _get_class_variables(self) -> dict:
class_variable_names = {**self._get_annotations(), **self._get_class_dict()}.keys()

try:
class_variables = self._get_from_self_and_super(
extract_func=lambda super_class: get_class_variables(super_class)
)
class_variables = self._get_from_self_and_super(extract_func=get_class_variables)

# Handle edge-case of source code modification while code is running
variables_to_add = class_variable_names - class_variables.keys()
Expand Down Expand Up @@ -717,7 +708,6 @@ def __str__(self) -> str:
"""
return pformat(self.as_dict())

@fix_py36_copy
def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType:
"""Deepcopy the Tap object."""
copied = type(self).__new__(type(self))
Expand Down
161 changes: 114 additions & 47 deletions src/tap/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from argparse import ArgumentParser, ArgumentTypeError
import ast
from base64 import b64encode, b64decode
import copy
from functools import wraps
Expand All @@ -10,20 +11,24 @@
import re
import subprocess
import sys
import textwrap
import tokenize
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin
import warnings

if sys.version_info >= (3, 10):
from types import UnionType
Expand Down Expand Up @@ -162,7 +167,7 @@ def get_argument_name(*name_or_flags) -> str:
return "help"

if len(name_or_flags) > 1:
name_or_flags = [n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--")]
name_or_flags = tuple(n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--"))

if len(name_or_flags) != 1:
raise ValueError(f"There should only be a single canonical name for argument {name_or_flags}!")
Expand Down Expand Up @@ -201,30 +206,28 @@ def is_positional_arg(*name_or_flags) -> bool:
return not is_option_arg(*name_or_flags)


def tokenize_source(obj: object) -> Generator:
"""Returns a generator for the tokens of the object's source code."""
source = inspect.getsource(obj)
token_generator = tokenize.generate_tokens(StringIO(source).readline)
def tokenize_source(source: str) -> Generator[tokenize.TokenInfo, None, None]:
"""Returns a generator for the tokens of the object's source code, given the source code."""
return tokenize.generate_tokens(StringIO(source).readline)

return token_generator


def get_class_column(obj: type) -> int:
"""Determines the column number for class variables in a class."""
def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:
"""Determines the column number for class variables in a class, given the tokens of the class."""
first_line = 1
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
if token.strip() == "@":
first_line += 1
if start_line <= first_line or token.strip() == "":
continue

return start_column
raise ValueError("Could not find any class variables in the class.")


def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, int]]]]:
"""Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code."""
def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, List[Dict[str, Union[str, int]]]]:
"""Extract a map from each line number to list of mappings providing information about each token."""
line_to_tokens = {}
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj):
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
line_to_tokens.setdefault(start_line, []).append(
{
"token_type": token_type,
Expand All @@ -240,20 +243,98 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in
return line_to_tokens


def get_subsequent_assign_lines(source_cls: str) -> Tuple[Set[int], Set[int]]:
"""For all multiline assign statements, get the line numbers after the first line in the assignment.
:param source_cls: The source code of the class.
:return: A set of intermediate line numbers for multiline assign statements and a set of final line numbers.
"""
# Parse source code using ast (with an if statement to avoid indentation errors)
source = f"if True:\n{textwrap.indent(source_cls, ' ')}"
body = ast.parse(source).body[0]

# Set up warning message
parse_warning = (
"Could not parse class source code to extract comments. Comments in the help string may be incorrect."
)

# Check for correct parsing
if not isinstance(body, ast.If):
warnings.warn(parse_warning)
return set(), set()

# Extract if body
if_body = body.body

# Check for a single body
if len(if_body) != 1:
warnings.warn(parse_warning)
return set(), set()

# Extract class body
cls_body = if_body[0]

# Check for a single class definition
if not isinstance(cls_body, ast.ClassDef):
warnings.warn(parse_warning)
return set(), set()

# Get line numbers of assign statements
intermediate_assign_lines = set()
final_assign_lines = set()
for node in cls_body.body:
if isinstance(node, (ast.Assign, ast.AnnAssign)):
# Check if the end line number is found
if node.end_lineno is None:
warnings.warn(parse_warning)
continue

# Only consider multiline assign statements
if node.end_lineno > node.lineno:
# Get intermediate line number of assign statement excluding the first line (and minus 1 for the if statement)
intermediate_assign_lines |= set(range(node.lineno, node.end_lineno - 1))

# If multiline assign statement, get the line number of the last line (and minus 1 for the if statement)
final_assign_lines.add(node.end_lineno - 1)

return intermediate_assign_lines, final_assign_lines


def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
"""Returns a dictionary mapping class variables to their additional information (currently just comments)."""
# Get the source code and tokens of the class
source_cls = inspect.getsource(cls)
tokens = tuple(tokenize_source(source_cls))

# Get mapping from line number to tokens
line_to_tokens = source_line_to_tokens(cls)
line_to_tokens = source_line_to_tokens(tokens)

# Get class variable column number
class_variable_column = get_class_column(cls)
class_variable_column = get_class_column(tokens)

# For all multiline assign statements, get the line numbers after the first line of the assignment
# This is used to avoid identifying comments in multiline assign statements
intermediate_assign_lines, final_assign_lines = get_subsequent_assign_lines(source_cls)

# Extract class variables
class_variable = None
variable_to_comment = {}
for tokens in line_to_tokens.values():
for i, token in enumerate(tokens):
for line, tokens in line_to_tokens.items():
# If this is the final line of a multiline assign, extract any potential comments
if line in final_assign_lines:
# Find the comment (if it exists)
for token in tokens:
if token["token_type"] == tokenize.COMMENT:
# Leave out "#" and whitespace from comment
variable_to_comment[class_variable]["comment"] = token["token"][1:].strip()
break
continue

# Skip assign lines after the first line of multiline assign statements
if line in intermediate_assign_lines:
continue

for i, token in enumerate(tokens):
# Skip whitespace
if token["token"].strip() == "":
continue
Expand All @@ -265,8 +346,21 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
and token["token"][:1] in {'"', "'"}
):
sep = " " if variable_to_comment[class_variable]["comment"] else ""

# Identify the quote character (single or double)
quote_char = token["token"][:1]
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip(quote_char).strip()

# Identify the number of quote characters at the start of the string
num_quote_chars = len(token["token"]) - len(token["token"].lstrip(quote_char))

# Remove the number of quote characters at the start of the string and the end of the string
token["token"] = token["token"][num_quote_chars:-num_quote_chars]

# Remove the unicode escape sequences (e.g. "\"")
token["token"] = bytes(token["token"], encoding="ascii").decode("unicode-escape")

# Add the token to the comment, stripping whitespace
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip()

# Match class variable
class_variable = None
Expand All @@ -292,7 +386,7 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
return variable_to_comment


def get_literals(literal: Literal, variable: str) -> Tuple[Callable[[str], Any], List[str]]:
def get_literals(literal: Literal, variable: str) -> Tuple[Callable[[str], Any], List[type]]:
"""Extracts the values from a Literal type and ensures that the values are all primitive types."""
literals = list(get_args(literal))

Expand Down Expand Up @@ -449,33 +543,6 @@ def as_python_object(dct: Any) -> Any:
return dct


def fix_py36_copy(func: Callable) -> Callable:
"""Decorator that fixes functions using Python 3.6 deepcopy of ArgumentParsers.
Based on https://stackoverflow.com/questions/6279305/typeerror-cannot-deepcopy-this-pattern-object
"""
if sys.version_info[:2] > (3, 6):
return func

@wraps(func)
def wrapper(*args, **kwargs):
re_type = type(re.compile(""))
has_prev_val = re_type in copy._deepcopy_dispatch
prev_val = copy._deepcopy_dispatch.get(re_type, None)
copy._deepcopy_dispatch[type(re.compile(""))] = lambda r, _: r

result = func(*args, **kwargs)

if has_prev_val:
copy._deepcopy_dispatch[re_type] = prev_val
else:
del copy._deepcopy_dispatch[re_type]

return result

return wrapper


def enforce_reproducibility(
saved_reproducibility_data: Optional[Dict[str, str]], current_reproducibility_data: Dict[str, str], path: PathLike
) -> None:
Expand Down Expand Up @@ -512,7 +579,7 @@ def enforce_reproducibility(
raise ValueError(f"{no_reproducibility_message}: Uncommitted changes " f"in current args.")


# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8, 3.9, and 3.10
# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9 and 3.10
# https://github.com/ilevkivskyi/typing_inspect/issues/64
# https://github.com/ilevkivskyi/typing_inspect/issues/65
def get_origin(tp: Any) -> Any:
Expand Down
3 changes: 0 additions & 3 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def configure(self):
# tried redirecting stderr using unittest.mock.patch
# VersionTap().parse_args(['--version'])

@unittest.skipIf(sys.version_info < (3, 8), 'action="extend" introduced in argparse in Python 3.8')
def test_actions_extend(self):
class ExtendTap(Tap):
arg = [1, 2]
Expand All @@ -185,7 +184,6 @@ def configure(self):
args = ExtendTap().parse_args("--arg a b --arg a --arg c d".split())
self.assertEqual(args.arg, [1, 2] + "a b a c d".split())

@unittest.skipIf(sys.version_info < (3, 8), 'action="extend" introduced in argparse in Python 3.8')
def test_actions_extend_list(self):
class ExtendListTap(Tap):
arg: List = ["hi"]
Expand All @@ -196,7 +194,6 @@ def configure(self):
args = ExtendListTap().parse_args("--arg yo yo --arg yoyo --arg yo yo".split())
self.assertEqual(args.arg, "hi yo yo yoyo yo yo".split())

@unittest.skipIf(sys.version_info < (3, 8), 'action="extend" introduced in argparse in Python 3.8')
def test_actions_extend_list_int(self):
class ExtendListIntTap(Tap):
arg: List[int] = [0]
Expand Down
Loading

0 comments on commit 792da8d

Please sign in to comment.