diff --git a/tests/unit/compiler/venom/test_convert_basicblock_simple.py b/tests/unit/compiler/venom/test_convert_basicblock_simple.py index fdaa341a81..406ce5f6ff 100644 --- a/tests/unit/compiler/venom/test_convert_basicblock_simple.py +++ b/tests/unit/compiler/venom/test_convert_basicblock_simple.py @@ -8,7 +8,9 @@ def test_simple(): venom = ir_node_to_venom(ir_node) assert venom is not None - bb = venom.basic_blocks[0] + fn = list(venom.functions.values())[0] + + bb = fn.entry assert bb.instructions[0].opcode == "calldatasize" assert bb.instructions[1].opcode == "calldatacopy" diff --git a/tests/unit/compiler/venom/test_dominator_tree.py b/tests/unit/compiler/venom/test_dominator_tree.py index c5b7404b58..29f86df221 100644 --- a/tests/unit/compiler/venom/test_dominator_tree.py +++ b/tests/unit/compiler/venom/test_dominator_tree.py @@ -2,18 +2,19 @@ from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet -from vyper.venom.analysis import calculate_cfg +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis.dominators import DominatorTreeAnalysis from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRLiteral, IRVariable -from vyper.venom.dominators import DominatorTree +from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.passes.make_ssa import MakeSSA def _add_bb( - ctx: IRFunction, label: IRLabel, cfg_outs: [IRLabel], bb: Optional[IRBasicBlock] = None + fn: IRFunction, label: IRLabel, cfg_outs: list[IRLabel], bb: Optional[IRBasicBlock] = None ) -> IRBasicBlock: - bb = bb if bb is not None else IRBasicBlock(label, ctx) - ctx.append_basic_block(bb) + bb = bb if bb is not None else IRBasicBlock(label, fn) + fn.append_basic_block(bb) cfg_outs_len = len(cfg_outs) if cfg_outs_len == 0: bb.append_instruction("stop") @@ -29,27 +30,27 @@ def _add_bb( def _make_test_ctx(): lab = [IRLabel(str(i)) for i in range(0, 9)] - ctx = IRFunction(lab[1]) + ctx = IRContext() + fn = ctx.create_function(lab[1].value) - bb1 = ctx.basic_blocks[0] - bb1.append_instruction("jmp", lab[2]) + fn.entry.append_instruction("jmp", lab[2]) - _add_bb(ctx, lab[7], []) - _add_bb(ctx, lab[6], [lab[7], lab[2]]) - _add_bb(ctx, lab[5], [lab[6], lab[3]]) - _add_bb(ctx, lab[4], [lab[6]]) - _add_bb(ctx, lab[3], [lab[5]]) - _add_bb(ctx, lab[2], [lab[3], lab[4]]) + _add_bb(fn, lab[7], []) + _add_bb(fn, lab[6], [lab[7], lab[2]]) + _add_bb(fn, lab[5], [lab[6], lab[3]]) + _add_bb(fn, lab[4], [lab[6]]) + _add_bb(fn, lab[3], [lab[5]]) + _add_bb(fn, lab[2], [lab[3], lab[4]]) - return ctx + return fn def test_deminator_frontier_calculation(): - ctx = _make_test_ctx() - bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + fn = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [fn.get_basic_block(str(i)) for i in range(1, 8)] - calculate_cfg(ctx) - dom = DominatorTree.build_dominator_tree(ctx, bb1) + ac = IRAnalysesCache(fn) + dom = ac.request_analysis(DominatorTreeAnalysis) df = dom.dominator_frontiers assert len(df[bb1]) == 0, df[bb1] @@ -62,12 +63,13 @@ def test_deminator_frontier_calculation(): def test_phi_placement(): - ctx = _make_test_ctx() - bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [ctx.get_basic_block(str(i)) for i in range(1, 8)] + fn = _make_test_ctx() + bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [fn.get_basic_block(str(i)) for i in range(1, 8)] x = IRVariable("%x") bb1.insert_instruction(IRInstruction("mload", [IRLiteral(0)], x), 0) bb2.insert_instruction(IRInstruction("add", [x, IRLiteral(1)], x), 0) bb7.insert_instruction(IRInstruction("mstore", [x, IRLiteral(0)]), 0) - MakeSSA().run_pass(ctx, bb1) + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() diff --git a/tests/unit/compiler/venom/test_duplicate_operands.py b/tests/unit/compiler/venom/test_duplicate_operands.py index 7cc58e6f5c..44c4ed0404 100644 --- a/tests/unit/compiler/venom/test_duplicate_operands.py +++ b/tests/unit/compiler/venom/test_duplicate_operands.py @@ -1,6 +1,6 @@ from vyper.compiler.settings import OptimizationLevel from vyper.venom import generate_assembly_experimental -from vyper.venom.function import IRFunction +from vyper.venom.context import IRContext def test_duplicate_operands(): @@ -15,8 +15,9 @@ def test_duplicate_operands(): Should compile to: [PUSH1, 10, DUP1, DUP1, DUP1, ADD, MUL, STOP] """ - ctx = IRFunction() - bb = ctx.get_basic_block() + ctx = IRContext() + fn = ctx.create_function("test") + bb = fn.get_basic_block() op = bb.append_instruction("store", 10) sum_ = bb.append_instruction("add", op, op) bb.append_instruction("mul", sum_, op) diff --git a/tests/unit/compiler/venom/test_make_ssa.py b/tests/unit/compiler/venom/test_make_ssa.py index da3a143b30..9cea1a20a4 100644 --- a/tests/unit/compiler/venom/test_make_ssa.py +++ b/tests/unit/compiler/venom/test_make_ssa.py @@ -1,22 +1,23 @@ -from vyper.venom.analysis import calculate_cfg, calculate_liveness +from vyper.venom.analysis.analysis import IRAnalysesCache from vyper.venom.basicblock import IRBasicBlock, IRLabel -from vyper.venom.function import IRFunction +from vyper.venom.context import IRContext from vyper.venom.passes.make_ssa import MakeSSA def test_phi_case(): - ctx = IRFunction(IRLabel("_global")) + ctx = IRContext() + fn = ctx.create_function("_global") - bb = ctx.get_basic_block() + bb = fn.get_basic_block() - bb_cont = IRBasicBlock(IRLabel("condition"), ctx) - bb_then = IRBasicBlock(IRLabel("then"), ctx) - bb_else = IRBasicBlock(IRLabel("else"), ctx) - bb_if_exit = IRBasicBlock(IRLabel("if_exit"), ctx) - ctx.append_basic_block(bb_cont) - ctx.append_basic_block(bb_then) - ctx.append_basic_block(bb_else) - ctx.append_basic_block(bb_if_exit) + bb_cont = IRBasicBlock(IRLabel("condition"), fn) + bb_then = IRBasicBlock(IRLabel("then"), fn) + bb_else = IRBasicBlock(IRLabel("else"), fn) + bb_if_exit = IRBasicBlock(IRLabel("if_exit"), fn) + fn.append_basic_block(bb_cont) + fn.append_basic_block(bb_then) + fn.append_basic_block(bb_else) + fn.append_basic_block(bb_if_exit) v = bb.append_instruction("mload", 64) bb_cont.append_instruction("jnz", v, bb_then.label, bb_else.label) @@ -30,11 +31,10 @@ def test_phi_case(): bb.append_instruction("jmp", bb_cont.label) - calculate_cfg(ctx) - MakeSSA().run_pass(ctx, ctx.basic_blocks[0]) - calculate_liveness(ctx) + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() - condition_block = ctx.get_basic_block("condition") + condition_block = fn.get_basic_block("condition") assert len(condition_block.instructions) == 2 phi_inst = condition_block.instructions[0] diff --git a/tests/unit/compiler/venom/test_multi_entry_block.py b/tests/unit/compiler/venom/test_multi_entry_block.py index 95b6e62daf..313fbb3ebd 100644 --- a/tests/unit/compiler/venom/test_multi_entry_block.py +++ b/tests/unit/compiler/venom/test_multi_entry_block.py @@ -1,44 +1,48 @@ -from vyper.venom.analysis import calculate_cfg -from vyper.venom.function import IRBasicBlock, IRFunction, IRLabel +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.context import IRContext +from vyper.venom.function import IRBasicBlock, IRLabel from vyper.venom.passes.normalization import NormalizationPass def test_multi_entry_block_1(): - ctx = IRFunction() + ctx = IRContext() + fn = ctx.create_function("__global") finish_label = IRLabel("finish") target_label = IRLabel("target") - block_1_label = IRLabel("block_1", ctx) + block_1_label = IRLabel("block_1", fn) - bb = ctx.get_basic_block() + bb = fn.get_basic_block() op = bb.append_instruction("store", 10) acc = bb.append_instruction("add", op, op) bb.append_instruction("jnz", acc, finish_label, block_1_label) - block_1 = IRBasicBlock(block_1_label, ctx) - ctx.append_basic_block(block_1) + block_1 = IRBasicBlock(block_1_label, fn) + fn.append_basic_block(block_1) acc = block_1.append_instruction("add", acc, op) op = block_1.append_instruction("store", 10) block_1.append_instruction("mstore", acc, op) block_1.append_instruction("jnz", acc, finish_label, target_label) - target_bb = IRBasicBlock(target_label, ctx) - ctx.append_basic_block(target_bb) + target_bb = IRBasicBlock(target_label, fn) + fn.append_basic_block(target_bb) target_bb.append_instruction("mul", acc, acc) target_bb.append_instruction("jmp", finish_label) - finish_bb = IRBasicBlock(finish_label, ctx) - ctx.append_basic_block(finish_bb) + finish_bb = IRBasicBlock(finish_label, fn) + fn.append_basic_block(finish_bb) finish_bb.append_instruction("stop") - calculate_cfg(ctx) - assert not ctx.normalized, "CFG should not be normalized" + ac = IRAnalysesCache(fn) + ac.request_analysis(CFGAnalysis) + assert not fn.normalized, "CFG should not be normalized" - NormalizationPass().run_pass(ctx) + NormalizationPass(ac, fn).run_pass() - assert ctx.normalized, "CFG should be normalized" + assert fn.normalized, "CFG should be normalized" - finish_bb = ctx.get_basic_block(finish_label.value) + finish_bb = fn.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" @@ -47,50 +51,52 @@ def test_multi_entry_block_1(): # more complicated one def test_multi_entry_block_2(): - ctx = IRFunction() + ctx = IRContext() + fn = ctx.create_function("__global") finish_label = IRLabel("finish") target_label = IRLabel("target") - block_1_label = IRLabel("block_1", ctx) - block_2_label = IRLabel("block_2", ctx) + block_1_label = IRLabel("block_1", fn) + block_2_label = IRLabel("block_2", fn) - bb = ctx.get_basic_block() + bb = fn.get_basic_block() op = bb.append_instruction("store", 10) acc = bb.append_instruction("add", op, op) bb.append_instruction("jnz", acc, finish_label, block_1_label) - block_1 = IRBasicBlock(block_1_label, ctx) - ctx.append_basic_block(block_1) + block_1 = IRBasicBlock(block_1_label, fn) + fn.append_basic_block(block_1) acc = block_1.append_instruction("add", acc, op) op = block_1.append_instruction("store", 10) block_1.append_instruction("mstore", acc, op) block_1.append_instruction("jnz", acc, target_label, finish_label) - block_2 = IRBasicBlock(block_2_label, ctx) - ctx.append_basic_block(block_2) + block_2 = IRBasicBlock(block_2_label, fn) + fn.append_basic_block(block_2) acc = block_2.append_instruction("add", acc, op) op = block_2.append_instruction("store", 10) block_2.append_instruction("mstore", acc, op) # switch the order of the labels, for fun and profit block_2.append_instruction("jnz", acc, finish_label, target_label) - target_bb = IRBasicBlock(target_label, ctx) - ctx.append_basic_block(target_bb) + target_bb = IRBasicBlock(target_label, fn) + fn.append_basic_block(target_bb) target_bb.append_instruction("mul", acc, acc) target_bb.append_instruction("jmp", finish_label) - finish_bb = IRBasicBlock(finish_label, ctx) - ctx.append_basic_block(finish_bb) + finish_bb = IRBasicBlock(finish_label, fn) + fn.append_basic_block(finish_bb) finish_bb.append_instruction("stop") - calculate_cfg(ctx) - assert not ctx.normalized, "CFG should not be normalized" + ac = IRAnalysesCache(fn) + ac.request_analysis(CFGAnalysis) + assert not fn.normalized, "CFG should not be normalized" - NormalizationPass().run_pass(ctx) + NormalizationPass(ac, fn).run_pass() - assert ctx.normalized, "CFG should be normalized" + assert fn.normalized, "CFG should be normalized" - finish_bb = ctx.get_basic_block(finish_label.value) + finish_bb = fn.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" @@ -98,40 +104,42 @@ def test_multi_entry_block_2(): def test_multi_entry_block_with_dynamic_jump(): - ctx = IRFunction() + ctx = IRContext() + fn = ctx.create_function("__global") finish_label = IRLabel("finish") target_label = IRLabel("target") - block_1_label = IRLabel("block_1", ctx) + block_1_label = IRLabel("block_1", fn) - bb = ctx.get_basic_block() + bb = fn.get_basic_block() op = bb.append_instruction("store", 10) acc = bb.append_instruction("add", op, op) bb.append_instruction("djmp", acc, finish_label, block_1_label) - block_1 = IRBasicBlock(block_1_label, ctx) - ctx.append_basic_block(block_1) + block_1 = IRBasicBlock(block_1_label, fn) + fn.append_basic_block(block_1) acc = block_1.append_instruction("add", acc, op) op = block_1.append_instruction("store", 10) block_1.append_instruction("mstore", acc, op) block_1.append_instruction("jnz", acc, finish_label, target_label) - target_bb = IRBasicBlock(target_label, ctx) - ctx.append_basic_block(target_bb) + target_bb = IRBasicBlock(target_label, fn) + fn.append_basic_block(target_bb) target_bb.append_instruction("mul", acc, acc) target_bb.append_instruction("jmp", finish_label) - finish_bb = IRBasicBlock(finish_label, ctx) - ctx.append_basic_block(finish_bb) + finish_bb = IRBasicBlock(finish_label, fn) + fn.append_basic_block(finish_bb) finish_bb.append_instruction("stop") - calculate_cfg(ctx) - assert not ctx.normalized, "CFG should not be normalized" + ac = IRAnalysesCache(fn) + ac.request_analysis(CFGAnalysis) + assert not fn.normalized, "CFG should not be normalized" - NormalizationPass().run_pass(ctx) - assert ctx.normalized, "CFG should be normalized" + NormalizationPass(ac, fn).run_pass() + assert fn.normalized, "CFG should be normalized" - finish_bb = ctx.get_basic_block(finish_label.value) + finish_bb = fn.get_basic_block(finish_label.value) cfg_in = list(finish_bb.cfg_in) assert cfg_in[0].label.value == "target", "Should contain target" assert cfg_in[1].label.value == "__global_split_finish", "Should contain __global_split_finish" diff --git a/tests/unit/compiler/venom/test_sccp.py b/tests/unit/compiler/venom/test_sccp.py index 8102a0d89c..37a8bc9000 100644 --- a/tests/unit/compiler/venom/test_sccp.py +++ b/tests/unit/compiler/venom/test_sccp.py @@ -1,24 +1,26 @@ +from vyper.venom.analysis.analysis import IRAnalysesCache from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable -from vyper.venom.function import IRFunction +from vyper.venom.context import IRContext from vyper.venom.passes.make_ssa import MakeSSA from vyper.venom.passes.sccp import SCCP from vyper.venom.passes.sccp.sccp import LatticeEnum def test_simple_case(): - ctx = IRFunction(IRLabel("_global")) + ctx = IRContext() + fn = ctx.create_function("_global") - bb = ctx.get_basic_block() + bb = fn.get_basic_block() p1 = bb.append_instruction("param") op1 = bb.append_instruction("store", 32) op2 = bb.append_instruction("store", 64) op3 = bb.append_instruction("add", op1, op2) bb.append_instruction("return", p1, op3) - make_ssa_pass = MakeSSA() - make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) - sccp = SCCP(make_ssa_pass.dom) - sccp.run_pass(ctx, ctx.basic_blocks[0]) + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + sccp = SCCP(ac, fn) + sccp.run_pass() assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM assert sccp.lattice[IRVariable("%2")].value == 32 @@ -27,14 +29,15 @@ def test_simple_case(): def test_cont_jump_case(): - ctx = IRFunction(IRLabel("_global")) + ctx = IRContext() + fn = ctx.create_function("_global") - bb = ctx.get_basic_block() + bb = fn.get_basic_block() - br1 = IRBasicBlock(IRLabel("then"), ctx) - ctx.append_basic_block(br1) - br2 = IRBasicBlock(IRLabel("else"), ctx) - ctx.append_basic_block(br2) + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) p1 = bb.append_instruction("param") op1 = bb.append_instruction("store", 32) @@ -47,10 +50,10 @@ def test_cont_jump_case(): br2.append_instruction("add", op3, p1) br2.append_instruction("stop") - make_ssa_pass = MakeSSA() - make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) - sccp = SCCP(make_ssa_pass.dom) - sccp.run_pass(ctx, ctx.basic_blocks[0]) + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + sccp = SCCP(ac, fn) + sccp.run_pass() assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM assert sccp.lattice[IRVariable("%2")].value == 32 @@ -61,16 +64,17 @@ def test_cont_jump_case(): def test_cont_phi_case(): - ctx = IRFunction(IRLabel("_global")) + ctx = IRContext() + fn = ctx.create_function("_global") - bb = ctx.get_basic_block() + bb = fn.get_basic_block() - br1 = IRBasicBlock(IRLabel("then"), ctx) - ctx.append_basic_block(br1) - br2 = IRBasicBlock(IRLabel("else"), ctx) - ctx.append_basic_block(br2) - join = IRBasicBlock(IRLabel("join"), ctx) - ctx.append_basic_block(join) + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + join = IRBasicBlock(IRLabel("join"), fn) + fn.append_basic_block(join) p1 = bb.append_instruction("param") op1 = bb.append_instruction("store", 32) @@ -85,11 +89,10 @@ def test_cont_phi_case(): join.append_instruction("return", op4, p1) - make_ssa_pass = MakeSSA() - make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) - - sccp = SCCP(make_ssa_pass.dom) - sccp.run_pass(ctx, ctx.basic_blocks[0]) + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + sccp = SCCP(ac, fn) + sccp.run_pass() assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM assert sccp.lattice[IRVariable("%2")].value == 32 @@ -101,16 +104,17 @@ def test_cont_phi_case(): def test_cont_phi_const_case(): - ctx = IRFunction(IRLabel("_global")) + ctx = IRContext() + fn = ctx.create_function("_global") - bb = ctx.get_basic_block() + bb = fn.get_basic_block() - br1 = IRBasicBlock(IRLabel("then"), ctx) - ctx.append_basic_block(br1) - br2 = IRBasicBlock(IRLabel("else"), ctx) - ctx.append_basic_block(br2) - join = IRBasicBlock(IRLabel("join"), ctx) - ctx.append_basic_block(join) + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + join = IRBasicBlock(IRLabel("join"), fn) + fn.append_basic_block(join) p1 = bb.append_instruction("store", 1) op1 = bb.append_instruction("store", 32) @@ -125,10 +129,10 @@ def test_cont_phi_const_case(): join.append_instruction("return", op4, p1) - make_ssa_pass = MakeSSA() - make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) - sccp = SCCP(make_ssa_pass.dom) - sccp.run_pass(ctx, ctx.basic_blocks[0]) + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + sccp = SCCP(ac, fn) + sccp.run_pass() assert sccp.lattice[IRVariable("%1")].value == 1 assert sccp.lattice[IRVariable("%2")].value == 32 diff --git a/vyper/utils.py b/vyper/utils.py index 01ae37e213..600f5552ab 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -578,25 +578,3 @@ def annotate_source_code( cleanup_lines += [""] * (num_lines - len(cleanup_lines)) return "\n".join(cleanup_lines) - - -def ir_pass(func): - """ - Decorator for IR passes. This decorator will run the pass repeatedly until - no more changes are made. - """ - - def wrapper(*args, **kwargs): - count = 0 - - while True: - changes = func(*args, **kwargs) or 0 - if isinstance(changes, list) or isinstance(changes, set): - changes = len(changes) - count += changes - if changes == 0: - break - - return count - - return wrapper diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index a60f679a76..4e13a220ef 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -5,17 +5,15 @@ from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness -from vyper.venom.bb_optimizer import ( - ir_pass_optimize_empty_blocks, - ir_pass_optimize_unused_variables, - ir_pass_remove_unreachable_blocks, -) +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom from vyper.venom.passes.dft import DFTPass from vyper.venom.passes.make_ssa import MakeSSA from vyper.venom.passes.mem2var import Mem2Var +from vyper.venom.passes.remove_unused_variables import RemoveUnusedVariablesPass from vyper.venom.passes.sccp import SCCP from vyper.venom.passes.simplify_cfg import SimplifyCFGPass from vyper.venom.venom_to_assembly import VenomCompiler @@ -24,8 +22,8 @@ def generate_assembly_experimental( - runtime_code: IRFunction, - deploy_code: Optional[IRFunction] = None, + runtime_code: IRContext, + deploy_code: Optional[IRContext] = None, optimize: OptimizationLevel = DEFAULT_OPT_LEVEL, ) -> list[str]: # note: VenomCompiler is sensitive to the order of these! @@ -38,73 +36,26 @@ def generate_assembly_experimental( return compiler.generate_evm(optimize == OptimizationLevel.NONE) -def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: +def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: # Run passes on Venom IR # TODO: Add support for optimization levels - ir_pass_optimize_empty_blocks(ctx) - ir_pass_remove_unreachable_blocks(ctx) + ac = IRAnalysesCache(fn) - internals = [ - bb - for bb in ctx.basic_blocks - if bb.label.value.startswith("internal") and len(bb.cfg_in) == 0 - ] + SimplifyCFGPass(ac, fn).run_pass() + Mem2Var(ac, fn).run_pass() + MakeSSA(ac, fn).run_pass() + SCCP(ac, fn).run_pass() - SimplifyCFGPass().run_pass(ctx, ctx.basic_blocks[0]) - for entry in internals: - SimplifyCFGPass().run_pass(ctx, entry) + SimplifyCFGPass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() + DFTPass(ac, fn).run_pass() - dfg = DFG.build_dfg(ctx) - Mem2Var().run_pass(ctx, ctx.basic_blocks[0], dfg) - for entry in internals: - Mem2Var().run_pass(ctx, entry, dfg) - make_ssa_pass = MakeSSA() - make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) - - cfg_dirty = False - sccp_pass = SCCP(make_ssa_pass.dom) - sccp_pass.run_pass(ctx, ctx.basic_blocks[0]) - cfg_dirty |= sccp_pass.cfg_dirty - - for entry in internals: - make_ssa_pass.run_pass(ctx, entry) - sccp_pass = SCCP(make_ssa_pass.dom) - sccp_pass.run_pass(ctx, entry) - cfg_dirty |= sccp_pass.cfg_dirty - - calculate_cfg(ctx) - SimplifyCFGPass().run_pass(ctx, ctx.basic_blocks[0]) - - calculate_cfg(ctx) - calculate_liveness(ctx) - - while True: - changes = 0 - - changes += ir_pass_optimize_empty_blocks(ctx) - changes += ir_pass_remove_unreachable_blocks(ctx) - - calculate_liveness(ctx) - - changes += ir_pass_optimize_unused_variables(ctx) - - calculate_cfg(ctx) - calculate_liveness(ctx) - - changes += DFTPass().run_pass(ctx) - - calculate_cfg(ctx) - calculate_liveness(ctx) - - if changes == 0: - break - - -def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> IRFunction: +def generate_ir(ir: IRnode, optimize: OptimizationLevel) -> IRContext: # Convert "old" IR to "new" IR ctx = ir_node_to_venom(ir) - _run_passes(ctx, optimize) + for fn in ctx.functions.values(): + _run_passes(fn, optimize) return ctx diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py deleted file mode 100644 index 8e4c24fea3..0000000000 --- a/vyper/venom/analysis.py +++ /dev/null @@ -1,198 +0,0 @@ -from typing import Optional - -from vyper.exceptions import CompilerPanic -from vyper.utils import OrderedSet -from vyper.venom.basicblock import ( - BB_TERMINATORS, - CFG_ALTERING_INSTRUCTIONS, - IRBasicBlock, - IRInstruction, - IRVariable, -) -from vyper.venom.function import IRFunction - - -def calculate_cfg(ctx: IRFunction) -> None: - """ - Calculate (cfg) inputs for each basic block. - """ - for bb in ctx.basic_blocks: - bb.cfg_in = OrderedSet() - bb.cfg_out = OrderedSet() - bb.out_vars = OrderedSet() - - for bb in ctx.basic_blocks: - assert len(bb.instructions) > 0, "Basic block should not be empty" - last_inst = bb.instructions[-1] - assert last_inst.opcode in BB_TERMINATORS, f"Last instruction should be a terminator {bb}" - - for inst in bb.instructions: - if inst.opcode in CFG_ALTERING_INSTRUCTIONS: - ops = inst.get_label_operands() - for op in ops: - ctx.get_basic_block(op.value).add_cfg_in(bb) - - # Fill in the "out" set for each basic block - for bb in ctx.basic_blocks: - for in_bb in bb.cfg_in: - in_bb.add_cfg_out(bb) - - -def _reset_liveness(ctx: IRFunction) -> None: - for bb in ctx.basic_blocks: - bb.out_vars = OrderedSet() - for inst in bb.instructions: - inst.liveness = OrderedSet() - - -def _calculate_liveness(bb: IRBasicBlock) -> bool: - """ - Compute liveness of each instruction in the basic block. - Returns True if liveness changed - """ - orig_liveness = bb.instructions[0].liveness.copy() - liveness = bb.out_vars.copy() - for instruction in reversed(bb.instructions): - ins = instruction.get_inputs() - outs = instruction.get_outputs() - - if ins or outs: - # perf: only copy if changed - liveness = liveness.copy() - liveness.update(ins) - liveness.dropmany(outs) - - instruction.liveness = liveness - - return orig_liveness != bb.instructions[0].liveness - - -def _calculate_out_vars(bb: IRBasicBlock) -> bool: - """ - Compute out_vars of basic block. - Returns True if out_vars changed - """ - out_vars = bb.out_vars.copy() - for out_bb in bb.cfg_out: - target_vars = input_vars_from(bb, out_bb) - bb.out_vars = bb.out_vars.union(target_vars) - return out_vars != bb.out_vars - - -def calculate_liveness(ctx: IRFunction) -> None: - _reset_liveness(ctx) - while True: - changed = False - for bb in ctx.basic_blocks: - changed |= _calculate_out_vars(bb) - changed |= _calculate_liveness(bb) - - if not changed: - break - - -def calculate_dup_requirements(ctx: IRFunction) -> None: - for bb in ctx.basic_blocks: - last_liveness = bb.out_vars - for inst in reversed(bb.instructions): - inst.dup_requirements = OrderedSet() - ops = inst.get_inputs() - for op in ops: - if op in last_liveness: - inst.dup_requirements.add(op) - last_liveness = inst.liveness - - -# calculate the input variables into self from source -def input_vars_from(source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: - liveness = target.instructions[0].liveness.copy() - assert isinstance(liveness, OrderedSet) - - for inst in target.instructions: - if inst.opcode == "phi": - # we arbitrarily choose one of the arguments to be in the - # live variables set (dependent on how we traversed into this - # basic block). the argument will be replaced by the destination - # operand during instruction selection. - # for instance, `%56 = phi %label1 %12 %label2 %14` - # will arbitrarily choose either %12 or %14 to be in the liveness - # set, and then during instruction selection, after this instruction, - # %12 will be replaced by %56 in the liveness set - - # bad path into this phi node - if source.label not in inst.operands: - raise CompilerPanic(f"unreachable: {inst} from {source.label}") - - for label, var in inst.phi_operands: - if label == source.label: - liveness.add(var) - else: - if var in liveness: - liveness.remove(var) - - return liveness - - -# DataFlow Graph -# this could be refactored into its own file, but it's only used here -# for now -class DFG: - _dfg_inputs: dict[IRVariable, list[IRInstruction]] - _dfg_outputs: dict[IRVariable, IRInstruction] - - def __init__(self): - self._dfg_inputs = dict() - self._dfg_outputs = dict() - - # return uses of a given variable - def get_uses(self, op: IRVariable) -> list[IRInstruction]: - return self._dfg_inputs.get(op, []) - - # the instruction which produces this variable. - def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: - return self._dfg_outputs.get(op) - - @property - def outputs(self) -> dict[IRVariable, IRInstruction]: - return self._dfg_outputs - - @classmethod - def build_dfg(cls, ctx: IRFunction) -> "DFG": - dfg = cls() - - # Build DFG - - # %15 = add %13 %14 - # %16 = iszero %15 - # dfg_outputs of %15 is (%15 = add %13 %14) - # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] - for bb in ctx.basic_blocks: - for inst in bb.instructions: - operands = inst.get_inputs() - res = inst.get_outputs() - - for op in operands: - inputs = dfg._dfg_inputs.setdefault(op, []) - inputs.append(inst) - - for op in res: # type: ignore - dfg._dfg_outputs[op] = inst - - return dfg - - def as_graph(self) -> str: - """ - Generate a graphviz representation of the dfg - """ - lines = ["digraph dfg_graph {"] - for var, inputs in self._dfg_inputs.items(): - for input in inputs: - for op in input.get_outputs(): - if isinstance(op, IRVariable): - lines.append(f' " {var.name} " -> " {op.name} "') - - lines.append("}") - return "\n".join(lines) - - def __repr__(self) -> str: - return self.as_graph() diff --git a/vyper/venom/analysis/__init__.py b/vyper/venom/analysis/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vyper/venom/analysis/analysis.py b/vyper/venom/analysis/analysis.py new file mode 100644 index 0000000000..f154993925 --- /dev/null +++ b/vyper/venom/analysis/analysis.py @@ -0,0 +1,75 @@ +from typing import Type + +from vyper.venom.function import IRFunction + + +class IRAnalysis: + """ + Base class for all Venom IR analyses. + """ + + function: "IRFunction" + analyses_cache: "IRAnalysesCache" + + def __init__(self, analyses_cache: "IRAnalysesCache", function: IRFunction): + self.analyses_cache = analyses_cache + self.function = function + + def analyze(self, *args, **kwargs): + """ + Override this method to perform the analysis. + """ + raise NotImplementedError + + def invalidate(self): + """ + Override this method to respond to an invalidation request, and possibly + invalidate any other analyses that depend on this one. + """ + pass + + +class IRAnalysesCache: + """ + A cache for IR analyses. + """ + + function: IRFunction + analyses_cache: dict[Type[IRAnalysis], IRAnalysis] + + def __init__(self, function: IRFunction): + self.analyses_cache = {} + self.function = function + + def request_analysis(self, analysis_cls: Type[IRAnalysis], *args, **kwargs): + """ + Request a specific analysis to be run on the IR. The result is cached and + returned if the analysis has already been run. + """ + assert issubclass(analysis_cls, IRAnalysis), f"{analysis_cls} is not an IRAnalysis" + if analysis_cls in self.analyses_cache: + return self.analyses_cache[analysis_cls] + analysis = analysis_cls(self, self.function) + analysis.analyze(*args, **kwargs) + + self.analyses_cache[analysis_cls] = analysis + return analysis + + def invalidate_analysis(self, analysis_cls: Type[IRAnalysis]): + """ + Invalidate a specific analysis. This will remove the analysis from the cache. + """ + assert issubclass(analysis_cls, IRAnalysis), f"{analysis_cls} is not an IRAnalysis" + analysis = self.analyses_cache.pop(analysis_cls, None) + if analysis is not None: + analysis.invalidate() + + def force_analysis(self, analysis_cls: Type[IRAnalysis], *args, **kwargs): + """ + Force a specific analysis to be run on the IR even if it has already been run, + and is cached. + """ + assert issubclass(analysis_cls, IRAnalysis), f"{analysis_cls} is not an IRAnalysis" + if analysis_cls in self.analyses_cache: + self.invalidate_analysis(analysis_cls) + return self.request_analysis(analysis_cls, *args, **kwargs) diff --git a/vyper/venom/analysis/cfg.py b/vyper/venom/analysis/cfg.py new file mode 100644 index 0000000000..2a521ab131 --- /dev/null +++ b/vyper/venom/analysis/cfg.py @@ -0,0 +1,41 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis.analysis import IRAnalysis +from vyper.venom.basicblock import BB_TERMINATORS, CFG_ALTERING_INSTRUCTIONS + + +class CFGAnalysis(IRAnalysis): + """ + Compute control flow graph information for each basic block in the function. + """ + + def analyze(self) -> None: + fn = self.function + for bb in fn.basic_blocks: + bb.cfg_in = OrderedSet() + bb.cfg_out = OrderedSet() + bb.out_vars = OrderedSet() + + for bb in fn.basic_blocks: + assert len(bb.instructions) > 0, "Basic block should not be empty" + last_inst = bb.instructions[-1] + assert ( + last_inst.opcode in BB_TERMINATORS + ), f"Last instruction should be a terminator {bb}" + + for inst in bb.instructions: + if inst.opcode in CFG_ALTERING_INSTRUCTIONS: + ops = inst.get_label_operands() + for op in ops: + fn.get_basic_block(op.value).add_cfg_in(bb) + + # Fill in the "out" set for each basic block + for bb in fn.basic_blocks: + for in_bb in bb.cfg_in: + in_bb.add_cfg_out(bb) + + def invalidate(self): + from vyper.venom.analysis.dominators import DominatorTreeAnalysis + from vyper.venom.analysis.liveness import LivenessAnalysis + + self.analyses_cache.invalidate_analysis(DominatorTreeAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py new file mode 100644 index 0000000000..8b113e74bc --- /dev/null +++ b/vyper/venom/analysis/dfg.py @@ -0,0 +1,63 @@ +from typing import Optional + +from vyper.venom.analysis.analysis import IRAnalysesCache, IRAnalysis +from vyper.venom.basicblock import IRInstruction, IRVariable +from vyper.venom.function import IRFunction + + +class DFGAnalysis(IRAnalysis): + _dfg_inputs: dict[IRVariable, list[IRInstruction]] + _dfg_outputs: dict[IRVariable, IRInstruction] + + def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): + super().__init__(analyses_cache, function) + self._dfg_inputs = dict() + self._dfg_outputs = dict() + + # return uses of a given variable + def get_uses(self, op: IRVariable) -> list[IRInstruction]: + return self._dfg_inputs.get(op, []) + + # the instruction which produces this variable. + def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: + return self._dfg_outputs.get(op) + + @property + def outputs(self) -> dict[IRVariable, IRInstruction]: + return self._dfg_outputs + + def analyze(self): + # Build DFG + + # %15 = add %13 %14 + # %16 = iszero %15 + # dfg_outputs of %15 is (%15 = add %13 %14) + # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] + for bb in self.function.basic_blocks: + for inst in bb.instructions: + operands = inst.get_inputs() + res = inst.get_outputs() + + for op in operands: + inputs = self._dfg_inputs.setdefault(op, []) + inputs.append(inst) + + for op in res: # type: ignore + self._dfg_outputs[op] = inst + + def as_graph(self) -> str: + """ + Generate a graphviz representation of the dfg + """ + lines = ["digraph dfg_graph {"] + for var, inputs in self._dfg_inputs.items(): + for input in inputs: + for op in input.get_outputs(): + if isinstance(op, IRVariable): + lines.append(f' " {var.name} " -> " {op.name} "') + + lines.append("}") + return "\n".join(lines) + + def __repr__(self) -> str: + return self.as_graph() diff --git a/vyper/venom/dominators.py b/vyper/venom/analysis/dominators.py similarity index 93% rename from vyper/venom/dominators.py rename to vyper/venom/analysis/dominators.py index b69c17e1d8..c0b149d880 100644 --- a/vyper/venom/dominators.py +++ b/vyper/venom/analysis/dominators.py @@ -1,17 +1,19 @@ from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet +from vyper.venom.analysis.analysis import IRAnalysis +from vyper.venom.analysis.cfg import CFGAnalysis from vyper.venom.basicblock import IRBasicBlock from vyper.venom.function import IRFunction -class DominatorTree: +class DominatorTreeAnalysis(IRAnalysis): """ Dominator tree implementation. This class computes the dominator tree of a function and provides methods to query the tree. The tree is computed using the Lengauer-Tarjan algorithm. """ - ctx: IRFunction + fn: IRFunction entry_block: IRBasicBlock dfs_order: dict[IRBasicBlock, int] dfs_walk: list[IRBasicBlock] @@ -20,18 +22,12 @@ class DominatorTree: dominated: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] dominator_frontiers: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] - @classmethod - def build_dominator_tree(cls, ctx, entry): - ret = DominatorTree() - ret.compute(ctx, entry) - return ret - - def compute(self, ctx: IRFunction, entry: IRBasicBlock): + def analyze(self): """ Compute the dominator tree. """ - self.ctx = ctx - self.entry_block = entry + self.fn = self.function + self.entry_block = self.fn.entry self.dfs_order = {} self.dfs_walk = [] self.dominators = {} @@ -39,6 +35,8 @@ def compute(self, ctx: IRFunction, entry: IRBasicBlock): self.dominated = {} self.dominator_frontiers = {} + self.analyses_cache.request_analysis(CFGAnalysis) + self._compute_dfs(self.entry_block, OrderedSet()) self._compute_dominators() self._compute_idoms() @@ -155,7 +153,7 @@ def as_graph(self) -> str: Generate a graphviz representation of the dominator tree. """ lines = ["digraph dominator_tree {"] - for bb in self.ctx.basic_blocks: + for bb in self.fn.basic_blocks: if bb == self.entry_block: continue idom = self.immediate_dominator(bb) diff --git a/vyper/venom/analysis/dup_requirements.py b/vyper/venom/analysis/dup_requirements.py new file mode 100644 index 0000000000..015c7c5871 --- /dev/null +++ b/vyper/venom/analysis/dup_requirements.py @@ -0,0 +1,15 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis.analysis import IRAnalysis + + +class DupRequirementsAnalysis(IRAnalysis): + def analyze(self): + for bb in self.function.basic_blocks: + last_liveness = bb.out_vars + for inst in reversed(bb.instructions): + inst.dup_requirements = OrderedSet() + ops = inst.get_inputs() + for op in ops: + if op in last_liveness: + inst.dup_requirements.add(op) + last_liveness = inst.liveness diff --git a/vyper/venom/analysis/liveness.py b/vyper/venom/analysis/liveness.py new file mode 100644 index 0000000000..95853e57aa --- /dev/null +++ b/vyper/venom/analysis/liveness.py @@ -0,0 +1,90 @@ +from vyper.exceptions import CompilerPanic +from vyper.utils import OrderedSet +from vyper.venom.analysis.analysis import IRAnalysis +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRVariable + + +class LivenessAnalysis(IRAnalysis): + """ + Compute liveness information for each instruction in the function. + """ + + def analyze(self): + self.analyses_cache.request_analysis(CFGAnalysis) + self._reset_liveness() + while True: + changed = False + for bb in self.function.basic_blocks: + changed |= self._calculate_out_vars(bb) + changed |= self._calculate_liveness(bb) + + if not changed: + break + + def _reset_liveness(self) -> None: + for bb in self.function.basic_blocks: + bb.out_vars = OrderedSet() + for inst in bb.instructions: + inst.liveness = OrderedSet() + + def _calculate_liveness(self, bb: IRBasicBlock) -> bool: + """ + Compute liveness of each instruction in the basic block. + Returns True if liveness changed + """ + orig_liveness = bb.instructions[0].liveness.copy() + liveness = bb.out_vars.copy() + for instruction in reversed(bb.instructions): + ins = instruction.get_inputs() + outs = instruction.get_outputs() + + if ins or outs: + # perf: only copy if changed + liveness = liveness.copy() + liveness.update(ins) + liveness.dropmany(outs) + + instruction.liveness = liveness + + return orig_liveness != bb.instructions[0].liveness + + def _calculate_out_vars(self, bb: IRBasicBlock) -> bool: + """ + Compute out_vars of basic block. + Returns True if out_vars changed + """ + out_vars = bb.out_vars.copy() + for out_bb in bb.cfg_out: + target_vars = self.input_vars_from(bb, out_bb) + bb.out_vars = bb.out_vars.union(target_vars) + return out_vars != bb.out_vars + + # calculate the input variables into self from source + def input_vars_from(self, source: IRBasicBlock, target: IRBasicBlock) -> OrderedSet[IRVariable]: + liveness = target.instructions[0].liveness.copy() + assert isinstance(liveness, OrderedSet) + + for inst in target.instructions: + if inst.opcode == "phi": + # we arbitrarily choose one of the arguments to be in the + # live variables set (dependent on how we traversed into this + # basic block). the argument will be replaced by the destination + # operand during instruction selection. + # for instance, `%56 = phi %label1 %12 %label2 %14` + # will arbitrarily choose either %12 or %14 to be in the liveness + # set, and then during instruction selection, after this instruction, + # %12 will be replaced by %56 in the liveness set + + # bad path into this phi node + if source.label not in inst.operands: + raise CompilerPanic(f"unreachable: {inst} from {source.label}") + + for label, var in inst.phi_operands: + if label == source.label: + liveness.add(var) + else: + if var in liveness: + liveness.remove(var) + + return liveness diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py deleted file mode 100644 index 284a1f1b9c..0000000000 --- a/vyper/venom/bb_optimizer.py +++ /dev/null @@ -1,93 +0,0 @@ -from vyper.utils import ir_pass -from vyper.venom.analysis import calculate_cfg -from vyper.venom.basicblock import IRInstruction, IRLabel -from vyper.venom.function import IRFunction - - -def _optimize_unused_variables(ctx: IRFunction) -> set[IRInstruction]: - """ - Remove unused variables. - """ - removeList = set() - for bb in ctx.basic_blocks: - for i, inst in enumerate(bb.instructions[:-1]): - if inst.volatile: - continue - next_liveness = bb.instructions[i + 1].liveness - if (inst.output and inst.output not in next_liveness) or inst.opcode == "nop": - removeList.add(inst) - - bb.instructions = [inst for inst in bb.instructions if inst not in removeList] - - return removeList - - -def _optimize_empty_basicblocks(ctx: IRFunction) -> int: - """ - Remove empty basic blocks. - """ - count = 0 - i = 0 - while i < len(ctx.basic_blocks): - bb = ctx.basic_blocks[i] - i += 1 - if len(bb.instructions) > 0: - continue - - replaced_label = bb.label - replacement_label = ctx.basic_blocks[i].label if i < len(ctx.basic_blocks) else None - if replacement_label is None: - continue - - # Try to preserve symbol labels - if replaced_label.is_symbol: - replaced_label, replacement_label = replacement_label, replaced_label - ctx.basic_blocks[i].label = replacement_label - - for bb2 in ctx.basic_blocks: - for inst in bb2.instructions: - for op in inst.operands: - if isinstance(op, IRLabel) and op.value == replaced_label.value: - op.value = replacement_label.value - - ctx.basic_blocks.remove(bb) - i -= 1 - count += 1 - - return count - - -def _daisychain_empty_basicblocks(ctx: IRFunction) -> int: - count = 0 - i = 0 - while i < len(ctx.basic_blocks): - bb = ctx.basic_blocks[i] - i += 1 - if bb.is_terminated: - continue - - if i < len(ctx.basic_blocks) - 1: - bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) - else: - bb.append_instruction("stop") - - count += 1 - - return count - - -@ir_pass -def ir_pass_optimize_empty_blocks(ctx: IRFunction) -> int: - changes = _optimize_empty_basicblocks(ctx) - calculate_cfg(ctx) - return changes - - -@ir_pass -def ir_pass_remove_unreachable_blocks(ctx: IRFunction) -> int: - return ctx.remove_unreachable_blocks() - - -@ir_pass -def ir_pass_optimize_unused_variables(ctx: IRFunction) -> int: - return len(_optimize_unused_variables(ctx)) diff --git a/vyper/venom/context.py b/vyper/venom/context.py new file mode 100644 index 0000000000..2e35967dfe --- /dev/null +++ b/vyper/venom/context.py @@ -0,0 +1,67 @@ +from typing import Optional + +from vyper.venom.basicblock import IRInstruction, IRLabel, IROperand +from vyper.venom.function import IRFunction + + +class IRContext: + functions: dict[IRLabel, IRFunction] + ctor_mem_size: Optional[int] + immutables_len: Optional[int] + data_segment: list[IRInstruction] + last_label: int + + def __init__(self) -> None: + self.functions = {} + self.ctor_mem_size = None + self.immutables_len = None + self.data_segment = [] + self.last_label = 0 + + def add_function(self, fn: IRFunction) -> None: + fn.ctx = self + self.functions[fn.name] = fn + + def create_function(self, name: str) -> IRFunction: + label = IRLabel(name, True) + fn = IRFunction(label, self) + self.add_function(fn) + return fn + + def get_function(self, name: IRLabel) -> IRFunction: + if name in self.functions: + return self.functions[name] + raise Exception(f"Function {name} not found in context") + + def get_next_label(self, suffix: str = "") -> IRLabel: + if suffix != "": + suffix = f"_{suffix}" + self.last_label += 1 + return IRLabel(f"{self.last_label}{suffix}") + + def chain_basic_blocks(self) -> None: + """ + Chain basic blocks together. This is necessary for the IR to be valid, and is done after + the IR is generated. + """ + for fn in self.functions.values(): + fn.chain_basic_blocks() + + def append_data(self, opcode: str, args: list[IROperand]) -> None: + """ + Append data + """ + self.data_segment.append(IRInstruction(opcode, args)) # type: ignore + + def __repr__(self) -> str: + s = ["IRContext:"] + for fn in self.functions.values(): + s.append(fn.__repr__()) + s.append("\n") + + if len(self.data_segment) > 0: + s += "\nData segment:\n" + for inst in self.data_segment: + s += f"{inst}\n" + + return "\n".join(s) diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 8756642f80..556be28246 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -2,16 +2,7 @@ from vyper.codegen.ir_node import IRnode from vyper.utils import OrderedSet -from vyper.venom.basicblock import ( - CFG_ALTERING_INSTRUCTIONS, - IRBasicBlock, - IRInstruction, - IRLabel, - IROperand, - IRVariable, -) - -GLOBAL_LABEL = IRLabel("__global") +from vyper.venom.basicblock import CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRLabel, IRVariable class IRFunction: @@ -20,12 +11,9 @@ class IRFunction: """ name: IRLabel # symbol name - entry_points: list[IRLabel] # entry points + ctx: "IRContext" # type: ignore # noqa: F821 args: list - ctor_mem_size: Optional[int] - immutables_len: Optional[int] basic_blocks: list[IRBasicBlock] - data_segment: list[IRInstruction] last_label: int last_variable: int @@ -34,37 +22,23 @@ class IRFunction: _error_msg_stack: list[str] _bb_index: dict[str, int] - def __init__(self, name: IRLabel = None) -> None: - if name is None: - name = GLOBAL_LABEL + def __init__(self, name: IRLabel, ctx: "IRContext" = None) -> None: # type: ignore # noqa: F821 + self.ctx = ctx self.name = name - self.entry_points = [] self.args = [] - self.ctor_mem_size = None - self.immutables_len = None self.basic_blocks = [] - self.data_segment = [] - self.last_label = 0 + self.last_variable = 0 self._ast_source_stack = [] self._error_msg_stack = [] self._bb_index = {} - self.add_entry_point(name) self.append_basic_block(IRBasicBlock(name, self)) - def add_entry_point(self, label: IRLabel) -> None: - """ - Add entry point. - """ - self.entry_points.append(label) - - def remove_entry_point(self, label: IRLabel) -> None: - """ - Remove entry point. - """ - self.entry_points.remove(label) + @property + def entry(self) -> IRBasicBlock: + return self.basic_blocks[0] def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: """ @@ -86,7 +60,9 @@ def _get_basicblock_index(self, label: str): # do a reindex self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks)) # sanity check - no duplicate labels - assert len(self._bb_index) == len(self.basic_blocks) + assert len(self._bb_index) == len( + self.basic_blocks + ), f"Duplicate labels in function '{self.name}' {self._bb_index} {self.basic_blocks}" return self._bb_index[label] def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: @@ -122,12 +98,6 @@ def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: """ return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] - def get_next_label(self, suffix: str = "") -> IRLabel: - if suffix != "": - suffix = f"_{suffix}" - self.last_label += 1 - return IRLabel(f"{self.last_label}{suffix}") - def get_next_variable(self) -> IRVariable: self.last_variable += 1 return IRVariable(f"%{self.last_variable}") @@ -176,9 +146,7 @@ def _compute_reachability(self) -> None: bb.reachable = OrderedSet() bb.is_reachable = False - for entry in self.entry_points: - entry_bb = self.get_basic_block(entry.value) - self._compute_reachability_from(entry_bb) + self._compute_reachability_from(self.entry) def _compute_reachability_from(self, bb: IRBasicBlock) -> None: """ @@ -188,18 +156,12 @@ def _compute_reachability_from(self, bb: IRBasicBlock) -> None: return bb.is_reachable = True for inst in bb.instructions: - if inst.opcode in CFG_ALTERING_INSTRUCTIONS or inst.opcode == "invoke": + if inst.opcode in CFG_ALTERING_INSTRUCTIONS: for op in inst.get_label_operands(): out_bb = self.get_basic_block(op.value) bb.reachable.add(out_bb) self._compute_reachability_from(out_bb) - def append_data(self, opcode: str, args: list[IROperand]) -> None: - """ - Append data - """ - self.data_segment.append(IRInstruction(opcode, args)) # type: ignore - @property def normalized(self) -> bool: """ @@ -243,10 +205,28 @@ def ast_source(self) -> Optional[IRnode]: def error_msg(self) -> Optional[str]: return self._error_msg_stack[-1] if len(self._error_msg_stack) > 0 else None + def chain_basic_blocks(self) -> None: + """ + Chain basic blocks together. If a basic block is not terminated, jump to the next one. + Otherwise, append a stop instruction. This is necessary for the IR to be valid, and is + done after the IR is generated. + """ + for i, bb in enumerate(self.basic_blocks): + if not bb.is_terminated: + if len(self.basic_blocks) - 1 > i: + # TODO: revisit this. When contructor calls internal functions they + # are linked to the last ctor block. Should separate them before this + # so we don't have to handle this here + if self.basic_blocks[i + 1].label.value.startswith("internal"): + bb.append_instruction("stop") + else: + bb.append_instruction("jmp", self.basic_blocks[i + 1].label) + else: + bb.append_instruction("exit") + def copy(self): new = IRFunction(self.name) new.basic_blocks = self.basic_blocks.copy() - new.data_segment = self.data_segment.copy() new.last_label = self.last_label new.last_variable = self.last_variable return new @@ -281,8 +261,4 @@ def __repr__(self) -> str: str = f"IRFunction: {self.name}\n" for bb in self.basic_blocks: str += f"{bb}\n" - if len(self.data_segment) > 0: - str += "Data segment:\n" - for inst in self.data_segment: - str += f"{inst}\n" return str.strip() diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index bcf27bbb0c..b4465e9f7b 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -13,6 +13,7 @@ IROperand, IRVariable, ) +from vyper.venom.context import IRContext from vyper.venom.function import IRFunction # Instructions that are mapped to their inverse @@ -107,61 +108,51 @@ SymbolTable = dict[str, Optional[IROperand]] _global_symbols: SymbolTable = {} +MAIN_ENTRY_LABEL_NAME = "__main_entry" # convert IRnode directly to venom -def ir_node_to_venom(ir: IRnode) -> IRFunction: +def ir_node_to_venom(ir: IRnode) -> IRContext: global _global_symbols _global_symbols = {} - ctx = IRFunction() - _convert_ir_bb(ctx, ir, {}) + ctx = IRContext() + fn = ctx.create_function(MAIN_ENTRY_LABEL_NAME) - # Patch up basic blocks. Connect unterminated blocks to the next with - # a jump. terminate final basic block with STOP. - for i, bb in enumerate(ctx.basic_blocks): - if not bb.is_terminated: - if len(ctx.basic_blocks) - 1 > i: - # TODO: revisit this. When contructor calls internal functions they - # are linked to the last ctor block. Should separate them before this - # so we don't have to handle this here - if ctx.basic_blocks[i + 1].label.value.startswith("internal"): - bb.append_instruction("stop") - else: - bb.append_instruction("jmp", ctx.basic_blocks[i + 1].label) - else: - bb.append_instruction("exit") + _convert_ir_bb(fn, ir, {}) + + ctx.chain_basic_blocks() return ctx -def _append_jmp(ctx: IRFunction, label: IRLabel) -> None: - bb = ctx.get_basic_block() +def _append_jmp(fn: IRFunction, label: IRLabel) -> None: + bb = fn.get_basic_block() if bb.is_terminated: - bb = IRBasicBlock(ctx.get_next_label("jmp_target"), ctx) - ctx.append_basic_block(bb) + bb = IRBasicBlock(fn.ctx.get_next_label("jmp_target"), fn) + fn.append_basic_block(bb) bb.append_instruction("jmp", label) -def _new_block(ctx: IRFunction) -> IRBasicBlock: - bb = IRBasicBlock(ctx.get_next_label(), ctx) - bb = ctx.append_basic_block(bb) +def _new_block(fn: IRFunction) -> IRBasicBlock: + bb = IRBasicBlock(fn.ctx.get_next_label(), fn) + bb = fn.append_basic_block(bb) return bb -def _append_return_args(ctx: IRFunction, ofst: int = 0, size: int = 0): - bb = ctx.get_basic_block() +def _append_return_args(fn: IRFunction, ofst: int = 0, size: int = 0): + bb = fn.get_basic_block() if bb.is_terminated: - bb = IRBasicBlock(ctx.get_next_label("exit_to"), ctx) - ctx.append_basic_block(bb) + bb = IRBasicBlock(fn.ctx.get_next_label("exit_to"), fn) + fn.append_basic_block(bb) ret_ofst = IRVariable("ret_ofst") ret_size = IRVariable("ret_size") bb.append_instruction("store", ofst, ret=ret_ofst) bb.append_instruction("store", size, ret=ret_size) -def _handle_self_call(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> Optional[IRVariable]: +def _handle_self_call(fn: IRFunction, ir: IRnode, symbols: SymbolTable) -> Optional[IRVariable]: setup_ir = ir.args[1] goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] target_label = goto_ir.args[0].value # goto @@ -169,11 +160,11 @@ def _handle_self_call(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> Opti ret_args: list[IROperand] = [IRLabel(target_label)] # type: ignore if setup_ir != goto_ir: - _convert_ir_bb(ctx, setup_ir, symbols) + _convert_ir_bb(fn, setup_ir, symbols) - return_buf = _convert_ir_bb(ctx, return_buf_ir, symbols) + return_buf = _convert_ir_bb(fn, return_buf_ir, symbols) - bb = ctx.get_basic_block() + bb = fn.get_basic_block() if len(goto_ir.args) > 2: ret_args.append(return_buf) # type: ignore @@ -183,10 +174,10 @@ def _handle_self_call(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> Opti def _handle_internal_func( - ctx: IRFunction, ir: IRnode, does_return_data: bool, symbols: SymbolTable -): - bb = IRBasicBlock(IRLabel(ir.args[0].args[0].value, True), ctx) # type: ignore - bb = ctx.append_basic_block(bb) + fn: IRFunction, ir: IRnode, does_return_data: bool, symbols: SymbolTable +) -> IRFunction: + fn = fn.ctx.create_function(ir.args[0].args[0].value) + bb = fn.get_basic_block() # return buffer if does_return_data: @@ -197,27 +188,29 @@ def _handle_internal_func( symbols["return_pc"] = bb.append_instruction("param") bb.instructions[-1].annotation = "return_pc" - _convert_ir_bb(ctx, ir.args[0].args[2], symbols) + _convert_ir_bb(fn, ir.args[0].args[2], symbols) + + return fn def _convert_ir_simple_node( - ctx: IRFunction, ir: IRnode, symbols: SymbolTable + fn: IRFunction, ir: IRnode, symbols: SymbolTable ) -> Optional[IRVariable]: # execute in order - args = _convert_ir_bb_list(ctx, ir.args, symbols) + args = _convert_ir_bb_list(fn, ir.args, symbols) # reverse output variables for stack args.reverse() - return ctx.get_basic_block().append_instruction(ir.value, *args) # type: ignore + return fn.get_basic_block().append_instruction(ir.value, *args) # type: ignore _break_target: Optional[IRBasicBlock] = None _continue_target: Optional[IRBasicBlock] = None -def _convert_ir_bb_list(ctx, ir, symbols): +def _convert_ir_bb_list(fn, ir, symbols): ret = [] for ir_node in ir: - venom = _convert_ir_bb(ctx, ir_node, symbols) + venom = _convert_ir_bb(fn, ir_node, symbols) ret.append(venom) return ret @@ -229,31 +222,32 @@ def _convert_ir_bb_list(ctx, ir, symbols): def pop_source_on_return(func): @functools.wraps(func) def pop_source(*args, **kwargs): - ctx = args[0] + fn = args[0] ret = func(*args, **kwargs) - ctx.pop_source() + fn.pop_source() return ret return pop_source @pop_source_on_return -def _convert_ir_bb(ctx, ir, symbols): +def _convert_ir_bb(fn, ir, symbols): assert isinstance(ir, IRnode), ir global _break_target, _continue_target, current_func, var_list, _global_symbols - ctx.push_source(ir) + ctx = fn.ctx + fn.push_source(ir) if ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: org_value = ir.value ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] - new_var = _convert_ir_simple_node(ctx, ir, symbols) + new_var = _convert_ir_simple_node(fn, ir, symbols) ir.value = org_value - return ctx.get_basic_block().append_instruction("iszero", new_var) + return fn.get_basic_block().append_instruction("iszero", new_var) elif ir.value in PASS_THROUGH_INSTRUCTIONS: - return _convert_ir_simple_node(ctx, ir, symbols) + return _convert_ir_simple_node(fn, ir, symbols) elif ir.value == "return": - ctx.get_basic_block().append_instruction( + fn.get_basic_block().append_instruction( "return", IRVariable("ret_size"), IRVariable("ret_ofst") ) elif ir.value == "deploy": @@ -264,7 +258,7 @@ def _convert_ir_bb(ctx, ir, symbols): if len(ir.args) == 0: return None if ir.is_self_call: - return _handle_self_call(ctx, ir, symbols) + return _handle_self_call(fn, ir, symbols) elif ir.args[0].value == "label": current_func = ir.args[0].args[0].value is_external = current_func.startswith("external") @@ -275,68 +269,68 @@ def _convert_ir_bb(ctx, ir, symbols): does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args _global_symbols = {} symbols = {} - _handle_internal_func(ctx, ir, does_return_data, symbols) + new_fn = _handle_internal_func(fn, ir, does_return_data, symbols) for ir_node in ir.args[1:]: - ret = _convert_ir_bb(ctx, ir_node, symbols) + ret = _convert_ir_bb(new_fn, ir_node, symbols) return ret elif is_external: _global_symbols = {} - ret = _convert_ir_bb(ctx, ir.args[0], symbols) - _append_return_args(ctx) + ret = _convert_ir_bb(fn, ir.args[0], symbols) + _append_return_args(fn) else: - bb = ctx.get_basic_block() + bb = fn.get_basic_block() if bb.is_terminated: - bb = IRBasicBlock(ctx.get_next_label("seq"), ctx) - ctx.append_basic_block(bb) - ret = _convert_ir_bb(ctx, ir.args[0], symbols) + bb = IRBasicBlock(ctx.get_next_label("seq"), fn) + fn.append_basic_block(bb) + ret = _convert_ir_bb(fn, ir.args[0], symbols) for ir_node in ir.args[1:]: - ret = _convert_ir_bb(ctx, ir_node, symbols) + ret = _convert_ir_bb(fn, ir_node, symbols) return ret elif ir.value == "if": cond = ir.args[0] # convert the condition - cont_ret = _convert_ir_bb(ctx, cond, symbols) - cond_block = ctx.get_basic_block() + cont_ret = _convert_ir_bb(fn, cond, symbols) + cond_block = fn.get_basic_block() saved_global_symbols = _global_symbols.copy() - then_block = IRBasicBlock(ctx.get_next_label("then"), ctx) - else_block = IRBasicBlock(ctx.get_next_label("else"), ctx) + then_block = IRBasicBlock(ctx.get_next_label("then"), fn) + else_block = IRBasicBlock(ctx.get_next_label("else"), fn) # convert "then" cond_symbols = symbols.copy() - ctx.append_basic_block(then_block) - then_ret_val = _convert_ir_bb(ctx, ir.args[1], cond_symbols) + fn.append_basic_block(then_block) + then_ret_val = _convert_ir_bb(fn, ir.args[1], cond_symbols) if isinstance(then_ret_val, IRLiteral): - then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) + then_ret_val = fn.get_basic_block().append_instruction("store", then_ret_val) - then_block_finish = ctx.get_basic_block() + then_block_finish = fn.get_basic_block() # convert "else" cond_symbols = symbols.copy() _global_symbols = saved_global_symbols.copy() - ctx.append_basic_block(else_block) + fn.append_basic_block(else_block) else_ret_val = None if len(ir.args) == 3: - else_ret_val = _convert_ir_bb(ctx, ir.args[2], cond_symbols) + else_ret_val = _convert_ir_bb(fn, ir.args[2], cond_symbols) if isinstance(else_ret_val, IRLiteral): assert isinstance(else_ret_val.value, int) # help mypy - else_ret_val = ctx.get_basic_block().append_instruction("store", else_ret_val) + else_ret_val = fn.get_basic_block().append_instruction("store", else_ret_val) - else_block_finish = ctx.get_basic_block() + else_block_finish = fn.get_basic_block() # finish the condition block cond_block.append_instruction("jnz", cont_ret, then_block.label, else_block.label) # exit bb - exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), ctx) - exit_bb = ctx.append_basic_block(exit_bb) + exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), fn) + exit_bb = fn.append_basic_block(exit_bb) - if_ret = ctx.get_next_variable() + if_ret = fn.get_next_variable() if then_ret_val is not None and else_ret_val is not None: then_block_finish.append_instruction("store", then_ret_val, ret=if_ret) else_block_finish.append_instruction("store", else_ret_val, ret=if_ret) @@ -352,28 +346,28 @@ def _convert_ir_bb(ctx, ir, symbols): return if_ret elif ir.value == "with": - ret = _convert_ir_bb(ctx, ir.args[1], symbols) # initialization + ret = _convert_ir_bb(fn, ir.args[1], symbols) # initialization - ret = ctx.get_basic_block().append_instruction("store", ret) + ret = fn.get_basic_block().append_instruction("store", ret) sym = ir.args[0] with_symbols = symbols.copy() with_symbols[sym.value] = ret - return _convert_ir_bb(ctx, ir.args[2], with_symbols) # body + return _convert_ir_bb(fn, ir.args[2], with_symbols) # body elif ir.value == "goto": - _append_jmp(ctx, IRLabel(ir.args[0].value)) + _append_jmp(fn, IRLabel(ir.args[0].value)) elif ir.value == "djump": - args = [_convert_ir_bb(ctx, ir.args[0], symbols)] + args = [_convert_ir_bb(fn, ir.args[0], symbols)] for target in ir.args[1:]: args.append(IRLabel(target.value)) - ctx.get_basic_block().append_instruction("djmp", *args) - _new_block(ctx) + fn.get_basic_block().append_instruction("djmp", *args) + _new_block(fn) elif ir.value == "set": sym = ir.args[0] - arg_1 = _convert_ir_bb(ctx, ir.args[1], symbols) - ctx.get_basic_block().append_instruction("store", arg_1, ret=symbols[sym.value]) + arg_1 = _convert_ir_bb(fn, ir.args[1], symbols) + fn.get_basic_block().append_instruction("store", arg_1, ret=symbols[sym.value]) elif ir.value == "symbol": return IRLabel(ir.args[0].value, True) elif ir.value == "data": @@ -386,29 +380,29 @@ def _convert_ir_bb(ctx, ir, symbols): elif isinstance(c.value, bytes): ctx.append_data("db", [c.value]) # type: ignore elif isinstance(c, IRnode): - data = _convert_ir_bb(ctx, c, symbols) + data = _convert_ir_bb(fn, c, symbols) ctx.append_data("db", [data]) # type: ignore elif ir.value == "label": label = IRLabel(ir.args[0].value, True) - bb = ctx.get_basic_block() + bb = fn.get_basic_block() if not bb.is_terminated: bb.append_instruction("jmp", label) - bb = IRBasicBlock(label, ctx) - ctx.append_basic_block(bb) + bb = IRBasicBlock(label, fn) + fn.append_basic_block(bb) code = ir.args[2] if code.value == "pass": bb.append_instruction("exit") else: - _convert_ir_bb(ctx, code, symbols) + _convert_ir_bb(fn, code, symbols) elif ir.value == "exit_to": - args = _convert_ir_bb_list(ctx, ir.args[1:], symbols) + args = _convert_ir_bb_list(fn, ir.args[1:], symbols) var_list = args - _append_return_args(ctx, *var_list) - bb = ctx.get_basic_block() + _append_return_args(fn, *var_list) + bb = fn.get_basic_block() if bb.is_terminated: - bb = IRBasicBlock(ctx.get_next_label("exit_to"), ctx) - ctx.append_basic_block(bb) - bb = ctx.get_basic_block() + bb = IRBasicBlock(ctx.get_next_label("exit_to"), fn) + fn.append_basic_block(bb) + bb = fn.get_basic_block() label = IRLabel(ir.args[0].value) if label.value == "return_pc": @@ -418,17 +412,17 @@ def _convert_ir_bb(ctx, ir, symbols): bb.append_instruction("jmp", label) elif ir.value == "dload": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols) - bb = ctx.get_basic_block() + arg_0 = _convert_ir_bb(fn, ir.args[0], symbols) + bb = fn.get_basic_block() src = bb.append_instruction("add", arg_0, IRLabel("code_end")) bb.append_instruction("dloadbytes", 32, src, MemoryPositions.FREE_VAR_SPACE) return bb.append_instruction("mload", MemoryPositions.FREE_VAR_SPACE) elif ir.value == "dloadbytes": - dst, src_offset, len_ = _convert_ir_bb_list(ctx, ir.args, symbols) + dst, src_offset, len_ = _convert_ir_bb_list(fn, ir.args, symbols) - bb = ctx.get_basic_block() + bb = fn.get_basic_block() src = bb.append_instruction("add", src_offset, IRLabel("code_end")) bb.append_instruction("dloadbytes", len_, src, dst) return None @@ -436,14 +430,14 @@ def _convert_ir_bb(ctx, ir, symbols): elif ir.value == "mstore": # some upstream code depends on reversed order of evaluation -- # to fix upstream. - val, ptr = _convert_ir_bb_list(ctx, reversed(ir.args), symbols) + val, ptr = _convert_ir_bb_list(fn, reversed(ir.args), symbols) - return ctx.get_basic_block().append_instruction("mstore", val, ptr) + return fn.get_basic_block().append_instruction("mstore", val, ptr) elif ir.value == "ceil32": x = ir.args[0] expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) - return _convert_ir_bb(ctx, expanded, symbols) + return _convert_ir_bb(fn, expanded, symbols) elif ir.value == "select": cond, a, b = ir.args expanded = IRnode.from_list( @@ -459,7 +453,7 @@ def _convert_ir_bb(ctx, ir, symbols): ], ] ) - return _convert_ir_bb(ctx, expanded, symbols) + return _convert_ir_bb(fn, expanded, symbols) elif ir.value == "repeat": def emit_body_blocks(): @@ -467,12 +461,12 @@ def emit_body_blocks(): old_targets = _break_target, _continue_target _break_target, _continue_target = exit_block, incr_block saved_global_symbols = _global_symbols.copy() - _convert_ir_bb(ctx, body, symbols.copy()) + _convert_ir_bb(fn, body, symbols.copy()) _break_target, _continue_target = old_targets _global_symbols = saved_global_symbols sym = ir.args[0] - start, end, _ = _convert_ir_bb_list(ctx, ir.args[1:4], symbols) + start, end, _ = _convert_ir_bb_list(fn, ir.args[1:4], symbols) assert ir.args[3].is_literal, "repeat bound expected to be literal" @@ -486,15 +480,15 @@ def emit_body_blocks(): body = ir.args[4] - entry_block = IRBasicBlock(ctx.get_next_label("repeat"), ctx) - cond_block = IRBasicBlock(ctx.get_next_label("condition"), ctx) - body_block = IRBasicBlock(ctx.get_next_label("body"), ctx) - incr_block = IRBasicBlock(ctx.get_next_label("incr"), ctx) - exit_block = IRBasicBlock(ctx.get_next_label("exit"), ctx) + entry_block = IRBasicBlock(ctx.get_next_label("repeat"), fn) + cond_block = IRBasicBlock(ctx.get_next_label("condition"), fn) + body_block = IRBasicBlock(ctx.get_next_label("body"), fn) + incr_block = IRBasicBlock(ctx.get_next_label("incr"), fn) + exit_block = IRBasicBlock(ctx.get_next_label("exit"), fn) - bb = ctx.get_basic_block() + bb = fn.get_basic_block() bb.append_instruction("jmp", entry_block.label) - ctx.append_basic_block(entry_block) + fn.append_basic_block(entry_block) counter_var = entry_block.append_instruction("store", start) symbols[sym.value] = counter_var @@ -505,52 +499,52 @@ def emit_body_blocks(): xor_ret = cond_block.append_instruction("xor", counter_var, end) cont_ret = cond_block.append_instruction("iszero", xor_ret) - ctx.append_basic_block(cond_block) + fn.append_basic_block(cond_block) - ctx.append_basic_block(body_block) + fn.append_basic_block(body_block) if bound: xor_ret = body_block.append_instruction("xor", counter_var, bound) body_block.append_instruction("assert", xor_ret) emit_body_blocks() - body_end = ctx.get_basic_block() + body_end = fn.get_basic_block() if body_end.is_terminated is False: body_end.append_instruction("jmp", incr_block.label) - ctx.append_basic_block(incr_block) + fn.append_basic_block(incr_block) incr_block.insert_instruction( IRInstruction("add", [counter_var, IRLiteral(1)], counter_var) ) incr_block.append_instruction("jmp", cond_block.label) - ctx.append_basic_block(exit_block) + fn.append_basic_block(exit_block) cond_block.append_instruction("jnz", cont_ret, exit_block.label, body_block.label) elif ir.value == "break": assert _break_target is not None, "Break with no break target" - ctx.get_basic_block().append_instruction("jmp", _break_target.label) - ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + fn.get_basic_block().append_instruction("jmp", _break_target.label) + fn.append_basic_block(IRBasicBlock(ctx.get_next_label(), fn)) elif ir.value == "continue": assert _continue_target is not None, "Continue with no contrinue target" - ctx.get_basic_block().append_instruction("jmp", _continue_target.label) - ctx.append_basic_block(IRBasicBlock(ctx.get_next_label(), ctx)) + fn.get_basic_block().append_instruction("jmp", _continue_target.label) + fn.append_basic_block(IRBasicBlock(ctx.get_next_label(), fn)) elif ir.value in NOOP_INSTRUCTIONS: pass elif isinstance(ir.value, str) and ir.value.startswith("log"): - args = reversed(_convert_ir_bb_list(ctx, ir.args, symbols)) + args = reversed(_convert_ir_bb_list(fn, ir.args, symbols)) topic_count = int(ir.value[3:]) assert topic_count >= 0 and topic_count <= 4, "invalid topic count" - ctx.get_basic_block().append_instruction("log", topic_count, *args) + fn.get_basic_block().append_instruction("log", topic_count, *args) elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): - _convert_ir_opcode(ctx, ir, symbols) + _convert_ir_opcode(fn, ir, symbols) elif isinstance(ir.value, str): if ir.value.startswith("$alloca") and ir.value not in _global_symbols: alloca = ir.passthrough_metadata["alloca"] - ptr = ctx.get_basic_block().append_instruction("alloca", alloca.offset, alloca.size) + ptr = fn.get_basic_block().append_instruction("alloca", alloca.offset, alloca.size) _global_symbols[ir.value] = ptr elif ir.value.startswith("$palloca") and ir.value not in _global_symbols: alloca = ir.passthrough_metadata["alloca"] - ptr = ctx.get_basic_block().append_instruction("store", alloca.offset) + ptr = fn.get_basic_block().append_instruction("store", alloca.offset) _global_symbols[ir.value] = ptr return _global_symbols.get(ir.value) or symbols.get(ir.value) @@ -562,10 +556,10 @@ def emit_body_blocks(): return None -def _convert_ir_opcode(ctx: IRFunction, ir: IRnode, symbols: SymbolTable) -> None: +def _convert_ir_opcode(fn: IRFunction, ir: IRnode, symbols: SymbolTable) -> None: opcode = ir.value.upper() # type: ignore inst_args = [] for arg in ir.args: if isinstance(arg, IRnode): - inst_args.append(_convert_ir_bb(ctx, arg, symbols)) - ctx.get_basic_block().append_instruction(opcode, *inst_args) + inst_args.append(_convert_ir_bb(fn, arg, symbols)) + fn.get_basic_block().append_instruction(opcode, *inst_args) diff --git a/vyper/venom/passes/base_pass.py b/vyper/venom/passes/base_pass.py index 1851b35233..4d1bfe9647 100644 --- a/vyper/venom/passes/base_pass.py +++ b/vyper/venom/passes/base_pass.py @@ -1,21 +1,18 @@ +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.function import IRFunction + + class IRPass: """ - Decorator for IR passes. This decorator will run the pass repeatedly - until no more changes are made. + Base class for all Venom IR passes. """ - def run_pass(self, *args, **kwargs): - count = 0 + function: IRFunction + analyses_cache: IRAnalysesCache - for _ in range(1000): - changes_count = self._run_pass(*args, **kwargs) or 0 - count += changes_count - if changes_count == 0: - break - else: - raise Exception("Too many iterations in IR pass!", self.__class__) + def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): + self.function = function + self.analyses_cache = analyses_cache - return count - - def _run_pass(self, *args, **kwargs): + def run_pass(self, *args, **kwargs): raise NotImplementedError(f"Not implemented! {self.__class__}.run_pass()") diff --git a/vyper/venom/passes/constant_propagation.py b/vyper/venom/passes/constant_propagation.py deleted file mode 100644 index 94b556124e..0000000000 --- a/vyper/venom/passes/constant_propagation.py +++ /dev/null @@ -1,13 +0,0 @@ -from vyper.utils import ir_pass -from vyper.venom.basicblock import IRBasicBlock -from vyper.venom.function import IRFunction - - -def _process_basic_block(ctx: IRFunction, bb: IRBasicBlock): - pass - - -@ir_pass -def ir_pass_constant_propagation(ctx: IRFunction): - for bb in ctx.basic_blocks: - _process_basic_block(ctx, bb) diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index 994ab9d70d..e4e27ed813 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -1,11 +1,12 @@ from vyper.utils import OrderedSet -from vyper.venom.analysis import DFG +from vyper.venom.analysis.dfg import DFGAnalysis from vyper.venom.basicblock import BB_TERMINATORS, IRBasicBlock, IRInstruction, IRVariable from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass class DFTPass(IRPass): + function: IRFunction inst_order: dict[IRInstruction, int] inst_order_num: int @@ -50,7 +51,7 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset: self.inst_order[inst] = self.inst_order_num + offset def _process_basic_block(self, bb: IRBasicBlock) -> None: - self.ctx.append_basic_block(bb) + self.function.append_basic_block(bb) for inst in bb.instructions: inst.fence_id = self.fence_id @@ -67,15 +68,14 @@ def _process_basic_block(self, bb: IRBasicBlock) -> None: bb.instructions.sort(key=lambda x: self.inst_order[x]) - def _run_pass(self, ctx: IRFunction) -> None: - self.ctx = ctx - self.dfg = DFG.build_dfg(ctx) + def run_pass(self) -> None: + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) self.fence_id = 0 self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() - basic_blocks = ctx.basic_blocks - ctx.basic_blocks = [] + basic_blocks = self.function.basic_blocks + self.function.basic_blocks = [] for bb in basic_blocks: self._process_basic_block(bb) diff --git a/vyper/venom/passes/make_ssa.py b/vyper/venom/passes/make_ssa.py index 91611482a2..fd7861930a 100644 --- a/vyper/venom/passes/make_ssa.py +++ b/vyper/venom/passes/make_ssa.py @@ -1,8 +1,8 @@ from vyper.utils import OrderedSet -from vyper.venom.analysis import calculate_cfg, calculate_liveness +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis.dominators import DominatorTreeAnalysis +from vyper.venom.analysis.liveness import LivenessAnalysis from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IROperand, IRVariable -from vyper.venom.dominators import DominatorTree -from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -11,24 +11,22 @@ class MakeSSA(IRPass): This pass converts the function into Static Single Assignment (SSA) form. """ - dom: DominatorTree + dom: DominatorTreeAnalysis defs: dict[IRVariable, OrderedSet[IRBasicBlock]] - def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> int: - self.ctx = ctx + def run_pass(self): + fn = self.function - calculate_cfg(ctx) - self.dom = DominatorTree.build_dominator_tree(ctx, entry) + self.analyses_cache.request_analysis(CFGAnalysis) + self.dom = self.analyses_cache.request_analysis(DominatorTreeAnalysis) + self.analyses_cache.request_analysis(LivenessAnalysis) - calculate_liveness(ctx) self._add_phi_nodes() self.var_name_counters = {var.name: 0 for var in self.defs.keys()} self.var_name_stacks = {var.name: [0] for var in self.defs.keys()} - self._rename_vars(entry) - self._remove_degenerate_phis(entry) - - return 0 + self._rename_vars(fn.entry) + self._remove_degenerate_phis(fn.entry) def _add_phi_nodes(self): """ diff --git a/vyper/venom/passes/mem2var.py b/vyper/venom/passes/mem2var.py index 9d74dfec0b..0ad9c411f1 100644 --- a/vyper/venom/passes/mem2var.py +++ b/vyper/venom/passes/mem2var.py @@ -1,5 +1,7 @@ from vyper.utils import OrderedSet -from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis.liveness import LivenessAnalysis from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -11,20 +13,13 @@ class Mem2Var(IRPass): It does yet do any memory aliasing analysis, so it is conservative. """ - ctx: IRFunction + function: IRFunction defs: dict[IRVariable, OrderedSet[IRBasicBlock]] - dfg: DFG - def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock, dfg: DFG) -> int: - self.ctx = ctx - self.dfg = dfg - - calculate_cfg(ctx) - - dfg = DFG.build_dfg(ctx) - self.dfg = dfg - - calculate_liveness(ctx) + def run_pass(self): + self.analyses_cache.request_analysis(CFGAnalysis) + dfg = self.analyses_cache.request_analysis(DFGAnalysis) + self.analyses_cache.request_analysis(LivenessAnalysis) self.var_name_count = 0 for var, inst in dfg.outputs.items(): @@ -32,9 +27,10 @@ def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock, dfg: DFG) -> int: continue self._process_alloca_var(dfg, var) - return 0 + self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis) - def _process_alloca_var(self, dfg: DFG, var: IRVariable): + def _process_alloca_var(self, dfg: DFGAnalysis, var: IRVariable): """ Process alloca allocated variable. If it is only used by mstore/mload/return instructions, it is promoted to a stack variable. Otherwise, it is left as is. @@ -57,7 +53,7 @@ def _process_alloca_var(self, dfg: DFG, var: IRVariable): inst.operands = [IRVariable(var_name)] elif inst.opcode == "return": bb = inst.parent - new_var = self.ctx.get_next_variable() + new_var = self.function.get_next_variable() idx = bb.instructions.index(inst) bb.insert_instruction( IRInstruction("mstore", [IRVariable(var_name), inst.operands[1]], new_var), diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 9ca8127b2d..83c565b1be 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -1,6 +1,6 @@ -from vyper.venom.analysis import calculate_cfg +from vyper.exceptions import CompilerPanic +from vyper.venom.analysis.cfg import CFGAnalysis from vyper.venom.basicblock import IRBasicBlock, IRLabel -from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -27,14 +27,15 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB # Create an intermediary basic block and append it source = in_bb.label.value target = bb.label.value + fn = self.function split_label = IRLabel(f"{source}_split_{target}") in_terminal = in_bb.instructions[-1] in_terminal.replace_label_operands({bb.label: split_label}) - split_bb = IRBasicBlock(split_label, self.ctx) + split_bb = IRBasicBlock(split_label, fn) split_bb.append_instruction("jmp", bb.label) - self.ctx.append_basic_block(split_bb) + fn.append_basic_block(split_bb) for inst in bb.instructions: if inst.opcode != "phi": @@ -44,24 +45,34 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, in_bb: IRBasicBlock) -> IRB inst.operands[i] = split_bb.label # Update the labels in the data segment - for inst in self.ctx.data_segment: + for inst in fn.ctx.data_segment: if inst.opcode == "db" and inst.operands[0] == bb.label: inst.operands[0] = split_bb.label return split_bb - def _run_pass(self, ctx: IRFunction) -> int: - self.ctx = ctx + def _run_pass(self) -> int: + fn = self.function self.changes = 0 + self.analyses_cache.request_analysis(CFGAnalysis) + # Split blocks that need splitting - for bb in ctx.basic_blocks: + for bb in fn.basic_blocks: if len(bb.cfg_in) > 1: self._split_basic_block(bb) # If we made changes, recalculate the cfg if self.changes > 0: - calculate_cfg(ctx) - ctx.remove_unreachable_blocks() + self.analyses_cache.force_analysis(CFGAnalysis) + fn.remove_unreachable_blocks() return self.changes + + def run_pass(self): + fn = self.function + for _ in range(len(fn.basic_blocks) * 2): + if self._run_pass() == 0: + break + else: + raise CompilerPanic("Normalization pass did not converge") diff --git a/vyper/venom/passes/remove_unused_variables.py b/vyper/venom/passes/remove_unused_variables.py new file mode 100644 index 0000000000..b7fb3abbf0 --- /dev/null +++ b/vyper/venom/passes/remove_unused_variables.py @@ -0,0 +1,22 @@ +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.venom.passes.base_pass import IRPass + + +class RemoveUnusedVariablesPass(IRPass): + def run_pass(self): + removeList = set() + + self.analyses_cache.request_analysis(LivenessAnalysis) + + for bb in self.function.basic_blocks: + for i, inst in enumerate(bb.instructions[:-1]): + if inst.volatile: + continue + next_liveness = bb.instructions[i + 1].liveness + if (inst.output and inst.output not in next_liveness) or inst.opcode == "nop": + removeList.add(inst) + + bb.instructions = [inst for inst in bb.instructions if inst not in removeList] + + self.analyses_cache.invalidate_analysis(DFGAnalysis) diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 7dfca8edd4..7f3fc7e03e 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -5,6 +5,9 @@ from vyper.exceptions import CompilerPanic, StaticAssertionException from vyper.utils import OrderedSet +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.analysis.dominators import DominatorTreeAnalysis from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -13,7 +16,6 @@ IROperand, IRVariable, ) -from vyper.venom.dominators import DominatorTree from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass from vyper.venom.passes.sccp.eval import ARITHMETIC_OPS @@ -49,28 +51,30 @@ class SCCP(IRPass): with their constant values. """ - ctx: IRFunction - dom: DominatorTree + fn: IRFunction + dom: DominatorTreeAnalysis uses: dict[IRVariable, OrderedSet[IRInstruction]] lattice: Lattice work_list: list[WorkListItem] cfg_dirty: bool cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] - def __init__(self, dom: DominatorTree): - self.dom = dom + def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): + super().__init__(analyses_cache, function) self.lattice = {} self.work_list: list[WorkListItem] = [] self.cfg_dirty = False - def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> int: - self.ctx = ctx - self._compute_uses(self.dom) - self._calculate_sccp(entry) + def run_pass(self): + self.fn = self.function + self.dom = self.analyses_cache.request_analysis(DominatorTreeAnalysis) + self._compute_uses() + self._calculate_sccp(self.fn.entry) self._propagate_constants() # self._propagate_variables() - return 0 + + self.analyses_cache.invalidate_analysis(CFGAnalysis) def _calculate_sccp(self, entry: IRBasicBlock): """ @@ -83,9 +87,9 @@ def _calculate_sccp(self, entry: IRBasicBlock): and the work list. The `_propagate_constants()` method is responsible for updating the IR with the constant values. """ - self.cfg_in_exec = {bb: OrderedSet() for bb in self.ctx.basic_blocks} + self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.basic_blocks} - dummy = IRBasicBlock(IRLabel("__dummy_start"), self.ctx) + dummy = IRBasicBlock(IRLabel("__dummy_start"), self.fn) self.work_list.append(FlowWorkItem(dummy, entry)) # Initialize the lattice with TOP values for all variables @@ -143,7 +147,7 @@ def _visit_phi(self, inst: IRInstruction): assert inst.opcode == "phi", "Can't visit non phi instruction" in_vars: list[LatticeItem] = [] for bb_label, var in inst.phi_operands: - bb = self.ctx.get_basic_block(bb_label.name) + bb = self.fn.get_basic_block(bb_label.name) if bb not in self.cfg_in_exec[inst.parent]: continue in_vars.append(self.lattice[var]) @@ -162,7 +166,7 @@ def _visit_expr(self, inst: IRInstruction): self.lattice[inst.output] = self.lattice[inst.operands[0]] # type: ignore self._add_ssa_work_items(inst) elif opcode == "jmp": - target = self.ctx.get_basic_block(inst.operands[0].value) + target = self.fn.get_basic_block(inst.operands[0].value) self.work_list.append(FlowWorkItem(inst.parent, target)) elif opcode == "jnz": lat = self.lattice[inst.operands[0]] @@ -172,17 +176,17 @@ def _visit_expr(self, inst: IRInstruction): self.work_list.append(FlowWorkItem(inst.parent, out_bb)) else: if _meet(lat, IRLiteral(0)) == LatticeEnum.BOTTOM: - target = self.ctx.get_basic_block(inst.operands[1].name) + target = self.fn.get_basic_block(inst.operands[1].name) self.work_list.append(FlowWorkItem(inst.parent, target)) if _meet(lat, IRLiteral(1)) == LatticeEnum.BOTTOM: - target = self.ctx.get_basic_block(inst.operands[2].name) + target = self.fn.get_basic_block(inst.operands[2].name) self.work_list.append(FlowWorkItem(inst.parent, target)) elif opcode == "djmp": lat = self.lattice[inst.operands[0]] assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}" if lat == LatticeEnum.BOTTOM: for op in inst.operands[1:]: - target = self.ctx.get_basic_block(op.name) + target = self.fn.get_basic_block(op.name) self.work_list.append(FlowWorkItem(inst.parent, target)) elif isinstance(lat, IRLiteral): raise CompilerPanic("Unimplemented djmp with literal") @@ -239,14 +243,14 @@ def _add_ssa_work_items(self, inst: IRInstruction): for target_inst in self._get_uses(inst.output): # type: ignore self.work_list.append(SSAWorkListItem(target_inst)) - def _compute_uses(self, dom: DominatorTree): + def _compute_uses(self): """ This method computes the uses for each variable in the IR. It iterates over the dominator tree and collects all the instructions that use each variable. """ self.uses = {} - for bb in dom.dfs_walk: + for bb in self.dom.dfs_walk: for var, insts in bb.get_uses().items(): self._get_uses(var).update(insts) diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index bebf2acd32..bb5233eba0 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -1,8 +1,7 @@ from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet -from vyper.venom.basicblock import IRBasicBlock -from vyper.venom.bb_optimizer import ir_pass_remove_unreachable_blocks -from vyper.venom.function import IRFunction +from vyper.venom.analysis.cfg import CFGAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.passes.base_pass import IRPass @@ -31,7 +30,7 @@ def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): break inst.operands[inst.operands.index(b.label)] = a.label - self.ctx.basic_blocks.remove(b) + self.function.basic_blocks.remove(b) def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb = b.cfg_out.first() @@ -45,7 +44,7 @@ def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) - self.ctx.basic_blocks.remove(b) + self.function.basic_blocks.remove(b) def _collapse_chained_blocks_r(self, bb: IRBasicBlock): """ @@ -83,12 +82,60 @@ def _collapse_chained_blocks(self, entry: IRBasicBlock): self.visited = OrderedSet() self._collapse_chained_blocks_r(entry) - def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> None: - self.ctx = ctx + def _optimize_empty_basicblocks(self) -> int: + """ + Remove empty basic blocks. + """ + fn = self.function + count = 0 + i = 0 + while i < len(fn.basic_blocks): + bb = fn.basic_blocks[i] + i += 1 + if len(bb.instructions) > 0: + continue + + replaced_label = bb.label + replacement_label = fn.basic_blocks[i].label if i < len(fn.basic_blocks) else None + if replacement_label is None: + continue + + # Try to preserve symbol labels + if replaced_label.is_symbol: + replaced_label, replacement_label = replacement_label, replaced_label + fn.basic_blocks[i].label = replacement_label + + for bb2 in fn.basic_blocks: + for inst in bb2.instructions: + for op in inst.operands: + if isinstance(op, IRLabel) and op.value == replaced_label.value: + op.value = replacement_label.value + + fn.basic_blocks.remove(bb) + i -= 1 + count += 1 + + return count - for _ in range(len(ctx.basic_blocks)): # essentially `while True` + def run_pass(self): + fn = self.function + entry = fn.entry + + for _ in range(len(fn.basic_blocks)): + changes = self._optimize_empty_basicblocks() + changes += fn.remove_unreachable_blocks() + if changes == 0: + break + else: + raise CompilerPanic("Too many iterations removing empty basic blocks") + + self.analyses_cache.force_analysis(CFGAnalysis) + + for _ in range(len(fn.basic_blocks)): # essentially `while True` self._collapse_chained_blocks(entry) - if ir_pass_remove_unreachable_blocks(ctx) == 0: + if fn.remove_unreachable_blocks() == 0: break else: raise CompilerPanic("Too many iterations collapsing chained blocks") + + self.analyses_cache.invalidate_analysis(CFGAnalysis) diff --git a/vyper/venom/passes/stack_reorder.py b/vyper/venom/passes/stack_reorder.py index b32ec4abde..a92fe0e626 100644 --- a/vyper/venom/passes/stack_reorder.py +++ b/vyper/venom/passes/stack_reorder.py @@ -1,13 +1,12 @@ from vyper.utils import OrderedSet from vyper.venom.basicblock import IRBasicBlock -from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass class StackReorderPass(IRPass): visited: OrderedSet - def _reorder_stack(self, bb: IRBasicBlock): + def _reorder_stack(self): pass def _visit(self, bb: IRBasicBlock): @@ -18,7 +17,7 @@ def _visit(self, bb: IRBasicBlock): for bb_out in bb.cfg_out: self._visit(bb_out) - def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock): - self.ctx = ctx + def _run_pass(self): + entry = self.function.entry self.visited = OrderedSet() self._visit(entry) diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index e99b0a95b7..26bad8882c 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -11,12 +11,9 @@ optimize_assembly, ) from vyper.utils import MemoryPositions, OrderedSet -from vyper.venom.analysis import ( - calculate_cfg, - calculate_dup_requirements, - calculate_liveness, - input_vars_from, -) +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis.dup_requirements import DupRequirementsAnalysis +from vyper.venom.analysis.liveness import LivenessAnalysis from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -25,7 +22,7 @@ IROperand, IRVariable, ) -from vyper.venom.function import IRFunction +from vyper.venom.context import IRContext from vyper.venom.passes.normalization import NormalizationPass from vyper.venom.stack_model import StackModel @@ -127,12 +124,13 @@ def apply_line_numbers(inst: IRInstruction, asm) -> list[str]: # with the assembler. My suggestion is to let this be for now, and we can # refactor it later when we are finished phasing out the old IR. class VenomCompiler: - ctxs: list[IRFunction] + ctxs: list[IRContext] label_counter = 0 visited_instructions: OrderedSet # {IRInstruction} visited_basicblocks: OrderedSet # {IRBasicBlock} + liveness_analysis: LivenessAnalysis - def __init__(self, ctxs: list[IRFunction]): + def __init__(self, ctxs: list[IRContext]): self.ctxs = ctxs self.label_counter = 0 self.visited_instructions = OrderedSet() @@ -146,22 +144,17 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: asm: list[Any] = [] top_asm = asm - # Before emitting the assembly, we need to make sure that the - # CFG is normalized. Calling calculate_cfg() will denormalize IR (reset) - # so it should not be called after calling NormalizationPass().run_pass(). - # Liveness is then computed for the normalized IR, and we can proceed to - # assembly generation. - # This is a side-effect of how dynamic jumps are temporarily being used - # to support the O(1) dispatcher. -> look into calculate_cfg() for ctx in self.ctxs: - NormalizationPass().run_pass(ctx) - calculate_cfg(ctx) - calculate_liveness(ctx) - calculate_dup_requirements(ctx) + for fn in ctx.functions.values(): + ac = IRAnalysesCache(fn) + + NormalizationPass(ac, fn).run_pass() + self.liveness_analysis = ac.request_analysis(LivenessAnalysis) + ac.request_analysis(DupRequirementsAnalysis) - assert ctx.normalized, "Non-normalized CFG!" + assert fn.normalized, "Non-normalized CFG!" - self._generate_evm_for_basicblock_r(asm, ctx.basic_blocks[0], StackModel()) + self._generate_evm_for_basicblock_r(asm, fn.entry, StackModel()) # TODO make this property on IRFunction asm.extend(["_sym__ctor_exit", "JUMPDEST"]) @@ -321,7 +314,7 @@ def clean_stack_from_cfg_in( to_pop = OrderedSet[IRVariable]() for in_bb in basicblock.cfg_in: # inputs is the input variables we need from in_bb - inputs = input_vars_from(in_bb, basicblock) + inputs = self.liveness_analysis.input_vars_from(in_bb, basicblock) # layout is the output stack layout for in_bb (which works # for all possible cfg_outs from the in_bb). @@ -405,7 +398,7 @@ def _generate_evm_for_instruction( # prepare stack for jump into another basic block assert inst.parent and isinstance(inst.parent.cfg_out, OrderedSet) b = next(iter(inst.parent.cfg_out)) - target_stack = input_vars_from(inst.parent, b) + target_stack = self.liveness_analysis.input_vars_from(inst.parent, b) # TODO optimize stack reordering at entry and exit from basic blocks # NOTE: stack in general can contain multiple copies of the same variable, # however we are safe in the case of jmp/djmp/jnz as it's not going to