Skip to content

Commit

Permalink
feat[venom]: optimize branching (#4049)
Browse files Browse the repository at this point in the history
This commit introduces a new pass called `BranchOptimizationPass` 
that optimizes inefficient branches.

More specifically, when a branch is led with a logic inversion `ISZERO` 
we remove the `ISZERO` and invert the branch targets.
  • Loading branch information
harkal authored May 28, 2024
1 parent 96a8384 commit a7a647f
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/unit/compiler/venom/test_branch_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.analysis.dfg import DFGAnalysis
from vyper.venom.basicblock import IRBasicBlock, IRLabel
from vyper.venom.context import IRContext
from vyper.venom.passes.branch_optimization import BranchOptimizationPass
from vyper.venom.passes.make_ssa import MakeSSA


def test_simple_jump_case():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
br2 = IRBasicBlock(IRLabel("else"), fn)
fn.append_basic_block(br2)

p1 = bb.append_instruction("param")
op1 = bb.append_instruction("store", p1)
op2 = bb.append_instruction("store", 64)
op3 = bb.append_instruction("add", op1, op2)
jnz_input = bb.append_instruction("iszero", op3)
bb.append_instruction("jnz", jnz_input, br1.label, br2.label)

br1.append_instruction("add", op3, 10)
br1.append_instruction("stop")
br2.append_instruction("add", op3, p1)
br2.append_instruction("stop")

term_inst = bb.instructions[-1]

ac = IRAnalysesCache(fn)
MakeSSA(ac, fn).run_pass()

old_dfg = ac.request_analysis(DFGAnalysis)
assert term_inst not in old_dfg.get_uses(op3), "jnz not using the old condition"
assert term_inst in old_dfg.get_uses(jnz_input), "jnz using the new condition"

BranchOptimizationPass(ac, fn).run_pass()

# Test that the jnz targets are inverted and
# the jnz condition updated
assert term_inst.opcode == "jnz"
assert term_inst.operands[0] == op3
assert term_inst.operands[1] == br2.label
assert term_inst.operands[2] == br1.label

# Test that the dfg is updated correctly
dfg = ac.request_analysis(DFGAnalysis)
assert dfg is old_dfg, "DFG should not be invalidated by BranchOptimizationPass"
assert term_inst in dfg.get_uses(op3), "jnz not using the new condition"
assert term_inst not in dfg.get_uses(jnz_input), "jnz still using the old condition"
2 changes: 2 additions & 0 deletions vyper/venom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vyper.venom.context import IRContext
from vyper.venom.function import IRFunction
from vyper.venom.ir_node_to_venom import ir_node_to_venom
from vyper.venom.passes.branch_optimization import BranchOptimizationPass
from vyper.venom.passes.dft import DFTPass
from vyper.venom.passes.make_ssa import MakeSSA
from vyper.venom.passes.mem2var import Mem2Var
Expand Down Expand Up @@ -49,6 +50,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None:
SCCP(ac, fn).run_pass()
StoreElimination(ac, fn).run_pass()
SimplifyCFGPass(ac, fn).run_pass()
BranchOptimizationPass(ac, fn).run_pass()
RemoveUnusedVariablesPass(ac, fn).run_pass()
DFTPass(ac, fn).run_pass()

Expand Down
4 changes: 4 additions & 0 deletions vyper/venom/analysis/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def get_uses(self, op: IRVariable) -> list[IRInstruction]:
def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]:
return self._dfg_outputs.get(op)

def add_use(self, op: IRVariable, inst: IRInstruction):
uses = self._dfg_inputs.setdefault(op, [])
uses.append(inst)

def remove_use(self, op: IRVariable, inst: IRInstruction):
uses = self._dfg_inputs.get(op, [])
uses.remove(inst)
Expand Down
30 changes: 30 additions & 0 deletions vyper/venom/passes/branch_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from vyper.venom.analysis.dfg import DFGAnalysis
from vyper.venom.passes.base_pass import IRPass


class BranchOptimizationPass(IRPass):
"""
This pass optimizes branches inverting jnz instructions where appropriate
"""

def _optimize_branches(self) -> None:
fn = self.function
for bb in fn.get_basic_blocks():
term_inst = bb.instructions[-1]
if term_inst.opcode != "jnz":
continue

prev_inst = self.dfg.get_producing_instruction(term_inst.operands[0])
if prev_inst.opcode == "iszero":
new_cond = prev_inst.operands[0]
term_inst.operands = [new_cond, term_inst.operands[2], term_inst.operands[1]]

# Since the DFG update is simple we do in place to avoid invalidating the DFG
# and having to recompute it (which is expensive(er))
self.dfg.remove_use(prev_inst.output, term_inst)
self.dfg.add_use(new_cond, term_inst)

def run_pass(self):
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)

self._optimize_branches()

0 comments on commit a7a647f

Please sign in to comment.