Skip to content

Commit

Permalink
Improve documentation, typing compliance, and removed deprecated code
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjm97 committed Sep 8, 2024
1 parent 7ea9791 commit b78111e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 62 deletions.
16 changes: 5 additions & 11 deletions src/tap/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
TupleTypeEnforcer,
define_python_object_encoder,
as_python_object,
fix_py36_copy,
enforce_reproducibility,
PathLike,
)
Expand Down Expand Up @@ -227,7 +226,7 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
# Handle Tuple type (with type args) by extracting types of Tuple elements and enforcing them
elif get_origin(var_type) in (Tuple, tuple) and len(get_args(var_type)) > 0:
loop = False
types = get_args(var_type)
types = list(get_args(var_type))

# Handle Tuple[type, ...]
if len(types) == 2 and types[1] == Ellipsis:
Expand Down Expand Up @@ -331,7 +330,7 @@ def add_subparser(self, flag: str, subparser_type: type, **kwargs) -> None:
self._subparser_buffer.append((flag, subparser_type, kwargs))

def _add_subparsers(self) -> None:
"""Add each of the subparsers to the Tap object. """
"""Add each of the subparsers to the Tap object."""
# Initialize the _subparsers object if not already created
if self._subparsers is None and len(self._subparser_buffer) > 0:
self._subparsers = super(Tap, self).add_subparsers()
Expand All @@ -345,7 +344,7 @@ def add_subparsers(self, **kwargs) -> None:
self._subparsers = super().add_subparsers(**kwargs)

def _configure(self) -> None:
"""Executes the user-defined configuration. """
"""Executes the user-defined configuration."""
# Call the user-defined configuration
self.configure()

Expand Down Expand Up @@ -526,11 +525,7 @@ def _get_class_dict(self) -> Dict[str, Any]:
class_dict = {
var: val
for var, val in class_dict.items()
if not (
var.startswith("_")
or callable(val)
or isinstance(val, (staticmethod, classmethod, property))
)
if not (var.startswith("_") or callable(val) or isinstance(val, (staticmethod, classmethod, property)))
}

return class_dict
Expand Down Expand Up @@ -712,7 +707,6 @@ def __str__(self) -> str:
"""
return pformat(self.as_dict())

@fix_py36_copy
def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType:
"""Deepcopy the Tap object."""
copied = type(self).__new__(type(self))
Expand All @@ -722,7 +716,7 @@ def __deepcopy__(self, memo: Dict[int, Any] = None) -> TapType:

memo[id(self)] = copied

for (k, v) in self.__dict__.items():
for k, v in self.__dict__.items():
copied.__dict__[k] = deepcopy(v, memo)

return copied
Expand Down
69 changes: 18 additions & 51 deletions src/tap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_argument_name(*name_or_flags) -> str:
return "help"

if len(name_or_flags) > 1:
name_or_flags = [n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--")]
name_or_flags = tuple(n_or_f for n_or_f in name_or_flags if n_or_f.startswith("--"))

if len(name_or_flags) != 1:
raise ValueError(f"There should only be a single canonical name for argument {name_or_flags}!")
Expand Down Expand Up @@ -204,39 +204,33 @@ def get_class_column(tokens: Iterable[tokenize.TokenInfo]) -> int:


def source_line_to_tokens(tokens: Iterable[tokenize.TokenInfo]) -> Dict[int, List[Dict[str, Union[str, int]]]]:
"""
Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code,
given the tokens of the object's source code.
"""
"""Extract a map from each line number to list of mappings providing information about each token."""
line_to_tokens = {}
for token_type, token, (start_line, start_column), (end_line, end_column), line in tokens:
line_to_tokens.setdefault(start_line, []).append({
'token_type': token_type,
'token': token,
'start_line': start_line,
'start_column': start_column,
'end_line': end_line,
'end_column': end_column,
'line': line
})
line_to_tokens.setdefault(start_line, []).append(
{
"token_type": token_type,
"token": token,
"start_line": start_line,
"start_column": start_column,
"end_line": end_line,
"end_column": end_column,
"line": line,
}
)

return line_to_tokens


def get_subsequent_assign_lines(source_cls: str) -> Set[int]:
"""
For all multiline assign statements, get the line numbers after the first line of the assignment,
given the source code of the object.
"""

"""For all multiline assign statements, get the line numbers after the first line in the assignment."""
# Parse source code using ast (with an if statement to avoid indentation errors)
source = f"if True:\n{textwrap.indent(source_cls, ' ')}"
body = ast.parse(source).body[0]

# Set up warning message
parse_warning = (
"Could not parse class source code to extract comments. "
"Comments in the help string may be incorrect."
"Could not parse class source code to extract comments. Comments in the help string may be incorrect."
)

# Check for correct parsing
Expand Down Expand Up @@ -322,7 +316,7 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
token["token"] = token["token"][num_quote_chars:-num_quote_chars]

# Remove the unicode escape sequences (e.g. "\"")
token["token"] = bytes(token["token"], encoding='ascii').decode('unicode-escape')
token["token"] = bytes(token["token"], encoding="ascii").decode("unicode-escape")

# Add the token to the comment, stripping whitespace
variable_to_comment[class_variable]["comment"] += sep + token["token"].strip()
Expand Down Expand Up @@ -351,7 +345,7 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]:
return variable_to_comment


def get_literals(literal: Literal, variable: str) -> Tuple[Callable[[str], Any], List[str]]:
def get_literals(literal: Literal, variable: str) -> Tuple[Callable[[str], Any], List[type]]:
"""Extracts the values from a Literal type and ensures that the values are all primitive types."""
literals = list(get_args(literal))

Expand Down Expand Up @@ -476,7 +470,7 @@ def default(self, obj: Any) -> Any:


class UnpicklableObject:
"""A class that serves as a placeholder for an object that could not be pickled. """
"""A class that serves as a placeholder for an object that could not be pickled."""

def __eq__(self, other):
return isinstance(other, UnpicklableObject)
Expand Down Expand Up @@ -508,33 +502,6 @@ def as_python_object(dct: Any) -> Any:
return dct


def fix_py36_copy(func: Callable) -> Callable:
"""Decorator that fixes functions using Python 3.6 deepcopy of ArgumentParsers.
Based on https://stackoverflow.com/questions/6279305/typeerror-cannot-deepcopy-this-pattern-object
"""
if sys.version_info[:2] > (3, 6):
return func

@wraps(func)
def wrapper(*args, **kwargs):
re_type = type(re.compile(""))
has_prev_val = re_type in copy._deepcopy_dispatch
prev_val = copy._deepcopy_dispatch.get(re_type, None)
copy._deepcopy_dispatch[type(re.compile(""))] = lambda r, _: r

result = func(*args, **kwargs)

if has_prev_val:
copy._deepcopy_dispatch[re_type] = prev_val
else:
del copy._deepcopy_dispatch[re_type]

return result

return wrapper


def enforce_reproducibility(
saved_reproducibility_data: Optional[Dict[str, str]], current_reproducibility_data: Dict[str, str], path: PathLike
) -> None:
Expand Down

0 comments on commit b78111e

Please sign in to comment.