Skip to content

Commit

Permalink
[mypyc] Optimize const int regs during pretty IR printing (#9181)
Browse files Browse the repository at this point in the history
Follow up of #9158, makes IR less verbose.
  • Loading branch information
TH3CHARLie authored Jul 22, 2020
1 parent a5455bd commit c2e20e9
Show file tree
Hide file tree
Showing 20 changed files with 2,441 additions and 3,178 deletions.
2 changes: 1 addition & 1 deletion mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD
from mypyc.ir.class_ir import ClassIR
from mypyc.analysis.const_int import find_constant_integer_registers
from mypyc.ir.const_int import find_constant_integer_registers

# Whether to insert debug asserts for all error handling, to quickly
# catch errors propagating without exceptions set.
Expand Down
File renamed without changes.
21 changes: 16 additions & 5 deletions mypyc/ir/func_ir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Intermediate representation of functions."""
import re

from typing import List, Optional, Sequence, Dict
from typing_extensions import Final
Expand All @@ -10,6 +11,7 @@
DeserMaps, Goto, Branch, Return, Unreachable, BasicBlock, Environment
)
from mypyc.ir.rtypes import RType, deserialize_type
from mypyc.ir.const_int import find_constant_integer_registers
from mypyc.namegen import NameGenerator


Expand Down Expand Up @@ -218,7 +220,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> 'FuncIR':
INVALID_FUNC_DEF = FuncDef('<INVALID_FUNC_DEF>', [], Block([])) # type: Final


def format_blocks(blocks: List[BasicBlock], env: Environment) -> List[str]:
def format_blocks(blocks: List[BasicBlock],
env: Environment,
const_regs: Dict[str, int]) -> List[str]:
"""Format a list of IR basic blocks into a human-readable form."""
# First label all of the blocks
for i, block in enumerate(blocks):
Expand All @@ -244,9 +248,14 @@ def format_blocks(blocks: List[BasicBlock], env: Environment) -> List[str]:
and ops[-1].label == blocks[i + 1]):
# Hide the last goto if it just goes to the next basic block.
ops = ops[:-1]
# load int registers start with 'i'
regex = re.compile(r'\bi[0-9]+\b')
for op in ops:
line = ' ' + op.to_str(env)
lines.append(line)
if op.name not in const_regs:
line = ' ' + op.to_str(env)
line = regex.sub(lambda i: str(const_regs[i.group()]) if i.group() in const_regs
else i.group(), line)
lines.append(line)

if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
# Each basic block needs to exit somewhere.
Expand All @@ -259,8 +268,10 @@ def format_func(fn: FuncIR) -> List[str]:
cls_prefix = fn.class_name + '.' if fn.class_name else ''
lines.append('def {}{}({}):'.format(cls_prefix, fn.name,
', '.join(arg.name for arg in fn.args)))
for line in fn.env.to_lines():
# compute constants
const_regs = find_constant_integer_registers(fn.blocks)
for line in fn.env.to_lines(const_regs):
lines.append(' ' + line)
code = format_blocks(fn.blocks, fn.env)
code = format_blocks(fn.blocks, fn.env, const_regs)
lines.extend(code)
return lines
20 changes: 16 additions & 4 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self, name: Optional[str] = None) -> None:
self.indexes = OrderedDict() # type: Dict[Value, int]
self.symtable = OrderedDict() # type: OrderedDict[SymbolNode, AssignmentTarget]
self.temp_index = 0
self.temp_load_int_idx = 0
# All names genereted; value is the number of duplicates seen.
self.names = {} # type: Dict[str, int]
self.vars_needing_init = set() # type: Set[Value]
Expand Down Expand Up @@ -198,6 +199,10 @@ def add_op(self, reg: 'RegisterOp') -> None:
"""Record the value of an operation."""
if reg.is_void:
return
if isinstance(reg, LoadInt):
self.add(reg, "i%d" % self.temp_load_int_idx)
self.temp_load_int_idx += 1
return
self.add(reg, 'r%d' % self.temp_index)
self.temp_index += 1

Expand Down Expand Up @@ -232,17 +237,24 @@ def format(self, fmt: str, *args: Any) -> str:
i = n
return ''.join(result)

def to_lines(self) -> List[str]:
def to_lines(self, const_regs: Optional[Dict[str, int]] = None) -> List[str]:
result = []
i = 0
regs = list(self.regs())

if const_regs is None:
const_regs = {}
while i < len(regs):
i0 = i
group = [regs[i0].name]
if regs[i0].name not in const_regs:
group = [regs[i0].name]
else:
group = []
i += 1
continue
while i + 1 < len(regs) and regs[i + 1].type == regs[i0].type:
i += 1
group.append(regs[i].name)
if regs[i].name not in const_regs:
group.append(regs[i].name)
i += 1
result.append('%s :: %s' % (', '.join(group), regs[i0].type))
return result
Expand Down
Loading

0 comments on commit c2e20e9

Please sign in to comment.