Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: accelerate type-limits under wide-recursion #31734

Merged
merged 1 commit into from
Apr 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 31 additions & 24 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
nonbot = 0 # the index of the only non-Bottom inference result if > 0
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
multiple_matches = napplicable > 1
for i in 1:napplicable
match = applicable[i]::SimpleVector
method = match[3]::Method
Expand All @@ -80,7 +81,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
if splitunions
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), sv)
rt, edgecycle1, edge = abstract_call_method(method, sig_n, svec(), multiple_matches, sv)
if edge !== nothing
push!(edges, edge)
end
Expand All @@ -89,7 +90,7 @@ function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nosp
this_rt === Any && break
end
else
this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, sv)
this_rt, edgecycle1, edge = abstract_call_method(method, sig, match[2]::SimpleVector, multiple_matches, sv)
edgecycle |= edgecycle1::Bool
if edge !== nothing
push!(edges, edge)
Expand Down Expand Up @@ -227,7 +228,7 @@ function abstract_call_method_with_const_args(@nospecialize(rettype), @nospecial
return result
end

function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, sv::InferenceState)
function abstract_call_method(method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
return Any, false, nothing
end
Expand Down Expand Up @@ -266,30 +267,36 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
if topmost === nothing && method2 === inf_method2
# inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
# otherwise, we don't
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
break
end
end
let parent = infstate.parent
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
if hardlimit
topmost = infstate
edgecycle = true
else
# if this is a soft limit,
# also inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
# otherwise, we don't
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
break
end
end
let parent = infstate.parent
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
end
end
end
end
Expand Down Expand Up @@ -321,7 +328,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
comparison = method.sig
end
# see if the type is actually too big (relative to the caller), and limit it if required
newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)
newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, sv.params.TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)

if newsig !== sig
# continue inference, but note that we've limited parameter complexity
Expand Down
41 changes: 39 additions & 2 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1050,13 +1050,13 @@ copy_dims_out(out) = ()
copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...)
copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...)
@test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 20 < count_specializations(m) < 45, methods(copy_dims_out))
@test all(m -> 4 < count_specializations(m) < 15, methods(copy_dims_out)) # currently about 5

copy_dims_pair(out) = ()
copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...)
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...)
@test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 10 < count_specializations(m) < 35, methods(copy_dims_pair))
@test all(m -> 5 < count_specializations(m) < 15, methods(copy_dims_pair)) # currently about 7

@test isdefined_tfunc(typeof(NamedTuple()), Const(0)) === Const(false)
@test isdefined_tfunc(typeof(NamedTuple()), Const(1)) === Const(false)
Expand Down Expand Up @@ -2348,3 +2348,40 @@ function gen_nodes(qty::Integer) :: AbstractNode
end
end
@test count(==('}'), string(I31663.gen_nodes(50))) == 1275

# issue #31572
struct MixedKeyDict{T<:Tuple} #<: AbstractDict{Any,Any}
dicts::T
end
Base.merge(f::Function, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...)
Base.merge(f, d::MixedKeyDict, others::MixedKeyDict...) = _merge(f, (), d.dicts, (d->d.dicts).(others)...)
function _merge(f, res, d, others...)
ofsametype, remaining = _alloftype(Base.heads(d), ((),), others...)
return _merge(f, (res..., merge(f, ofsametype...)), Base.tail(d), remaining...)
end
_merge(f, res, ::Tuple{}, others...) = _merge(f, res, others...)
_merge(f, res, d) = MixedKeyDict((res..., d...))
_merge(f, res, ::Tuple{}) = MixedKeyDict(res)
function _alloftype(ofdesiredtype::Tuple{Vararg{D}}, accumulated, d::Tuple{D,Vararg}, others...) where D
return _alloftype((ofdesiredtype..., first(d)),
(Base.front(accumulated)..., (last(accumulated)..., Base.tail(d)...), ()),
others...)
end
function _alloftype(ofdesiredtype, accumulated, d, others...)
return _alloftype(ofdesiredtype,
(Base.front(accumulated)..., (last(accumulated)..., first(d))),
Base.tail(d), others...)
end
function _alloftype(ofdesiredtype, accumulated, ::Tuple{}, others...)
return _alloftype(ofdesiredtype,
(accumulated..., ()),
others...)
end
_alloftype(ofdesiredtype, accumulated) = ofdesiredtype, Base.front(accumulated)
let
d = MixedKeyDict((Dict(1 => 3), Dict(4. => 2)))
e = MixedKeyDict((Dict(1 => 7), Dict(5. => 9)))
@test merge(+, d, e).dicts == (Dict(1 => 10), Dict(4.0 => 2, 5.0 => 9))
f = MixedKeyDict((Dict(2 => 7), Dict(5. => 11)))
@test merge(+, d, e, f).dicts == (Dict(1 => 10, 2 => 7), Dict(4.0 => 2, 5.0 => 20))
end