diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index c6913dd077d60..cf53caf6552d0 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -6,29 +6,6 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I return singleton_type(ft) === func end -""" - du::SSADefUse - -This struct keeps track of all uses of some mutable struct allocated in the current function: -- `du.uses::Vector{Int}` are all instances of `getfield` on the struct -- `du.defs::Vector{Int}` are all instances of `setfield!` on the struct -The terminology refers to the uses/defs of the "slot bundle" that the mutable struct represents. - -In addition we keep track of all instances of a `:foreigncall` that preserves of this mutable -struct in `du.ccall_preserve_uses`. Somewhat counterintuitively, we don't actually need to -make sure that the struct itself is live (or even allocated) at a `ccall` site. -If there are no other places where the struct escapes (and thus e.g. where its address is taken), -it need not be allocated. We do however, need to make sure to preserve any elements of this struct. -""" -struct SSADefUse - uses::Vector{Int} - defs::Vector{Int} - ccall_preserve_uses::Vector{Int} -end -SSADefUse() = SSADefUse(Int[], Int[], Int[]) - -compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses) - # assume `stmt == getfield(obj, field, ...)` or `stmt == setfield!(obj, field, val, ...)` try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) = try_compute_field(ir, stmt.args[3]) @@ -55,112 +32,6 @@ function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::E return try_compute_fieldidx(typ, field) end -function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) - # TODO: This can be much faster by looking at current level and only - # searching for those blocks in a sorted order - while !(curblock in allblocks) - curblock = domtree.idoms_bb[curblock] - end - return curblock -end - -function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) - ex = ir[SSAValue(def)][:inst] - if isexpr(ex, :new) - return ex.args[1+fidx] - else - @assert isa(ex, Expr) - # The use is whatever the setfield was - return ex.args[4] - end -end - -function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) - curblock = find_curblock(domtree, allblocks, curblock) - def = 0 - for stmt in du.defs - if block_for_inst(ir.cfg, stmt) == curblock - def = max(def, stmt) - end - end - def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx) -end - -function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) - def, useblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use) - if def == 0 - if !haskey(phinodes, curblock) - # If this happens, we need to search the predecessors for defs. Which - # one doesn't matter - if it did, we'd have had a phinode - return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) - end - # The use is the phinode - return phinodes[curblock] - else - return val_for_def_expr(ir, def, fidx) - end -end - -# even when the allocation contains an uninitialized field, we try an extra effort to check -# if this load at `idx` have any "safe" `setfield!` calls that define the field -function has_safe_def( - ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, - newidx::Int, idx::Int) - def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx) - # will throw since we already checked this `:new` site doesn't define this field - def == newidx && return false - # found a "safe" definition - def ≠ 0 && return true - # we may still be able to replace this load with `PhiNode` - # examine if all predecessors of `block` have any "safe" definition - block = block_for_inst(ir, idx) - seen = BitSet(block) - worklist = BitSet(ir.cfg.blocks[block].preds) - isempty(worklist) && return false - while !isempty(worklist) - pred = pop!(worklist) - # if this block has already been examined, bail out to avoid infinite cycles - pred in seen && return false - idx = last(ir.cfg.blocks[pred].stmts) - # NOTE `idx` isn't a load, thus we can use inclusive coondition within the `find_def_for_use` - def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx, true) - # will throw since we already checked this `:new` site doesn't define this field - def == newidx && return false - push!(seen, pred) - # found a "safe" definition for this predecessor - def ≠ 0 && continue - # check for the predecessors of this predecessor - for newpred in ir.cfg.blocks[pred].preds - push!(worklist, newpred) - end - end - return true -end - -# find the first dominating def for the given use -function find_def_for_use( - ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use::Int, inclusive::Bool=false) - useblock = block_for_inst(ir.cfg, use) - curblock = find_curblock(domtree, allblocks, useblock) - local def = 0 - for idx in du.defs - if block_for_inst(ir.cfg, idx) == curblock - if curblock != useblock - # Find the last def in this block - def = max(def, idx) - else - # Find the last def before our use - if inclusive - def = max(def, idx ≤ use ? idx : 0) - else - def = max(def, idx < use ? idx : 0) - end - end - end - end - return def, useblock, curblock -end - function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint)) if isa(val, Union{OldSSAValue, SSAValue}) val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) @@ -643,19 +514,24 @@ a result of succeeding dead code elimination. """ function sroa_pass!(ir::IRCode, nargs::Int) compact = IncrementalCompact(ir) - defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + anymutability = false for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue - is_setfield = false field_ordering = :unspecified - if is_known_call(stmt, setfield!, compact) - 4 <= length(stmt.args) <= 5 || continue - is_setfield = true - if length(stmt.args) == 5 - field_ordering = argextype(stmt.args[5], compact) + if isexpr(stmt, :new) + typ = unwrap_unionall(widenconst(argextype(SSAValue(idx), compact))) + if ismutabletype(typ) + # mutable SROA is performed later, mark it now + anymutability = true end + continue + # elseif is_known_call(stmt, setfield!, compact) + # 4 <= length(stmt.args) <= 5 || continue + # if length(stmt.args) == 5 + # field_ordering = argextype(stmt.args[5], compact) + # end elseif is_known_call(stmt, getfield, compact) 3 <= length(stmt.args) <= 5 || continue if length(stmt.args) == 5 @@ -671,40 +547,21 @@ function sroa_pass!(ir::IRCode, nargs::Int) for pidx in (6+nccallargs):length(stmt.args) preserved_arg = stmt.args[pidx] isa(preserved_arg, SSAValue) || continue - let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) - push!(intermediaries, ssa.id) - return false - end - def = simple_walk(compact, preserved_arg, callback) - isa(def, SSAValue) || continue - defidx = def.id - def = compact[defidx] - if is_known_call(def, tuple, compact) + def = simple_walk(compact, preserved_arg) + isa(def, SSAValue) || continue + defidx = def.id + def = compact[defidx] + if is_known_call(def, tuple, compact) + record_immutable_preserve!(new_preserves, def, compact) + push!(preserved, preserved_arg.id) + elseif isexpr(def, :new) + typ = unwrap_unionall(widenconst(argextype(SSAValue(defidx), compact))) + if typ isa DataType + ismutabletype(typ) && continue # mutable SROA is performed later record_immutable_preserve!(new_preserves, def, compact) push!(preserved, preserved_arg.id) - continue - elseif isexpr(def, :new) - typ = widenconst(argextype(SSAValue(defidx), compact)) - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) - end - if typ isa DataType && !ismutabletype(typ) - record_immutable_preserve!(new_preserves, def, compact) - push!(preserved, preserved_arg.id) - continue - end - else - continue - end - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() end - mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) - push!(defuse.ccall_preserve_uses, idx) - union!(mid, intermediaries) end - continue end if !isempty(new_preserves) compact[idx] = form_new_preserves(stmt, preserved, new_preserves) @@ -723,7 +580,7 @@ function sroa_pass!(ir::IRCode, nargs::Int) continue end - # analyze this `getfield` / `setfield!` call + # analyze this `getfield` call field = try_compute_field_stmt(compact, stmt) field === nothing && continue @@ -741,32 +598,7 @@ function sroa_pass!(ir::IRCode, nargs::Int) continue end - # analyze this mutable struct here for the later pass - if ismutabletype(struct_typ) - isa(val, SSAValue) || continue - let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) - push!(intermediaries, ssa.id) - return false - end - def = simple_walk(compact, val, callback) - # Mutable stuff here - isa(def, SSAValue) || continue - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() - end - mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) - if is_setfield - push!(defuse.defs, idx) - else - push!(defuse.uses, idx) - end - union!(mid, intermediaries) - end - continue - elseif is_setfield - continue # invalid `setfield!` call, but just ignore here - end + ismutabletype(struct_typ) && continue # mutable SROA is performed later # perform SROA on immutable structs here on @@ -804,184 +636,409 @@ function sroa_pass!(ir::IRCode, nargs::Int) end non_dce_finish!(compact) - if defuses !== nothing - # now go through analyzed mutable structs and see which ones we can eliminate - # NOTE copy the use count here, because `simple_dce!` may modify it and we need it - # consistent with the state of the IR here (after tracking `PhiNode` arguments, - # but before the DCE) for our predicate within `sroa_mutables!`, but we also - # try an extra effort using a callback so that reference counts are updated - used_ssas = copy(compact.used_ssas) - simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) - ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas, nargs) - return ir - else - simple_dce!(compact) - return complete(compact) + simple_dce!(compact) + ir = complete(compact) + anymutability && sroa_mutables!(ir, nargs) + return ir +end + +function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) + newex = Expr(:foreigncall) + nccallargs = length(origex.args[3]::SimpleVector) + for i in 1:(6+nccallargs-1) + push!(newex.args, origex.args[i]) end + for i in (6+nccallargs):length(origex.args) + x = origex.args[i] + # don't need to preserve intermediaries + if isa(x, SSAValue) && x.id in intermediates + continue + end + push!(newex.args, x) + end + for i in 1:length(new_preserves) + push!(newex.args, new_preserves[i]) + end + return newex end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, nargs::Int) - # initialization of domtree is delayed to avoid the expensive computation in many cases - local domtree = nothing +import .EscapeAnalysis: EscapeInfo, Indexable, getaliases + +function sroa_mutables!(ir::IRCode, nargs::Int) + # Compute domtree now, needed below, now that we have finished compacting the IR. + # This needs to be after we iterate through the IR with `IncrementalCompact` + # because removing dead blocks can invalidate the domtree. + # TODO initialization of the domtree can be delayed to avoid the expensive computation + # in cases when there are no loads to be forwarded + @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) estate = analyze_escapes(ir, nargs) - for (idx, (intermediaries, defuse)) in defuses - intermediaries = collect(intermediaries) - # Check if there are any uses we did not account for. If so, the variable - # escapes and we cannot eliminate the allocation. This works, because we're guaranteed - # not to include any intermediaries that have dead uses. As a result, missing uses will only ever - # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) - nuses = 0 - for idx in intermediaries - nuses += used_ssas[idx] - end - nuses_total = used_ssas[idx] + nuses - length(intermediaries) - nleaves == nuses_total || continue - # Find the type for this allocation - defexpr = ir[SSAValue(idx)][:inst] - isexpr(defexpr, :new) || continue - newidx = idx - typ = ir.stmts[newidx][:type] - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) + wset = BitSet(1:length(ir.stmts)+length(ir.new_nodes.stmts)) + eliminated = BitSet() + revisit = Tuple{#=related=#Vector{SSAValue}, #=liveness=#BitSet}[] + all_preserved = true + newpreserves = nothing + while !isempty(wset) + idx = pop!(wset) + ssa = SSAValue(idx) + stmt = ir[ssa][:inst] + isexpr(stmt, :new) || continue + einfo = estate[ssa] + is_load_forwardable(einfo) || continue + aliases = getaliases(ssa, estate) + if aliases === nothing + related = SSAValue[ssa] + else + related = SSAValue[] + for alias in aliases + @assert isa(alias, SSAValue) "invalid escape analysis" + push!(related, alias) + delete!(wset, alias.id) + end end - # Could still end up here if we tried to setfield! on an immutable, which would - # error at runtime, but is not illegal to have in the IR. - ismutabletype(typ) || continue - typ = typ::DataType + finfos = (einfo.AliasInfo::Indexable).infos + nfields = length(finfos) + # Partition defuses by field - fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] - all_forwarded = true - for use in defuse.uses - stmt = ir[SSAValue(use)][:inst] # == `getfield` call - # We may have discovered above that this use is dead - # after the getfield elim of immutables. In that case, - # it would have been deleted. That's fine, just ignore - # the use in that case. - if stmt === nothing - all_forwarded = false - continue + fdefuses = Vector{FieldDefUse}(undef, nfields) + for i = 1:nfields + finfo = finfos[i] + fdu = FieldDefUse() + for pc in finfo + if pc > 0 + push!(fdu.uses, GetfieldLoad(pc)) # use (getfield call) + else + push!(fdu.defs, -pc) # def (setfield! call or :new expression) + end end - field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) - field === nothing && @goto skip - push!(fielddefuse[field].uses, use) + fdefuses[i] = fdu end - for def in defuse.defs - stmt = ir[SSAValue(def)][:inst]::Expr # == `setfield!` call - field = try_compute_fieldidx_stmt(ir, stmt, typ) - field === nothing && @goto skip - isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error - push!(fielddefuse[field].defs, def) + + Liveness = einfo.Liveness + for livepc in Liveness + livestmt = ir[SSAValue(livepc)][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + @goto next_itr + elseif is_known_call(livestmt, isdefined, ir) + args = livestmt.args + length(args) ≥ 3 || continue + obj = args[2] + obj in related || continue + fld = args[3] + fldval = try_compute_field(ir, fld) + fldval === nothing && continue + typ = unwrap_unionall(widenconst(argextype(obj, ir))) + isa(typ, DataType) || continue + fldidx = try_compute_fieldidx(typ, fldval) + fldidx === nothing && continue + push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) + elseif isexpr(livestmt, :foreigncall) # preserve use (otherwise not is_load_forwarded) + for fidx in 1:nfields + push!(fdefuses[fidx].uses, PreserveUse(livepc)) + end + end end - # Check that the defexpr has defined values for all the fields - # we're accessing. In the future, we may want to relax this, - # but we should come up with semantics for well defined semantics - # for uninitialized fields first. - ndefuse = length(fielddefuse) - blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse) - for fidx in 1:ndefuse - du = fielddefuse[fidx] - isempty(du.uses) && continue - push!(du.defs, newidx) - ldu = compute_live_ins(ir.cfg, du) + + for fidx in 1:nfields + fdu = fdefuses[fidx] + isempty(fdu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, fdu) if isempty(ldu.live_in_bbs) phiblocks = Int[] else - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) end - allblocks = sort(vcat(phiblocks, ldu.def_bbs)) - blocks[fidx] = phiblocks, allblocks - if fidx + 1 > length(defexpr.args) - for use in du.uses - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in fdu.uses + isa(use, IsdefinedUse) && continue + if !has_safe_def(ir, domtree, allblocks, fdu, getuseidx(use)) + all_preserved = false + @goto next_use end end - end - is_load_forwardable(estate[SSAValue(idx)]) || println("[EA] bad EA: ", ir.argtypes[1:nargs], " at ", idx) - # Everything accounted for. Go field by field and perform idf: - # Compute domtree now, needed below, now that we have finished compacting the IR. - # This needs to be after we iterate through the IR with `IncrementalCompact` - # because removing dead blocks can invalidate the domtree. - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing : - IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses))) - for fidx in 1:ndefuse - du = fielddefuse[fidx] - ftyp = fieldtype(typ, fidx) - if !isempty(du.uses) - phiblocks, allblocks = blocks[fidx] - phinodes = IdDict{Int, SSAValue}() - for b in phiblocks - phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), - NewInstruction(PhiNode(), ftyp)) - end - # Now go through all uses and rewrite them - for stmt in du.uses - ir[SSAValue(stmt)][:inst] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) - end - if !isbitstype(ftyp) - if preserve_uses !== nothing - for (use, list) in preserve_uses - push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use)) - end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in fdu.uses + if isa(use, GetfieldLoad) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, fdu, phinodes, fidx, use) + push!(eliminated, use) + elseif all_preserved && isa(use, PreserveUse) + if newpreserves === nothing + newpreserves = IdDict{Int,Vector{Any}}() end - end - for b in phiblocks - n = ir[phinodes[b]][:inst]::PhiNode - for p in ir.cfg.blocks[b].preds - push!(n.edges, p) - push!(n.values, compute_value_for_block(ir, domtree, - allblocks, du, phinodes, fidx, p)) + use = getuseidx(use) + push!(get!(()->Any[], newpreserves, use), compute_value_for_use( + ir, domtree, allblocks, fdu, phinodes, fidx, use)) + elseif isa(use, IsdefinedUse) + use = getuseidx(use) + if has_safe_def(ir, domtree, allblocks, fdu, use) + ir[SSAValue(use)][:inst] = true + push!(eliminated, use) end + else + throw("unknown FieldDefUse") end end - for stmt in du.defs - stmt == newidx && continue - ir[SSAValue(stmt)][:inst] = nothing + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t end + @label next_use end - preserve_uses === nothing && continue - if all_forwarded - # this means all ccall preserves have been replaced with forwarded loads - # so we can potentially eliminate the allocation, otherwise we must preserve - # the whole allocation. - push!(intermediaries, newidx) + push!(revisit, (related, Liveness)) + @label next_itr + end + + # remove dead setfield! and :new allocs + deadssas = IdSet{SSAValue}() + mark_dead_ssas!(ir, deadssas, revisit, eliminated) + for ssa in deadssas + ir[ssa][:inst] = nothing + end + if all_preserved && newpreserves !== nothing + deadssas = Int[ssa.id for ssa in deadssas] + for (idx, newuses) in newpreserves + ir[SSAValue(idx)][:inst] = form_new_preserves( + ir[SSAValue(idx)][:inst]::Expr, deadssas, newuses) end - # Insert the new preserves - for (use, new_preserves) in preserve_uses - ir[SSAValue(use)][:inst] = form_new_preserves(ir[SSAValue(use)][:inst]::Expr, intermediaries, new_preserves) + end + + return ir +end + +function is_load_forwardable(x::EscapeInfo) + AliasInfo = x.AliasInfo + return isa(AliasInfo, Indexable) && !AliasInfo.array +end + +struct FieldDefUse + uses::Vector{Any} + defs::Vector{Int} +end +FieldDefUse() = FieldDefUse(Any[], Int[]) +struct GetfieldLoad + idx::Int +end +struct PreserveUse + idx::Int +end +struct IsdefinedUse + idx::Int +end +function getuseidx(@nospecialize use) + if isa(use, GetfieldLoad) + return use.idx + elseif isa(use, PreserveUse) + return use.idx + elseif isa(use, IsdefinedUse) + return use.idx + end + throw("unknown FieldDefUse") +end + +function compute_live_ins(cfg::CFG, fdu::FieldDefUse) + uses = Int[getuseidx(use) for use in fdu.uses] + compute_live_ins(cfg, fdu.defs, uses) +end + +# even when the allocation contains an uninitialized field, we try an extra effort to check +# if this load at `idx` have any "safe" `setfield!` calls that define the field +# try to find +function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + dfu === nothing && return false + def = dfu[1] + def ≠ 0 && return true # found a "safe" definition + # we may still be able to replace this load with `PhiNode` -- examine if all predecessors of + # this `block` have any "safe" definition + block = block_for_inst(ir, use) + seen = BitSet(block) + worklist = BitSet(ir.cfg.blocks[block].preds) + isempty(worklist) && return false + while !isempty(worklist) + pred = pop!(worklist) + # if this block has already been examined, bail out to avoid infinite cycles + pred in seen && return false + use = last(ir.cfg.blocks[pred].stmts) + # NOTE this `use` isn't a load, and so the inclusive condition can be used + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use, true) + dfu === nothing && return false + def = dfu[1] + push!(seen, pred) + def ≠ 0 && continue # found a "safe" definition for this predecessor + # if not, check for the predecessors of this predecessor + for newpred in ir.cfg.blocks[pred].preds + push!(worklist, newpred) end + end + return true +end - @label skip +# find the first dominating def for the given use +function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, use::Int, inclusive::Bool=false) + useblock = block_for_inst(ir.cfg, use) + curblock = find_curblock(domtree, allblocks, useblock) + curblock === nothing && return nothing + local def = 0 + for idx in fdu.defs + if block_for_inst(ir.cfg, idx) == curblock + if curblock != useblock + # Find the last def in this block + def = max(def, idx) + else + # Find the last def before our use + if inclusive + def = max(def, idx ≤ use ? idx : 0) + else + def = max(def, idx < use ? idx : 0) + end + end + end end + return def, useblock, curblock end -function is_load_forwardable(x::EscapeAnalysis.EscapeInfo) - AliasInfo = x.AliasInfo - return isa(AliasInfo, EscapeAnalysis.Indexable) && !AliasInfo.array +function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) + # TODO: This can be much faster by looking at current level and only + # searching for those blocks in a sorted order + while !(curblock in allblocks) + curblock = domtree.idoms_bb[curblock] + curblock == 0 && return nothing + end + return curblock end -function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) - newex = Expr(:foreigncall) - nccallargs = length(origex.args[3]::SimpleVector) - for i in 1:(6+nccallargs-1) - push!(newex.args, origex.args[i]) +function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + @assert dfu !== nothing "has_safe_def condition unsatisfied" + def, useblock, curblock = dfu + if def == 0 + if !haskey(phinodes, curblock) + # If this happens, we need to search the predecessors for defs. Which + # one doesn't matter - if it did, we'd have had a phinode + return compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) + end + # The use is the phinode + return phinodes[curblock] + else + return val_for_def_expr(ir, def, fidx) end - for i in (6+nccallargs):length(origex.args) - x = origex.args[i] - # don't need to preserve intermediaries - if isa(x, SSAValue) && x.id in intermediates - continue +end + +function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) + curblock = find_curblock(domtree, allblocks, curblock) + @assert curblock !== nothing "has_safe_def condition unsatisfied" + def = 0 + for stmt in fdu.defs + if block_for_inst(ir.cfg, stmt) == curblock + def = max(def, stmt) end - push!(newex.args, x) end - for i in 1:length(new_preserves) - push!(newex.args, new_preserves[i]) + return def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx) +end + +function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) + ex = ir[SSAValue(def)][:inst] + if isexpr(ex, :new) + return ex.args[1+fidx] + else + @assert is_known_call(ex, setfield!, ir) "invalid load forwarding" + return ex.args[4] end - return newex +end + +function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, + revisit::Vector{Tuple{Vector{SSAValue},BitSet}}, eliminated::BitSet) + wset = BitSet(1:length(revisit)) + while !isempty(wset) + revisit_idx = pop!(wset) + mark_dead_ssas!(ir, deadssas, revisit, eliminated, wset, revisit_idx) + end +end + +function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, + revisit::Vector{Tuple{Vector{SSAValue},BitSet}}, eliminated::BitSet, wset::BitSet, revisit_idx::Int) + related, Liveness = revisit[revisit_idx] + eliminable = SSAValue[] + for livepc in Liveness + livepc in eliminated && @goto next_live + ssa = SSAValue(livepc) + stmt = ir[ssa][:inst] + if isexpr(stmt, :new) + ssa in deadssas && @goto next_live + for new_revisit_idx in wset + if ssa in revisit[new_revisit_idx][1] + delete!(wset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, revisit, eliminated, wset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + return false + elseif is_known_call(stmt, setfield!, ir) + @assert length(stmt.args) ≥ 4 "invalid escape analysis" + obj = stmt.args[2] + val = stmt.args[4] + if isa(obj, SSAValue) + if obj in related + push!(eliminable, ssa) + @goto next_live + end + if isa(val, SSAValue) && val in related + if obj in deadssas + push!(eliminable, ssa) + @goto next_live + end + for new_revisit_idx in wset + if obj in revisit[new_revisit_idx][1] + delete!(wset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, revisit, eliminated, wset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + end + end + return false + elseif isexpr(stmt, :foreigncall) + @goto next_live + else + return false + end + @label next_live + end + for ssa in related; push!(deadssas, ssa); end + for ssa in eliminable; push!(deadssas, ssa); end + return true end """ diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index cbbf4375541e1..d169d5d23b3d6 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -68,49 +68,107 @@ let m = Meta.@lower 1 + 1 end # Tests for SROA +# ============== -import Core.Compiler: argextype, singleton_type -const EMPTY_SPTYPES = Any[] +import Core: CodeInfo, Argument, SSAValue +import Core.Compiler: argextype, singleton_type, widenconst -code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo +argextype(@nospecialize args...) = argextype(args..., Any[]) +code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::CodeInfo get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code # check if `x` is a statement with a given `head` isnew(@nospecialize x) = Meta.isexpr(x, :new) +isreturn(@nospecialize x) = isa(x, Core.ReturnNode) # check if `x` is a dynamic call of a given function iscall(y) = @nospecialize(x) -> iscall(y, x) -function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x)) +function iscall((src, f)::Tuple{CodeInfo,Function}, @nospecialize(x)) return iscall(x) do @nospecialize x - singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f + singleton_type(argextype(x, src)) === f end end iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) +is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code) +is_scalar_replaced(src::CodeInfo) = + is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code) + +function is_load_forwarded(@nospecialize(T), src::CodeInfo) + for i in 1:length(src.code) + x = src.code[i] + if iscall((src, getfield), x) + widenconst(argextype(x.args[1], src)) <: T && return false + end + end + return true +end +function is_scalar_replaced(@nospecialize(T), src::CodeInfo) + is_load_forwarded(T, src) || return false + for i in 1:length(src.code) + x = src.code[i] + if iscall((src, setfield!), x) + widenconst(argextype(x.args[1], src)) <: T && return false + elseif isnew(x) + widenconst(argextype(SSAValue(i), src)) <: T && return false + end + end + return true +end + struct ImmutableXYZ; x; y; z; end mutable struct MutableXYZ; x; y; z; end +struct ImmutableOuter{T}; x::T; y::T; z::T; end +mutable struct MutableOuter{T}; x::T; y::T; z::T; end +struct ImmutableRef{T}; x::T; end +Base.getindex(r::ImmutableRef) = r.x +mutable struct SafeRef{T}; x::T; end +Base.getindex(s::SafeRef) = getfield(s, 1) +Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x) + +# simple immutability +# ------------------- -# should optimize away very basic cases let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = (x, y, z) + xyz[1], xyz[2], xyz[3] + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end + +# simple mutability +# ----------------- + let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end - -# should handle simple mutabilities let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) xyz.y = 42 xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=x=# Core.Argument(2), 42, #=x=# Core.Argument(4)] @@ -121,19 +179,23 @@ let src = code_typed1((Any,Any,Any)) do x, y, z xyz.x, xyz.z = xyz.z, xyz.x xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] end end -# circumvent uninitialized fields as far as there is a solid `setfield!` definition + +# uninitialized fields +# -------------------- + +# safe cases let src = code_typed1() do r = Ref{Any}() r[] = 42 return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -145,7 +207,7 @@ let src = code_typed1((Bool,)) do cond return r[] end end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -156,7 +218,7 @@ let src = code_typed1((Bool,)) do cond end return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z r = Ref{Any}() @@ -171,7 +233,16 @@ let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z end return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) +end + +# unsafe cases +let src = code_typed1() do + r = Ref{Any}() + return r[] + end + @test count(isnew, src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -181,7 +252,9 @@ let src = code_typed1((Bool,)) do cond return r[] end # N.B. `r` should be allocated since `cond` might be `false` and then it will be thrown - @test any(isnew, src.code) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 end let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y r = Ref{Any}() @@ -195,12 +268,95 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y return r[] end # N.B. `r` should be allocated since `c2` might be `false` and then it will be thrown - @test any(isnew, src.code) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 end -# should include a simple alias analysis -struct ImmutableOuter{T}; x::T; y::T; z::T; end -mutable struct MutableOuter{T}; x::T; y::T; z::T; end +# load forwarding +# --------------- +# even if allocation can't be eliminated + +# safe cases +for T in (ImmutableRef{Any}, Ref{Any}) + let src = @eval code_typed1((Bool,Any,)) do c, a + r = $T(a) + if c + return r[] + else + return r + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + end + let src = @eval code_typed1((Bool,String,)) do c, a + r = $T(a) + if c + return r[]::String # adce_pass! will further eliminate this type assert call also + else + return r + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test !any(iscall((src, typeassert)), src.code) + end + let src = @eval code_typed1((Bool,Any,)) do c, a + r = $T(a) + if c + return r[] + else + throw(r) + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + end +end +let src = code_typed1((Bool,Any,Any)) do c, a, b + r = Ref{Any}(a) + if c + return r[] + end + r[] = b + return r + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 + @test count(src.code) do @nospecialize x + isreturn(x) && x.val === Argument(3) # a + end == 1 +end + +# unsafe case +let src = code_typed1((Bool,Any,Any)) do c, a, b + r = Ref{Any}(a) + r[] = b + @noinline some_escape!(r) + return r[] + end + @test !is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 +end +let src = code_typed1((Bool,String,Regex)) do c, a, b + r1 = Ref{Any}(a) + r2 = Ref{Any}(b) + return ifelse(c, r1, r2)[] + end + r = only(findall(isreturn, src.code)) + v = (src.code[r]::Core.ReturnNode).val + @test v !== Argument(3) # a + @test v !== Argument(4) # b + @test_broken is_load_forwarded(src) # ideally +end + +# aliased load forwarding +# ----------------------- + +# OK: immutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) @@ -228,7 +384,6 @@ let src = code_typed1((Any,Any,Any)) do x, y, z end end -# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well # OK: mutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) @@ -236,14 +391,14 @@ let src = code_typed1((Any,Any,Any)) do x, y, z v = t[1].x v, v, v end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] @@ -254,48 +409,477 @@ let # this is a simple end to end test case, which demonstrates allocation elimi # NOTE this test case isn't so robust and might be subject to future changes of the broadcasting implementation, # in that case you don't really need to stick to keeping this test case around simple_sroa(s) = broadcast(identity, Ref(s)) + let src = code_typed1(simple_sroa, (String,)) + @test is_scalar_replaced(src) + end s = Base.inferencebarrier("julia")::String simple_sroa(s) # NOTE don't hard-code `"julia"` in `@allocated` clause and make sure to execute the # compiled code for `simple_sroa`, otherwise everything can be folded even without SROA @test @allocated(simple_sroa(s)) == 0 end -# FIXME: immutable(mutable(...)) case +let # some insanely nested example + src = code_typed1((Int,)) do x + (Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][] + end + @test is_scalar_replaced(src) +end + +# OK: immutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end -# FIXME: mutable(mutable(...)) case + +# OK: mutable(mutable(...)) case +# new chain let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = z, y, x + outer = MutableOuter(xyz, xyz, xyz) + outer.x.x, outer.y.y, outer.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = xyz.z, xyz.y, xyz.x + outer = MutableOuter(xyz, xyz, xyz) + outer.x.x, outer.y.y, outer.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = z, y, x + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +# setfield! chain +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = Ref{MutableXYZ}() + outer[] = xyz + return outer[].x, outer[].y, outer[].z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = Ref{MutableXYZ}() + outer[] = xyz + xyz.z = 42 + return outer[].x, outer[].y, outer[].z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), 42] + end +end + +# ϕ-allocation elimination +# ------------------------ + +# safe cases +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 +end +let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z + if cond1 + ϕ = Ref{Any}(x) + elseif cond2 + ϕ = Ref{Any}(y) + else + ϕ = Ref{Any}(z) + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(4) in x.values && + #=y=# Core.Argument(5) in x.values && + #=z=# Core.Argument(6) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + ϕ[] = z + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.ReturnNode) && + #=z=# Core.Argument(5) === x.val + end == 1 +end +let src = code_typed1((Bool,Any,Any,)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + out1 = ϕ[] + else + ϕ = Ref{Any}(y) + out1 = ϕ[] + end + out2 = ϕ[] + out1, out2 + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + ϕ[] = z + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + out1 = ϕ[] + else + ϕ = Ref{Any}(y) + out1 = ϕ[] + ϕ[] = z + end + out2 = ϕ[] + out1, out2 + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any)) do cond, x, y + # these allocation form multiple ϕ-nodes + if cond + ϕ2 = ϕ1 = Ref{Any}(x) + else + ϕ2 = ϕ1 = Ref{Any}(y) + end + ϕ1[], ϕ2[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 +end +let src = code_typed1((Bool,String,)) do cond, x + # these allocation form multiple ϕ-nodes + if cond + ϕ2 = ϕ1 = Ref{Any}("foo") + else + ϕ2 = ϕ1 = Ref{Any}("bar") + end + ϕ2[] = x + y = ϕ1[] # => x + return y + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.ReturnNode) && + #=x=# x.val === Core.Argument(3) + end == 1 +end + +# unsafe cases +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + some_escape!(ϕ) + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + some_escape!(ϕ) + else + ϕ = Ref{Any}(y) + end + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any,)) do cond, x + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}() + end + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any)) do c, a + local r + if c + r = Ref{Any}(a) + end + (r::Base.RefValue{Any})[] + end + @test count(isnew, src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 end -# should work nicely with inlining to optimize away a complicated case +function mutable_ϕ_elim(x, xs) + r = Ref(x) + for x in xs + r = Ref(x) + end + return r[] +end +let src = code_typed1(mutable_ϕ_elim, (String, Vector{String})) + @test is_scalar_replaced(src) + + xs = String[string(gensym()) for _ in 1:100] + mutable_ϕ_elim("init", xs) + @test @allocated(mutable_ϕ_elim("init", xs)) == 0 +end + +# demonstrate the power of our field / alias analysis with realistic end to end examples # adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B -struct Point - x::Float64 - y::Float64 -end -#=@inline=# add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) -function compute_points() - a = Point(1.5, 2.5) - b = Point(2.25, 4.75) - for i in 0:(100000000-1) +abstract type AbstractPoint{T} end +struct Point{T} <: AbstractPoint{T} + x::T + y::T +end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute_point(T, n, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(n-1) a = add(add(a, b), b) end a.x, a.y end -let src = code_typed1(compute_points) +function compute_point(n, a, b) + for i in 0:(n-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute_point!(n, a, b) + for i in 0:(n-1) + a′ = add(add(a, b), b) + a.x = a′.x + a.y = a′.y + end +end + +let # immutable case + src = code_typed1((Int,)) do n + compute_point(Point, n, 1+.5, 2+.5, 2+.25, 4+.75) + end + @test is_scalar_replaced(Point, src) + src = code_typed1((Int,)) do n + compute_point(Point, n, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + end + @test is_scalar_replaced(Point, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) + + # mutable case + src = code_typed1((Int,)) do n + compute_point(MPoint, n, 1+.5, 2+.5, 2+.25, 4+.75) + end + @test is_scalar_replaced(MPoint, src) + src = code_typed1((Int,)) do n + compute_point(MPoint, n, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + end + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +compute_point(MPoint, 10, 1+.5, 2+.5, 2+.25, 4+.75) +compute_point(MPoint, 10, 1+.5im, 2+.5im, 2+.25im, 4+.75im) +@test @allocated(compute_point(MPoint, 10000, 1+.5, 2+.5, 2+.25, 4+.75)) == 0 +@test @allocated(compute_point(MPoint, 10000, 1+.5im, 2+.5im, 2+.25im, 4+.75im)) == 0 + +let # immutable case + src = code_typed1((Int,)) do n + compute_point(n, Point(1+.5, 2+.5), Point(2+.25, 4+.75)) + end + @test is_scalar_replaced(Point, src) + src = code_typed1((Int,)) do n + compute_point(n, Point(1+.5im, 2+.5im), Point(2+.25im, 4+.75im)) + end + @test is_scalar_replaced(Point, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) + + # mutable case + src = code_typed1((Int,)) do n + compute_point(n, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) + end + @test is_scalar_replaced(MPoint, src) + src = code_typed1((Int,)) do n + compute_point(n, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + end + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +compute_point(10, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) +compute_point(10, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) +@test @allocated(compute_point(10000, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75))) == 0 +@test @allocated(compute_point(10000, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im))) == 0 + +let # mutable case + src = code_typed1(compute_point!, (Int,MPoint{Float64},MPoint{Float64})) + @test is_scalar_replaced(MPoint, src) + src = code_typed1(compute_point!, (Int,MPoint{ComplexF64},MPoint{ComplexF64})) + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +let + af, bf = MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75) + ac, bc = MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im) + compute_point!(10, af, bf) + compute_point!(10, ac, bc) + @test @allocated(compute_point!(10000, af, bf)) == 0 + @test @allocated(compute_point!(10000, ac, bc)) == 0 +end + +# isdefined elimination +# --------------------- + +let src = code_typed1((Any,)) do a + r = Ref{Any}() + r[] = a + if isassigned(r) + return r[] + end + return nothing + end + @test is_scalar_replaced(src) +end + +callit(f, args...) = f(args...) +function isdefined_elim() + local arr::Vector{Any} + callit() do + arr = Any[] + end + return arr +end +let src = code_typed1(isdefined_elim) + @test is_scalar_replaced(src) +end +@test isdefined_elim() == Any[] + +# preserve elimination +# -------------------- + +let src = code_typed1((String,)) do s + ccall(:some_ccall, Cint, (Ptr{String},), Ref(s)) + end @test !any(isnew, src.code) end +# ----------------------- + mutable struct Foo30594; x::Float64; end Base.copy(x::Foo30594) = Foo30594(x.x) function add!(p::Foo30594, off::Foo30594) @@ -625,14 +1209,8 @@ let # `sroa_pass!` should work with constant globals return sin(x) end |> only |> first end - @test !any(src.code) do @nospecialize(stmt) - Meta.isexpr(stmt, :call) || return false - ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES) - return Core.Compiler.widenconst(ft) == typeof(getfield) - end - @test !any(src.code) do @nospecialize(stmt) - return Meta.isexpr(stmt, :new) - end + @test !any(iscall((src, getfield)), src.code) + @test !any(isnew, src.code) # mutable pass src = @eval Module() begin @@ -643,14 +1221,8 @@ let # `sroa_pass!` should work with constant globals return sin(x) end |> only |> first end - @test !any(src.code) do @nospecialize(stmt) - Meta.isexpr(stmt, :call) || return false - ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES) - return Core.Compiler.widenconst(ft) == typeof(getfield) - end - @test !any(src.code) do @nospecialize(stmt) - return Meta.isexpr(stmt, :new) - end + @test !any(iscall((src, getfield)), src.code) + @test !any(isnew, src.code) end let @@ -666,11 +1238,7 @@ let end |> only |> first end # eliminate `typeassert(x2.x, Foo)` - @test all(src.code) do @nospecialize stmt - Meta.isexpr(stmt, :call) || return true - ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES) - return Core.Compiler.widenconst(ft) !== typeof(typeassert) - end + @test !any(iscall((src, typeassert)), src.code) end let