Skip to content

Commit

Permalink
[mypyc] Support ERR_ALWAYS (#9073)
Browse files Browse the repository at this point in the history
Related to mypyc/mypyc#734, with a focus on exceptions related ops.

This PR adds a new error kind: ERR_ALWAYS, which indicates the op always fails.

It adds temporary false value to ensure such behavior in the exception handling
transform and makes the raise op void.
  • Loading branch information
TH3CHARLie authored Jul 2, 2020
1 parent eae1860 commit a04bdbf
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 62 deletions.
26 changes: 16 additions & 10 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,7 @@ def visit_branch(self, op: Branch) -> None:

self.emit_line('if ({}) {{'.format(cond))

if op.traceback_entry is not None:
globals_static = self.emitter.static_name('globals', self.module_name)
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
op.traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
self.emit_traceback(op)

self.emit_lines(
'goto %s;' % self.label(op.true),
Expand Down Expand Up @@ -422,7 +414,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
self.emitter.emit_line('{} = 0;'.format(self.reg(op)))

def visit_call_c(self, op: CallC) -> None:
dest = self.get_dest_assign(op)
if op.is_void:
dest = ''
else:
dest = self.get_dest_assign(op)
args = ', '.join(self.reg(arg) for arg in op.args)
self.emitter.emit_line("{}{}({});".format(dest, op.function_name, args))

Expand Down Expand Up @@ -472,3 +467,14 @@ def emit_dec_ref(self, dest: str, rtype: RType, is_xdec: bool) -> None:

def emit_declaration(self, line: str) -> None:
self.declarations.emit_line(line)

def emit_traceback(self, op: Branch) -> None:
if op.traceback_entry is not None:
globals_static = self.emitter.static_name('globals', self.module_name)
self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
op.traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
7 changes: 6 additions & 1 deletion mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def terminated(self) -> bool:
ERR_FALSE = 2 # type: Final
# Generates negative integer on exception
ERR_NEG_INT = 3 # type: Final
# Always fails
ERR_ALWAYS = 4 # type: Final

# Hack: using this line number for an op will suppress it in tracebacks
NO_TRACEBACK_LINE_NO = -10000
Expand Down Expand Up @@ -1167,7 +1169,10 @@ def __init__(self,

def to_str(self, env: Environment) -> str:
args_str = ', '.join(env.format('%r', arg) for arg in self.args)
return env.format('%r = %s(%s)', self, self.function_name, args_str)
if self.is_void:
return env.format('%s(%s)', self.function_name, args_str)
else:
return env.format('%r = %s(%s)', self, self.function_name, args_str)

def sources(self) -> List[Value]:
return self.args
Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def transform_raise_stmt(builder: IRBuilder, s: RaiseStmt) -> None:
return

exc = builder.accept(s.expr)
builder.primitive_op(raise_exception_op, [exc], s.line)
builder.call_c(raise_exception_op, [exc], s.line)
builder.add(Unreachable())


Expand Down Expand Up @@ -614,7 +614,7 @@ def transform_assert_stmt(builder: IRBuilder, a: AssertStmt) -> None:
message = builder.accept(a.msg)
exc_type = builder.load_module_attr_by_fullname('builtins.AssertionError', a.line)
exc = builder.py_call(exc_type, [message], a.line)
builder.primitive_op(raise_exception_op, [exc], a.line)
builder.call_c(raise_exception_op, [exc], a.line)
builder.add(Unreachable())
builder.activate_block(ok_block)

Expand Down
15 changes: 6 additions & 9 deletions mypyc/primitives/exc_ops.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
"""Exception-related primitive ops."""

from mypyc.ir.ops import ERR_NEVER, ERR_FALSE
from mypyc.ir.ops import ERR_NEVER, ERR_FALSE, ERR_ALWAYS
from mypyc.ir.rtypes import bool_rprimitive, object_rprimitive, void_rtype, exc_rtuple
from mypyc.primitives.registry import (
simple_emit, call_emit, call_void_emit, call_and_fail_emit, custom_op,
simple_emit, call_emit, call_void_emit, call_and_fail_emit, custom_op, c_custom_op
)

# If the argument is a class, raise an instance of the class. Otherwise, assume
# that the argument is an exception object, and raise it.
#
# TODO: Making this raise conditionally is kind of hokey.
raise_exception_op = custom_op(
raise_exception_op = c_custom_op(
arg_types=[object_rprimitive],
result_type=bool_rprimitive,
error_kind=ERR_FALSE,
format_str='raise_exception({args[0]}); {dest} = 0',
emit=call_and_fail_emit('CPy_Raise'))
return_type=void_rtype,
c_function_name='CPy_Raise',
error_kind=ERR_ALWAYS)

# Raise StopIteration exception with the specified value (which can be NULL).
set_stop_iteration_value = custom_op(
Expand Down
6 changes: 2 additions & 4 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -1322,24 +1322,22 @@ def foo():
r0 :: object
r1 :: str
r2, r3 :: object
r4 :: bool
L0:
r0 = builtins :: module
r1 = unicode_1 :: static ('Exception')
r2 = getattr r0, r1
r3 = py_call(r2)
raise_exception(r3); r4 = 0
CPy_Raise(r3)
unreachable
def bar():
r0 :: object
r1 :: str
r2 :: object
r3 :: bool
L0:
r0 = builtins :: module
r1 = unicode_1 :: static ('Exception')
r2 = getattr r0, r1
raise_exception(r2); r3 = 0
CPy_Raise(r2)
unreachable

[case testModuleTopLevel_toplevel]
Expand Down
9 changes: 4 additions & 5 deletions mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,7 @@ def complex_msg(x, s):
r4 :: object
r5 :: str
r6, r7 :: object
r8 :: bool
r9 :: None
r8 :: None
L0:
r0 = builtins.None :: object
r1 = x is not r0
Expand All @@ -629,11 +628,11 @@ L2:
r5 = unicode_3 :: static ('AssertionError')
r6 = getattr r4, r5
r7 = py_call(r6, s)
raise_exception(r7); r8 = 0
CPy_Raise(r7)
unreachable
L3:
r9 = None
return r9
r8 = None
return r8

[case testDelList]
def delList() -> None:
Expand Down
49 changes: 24 additions & 25 deletions mypyc/test-data/irbuild-try.test
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,13 @@ def a(b):
r1 :: object
r2 :: str
r3, r4 :: object
r5 :: bool
r6, r7, r8 :: tuple[object, object, object]
r9 :: str
r10 :: object
r11 :: str
r12, r13 :: object
r14, r15 :: bool
r16 :: None
r5, r6, r7 :: tuple[object, object, object]
r8 :: str
r9 :: object
r10 :: str
r11, r12 :: object
r13, r14 :: bool
r15 :: None
L0:
L1:
if b goto L2 else goto L3 :: bool
Expand All @@ -294,39 +293,39 @@ L2:
r2 = unicode_2 :: static ('Exception')
r3 = getattr r1, r2
r4 = py_call(r3, r0)
raise_exception(r4); r5 = 0
CPy_Raise(r4)
unreachable
L3:
L4:
L5:
r7 = <error> :: tuple[object, object, object]
r6 = r7
r6 = <error> :: tuple[object, object, object]
r5 = r6
goto L7
L6: (handler for L1, L2, L3)
r8 = error_catch
r6 = r8
r7 = error_catch
r5 = r7
L7:
r9 = unicode_3 :: static ('finally')
r10 = builtins :: module
r11 = unicode_4 :: static ('print')
r12 = getattr r10, r11
r13 = py_call(r12, r9)
if is_error(r6) goto L9 else goto L8
r8 = unicode_3 :: static ('finally')
r9 = builtins :: module
r10 = unicode_4 :: static ('print')
r11 = getattr r9, r10
r12 = py_call(r11, r8)
if is_error(r5) goto L9 else goto L8
L8:
reraise_exc; r14 = 0
reraise_exc; r13 = 0
unreachable
L9:
goto L13
L10: (handler for L7, L8)
if is_error(r6) goto L12 else goto L11
if is_error(r5) goto L12 else goto L11
L11:
restore_exc_info r6
restore_exc_info r5
L12:
r15 = keep_propagating
r14 = keep_propagating
unreachable
L13:
r16 = None
return r16
r15 = None
return r15

[case testWith]
from typing import Any
Expand Down
25 changes: 19 additions & 6 deletions mypyc/transform/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from typing import List, Optional

from mypyc.ir.ops import (
BasicBlock, LoadErrorValue, Return, Branch, RegisterOp, ERR_NEVER, ERR_MAGIC,
ERR_FALSE, ERR_NEG_INT, NO_TRACEBACK_LINE_NO,
BasicBlock, LoadErrorValue, Return, Branch, RegisterOp, LoadInt, ERR_NEVER, ERR_MAGIC,
ERR_FALSE, ERR_NEG_INT, ERR_ALWAYS, NO_TRACEBACK_LINE_NO, Environment
)
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.rtypes import bool_rprimitive


def insert_exception_handling(ir: FuncIR) -> None:
Expand All @@ -29,7 +30,7 @@ def insert_exception_handling(ir: FuncIR) -> None:
error_label = add_handler_block(ir)
break
if error_label:
ir.blocks = split_blocks_at_errors(ir.blocks, error_label, ir.traceback_name)
ir.blocks = split_blocks_at_errors(ir.blocks, error_label, ir.traceback_name, ir.env)


def add_handler_block(ir: FuncIR) -> BasicBlock:
Expand All @@ -44,7 +45,8 @@ def add_handler_block(ir: FuncIR) -> BasicBlock:

def split_blocks_at_errors(blocks: List[BasicBlock],
default_error_handler: BasicBlock,
func_name: Optional[str]) -> List[BasicBlock]:
func_name: Optional[str],
env: Environment) -> List[BasicBlock]:
new_blocks = [] # type: List[BasicBlock]

# First split blocks on ops that may raise.
Expand All @@ -60,6 +62,7 @@ def split_blocks_at_errors(blocks: List[BasicBlock],
block.error_handler = None

for op in ops:
target = op
cur_block.ops.append(op)
if isinstance(op, RegisterOp) and op.error_kind != ERR_NEVER:
# Split
Expand All @@ -77,14 +80,24 @@ def split_blocks_at_errors(blocks: List[BasicBlock],
elif op.error_kind == ERR_NEG_INT:
variant = Branch.NEG_INT_EXPR
negated = False
elif op.error_kind == ERR_ALWAYS:
variant = Branch.BOOL_EXPR
negated = True
# this is a hack to represent the always fail
# semantics, using a temporary bool with value false
tmp = LoadInt(0, rtype=bool_rprimitive)
cur_block.ops.append(tmp)
env.add_op(tmp)
target = tmp
else:
assert False, 'unknown error kind %d' % op.error_kind

# Void ops can't generate errors since error is always
# indicated by a special value stored in a register.
assert not op.is_void, "void op generating errors?"
if op.error_kind != ERR_ALWAYS:
assert not op.is_void, "void op generating errors?"

branch = Branch(op,
branch = Branch(target,
true_label=error_label,
false_label=new_block,
op=variant,
Expand Down

0 comments on commit a04bdbf

Please sign in to comment.