Skip to content

Commit

Permalink
compiler: simplify backedge calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Sep 15, 2022
1 parent 98fe82a commit 1992f18
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 60 deletions.
1 change: 0 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2497,7 +2497,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

for currpc in bbstart:bbend
frame.currpc = currpc
empty_backedges!(frame, currpc)
stmt = frame.src.code[currpc]
# If we're at the end of the basic block ...
if currpc == bbend
Expand Down
38 changes: 6 additions & 32 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ mutable struct InferenceState
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Union{Nothing,Vector{Any}}}
backedges::IdSet{Any}
stmt_info::Vector{Any}

#= intermediate states for interprocedural abstract interpretation =#
Expand Down Expand Up @@ -151,7 +151,7 @@ mutable struct InferenceState
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
backedges = IdSet{Any}()
stmt_info = Any[ nothing for i = 1:nstmts ]

nslots = length(src.slotflags)
Expand Down Expand Up @@ -195,7 +195,7 @@ mutable struct InferenceState

frame = new(
linfo, world, mod, sptypes, slottypes, src, cfg,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, backedges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred,
result, valid_worlds, bestguess, ipo_effects,
params, restrict_abstract_call_sites, cached,
Expand Down Expand Up @@ -488,44 +488,18 @@ end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
end
push!(caller.backedges, li)
return nothing
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, li)
end
push!(caller.backedges, InvokeEdge(invokesig, li))
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
end
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
edges = frame.stmt_edges[currpc]
edges === nothing || empty!(edges)
push!(caller.backedges, MTEdge(typ, mt))
return nothing
end

Expand Down
11 changes: 5 additions & 6 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ end
#####################

struct EdgeTracker
edges::Vector{Any}
edges::IdSet{Any}
valid_worlds::RefValue{WorldRange}
EdgeTracker(edges::Vector{Any}, range::WorldRange) =
EdgeTracker(edges::IdSet{Any}, range::WorldRange) =
new(edges, RefValue{WorldRange}(range))
end
EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt))
EdgeTracker() = EdgeTracker(IdSet{Any}(), 0:typemax(UInt))

intersect!(et::EdgeTracker, range::WorldRange) =
et.valid_worlds[] = intersect(et.valid_worlds[], range)
Expand All @@ -69,7 +69,7 @@ function add_backedge!(et::EdgeTracker, mi::MethodInstance)
return nothing
end
function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
push!(et.edges, invokesig, mi)
push!(et.edges, InvokeEdge(invokesig, mi))
return nothing
end

Expand Down Expand Up @@ -121,9 +121,8 @@ mutable struct OptimizationState
cfg::Union{Nothing,CFG}
function OptimizationState(frame::InferenceState, params::OptimizationParams,
interp::AbstractInterpreter, recompute_cfg::Bool=true)
s_edges = frame.stmt_edges[1]::Vector{Any}
inlining = InliningState(params,
EdgeTracker(s_edges, frame.valid_worlds),
EdgeTracker(frame.backedges, frame.valid_worlds),
WorldView(code_cache(interp), frame.world),
interp)
cfg = recompute_cfg ? nothing : frame.cfg
Expand Down
24 changes: 6 additions & 18 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
caller.inferred = true
end
# collect results for the new expanded frame
results = Tuple{InferenceResult, Vector{Any}, Bool}[
results = Tuple{InferenceResult, IdSet{Any}, Bool}[
( frames[i].result,
frames[i].stmt_edges[1]::Vector{Any},
frames[i].backedges,
frames[i].cached )
for i in 1:length(frames) ]
empty!(frames)
Expand Down Expand Up @@ -494,22 +494,10 @@ function adjust_effects(sv::InferenceState)
return ipo_effects
end

# inference completed on `me`
# update the MethodInstance
# inference completed on `me`, now prepare to run optimization passes on fulltree
function finish(me::InferenceState, interp::AbstractInterpreter)
# prepare to run optimization passes on fulltree
s_edges = me.stmt_edges[1]
if s_edges === nothing
s_edges = me.stmt_edges[1] = []
end
for edges in me.stmt_edges
edges === nothing && continue
edges === s_edges && continue
append!(s_edges, edges)
empty!(edges)
end
if me.src.edges !== nothing
append!(s_edges, me.src.edges::Vector)
union!(me.backedges, me.src.edges::Vector)
me.src.edges = nothing
end
# inspect whether our inference had a limited result accuracy,
Expand Down Expand Up @@ -559,15 +547,15 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
end

# record the backedges
function store_backedges(frame::InferenceResult, edges::Vector{Any})
function store_backedges(frame::InferenceResult, edges::IdSet{Any})
toplevel = !isa(frame.linfo.def, Method)
if !toplevel
store_backedges(frame.linfo, edges)
end
nothing
end

function store_backedges(frame::MethodInstance, edges::Vector{Any})
function store_backedges(frame::MethodInstance, edges::IdSet{Any})
for (; sig, caller) in BackedgeIterator(edges)
if isa(caller, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame)
Expand Down
19 changes: 16 additions & 3 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,18 @@ is_no_constprop(method::Union{Method,CodeInfo}) = method.constprop == 0x02
# backedges #
#############

struct MTEdge
sig # ::Type
mt::Core.MethodTable
MTEdge(@nospecialize(sig), mt::Core.MethodTable) = new(sig, mt)
end

struct InvokeEdge
invokesig # ::Type
mi::MethodInstance
InvokeEdge(@nospecialize(invokesig), mi::MethodInstance) = new(invokesig, mi)
end

"""
BackedgeIterator(backedges::Vector{Any})
Expand Down Expand Up @@ -265,6 +277,7 @@ callyou(Float64) from callyou(Any)
struct BackedgeIterator
backedges::Vector{Any}
end
BackedgeIterator(backedges::IdSet{Any}) = BackedgeIterator(collect(backedges))

const empty_backedge_iter = BackedgeIterator(Any[])

Expand All @@ -278,9 +291,9 @@ function iterate(iter::BackedgeIterator, i::Int=1)
backedges = iter.backedges
i > length(backedges) && return nothing
item = backedges[i]
isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch
isa(item, Core.MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch
return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls
isa(item, MTEdge) && return BackedgePair(item.sig, item.mt), i+1 # abstract dispatch
isa(item, InvokeEdge) && return BackedgePair(item.invokesig, item.mi), i+1 # `invoke` calls
return BackedgePair(nothing, item::MethodInstance), i+1 # regular dispatch
end

#########
Expand Down

0 comments on commit 1992f18

Please sign in to comment.