From a04bdbfec48796afa20049c9d419d6cc5ecbeb7e Mon Sep 17 00:00:00 2001 From: Xuanda Yang Date: Thu, 2 Jul 2020 17:57:43 +0800 Subject: [PATCH] [mypyc] Support ERR_ALWAYS (#9073) Related to mypyc/mypyc#734, with a focus on exceptions related ops. This PR adds a new error kind: ERR_ALWAYS, which indicates the op always fails. It adds temporary false value to ensure such behavior in the exception handling transform and makes the raise op void. --- mypyc/codegen/emitfunc.py | 26 ++++++++----- mypyc/ir/ops.py | 7 +++- mypyc/irbuild/statement.py | 4 +- mypyc/primitives/exc_ops.py | 15 +++----- mypyc/test-data/irbuild-basic.test | 6 +-- mypyc/test-data/irbuild-statements.test | 9 ++--- mypyc/test-data/irbuild-try.test | 49 ++++++++++++------------- mypyc/transform/exceptions.py | 25 ++++++++++--- 8 files changed, 79 insertions(+), 62 deletions(-) diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index ef88b8c21305..6d6b46b277f5 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -130,15 +130,7 @@ def visit_branch(self, op: Branch) -> None: self.emit_line('if ({}) {{'.format(cond)) - if op.traceback_entry is not None: - globals_static = self.emitter.static_name('globals', self.module_name) - self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % ( - self.source_path.replace("\\", "\\\\"), - op.traceback_entry[0], - op.traceback_entry[1], - globals_static)) - if DEBUG_ERRORS: - self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");') + self.emit_traceback(op) self.emit_lines( 'goto %s;' % self.label(op.true), @@ -422,7 +414,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None: self.emitter.emit_line('{} = 0;'.format(self.reg(op))) def visit_call_c(self, op: CallC) -> None: - dest = self.get_dest_assign(op) + if op.is_void: + dest = '' + else: + dest = self.get_dest_assign(op) args = ', '.join(self.reg(arg) for arg in op.args) self.emitter.emit_line("{}{}({});".format(dest, op.function_name, args)) @@ -472,3 +467,14 @@ def emit_dec_ref(self, dest: str, rtype: RType, is_xdec: bool) -> None: def emit_declaration(self, line: str) -> None: self.declarations.emit_line(line) + + def emit_traceback(self, op: Branch) -> None: + if op.traceback_entry is not None: + globals_static = self.emitter.static_name('globals', self.module_name) + self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % ( + self.source_path.replace("\\", "\\\\"), + op.traceback_entry[0], + op.traceback_entry[1], + globals_static)) + if DEBUG_ERRORS: + self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");') diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 2eb53b444130..0344d49af72a 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -295,6 +295,8 @@ def terminated(self) -> bool: ERR_FALSE = 2 # type: Final # Generates negative integer on exception ERR_NEG_INT = 3 # type: Final +# Always fails +ERR_ALWAYS = 4 # type: Final # Hack: using this line number for an op will suppress it in tracebacks NO_TRACEBACK_LINE_NO = -10000 @@ -1167,7 +1169,10 @@ def __init__(self, def to_str(self, env: Environment) -> str: args_str = ', '.join(env.format('%r', arg) for arg in self.args) - return env.format('%r = %s(%s)', self, self.function_name, args_str) + if self.is_void: + return env.format('%s(%s)', self.function_name, args_str) + else: + return env.format('%r = %s(%s)', self, self.function_name, args_str) def sources(self) -> List[Value]: return self.args diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index a3c65a99c7f8..1f669930c634 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -243,7 +243,7 @@ def transform_raise_stmt(builder: IRBuilder, s: RaiseStmt) -> None: return exc = builder.accept(s.expr) - builder.primitive_op(raise_exception_op, [exc], s.line) + builder.call_c(raise_exception_op, [exc], s.line) builder.add(Unreachable()) @@ -614,7 +614,7 @@ def transform_assert_stmt(builder: IRBuilder, a: AssertStmt) -> None: message = builder.accept(a.msg) exc_type = builder.load_module_attr_by_fullname('builtins.AssertionError', a.line) exc = builder.py_call(exc_type, [message], a.line) - builder.primitive_op(raise_exception_op, [exc], a.line) + builder.call_c(raise_exception_op, [exc], a.line) builder.add(Unreachable()) builder.activate_block(ok_block) diff --git a/mypyc/primitives/exc_ops.py b/mypyc/primitives/exc_ops.py index ea79203b8b1f..a42f8d3c0aa4 100644 --- a/mypyc/primitives/exc_ops.py +++ b/mypyc/primitives/exc_ops.py @@ -1,21 +1,18 @@ """Exception-related primitive ops.""" -from mypyc.ir.ops import ERR_NEVER, ERR_FALSE +from mypyc.ir.ops import ERR_NEVER, ERR_FALSE, ERR_ALWAYS from mypyc.ir.rtypes import bool_rprimitive, object_rprimitive, void_rtype, exc_rtuple from mypyc.primitives.registry import ( - simple_emit, call_emit, call_void_emit, call_and_fail_emit, custom_op, + simple_emit, call_emit, call_void_emit, call_and_fail_emit, custom_op, c_custom_op ) # If the argument is a class, raise an instance of the class. Otherwise, assume # that the argument is an exception object, and raise it. -# -# TODO: Making this raise conditionally is kind of hokey. -raise_exception_op = custom_op( +raise_exception_op = c_custom_op( arg_types=[object_rprimitive], - result_type=bool_rprimitive, - error_kind=ERR_FALSE, - format_str='raise_exception({args[0]}); {dest} = 0', - emit=call_and_fail_emit('CPy_Raise')) + return_type=void_rtype, + c_function_name='CPy_Raise', + error_kind=ERR_ALWAYS) # Raise StopIteration exception with the specified value (which can be NULL). set_stop_iteration_value = custom_op( diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 272a28e2b0ec..b47af32ad533 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -1322,24 +1322,22 @@ def foo(): r0 :: object r1 :: str r2, r3 :: object - r4 :: bool L0: r0 = builtins :: module r1 = unicode_1 :: static ('Exception') r2 = getattr r0, r1 r3 = py_call(r2) - raise_exception(r3); r4 = 0 + CPy_Raise(r3) unreachable def bar(): r0 :: object r1 :: str r2 :: object - r3 :: bool L0: r0 = builtins :: module r1 = unicode_1 :: static ('Exception') r2 = getattr r0, r1 - raise_exception(r2); r3 = 0 + CPy_Raise(r2) unreachable [case testModuleTopLevel_toplevel] diff --git a/mypyc/test-data/irbuild-statements.test b/mypyc/test-data/irbuild-statements.test index 3587c8ec3c02..d9732218f684 100644 --- a/mypyc/test-data/irbuild-statements.test +++ b/mypyc/test-data/irbuild-statements.test @@ -614,8 +614,7 @@ def complex_msg(x, s): r4 :: object r5 :: str r6, r7 :: object - r8 :: bool - r9 :: None + r8 :: None L0: r0 = builtins.None :: object r1 = x is not r0 @@ -629,11 +628,11 @@ L2: r5 = unicode_3 :: static ('AssertionError') r6 = getattr r4, r5 r7 = py_call(r6, s) - raise_exception(r7); r8 = 0 + CPy_Raise(r7) unreachable L3: - r9 = None - return r9 + r8 = None + return r8 [case testDelList] def delList() -> None: diff --git a/mypyc/test-data/irbuild-try.test b/mypyc/test-data/irbuild-try.test index 5df1420ce349..f5cee4864957 100644 --- a/mypyc/test-data/irbuild-try.test +++ b/mypyc/test-data/irbuild-try.test @@ -277,14 +277,13 @@ def a(b): r1 :: object r2 :: str r3, r4 :: object - r5 :: bool - r6, r7, r8 :: tuple[object, object, object] - r9 :: str - r10 :: object - r11 :: str - r12, r13 :: object - r14, r15 :: bool - r16 :: None + r5, r6, r7 :: tuple[object, object, object] + r8 :: str + r9 :: object + r10 :: str + r11, r12 :: object + r13, r14 :: bool + r15 :: None L0: L1: if b goto L2 else goto L3 :: bool @@ -294,39 +293,39 @@ L2: r2 = unicode_2 :: static ('Exception') r3 = getattr r1, r2 r4 = py_call(r3, r0) - raise_exception(r4); r5 = 0 + CPy_Raise(r4) unreachable L3: L4: L5: - r7 = :: tuple[object, object, object] - r6 = r7 + r6 = :: tuple[object, object, object] + r5 = r6 goto L7 L6: (handler for L1, L2, L3) - r8 = error_catch - r6 = r8 + r7 = error_catch + r5 = r7 L7: - r9 = unicode_3 :: static ('finally') - r10 = builtins :: module - r11 = unicode_4 :: static ('print') - r12 = getattr r10, r11 - r13 = py_call(r12, r9) - if is_error(r6) goto L9 else goto L8 + r8 = unicode_3 :: static ('finally') + r9 = builtins :: module + r10 = unicode_4 :: static ('print') + r11 = getattr r9, r10 + r12 = py_call(r11, r8) + if is_error(r5) goto L9 else goto L8 L8: - reraise_exc; r14 = 0 + reraise_exc; r13 = 0 unreachable L9: goto L13 L10: (handler for L7, L8) - if is_error(r6) goto L12 else goto L11 + if is_error(r5) goto L12 else goto L11 L11: - restore_exc_info r6 + restore_exc_info r5 L12: - r15 = keep_propagating + r14 = keep_propagating unreachable L13: - r16 = None - return r16 + r15 = None + return r15 [case testWith] from typing import Any diff --git a/mypyc/transform/exceptions.py b/mypyc/transform/exceptions.py index d1f82e56829c..755ba6091663 100644 --- a/mypyc/transform/exceptions.py +++ b/mypyc/transform/exceptions.py @@ -12,10 +12,11 @@ from typing import List, Optional from mypyc.ir.ops import ( - BasicBlock, LoadErrorValue, Return, Branch, RegisterOp, ERR_NEVER, ERR_MAGIC, - ERR_FALSE, ERR_NEG_INT, NO_TRACEBACK_LINE_NO, + BasicBlock, LoadErrorValue, Return, Branch, RegisterOp, LoadInt, ERR_NEVER, ERR_MAGIC, + ERR_FALSE, ERR_NEG_INT, ERR_ALWAYS, NO_TRACEBACK_LINE_NO, Environment ) from mypyc.ir.func_ir import FuncIR +from mypyc.ir.rtypes import bool_rprimitive def insert_exception_handling(ir: FuncIR) -> None: @@ -29,7 +30,7 @@ def insert_exception_handling(ir: FuncIR) -> None: error_label = add_handler_block(ir) break if error_label: - ir.blocks = split_blocks_at_errors(ir.blocks, error_label, ir.traceback_name) + ir.blocks = split_blocks_at_errors(ir.blocks, error_label, ir.traceback_name, ir.env) def add_handler_block(ir: FuncIR) -> BasicBlock: @@ -44,7 +45,8 @@ def add_handler_block(ir: FuncIR) -> BasicBlock: def split_blocks_at_errors(blocks: List[BasicBlock], default_error_handler: BasicBlock, - func_name: Optional[str]) -> List[BasicBlock]: + func_name: Optional[str], + env: Environment) -> List[BasicBlock]: new_blocks = [] # type: List[BasicBlock] # First split blocks on ops that may raise. @@ -60,6 +62,7 @@ def split_blocks_at_errors(blocks: List[BasicBlock], block.error_handler = None for op in ops: + target = op cur_block.ops.append(op) if isinstance(op, RegisterOp) and op.error_kind != ERR_NEVER: # Split @@ -77,14 +80,24 @@ def split_blocks_at_errors(blocks: List[BasicBlock], elif op.error_kind == ERR_NEG_INT: variant = Branch.NEG_INT_EXPR negated = False + elif op.error_kind == ERR_ALWAYS: + variant = Branch.BOOL_EXPR + negated = True + # this is a hack to represent the always fail + # semantics, using a temporary bool with value false + tmp = LoadInt(0, rtype=bool_rprimitive) + cur_block.ops.append(tmp) + env.add_op(tmp) + target = tmp else: assert False, 'unknown error kind %d' % op.error_kind # Void ops can't generate errors since error is always # indicated by a special value stored in a register. - assert not op.is_void, "void op generating errors?" + if op.error_kind != ERR_ALWAYS: + assert not op.is_void, "void op generating errors?" - branch = Branch(op, + branch = Branch(target, true_label=error_label, false_label=new_block, op=variant,