Skip to content

Commit

Permalink
[mypyc] Add a simple irchecking analysis system (#11283)
Browse files Browse the repository at this point in the history
* Introduce basic ir analysis checks

Adds a new module for performing analysis checks on ir as several bugs
in mypy had underlying issues in which the ir produced by mypy was
invalid and guaranteed to have issues at runtime.

For now the checks are relatively simple - the only two supported are
some validity on basic blocks: that they terminate and that control
ops reference a basic block within the same function.

Error reporting is non-existent and instead we are just testing that
the resulting error datatypes are what we expect. In the future we will
need to incorporate the errors into the pretty-printer so that we can
produce an ir dump that references where the error actually is.

In addition it would be useful if we had an ir parser so we could have a
test framework similar to other parts of mypyc in which we simply write
the IR in text format instead of constructing the IR ast from within
python.

* Integrate ircheck analysis into ir tests
  • Loading branch information
jhance authored Nov 4, 2021
1 parent 5bd2641 commit ad7e353
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 7 deletions.
165 changes: 165 additions & 0 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Utilities for checking that internal ir is valid and consistent."""
from typing import List, Union
from mypyc.ir.pprint import format_func
from mypyc.ir.ops import (
OpVisitor, BasicBlock, Op, ControlOp, Goto, Branch, Return, Unreachable,
Assign, AssignMulti, LoadErrorValue, LoadLiteral, GetAttr, SetAttr, LoadStatic,
InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast,
Box, Unbox, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp,
LoadMem, SetMem, GetElementPtr, LoadAddress, KeepAlive
)
from mypyc.ir.func_ir import FuncIR


class FnError(object):
def __init__(self, source: Union[Op, BasicBlock], desc: str) -> None:
self.source = source
self.desc = desc

def __eq__(self, other: object) -> bool:
return isinstance(other, FnError) and self.source == other.source and \
self.desc == other.desc

def __repr__(self) -> str:
return f"FnError(source={self.source}, desc={self.desc})"


def check_func_ir(fn: FuncIR) -> List[FnError]:
"""Applies validations to a given function ir and returns a list of errors found."""
errors = []

for block in fn.blocks:
if not block.terminated:
errors.append(FnError(
source=block.ops[-1] if block.ops else block,
desc="Block not terminated",
))

op_checker = OpChecker(fn)
for block in fn.blocks:
for op in block.ops:
op.accept(op_checker)

return errors + op_checker.errors


class IrCheckException(Exception):
pass


def assert_func_ir_valid(fn: FuncIR) -> None:
errors = check_func_ir(fn)
if errors:
raise IrCheckException("Internal error: Generated invalid IR: \n" + "\n".join(
format_func(fn, [(e.source, e.desc) for e in errors])),
)


class OpChecker(OpVisitor[None]):
def __init__(self, parent_fn: FuncIR) -> None:
self.parent_fn = parent_fn
self.errors: List[FnError] = []

def fail(self, source: Op, desc: str) -> None:
self.errors.append(FnError(source=source, desc=desc))

def check_control_op_targets(self, op: ControlOp) -> None:
for target in op.targets():
if target not in self.parent_fn.blocks:
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")

def visit_goto(self, op: Goto) -> None:
self.check_control_op_targets(op)

def visit_branch(self, op: Branch) -> None:
self.check_control_op_targets(op)

def visit_return(self, op: Return) -> None:
pass

def visit_unreachable(self, op: Unreachable) -> None:
pass

def visit_assign(self, op: Assign) -> None:
pass

def visit_assign_multi(self, op: AssignMulti) -> None:
pass

def visit_load_error_value(self, op: LoadErrorValue) -> None:
pass

def visit_load_literal(self, op: LoadLiteral) -> None:
pass

def visit_get_attr(self, op: GetAttr) -> None:
pass

def visit_set_attr(self, op: SetAttr) -> None:
pass

def visit_load_static(self, op: LoadStatic) -> None:
pass

def visit_init_static(self, op: InitStatic) -> None:
pass

def visit_tuple_get(self, op: TupleGet) -> None:
pass

def visit_tuple_set(self, op: TupleSet) -> None:
pass

def visit_inc_ref(self, op: IncRef) -> None:
pass

def visit_dec_ref(self, op: DecRef) -> None:
pass

def visit_call(self, op: Call) -> None:
pass

def visit_method_call(self, op: MethodCall) -> None:
pass

def visit_cast(self, op: Cast) -> None:
pass

def visit_box(self, op: Box) -> None:
pass

def visit_unbox(self, op: Unbox) -> None:
pass

def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
pass

def visit_call_c(self, op: CallC) -> None:
pass

def visit_truncate(self, op: Truncate) -> None:
pass

def visit_load_global(self, op: LoadGlobal) -> None:
pass

def visit_int_op(self, op: IntOp) -> None:
pass

def visit_comparison_op(self, op: ComparisonOp) -> None:
pass

def visit_load_mem(self, op: LoadMem) -> None:
pass

def visit_set_mem(self, op: SetMem) -> None:
pass

def visit_get_element_ptr(self, op: GetElementPtr) -> None:
pass

def visit_load_address(self, op: LoadAddress) -> None:
pass

def visit_keep_alive(self, op: KeepAlive) -> None:
pass
31 changes: 24 additions & 7 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for pretty-printing IR in a human-readable form."""

from typing import Any, Dict, List
from collections import defaultdict
from typing import Any, Dict, List, Union, Sequence, Tuple

from typing_extensions import Final

Expand All @@ -10,12 +11,14 @@
LoadStatic, InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast, Box, Unbox,
RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp, LoadMem, SetMem,
GetElementPtr, LoadAddress, Register, Value, OpVisitor, BasicBlock, ControlOp, LoadLiteral,
AssignMulti, KeepAlive
AssignMulti, KeepAlive, Op
)
from mypyc.ir.func_ir import FuncIR, all_values_full
from mypyc.ir.module_ir import ModuleIRs
from mypyc.ir.rtypes import is_bool_rprimitive, is_int_rprimitive, RType

