diff --git a/base/inference.jl b/base/inference.jl index b97162661fe4b..b7658d1ddb6ff 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -114,13 +114,12 @@ end mutable struct InferenceState sp::SimpleVector # static parameters - label_counter::Int # index of the current highest label for this function mod::Module currpc::LineNum # info on the state of inference and the linfo params::InferenceParams - linfo::MethodInstance # used here for the tuple (specTypes, env, Method) + linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity src::CodeInfo min_valid::UInt max_valid::UInt @@ -162,7 +161,6 @@ mutable struct InferenceState function InferenceState(linfo::MethodInstance, src::CodeInfo, optimize::Bool, cached::Bool, params::InferenceParams) code = src.code::Array{Any,1} - nl = label_counter(code) + 1 toplevel = !isa(linfo.def, Method) if !toplevel && isempty(linfo.sparam_vals) && !isempty(linfo.def.sparam_syms) @@ -274,7 +272,7 @@ mutable struct InferenceState max_valid = typemin(UInt) end frame = new( - sp, nl, inmodule, 0, params, + sp, inmodule, 0, params, linfo, src, min_valid, max_valid, nargs, s_types, s_edges, Union{}, W, 1, n, @@ -311,6 +309,71 @@ function get_staged(li::MethodInstance) end +mutable struct OptimizationState + linfo::MethodInstance + vararg_type_container #::Type + backedges::Vector{Any} + src::CodeInfo + mod::Module + nargs::Int + next_label::Int # index of the current highest label for this function + min_valid::UInt + max_valid::UInt + params::InferenceParams + function OptimizationState(frame::InferenceState) + s_edges = frame.stmt_edges[1] + if s_edges === () + s_edges = [] + frame.stmt_edges[1] = s_edges + end + next_label = label_counter(frame.src.code) + 1 + return new(frame.linfo, frame.vararg_type_container, + s_edges::Vector{Any}, + frame.src, frame.mod, frame.nargs, + next_label, frame.min_valid, frame.max_valid, + frame.params) + end + function OptimizationState(linfo::MethodInstance, src::CodeInfo, + params::InferenceParams) + # prepare src for running optimization passes + # if it isn't already + nssavalues = src.ssavaluetypes + if nssavalues isa Int + src.ssavaluetypes = Any[ Any for i = 1:nssavalues ] + end + if src.slottypes === nothing + nslots = length(src.slotnames) + src.slottypes = Any[ Any for i = 1:nslots ] + end + s_edges = [] + # cache some useful state computations + toplevel = !isa(linfo.def, Method) + if !toplevel + meth = linfo.def + inmodule = meth.module + nargs = meth.nargs + else + inmodule = linfo.def::Module + nargs = 0 + end + next_label = label_counter(src.code) + 1 + vararg_type_container = nothing # if you want something more accurate, set it yourself :P + return new(linfo, vararg_type_container, + s_edges::Vector{Any}, + src, inmodule, nargs, + next_label, + min_world(linfo), max_world(linfo), + params) + end +end + +function OptimizationState(linfo::MethodInstance, params::InferenceParams) + src = retrieve_code_info(linfo) + src === nothing && return nothing + return OptimizationState(linfo, src, params) +end + + #### debugging utilities #### function print_callstack(sv::InferenceState) @@ -378,8 +441,9 @@ function contains_is(itr, @nospecialize(x)) return false end -anymap(f::Function, a::Array{Any,1}) = Any[ f(a[i]) for i=1:length(a) ] +anymap(f::Function, a::Array{Any,1}) = Any[ f(a[i]) for i in 1:length(a) ] +_topmod(sv::OptimizationState) = _topmod(sv.mod) _topmod(sv::InferenceState) = _topmod(sv.mod) _topmod(m::Module) = ccall(:jl_base_relative_to, Any, (Any,), m)::Module @@ -409,7 +473,7 @@ function tupleparam_tail(t::SimpleVector, n) return Tuple{t[n:lt]...} end -function is_specializable_vararg_slot(arg, sv::InferenceState) +function is_specializable_vararg_slot(@nospecialize(arg), sv::Union{InferenceState, OptimizationState}) return (isa(arg, Slot) && slot_id(arg) == sv.nargs && isa(sv.vararg_type_container, DataType)) end @@ -2819,19 +2883,31 @@ end #### helper functions for typeinf initialization and looping #### +# scan body for the value of the largest referenced label function label_counter(body::Vector{Any}) - l = -1 + l = 0 for b in body - if isa(b, LabelNode) && b.label > l - l = b.label + label = 0 + if isa(b, GotoNode) + label = b.label::Int + elseif isa(b, LabelNode) + label = b.label + elseif isa(b, Expr) && b.head == :gotoifnot + label = b.args[2]::Int + elseif isa(b, Expr) && b.head == :enter + label = b.args[1]::Int + end + if label > l + l = label end end return l end -genlabel(sv) = LabelNode(sv.label_counter += 1) +genlabel(sv::OptimizationState) = LabelNode(sv.next_label += 1) -function get_label_map(body::Vector{Any}, sv::InferenceState) - labelmap = zeros(Int, sv.label_counter) +function get_label_map(body::Vector{Any}) + nlabels = label_counter(body) + labelmap = zeros(Int, nlabels) for i = 1:length(body) el = body[i] if isa(el, LabelNode) @@ -2880,7 +2956,7 @@ function find_ssavalue_defs(body::Vector{Any}, nvals::Int) return defs end -function newvar!(sv::InferenceState, @nospecialize(typ)) +function newvar!(sv::OptimizationState, @nospecialize(typ)) id = length(sv.src.ssavaluetypes) push!(sv.src.ssavaluetypes, typ) return SSAValue(id) @@ -2893,11 +2969,24 @@ coverage_enabled() = (JLOptions().code_coverage != 0) function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState) sv.min_valid = max(sv.min_valid, min_valid) sv.max_valid = min(sv.max_valid, max_valid) - @assert !isa(sv.linfo.def, Method) || !sv.cached || sv.min_valid <= sv.params.world <= sv.max_valid "invalid age range update" + @assert(!isa(sv.linfo.def, Method) || + !sv.cached || + sv.min_valid <= sv.params.world <= sv.max_valid, + "invalid age range update") + nothing +end +function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState) + sv.min_valid = max(sv.min_valid, min_valid) + sv.max_valid = min(sv.max_valid, max_valid) + @assert(!isa(sv.linfo.def, Method) || + (sv.min_valid == typemax(UInt) && sv.max_valid == typemin(UInt)) || + sv.min_valid <= sv.params.world <= sv.max_valid, + "invalid age range update") nothing end update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(edge.min_valid, edge.max_valid, sv) update_valid_age!(li::MethodInstance, sv::InferenceState) = update_valid_age!(min_world(li), max_world(li), sv) +update_valid_age!(li::MethodInstance, sv::OptimizationState) = update_valid_age!(min_world(li), max_world(li), sv) # temporarily accumulate our edges to later add as backedges in the callee function add_backedge!(li::MethodInstance, caller::InferenceState) @@ -2910,6 +2999,13 @@ function add_backedge!(li::MethodInstance, caller::InferenceState) nothing end +function add_backedge!(li::MethodInstance, caller::OptimizationState) + isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs + push!(caller.backedges, li) + update_valid_age!(li, caller) + nothing +end + # temporarily accumulate our no method errors to later add as backedges in the callee method table function add_mt_backedge(mt::MethodTable, @nospecialize(typ), caller::InferenceState) isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs @@ -3334,7 +3430,6 @@ function typeinf_work(frame::InferenceState) end function typeinf(frame::InferenceState) - typeinf_work(frame) # If the current frame is part of a cycle, solve the cycle before finishing @@ -3449,30 +3544,33 @@ function optimize(me::InferenceState) # run optimization passes on fulltree force_noinline = true if me.optimize + opt = OptimizationState(me) # This pass is required for the AST to be valid in codegen # if any `SSAValue` is created by type inference. Ref issue #6068 # This (and `reindex_labels!`) needs to be run for `!me.optimize` # if we start to create `SSAValue` in type inference when not # optimizing and use unoptimized IR in codegen. - gotoifnot_elim_pass!(me) - inlining_pass!(me, me.src.propagate_inbounds) + gotoifnot_elim_pass!(opt) + inlining_pass!(opt, opt.src.propagate_inbounds) # Clean up after inlining - gotoifnot_elim_pass!(me) - basic_dce_pass!(me) - void_use_elim_pass!(me) + gotoifnot_elim_pass!(opt) + basic_dce_pass!(opt) + void_use_elim_pass!(opt) # Compute escape information # and elide unnecessary allocations - alloc_elim_pass!(me) - getfield_elim_pass!(me) + alloc_elim_pass!(opt) + getfield_elim_pass!(opt) # Clean up for `alloc_elim_pass!` and `getfield_elim_pass!` - void_use_elim_pass!(me) + void_use_elim_pass!(opt) # Pop metadata before label reindexing - let code = me.src.code::Array{Any,1} + let code = opt.src.code::Array{Any,1} meta_elim_pass!(code, coverage_enabled()) filter!(x -> x !== nothing, code) force_noinline = popmeta!(code, :noinline)[1] end - reindex_labels!(me) + reindex_labels!(opt) + me.min_valid = opt.min_valid + me.max_valid = opt.max_valid elseif me.cached && me.parent !== nothing # top parent will be cached still, but not this intermediate work me.cached = false @@ -3482,8 +3580,9 @@ function optimize(me::InferenceState) # convert all type information into the form consumed by the code-generator widen_all_consts!(me.src) - if isa(me.bestguess, Const) || isconstType(me.bestguess) - me.const_ret = true + # compute inlining and other related properties + me.const_ret = (isa(me.bestguess, Const) || isconstType(me.bestguess)) + if me.const_ret proven_pure = false # must be proven pure to use const_api; otherwise we might skip throwing errors # (issue #20704) @@ -3556,7 +3655,6 @@ end # inference completed on `me` # update the MethodInstance and notify the edges function finish(me::InferenceState) - me.currpc = 1 # used by add_backedge if me.cached toplevel = !isa(me.linfo.def, Method) if !toplevel @@ -3620,6 +3718,7 @@ function finish(me::InferenceState) me.linfo = cache end end + me.linfo.inInference = false end # update all of the callers with real backedges by traversing the temporary list of backedges @@ -3628,7 +3727,6 @@ function finish(me::InferenceState) end # finalize and record the linfo result - me.cached && (me.linfo.inInference = false) me.inferred = true nothing end @@ -3734,7 +3832,7 @@ function type_annotate!(sv::InferenceState) annotate_slot_load!(expr, st_i, sv, undefs) elseif isa(expr, Slot) id = slot_id(expr) - if st_i[slot_id(expr)].undef + if st_i[id].undef # find used-undef variables in statement position undefs[id] = true end @@ -3767,7 +3865,7 @@ function type_annotate!(sv::InferenceState) # must mean that the target is unreachable. Later optimization passes will # assume that all branches lead to labels that exist, so we must replace # the node with the branch condition (which may have side effects). - labelmap = get_label_map(body, sv) + labelmap = get_label_map(body) for i in 1:length(body) expr = body[i] if isa(expr, Expr) && expr.head === :gotoifnot @@ -4096,7 +4194,7 @@ struct InvokeData texpr end -function inline_as_constant(@nospecialize(val), argexprs, sv::InferenceState, @nospecialize(invoke_data)) +function inline_as_constant(@nospecialize(val), argexprs::Vector{Any}, sv::OptimizationState, @nospecialize(invoke_data)) if invoke_data === nothing invoke_fexpr = nothing invoke_texpr = nothing @@ -4136,7 +4234,7 @@ function countunionsplit(atypes) return nu end -function get_spec_lambda(@nospecialize(atypes), sv, @nospecialize(invoke_data)) +function get_spec_lambda(@nospecialize(atypes), sv::OptimizationState, @nospecialize(invoke_data)) if invoke_data === nothing return ccall(:jl_get_spec_lambda, Any, (Any, UInt), atypes, sv.params.world) else @@ -4147,7 +4245,7 @@ function get_spec_lambda(@nospecialize(atypes), sv, @nospecialize(invoke_data)) end end -function linearize_args!(args::Vector{Any}, atypes::Vector{Any}, stmts::Vector{Any}, sv::InferenceState) +function linearize_args!(args::Vector{Any}, atypes::Vector{Any}, stmts::Vector{Any}, sv::OptimizationState) # linearize the IR by moving the arguments to SSA position na = length(args) @assert length(atypes) == na @@ -4166,7 +4264,7 @@ function linearize_args!(args::Vector{Any}, atypes::Vector{Any}, stmts::Vector{A return newargs end -function invoke_NF(argexprs, @nospecialize(etype), atypes::Vector{Any}, sv::InferenceState, +function invoke_NF(argexprs, @nospecialize(etype), atypes::Vector{Any}, sv::OptimizationState, @nospecialize(atype_unlimited), @nospecialize(invoke_data)) # converts a :call to :invoke nu = countunionsplit(atypes) @@ -4298,7 +4396,7 @@ end # we can estimate the total size of the enclosing function after inlining. function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector{Any}, pending_stmt::Vector{Any}, boundscheck::Symbol, - sv::InferenceState) + sv::OptimizationState) argexprs = e.args if (f === typeassert || ft ⊑ typeof(typeassert)) && length(atypes)==3 @@ -4542,15 +4640,17 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector end # compute the return value + if isa(frame, InferenceState) && !frame.src.inferred + frame = nothing + end if isa(frame, InferenceState) - frame = frame::InferenceState linfo = frame.linfo inferred = frame.src if frame.const_api # handle like jlcall_api == 2 if frame.inferred || !frame.cached - add_backedge!(frame.linfo, sv) + add_backedge!(linfo, sv) else - add_backedge!(frame, sv, 0) + add_backedge!(frame, sv) end if isa(frame.bestguess, Const) inferred_const = (frame.bestguess::Const).val @@ -4590,7 +4690,7 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector if isa(frame, InferenceState) && !frame.inferred && frame.cached # in this case, the actual backedge linfo hasn't been computed # yet, but will be when inference on the frame finishes - add_backedge!(frame, sv, 0) + add_backedge!(frame, sv) else add_backedge!(linfo, sv) end @@ -4666,27 +4766,24 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector # make labels / goto statements unique # relocate inlining information - newlabels = zeros(Int,label_counter(body.args)+1) + newlabels = zeros(Int, label_counter(body.args)) for i = 1:length(body.args) a = body.args[i] - if isa(a,LabelNode) - a = a::LabelNode + if isa(a, LabelNode) newlabel = genlabel(sv) - newlabels[a.label+1] = newlabel.label + newlabels[a.label] = newlabel.label body.args[i] = newlabel end end for i = 1:length(body.args) a = body.args[i] - if isa(a,GotoNode) - a = a::GotoNode - body.args[i] = GotoNode(newlabels[a.label+1]) - elseif isa(a,Expr) - a = a::Expr + if isa(a, GotoNode) + body.args[i] = GotoNode(newlabels[a.label]) + elseif isa(a, Expr) if a.head === :enter - a.args[1] = newlabels[a.args[1]+1] + a.args[1] = newlabels[a.args[1]::Int] elseif a.head === :gotoifnot - a.args[2] = newlabels[a.args[2]+1] + a.args[2] = newlabels[a.args[2]::Int] end end end @@ -4935,13 +5032,13 @@ function mk_getfield(texpr, i, T) return e end -function mk_tuplecall(args, sv::InferenceState) +function mk_tuplecall(args, sv::OptimizationState) e = Expr(:call, top_tuple, args...) e.typ = tuple_tfunc(Tuple{Any[widenconst(exprtype(x, sv.src, sv.mod)) for x in args]...}) return e end -function inlining_pass!(sv::InferenceState, propagate_inbounds::Bool) +function inlining_pass!(sv::OptimizationState, propagate_inbounds::Bool) # Also handles bounds check elision: # # 1. If check_bounds is always on, set `Expr(:boundscheck)` true @@ -4999,7 +5096,7 @@ const corenumtype = Union{Int32, Int64, Float32, Float64} # return inlined replacement for `e`, inserting new needed statements # at index `ins` in `stmts`. -function inlining_pass(e::Expr, sv::InferenceState, stmts::Vector{Any}, ins, boundscheck::Symbol) +function inlining_pass(e::Expr, sv::OptimizationState, stmts::Vector{Any}, ins, boundscheck::Symbol) if e.head === :meta # ignore meta nodes return e @@ -5479,7 +5576,7 @@ symequal(x::Slot , y::Slot) = x.id === y.id symequal(@nospecialize(x) , @nospecialize(y)) = x === y function occurs_outside_getfield(@nospecialize(e), @nospecialize(sym), - sv::InferenceState, field_count::Int, @nospecialize(field_names)) + sv::OptimizationState, field_count::Int, @nospecialize(field_names)) if e === sym || (isa(e, Slot) && isa(sym, Slot) && slot_id(e) == slot_id(sym)) return true end @@ -5541,7 +5638,7 @@ function occurs_outside_getfield(@nospecialize(e), @nospecialize(sym), return false end -function void_use_elim_pass!(sv::InferenceState) +function void_use_elim_pass!(sv::OptimizationState) # Remove top level SSAValue and slots that is `!usedUndef`. # Also remove some `nothing` while we are at it.... not_void_use = function (@nospecialize(ex),) @@ -5686,20 +5783,20 @@ end # does the same job as alloc_elim_pass for allocations inline in getfields # TODO can probably be removed when we switch to a linear IR -function getfield_elim_pass!(sv::InferenceState) +function getfield_elim_pass!(sv::OptimizationState) body = sv.src.code nssavalues = length(sv.src.ssavaluetypes) - sv.ssavalue_defs = find_ssavalue_defs(body, nssavalues) - sv.ssavalue_uses = find_ssavalue_uses(body, nssavalues) + ssa_defs = find_ssavalue_defs(body, nssavalues) + ssa_uses = find_ssavalue_uses(body, nssavalues) for i = 1:length(body) - body[i] = _getfield_elim_pass!(body[i], sv) + body[i] = _getfield_elim_pass!(body[i], ssa_defs, ssa_uses, sv) end end -function _getfield_elim_pass!(e::Expr, sv::InferenceState) +function _getfield_elim_pass!(e::Expr, ssa_defs::Vector{LineNum}, ssa_uses::Vector{IntSet}, sv::OptimizationState) nargs = length(e.args) for i = 1:nargs - e.args[i] = _getfield_elim_pass!(e.args[i], sv) + e.args[i] = _getfield_elim_pass!(e.args[i], ssa_defs, ssa_uses, sv) end if is_known_call(e, getfield, sv.src, sv.mod) && (nargs == 3 || nargs == 4) && @@ -5710,11 +5807,11 @@ function _getfield_elim_pass!(e::Expr, sv::InferenceState) single_use = true while isa(e1, SSAValue) if single_use - if length(sv.ssavalue_uses[e1.id + 1]) > 1 + if length(ssa_uses[e1.id + 1]) > 1 single_use = false end end - def = sv.ssavalue_defs[e1.id + 1] + def = ssa_defs[e1.id + 1] stmt = sv.src.code[def]::Expr e1 = stmt.args[2] end @@ -5765,12 +5862,12 @@ function _getfield_elim_pass!(e::Expr, sv::InferenceState) return e end -_getfield_elim_pass!(@nospecialize(e), sv) = e +_getfield_elim_pass!(@nospecialize(e), ssa_defs::Vector{LineNum}, ssa_uses::Vector{IntSet}, sv::OptimizationState) = e # check if e is a successful allocation of an struct # if it is, returns (n,f) such that it is always valid to call # getfield(..., 1 <= x <= n) or getfield(..., x in f) on the result -function is_allocation(@nospecialize(e), sv::InferenceState) +function is_allocation(@nospecialize(e), sv::OptimizationState) isa(e, Expr) || return false if is_known_call(e, tuple, sv.src, sv.mod) return (length(e.args)-1,()) @@ -5793,7 +5890,7 @@ function is_allocation(@nospecialize(e), sv::InferenceState) end # Replace branches with constant conditions with unconditional branches -function gotoifnot_elim_pass!(sv::InferenceState) +function gotoifnot_elim_pass!(sv::OptimizationState) body = sv.src.code i = 1 while i < length(body) @@ -5820,9 +5917,9 @@ function gotoifnot_elim_pass!(sv::InferenceState) end # basic dead-code-elimination of unreachable statements -function basic_dce_pass!(sv::InferenceState) +function basic_dce_pass!(sv::OptimizationState) body = sv.src.code - labelmap = get_label_map(body, sv) + labelmap = get_label_map(body) reachable = IntSet() W = IntSet() push!(W, 1) @@ -5831,11 +5928,10 @@ function basic_dce_pass!(sv::InferenceState) pc in reachable && continue push!(reachable, pc) expr = body[pc] - pc += 1 + pc´ = pc + 1 # next program-counter (after executing instruction) if isa(expr, GotoNode) - pc = labelmap[expr.label] + pc´ = labelmap[expr.label] elseif isa(expr, Expr) - label = 0 if expr.head === :gotoifnot push!(W, labelmap[expr.args[2]::Int]) elseif expr.head === :enter @@ -5844,7 +5940,7 @@ function basic_dce_pass!(sv::InferenceState) continue end end - pc <= length(body) && push!(W, pc) + pc´ <= length(body) && push!(W, pc´) end for i in 1:length(body) expr = body[i] @@ -5859,7 +5955,7 @@ end # eliminate allocation of unnecessary objects # that are only used as arguments to safe getfield calls -function alloc_elim_pass!(sv::InferenceState) +function alloc_elim_pass!(sv::OptimizationState) body = sv.src.code bexpr = Expr(:block) bexpr.args = body @@ -6005,7 +6101,7 @@ function delete_void_use!(body, var::Slot, i0) return ndel end -function replace_getfield!(e::Expr, tupname, vals, field_names, sv::InferenceState) +function replace_getfield!(e::Expr, tupname, vals, field_names, sv::OptimizationState) for i = 1:length(e.args) a = e.args[i] if !isa(a, Expr) @@ -6063,9 +6159,9 @@ function replace_getfield!(e::Expr, tupname, vals, field_names, sv::InferenceSta end # fix label numbers to always equal the statement index of the label -function reindex_labels!(sv::InferenceState) +function reindex_labels!(sv::OptimizationState) body = sv.src.code - mapping = get_label_map(body, sv) + mapping = get_label_map(body) for i = 1:length(body) el = body[i] # For goto and enter, the statement and the target has to be diff --git a/test/inference.jl b/test/inference.jl index cc1f51411450e..4d7e69c2f14c2 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -1243,3 +1243,18 @@ end _false13183 = false gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0) @test gg13183(5) == 0 + +# test the external OptimizationState constructor +let linfo = get_linfo(Base.convert, Tuple{Type{Int64}, Int32}), + world = typemax(UInt), + opt = Core.Inference.OptimizationState(linfo, Core.Inference.InferenceParams(world)) + # make sure the state of the properties look reasonable + @test opt.src !== linfo.def.source + @test length(opt.src.slotflags) == length(opt.src.slotnames) == length(opt.src.slottypes) + @test opt.src.ssavaluetypes isa Vector{Any} + @test !opt.src.inferred + @test opt.mod === Base + @test opt.max_valid === typemax(UInt) + @test opt.min_valid === Core.Inference.min_world(opt.linfo) > 2 + @test opt.nargs == 3 +end