Skip to content

Commit

Permalink
optimizer: fix #43254, avoid infinite CFG traversal in SROA's dominat…
Browse files Browse the repository at this point in the history
…ion analysis (#43265)

Since CFG can be cyclic, the previous implementation of `has_safe_def`
that simply walks predecessors recursively was just wrong.

This commit fixes it by making `has_safe_def` maintain a single set that
keeps the identities of basic blocks that have been examined already.
  • Loading branch information
aviatesk authored Dec 2, 2021
1 parent ec9bb1c commit 45d43f3
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,34 @@ end
# if this load at `idx` have any "safe" `setfield!` calls that define the field
function has_safe_def(
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
newidx::Int, idx::Int, inclusive::Bool = false)
def = first(find_def_for_use(ir, domtree, allblocks, du, idx, inclusive))

# this field is supposed to be defined at the `:new` site (but it's not and thus this load will throw)
newidx::Int, idx::Int)
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx)
# will throw since we already checked this `:new` site doesn't define this field
def == newidx && return false

def 0 && return true # found a "safe" definition

# we may be able to replace this load with `PhiNode` if all the predecessors have "safe" definitions
idxblock = block_for_inst(ir, idx)
for pred in ir.cfg.blocks[idxblock].preds
lastidx = last(ir.cfg.blocks[pred].stmts)
# NOTE `lastidx` isn't a load, thus we can use inclusive coondition within the `find_def_for_use`
has_safe_def(ir, domtree, allblocks, du, newidx, lastidx, true) || return false
# found a "safe" definition
def 0 && return true
# we may still be able to replace this load with `PhiNode`
# examine if all predecessors of `block` have any "safe" definition
block = block_for_inst(ir, idx)
seen = BitSet(block)
worklist = BitSet(ir.cfg.blocks[block].preds)
isempty(worklist) && return false
while !isempty(worklist)
pred = pop!(worklist)
# if this block has already been examined, bail out to avoid infinite cycles
pred in seen && return false
idx = last(ir.cfg.blocks[pred].stmts)
# NOTE `idx` isn't a load, thus we can use inclusive coondition within the `find_def_for_use`
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx, true)
# will throw since we already checked this `:new` site doesn't define this field
def == newidx && return false
push!(seen, pred)
# found a "safe" definition for this predecessor
def 0 && continue
# check for the predecessors of this predecessor
for newpred in ir.cfg.blocks[pred].preds
push!(worklist, newpred)
end
end
return true
end
Expand Down Expand Up @@ -599,8 +613,8 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining,
# which can be very large sometimes, and analyzed program counters are often very sparse
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

"""
Expand Down Expand Up @@ -897,8 +911,8 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
end
end
for b in phiblocks
n = ir[phinodes[b]]::PhiNode
for p in ir.cfg.blocks[b].preds
n = ir[phinodes[b]]::PhiNode
push!(n.edges, p)
push!(n.values, compute_value_for_block(ir, domtree,
allblocks, du, phinodes, fidx, p))
Expand Down Expand Up @@ -967,7 +981,7 @@ function count_uses(@nospecialize(stmt), uses::Vector{Int})
end
end

function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::SPCSet, phi::Int)
worklist = Int[]
push!(worklist, phi)
while !isempty(worklist)
Expand Down Expand Up @@ -1037,7 +1051,7 @@ function adce_pass!(ir::IRCode)
changed = true
while changed
changed = false
safe_phis = BitSet()
safe_phis = SPCSet()
for phi in all_phis
# Save any phi cycles that have non-phi uses
if compact.used_ssas[phi] - phi_uses[phi] != 0
Expand Down

0 comments on commit 45d43f3

Please sign in to comment.