From 4e3d6af3ff6237c83a60f836244b0225ff54d35b Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Mon, 15 Apr 2019 17:25:21 -0400 Subject: [PATCH] inference: accelerate type-limits under wide-recursion when we hit union-splitting, we need to ensure type limits are very aggressive and preferably also independent of the height of the recursion chain fix #31572 --- base/compiler/abstractinterpretation.jl | 55 ++++++++++++++----------- test/compiler/inference.jl | 41 +++++++++++++++++- 2 files changed, 70 insertions(+), 26 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 1aa80b8be10a3..f0b210b0d28c2 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -63,6 +63,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp nonbot = 0 # the index of the only non-Bottom inference result if > 0 seen = 0 # number of signatures actually inferred istoplevel = sv.linfo.def isa Module + any_splitunions = napplicable > 1 for i in 1:napplicable match = applicable[i]::SimpleVector method = match[3]::Method @@ -80,7 +81,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp if splitunions splitsigs = switchtupleunion(sig) for sig_n in splitsigs - rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), sv) + rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), any_splitunions, sv) if edge !== nothing push!(edges, edge) end @@ -89,7 +90,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp this_rt === Any && break end else - this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, sv) + this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, any_splitunions, sv) edgecycle |= edgecycle1::Bool if edge !== nothing push!(edges, edge) @@ -227,7 +228,7 @@ function abstract_call_method_with_const_args(@nospecialize(rettype), @nospecial return result end -function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, sv::InferenceState) +function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState) if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base return Any, false, nothing end @@ -266,30 +267,36 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing} if topmost === nothing && method2 === inf_method2 - # inspect the parent of this edge, - # to see if they are the same Method as sv - # in which case we'll need to ensure it is convergent - # otherwise, we don't - for parent in infstate.callers_in_cycle - # check in the cycle list first - # all items in here are mutual parents of all others - parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match - parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} - if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 - topmost = infstate - edgecycle = true - break - end - end - let parent = infstate.parent - # then check the parent link - if topmost === nothing && parent !== nothing - parent = parent::InferenceState + if hardlimit + topmost = infstate + edgecycle = true + else + # if this is a soft limit, + # also inspect the parent of this edge, + # to see if they are the same Method as sv + # in which case we'll need to ensure it is convergent + # otherwise, we don't + for parent in infstate.callers_in_cycle + # check in the cycle list first + # all items in here are mutual parents of all others parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} - if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 + if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 topmost = infstate edgecycle = true + break + end + end + let parent = infstate.parent + # then check the parent link + if topmost === nothing && parent !== nothing + parent = parent::InferenceState + parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match + parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} + if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 + topmost = infstate + edgecycle = true + end end end end @@ -321,7 +328,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl comparison = method.sig end # see if the type is actually too big (relative to the caller), and limit it if required - newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len) + newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len) if newsig !== sig # continue inference, but note that we've limited parameter complexity diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 74211a8453746..c56ed5df2ff16 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1050,13 +1050,13 @@ copy_dims_out(out) = () copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...) copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...) @test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}] -@test all(m -> 20 < count_specializations(m) < 45, methods(copy_dims_out)) +@test all(m -> 4 < count_specializations(m) < 15, methods(copy_dims_out)) # currently about 5 copy_dims_pair(out) = () copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...) copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...) @test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}] -@test all(m -> 10 < count_specializations(m) < 35, methods(copy_dims_pair)) +@test all(m -> 5 < count_specializations(m) < 15, methods(copy_dims_pair)) # currently about 7 @test isdefined_tfunc(typeof(NamedTuple()), Const(0)) === Const(false) @test isdefined_tfunc(typeof(NamedTuple()), Const(1)) === Const(false) @@ -2348,3 +2348,40 @@ function gen_nodes(qty::Integer) :: AbstractNode end end @test count(==('}'), string(I31663.gen_nodes(50))) == 1275 + +# issue #31572 +struct MixedKeyDict{T<:Tuple} #<: AbstractDict{Any,Any} + dicts::T +end +Base.merge(f::Function, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...) +Base.merge(f, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...) +function _merge(f, res, d, others...) + ofsametype, remaining = _alloftype(Base.heads(d), ((),), others...) + return _merge(f, (res..., merge(f, ofsametype...)), Base.tail(d), remaining...) +end +_merge(f, res, ::Tuple{}, others...) = _merge(f, res, others...) +_merge(f, res, d) = MixedKeyDict((res..., d...)) +_merge(f, res, ::Tuple{}) = MixedKeyDict(res) +function _alloftype(ofdesiredtype::Tuple{Vararg{D}}, accumulated, d::Tuple{D,Vararg}, others...) where D + return _alloftype((ofdesiredtype..., first(d)), + (Base.front(accumulated)..., (last(accumulated)..., Base.tail(d)...), ()), + others...) +end +function _alloftype(ofdesiredtype, accumulated, d, others...) + return _alloftype(ofdesiredtype, + (Base.front(accumulated)..., (last(accumulated)..., first(d))), + Base.tail(d), others...) +end +function _alloftype(ofdesiredtype, accumulated, ::Tuple{}, others...) + return _alloftype(ofdesiredtype, + (accumulated..., ()), + others...) +end +_alloftype(ofdesiredtype, accumulated) = ofdesiredtype, Base.front(accumulated) +let + d = MixedKeyDict((Dict(1 => 3), Dict(4. => 2))) + e = MixedKeyDict((Dict(1 => 7), Dict(5. => 9))) + @test merge(+, d, e).dicts == (Dict(1 => 10), Dict(4.0 => 2, 5.0 => 9)) + f = MixedKeyDict((Dict(2 => 7), Dict(5. => 11))) + @test merge(+, d, e, f).dicts == (Dict(1 => 10, 2 => 7), Dict(4.0 => 2, 5.0 => 20)) +end