From 3cff21e725097673f969c19f8f0992c9a0838ab3 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Fri, 22 Apr 2022 10:46:49 +0900 Subject: [PATCH] NFC: tidy up `InferenceState` definition (#45049) This commit sorts its fields in order of their purpose, and refactor the main constructor accordingly. Also removed the `nargs::Int` field as it can be easily recovered and not used so often. Now it's defined as: ```julia mutable struct InferenceState #= information about this method instance =# linfo::MethodInstance world::UInt mod::Module sptypes::Vector{Any} slottypes::Vector{Any} src::CodeInfo #= intermediate states for local abstract interpretation =# currpc::Int ip::BitSetBoundedMinPrioritySet handler_at::Vector{Int} ssavalue_uses::Vector{BitSet} stmt_types::Vector{Union{Nothing, VarTable}} stmt_edges::Vector{Union{Nothing, Vector{Any}}} stmt_info::Vector{Any} #= interprocedural intermediate states for abstract interpretation =# pclimitations::IdSet{InferenceState} limitations::IdSet{InferenceState} cycle_backedges::Vector{Tuple{InferenceState, Int}} callers_in_cycle::Vector{InferenceState} dont_work_on_me::Bool parent::Union{Nothing, InferenceState} inferred::Bool #= results =# result::InferenceResult valid_worlds::WorldRange bestguess ipo_effects::Effects #= flags =# params::InferenceParams restrict_abstract_call_sites::Bool cached::Bool interp::AbstractInterpreter ... end ``` --- base/compiler/abstractinterpretation.jl | 11 +- base/compiler/inferencestate.jl | 148 ++++++++++++------------ base/compiler/typeinfer.jl | 3 +- 3 files changed, 77 insertions(+), 85 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index fcfe6d4797db9..e09585f419976 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2072,7 +2072,7 @@ function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) return typ end -function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, slottypes::Vector{Any}, changes::VarTable) +function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nargs::Int, slottypes::Vector{Any}, changes::VarTable) if !(bestguess ⊑ Bool) || bestguess === Bool # give up inter-procedural constraint back-propagation # when tmerge would widen the result anyways (as an optimization) @@ -2080,7 +2080,7 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s else if isa(rt, Conditional) id = slot_id(rt.var) - if 1 ≤ id ≤ nslots + if 1 ≤ id ≤ nargs old_id_type = widenconditional(slottypes[id]) # same as `(states[1]::VarTable)[id].typ` if (!(rt.vtype ⊑ old_id_type) || old_id_type ⊑ rt.vtype) && (!(rt.elsetype ⊑ old_id_type) || old_id_type ⊑ rt.elsetype) @@ -2108,7 +2108,7 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s # pick up the first "interesting" slot, convert `rt` to its `Conditional` # TODO: ideally we want `Conditional` and `InterConditional` to convey # constraints on multiple slots - for slot_id in 1:nslots + for slot_id in 1:nargs rt = bool_rt_to_conditional(rt, slottypes, changes, slot_id) rt isa InterConditional && break end @@ -2167,10 +2167,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) frame.dont_work_on_me = true # mark that this function is currently on the stack W = frame.ip states = frame.stmt_types - nargs = frame.nargs def = frame.linfo.def isva = isa(def, Method) && def.isva - nslots = nargs - isva + nargs = length(frame.result.argtypes) - isva slottypes = frame.slottypes ssavaluetypes = frame.src.ssavaluetypes::Vector{Any} while !isempty(W) @@ -2238,7 +2237,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) elseif isa(stmt, ReturnNode) bestguess = frame.bestguess rt = abstract_eval_value(interp, stmt.val, changes, frame) - rt = widenreturn(rt, bestguess, nslots, slottypes, changes) + rt = widenreturn(rt, bestguess, nargs, slottypes, changes) # narrow representation of bestguess slightly to prepare for tmerge with rt if rt isa InterConditional && bestguess isa Const let slot_id = rt.slot diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 0790b18bf83bd..24423deef8623 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -1,7 +1,5 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license -const LineNum = Int - # The type of a variable load is either a value or an UndefVarError # (only used in abstractinterpret, doesn't appear in optimize) struct VarState @@ -83,98 +81,91 @@ function in(idx::Int, bsbmp::BitSetBoundedMinPrioritySet) end mutable struct InferenceState - params::InferenceParams - result::InferenceResult # remember where to put the result + #= information about this method instance =# linfo::MethodInstance - sptypes::Vector{Any} # types of static parameter - slottypes::Vector{Any} + world::UInt mod::Module - currpc::LineNum - pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue - limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return - - # info on the state of inference and the linfo + sptypes::Vector{Any} + slottypes::Vector{Any} src::CodeInfo - world::UInt - valid_worlds::WorldRange - nargs::Int + + #= intermediate states for local abstract interpretation =# + currpc::Int + ip::BitSetBoundedMinPrioritySet # current active instruction pointers + handler_at::Vector{Int} # current exception handler info + ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info stmt_types::Vector{Union{Nothing, VarTable}} stmt_edges::Vector{Union{Nothing, Vector{Any}}} stmt_info::Vector{Any} - # return type - bestguess #::Type - # current active instruction pointers - ip::BitSetBoundedMinPrioritySet - # current exception handler info - handler_at::Vector{LineNum} - # ssavalue sparsity and restart info - ssavalue_uses::Vector{BitSet} - - cycle_backedges::Vector{Tuple{InferenceState, LineNum}} # call-graph backedges connecting from callee to caller + + #= interprocedural intermediate states for abstract interpretation =# + pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue + limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return + cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller callers_in_cycle::Vector{InferenceState} + dont_work_on_me::Bool parent::Union{Nothing, InferenceState} + inferred::Bool # TODO move this to InferenceResult? - # TODO: move these to InferenceResult / Params? - cached::Bool - inferred::Bool - dont_work_on_me::Bool + #= results =# + result::InferenceResult # remember where to put the result + valid_worlds::WorldRange + bestguess #::Type + ipo_effects::Effects + #= flags =# + params::InferenceParams # Whether to restrict inference of abstract call sites to avoid excessive work # Set by default for toplevel frame. restrict_abstract_call_sites::Bool - - # Inferred purity flags - ipo_effects::Effects + cached::Bool # TODO move this to InferenceResult? # The interpreter that created this inference state. Not looked at by # NativeInterpreter. But other interpreters may use this to detect cycles interp::AbstractInterpreter # src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results - function InferenceState(result::InferenceResult, src::CodeInfo, - cache::Symbol, interp::AbstractInterpreter) - (; def) = linfo = result.linfo - code = src.code::Vector{Any} - - params = InferenceParams(interp) - - sp = sptypes_from_meth_instance(linfo::MethodInstance) - - nssavalues = src.ssavaluetypes::Int - src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ] - stmt_info = Any[ nothing for i = 1:length(code) ] + function InferenceState(result::InferenceResult, + src::CodeInfo, cache::Symbol, interp::AbstractInterpreter) + linfo = result.linfo + world = get_world_counter(interp) + def = linfo.def + mod = isa(def, Method) ? def.module : def + sptypes = sptypes_from_meth_instance(linfo) + code = src.code::Vector{Any} nstmts = length(code) - s_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ] - s_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ] + currpc = 1 + ip = BitSetBoundedMinPrioritySet(nstmts) + handler_at = compute_trycatch(code, ip.elems) + push!(ip, 1) + nssavalues = src.ssavaluetypes::Int + ssavalue_uses = find_ssavalue_uses(code, nssavalues) + stmt_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ] + stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ] + stmt_info = Any[ nothing for i = 1:nstmts ] - # initial types nslots = length(src.slotflags) + slottypes = Vector{Any}(undef, nslots) argtypes = result.argtypes nargs = length(argtypes) - s_argtypes = VarTable(undef, nslots) - slottypes = Vector{Any}(undef, nslots) + stmt_types[1] = stmt_type1 = VarTable(undef, nslots) for i in 1:nslots - at = (i > nargs) ? Bottom : argtypes[i] - s_argtypes[i] = VarState(at, i > nargs) - slottypes[i] = at + argtyp = (i > nargs) ? Bottom : argtypes[i] + stmt_type1[i] = VarState(argtyp, i > nargs) + slottypes[i] = argtyp end - s_types[1] = s_argtypes - ssavalue_uses = find_ssavalue_uses(code, nssavalues) - - # exception handlers - ip = BitSetBoundedMinPrioritySet(nstmts) - handler_at = compute_trycatch(src.code, ip.elems) - push!(ip, 1) - - # `throw` block deoptimization - params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at) - - mod = isa(def, Method) ? def.module : def - valid_worlds = WorldRange(src.min_world, - src.max_world == typemax(UInt) ? get_world_counter() : src.max_world) + pclimitations = IdSet{InferenceState}() + limitations = IdSet{InferenceState}() + cycle_backedges = Vector{Tuple{InferenceState,Int}}() + callers_in_cycle = Vector{InferenceState}() + dont_work_on_me = false + parent = nothing + inferred = false + valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world) + bestguess = Bottom # TODO: Currently, any :inbounds declaration taints consistency, # because we cannot be guaranteed whether or not boundschecks # will be eliminated and if they are, we cannot be guaranteed @@ -184,24 +175,27 @@ mutable struct InferenceState inbounds = inbounds_option() inbounds_taints_consistency = !(inbounds === :on || (inbounds === :default && !any_inbounds(code))) consistent = inbounds_taints_consistency ? TRISTATE_UNKNOWN : ALWAYS_TRUE + ipo_effects = Effects(EFFECTS_TOTAL; consistent, inbounds_taints_consistency) + params = InferenceParams(interp) + restrict_abstract_call_sites = isa(linfo.def, Module) @assert cache === :no || cache === :local || cache === :global + cached = cache === :global + frame = new( - params, result, linfo, - sp, slottypes, mod, #=currpc=#0, - #=pclimitations=#IdSet{InferenceState}(), #=limitations=#IdSet{InferenceState}(), - src, get_world_counter(interp), valid_worlds, - nargs, s_types, s_edges, stmt_info, - #=bestguess=#Union{}, ip, handler_at, ssavalue_uses, - #=cycle_backedges=#Vector{Tuple{InferenceState,LineNum}}(), - #=callers_in_cycle=#Vector{InferenceState}(), - #=parent=#nothing, - #=cached=#cache === :global, - #=inferred=#false, #=dont_work_on_me=#false, #=restrict_abstract_call_sites=# isa(linfo.def, Module), - #=ipo_effects=#Effects(EFFECTS_TOTAL; consistent, inbounds_taints_consistency), + linfo, world, mod, sptypes, slottypes, src, + currpc, ip, handler_at, ssavalue_uses, stmt_types, stmt_edges, stmt_info, + pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred, + result, valid_worlds, bestguess, ipo_effects, + params, restrict_abstract_call_sites, cached, interp) + + # some more setups + src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ] + params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at) result.result = frame cache !== :no && push!(get_inference_cache(interp), result) + return frame end end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 51bc4d7afa50e..c76849d599c46 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -40,7 +40,7 @@ function _typeinf_identifier(frame::Core.Compiler.InferenceState) frame.world, copy(frame.sptypes), copy(frame.slottypes), - frame.nargs, + length(frame.result.argtypes), ) return mi_info end @@ -665,7 +665,6 @@ function type_annotate!(sv::InferenceState, run_optimizer::Bool) # remove dead code optimization # and compute which variables may be used undef states = sv.stmt_types - nargs = sv.nargs nslots = length(states[1]::VarTable) undefs = fill(false, nslots) body = src.code::Array{Any,1}