ErrorSource = Union[BasicBlock, Op]


class IRPrettyPrintVisitor(OpVisitor[str]):
"""Internal visitor that pretty-prints ops."""
Expand Down Expand Up @@ -269,7 +272,8 @@ def format_registers(func_ir: FuncIR,


def format_blocks(blocks: List[BasicBlock],
names: Dict[Value, str]) -> List[str]:
names: Dict[Value, str],
source_to_error: Dict[ErrorSource, List[str]]) -> List[str]:
"""Format a list of IR basic blocks into a human-readable form."""
# First label all of the blocks
for i, block in enumerate(blocks):
Expand All @@ -290,30 +294,43 @@ def format_blocks(blocks: List[BasicBlock],
handler_msg = ' (handler for {})'.format(', '.join(labels))

lines.append('L%d:%s' % (block.label, handler_msg))
if block in source_to_error:
for error in source_to_error[block]:
lines.append(f" ERR: {error}")
ops = block.ops
if (isinstance(ops[-1], Goto) and i + 1 < len(blocks)
and ops[-1].label == blocks[i + 1]):
# Hide the last goto if it just goes to the next basic block.
and ops[-1].label == blocks[i + 1]
and not source_to_error.get(ops[-1], [])):
# Hide the last goto if it just goes to the next basic block,
# and there are no assocatiated errors with the op.
ops = ops[:-1]
for op in ops:
line = ' ' + op.accept(visitor)
lines.append(line)
if op in source_to_error:
for error in source_to_error[op]:
lines.append(f" ERR: {error}")

if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
# Each basic block needs to exit somewhere.
lines.append(' [MISSING BLOCK EXIT OPCODE]')
return lines


def format_func(fn: FuncIR) -> List[str]:
def format_func(fn: FuncIR, errors: Sequence[Tuple[ErrorSource, str]] = ()) -> List[str]:
lines = []
cls_prefix = fn.class_name + '.' if fn.class_name else ''
lines.append('def {}{}({}):'.format(cls_prefix, fn.name,
', '.join(arg.name for arg in fn.args)))
names = generate_names_for_ir(fn.arg_regs, fn.blocks)
for line in format_registers(fn, names):
lines.append(' ' + line)
code = format_blocks(fn.blocks, names)

