From 1669d532de7434108f1092f34361166737706ba5 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Fri, 27 Oct 2017 01:09:43 -0400 Subject: [PATCH 1/3] inference: enable constant propagation through function boundaries Create a separate type for results that holds just the result of running some configuration of inference, and does local caching of the results. --- NEWS.md | 7 + base/inference.jl | 563 +++++++++++++++++++++++++++------------------- test/inference.jl | 5 + 3 files changed, 344 insertions(+), 231 deletions(-) diff --git a/NEWS.md b/NEWS.md index c280cc864979d..5128b5d7af869 100644 --- a/NEWS.md +++ b/NEWS.md @@ -418,6 +418,13 @@ Compiler/Runtime improvements exceeds the cost of setting up and issuing a subroutine call. ([#22210], [#22732]) + * Inference recursion-detection heuristics are now more precise, + allowing them to be triggered less often, but being more agressive when they + are triggered to drive the inference computation to a solution ([#23912]). + + * Inference now propagates constants inter-procedurally, and can compute + various constants expressions at compile-time ([#24362]). + Deprecated or removed --------------------- diff --git a/base/inference.jl b/base/inference.jl index c229a00b63a0f..69f9db84dad8e 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -9,22 +9,112 @@ const TUPLE_COMPLEXITY_LIMIT_DEPTH = 3 const MAX_INLINE_CONST_SIZE = 256 +const empty_vector = Vector{Any}() + +mutable struct InferenceResult + linfo::MethodInstance + args::Vector{Any} + result # ::Type, or InferenceState if WIP + src::Union{CodeInfo, Void} # if inferred copy is available + function InferenceResult(linfo::MethodInstance) + if isdefined(linfo, :inferred_const) + result = Const(linfo.inferred_const) + else + result = linfo.rettype + end + return new(linfo, empty_vector, result, nothing) + end +end + +function get_argtypes(result::InferenceResult) + result.args === empty_vector || return result.args # already cached + linfo = result.linfo + toplevel = !isa(linfo.def, Method) + atypes::SimpleVector = unwrap_unionall(linfo.specTypes).parameters + nargs::Int = toplevel ? 0 : linfo.def.nargs + args = Vector{Any}(uninitialized, nargs) + if !toplevel && linfo.def.isva + if linfo.specTypes == Tuple + if nargs > 1 + atypes = svec(Any[ Any for i = 1:(nargs - 1) ]..., Tuple.parameters[1]) + end + vararg_type = Tuple + else + vararg_type = rewrap(tupleparam_tail(atypes, nargs), linfo.specTypes) + end + args[nargs] = vararg_type + nargs -= 1 + end + laty = length(atypes) + if laty > 0 + if laty > nargs + laty = nargs + end + local lastatype + atail = laty + for i = 1:laty + atyp = atypes[i] + if i == laty && isvarargtype(atyp) + atyp = unwrap_unionall(atyp).parameters[1] + atail -= 1 + end + if isa(atyp, TypeVar) + atyp = atyp.ub + end + if isa(atyp, DataType) && isdefined(atyp, :instance) + # replace singleton types with their equivalent Const object + atyp = Const(atyp.instance) + elseif isconstType(atyp) + atyp = Const(atyp.parameters[1]) + else + atyp = rewrap_unionall(atyp, linfo.specTypes) + end + i == laty && (lastatype = atyp) + args[i] = atyp + end + for i = (atail + 1):nargs + args[i] = lastatype + end + else + @assert nargs == 0 "invalid specialization of method" # wrong number of arguments + end + result.args = args + return args +end + struct InferenceParams + cache::Vector{InferenceResult} world::UInt # optimization inlining::Bool + ipo_constant_propagation::Bool + aggressive_constant_propagation::Bool inline_cost_threshold::Int # number of CPU cycles beyond which it's not worth inlining inline_nonleaf_penalty::Int # penalty for dynamic dispatch inline_tupleret_bonus::Int # extra willingness for non-isbits tuple return types - # parameters limiting potentially-infinite types (configurable) + # don't consider more than N methods. this trades off between + # compiler performance and generated code performance. + # typically, considering many methods means spending lots of time + # obtaining poor type information. + # It is important for N to be >= the number of methods in the error() + # function, so we can still know that error() is always Bottom. MAX_METHODS::Int + # the maximum number of union-tuples to swap / expand + # before computing the set of matching methods + MAX_UNION_SPLITTING::Int + # the maximum number of union-tuples to swap / expand + # when inferring a call to _apply + MAX_APPLY_UNION_ENUM::Int + + # parameters limiting large types MAX_TUPLETYPE_LEN::Int MAX_TUPLE_DEPTH::Int + + # when attempting to inlining _apply, abort the optimization if the tuple + # contains more than this many elements MAX_TUPLE_SPLAT::Int - MAX_UNION_SPLITTING::Int - MAX_APPLY_UNION_ENUM::Int # reasonable defaults function InferenceParams(world::UInt; @@ -38,9 +128,10 @@ struct InferenceParams tuple_splat::Int = 16, union_splitting::Int = 4, apply_union_enum::Int = 8) - return new(world, inlining, inline_cost_threshold, inline_nonleaf_penalty, - inline_tupleret_bonus, max_methods, tupletype_len, - tuple_depth, tuple_splat, union_splitting, apply_union_enum) + return new(Vector{InferenceResult}(), + world, inlining, true, false, inline_cost_threshold, inline_nonleaf_penalty, + inline_tupleret_bonus, max_methods, union_splitting, apply_union_enum, + tupletype_len, tuple_depth, tuple_splat) end end @@ -113,13 +204,14 @@ function rewrap(@nospecialize(t), @nospecialize(u)) end mutable struct InferenceState + params::InferenceParams # describes how to compute the result + result::InferenceResult # remember where to put the result + linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity sp::SimpleVector # static parameters mod::Module currpc::LineNum # info on the state of inference and the linfo - params::InferenceParams - linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity src::CodeInfo min_valid::UInt max_valid::UInt @@ -148,18 +240,17 @@ mutable struct InferenceState const_api::Bool const_ret::Bool - # TODO: put these in InferenceParams (depends on proper multi-methodcache support) + # TODO: move these to InferenceResult / InferenceParams? optimize::Bool cached::Bool limited::Bool - inferred::Bool - dont_work_on_me::Bool # src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results - function InferenceState(linfo::MethodInstance, src::CodeInfo, + function InferenceState(result::InferenceResult, src::CodeInfo, optimize::Bool, cached::Bool, params::InferenceParams) + linfo = result.linfo code = src.code::Array{Any,1} toplevel = !isa(linfo.def, Method) @@ -167,7 +258,7 @@ mutable struct InferenceState # linfo is unspecialized sp = Any[] sig = linfo.def.sig - while isa(sig,UnionAll) + while isa(sig, UnionAll) push!(sp, sig.var) sig = sig.body end @@ -185,66 +276,24 @@ mutable struct InferenceState # initial types nslots = length(src.slotnames) - s_types[1] = Any[ VarState(Bottom, true) for i = 1:nslots ] - src.slottypes = Any[ Bottom for i = 1:nslots ] - - atypes = unwrap_unionall(linfo.specTypes) - nargs::Int = toplevel ? 0 : linfo.def.nargs - la = nargs + argtypes = get_argtypes(result) vararg_type_container = nothing - if la > 0 - if linfo.def.isva - if atypes == Tuple - if la > 1 - atypes = Tuple{Any[Any for i = 1:(la - 1)]..., Tuple.parameters[1]} - end - vararg_type = Tuple - else - vararg_type_container = limit_tuple_depth(params, tupleparam_tail(atypes.parameters, la)) + nargs = length(argtypes) + s_argtypes = VarTable(uninitialized, nslots) + src.slottypes = Vector{Any}(uninitialized, nslots) + for i in 1:nslots + at = (i > nargs) ? Bottom : argtypes[i] + if !toplevel && linfo.def.isva && i == nargs + if !(at == Tuple) # would just be a no-op + vararg_type_container = limit_tuple_depth(params, unwrap_unionall(at)) # TODO: should be limiting tuple depth much earlier than here vararg_type = tuple_tfunc(vararg_type_container) # returns a Const object, if applicable - vararg_type = rewrap(vararg_type, linfo.specTypes) - end - s_types[1][la] = VarState(vararg_type, false) - src.slottypes[la] = vararg_type - la -= 1 - end - end - - laty = length(atypes.parameters) - if laty > 0 - if laty > la - laty = la - end - local lastatype - atail = laty - for i = 1:laty - atyp = atypes.parameters[i] - if i == laty && isvarargtype(atyp) - atyp = unwrap_unionall(atyp).parameters[1] - atail -= 1 - end - if isa(atyp, TypeVar) - atyp = atyp.ub - end - if isa(atyp, DataType) && isdefined(atyp, :instance) - # replace singleton types with their equivalent Const object - atyp = Const(atyp.instance) - elseif isconstType(atyp) - atyp = Const(atyp.parameters[1]) - else - atyp = rewrap_unionall(atyp, linfo.specTypes) + at = rewrap(vararg_type, linfo.specTypes) end - i == laty && (lastatype = atyp) - s_types[1][i] = VarState(atyp, false) - src.slottypes[i] = atyp end - for i = (atail + 1):la - s_types[1][i] = VarState(lastatype, false) - src.slottypes[i] = lastatype - end - else - @assert la == 0 # wrong number of arguments + s_argtypes[i] = VarState(at, i > nargs) + src.slottypes[i] = at end + s_types[1] = s_argtypes ssavalue_uses = find_ssavalue_uses(code, nssavalues) ssavalue_defs = find_ssavalue_defs(code, nssavalues) @@ -272,8 +321,9 @@ mutable struct InferenceState max_valid = typemin(UInt) end frame = new( - sp, inmodule, 0, params, - linfo, src, min_valid, max_valid, + params, result, linfo, + sp, inmodule, 0, + src, min_valid, max_valid, nargs, s_types, s_edges, Union{}, W, 1, n, cur_hand, handler_at, n_handlers, @@ -282,26 +332,32 @@ mutable struct InferenceState Vector{InferenceState}(), # callers_in_cycle #=parent=#nothing, false, false, optimize, cached, false, false, false) + result.result = frame + cached && push!(params.cache, result) return frame end end function InferenceState(linfo::MethodInstance, optimize::Bool, cached::Bool, params::InferenceParams) + return InferenceState(InferenceResult(linfo), optimize, cached, params) +end +function InferenceState(result::InferenceResult, + optimize::Bool, cached::Bool, params::InferenceParams) # prepare an InferenceState object for inferring lambda - src = retrieve_code_info(linfo) + src = retrieve_code_info(result.linfo) src === nothing && return nothing if JLOptions().debug_level == 2 # this is a debug build of julia, so let's validate linfo - errors = validate_code(linfo, src) + errors = validate_code(result.linfo, src) if !isempty(errors) for e in errors println(STDERR, "WARNING: Encountered invalid lowered code for method ", - linfo.def, ": ", e) + result.linfo, ": ", e) end end end - return InferenceState(linfo, src, optimize, cached, params) + return InferenceState(result, src, optimize, cached, params) end function get_staged(li::MethodInstance) @@ -380,6 +436,7 @@ function print_callstack(sv::InferenceState) while sv !== nothing print(sv.linfo) sv.limited && print(" [limited]") + !sv.cached && print(" [uncached]") println() for cycle in sv.callers_in_cycle print(' ', cycle.linfo) @@ -1818,17 +1875,10 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci return tunion end -function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::InferenceState) - # don't consider more than N methods. this trades off between - # compiler performance and generated code performance. - # typically, considering many methods means spending lots of time - # obtaining poor type information. - # It is important for N to be >= the number of methods in the error() - # function, so we can still know that error() is always Bottom. - # here I picked 4. - argtype = limit_tuple_type(atype, sv.params) - argtypes = unwrap_unionall(argtype).parameters - ft = unwrap_unionall(argtypes[1]) # TODO: ccall jl_first_argument_datatype here +function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState) + atype = limit_tuple_type(atype, sv.params) + atype_params = unwrap_unionall(atype).parameters + ft = unwrap_unionall(atype_params[1]) # TODO: ccall jl_first_argument_datatype here isa(ft, DataType) || return Any # the function being called is unknown. can't properly handle this backedge right now ftname = ft.name isdefined(ftname, :mt) || return Any # not callable. should be Bottom, but can't track this backedge right now @@ -1846,9 +1896,9 @@ function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::In end min_valid = UInt[typemin(UInt)] max_valid = UInt[typemax(UInt)] - splitunions = 1 < countunionsplit(argtypes) <= sv.params.MAX_UNION_SPLITTING + splitunions = 1 < countunionsplit(atype_params) <= sv.params.MAX_UNION_SPLITTING if splitunions - splitsigs = switchtupleunion(argtype) + splitsigs = switchtupleunion(atype) applicable = Any[] for sig_n in splitsigs xapplicable = _methods_by_ftype(sig_n, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid) @@ -1856,7 +1906,7 @@ function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::In append!(applicable, xapplicable) end else - applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid) + applicable = _methods_by_ftype(atype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid) if applicable === false # this means too many methods matched # (assume this will always be true, so we don't compute / update valid age in this case) @@ -1867,6 +1917,7 @@ function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::In applicable = applicable::Array{Any,1} napplicable = length(applicable) rettype = Bottom + edgecycle = false for i in 1:napplicable match = applicable[i]::SimpleVector method = match[3]::Method @@ -1878,18 +1929,29 @@ function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::In if splitunions splitsigs = switchtupleunion(sig) for sig_n in splitsigs - rt = abstract_call_method(method, f, sig_n, svec(), sv) + rt, edgecycle1 = abstract_call_method(method, sig_n, svec(), sv) + edgecycle |= edgecycle1::Bool rettype = tmerge(rettype, rt) rettype === Any && break end rettype === Any && break else - rt = abstract_call_method(method, f, sig, match[2]::SimpleVector, sv) + rt, edgecycle = abstract_call_method(method, sig, match[2]::SimpleVector, sv) rettype = tmerge(rettype, rt) rettype === Any && break end end - if !(rettype === Any) + if napplicable == 1 && !edgecycle && isa(rettype, Type) && sv.params.ipo_constant_propagation + # if there's a possibility we could constant-propagate a better result + # (hopefully without doing too much work), try to do that now + # TODO: it feels like this could be better integrated into abstract_call_method / typeinf_edge + const_rettype = abstract_call_method_with_const_args(argtypes, applicable[1]::SimpleVector, sv) + if const_rettype ⊑ rettype + # use the better result, if it's a refinement of rettype + rettype = const_rettype + end + end + if !(rettype === Any) # adding a new method couldn't refine (widen) this type fullmatch = false for i in napplicable:-1:1 match = applicable[i]::SimpleVector @@ -1909,7 +1971,103 @@ function abstract_call_gf_by_type(@nospecialize(f), @nospecialize(atype), sv::In return rettype end -function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(sig), sparams::SimpleVector, sv::InferenceState) +function cache_lookup(code::MethodInstance, argtypes::Vector{Any}, cache::Vector{InferenceResult}) + method = code.def::Method + nargs::Int = method.nargs + method.isva && (nargs -= 1) + for cache_code in cache + # try to search cache first + cache_args = cache_code.args + if cache_code.linfo === code && length(cache_args) >= nargs + cache_match = true + # verify that the trailing args (va) aren't Const + for i in (nargs + 1):length(cache_args) + if isa(cache_args[i], Const) + cache_match = false + break + end + end + cache_match || continue + for i in 1:nargs + a = argtypes[i] + ca = cache_args[i] + # verify that all Const argument types match between the call and cache + if (isa(a, Const) || isa(ca, Const)) && !(a === ca) + cache_match = false + break + end + end + cache_match || continue + return cache_code + end + end + return nothing +end + +function abstract_call_method_with_const_args(argtypes::Vector{Any}, match::SimpleVector, sv::InferenceState) + method = match[3]::Method + nargs::Int = method.nargs + method.isva && (nargs -= 1) + length(argtypes) >= nargs || return Any # probably limit_tuple_type made this non-matching method apparently match + haveconst = false + for i in 1:nargs + a = argtypes[i] + if isa(a, Const) && !isdefined(typeof(a.val), :instance) + if !isleaftype(a.val) # alternately: !isa(a.val, DataType) || !isconstType(Type{a.val}) + # have new information from argtypes that wasn't available from the signature + haveconst = true + break + end + end + end + haveconst || return Any + sig = match[1] + sparams = match[2]::SimpleVector + code = code_for_method(method, sig, sparams, sv.params.world) + code === nothing && return Any + code = code::MethodInstance + # decide if it's likely to be worthwhile + cache_inlineable = false + if isdefined(code, :inferred) + cache_inf = code.inferred + if !(cache_inf === nothing) + cache_src_inferred = ccall(:jl_ast_flag_inferred, Bool, (Any,), cache_inf) + cache_src_inlineable = ccall(:jl_ast_flag_inlineable, Bool, (Any,), cache_inf) + cache_inlineable = cache_src_inferred && cache_src_inlineable + end + end + if !cache_inlineable && !sv.params.aggressive_constant_propagation + # in this case, see if all of the arguments are constants + for i in 1:nargs + a = argtypes[i] + if !isa(a, Const) && !isconstType(a) + return Any + end + end + end + inf_result = cache_lookup(code, argtypes, sv.params.cache) + if inf_result === nothing + inf_result = InferenceResult(code) + atypes = get_argtypes(inf_result) + for i in 1:nargs + a = argtypes[i] + if a isa Const + atypes[i] = a # inject Const argtypes into inference + end + end + frame = InferenceState(inf_result, #=optimize=#true, #=cache=#false, sv.params) + frame.limited = true + frame.parent = sv + push!(sv.params.cache, inf_result) + typeinf(frame) + end + result = inf_result.result + isa(result, InferenceState) && return Any # TODO: is this recursive constant inference? + add_backedge!(inf_result.linfo, sv) + return result +end + +function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, sv::InferenceState) topmost = nothing # Limit argument type tuple growth of functions: # look through the parents list to see if there's a call to the same method @@ -1917,6 +2075,7 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si # Returns the topmost occurrence of that repeated edge. cyclei = 0 infstate = sv + edgecycle = false while !(infstate === nothing) infstate = infstate::InferenceState if method === infstate.linfo.def @@ -1924,6 +2083,7 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si # avoid widening when detecting self-recursion # TODO: merge call cycle and return right away topmost = nothing + edgecycle = true break end if topmost === nothing @@ -1936,6 +2096,7 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si # all items in here are mutual parents of all others if parent.linfo.def === sv.linfo.def topmost = infstate + edgecycle = true break end end @@ -1945,6 +2106,7 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si parent = parent::InferenceState if parent.cached && parent.linfo.def === sv.linfo.def topmost = infstate + edgecycle = true end end end @@ -2006,8 +2168,12 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si end rt, edge = typeinf_edge(method, sig, sparams, sv) - edge !== nothing && add_backedge!(edge::MethodInstance, sv) - return rt + if edge === nothing + edgecycle = true + else + add_backedge!(edge::MethodInstance, sv) + end + return rt, edgecycle end # determine whether `ex` abstractly evals to constant `c` @@ -2147,7 +2313,7 @@ function abstract_apply(@nospecialize(aft), fargs::Vector{Any}, aargtypes::Vecto rt = abstract_call(aft.parameters[1], (), ct, vtypes, sv) else astype = argtypes_to_type(ct) - rt = abstract_call_gf_by_type(nothing, astype, sv) + rt = abstract_call_gf_by_type(nothing, ct, astype, sv) end res = tmerge(res, rt) if res === Any @@ -2160,7 +2326,7 @@ end # TODO: this function is a very buggy and poor model of the return_type function # since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type, # while this assumes that it is a precisely accurate and exact model of both -function return_type_tfunc(@nospecialize(argtypes), vtypes::VarTable, sv::InferenceState) +function return_type_tfunc(argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState) if length(argtypes) == 3 tt = argtypes[3] if isa(tt, Const) || (isType(tt) && !has_free_typevars(tt)) @@ -2176,7 +2342,7 @@ function return_type_tfunc(@nospecialize(argtypes), vtypes::VarTable, sv::Infere elseif isconstType(aft) rt = abstract_call(aft.parameters[1], (), argtypes_vec, vtypes, sv) else - rt = abstract_call_gf_by_type(nothing, astype, sv) + rt = abstract_call_gf_by_type(nothing, argtypes_vec, astype, sv) end if isa(rt, Const) # output was computed to be constant @@ -2198,7 +2364,7 @@ function return_type_tfunc(@nospecialize(argtypes), vtypes::VarTable, sv::Infere return NF end -function pure_eval_call(@nospecialize(f), @nospecialize(argtypes), @nospecialize(atype), sv::InferenceState) +function pure_eval_call(@nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState) for i = 2:length(argtypes) a = argtypes[i] if !(isa(a,Const) || isconstType(a)) @@ -2231,17 +2397,6 @@ end argtypes_to_type(argtypes::Array{Any,1}) = Tuple{anymap(widenconst, argtypes)...} -_Pair_name = nothing -function Pair_name() - global _Pair_name - if _Pair_name === nothing - if isdefined(Main, :Base) && isdefined(Main.Base, :Pair) - _Pair_name = Main.Base.Pair.body.body.name - end - end - return _Pair_name -end - _typename(a) = Union{} _typename(a::Vararg) = Any _typename(a::TypeVar) = Any @@ -2396,7 +2551,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt elseif length(fargs) == 2 && istopfunction(tm, f, :!) aty = argtypes[2] if isa(aty, Conditional) - abstract_call_gf_by_type(f, Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)` + abstract_call_gf_by_type(f, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)` return Conditional(aty.var, aty.elsetype, aty.vtype) end elseif length(fargs) == 3 && istopfunction(tm, f, :!==) @@ -2415,37 +2570,6 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt return rty end - if la>2 && argtypes[3] ⊑ Int - at2 = widenconst(argtypes[2]) - if la==3 && at2 <: SimpleVector && istopfunction(tm, f, :getindex) - if isa(argtypes[2], Const) && isa(argtypes[3], Const) - svecval = argtypes[2].val - idx = argtypes[3].val - if isa(idx, Int) && 1 <= idx <= length(svecval) && - isassigned(svecval, idx) - return Const(getindex(svecval, idx)) - end - end - elseif (at2 <: Tuple || - (isa(at2, DataType) && (at2::DataType).name === Pair_name())) - # allow tuple indexing functions to take advantage of constant - # index arguments. - if istopfunction(tm, f, :getindex) && la==3 - return getfield_tfunc(argtypes[2], argtypes[3]) - elseif istopfunction(tm, f, :next) && la==3 - t1 = widenconst(getfield_tfunc(argtypes[2], argtypes[3])) - return t1===Bottom ? Bottom : Tuple{t1, Int} - elseif istopfunction(tm, f, :indexed_next) && la==4 - t1 = widenconst(getfield_tfunc(argtypes[2], argtypes[3])) - return t1===Bottom ? Bottom : Tuple{t1, Int} - end - end - elseif la==2 && argtypes[2] ⊑ SimpleVector && istopfunction(tm, f, :length) - if isa(argtypes[2], Const) - return Const(length(argtypes[2].val)) - end - end - atype = argtypes_to_type(argtypes) t = pure_eval_call(f, argtypes, atype, sv) t !== false && return t @@ -2469,11 +2593,11 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt if a1 ⊑ basenumtype ftimes = Main.Base.:* ta1 = widenconst(a1) - abstract_call_gf_by_type(ftimes, Tuple{typeof(ftimes), ta1, ta1}, sv) + abstract_call_gf_by_type(ftimes, Any[ftimes, a1, a1], Tuple{typeof(ftimes), ta1, ta1}, sv) end end end - return abstract_call_gf_by_type(f, atype, sv) + return abstract_call_gf_by_type(f, argtypes, atype, sv) end function abstract_eval_call(e::Expr, vtypes::VarTable, sv::InferenceState) @@ -2498,7 +2622,7 @@ function abstract_eval_call(e::Expr, vtypes::VarTable, sv::InferenceState) end # non-constant function, but type is known if (isleaftype(ft) || ft <: Type) && !(ft <: Builtin) && !(ft <: IntrinsicFunction) - return abstract_call_gf_by_type(nothing, argtypes_to_type(argtypes), sv) + return abstract_call_gf_by_type(nothing, argtypes, argtypes_to_type(argtypes), sv) end return Any end @@ -3123,20 +3247,24 @@ end # returned instead. function resolve_call_cycle!(linfo::MethodInstance, parent::InferenceState) frame = parent + uncached = false while isa(frame, InferenceState) + uncached |= !frame.cached # ensure we never add an uncached frame to a cycle if frame.linfo === linfo + uncached && return true merge_call_chain!(parent, frame, frame) return frame end for caller in frame.callers_in_cycle if caller.linfo === linfo + uncached && return true merge_call_chain!(parent, frame, caller) return caller end end frame = frame.parent end - return nothing + return false end # build (and start inferring) the inference frame for the linfo @@ -3170,15 +3298,16 @@ function typeinf_edge(method::Method, @nospecialize(atypes), sparams::SimpleVect if !caller.cached && caller.parent === nothing # this caller exists to return to the user # (if we asked resolve_call_cyle, it might instead detect that there is a cycle that it can't merge) - frame = nothing + frame = false else frame = resolve_call_cycle!(code, caller) end - if frame === nothing + if frame === false # completely new code.inInference = true frame = InferenceState(code, #=optimize=#true, #=cached=#true, caller.params) # always optimize and cache edge targets if frame === nothing + # can't get the source for this, so we know nothing code.inInference = false return Any, nothing end @@ -3187,6 +3316,9 @@ function typeinf_edge(method::Method, @nospecialize(atypes), sparams::SimpleVect end typeinf(frame) return frame.bestguess, frame.inferred ? frame.linfo : nothing + elseif frame === true + # unresolvable cycle + return Any, nothing end frame = frame::InferenceState return frame.bestguess, nothing @@ -3277,7 +3409,8 @@ function typeinf_ext(linfo::MethodInstance, world::UInt) else # toplevel lambda - infer directly ccall(:jl_typeinf_begin, Void, ()) - frame = InferenceState(linfo, linfo.inferred::CodeInfo, + result = InferenceResult(linfo) + frame = InferenceState(result, linfo.inferred::CodeInfo, true, true, InferenceParams(world)) typeinf(frame) ccall(:jl_typeinf_end, Void, ()) @@ -3484,6 +3617,7 @@ function typeinf(frame::InferenceState) @assert !(caller.dont_work_on_me) caller.dont_work_on_me = true end + # complete the computation of the src optimizations for caller in frame.callers_in_cycle optimize(caller) if frame.min_valid < caller.min_valid @@ -3493,6 +3627,7 @@ function typeinf(frame::InferenceState) frame.max_valid = caller.max_valid end end + # update and store in the global cache for caller in frame.callers_in_cycle caller.min_valid = frame.min_valid end @@ -3560,7 +3695,7 @@ function optimize(me::InferenceState) # run optimization passes on fulltree force_noinline = true - if me.limited && me.parent !== nothing + if me.limited && me.cached && me.parent !== nothing # a top parent will be cached still, but not this intermediate work me.cached = false me.linfo.inInference = false @@ -3734,11 +3869,18 @@ function finish(me::InferenceState) if cache !== me.linfo me.linfo.inInference = false me.linfo = cache + me.result.linfo = cache end end me.linfo.inInference = false end + # finish updating the result struct + if me.src.inlineable + me.result.src = me.src # stash a copy of the code (for inlining) + end + me.result.result = me.bestguess # record type, and that wip is done and me.linfo can be used as a backedge + # update all of the callers with real backedges by traversing the temporary list of backedges for (i, _) in me.backedges add_backedge!(me.linfo, i) @@ -4589,29 +4731,8 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector isa(methsp[i], TypeVar) && return NF end - # some gf have special tfunc, meaning they wouldn't have been inferred yet - # check the same conditions from abstract_call to detect this case - force_infer = false - if !isdefined(method, :generator) - if method.module == _topmod(method.module) || (isdefined(Main, :Base) && method.module == Main.Base) - la = length(atypes) - if (la==3 && (method.name == :getindex || method.name == :next)) || - (la==4 && method.name == :indexed_next) - if atypes[3] ⊑ Int - at2 = widenconst(atypes[2]) - if (at2 <: Tuple || at2 <: SimpleVector || - (isa(at2, DataType) && (at2::DataType).name === Pair_name())) - force_infer = true - end - end - elseif la == 2 && method.name == :length && atypes[2] ⊑ SimpleVector - force_infer = true - end - end - end - # see if the method has been previously inferred (and cached) - linfo = code_for_method(method, metharg, methsp, sv.params.world, !force_infer) # Union{Void, MethodInstance} + linfo = code_for_method(method, metharg, methsp, sv.params.world, true) # Union{Void, MethodInstance} isa(linfo, MethodInstance) || return invoke_NF(argexprs0, e.typ, atypes0, sv, atype_unlimited, invoke_data) linfo = linfo::MethodInstance @@ -4621,59 +4742,46 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector return inline_as_constant(linfo.inferred_const, argexprs, sv, invoke_data) end - # see if the method has a current InferenceState frame - # or existing inferred code info - frame = nothing # Union{Void, InferenceState} - inferred = nothing # Union{Void, CodeInfo} - if force_infer && la > 2 && isa(atypes[3], Const) - # Since we inferred this with the information that atypes[3]::Const, - # must inline with that same information. - # We do that by overriding the argument type, - # while ensuring we don't cache that information - # This isn't particularly important for `getindex`, - # as we'll be able to fix that up at the end of inlinable when we verify the return type. - # But `next` and `indexed_next` make tuples which would end up burying some of that information in the AST - # where we can't easily correct it afterwards. - frame = InferenceState(linfo, #=optimize=#true, #=cache=#false, sv.params) - frame.stmt_types[1][3] = VarState(atypes[3], false) - typeinf(frame) + # see if the method has a InferenceResult in the current cache + # or an existing inferred code info store in `.inferred` + haveconst = false + for i in 1:length(atypes) + a = atypes[i] + if isa(a, Const) && !isdefined(typeof(a.val), :instance) + if !isleaftype(a.val) # alternately: !isa(a.val, DataType) || !isconstType(Type{a.val}) + # have new information from argtypes that wasn't available from the signature + haveconst = true + break + end + end + end + if haveconst + inf_result = cache_lookup(linfo, atypes, sv.params.cache) # Union{Void, InferenceResult} else - if isdefined(linfo, :inferred) && linfo.inferred !== nothing - # use cache - inferred = linfo.inferred - elseif force_infer - # create inferred code on-demand - # but if we decided in the past not to try to infer this particular signature - # (due to signature coarsening in abstract_call_gf_by_type) - # don't infer it now, as attempting to force it now would be a bad idea (non terminating) - frame = typeinf_frame(linfo, #=optimize=#true, #=cache=#true, sv.params) - end - end - - # compute the return value - if isa(frame, InferenceState) && !frame.src.inferred - frame = nothing - end - if isa(frame, InferenceState) - linfo = frame.linfo - inferred = frame.src - if frame.const_api # handle like jlcall_api == 2 - if frame.inferred || !frame.cached + inf_result = nothing + end + if isa(inf_result, InferenceResult) && isa(inf_result.src, CodeInfo) + linfo = inf_result.linfo + result = inf_result.result + if (inf_result.src::CodeInfo).pure + if isa(result, Const) + inferred_const = result.val + elseif isconstType(result) + inferred_const = result.parameters[1] + end + if @isdefined inferred_const add_backedge!(linfo, sv) - else - add_backedge!(frame, sv) - end - if isa(frame.bestguess, Const) - inferred_const = (frame.bestguess::Const).val - else - @assert isconstType(frame.bestguess) - inferred_const = frame.bestguess.parameters[1] + return inline_as_constant(inferred_const, argexprs, sv, invoke_data) end - return inline_as_constant(inferred_const, argexprs, sv, invoke_data) end - rettype = widenconst(frame.bestguess) - else + inferred = inf_result.src + rettype = widenconst(result) + elseif isdefined(linfo, :inferred) + inferred = linfo.inferred rettype = linfo.rettype + else + rettype = Any + inferred = nothing end # check that the code is inlineable @@ -4688,26 +4796,19 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector invoke_data) end + # create the backedge + add_backedge!(linfo, sv) + + # prepare the code object for mutation if isa(inferred, CodeInfo) src = inferred ast = copy_exprargs(inferred.code) else - src = ccall(:jl_uncompress_ast, Any, (Any, Any), method, inferred)::CodeInfo + src = ccall(:jl_uncompress_ast, Any, (Any, Any), method, inferred::Vector{UInt8})::CodeInfo ast = src.code end ast = ast::Array{Any,1} - - # create the backedge - if isa(frame, InferenceState) && !frame.inferred && frame.cached - # in this case, the actual backedge linfo hasn't been computed - # yet, but will be when inference on the frame finishes - add_backedge!(frame, sv) - else - add_backedge!(linfo, sv) - end - nm = length(unwrap_unionall(metharg).parameters) - body = Expr(:block) body.args = ast diff --git a/test/inference.jl b/test/inference.jl index 3894e43569add..bfc197a686161 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -1291,3 +1291,8 @@ let T1 = Array{Float64}, T2 = Array{_1,2} where _1 rt = Base.return_types(g, (Union{Ref{Array{Float64}}, Ref{Array{Float32}}},))[1] @test rt >: Union{Type{Array{Float64}}, Type{Array{Float32}}} end + +# Demonstrate IPO constant propagation (#24362) +f_constant(x) = convert(Int, x) +g_test_constant() = (f_constant(3) == 3 && f_constant(4) == 4 ? true : "BAD") +@test @inferred g_test_constant() From e4112d2d3e0dfa273110e8bd0ec56ce860f5c5e6 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Tue, 7 Nov 2017 10:57:17 -0500 Subject: [PATCH 2/3] inference: fix various minor inconsistencies --- base/inference.jl | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/base/inference.jl b/base/inference.jl index 69f9db84dad8e..ae13c8604e84e 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -2474,7 +2474,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt end return isa(rt, TypeVar) ? rt.ub : rt elseif f === Core.kwfunc - if length(fargs) == 2 + if length(argtypes) == 2 ft = widenconst(argtypes[2]) if isa(ft, DataType) && isdefined(ft.name, :mt) && isdefined(ft.name.mt, :kwsorter) return Const(ft.name.mt.kwsorter) @@ -2485,10 +2485,10 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt lb = Union{} ub = Any ub_certain = lb_certain = true - if length(fargs) >= 2 && isa(argtypes[2], Const) + if length(argtypes) >= 2 && isa(argtypes[2], Const) nv = argtypes[2].val ubidx = 3 - if length(fargs) >= 4 + if length(argtypes) >= 4 ubidx = 4 if isa(argtypes[3], Const) lb = argtypes[3].val @@ -2499,7 +2499,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt return TypeVar end end - if length(fargs) >= ubidx + if length(argtypes) >= ubidx if isa(argtypes[ubidx], Const) ub = argtypes[ubidx].val elseif isType(argtypes[ubidx]) @@ -2514,7 +2514,7 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt end return TypeVar elseif f === UnionAll - if length(fargs) == 3 + if length(argtypes) == 3 canconst = true if isa(argtypes[3], Const) body = argtypes[3].val @@ -2548,13 +2548,15 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt if rt_rt !== NF return rt_rt end - elseif length(fargs) == 2 && istopfunction(tm, f, :!) + elseif length(argtypes) == 2 && istopfunction(tm, f, :!) + # handle Conditional propagation through !Bool aty = argtypes[2] if isa(aty, Conditional) abstract_call_gf_by_type(f, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)` return Conditional(aty.var, aty.elsetype, aty.vtype) end - elseif length(fargs) == 3 && istopfunction(tm, f, :!==) + elseif length(argtypes) == 3 && istopfunction(tm, f, :!==) + # mark !== as exactly a negated call to === rty = abstract_call((===), fargs, argtypes, vtypes, sv) if isa(rty, Conditional) return Conditional(rty.var, rty.elsetype, rty.vtype) # swap if-else @@ -2562,12 +2564,30 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt return Const(rty.val === false) end return rty - elseif length(fargs) == 3 && istopfunction(tm, f, :(>:)) + elseif length(argtypes) == 3 && istopfunction(tm, f, :(>:)) + # mark issupertype as a exact alias for issubtype # swap T1 and T2 arguments and call <: - fargs = Any[<:, fargs[3], fargs[2]] + if length(fargs) == 3 + fargs = Any[<:, fargs[3], fargs[2]] + else + fargs = () + end argtypes = Any[typeof(<:), argtypes[3], argtypes[2]] rty = abstract_call(<:, fargs, argtypes, vtypes, sv) return rty + elseif length(argtypes) == 2 && isa(argtypes[2], Const) && isa(argtypes[2].val, SimpleVector) && istopfunction(tm, f, :length) + # mark length(::SimpleVector) as @pure + return Const(length(argtypes[2].val)) + elseif length(argtypes) == 3 && isa(argtypes[2], Const) && isa(argtypes[3], Const) && + isa(argtypes[2].val, SimpleVector) && isa(argtypes[3].val, Int) && istopfunction(tm, f, :getindex) + # mark getindex(::SimpleVector, i::Int) as @pure + svecval = argtypes[2].val::SimpleVector + idx = argtypes[3].val::Int + if 1 <= idx <= length(svecval) && isassigned(svecval, idx) + return Const(getindex(svecval, idx)) + end + elseif length(argtypes) == 2 && istopfunction(tm, f, :typename) + return typename_static(argtypes[2]) end atype = argtypes_to_type(argtypes) @@ -2576,8 +2596,6 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt if istopfunction(tm, f, :typejoin) || f === return_type return Type # don't try to infer these function edges directly -- it won't actually come up with anything useful - elseif length(argtypes) == 2 && istopfunction(tm, f, :typename) - return typename_static(argtypes[2]) end if sv.params.inlining From 62167c8e2b94d133df72900708fb66da05ef6e41 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Tue, 7 Nov 2017 12:33:41 -0500 Subject: [PATCH 3/3] inference: enable evaluation of pure intrinsics --- base/inference.jl | 37 +++++++++++++++++++++++++++++-------- test/inference.jl | 3 +++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/base/inference.jl b/base/inference.jl index ae13c8604e84e..6b5f867eabf7f 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -1768,6 +1768,12 @@ function builtin_tfunction(@nospecialize(f), argtypes::Array{Any,1}, return Any end if isa(f, IntrinsicFunction) + if is_pure_intrinsic_infer(f) && all(a -> isa(a, Const), argtypes) + argvals = anymap(a -> a.val, argtypes) + try + return Const(f(argvals...)) + end + end iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1 if iidx < 0 || iidx > length(t_ifunc) # invalid intrinsic @@ -4235,23 +4241,38 @@ const _pure_builtins = Any[tuple, svec, fieldtype, apply_type, ===, isa, typeof, # known effect-free calls (might not be affect-free) const _pure_builtins_volatile = Any[getfield, arrayref, isdefined, Core.sizeof] -function is_pure_intrinsic(f::IntrinsicFunction) +# whether `f` is pure for Inference +function is_pure_intrinsic_infer(f::IntrinsicFunction) + return !(f === Intrinsics.pointerref || # this one is volatile + f === Intrinsics.pointerset || # this one is never effect-free + f === Intrinsics.llvmcall || # this one is never effect-free + f === Intrinsics.arraylen || # this one is volatile + f === Intrinsics.sqrt_llvm || # this one may differ at runtime (by a few ulps) + f === Intrinsics.cglobal) # cglobal lookup answer changes at runtime +end + +# whether `f` is pure for Optimizations +function is_pure_intrinsic_optim(f::IntrinsicFunction) return !(f === Intrinsics.pointerref || # this one is volatile f === Intrinsics.pointerset || # this one is never effect-free f === Intrinsics.llvmcall || # this one is never effect-free - f === Intrinsics.checked_sdiv_int || + f === Intrinsics.arraylen || # this one is volatile + f === Intrinsics.checked_sdiv_int || # these may throw errors f === Intrinsics.checked_udiv_int || f === Intrinsics.checked_srem_int || f === Intrinsics.checked_urem_int || - f === Intrinsics.sqrt_llvm || f === Intrinsics.cglobal) # cglobal throws an error for symbol-not-found end function is_pure_builtin(@nospecialize(f)) - return (contains_is(_pure_builtins, f) || - contains_is(_pure_builtins_volatile, f) || - (isa(f,IntrinsicFunction) && is_pure_intrinsic(f)) || - f === return_type) + if isa(f, IntrinsicFunction) + return is_pure_intrinsic_optim(f) + elseif isa(f, Builtin) + return (contains_is(_pure_builtins, f) || + contains_is(_pure_builtins_volatile, f)) + else + return f === return_type + end end function statement_effect_free(@nospecialize(e), src::CodeInfo, mod::Module) @@ -4592,7 +4613,7 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector (isbits(val) && Core.sizeof(val) <= MAX_INLINE_CONST_SIZE && (contains_is(_pure_builtins, f) || (f === getfield && effect_free(e, sv.src, sv.mod, false)) || - (isa(f,IntrinsicFunction) && is_pure_intrinsic(f))))) + (isa(f, IntrinsicFunction) && is_pure_intrinsic_optim(f))))) return inline_as_constant(val, argexprs, sv, nothing) end end diff --git a/test/inference.jl b/test/inference.jl index bfc197a686161..e930eb70b2629 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -1296,3 +1296,6 @@ end f_constant(x) = convert(Int, x) g_test_constant() = (f_constant(3) == 3 && f_constant(4) == 4 ? true : "BAD") @test @inferred g_test_constant() + +f_pure_add() = (1 + 1 == 2) ? true : "FAIL" +@test @inferred f_pure_add()