Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax constprop recursion detection #39918

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
applicable = applicable::Array{Any,1}
napplicable = length(applicable)
rettype = Bottom
edgecycle = false
edgecycle = edgelimited = false
edges = MethodInstance[]
conditionals = nothing # keeps refinement information of call argument types when the return type is boolean
nonbot = 0 # the index of the only non-Bottom inference result if > 0
Expand Down Expand Up @@ -130,8 +130,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if splitunions
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
rt, edgecycle1, edge = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
rt, edgecycle1, edgelimited1, edge = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
edgecycle |= edgecycle1::Bool
edgelimited |= edgelimited1::Bool
if edge !== nothing
push!(edges, edge)
end
Expand All @@ -141,8 +142,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end
else
this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv)
this_rt, edgecycle1, edgelimited1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv)
edgecycle |= edgecycle1::Bool
edgelimited |= edgelimited1::Bool
if edge !== nothing
push!(edges, edge)
end
Expand Down Expand Up @@ -192,7 +194,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# (hopefully without doing too much work), try to do that now
# TODO: refactor this, enable constant propagation for each (union-split) signature
match = applicable[nonbot]::MethodMatch
const_rettype, result = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle)
const_rettype, result = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle, edgelimited)
const_conditional = ignorelimited(const_rettype)
@assert !(const_conditional isa Conditional) "invalid lattice element returned from inter-procedural context"
const_rettype = widenwrappedconditional(const_rettype)
Expand Down Expand Up @@ -374,7 +376,7 @@ function const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::M
return true
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, sv::InferenceState, edgecycle::Bool)
function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nospecialize(rettype), @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, sv::InferenceState, edgecycle::Bool, edgelimited::Bool)
method = match.method
nargs::Int = method.nargs
method.isva && (nargs -= 1)
Expand Down Expand Up @@ -456,7 +458,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
infstate = sv
cyclei = 0
while !(infstate === nothing)
if method === infstate.linfo.def && any(infstate.result.overridden_by_const)
if (edgelimited ? method === infstate.linfo.def : mi === infstate.linfo) && any(infstate.result.overridden_by_const)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to make sense in general, but I'd need to think hard through the specifics here to make sure it is good.

So, here we are checking with mi, which isn't normally a limited object, and thus would normally not be valid here for cycle detection (thus also why we don't use it originally in cycle detection). But I guess the concept is that abstract_call_gf_by_type is expected to deal with that possibility (via setting edgelimited when required), so we won't get here if this would be unrestricted in complexity, and making this === equivalent to <=?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Basically, the intuition is that the same logic that prevents infinite recursion over the type lattice should also restrict infinite recursion of the lattice that adjoins const elements, except that we can get additional recursion by calling the same method instance with different const elements (which is a cycle, but not infinite over the type lattice). There's an example that shows that case added in the tests (I verified the example stack overflows if you replace edgecycle by edgelimited above).

add_remark!(interp, sv, "[constprop] Edge cycle encountered")
return Any, nothing
end
Expand Down Expand Up @@ -488,7 +490,7 @@ const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Ann
function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
add_remark!(interp, sv, "Refusing to infer into `depwarn`")
return Any, false, nothing
return Any, false, false, nothing
end
topmost = nothing
# Limit argument type tuple growth of functions:
Expand All @@ -498,6 +500,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
cyclei = 0
infstate = sv
edgecycle = false
edgelimited = false
# The `method_for_inference_heuristics` will expand the given method's generator if
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
Expand All @@ -517,7 +520,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# we have a self-cycle in the call-graph, but not in the inference graph (typically):
# break this edge now (before we record it) by returning early
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return Any, true, nothing
return Any, true, true, nothing
end
topmost = nothing
edgecycle = true
Expand All @@ -542,7 +545,6 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to delete these edgecycle = true ?
In these cases there is a cycle, and we still need to try recursion detection in abstract_call_method_with_const_args ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right - these were deleted by accident (leftover from an earlier experiment) - thanks.

break
end
end
Expand All @@ -554,7 +556,6 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if (parent.cached || parent.parent !== nothing) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
end
end
end
Expand Down Expand Up @@ -599,13 +600,14 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# since it's very unlikely that we'll try to inline this,
# or want make an invoke edge to its calling convention return type.
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return Any, true, nothing
return Any, true, true, nothing
end
topmost = topmost::InferenceState
parentframe = topmost.parent
poison_callstack(sv, parentframe === nothing ? topmost : parentframe)
sig = newsig
sparams = svec()
edgelimited = true
end
end

Expand Down Expand Up @@ -637,9 +639,9 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp

rt, edge = typeinf_edge(interp, method, sig, sparams, sv)
if edge === nothing
edgecycle = true
edgecycle = edgelimited = true
end
return rt, edgecycle, edge
return rt, edgecycle, edgelimited, edge
end

# This is only for use with `Conditional`.
Expand Down
27 changes: 27 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3106,3 +3106,30 @@ end == [Int]
let f() = Val(fieldnames(Complex{Int}))
@test @inferred(f()) === Val((:re,:im))
end

# Make sure that const prop doesn't fall into cycles that aren't problematic
# in the type domain
f_recurse(x) = x > 1000000 ? x : f_recurse(x+1)
g_recurse() = f_recurse(1)
Base.return_types(g_recurse, Tuple{})[1] == Int

# issue #39915
function f33915(a_tuple, which_ones)
rest = my_getindex(tail(a_tuple), tail(which_ones))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rest = my_getindex(tail(a_tuple), tail(which_ones))
rest = f33915(tail(a_tuple), tail(which_ones))

if first(which_ones)
(first(a_tuple), rest...)
else
rest
end
end

function f33915(a_tuple::Tuple{}, which_ones::Tuple{})
()
end

function g39915(a_tuple)
f33915(a_tuple, (true, false, true, false))
end

h39915() = g39915((1, 1.0, "a", :a))
Base.return_types(h39915, Tuple{})[1] == Tuple{Int, String}