Skip to content

Commit

Permalink
Merge pull request hylang#1683 from brandonwillard/fix-py27-failed-im…
Browse files Browse the repository at this point in the history
…port-modules

Fix `sys.modules` for failed imports in Python 2.7
  • Loading branch information
Kodiologist authored Oct 16, 2018
2 parents d2319dc + a9763b3 commit 4132adb
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
15 changes: 12 additions & 3 deletions hy/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def load_module(self, fullname=None):
mod_type == imp.PKG_DIRECTORY and
os.path.isfile(pkg_path)):

if fullname in sys.modules:
was_in_sys = fullname in sys.modules
if was_in_sys:
mod = sys.modules[fullname]
else:
mod = sys.modules.setdefault(
Expand All @@ -311,7 +312,15 @@ def load_module(self, fullname=None):

mod.__name__ = fullname

self.exec_module(mod, fullname=fullname)
try:
self.exec_module(mod, fullname=fullname)
except Exception:
# Follow Python 2.7 logic and only remove a new, bad
# module; otherwise, leave the old--and presumably
# good--module in there.
if not was_in_sys:
del sys.modules[fullname]
raise

if mod is None:
self._reopen()
Expand Down Expand Up @@ -385,7 +394,7 @@ def get_code(self, fullname=None):
self.code = self.byte_compile_hy(fullname)

if self.code is None:
super(HyLoader, self).get_code(fullname=fullname)
super(HyLoader, self).get_code(fullname=fullname)

return self.code

Expand Down
39 changes: 30 additions & 9 deletions tests/importer/test_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import sys
import ast
import imp
import tempfile
import runpy
import importlib
Expand All @@ -15,12 +14,16 @@
import pytest

import hy
from hy._compat import bytes_type
from hy.errors import HyTypeError
from hy.lex import LexException
from hy.compiler import hy_compile
from hy.importer import hy_parse, HyLoader, cache_from_source

try:
from importlib import reload
except ImportError:
from imp import reload


def test_basics():
"Make sure the basics of the importer work"
Expand Down Expand Up @@ -85,6 +88,15 @@ def _import_error_test():
assert _import_error_test() is not None


def test_import_error_cleanup():
"Failed initial imports should not leave dead modules in `sys.modules`."

with pytest.raises(hy.errors.HyMacroExpansionError):
importlib.import_module('tests.resources.fails')

assert 'tests.resources.fails' not in sys.modules


@pytest.mark.skipif(sys.dont_write_bytecode,
reason="Bytecode generation is suppressed")
def test_import_autocompiles():
Expand Down Expand Up @@ -127,7 +139,13 @@ def eval_str(s):


def test_reload():
"""Copied from CPython's `test_import.py`"""
"""Generate a test module, confirm that it imports properly (and puts the
module in `sys.modules`), then modify the module so that it produces an
error when reloaded. Next, fix the error, reload, and check that the
module is updated and working fine. Rinse, repeat.
This test is adapted from CPython's `test_import.py`.
"""

def unlink(filename):
os.unlink(source)
Expand Down Expand Up @@ -160,7 +178,7 @@ def unlink(filename):
f.write("(setv b (// 20 0))")

with pytest.raises(ZeroDivisionError):
imp.reload(mod)
reload(mod)

# But we still expect the module to be in sys.modules.
mod = sys.modules.get(TESTFN)
Expand All @@ -178,23 +196,25 @@ def unlink(filename):
f.write("(setv a 11)")
f.write("(setv b (// 20 1))")

imp.reload(mod)
reload(mod)

mod = sys.modules.get(TESTFN)
assert mod is not None

assert mod.a == 11
assert mod.b == 20

# Now cause a LexException
# Now cause a `LexException`, and confirm that the good module and its
# contents stick around.
unlink(source)

with open(source, "w") as f:
# Missing paren...
f.write("(setv a 11")
f.write("(setv b (// 20 1))")

with pytest.raises(LexException):
imp.reload(mod)
reload(mod)

mod = sys.modules.get(TESTFN)
assert mod is not None
Expand All @@ -209,7 +229,7 @@ def unlink(filename):
f.write("(setv a 12)")
f.write("(setv b (// 10 1))")

imp.reload(mod)
reload(mod)

mod = sys.modules.get(TESTFN)
assert mod is not None
Expand All @@ -219,8 +239,9 @@ def unlink(filename):

finally:
del sys.path[0]
if TESTFN in sys.modules:
del sys.modules[TESTFN]
unlink(source)
del sys.modules[TESTFN]


def test_circular():
Expand Down
5 changes: 5 additions & 0 deletions tests/resources/fails.hy
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"This module produces an error when imported."
(defmacro a-macro [x]
(+ x 1))

(print (a-macro 'blah))

0 comments on commit 4132adb

Please sign in to comment.