diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 3ca8b29cf772f0..60393ac36db2a6 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -42,9 +42,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), fullmatch = Bool[] if splitunions splitsigs = switchtupleunion(atype) + split_argtypes = switchtupleunion(argtypes) applicable = Any[] + # arrays like `argtypes`, including constants, for each match + applicable_argtypes = Vector{Any}[] infos = MethodMatchInfo[] - for sig_n in splitsigs + for j in 1:length(splitsigs) + sig_n = splitsigs[j] mt = ccall(:jl_method_table_for, Any, (Any,), sig_n) if mt === nothing add_remark!(interp, sv, "Could not identify method table for call") @@ -58,6 +62,10 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end push!(infos, MethodMatchInfo(matches)) append!(applicable, matches) + for _ in 1:length(matches) + push!(applicable_argtypes, split_argtypes[j]) + end + # @assert argtypes_to_type(split_argtypes[j]) === sig_n "invalid union split" valid_worlds = intersect(valid_worlds, matches.valid_worlds) thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches) found = false @@ -93,15 +101,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), info = MethodMatchInfo(matches) applicable = matches.matches valid_worlds = matches.valid_worlds + applicable_argtypes = nothing end update_valid_age!(sv, valid_worlds) applicable = applicable::Array{Any,1} napplicable = length(applicable) rettype = Bottom edgecycle = false - edges = Any[] - nonbot = 0 # the index of the only non-Bottom inference result if > 0 - seen = 0 # number of signatures actually inferred + edges = MethodInstance[] istoplevel = sv.linfo.def isa Module multiple_matches = napplicable > 1 @@ -124,8 +131,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), break end sigtuple = unwrap_unionall(sig)::DataType - splitunions = false this_rt = Bottom + splitunions = false # TODO: splitunions = 1 < unionsplitcost(sigtuple.parameters) * napplicable <= InferenceParams(interp).MAX_UNION_SPLITTING # currently this triggers a bug in inference recursion detection if splitunions @@ -136,8 +143,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), push!(edges, edge) end edgecycle |= edgecycle1::Bool + this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] + const_rt = abstract_call_method_with_const_args(interp, rt, f, this_argtypes, match, sv, edgecycle) + if const_rt !== rt && const_rt ⊑ rt + rt = const_rt + end this_rt = tmerge(this_rt, rt) - this_rt === Any && break + if this_rt === Any + break + end end else this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) @@ -145,33 +159,19 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), if edge !== nothing push!(edges, edge) end - end - if this_rt !== Bottom - if nonbot === 0 - nonbot = i - else - nonbot = -1 + this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] + const_this_rt = abstract_call_method_with_const_args(interp, this_rt, f, this_argtypes, match, sv, edgecycle) + if const_this_rt !== this_rt && const_this_rt ⊑ this_rt + this_rt = const_this_rt end end - seen += 1 rettype = tmerge(rettype, this_rt) - rettype === Any && break - end - # try constant propagation if only 1 method is inferred to non-Bottom - # this is in preparation for inlining, or improving the return result - is_unused = call_result_unused(sv) - if nonbot > 0 && seen == napplicable && (!edgecycle || !is_unused) && - is_improvable(rettype) && InferenceParams(interp).ipo_constant_propagation - # if there's a possibility we could constant-propagate a better result - # (hopefully without doing too much work), try to do that now - # TODO: it feels like this could be better integrated into abstract_call_method / typeinf_edge - const_rettype = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle) - if const_rettype ⊑ rettype - # use the better result, if it's a refinement of rettype - rettype = const_rettype + if rettype === Any + break end end - if is_unused && !(rettype === Bottom) + + if call_result_unused(sv) && !(rettype === Bottom) add_remark!(interp, sv, "Call result type was widened because the return value is unused") # We're mainly only here because the optimizer might want this code, # but we ourselves locally don't typically care about it locally @@ -207,135 +207,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(rettype, info) end - -function const_prop_profitable(@nospecialize(arg)) - # have new information from argtypes that wasn't available from the signature - if isa(arg, PartialStruct) - for b in arg.fields - isconstType(b) && return true - const_prop_profitable(b) && return true - end - elseif !isa(arg, Const) || (isa(arg.val, Symbol) || isa(arg.val, Type) || (!isa(arg.val, String) && !ismutable(arg.val))) - # don't consider mutable values or Strings useful constants - return true - end - return false -end - -# This is a heuristic to avoid trying to const prop through complicated functions -# where we would spend a lot of time, but are probably unliekly to get an improved -# result anyway. -function const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) - # Peek at the inferred result for the function to determine if the optimizer - # was able to cut it down to something simple (inlineable in particular). - # If so, there's a good chance we might be able to const prop all the way - # through and learn something new. - code = get(code_cache(interp), mi, nothing) - declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source) - cache_inlineable = declared_inline - if isdefined(code, :inferred) && !cache_inlineable - cache_inf = code.inferred - if !(cache_inf === nothing) - cache_src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), cache_inf) - cache_src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), cache_inf) - cache_inlineable = cache_src_inferred && cache_src_inlineable - end - end - if !cache_inlineable - return false - end - return true -end - -function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, sv::InferenceState, edgecycle::Bool) - method = match.method - nargs::Int = method.nargs - method.isva && (nargs -= 1) - length(argtypes) >= nargs || return Any - haveconst = false - allconst = true - # see if any or all of the arguments are constant and propagating constants may be worthwhile - for a in argtypes - a = widenconditional(a) - if allconst && !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) - allconst = false - end - if !haveconst && has_nontrivial_const_info(a) && const_prop_profitable(a) - haveconst = true - end - if haveconst && !allconst - break - end - end - haveconst || improvable_via_constant_propagation(rettype) || return Any - if nargs > 1 - if istopfunction(f, :getindex) || istopfunction(f, :setindex!) - arrty = argtypes[2] - # don't propagate constant index into indexing of non-constant array - if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) - return Any - elseif arrty ⊑ Array - return Any - end - elseif istopfunction(f, :iterate) - itrty = argtypes[2] - if itrty ⊑ Array - return Any - end - end - end - if !allconst && (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || - istopfunction(f, :(==)) || istopfunction(f, :!=) || - istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) || - istopfunction(f, :<<) || istopfunction(f, :>>)) - return Any - end - force_inference = allconst || InferenceParams(interp).aggressive_constant_propagation - if istopfunction(f, :getproperty) || istopfunction(f, :setproperty!) - force_inference = true - end - mi = specialize_method(match, !force_inference) - mi === nothing && return Any - mi = mi::MethodInstance - # decide if it's likely to be worthwhile - if !force_inference && !const_prop_heuristic(interp, method, mi) - return Any - end - inf_cache = get_inference_cache(interp) - inf_result = cache_lookup(mi, argtypes, inf_cache) - if inf_result === nothing - if edgecycle - # if there might be a cycle, check to make sure we don't end up - # calling ourselves here. - infstate = sv - cyclei = 0 - while !(infstate === nothing) - if method === infstate.linfo.def && any(infstate.result.overridden_by_const) - return Any - end - if cyclei < length(infstate.callers_in_cycle) - cyclei += 1 - infstate = infstate.callers_in_cycle[cyclei] - else - cyclei = 0 - infstate = infstate.parent - end - end - end - inf_result = InferenceResult(mi, argtypes) - frame = InferenceState(inf_result, #=cache=#false, interp) - frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it - frame.parent = sv - push!(inf_cache, inf_result) - typeinf(interp, frame) || return Any - end - result = inf_result.result - # if constant inference hits a cycle, just bail out - isa(result, InferenceState) && return Any - add_backedge!(inf_result.linfo, sv) - return result -end - const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result." function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState) @@ -494,6 +365,168 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp return rt, edgecycle, edge end +function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), + @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + sv::InferenceState, edgecycle::Bool) + mi = maybe_get_const_prop_profitable(interp, rettype, f, argtypes, match, sv, edgecycle) + mi === nothing && return Any + # try constant prop' + inf_cache = get_inference_cache(interp) + inf_result = cache_lookup(mi, argtypes, inf_cache) + if inf_result === nothing + if edgecycle + # if there might be a cycle, check to make sure we don't end up + # calling ourselves here. + infstate = sv + cyclei = 0 + while !(infstate === nothing) + if match.method === infstate.linfo.def && any(infstate.result.overridden_by_const) + return Any + end + if cyclei < length(infstate.callers_in_cycle) + cyclei += 1 + infstate = infstate.callers_in_cycle[cyclei] + else + cyclei = 0 + infstate = infstate.parent + end + end + end + inf_result = InferenceResult(mi, argtypes) + frame = InferenceState(inf_result, #=cache=#false, interp) + frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it + frame.parent = sv + push!(inf_cache, inf_result) + typeinf(interp, frame) || return Any + end + result = inf_result.result + # if constant inference hits a cycle, just bail out + isa(result, InferenceState) && return Any + add_backedge!(mi, sv) + return result +end + +# if there's a possibility we could get a better result (hopefully without doing too much work) +# returns `MethodInstance` with constant arguments, returns nothing otherwise +function maybe_get_const_prop_profitable(interp::AbstractInterpreter, @nospecialize(rettype), + @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + sv::InferenceState, edgecycle::Bool) + const_prop_entry_heuristic(interp, rettype, sv, edgecycle) || return nothing + method = match.method + nargs::Int = method.nargs + method.isva && (nargs -= 1) + length(argtypes) >= nargs || return nothing + const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, rettype) || return nothing + allconst = is_allconst(argtypes) + const_prop_function_heuristic(interp, f, argtypes, nargs, allconst) || return nothing + force = force_const_prop(interp, f, allconst) + mi = specialize_method(match, !force) + mi === nothing && return nothing + mi = mi::MethodInstance + force || const_prop_methodinstance_heuristic(interp, method, mi) || return nothing + return mi +end + +function const_prop_entry_heuristic(interp::AbstractInterpreter, @nospecialize(rettype), sv::InferenceState, edgecycle::Bool) + call_result_unused(sv) && edgecycle && return false + return is_improvable(rettype) && InferenceParams(interp).ipo_constant_propagation +end + +# see if propagating constants may be worthwhile +function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any}) + for a in argtypes + a = widenconditional(a) + if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a) + return true + end + end + return false +end + +function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype)) + return improvable_via_constant_propagation(rettype) +end + +function is_allconst(argtypes::Vector{Any}) + for a in argtypes + a = widenconditional(a) + if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) + return false + end + end + return true +end + +function is_const_prop_profitable_arg(@nospecialize(arg)) + # have new information from argtypes that wasn't available from the signature + if isa(arg, PartialStruct) + for b in arg.fields + isconstType(b) && return true + is_const_prop_profitable_arg(b) && return true + end + elseif !isa(arg, Const) || (isa(arg.val, Symbol) || isa(arg.val, Type) || (!isa(arg.val, String) && !ismutable(arg.val))) + # don't consider mutable values or Strings useful constants + return true + end + return false +end + +function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool) + if nargs > 1 + if istopfunction(f, :getindex) || istopfunction(f, :setindex!) + arrty = argtypes[2] + # don't propagate constant index into indexing of non-constant array + if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) + return false + elseif arrty ⊑ Array + return false + end + elseif istopfunction(f, :iterate) + itrty = argtypes[2] + if itrty ⊑ Array + return false + end + end + end + if !allconst && (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || + istopfunction(f, :(==)) || istopfunction(f, :!=) || + istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) || + istopfunction(f, :<<) || istopfunction(f, :>>)) + return false + end + return true +end + +function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), allconst::Bool) + (allconst || InferenceParams(interp).aggressive_constant_propagation) && return true + return istopfunction(f, :getproperty) || istopfunction(f, :setproperty!) +end + +# This is a heuristic to avoid trying to const prop through complicated functions +# where we would spend a lot of time, but are probably unliekly to get an improved +# result anyway. +function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) + # Peek at the inferred result for the function to determine if the optimizer + # was able to cut it down to something simple (inlineable in particular). + # If so, there's a good chance we might be able to const prop all the way + # through and learn something new. + code = get(code_cache(interp), mi, nothing) + declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source) + cache_inlineable = declared_inline + if isdefined(code, :inferred) && !cache_inlineable + cache_inf = code.inferred + if !(cache_inf === nothing) + cache_src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), cache_inf) + cache_src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), cache_inf) + cache_inlineable = cache_src_inferred && cache_src_inlineable + end + end + if !cache_inlineable + return false + end + return true +end + # This is only for use with `Conditional`. # In general, usage of this is wrong. function ssa_def_slot(@nospecialize(arg), sv::InferenceState) diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index d5db2ab70ef1f4..5401b52582f964 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -180,10 +180,16 @@ function switchtupleunion(@nospecialize(ty)) return _switchtupleunion(Any[tparams...], length(tparams), [], ty) end +switchtupleunion(t::Vector{Any}) = _switchtupleunion(t, length(t), [], nothing) + function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt)) if i == 0 - tpl = rewrap_unionall(Tuple{t...}, origt) - push!(tunion, tpl) + if origt === nothing + push!(tunion, copy(t)) + else + tpl = rewrap_unionall(Tuple{t...}, origt) + push!(tunion, tpl) + end else ti = t[i] if isa(ti, Union) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 062d8c81cf7762..e6af1db5dab1fe 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3011,3 +3011,67 @@ g38888() = S38888(Base.inferencebarrier(3), nothing) f_inf_error_bottom(x::Vector) = isempty(x) ? error(x[1]) : x @test Core.Compiler.return_type(f_inf_error_bottom, Tuple{Vector{Any}}) == Vector{Any} + +@testset "constant prop' for union split signature" begin + anonymous_module() = Core.eval(@__MODULE__, :(module $(gensym()) end))::Module + + # indexing into tuples really relies on constant prop', and we will get looser result + # (`Union{Int,String,Char}`) if constant prop' doesn't happen for splitunion signatures + tt = (Union{Tuple{Int,String},Tuple{Int,Char}},) + @test Base.return_types(tt) do t + getindex(t, 1) + end == Any[Int] + @test Base.return_types(tt) do t + getindex(t, 2) + end == Any[Union{String,Char}] + @test Base.return_types(tt) do t + a, b = t + a + end == Any[Int] + @test Base.return_types(tt) do t + a, b = t + b + end == Any[Union{String,Char}] + + @test (@eval anonymous_module() begin + struct F32 + val::Float32 + _v::Int + end + struct F64 + val::Float64 + _v::Int + end + Base.return_types((Union{F32,F64},)) do f + f.val + end + end) == Any[Union{Float32,Float64}] + + @test (@eval anonymous_module() begin + struct F32 + val::Float32 + _v + end + struct F64 + val::Float64 + _v + end + Base.return_types((Union{F32,F64},)) do f + f.val + end + end) == Any[Union{Float32,Float64}] + + @test Base.return_types((Union{Tuple{Nothing,Any,Any},Tuple{Nothing,Any}},)) do t + getindex(t, 1) + end == Any[Nothing] + + # issue #37610 + @test Base.return_types((typeof(("foo" => "bar", "baz" => nothing)), Int)) do a, i + y = iterate(a, i) + if y !== nothing + (k, v), st = y + return k, v + end + return y + end == Any[Union{Nothing, Tuple{String, Union{Nothing, String}}}] +end