source_to_error = defaultdict(list)
for source, error in errors:
source_to_error[source].append(error)

code = format_blocks(fn.blocks, names, source_to_error)
lines.extend(code)
return lines

Expand Down
95 changes: 95 additions & 0 deletions mypyc/test/test_ircheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import unittest
from typing import List

from mypyc.analysis.ircheck import check_func_ir, FnError
from mypyc.ir.rtypes import none_rprimitive
from mypyc.ir.ops import BasicBlock, Op, Return, Integer, Goto
from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature
from mypyc.ir.pprint import format_func


def assert_has_error(fn: FuncIR, error: FnError) -> None:
errors = check_func_ir(fn)
assert errors == [error]


def assert_no_errors(fn: FuncIR) -> None:
assert not check_func_ir(fn)


NONE_VALUE = Integer(0, rtype=none_rprimitive)


class TestIrcheck(unittest.TestCase):
def setUp(self) -> None:
self.label = 0

def basic_block(self, ops: List[Op]) -> BasicBlock:
self.label += 1
block = BasicBlock(self.label)
block.ops = ops
return block

def func_decl(self, name: str) -> FuncDecl:
return FuncDecl(name=name, class_name=None, module_name="module", sig=FuncSignature(
args=[], ret_type=none_rprimitive,
))

def test_valid_fn(self) -> None:
assert_no_errors(FuncIR(
decl=self.func_decl(name="func_1"),
arg_regs=[],
blocks=[self.basic_block(ops=[
Return(value=NONE_VALUE),
])],
))

def test_block_not_terminated_empty_block(self) -> None:
block = self.basic_block([])
fn = FuncIR(
decl=self.func_decl(name="func_1"),
arg_regs=[],
blocks=[block],
)
assert_has_error(fn, FnError(source=block, desc="Block not terminated"))

def test_valid_goto(self) -> None:
block_1 = self.basic_block([Return(value=NONE_VALUE)])
block_2 = self.basic_block([Goto(label=block_1)])
fn = FuncIR(
decl=self.func_decl(name="func_1"),
arg_regs=[],
blocks=[block_1, block_2],
)
assert_no_errors(fn)

def test_invalid_goto(self) -> None:
block_1 = self.basic_block([Return(value=NONE_VALUE)])
goto = Goto(label=block_1)
block_2 = self.basic_block([goto])
fn = FuncIR(
decl=self.func_decl(name="func_1"),
arg_regs=[],
# block_1 omitted
blocks=[block_2],
)
assert_has_error(fn, FnError(source=goto, desc="Invalid control operation target: 1"))

def test_pprint(self) -> None:
block_1 = self.basic_block([Return(value=NONE_VALUE)])
goto = Goto(label=block_1)
block_2 = self.basic_block([goto])
fn = FuncIR(
decl=self.func_decl(name="func_1"),
arg_regs=[],
# block_1 omitted
blocks=[block_2],
)
errors = [(goto, "Invalid control operation target: 1")]
formatted = format_func(fn, errors)
assert formatted == [
"def func_1():",
"L0:",
" goto L1",
" ERR: Invalid control operation target: 1",
]
3 changes: 3 additions & 0 deletions mypyc/test/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mypy.test.helpers import assert_string_arrays_equal

from mypyc.options import CompilerOptions
from mypyc.analysis.ircheck import assert_func_ir_valid
from mypyc.ir.func_ir import FuncIR
from mypyc.errors import Errors
from mypyc.irbuild.main import build_ir
Expand Down Expand Up @@ -118,6 +119,8 @@ def build_ir_for_single_file(input_lines: List[str],
raise CompileError(errors.new_messages())

module = list(modules.values())[0]
for fn in module.functions:
assert_func_ir_valid(fn)
return module.functions


Expand Down

0 comments on commit ad7e353

Please sign in to comment.