Skip to content

Commit

Permalink
[mypyc] Avoid crash when importing unknown module with from import (#…
Browse files Browse the repository at this point in the history
…10550)

Fixes mypyc/mypyc#851

This fixes a bug where code compiled with mypyc would crash on from imports (from x import y) if:
 * y is a module
 * mypy doesn't know that y is a module (due to an ignore_missing_imports configuration option or something else)

The bug was caused by using getattr to import modules (i.e. y = getattr(x, 'y')) and changing this to import x.y as y when it can determine that y is a module. This doesn't work when we don't know that y is a module.

I changed the from import handling to use something similar to the method shown in the __import__ docs. I also removed the special casing of from imports for modules (from x import y where y is a module) mentioned earlier, because these changes make that special casing unnecessary.
  • Loading branch information
pranavrajpal authored Jun 9, 2021
1 parent 4028203 commit 7bb1f37
Show file tree
Hide file tree
Showing 6 changed files with 522 additions and 418 deletions.
38 changes: 33 additions & 5 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
SetAttr, LoadStatic, InitStatic, NAMESPACE_MODULE, RaiseStandardError
)
from mypyc.ir.rtypes import (
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
RType, RTuple, RInstance, c_int_rprimitive, int_rprimitive, dict_rprimitive,
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
str_rprimitive, is_tagged, is_list_rprimitive, is_tuple_rprimitive, c_pyssize_t_rprimitive
)
Expand All @@ -45,7 +45,9 @@
from mypyc.primitives.list_ops import to_list, list_pop_last, list_get_item_unsafe_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op
from mypyc.primitives.misc_ops import import_op, check_unpack_count_op, get_module_dict_op
from mypyc.primitives.misc_ops import (
import_op, check_unpack_count_op, get_module_dict_op, import_extra_args_op
)
from mypyc.crash import catch_errors
from mypyc.options import CompilerOptions
from mypyc.errors import Errors
Expand Down Expand Up @@ -286,19 +288,45 @@ def add_to_non_ext_dict(self, non_ext: NonExtClassInfo,
key_unicode = self.load_str(key)
self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line)

def gen_import_from(self, id: str, line: int, imported: List[str]) -> None:
self.imports[id] = None

globals_dict = self.load_globals_dict()
null = Integer(0, dict_rprimitive, line)
names_to_import = self.new_list_op([self.load_str(name) for name in imported], line)

level = Integer(0, c_int_rprimitive, line)
value = self.call_c(
import_extra_args_op,
[self.load_str(id), globals_dict, null, names_to_import, level],
line,
)
self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE))

def gen_import(self, id: str, line: int) -> None:
self.imports[id] = None

needs_import, out = BasicBlock(), BasicBlock()
first_load = self.load_module(id)
comparison = self.translate_is_op(first_load, self.none_object(), 'is not', line)
self.add_bool_branch(comparison, out, needs_import)
self.check_if_module_loaded(id, line, needs_import, out)

self.activate_block(needs_import)
value = self.call_c(import_op, [self.load_str(id)], line)
self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE))
self.goto_and_activate(out)

def check_if_module_loaded(self, id: str, line: int,
needs_import: BasicBlock, out: BasicBlock) -> None:
"""Generate code that checks if the module `id` has been loaded yet.
Arguments:
id: name of module to check if imported
line: line number that the import occurs on
needs_import: the BasicBlock that is run if the module has not been loaded yet
out: the BasicBlock that is run if the module has already been loaded"""
first_load = self.load_module(id)
comparison = self.translate_is_op(first_load, self.none_object(), 'is not', line)
self.add_bool_branch(comparison, out, needs_import)

def get_module(self, module: str, line: int) -> Value:
# Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :(
mod_dict = self.call_c(get_module_dict_op, [], line)
Expand Down
9 changes: 2 additions & 7 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None:

id = importlib.util.resolve_name('.' * node.relative + node.id, module_package)

builder.gen_import(id, node.line)
imported = [name for name, _ in node.names]
builder.gen_import_from(id, node.line, imported)
module = builder.load_module(id)

# Copy everything into our module's dict.
Expand All @@ -181,12 +182,6 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None:
# This probably doesn't matter much and the code runs basically right.
globals = builder.load_globals_dict()
for name, maybe_as_name in node.names:
# If one of the things we are importing is a module,
# import it as a module also.
fullname = id + '.' + name
if fullname in builder.graph or fullname in module_state.suppressed:
builder.gen_import(fullname, node.line)

as_name = maybe_as_name or name
obj = builder.py_get_attr(module, name, node.line)
builder.gen_method_call(
Expand Down
12 changes: 11 additions & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC, ERR_FALSE
from mypyc.ir.rtypes import (
bool_rprimitive, object_rprimitive, str_rprimitive, object_pointer_rprimitive,
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive,
list_rprimitive,
)
from mypyc.primitives.registry import (
function_op, custom_op, load_address_op, ERR_NEG_INT
Expand Down Expand Up @@ -113,6 +114,15 @@
c_function_name='PyImport_Import',
error_kind=ERR_MAGIC)

# Import with extra arguments (used in from import handling)
import_extra_args_op = custom_op(
arg_types=[str_rprimitive, dict_rprimitive, dict_rprimitive,
list_rprimitive, c_int_rprimitive],
return_type=object_rprimitive,
c_function_name='PyImport_ImportModuleLevelObject',
error_kind=ERR_MAGIC
)

# Get the sys.modules dictionary
get_module_dict_op = custom_op(
arg_types=[],
Expand Down
Loading

0 comments on commit 7bb1f37

Please sign in to comment.