Skip to content

Commit

Permalink
Fix py2, header recompilation
Browse files Browse the repository at this point in the history
Refs   #706.
  • Loading branch information
evhub committed Dec 24, 2022
1 parent 9b97a11 commit 67dd447
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 39 deletions.
2 changes: 1 addition & 1 deletion coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,8 @@ def getheader(self, which, use_hash=None, polish=True):
"""Get a formatted header."""
header = getheader(
which,
target=self.target,
use_hash=use_hash,
target=self.target,
no_tco=self.no_tco,
strict=self.strict,
no_wrap=self.no_wrap,
Expand Down
83 changes: 56 additions & 27 deletions coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,12 @@ def section(name, newline_before=True):
)


def base_pycondition(target, ver, if_lt=None, if_ge=None, indent=None, newline=False, fallback=""):
def prepare(code, indent=0, **kwargs):
"""Prepare a piece of code for the header."""
return _indent(code, by=indent, strip=True, **kwargs)


def base_pycondition(target, ver, if_lt=None, if_ge=None, indent=None, newline=False, initial_newline=False, fallback=""):
"""Produce code that depends on the Python version for the given target."""
internal_assert(isinstance(ver, tuple), "invalid pycondition version")
internal_assert(if_lt or if_ge, "either if_lt or if_ge must be specified")
Expand Down Expand Up @@ -160,6 +165,8 @@ def base_pycondition(target, ver, if_lt=None, if_ge=None, indent=None, newline=F

if indent is not None:
out = _indent(out, by=indent)
if initial_newline:
out = "\n" + out
if newline:
out += "\n"
return out
Expand Down Expand Up @@ -191,7 +198,7 @@ def __getattr__(self, attr):
COMMENT = Comment()


def process_header_args(which, target, use_hash, no_tco, strict, no_wrap):
def process_header_args(which, use_hash, target, no_tco, strict, no_wrap):
"""Create the dictionary passed to str.format in the header."""
target_startswith = one_num_ver(target)
target_info = get_target_info(target)
Expand Down Expand Up @@ -231,12 +238,13 @@ def process_header_args(which, target, use_hash, no_tco, strict, no_wrap):
''',
indent=1,
),
import_OrderedDict=_indent(
r'''OrderedDict = collections.OrderedDict if _coconut_sys.version_info >= (2, 7) else dict'''
if not target
import_OrderedDict=prepare(
r'''
OrderedDict = collections.OrderedDict if _coconut_sys.version_info >= (2, 7) else dict
''' if not target
else "OrderedDict = collections.OrderedDict" if target_info >= (2, 7)
else "OrderedDict = dict",
by=1,
indent=1,
),
import_collections_abc=pycondition(
(3, 3),
Expand All @@ -248,17 +256,18 @@ def process_header_args(which, target, use_hash, no_tco, strict, no_wrap):
''',
indent=1,
),
set_zip_longest=_indent(
r'''zip_longest = itertools.zip_longest if _coconut_sys.version_info >= (3,) else itertools.izip_longest'''
if not target
set_zip_longest=prepare(
r'''
zip_longest = itertools.zip_longest if _coconut_sys.version_info >= (3,) else itertools.izip_longest
''' if not target
else "zip_longest = itertools.zip_longest" if target_info >= (3,)
else "zip_longest = itertools.izip_longest",
by=1,
indent=1,
),
comma_bytearray=", bytearray" if target_startswith != "3" else "",
lstatic="staticmethod(" if target_startswith != "3" else "",
rstatic=")" if target_startswith != "3" else "",
zip_iter=_indent(
zip_iter=prepare(
r'''
for items in _coconut.iter(_coconut.zip(*self.iters, strict=self.strict) if _coconut_sys.version_info >= (3, 10) else _coconut.zip_longest(*self.iters, fillvalue=_coconut_sentinel) if self.strict else _coconut.zip(*self.iters)):
if self.strict and _coconut_sys.version_info < (3, 10) and _coconut.any(x is _coconut_sentinel for x in items):
Expand All @@ -277,8 +286,7 @@ def process_header_args(which, target, use_hash, no_tco, strict, no_wrap):
raise _coconut.ValueError("zip(..., strict=True) arguments have mismatched lengths")
yield items
''',
by=2,
strip=True,
indent=2,
),
# disabled mocks must have different docstrings so the
# interpreter can tell them apart from the real thing
Expand Down Expand Up @@ -475,7 +483,7 @@ def __lt__(self, other):
tco_comma="_coconut_tail_call, _coconut_tco, " if not no_tco else "",
call_set_names_comma="_coconut_call_set_names, " if target_info < (3, 6) else "",
handle_cls_args_comma="_coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, " if target_startswith != "3" else "",
async_def_anext=_indent(
async_def_anext=prepare(
r'''
async def __anext__(self):
return self.func(await self.aiter.__anext__())
Expand All @@ -496,8 +504,19 @@ async def __anext__(self):
__anext__ = _coconut.asyncio.coroutine(_coconut_anext_ns["__anext__"])
''',
),
by=1,
strip=True,
indent=1,
),
patch_cached_MatchError=pycondition(
(3,),
if_ge=r'''
for _coconut_varname in dir(MatchError):
try:
setattr(_coconut_cached_MatchError, _coconut_varname, getattr(MatchError, _coconut_varname))
except (AttributeError, TypeError):
pass
''',
indent=1,
initial_newline=True,
),
)

Expand Down Expand Up @@ -615,8 +634,12 @@ class you_need_to_install_backports_functools_lru_cache{object}:
# -----------------------------------------------------------------------------------------------------------------------


def getheader(which, target, use_hash, no_tco, strict, no_wrap):
"""Generate the specified header."""
def getheader(which, use_hash, target, no_tco, strict, no_wrap):
"""Generate the specified header.
IMPORTANT: Any new arguments to this function must be duplicated to
header_info and process_header_args.
"""
internal_assert(
which.startswith("package") or which in (
"none", "initial", "__coconut__", "sys", "code", "file",
Expand All @@ -628,12 +651,12 @@ def getheader(which, target, use_hash, no_tco, strict, no_wrap):
if which == "none":
return ""

target_startswith = one_num_ver(target)
target_info = get_target_info(target)

# initial, __coconut__, package:n, sys, code, file

format_dict = process_header_args(which, target, use_hash, no_tco, strict, no_wrap)
target_startswith = one_num_ver(target)
target_info = get_target_info(target)
header_info = tuple_str_of((VERSION, target, no_tco, strict, no_wrap), add_quotes=True)
format_dict = process_header_args(which, use_hash, target, no_tco, strict, no_wrap)

if which == "initial" or which == "__coconut__":
header = '''#!/usr/bin/env python{target_startswith}
Expand Down Expand Up @@ -669,17 +692,20 @@ def getheader(which, target, use_hash, no_tco, strict, no_wrap):

header += "import sys as _coconut_sys\n"

if which.startswith("package") or which == "__coconut__":
header += "_coconut_header_info = " + header_info + "\n"

if which.startswith("package"):
levels_up = int(which[len("package:"):])
coconut_file_dir = "_coconut_os.path.dirname(_coconut_os.path.abspath(__file__))"
for _ in range(levels_up):
coconut_file_dir = "_coconut_os.path.dirname(" + coconut_file_dir + ")"
return header + '''import os as _coconut_os
_coconut_file_dir = {coconut_file_dir}
_coconut_cached_module = _coconut_sys.modules.get({__coconut__})
if _coconut_cached_module is not None and _coconut_os.path.dirname(_coconut_cached_module.__file__) != _coconut_file_dir: # type: ignore
if _coconut_cached_module is not None and getattr(_coconut_cached_module, "_coconut_header_info", None) != _coconut_header_info: # type: ignore
_coconut_sys.modules[{_coconut_cached_module}] = _coconut_cached_module
del _coconut_sys.modules[{__coconut__}]
_coconut_file_dir = {coconut_file_dir}
_coconut_sys.path.insert(0, _coconut_file_dir)
_coconut_module_name = _coconut_os.path.splitext(_coconut_os.path.basename(_coconut_file_dir))[0]
if _coconut_module_name and _coconut_module_name[0].isalpha() and all(c.isalpha() or c.isdigit() for c in _coconut_module_name) and "__init__.py" in _coconut_os.listdir(_coconut_file_dir):
Expand Down Expand Up @@ -710,9 +736,12 @@ def getheader(which, target, use_hash, no_tco, strict, no_wrap):

# __coconut__, code, file

header += '''_coconut_cached_module = _coconut_sys.modules.get({_coconut_cached_module}, _coconut_sys.modules.get({__coconut__}))
_coconut_base_MatchError = Exception if _coconut_cached_module is None else getattr(_coconut_cached_module, "MatchError", Exception)
'''.format(**format_dict)
header += prepare(
'''
_coconut_cached_module = _coconut_sys.modules.get({_coconut_cached_module}, _coconut_sys.modules.get({__coconut__}))
''',
newline=True,
).format(**format_dict)

if target_info >= (3, 7):
header += PY37_HEADER
Expand Down
12 changes: 4 additions & 8 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class _coconut_base_hashable{object}:
def __setstate__(self, setvars):{COMMENT.fixes_unpickling_with_slots}
for k, v in setvars.items():
_coconut.setattr(self, k, v)
class MatchError(_coconut_base_hashable, _coconut_base_MatchError):
class MatchError(_coconut_base_hashable, Exception):
"""Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below}
max_val_repr_len = 500
def __init__(self, pattern=None, value=None):
Expand All @@ -74,13 +74,9 @@ class MatchError(_coconut_base_hashable, _coconut_base_MatchError):
return Exception.__unicode__(self)
def __reduce__(self):
return (self.__class__, (self.pattern, self.value), {lbrace}"_message": self._message{rbrace})
if _coconut_base_MatchError is not Exception:
for _coconut_MatchError_k in dir(MatchError):
try:
setattr(_coconut_base_MatchError, _coconut_MatchError_k, getattr(MatchError, _coconut_MatchError_k))
except (AttributeError, TypeError):
pass
MatchError = _coconut_base_MatchError
_coconut_cached_MatchError = None if _coconut_cached_module is None else getattr(_coconut_cached_module, "MatchError", None)
if _coconut_cached_MatchError is not None:{patch_cached_MatchError}
MatchError = _coconut_cached_MatchError
class _coconut_tail_call{object}:
__slots__ = ("func", "args", "kwargs")
def __init__(self, _coconut_func, *args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@
VERSION = "2.1.1"
VERSION_NAME = "The Spanish Inquisition"
# False for release, int >= 1 for develop
DEVELOP = 37
DEVELOP = 38
ALPHA = False # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
# UTILITIES:
# -----------------------------------------------------------------------------------------------------------------------


def _indent(code, by=1, tabsize=4, newline=False, strip=False):
def _indent(code, by=1, tabsize=4, strip=False, newline=False, initial_newline=False):
"""Indents every nonempty line of the given code."""
return "".join(
return ("\n" if initial_newline else "") + "".join(
(" " * (tabsize * by) if line.strip() else "") + line
for line in (code.strip() if strip else code).splitlines(True)
) + ("\n" if newline else "")
Expand Down

0 comments on commit 67dd447

Please sign in to comment.