From e22ef67b083330a6ded1a2b7540e908727acc6ea Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 24 Jan 2022 15:57:58 +0900 Subject: [PATCH] optimizer: simple array SROA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements a simple Julia-level array allocation elimination on top of #43888. ```julia julia> code_typed((String,String)) do s, t a = Vector{Base.RefValue{String}}(undef, 2) a[1] = Ref(s) a[2] = Ref(t) return a[1][] end ``` ```diff diff --git a/master b/pr index 9c8da14380..5b63d08190 100644 --- a/master +++ b/pr @@ -1,11 +1,4 @@ 1-element Vector{Any}: CodeInfo( -1 ─ %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Base.RefValue{String}}, svec(Any, Int64), 0, :(:ccall), Vector{Base.RefValue{String}}, 2, 2))::Vector{Base.RefValue{String}} -│ %2 = %new(Base.RefValue{String}, s)::Base.RefValue{String} -│ Base.arrayset(true, %1, %2, 1)::Vector{Base.RefValue{String}} -│ %4 = %new(Base.RefValue{String}, t)::Base.RefValue{String} -│ Base.arrayset(true, %1, %4, 2)::Vector{Base.RefValue{String}} -│ %6 = Base.arrayref(true, %1, 1)::Base.RefValue{String} -│ %7 = Base.getfield(%6, :x)::String -└── return %7 +1 ─ return s ) => String ``` Still this array SROA handle is very limited and able to handle only trivial examples (though I confirmed this version already eliminates few array allocations during sysimg build). For those who interested, I added some discussions on array optimization [here](https://aviatesk.github.io/EscapeAnalysis.jl/dev/#EA-Array-Analysis). --- base/compiler/optimize.jl | 14 +- base/compiler/ssair/passes.jl | 462 +++++++++++++++++++++++----------- test/compiler/codegen.jl | 8 +- test/compiler/irpasses.jl | 102 +++++++- 4 files changed, 431 insertions(+), 155 deletions(-) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index c9d88c23cfc7ca..c3522ff134d516 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -268,8 +268,7 @@ end function foreigncall_effect_free(stmt::Expr, src::Union{IRCode,IncrementalCompact}) args = stmt.args - name = args[1] - isa(name, QuoteNode) && (name = name.value) + name = normalize(args[1]) isa(name, Symbol) || return false ndims = alloc_array_ndims(name) if ndims !== nothing @@ -295,6 +294,17 @@ function alloc_array_ndims(name::Symbol) return nothing end +normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x + +function is_array_alloc(@nospecialize stmt) + isa(stmt, Expr) || return false + if isexpr(stmt, :foreigncall) + name = normalize(stmt.args[1]) + return isa(name, Symbol) && alloc_array_ndims(name) !== nothing + end + return false +end + function alloc_array_no_throw(args::Vector{Any}, ndims::Int, src::Union{IRCode,IncrementalCompact}) length(args) ≥ ndims+6 || return false atype = instanceof_tfunc(argextype(args[6], src))[1] diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 54a7ed862de12c..325ec40b1aa1c3 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -527,6 +527,9 @@ function sroa_pass!(ir::IRCode, nargs::Int) anymutability = true end continue + elseif is_array_alloc(stmt) + anymutability = true + continue # elseif is_known_call(stmt, setfield!, compact) # 4 <= length(stmt.args) <= 5 || continue # if length(stmt.args) == 5 @@ -662,7 +665,7 @@ function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preser return newex end -import .EscapeAnalysis: EscapeInfo, IndexableFields, LivenessSet, getaliases +import .EscapeAnalysis: EscapeInfo, IndexableFields, IndexableElements, LivenessSet, getaliases function sroa_mutables!(ir::IRCode, nargs::Int) # Compute domtree now, needed below, now that we have finished compacting the IR. @@ -676,12 +679,12 @@ function sroa_mutables!(ir::IRCode, nargs::Int) eliminated = BitSet() revisit = Tuple{#=related=#Vector{SSAValue}, #=Liveness=#LivenessSet}[] all_preserved = true - newpreserves = nothing + newpreserves = IdDict{Int,Vector{Any}}() while !isempty(wset) idx = pop!(wset) ssa = SSAValue(idx) stmt = ir[ssa][:inst] - isexpr(stmt, :new) || continue + isexpr(stmt, :new) || is_array_alloc(stmt) || continue einfo = estate[ssa] is_load_forwardable(einfo) || continue aliases = getaliases(ssa, estate) @@ -695,141 +698,43 @@ function sroa_mutables!(ir::IRCode, nargs::Int) delete!(wset, alias.id) end end - finfos = (einfo.AliasInfo::IndexableFields).infos - nfields = length(finfos) - - # Partition defuses by field - fdefuses = Vector{FieldDefUse}(undef, nfields) - for i = 1:nfields - finfo = finfos[i] - fdu = FieldDefUse() - for pc in finfo - if pc > 0 - push!(fdu.uses, GetfieldLoad(pc)) # use (getfield call) - else - push!(fdu.defs, -pc) # def (setfield! call or :new expression) - end - end - fdefuses[i] = fdu - end - Liveness = einfo.Liveness - for livepc in Liveness - livestmt = ir[SSAValue(livepc)][:inst] - if is_known_call(livestmt, Core.ifelse, ir) - # the succeeding domination analysis doesn't account for conditional branching - # by ifelse branching at this moment - @goto next_itr - elseif is_known_call(livestmt, isdefined, ir) - args = livestmt.args - length(args) ≥ 3 || continue - obj = args[2] - obj in related || continue - fld = args[3] - fldval = try_compute_field(ir, fld) - fldval === nothing && continue - typ = unwrap_unionall(widenconst(argextype(obj, ir))) - isa(typ, DataType) || continue - fldidx = try_compute_fieldidx(typ, fldval) - fldidx === nothing && continue - push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) - elseif isexpr(livestmt, :foreigncall) # preserve use (otherwise not is_load_forwarded) - for fidx in 1:nfields - push!(fdefuses[fidx].uses, PreserveUse(livepc)) - end - end - end - - for fidx in 1:nfields - fdu = fdefuses[fidx] - isempty(fdu.uses) && @goto next_use - # check if all uses have safe definitions first, otherwise we should bail out - # since then we may fail to form new ϕ-nodes - ldu = compute_live_ins(ir.cfg, fdu) - if isempty(ldu.live_in_bbs) - phiblocks = Int[] - else - phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) - end - allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) - for use in fdu.uses - isa(use, IsdefinedUse) && continue - if isa(use, PreserveUse) && isempty(fdu.defs) - # nothing to preserve, just ignore this use (may happen when there are unintialized fields) - continue - end - if !has_safe_def(ir, domtree, allblocks, fdu, getuseidx(use)) - all_preserved = false - @goto next_use - end - end - phinodes = IdDict{Int, SSAValue}() - for b in phiblocks - phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), - NewInstruction(PhiNode(), Any)) - end - # Now go through all uses and rewrite them - for use in fdu.uses - if isa(use, GetfieldLoad) - use = getuseidx(use) - ir[SSAValue(use)][:inst] = compute_value_for_use( - ir, domtree, allblocks, fdu, phinodes, fidx, use) - push!(eliminated, use) - elseif all_preserved && isa(use, PreserveUse) - if newpreserves === nothing - newpreserves = IdDict{Int,Vector{Any}}() - end - # record this `use` as replaceable no matter if we preserve new value or not - use = getuseidx(use) - newvalues = get!(()->Any[], newpreserves, use) - isempty(fdu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) - newval = compute_value_for_use( - ir, domtree, allblocks, fdu, phinodes, fidx, use) - if !isbitstype(widenconst(argextype(newval, ir))) - push!(newvalues, newval) - end - elseif isa(use, IsdefinedUse) - use = getuseidx(use) - if has_safe_def(ir, domtree, allblocks, fdu, use) - ir[SSAValue(use)][:inst] = true - push!(eliminated, use) - end - else - throw("unexpected use") - end - end - for b in phiblocks - ϕssa = phinodes[b] - n = ir[ϕssa][:inst]::PhiNode - t = Bottom - for p in ir.cfg.blocks[b].preds - push!(n.edges, p) - v = compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, p) - push!(n.values, v) - if t !== Any - t = tmerge(t, argextype(v, ir)) - end - end - ir[ϕssa][:type] = t - end - @label next_use + AliasInfo = einfo.AliasInfo + if isa(AliasInfo, IndexableFields) + all_preserved &= load_forward_object!(ir, domtree, + eliminated, revisit, + newpreserves, related, + AliasInfo, einfo.Liveness) + else + all_preserved &= load_forward_array!(ir, domtree, + eliminated, revisit, + newpreserves, related, + AliasInfo::IndexableElements, einfo.Liveness) end - push!(revisit, (related, Liveness)) - @label next_itr end # remove dead setfield! and :new allocs deadssas = IdSet{SSAValue}() - if all_preserved && newpreserves !== nothing + if all_preserved preserved = keys(newpreserves) else preserved = EMPTY_PRESERVED_SSAS end mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved) for ssa in deadssas + # stmt = ir[ssa][:inst] + # if is_known_call(stmt, setfield!, ir) + # println("[EASROA] eliminated setfield!: ", ir.argtypes) + # elseif isexpr(stmt, :new) + # println("[EASROA] eliminated object alloc: ", ir.argtypes) + # elseif is_known_call(stmt, arrayset, ir) + # println("[EASROA] eliminated arrayset: ", ir.argtypes) + # elseif is_array_alloc(stmt) + # println("[EASROA] eliminated array alloc: ", ir.argtypes) + # end ir[ssa][:inst] = nothing end - if all_preserved && newpreserves !== nothing + if all_preserved deadssas = Int[ssa.id for ssa in deadssas] for (idx, newuses) in newpreserves ir[SSAValue(idx)][:inst] = form_new_preserves( @@ -840,20 +745,258 @@ function sroa_mutables!(ir::IRCode, nargs::Int) return ir end +function load_forward_object!(ir::IRCode, domtree::DomTree, + eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}}, + newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue}, + AliasInfo::IndexableFields, Liveness::LivenessSet) + finfos = AliasInfo.infos + nfields = length(finfos) + + # Partition defuses by field + all_preserved = true + fdefuses = Vector{IndexedDefUse}(undef, nfields) + for i = 1:nfields + finfo = finfos[i] + idu = IndexedDefUse() + for pc in finfo + if pc > 0 + push!(idu.uses, LoadUse(pc)) # use (getfield call) + else + push!(idu.defs, -pc) # def (setfield! call or :new expression) + end + end + fdefuses[i] = idu + end + + for livepc in Liveness + livestmt = ir[SSAValue(livepc)][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + return false + elseif is_known_call(livestmt, isdefined, ir) + args = livestmt.args + length(args) ≥ 3 || continue + obj = args[2] + obj in related || continue + fld = args[3] + fldval = try_compute_field(ir, fld) + fldval === nothing && continue + typ = unwrap_unionall(widenconst(argextype(obj, ir))) + isa(typ, DataType) || continue + fldidx = try_compute_fieldidx(typ, fldval) + fldidx === nothing && continue + push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) + elseif isexpr(livestmt, :foreigncall) # preserve use (otherwise not is_load_forwarded) + for fidx in 1:nfields + push!(fdefuses[fidx].uses, PreserveUse(livepc)) + end + end + end + + for fidx in 1:nfields + idu = fdefuses[fidx] + isempty(idu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, idu) + if isempty(ldu.live_in_bbs) + phiblocks = Int[] + else + phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) + end + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in idu.uses + isa(use, IsdefinedUse) && continue + if isa(use, PreserveUse) && isempty(idu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue + end + if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use)) + all_preserved = false + @goto next_use + end + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in idu.uses + if isa(use, LoadUse) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, fidx, use) + push!(eliminated, use) + elseif isa(use, PreserveUse) + all_preserved || continue + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newvalues = get!(()->Any[], newpreserves, use) + isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + newval = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, fidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + push!(newvalues, newval) + end + elseif isa(use, IsdefinedUse) + use = getuseidx(use) + if has_safe_def(ir, domtree, allblocks, idu, use) + ir[SSAValue(use)][:inst] = true + push!(eliminated, use) + end + else + throw("load_forward_object!: unexpected use") + end + end + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t + end + @label next_use + end + push!(revisit, (related, Liveness)) + + return all_preserved +end + +# TODO is_array_isassigned folding? +function load_forward_array!(ir::IRCode, domtree::DomTree, + eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}}, + newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue}, + AliasInfo::IndexableElements, Liveness::LivenessSet) + elminfos = AliasInfo.infos + elmkeys = keys(elminfos) + + # Partition defuses by index + all_preserved = true + edefuses = IdDict{Int,IndexedDefUse}() + for eidx in elmkeys + einfo = elminfos[eidx] + idu = IndexedDefUse() + for pc in einfo + if pc > 0 + push!(idu.uses, LoadUse(pc)) # use (arrayref call) + else + push!(idu.defs, -pc) # def (arrayset call) + end + end + edefuses[eidx] = idu + end + + for livepc in Liveness + livestmt = ir[SSAValue(livepc)][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + return false + elseif isexpr(livestmt, :foreigncall) # preserve use (otherwise not is_load_forwarded) + for eidx in elmkeys + push!(edefuses[eidx].uses, PreserveUse(livepc)) + end + end + end + + for eidx in elmkeys + idu = edefuses[eidx] + isempty(idu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, idu) + if isempty(ldu.live_in_bbs) + phiblocks = Int[] + else + phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) + end + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in idu.uses + isa(use, IsdefinedUse) && continue + if isa(use, PreserveUse) && isempty(idu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue + end + if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use)) + all_preserved = false + @goto next_use + end + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in idu.uses + if isa(use, LoadUse) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, eidx, use) + push!(eliminated, use) + elseif isa(use, PreserveUse) + all_preserved || continue + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newvalues = get!(()->Any[], newpreserves, use) + isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + newval = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, eidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + push!(newvalues, newval) + end + elseif isa(use, IsdefinedUse) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = false + push!(eliminated, use) + else + throw("load_forward_array!: unexpected use") + end + end + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, eidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t + end + @label next_use + end + push!(revisit, (related, Liveness)) + + return all_preserved +end + const EMPTY_PRESERVED_SSAS = keys(IdDict{Int,Vector{Any}}()) const PreservedSets = typeof(EMPTY_PRESERVED_SSAS) function is_load_forwardable(x::EscapeInfo) AliasInfo = x.AliasInfo - return isa(AliasInfo, IndexableFields) + return isa(AliasInfo, IndexableFields) || isa(AliasInfo, IndexableElements) end -struct FieldDefUse +struct IndexedDefUse uses::Vector{Any} defs::Vector{Int} end -FieldDefUse() = FieldDefUse(Any[], Int[]) -struct GetfieldLoad +IndexedDefUse() = IndexedDefUse(Any[], Int[]) +struct LoadUse idx::Int end struct PreserveUse @@ -863,7 +1006,7 @@ struct IsdefinedUse idx::Int end function getuseidx(@nospecialize use) - if isa(use, GetfieldLoad) + if isa(use, LoadUse) return use.idx elseif isa(use, PreserveUse) return use.idx @@ -873,17 +1016,17 @@ function getuseidx(@nospecialize use) throw("getuseidx: unexpected use") end -function compute_live_ins(cfg::CFG, fdu::FieldDefUse) - uses = Int[getuseidx(use) for use in fdu.uses] - compute_live_ins(cfg, fdu.defs, uses) +function compute_live_ins(cfg::CFG, idu::IndexedDefUse) + uses = Int[getuseidx(use) for use in idu.uses] + compute_live_ins(cfg, idu.defs, uses) end # even when the allocation contains an uninitialized field, we try an extra effort to check # if this load at `idx` have any "safe" `setfield!` calls that define the field # try to find function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, use::Int) - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + idu::IndexedDefUse, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use) dfu === nothing && return false def = dfu[1] def ≠ 0 && return true # found a "safe" definition @@ -899,7 +1042,7 @@ function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, pred in seen && return false use = last(ir.cfg.blocks[pred].stmts) # NOTE this `use` isn't a load, and so the inclusive condition can be used - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use, true) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use, true) dfu === nothing && return false def = dfu[1] push!(seen, pred) @@ -914,12 +1057,12 @@ end # find the first dominating def for the given use function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, use::Int, inclusive::Bool=false) + idu::IndexedDefUse, use::Int, inclusive::Bool=false) useblock = block_for_inst(ir.cfg, use) curblock = find_curblock(domtree, allblocks, useblock) curblock === nothing && return nothing local def = 0 - for idx in fdu.defs + for idx in idu.defs if block_for_inst(ir.cfg, idx) == curblock if curblock != useblock # Find the last def in this block @@ -948,15 +1091,15 @@ function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) end function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use) @assert dfu !== nothing "has_safe_def condition unsatisfied" def, useblock, curblock = dfu if def == 0 if !haskey(phinodes, curblock) # If this happens, we need to search the predecessors for defs. Which # one doesn't matter - if it did, we'd have had a phinode - return compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) + return compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) end # The use is the phinode return phinodes[curblock] @@ -966,11 +1109,11 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) + idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) curblock = find_curblock(domtree, allblocks, curblock) @assert curblock !== nothing "has_safe_def condition unsatisfied" def = 0 - for stmt in fdu.defs + for stmt in idu.defs if block_for_inst(ir.cfg, stmt) == curblock def = max(def, stmt) end @@ -982,9 +1125,12 @@ function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) ex = ir[SSAValue(def)][:inst] if isexpr(ex, :new) return ex.args[1+fidx] - else - @assert is_known_call(ex, setfield!, ir) "invalid load forwarding" + elseif is_known_call(ex, setfield!, ir) + return ex.args[4] + elseif is_known_call(ex, arrayset, ir) return ex.args[4] + else + throw("invalid load forwarding") end end @@ -1053,6 +1199,34 @@ function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, end end return false + elseif is_known_call(stmt, arrayset, ir) + @assert length(stmt.args) ≥ 4 "invalid escape analysis" + ary = stmt.args[3] + val = stmt.args[4] + if isa(ary, SSAValue) + if ary in related + push!(eliminable, ssa) + @goto next_live + end + if isa(val, SSAValue) && val in related + if ary in deadssas + push!(eliminable, ssa) + @goto next_live + end + for new_revisit_idx in wset + if ary in revisit[new_revisit_idx][1] + delete!(wset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved, wset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + end + end + return false elseif isexpr(stmt, :foreigncall) livepc in preserved && @goto next_live return false diff --git a/test/compiler/codegen.jl b/test/compiler/codegen.jl index 7469dc74c8156e..ba0efff348b904 100644 --- a/test/compiler/codegen.jl +++ b/test/compiler/codegen.jl @@ -548,27 +548,27 @@ end # main use case function f1(cond) val = [1] - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f1, Tuple{Bool}, true, false, false)) # stack allocated objects (JuliaLang/julia#34241) function f3(cond) val = ([1],) - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f3, Tuple{Bool}, true, false, false)) # unions of immutables (JuliaLang/julia#39501) function f2(cond) val = cond ? 1 : 1f0 - GC.@preserve val begin end + GC.@preserve val begin val end end @test !occursin("llvm.julia.gc_preserve_begin", get_llvm(f2, Tuple{Bool}, true, false, false)) # make sure the fix for the above doesn't regress #34241 function f4(cond) val = cond ? ([1],) : ([1f0],) - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f4, Tuple{Bool}, true, false, false)) end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 1d5cf115ae2577..d1bc4cddff5e46 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -71,7 +71,7 @@ end # ============== import Core: CodeInfo, Argument, SSAValue -import Core.Compiler: argextype, singleton_type, widenconst +import Core.Compiler: argextype, singleton_type, widenconst, is_array_alloc argextype(@nospecialize args...) = argextype(args..., Any[]) code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::CodeInfo @@ -90,9 +90,16 @@ function iscall((src, f)::Tuple{CodeInfo,Function}, @nospecialize(x)) end iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) -is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code) -is_scalar_replaced(src::CodeInfo) = - is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code) +is_load_forwarded(src::CodeInfo) = + !any(iscall((src, getfield)), src.code) && !any(iscall((src, Core.arrayset)), src.code) +function is_scalar_replaced(src::CodeInfo) + is_load_forwarded(src) || return false + any(iscall((src, setfield!)), src.code) && return false + any(isnew, src.code) && return false + any(iscall((src, Core.arrayset)), src.code) && return false + any(is_array_alloc, src.code) && return false + return true +end function is_load_forwarded(@nospecialize(T), src::CodeInfo) for i in 1:length(src.code) @@ -865,7 +872,7 @@ function isdefined_elim() return arr end let src = code_typed1(isdefined_elim) - @test is_scalar_replaced(src) + @test count(isnew, src.code) == 0 # eliminates closure constructs end @test isdefined_elim() == Any[] @@ -888,6 +895,91 @@ let src = code_typed1() do @test count(isnew, src.code) == 1 end +# array SROA +# ---------- + +let src = code_typed1((Any,)) do s + a = Vector{Any}(undef, 1) + a[1] = s + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Any[nothing] + a[1] = s + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((String,String)) do s, t + a = Vector{Any}(undef, 2) + a[1] = Ref(s) + a[2] = Ref(t) + return a[1] + end + @test count(isnew, src.code) == 1 +end +let src = code_typed1((String,)) do s + a = Vector{Base.RefValue{String}}(undef, 1) + a[1] = Ref(s) + return a[1][] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((String,String)) do s, t + a = Vector{Base.RefValue{String}}(undef, 2) + a[1] = Ref(s) + a[2] = Ref(t) + return a[1][] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Vector{Any}[Any[nothing]] + a[1][1] = s + return a[1][1] + end + @test_broken is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any)) do c, s, t + a = Any[nothing] + if c + a[1] = s + else + a[1] = t + end + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any,Any,Any,)) do c, s1, s2, t1, t2 + if c + a = Vector{Any}(undef, 2) + a[1] = s1 + a[2] = s2 + else + a = Vector{Any}(undef, 2) + a[1] = t1 + a[2] = t2 + end + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any)) do c, s, t + # XXX this implicitly forms tuple to getfield chains + # and SROA on it produces complicated control flow + if c + a = Any[s] + else + a = Any[t] + end + return a[1] + end + @test_broken is_scalar_replaced(src) +end + # ----------------------- mutable struct Foo30594; x::Float64; end