From 4a9c52e1c30a1622520a6325a3c2dac18e358f32 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Thu, 27 May 2021 05:05:27 +0900 Subject: [PATCH] inference: Relax constprop recursion detection (#40561) At the moment, we restrict const prop whenever we see a cycle in methods being called. However, I think this condition can be relaxed slightly: In particular, if the type complexity limiting did not decide to limit the growth of the type in question, I think it should be fine to constant prop as long as there is no cycle in *method instances* (rather than just methods). Fixes #39915, replaces #39918 Co-authored-by: Keno Fisher --- base/compiler/abstractinterpretation.jl | 32 +++++++++++++++---------- test/compiler/inference.jl | 22 +++++++++++++++++ 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 465948c610c42d..553344bbde0f00 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -137,12 +137,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), if splitunions splitsigs = switchtupleunion(sig) for sig_n in splitsigs - rt, edgecycle, edge = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv) + rt, edgecycle, edgelimited, edge = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv) if edge !== nothing push!(edges, edge) end this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] - const_rt, const_result = abstract_call_method_with_const_args(interp, rt, f, this_argtypes, match, sv, edgecycle, false) + const_rt, const_result = abstract_call_method_with_const_args(interp, rt, f, this_argtypes, match, sv, edgecycle, edgelimited, false) if const_rt !== rt && const_rt ⊑ rt rt = const_rt end @@ -156,14 +156,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end end else - this_rt, edgecycle, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) + this_rt, edgecycle, edgelimited, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) if edge !== nothing push!(edges, edge) end # try constant propagation with argtypes for this match # this is in preparation for inlining, or improving the return result this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] - const_this_rt, const_result = abstract_call_method_with_const_args(interp, this_rt, f, this_argtypes, match, sv, edgecycle, false) + const_this_rt, const_result = abstract_call_method_with_const_args(interp, this_rt, f, this_argtypes, match, sv, edgecycle, edgelimited, false) if const_this_rt !== this_rt && const_this_rt ⊑ this_rt this_rt = const_this_rt end @@ -312,7 +312,7 @@ const RECURSION_MSG = "Bounded recursion detected. Call was widened to force con function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState) if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base add_remark!(interp, sv, "Refusing to infer into `depwarn`") - return Any, false, nothing + return Any, false, false, nothing end topmost = nothing # Limit argument type tuple growth of functions: @@ -320,6 +320,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp # and from the same method. # Returns the topmost occurrence of that repeated edge. edgecycle = false + edgelimited = false # The `method_for_inference_heuristics` will expand the given method's generator if # necessary in order to retrieve this field from the generated `CodeInfo`, if it exists. # The other `CodeInfo`s we inspect will already have this field inflated, so we just @@ -380,7 +381,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp # we have a self-cycle in the call-graph, but not in the inference graph (typically): # break this edge now (before we record it) by returning early # (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases) - return Any, true, nothing + return Any, true, true, nothing end topmost = nothing edgecycle = true @@ -429,7 +430,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp # since it's very unlikely that we'll try to inline this, # or want make an invoke edge to its calling convention return type. # (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases) - return Any, true, nothing + return Any, true, true, nothing end add_remark!(interp, sv, RECURSION_MSG) topmost = topmost::InferenceState @@ -437,6 +438,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp poison_callstack(sv, parentframe === nothing ? topmost : parentframe) sig = newsig sparams = svec() + edgelimited = true end end @@ -468,14 +470,14 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp rt, edge = typeinf_edge(interp, method, sig, sparams, sv) if edge === nothing - edgecycle = true + edgecycle = edgelimited = true end - return rt, edgecycle, edge + return rt, edgecycle, edgelimited, edge end function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, - sv::InferenceState, edgecycle::Bool, + sv::InferenceState, edgecycle::Bool, edgelimited::Bool, va_override::Bool) mi = maybe_get_const_prop_profitable(interp, rettype, f, argtypes, match, sv, edgecycle) mi === nothing && return Any, nothing @@ -486,7 +488,11 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp # if there might be a cycle, check to make sure we don't end up # calling ourselves here. if edgecycle && _any(InfStackUnwind(sv)) do infstate - return match.method === infstate.linfo.def && any(infstate.result.overridden_by_const) + # if the type complexity limiting didn't decide to limit the call signature (`edgelimited = false`) + # we can relax the cycle detection by comparing `MethodInstance`s and allow inference to + # propagate different constant elements if the recursion is finite over the lattice + return (edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) && + any(infstate.result.overridden_by_const) end add_remark!(interp, sv, "[constprop] Edge cycle encountered") return Any, nothing @@ -1228,7 +1234,7 @@ end function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState) pushfirst!(argtypes, closure.env) sig = argtypes_to_type(argtypes) - rt, edgecycle, edge = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv) + rt, edgecycle, edgelimited, edge = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv) edge !== nothing && add_backedge!(edge, sv) tt = closure.typ sigT = unwrap_unionall(tt).parameters[1] @@ -1236,7 +1242,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part info = OpaqueClosureCallInfo(match) if !edgecycle const_rettype, result = abstract_call_method_with_const_args(interp, rt, closure, argtypes, - match, sv, edgecycle, closure.isva) + match, sv, edgecycle, edgelimited, closure.isva) if const_rettype ⊑ rt rt = const_rettype end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index ff48b428852afd..8cb247bf958ebe 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3288,3 +3288,25 @@ end == [Union{Some{Float64}, Some{Int}, Some{UInt8}}] true end end + +# Make sure that const prop doesn't fall into cycles that aren't problematic +# in the type domain +f_recurse(x) = x > 1000000 ? x : f_recurse(x+1) +@test Base.return_types() do + f_recurse(1) +end |> first === Int + +# issue #39915 +function f33915(a_tuple, which_ones) + rest = f33915(Base.tail(a_tuple), Base.tail(which_ones)) + if first(which_ones) + (first(a_tuple), rest...) + else + rest + end +end +f33915(a_tuple::Tuple{}, which_ones::Tuple{}) = () +g39915(a_tuple) = f33915(a_tuple, (true, false, true, false)) +@test Base.return_types() do + g39915((1, 1.0, "a", :a)) +end |> first === Tuple{Int, String}