From a7a647f5a865a335f607f27b4d280434770c4c22 Mon Sep 17 00:00:00 2001 From: Harry Kalogirou Date: Tue, 28 May 2024 16:49:32 +0300 Subject: [PATCH] feat[venom]: optimize branching (#4049) This commit introduces a new pass called `BranchOptimizationPass` that optimizes inefficient branches. More specifically, when a branch is led with a logic inversion `ISZERO` we remove the `ISZERO` and invert the branch targets. --- .../compiler/venom/test_branch_optimizer.py | 54 +++++++++++++++++++ vyper/venom/__init__.py | 2 + vyper/venom/analysis/dfg.py | 4 ++ vyper/venom/passes/branch_optimization.py | 30 +++++++++++ 4 files changed, 90 insertions(+) create mode 100644 tests/unit/compiler/venom/test_branch_optimizer.py create mode 100644 vyper/venom/passes/branch_optimization.py diff --git a/tests/unit/compiler/venom/test_branch_optimizer.py b/tests/unit/compiler/venom/test_branch_optimizer.py new file mode 100644 index 0000000000..b6e806e217 --- /dev/null +++ b/tests/unit/compiler/venom/test_branch_optimizer.py @@ -0,0 +1,54 @@ +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRLabel +from vyper.venom.context import IRContext +from vyper.venom.passes.branch_optimization import BranchOptimizationPass +from vyper.venom.passes.make_ssa import MakeSSA + + +def test_simple_jump_case(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", p1) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + jnz_input = bb.append_instruction("iszero", op3) + bb.append_instruction("jnz", jnz_input, br1.label, br2.label) + + br1.append_instruction("add", op3, 10) + br1.append_instruction("stop") + br2.append_instruction("add", op3, p1) + br2.append_instruction("stop") + + term_inst = bb.instructions[-1] + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + + old_dfg = ac.request_analysis(DFGAnalysis) + assert term_inst not in old_dfg.get_uses(op3), "jnz not using the old condition" + assert term_inst in old_dfg.get_uses(jnz_input), "jnz using the new condition" + + BranchOptimizationPass(ac, fn).run_pass() + + # Test that the jnz targets are inverted and + # the jnz condition updated + assert term_inst.opcode == "jnz" + assert term_inst.operands[0] == op3 + assert term_inst.operands[1] == br2.label + assert term_inst.operands[2] == br1.label + + # Test that the dfg is updated correctly + dfg = ac.request_analysis(DFGAnalysis) + assert dfg is old_dfg, "DFG should not be invalidated by BranchOptimizationPass" + assert term_inst in dfg.get_uses(op3), "jnz not using the new condition" + assert term_inst not in dfg.get_uses(jnz_input), "jnz still using the old condition" diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 6bbcedaade..82901126bc 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -9,6 +9,7 @@ from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom +from vyper.venom.passes.branch_optimization import BranchOptimizationPass from vyper.venom.passes.dft import DFTPass from vyper.venom.passes.make_ssa import MakeSSA from vyper.venom.passes.mem2var import Mem2Var @@ -49,6 +50,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: SCCP(ac, fn).run_pass() StoreElimination(ac, fn).run_pass() SimplifyCFGPass(ac, fn).run_pass() + BranchOptimizationPass(ac, fn).run_pass() RemoveUnusedVariablesPass(ac, fn).run_pass() DFTPass(ac, fn).run_pass() diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index 2fb172a979..c64fb07fc2 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -22,6 +22,10 @@ def get_uses(self, op: IRVariable) -> list[IRInstruction]: def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: return self._dfg_outputs.get(op) + def add_use(self, op: IRVariable, inst: IRInstruction): + uses = self._dfg_inputs.setdefault(op, []) + uses.append(inst) + def remove_use(self, op: IRVariable, inst: IRInstruction): uses = self._dfg_inputs.get(op, []) uses.remove(inst) diff --git a/vyper/venom/passes/branch_optimization.py b/vyper/venom/passes/branch_optimization.py new file mode 100644 index 0000000000..354aab7900 --- /dev/null +++ b/vyper/venom/passes/branch_optimization.py @@ -0,0 +1,30 @@ +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.passes.base_pass import IRPass + + +class BranchOptimizationPass(IRPass): + """ + This pass optimizes branches inverting jnz instructions where appropriate + """ + + def _optimize_branches(self) -> None: + fn = self.function + for bb in fn.get_basic_blocks(): + term_inst = bb.instructions[-1] + if term_inst.opcode != "jnz": + continue + + prev_inst = self.dfg.get_producing_instruction(term_inst.operands[0]) + if prev_inst.opcode == "iszero": + new_cond = prev_inst.operands[0] + term_inst.operands = [new_cond, term_inst.operands[2], term_inst.operands[1]] + + # Since the DFG update is simple we do in place to avoid invalidating the DFG + # and having to recompute it (which is expensive(er)) + self.dfg.remove_use(prev_inst.output, term_inst) + self.dfg.add_use(new_cond, term_inst) + + def run_pass(self): + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) + + self._optimize_branches()