From 1992f1881e7f2df873b551f054353ba75382ee63 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 14 Sep 2022 19:25:08 +0900 Subject: [PATCH] compiler: simplify backedge calculation --- base/compiler/abstractinterpretation.jl | 1 - base/compiler/inferencestate.jl | 38 ++++--------------------- base/compiler/optimize.jl | 11 ++++--- base/compiler/typeinfer.jl | 24 ++++------------ base/compiler/utilities.jl | 19 +++++++++++-- 5 files changed, 33 insertions(+), 60 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 53c6b2157d05d..92f2610f494fc 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2497,7 +2497,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) for currpc in bbstart:bbend frame.currpc = currpc - empty_backedges!(frame, currpc) stmt = frame.src.code[currpc] # If we're at the end of the basic block ... if currpc == bbend diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index e1d20f01042c4..84946c943ac2f 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -105,7 +105,7 @@ mutable struct InferenceState # TODO: Could keep this sparsely by doing structural liveness analysis ahead of time. bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet ssavaluetypes::Vector{Any} - stmt_edges::Vector{Union{Nothing,Vector{Any}}} + backedges::IdSet{Any} stmt_info::Vector{Any} #= intermediate states for interprocedural abstract interpretation =# @@ -151,7 +151,7 @@ mutable struct InferenceState nssavalues = src.ssavaluetypes::Int ssavalue_uses = find_ssavalue_uses(code, nssavalues) nstmts = length(code) - stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ] + backedges = IdSet{Any}() stmt_info = Any[ nothing for i = 1:nstmts ] nslots = length(src.slotflags) @@ -195,7 +195,7 @@ mutable struct InferenceState frame = new( linfo, world, mod, sptypes, slottypes, src, cfg, - currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info, + currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, backedges, stmt_info, pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred, result, valid_worlds, bestguess, ipo_effects, params, restrict_abstract_call_sites, cached, @@ -488,44 +488,18 @@ end # temporarily accumulate our edges to later add as backedges in the callee function add_backedge!(caller::InferenceState, li::MethodInstance) - edges = get_stmt_edges!(caller) - if edges !== nothing - push!(edges, li) - end + push!(caller.backedges, li) return nothing end function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance) - edges = get_stmt_edges!(caller) - if edges !== nothing - push!(edges, invokesig, li) - end + push!(caller.backedges, InvokeEdge(invokesig, li)) return nothing end # used to temporarily accumulate our no method errors to later add as backedges in the callee method table function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ)) - edges = get_stmt_edges!(caller) - if edges !== nothing - push!(edges, mt, typ) - end - return nothing -end - -function get_stmt_edges!(caller::InferenceState) - if !isa(caller.linfo.def, Method) - return nothing # don't add backedges to toplevel exprs - end - edges = caller.stmt_edges[caller.currpc] - if edges === nothing - edges = caller.stmt_edges[caller.currpc] = [] - end - return edges -end - -function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc) - edges = frame.stmt_edges[currpc] - edges === nothing || empty!(edges) + push!(caller.backedges, MTEdge(typ, mt)) return nothing end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 947096e1c1338..6d4f83f4a7241 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -54,12 +54,12 @@ end ##################### struct EdgeTracker - edges::Vector{Any} + edges::IdSet{Any} valid_worlds::RefValue{WorldRange} - EdgeTracker(edges::Vector{Any}, range::WorldRange) = + EdgeTracker(edges::IdSet{Any}, range::WorldRange) = new(edges, RefValue{WorldRange}(range)) end -EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt)) +EdgeTracker() = EdgeTracker(IdSet{Any}(), 0:typemax(UInt)) intersect!(et::EdgeTracker, range::WorldRange) = et.valid_worlds[] = intersect(et.valid_worlds[], range) @@ -69,7 +69,7 @@ function add_backedge!(et::EdgeTracker, mi::MethodInstance) return nothing end function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance) - push!(et.edges, invokesig, mi) + push!(et.edges, InvokeEdge(invokesig, mi)) return nothing end @@ -121,9 +121,8 @@ mutable struct OptimizationState cfg::Union{Nothing,CFG} function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter, recompute_cfg::Bool=true) - s_edges = frame.stmt_edges[1]::Vector{Any} inlining = InliningState(params, - EdgeTracker(s_edges, frame.valid_worlds), + EdgeTracker(frame.backedges, frame.valid_worlds), WorldView(code_cache(interp), frame.world), interp) cfg = recompute_cfg ? nothing : frame.cfg diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 6db3c42a6ca54..61c6952e3d103 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -247,9 +247,9 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) caller.inferred = true end # collect results for the new expanded frame - results = Tuple{InferenceResult, Vector{Any}, Bool}[ + results = Tuple{InferenceResult, IdSet{Any}, Bool}[ ( frames[i].result, - frames[i].stmt_edges[1]::Vector{Any}, + frames[i].backedges, frames[i].cached ) for i in 1:length(frames) ] empty!(frames) @@ -494,22 +494,10 @@ function adjust_effects(sv::InferenceState) return ipo_effects end -# inference completed on `me` -# update the MethodInstance +# inference completed on `me`, now prepare to run optimization passes on fulltree function finish(me::InferenceState, interp::AbstractInterpreter) - # prepare to run optimization passes on fulltree - s_edges = me.stmt_edges[1] - if s_edges === nothing - s_edges = me.stmt_edges[1] = [] - end - for edges in me.stmt_edges - edges === nothing && continue - edges === s_edges && continue - append!(s_edges, edges) - empty!(edges) - end if me.src.edges !== nothing - append!(s_edges, me.src.edges::Vector) + union!(me.backedges, me.src.edges::Vector) me.src.edges = nothing end # inspect whether our inference had a limited result accuracy, @@ -559,7 +547,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter) end # record the backedges -function store_backedges(frame::InferenceResult, edges::Vector{Any}) +function store_backedges(frame::InferenceResult, edges::IdSet{Any}) toplevel = !isa(frame.linfo.def, Method) if !toplevel store_backedges(frame.linfo, edges) @@ -567,7 +555,7 @@ function store_backedges(frame::InferenceResult, edges::Vector{Any}) nothing end -function store_backedges(frame::MethodInstance, edges::Vector{Any}) +function store_backedges(frame::MethodInstance, edges::IdSet{Any}) for (; sig, caller) in BackedgeIterator(edges) if isa(caller, MethodInstance) ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 88e002a469575..626fead93f04a 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -229,6 +229,18 @@ is_no_constprop(method::Union{Method,CodeInfo}) = method.constprop == 0x02 # backedges # ############# +struct MTEdge + sig # ::Type + mt::Core.MethodTable + MTEdge(@nospecialize(sig), mt::Core.MethodTable) = new(sig, mt) +end + +struct InvokeEdge + invokesig # ::Type + mi::MethodInstance + InvokeEdge(@nospecialize(invokesig), mi::MethodInstance) = new(invokesig, mi) +end + """ BackedgeIterator(backedges::Vector{Any}) @@ -265,6 +277,7 @@ callyou(Float64) from callyou(Any) struct BackedgeIterator backedges::Vector{Any} end +BackedgeIterator(backedges::IdSet{Any}) = BackedgeIterator(collect(backedges)) const empty_backedge_iter = BackedgeIterator(Any[]) @@ -278,9 +291,9 @@ function iterate(iter::BackedgeIterator, i::Int=1) backedges = iter.backedges i > length(backedges) && return nothing item = backedges[i] - isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch - isa(item, Core.MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch - return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls + isa(item, MTEdge) && return BackedgePair(item.sig, item.mt), i+1 # abstract dispatch + isa(item, InvokeEdge) && return BackedgePair(item.invokesig, item.mi), i+1 # `invoke` calls + return BackedgePair(nothing, item::MethodInstance), i+1 # regular dispatch end #########