diff --git a/base/inference.jl b/base/inference.jl index 7f5e66182ad18f..e47ff5a36e46aa 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -1253,8 +1253,34 @@ end #### recursing into expression #### +# take a Tuple where one or more parameters are Unions +# and return an array such that those Unions are removed +# and `Union{return...} == ty` +function switchtupleunion(ty::ANY) + tparams = (unwrap_unionall(ty)::DataType).parameters + return _switchtupleunion(Any[tparams...], length(tparams), [], ty) +end + +function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, origt::ANY) + if i == 0 + tpl = rewrap_unionall(Tuple{t...}, origt) + push!(tunion, tpl) + else + ti = t[i] + if isa(ti, Union) + for ty in uniontypes(ti::Union) + t[i] = ty + _switchtupleunion(t, i - 1, tunion, origt) + end + t[i] = ti + else + _switchtupleunion(t, i - 1, tunion, origt) + end + end + return tunion +end + function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState) - tm = _topmod(sv) # don't consider more than N methods. this trades off between # compiler performance and generated code performance. # typically, considering many methods means spending lots of time @@ -1282,136 +1308,163 @@ function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState) end min_valid = UInt[typemin(UInt)] max_valid = UInt[typemax(UInt)] - applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid) - rettype = Bottom - if applicable === false - # this means too many methods matched - return Any + splitunions = 1 < countunionsplit(argtypes) <= sv.params.MAX_UNION_SPLITTING + if splitunions + splitsigs = switchtupleunion(argtype) + applicable = Any[] + for sig_n in splitsigs + xapplicable = _methods_by_ftype(sig_n, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid) + xapplicable === false && return Any + append!(applicable, xapplicable) + end + else + applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid) + if applicable === false + # this means too many methods matched + return Any + end end applicable = applicable::Array{Any,1} + napplicable = length(applicable) fullmatch = false - for (m::SimpleVector) in applicable - sig = m[1] - sigtuple = unwrap_unionall(sig)::DataType - method = m[3]::Method - sparams = m[2]::SimpleVector - recomputesvec = false + rettype = Bottom + for i in 1:napplicable + match = applicable[i]::SimpleVector + method = match[3]::Method if !fullmatch && (argtype <: method.sig) fullmatch = true end + sig = match[1] + sigtuple = unwrap_unionall(sig)::DataType + splitunions = 1 < countunionsplit(sigtuple.parameters) * napplicable <= sv.params.MAX_UNION_SPLITTING + if splitunions + splitsigs = switchtupleunion(sig) + for sig_n in splitsigs + rt = abstract_call_method(method, f, sig_n, svec(), sv) + rettype = tmerge(rettype, rt) + rettype === Any && break + end + rettype === Any && break + else + rt = abstract_call_method(method, f, sig, match[2]::SimpleVector, sv) + rettype = tmerge(rettype, rt) + rettype === Any && break + end + end + if !(fullmatch || rettype === Any) + # also need an edge to the method table in case something gets + # added that did not intersect with any existing method + add_mt_backedge(ftname.mt, argtype, sv) + update_valid_age!(min_valid[1], max_valid[1], sv) + end + #print("=> ", rettype, "\n") + return rettype +end - # limit argument type tuple growth - msig = unwrap_unionall(method.sig) - lsig = length(msig.parameters) - ls = length(sigtuple.parameters) - td = type_depth(sig) - mightlimitlength = ls > lsig + 1 - mightlimitdepth = td > 2 - limitlength = false - if mightlimitlength || mightlimitdepth - # TODO: FIXME: this heuristic depends on non-local state making type-inference unpredictable - cyclei = 0 - infstate = sv - while infstate !== nothing - infstate = infstate::InferenceState - if isdefined(infstate.linfo, :def) && method === infstate.linfo.def - if mightlimitlength && ls > length(unwrap_unionall(infstate.linfo.specTypes).parameters) - limitlength = true - end - if mightlimitdepth && td > type_depth(infstate.linfo.specTypes) - # impose limit if we recur and the argument types grow beyond MAX_TYPE_DEPTH - if td > MAX_TYPE_DEPTH - sig = limit_type_depth(sig, 0) - sigtuple = unwrap_unionall(sig) - recomputesvec = true - break - else - p1, p2 = sigtuple.parameters, unwrap_unionall(infstate.linfo.specTypes).parameters - if length(p2) == ls - limitdepth = false - newsig = Vector{Any}(ls) - for i = 1:ls - if p1[i] <: Function && type_depth(p1[i]) > type_depth(p2[i]) && - isa(p1[i],DataType) - # if a Function argument is growing (e.g. nested closures) - # then widen to the outermost function type. without this - # inference fails to terminate on do_quadgk. - newsig[i] = p1[i].name.wrapper - limitdepth = true - else - newsig[i] = limit_type_depth(p1[i], 1) - end - end - if limitdepth - sigtuple = Tuple{newsig...} - sig = rewrap_unionall(sigtuple, sig) - recomputesvec = true - break +function abstract_call_method(method::Method, f::ANY, sig::ANY, sparams::SimpleVector, sv::InferenceState) + sigtuple = unwrap_unionall(sig)::DataType + recomputesvec = false + + # limit argument type tuple growth + msig = unwrap_unionall(method.sig) + lsig = length(msig.parameters) + ls = length(sigtuple.parameters) + td = type_depth(sig) + mightlimitlength = ls > lsig + 1 + mightlimitdepth = td > 2 + limitlength = false + if mightlimitlength || mightlimitdepth + # TODO: FIXME: this heuristic depends on non-local state making type-inference unpredictable + cyclei = 0 + infstate = sv + while infstate !== nothing + infstate = infstate::InferenceState + if isdefined(infstate.linfo, :def) && method === infstate.linfo.def + if mightlimitlength && ls > length(unwrap_unionall(infstate.linfo.specTypes).parameters) + limitlength = true + end + if mightlimitdepth && td > type_depth(infstate.linfo.specTypes) + # impose limit if we recur and the argument types grow beyond MAX_TYPE_DEPTH + if td > MAX_TYPE_DEPTH + sig = limit_type_depth(sig, 0) + sigtuple = unwrap_unionall(sig) + recomputesvec = true + break + else + p1, p2 = sigtuple.parameters, unwrap_unionall(infstate.linfo.specTypes).parameters + if length(p2) == ls + limitdepth = false + newsig = Vector{Any}(ls) + for i = 1:ls + if p1[i] <: Function && type_depth(p1[i]) > type_depth(p2[i]) && + isa(p1[i],DataType) + # if a Function argument is growing (e.g. nested closures) + # then widen to the outermost function type. without this + # inference fails to terminate on do_quadgk. + newsig[i] = p1[i].name.wrapper + limitdepth = true + else + newsig[i] = limit_type_depth(p1[i], 1) end end + if limitdepth + sigtuple = Tuple{newsig...} + sig = rewrap_unionall(sigtuple, sig) + recomputesvec = true + break + end end end end - # iterate through the cycle before walking to the parent - if cyclei < length(infstate.callers_in_cycle) - cyclei += 1 - infstate = infstate.callers_in_cycle[cyclei] - else - cyclei = 0 - infstate = infstate.parent - end end - end - - # limit length based on size of definition signature. - # for example, given function f(T, Any...), limit to 3 arguments - # instead of the default (MAX_TUPLETYPE_LEN) - if limitlength - if !istopfunction(tm, f, :promote_typeof) - fst = sigtuple.parameters[lsig + 1] - allsame = true - # allow specializing on longer arglists if all the trailing - # arguments are the same, since there is no exponential - # blowup in this case. - for i = (lsig + 2):ls - if sigtuple.parameters[i] != fst - allsame = false - break - end - end - if !allsame - sigtuple = limit_tuple_type_n(sigtuple, lsig + 1) - sig = rewrap_unionall(sigtuple, sig) - recomputesvec = true + # iterate through the cycle before walking to the parent + if cyclei < length(infstate.callers_in_cycle) + cyclei += 1 + infstate = infstate.callers_in_cycle[cyclei] + else + cyclei = 0 + infstate = infstate.parent + end + end + end + + # limit length based on size of definition signature. + # for example, given function f(T, Any...), limit to 3 arguments + # instead of the default (MAX_TUPLETYPE_LEN) + if limitlength + tm = _topmod(sv) + if !istopfunction(tm, f, :promote_typeof) + fst = sigtuple.parameters[lsig + 1] + allsame = true + # allow specializing on longer arglists if all the trailing + # arguments are the same, since there is no exponential + # blowup in this case. + for i = (lsig + 2):ls + if sigtuple.parameters[i] != fst + allsame = false + break end end - end - - # if sig changed, may need to recompute the sparams environment - if recomputesvec && !isempty(sparams) - recomputed = ccall(:jl_env_from_type_intersection, Ref{SimpleVector}, (Any, Any), sig, method.sig) - sig = recomputed[1] - if !isa(unwrap_unionall(sig), DataType) # probably Union{} - rettype = Any - break + if !allsame + sigtuple = limit_tuple_type_n(sigtuple, lsig + 1) + sig = rewrap_unionall(sigtuple, sig) + recomputesvec = true end - sparams = recomputed[2]::SimpleVector - end - rt, edge = typeinf_edge(method, sig, sparams, sv) - edge !== nothing && add_backedge!(edge::MethodInstance, sv) - rettype = tmerge(rettype, rt) - if rettype === Any - break end end - if !(fullmatch || rettype === Any) - # also need an edge to the method table in case something gets - # added that did not intersect with any existing method - add_mt_backedge(ftname.mt, argtype, sv) - update_valid_age!(min_valid[1], max_valid[1], sv) + + # if sig changed, may need to recompute the sparams environment + if isa(method.sig, UnionAll) && (recomputesvec || isempty(sparams)) + recomputed = ccall(:jl_env_from_type_intersection, Ref{SimpleVector}, (Any, Any), sig, method.sig) + sig = recomputed[1] + if !isa(unwrap_unionall(sig), DataType) # probably Union{} + return Any + end + sparams = recomputed[2]::SimpleVector end - #print("=> ", rettype, "\n") - return rettype + rt, edge = typeinf_edge(method, sig, sparams, sv) + edge !== nothing && add_backedge!(edge::MethodInstance, sv) + return rt end # determine whether `ex` abstractly evals to constant `c` @@ -1562,6 +1615,9 @@ function abstract_apply(aft::ANY, fargs::Vector{Any}, aargtypes::Vector{Any}, vt return res end +# TODO: this function is a very buggy and poor model of the return_type function +# since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type, +# while this assumes that it is a precisely accurate and exact model of both function return_type_tfunc(argtypes::ANY, vtypes::VarTable, sv::InferenceState) if length(argtypes) == 3 tt = argtypes[3] @@ -2112,8 +2168,10 @@ function issubconditional(a::Conditional, b::Conditional) end function ⊑(a::ANY, b::ANY) - a === NF && return true - b === NF && return false + (a === NF || b === Any) && return true + (a === Any || b === NF) && return false + a === Union{} && return true + b === Union{} && return false if isa(a, Conditional) if isa(b, Conditional) return issubconditional(a, b) @@ -3483,7 +3541,7 @@ function is_self_quoting(x::ANY) return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type) end -function countunionsplit(atypes::Vector{Any}) +function countunionsplit(atypes) nu = 1 for ti in atypes if isa(ti, Union) diff --git a/base/reflection.jl b/base/reflection.jl index eb8e279e119778..410611b2161cf6 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -507,46 +507,9 @@ function _methods_by_ftype(t::ANY, lim::Int, world::UInt) return _methods_by_ftype(t, lim, world, UInt[typemin(UInt)], UInt[typemax(UInt)]) end function _methods_by_ftype(t::ANY, lim::Int, world::UInt, min::Array{UInt,1}, max::Array{UInt,1}) - tp = unwrap_unionall(t).parameters::SimpleVector - nu = 1 - for ti in tp - if isa(ti, Union) - nu *= unionlen(ti::Union) - end - end - if 1 < nu <= 64 - return _methods_by_ftype(Any[tp...], t, length(tp), lim, [], world, min, max) - end - # XXX: the following can return incorrect answers that the above branch would have corrected return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}), t, lim, 0, world, min, max) end -function _methods_by_ftype(t::Array, origt::ANY, i, lim::Integer, matching::Array{Any,1}, - world::UInt, min::Array{UInt,1}, max::Array{UInt,1}) - if i == 0 - world = typemax(UInt) - new = ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}), - rewrap_unionall(Tuple{t...}, origt), lim, 0, world, min, max) - new === false && return false - append!(matching, new::Array{Any,1}) - else - ti = t[i] - if isa(ti, Union) - for ty in uniontypes(ti::Union) - t[i] = ty - if _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max) === false - t[i] = ti - return false - end - end - t[i] = ti - else - return _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max) - end - end - return matching -end - # high-level, more convenient method lookup functions # type for reflecting and pretty-printing a subset of methods