-
-
Notifications
You must be signed in to change notification settings - Fork 804
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat[venom]: optimize branching (#4049)
This commit introduces a new pass called `BranchOptimizationPass` that optimizes inefficient branches. More specifically, when a branch is led with a logic inversion `ISZERO` we remove the `ISZERO` and invert the branch targets.
- Loading branch information
Showing
4 changed files
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from vyper.venom.analysis.analysis import IRAnalysesCache | ||
from vyper.venom.analysis.dfg import DFGAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock, IRLabel | ||
from vyper.venom.context import IRContext | ||
from vyper.venom.passes.branch_optimization import BranchOptimizationPass | ||
from vyper.venom.passes.make_ssa import MakeSSA | ||
|
||
|
||
def test_simple_jump_case(): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
bb = fn.get_basic_block() | ||
|
||
br1 = IRBasicBlock(IRLabel("then"), fn) | ||
fn.append_basic_block(br1) | ||
br2 = IRBasicBlock(IRLabel("else"), fn) | ||
fn.append_basic_block(br2) | ||
|
||
p1 = bb.append_instruction("param") | ||
op1 = bb.append_instruction("store", p1) | ||
op2 = bb.append_instruction("store", 64) | ||
op3 = bb.append_instruction("add", op1, op2) | ||
jnz_input = bb.append_instruction("iszero", op3) | ||
bb.append_instruction("jnz", jnz_input, br1.label, br2.label) | ||
|
||
br1.append_instruction("add", op3, 10) | ||
br1.append_instruction("stop") | ||
br2.append_instruction("add", op3, p1) | ||
br2.append_instruction("stop") | ||
|
||
term_inst = bb.instructions[-1] | ||
|
||
ac = IRAnalysesCache(fn) | ||
MakeSSA(ac, fn).run_pass() | ||
|
||
old_dfg = ac.request_analysis(DFGAnalysis) | ||
assert term_inst not in old_dfg.get_uses(op3), "jnz not using the old condition" | ||
assert term_inst in old_dfg.get_uses(jnz_input), "jnz using the new condition" | ||
|
||
BranchOptimizationPass(ac, fn).run_pass() | ||
|
||
# Test that the jnz targets are inverted and | ||
# the jnz condition updated | ||
assert term_inst.opcode == "jnz" | ||
assert term_inst.operands[0] == op3 | ||
assert term_inst.operands[1] == br2.label | ||
assert term_inst.operands[2] == br1.label | ||
|
||
# Test that the dfg is updated correctly | ||
dfg = ac.request_analysis(DFGAnalysis) | ||
assert dfg is old_dfg, "DFG should not be invalidated by BranchOptimizationPass" | ||
assert term_inst in dfg.get_uses(op3), "jnz not using the new condition" | ||
assert term_inst not in dfg.get_uses(jnz_input), "jnz still using the old condition" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from vyper.venom.analysis.dfg import DFGAnalysis | ||
from vyper.venom.passes.base_pass import IRPass | ||
|
||
|
||
class BranchOptimizationPass(IRPass): | ||
""" | ||
This pass optimizes branches inverting jnz instructions where appropriate | ||
""" | ||
|
||
def _optimize_branches(self) -> None: | ||
fn = self.function | ||
for bb in fn.get_basic_blocks(): | ||
term_inst = bb.instructions[-1] | ||
if term_inst.opcode != "jnz": | ||
continue | ||
|
||
prev_inst = self.dfg.get_producing_instruction(term_inst.operands[0]) | ||
if prev_inst.opcode == "iszero": | ||
new_cond = prev_inst.operands[0] | ||
term_inst.operands = [new_cond, term_inst.operands[2], term_inst.operands[1]] | ||
|
||
# Since the DFG update is simple we do in place to avoid invalidating the DFG | ||
# and having to recompute it (which is expensive(er)) | ||
self.dfg.remove_use(prev_inst.output, term_inst) | ||
self.dfg.add_use(new_cond, term_inst) | ||
|
||
def run_pass(self): | ||
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) | ||
|
||
self._optimize_branches() |