Skip to content

Commit

Permalink
NFC: tidy up InferenceState definition (JuliaLang#45049)
Browse files Browse the repository at this point in the history
This commit sorts its fields in order of their purpose, and refactor the
main constructor accordingly. Also removed the `nargs::Int` field as it
can be easily recovered and not used so often.

Now it's defined as:
```julia
mutable struct InferenceState
    #= information about this method instance =#
    linfo::MethodInstance
    world::UInt
    mod::Module
    sptypes::Vector{Any}
    slottypes::Vector{Any}
    src::CodeInfo

    #= intermediate states for local abstract interpretation =#
    currpc::Int
    ip::BitSetBoundedMinPrioritySet
    handler_at::Vector{Int}
    ssavalue_uses::Vector{BitSet}
    stmt_types::Vector{Union{Nothing, VarTable}}
    stmt_edges::Vector{Union{Nothing, Vector{Any}}}
    stmt_info::Vector{Any}

    #= interprocedural intermediate states for abstract interpretation 
=#
    pclimitations::IdSet{InferenceState}
    limitations::IdSet{InferenceState}
    cycle_backedges::Vector{Tuple{InferenceState, Int}}
    callers_in_cycle::Vector{InferenceState}
    dont_work_on_me::Bool
    parent::Union{Nothing, InferenceState}
    inferred::Bool

    #= results =#
    result::InferenceResult
    valid_worlds::WorldRange
    bestguess
    ipo_effects::Effects

    #= flags =#
    params::InferenceParams
    restrict_abstract_call_sites::Bool
    cached::Bool

    interp::AbstractInterpreter
    ...
end
```
  • Loading branch information
aviatesk authored Apr 22, 2022
1 parent 5b5715a commit 3cff21e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 85 deletions.
11 changes: 5 additions & 6 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2072,15 +2072,15 @@ function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
return typ
end

function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, slottypes::Vector{Any}, changes::VarTable)
function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nargs::Int, slottypes::Vector{Any}, changes::VarTable)
if !(bestguess Bool) || bestguess === Bool
# give up inter-procedural constraint back-propagation
# when tmerge would widen the result anyways (as an optimization)
rt = widenconditional(rt)
else
if isa(rt, Conditional)
id = slot_id(rt.var)
if 1 id nslots
if 1 id nargs
old_id_type = widenconditional(slottypes[id]) # same as `(states[1]::VarTable)[id].typ`
if (!(rt.vtype old_id_type) || old_id_type rt.vtype) &&
(!(rt.elsetype old_id_type) || old_id_type rt.elsetype)
Expand Down Expand Up @@ -2108,7 +2108,7 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s
# pick up the first "interesting" slot, convert `rt` to its `Conditional`
# TODO: ideally we want `Conditional` and `InterConditional` to convey
# constraints on multiple slots
for slot_id in 1:nslots
for slot_id in 1:nargs
rt = bool_rt_to_conditional(rt, slottypes, changes, slot_id)
rt isa InterConditional && break
end
Expand Down Expand Up @@ -2167,10 +2167,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.dont_work_on_me = true # mark that this function is currently on the stack
W = frame.ip
states = frame.stmt_types
nargs = frame.nargs
def = frame.linfo.def
isva = isa(def, Method) && def.isva
nslots = nargs - isva
nargs = length(frame.result.argtypes) - isva
slottypes = frame.slottypes
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
while !isempty(W)
Expand Down Expand Up @@ -2238,7 +2237,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif isa(stmt, ReturnNode)
bestguess = frame.bestguess
rt = abstract_eval_value(interp, stmt.val, changes, frame)
rt = widenreturn(rt, bestguess, nslots, slottypes, changes)
rt = widenreturn(rt, bestguess, nargs, slottypes, changes)
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
let slot_id = rt.slot
Expand Down
148 changes: 71 additions & 77 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

const LineNum = Int

# The type of a variable load is either a value or an UndefVarError
# (only used in abstractinterpret, doesn't appear in optimize)
struct VarState
Expand Down Expand Up @@ -83,98 +81,91 @@ function in(idx::Int, bsbmp::BitSetBoundedMinPrioritySet)
end

mutable struct InferenceState
params::InferenceParams
result::InferenceResult # remember where to put the result
#= information about this method instance =#
linfo::MethodInstance
sptypes::Vector{Any} # types of static parameter
slottypes::Vector{Any}
world::UInt
mod::Module
currpc::LineNum
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return

# info on the state of inference and the linfo
sptypes::Vector{Any}
slottypes::Vector{Any}
src::CodeInfo
world::UInt
valid_worlds::WorldRange
nargs::Int

#= intermediate states for local abstract interpretation =#
currpc::Int
ip::BitSetBoundedMinPrioritySet # current active instruction pointers
handler_at::Vector{Int} # current exception handler info
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
stmt_types::Vector{Union{Nothing, VarTable}}
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
stmt_info::Vector{Any}
# return type
bestguess #::Type
# current active instruction pointers
ip::BitSetBoundedMinPrioritySet
# current exception handler info
handler_at::Vector{LineNum}
# ssavalue sparsity and restart info
ssavalue_uses::Vector{BitSet}

cycle_backedges::Vector{Tuple{InferenceState, LineNum}} # call-graph backedges connecting from callee to caller

#= interprocedural intermediate states for abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
callers_in_cycle::Vector{InferenceState}
dont_work_on_me::Bool
parent::Union{Nothing, InferenceState}
inferred::Bool # TODO move this to InferenceResult?

