From 046f11ed278c5f23fec60a697a5636e78304922d Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Tue, 1 Jun 2021 06:28:19 +0900 Subject: [PATCH] AbstractInterpreter: refactor for `abstract_call_method` -> `abstract_call_method_with_const_args` chain (#41020) This PR refactors the `abstract_call_method` -> `abstract_call_method_with_const_args` chain, and simplifies the signature of `abstract_call_method_with_const_args`: the newly defined `MethodCallResult` struct wraps a result and context information of `abstract_method_call`, and is passed and consumed by the succeeding `abstract_call_method_with_const_args`. Although our constant-propagation heuristic will be likely to change in the future (as in #40561) and so the signature of `abstract_call_method_with_const_args` is very unstable, hopefully this PR makes it a bit more stable. As an additional benefit, now an external `AbstractInterpreter` can use the context information of `abstract_method_call` (especially `edge::MethodInstance`) within `maybe_get_const_prop_profitable`. --- base/compiler/abstractinterpretation.jl | 73 +++++++++++++++---------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 553344bbde0f0..9f4ca7ac7a4f7 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -137,12 +137,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), if splitunions splitsigs = switchtupleunion(sig) for sig_n in splitsigs - rt, edgecycle, edgelimited, edge = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv) + result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv) + rt, edge = result.rt, result.edge if edge !== nothing push!(edges, edge) end this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] - const_rt, const_result = abstract_call_method_with_const_args(interp, rt, f, this_argtypes, match, sv, edgecycle, edgelimited, false) + const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) if const_rt !== rt && const_rt ⊑ rt rt = const_rt end @@ -156,14 +157,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end end else - this_rt, edgecycle, edgelimited, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) + result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) + this_rt, edge = result.rt, result.edge if edge !== nothing push!(edges, edge) end # try constant propagation with argtypes for this match # this is in preparation for inlining, or improving the return result this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] - const_this_rt, const_result = abstract_call_method_with_const_args(interp, this_rt, f, this_argtypes, match, sv, edgecycle, edgelimited, false) + const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) if const_this_rt !== this_rt && const_this_rt ⊑ this_rt this_rt = const_this_rt end @@ -312,7 +314,7 @@ const RECURSION_MSG = "Bounded recursion detected. Call was widened to force con function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState) if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base add_remark!(interp, sv, "Refusing to infer into `depwarn`") - return Any, false, false, nothing + return MethodCallResult(Any, false, false, nothing) end topmost = nothing # Limit argument type tuple growth of functions: @@ -381,7 +383,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp # we have a self-cycle in the call-graph, but not in the inference graph (typically): # break this edge now (before we record it) by returning early # (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases) - return Any, true, true, nothing + return MethodCallResult(Any, true, true, nothing) end topmost = nothing edgecycle = true @@ -430,7 +432,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp # since it's very unlikely that we'll try to inline this, # or want make an invoke edge to its calling convention return type. # (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases) - return Any, true, true, nothing + return MethodCallResult(Any, true, true, nothing) end add_remark!(interp, sv, RECURSION_MSG) topmost = topmost::InferenceState @@ -472,14 +474,27 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp if edge === nothing edgecycle = edgelimited = true end - return rt, edgecycle, edgelimited, edge + return MethodCallResult(rt, edgecycle, edgelimited, edge) end -function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), +# keeps result and context information of abstract method call, will be used by succeeding constant-propagation +struct MethodCallResult + rt + edgecycle::Bool + edgelimited::Bool + edge::Union{Nothing,MethodInstance} + function MethodCallResult(@nospecialize(rt), + edgecycle::Bool, + edgelimited::Bool, + edge::Union{Nothing,MethodInstance}) + return new(rt, edgecycle, edgelimited, edge) + end +end + +function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult, @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, - sv::InferenceState, edgecycle::Bool, edgelimited::Bool, - va_override::Bool) - mi = maybe_get_const_prop_profitable(interp, rettype, f, argtypes, match, sv, edgecycle) + sv::InferenceState, va_override::Bool) + mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv) mi === nothing && return Any, nothing # try constant prop' inf_cache = get_inference_cache(interp) @@ -487,12 +502,12 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp if inf_result === nothing # if there might be a cycle, check to make sure we don't end up # calling ourselves here. - if edgecycle && _any(InfStackUnwind(sv)) do infstate - # if the type complexity limiting didn't decide to limit the call signature (`edgelimited = false`) + if result.edgecycle && _any(InfStackUnwind(sv)) do infstate + # if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`) # we can relax the cycle detection by comparing `MethodInstance`s and allow inference to # propagate different constant elements if the recursion is finite over the lattice - return (edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) && - any(infstate.result.overridden_by_const) + return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) && + any(infstate.result.overridden_by_const) end add_remark!(interp, sv, "[constprop] Edge cycle encountered") return Any, nothing @@ -513,17 +528,17 @@ end # if there's a possibility we could get a better result (hopefully without doing too much work) # returns `MethodInstance` with constant arguments, returns nothing otherwise -function maybe_get_const_prop_profitable(interp::AbstractInterpreter, @nospecialize(rettype), +function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult, @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, - sv::InferenceState, edgecycle::Bool) - const_prop_entry_heuristic(interp, rettype, sv, edgecycle) || return nothing + sv::InferenceState) + const_prop_entry_heuristic(interp, result, sv) || return nothing method = match.method nargs::Int = method.nargs method.isva && (nargs -= 1) if length(argtypes) < nargs return nothing end - const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, rettype) || return nothing + const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, result.rt) || return nothing allconst = is_allconst(argtypes) force = force_const_prop(interp, f, method) force || const_prop_function_heuristic(interp, f, argtypes, nargs, allconst) || return nothing @@ -541,9 +556,9 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, @nospecial return mi end -function const_prop_entry_heuristic(interp::AbstractInterpreter, @nospecialize(rettype), sv::InferenceState, edgecycle::Bool) - call_result_unused(sv) && edgecycle && return false - return is_improvable(rettype) && InferenceParams(interp).ipo_constant_propagation +function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodCallResult, sv::InferenceState) + call_result_unused(sv) && result.edgecycle && return false + return is_improvable(result.rt) && InferenceParams(interp).ipo_constant_propagation end # see if propagating constants may be worthwhile @@ -1234,20 +1249,20 @@ end function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState) pushfirst!(argtypes, closure.env) sig = argtypes_to_type(argtypes) - rt, edgecycle, edgelimited, edge = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv) + (; rt, edge) = result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv) edge !== nothing && add_backedge!(edge, sv) tt = closure.typ sigT = unwrap_unionall(tt).parameters[1] match = MethodMatch(sig, Core.svec(), closure.source::Method, sig <: rewrap_unionall(sigT, tt)) info = OpaqueClosureCallInfo(match) - if !edgecycle - const_rettype, result = abstract_call_method_with_const_args(interp, rt, closure, argtypes, - match, sv, edgecycle, edgelimited, closure.isva) + if !result.edgecycle + const_rettype, const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes, + match, sv, closure.isva) if const_rettype ⊑ rt rt = const_rettype end - if result !== nothing - info = ConstCallInfo(info, Union{Nothing,InferenceResult}[result]) + if const_result !== nothing + info = ConstCallInfo(info, Union{Nothing,InferenceResult}[const_result]) end end return CallMeta(rt, info)