# TODO: move these to InferenceResult / Params?
cached::Bool
inferred::Bool
dont_work_on_me::Bool
#= results =#
result::InferenceResult # remember where to put the result
valid_worlds::WorldRange
bestguess #::Type
ipo_effects::Effects

#= flags =#
params::InferenceParams
# Whether to restrict inference of abstract call sites to avoid excessive work
# Set by default for toplevel frame.
restrict_abstract_call_sites::Bool

# Inferred purity flags
ipo_effects::Effects
cached::Bool # TODO move this to InferenceResult?

# The interpreter that created this inference state. Not looked at by
# NativeInterpreter. But other interpreters may use this to detect cycles
interp::AbstractInterpreter

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult, src::CodeInfo,
cache::Symbol, interp::AbstractInterpreter)
(; def) = linfo = result.linfo
code = src.code::Vector{Any}

params = InferenceParams(interp)

sp = sptypes_from_meth_instance(linfo::MethodInstance)

nssavalues = src.ssavaluetypes::Int
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
stmt_info = Any[ nothing for i = 1:length(code) ]
function InferenceState(result::InferenceResult,
src::CodeInfo, cache::Symbol, interp::AbstractInterpreter)
linfo = result.linfo
world = get_world_counter(interp)
def = linfo.def
mod = isa(def, Method) ? def.module : def
sptypes = sptypes_from_meth_instance(linfo)

code = src.code::Vector{Any}
nstmts = length(code)
s_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ]
s_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
currpc = 1
ip = BitSetBoundedMinPrioritySet(nstmts)
handler_at = compute_trycatch(code, ip.elems)
push!(ip, 1)
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
stmt_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ]
stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
stmt_info = Any[ nothing for i = 1:nstmts ]

# initial types
nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
argtypes = result.argtypes
nargs = length(argtypes)
s_argtypes = VarTable(undef, nslots)
slottypes = Vector{Any}(undef, nslots)
stmt_types[1] = stmt_type1 = VarTable(undef, nslots)
for i in 1:nslots
at = (i > nargs) ? Bottom : argtypes[i]
s_argtypes[i] = VarState(at, i > nargs)
slottypes[i] = at
argtyp = (i > nargs) ? Bottom : argtypes[i]
stmt_type1[i] = VarState(argtyp, i > nargs)
slottypes[i] = argtyp
end
s_types[1] = s_argtypes

ssavalue_uses = find_ssavalue_uses(code, nssavalues)

# exception handlers
ip = BitSetBoundedMinPrioritySet(nstmts)
handler_at = compute_trycatch(src.code, ip.elems)
push!(ip, 1)

# `throw` block deoptimization
params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)

mod = isa(def, Method) ? def.module : def
valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
pclimitations = IdSet{InferenceState}()
limitations = IdSet{InferenceState}()
cycle_backedges = Vector{Tuple{InferenceState,Int}}()
callers_in_cycle = Vector{InferenceState}()
dont_work_on_me = false
parent = nothing
inferred = false

valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
bestguess = Bottom
# TODO: Currently, any :inbounds declaration taints consistency,
# because we cannot be guaranteed whether or not boundschecks
# will be eliminated and if they are, we cannot be guaranteed
Expand All @@ -184,24 +175,27 @@ mutable struct InferenceState
inbounds = inbounds_option()
inbounds_taints_consistency = !(inbounds === :on || (inbounds === :default && !any_inbounds(code)))
consistent = inbounds_taints_consistency ? TRISTATE_UNKNOWN : ALWAYS_TRUE
ipo_effects = Effects(EFFECTS_TOTAL; consistent, inbounds_taints_consistency)

params = InferenceParams(interp)
restrict_abstract_call_sites = isa(linfo.def, Module)
@assert cache === :no || cache === :local || cache === :global
cached = cache === :global

frame = new(
params, result, linfo,
sp, slottypes, mod, #=currpc=#0,
#=pclimitations=#IdSet{InferenceState}(), #=limitations=#IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
#=bestguess=#Union{}, ip, handler_at, ssavalue_uses,
#=cycle_backedges=#Vector{Tuple{InferenceState,LineNum}}(),
#=callers_in_cycle=#Vector{InferenceState}(),
#=parent=#nothing,
#=cached=#cache === :global,
#=inferred=#false, #=dont_work_on_me=#false, #=restrict_abstract_call_sites=# isa(linfo.def, Module),
#=ipo_effects=#Effects(EFFECTS_TOTAL; consistent, inbounds_taints_consistency),
linfo, world, mod, sptypes, slottypes, src,
currpc, ip, handler_at, ssavalue_uses, stmt_types, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred,
result, valid_worlds, bestguess, ipo_effects,
params, restrict_abstract_call_sites, cached,
interp)

# some more setups
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)
result.result = frame
cache !== :no && push!(get_inference_cache(interp), result)

return frame
end
end
Expand Down
3 changes: 1 addition & 2 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function _typeinf_identifier(frame::Core.Compiler.InferenceState)
frame.world,
copy(frame.sptypes),
copy(frame.slottypes),
frame.nargs,
length(frame.result.argtypes),
)
return mi_info
end
Expand Down Expand Up @@ -665,7 +665,6 @@ function type_annotate!(sv::InferenceState, run_optimizer::Bool)
# remove dead code optimization
# and compute which variables may be used undef
states = sv.stmt_types
nargs = sv.nargs
nslots = length(states[1]::VarTable)
undefs = fill(false, nslots)
body = src.code::Array{Any,1}
Expand Down

0 comments on commit 3cff21e

Please sign in to comment.