Skip to content

Commit

Permalink
Merge pull request #23912 from JuliaLang/jn/infer-norecur-more
Browse files Browse the repository at this point in the history
inference: revise recursion detection algorithm
  • Loading branch information
vtjnash authored Oct 10, 2017
2 parents 546a801 + b89e88e commit 10d470d
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 158 deletions.
20 changes: 15 additions & 5 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,18 +628,28 @@ function _collect_indices(indsA, A)
copy!(B, CartesianRange(indices(B)), A, CartesianRange(indsA))
end

# define this as a macro so that the call to Inference
# gets inlined into the caller before recursion detection
# gets a chance to see it, so that recursive calls to the caller
# don't trigger the inference limiter
if isdefined(Core, :Inference)
_default_eltype(@nospecialize itrt) = Core.Inference.return_type(first, Tuple{itrt})
macro default_eltype(itrt)
return quote
Core.Inference.return_type(first, Tuple{$(esc(itrt))})
end
end
else
_default_eltype(@nospecialize itr) = Any
macro default_eltype(itrt)
return :(Any)
end
end

_array_for(::Type{T}, itr, ::HasLength) where {T} = Array{T,1}(Int(length(itr)::Integer))
_array_for(::Type{T}, itr, ::HasShape) where {T} = similar(Array{T}, indices(itr))

function collect(itr::Generator)
isz = iteratorsize(itr.iter)
et = _default_eltype(typeof(itr))
et = @default_eltype(typeof(itr))
if isa(isz, SizeUnknown)
return grow_to!(Array{et,1}(0), itr)
else
Expand All @@ -653,12 +663,12 @@ function collect(itr::Generator)
end

_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
grow_to!(_similar_for(c, _default_eltype(typeof(itr)), itr, isz), itr)
grow_to!(_similar_for(c, @default_eltype(typeof(itr)), itr, isz), itr)

function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
st = start(itr)
if done(itr,st)
return _similar_for(c, _default_eltype(typeof(itr)), itr, isz)
return _similar_for(c, @default_eltype(typeof(itr)), itr, isz)
end
v1, st = next(itr, st)
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
Expand Down
4 changes: 2 additions & 2 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ associative_with_eltype(DT_apply, kv, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv
associative_with_eltype(DT_apply, kv::Generator, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv)
associative_with_eltype(DT_apply, ::Type{Pair{K,V}}) where {K,V} = DT_apply(K, V)()
associative_with_eltype(DT_apply, ::Type) = DT_apply(Any, Any)()
associative_with_eltype(DT_apply::F, kv, t) where {F} = grow_to!(associative_with_eltype(DT_apply, _default_eltype(typeof(kv))), kv)
associative_with_eltype(DT_apply::F, kv, t) where {F} = grow_to!(associative_with_eltype(DT_apply, @default_eltype(typeof(kv))), kv)
function associative_with_eltype(DT_apply::F, kv::Generator, t) where F
T = _default_eltype(typeof(kv))
T = @default_eltype(typeof(kv))
if T <: Union{Pair, Tuple{Any, Any}} && _isleaftype(T)
return associative_with_eltype(DT_apply, kv, T)
end
Expand Down
282 changes: 174 additions & 108 deletions base/inference.jl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ Returns the sum of all elements of `A`, using the Kahan-Babuska-Neumaier compens
summation algorithm for additional accuracy.
"""
function sum_kbn(A)
T = _default_eltype(typeof(A))
T = @default_eltype(typeof(A))
c = r_promote(+, zero(T)::T)
i = start(A)
if done(A, i)
Expand Down
4 changes: 2 additions & 2 deletions base/set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ for sets of arbitrary objects.
"""
Set(itr) = Set{eltype(itr)}(itr)
function Set(g::Generator)
T = _default_eltype(typeof(g))
T = @default_eltype(typeof(g))
(_isleaftype(T) || T === Union{}) || return grow_to!(Set{T}(), g)
return Set{T}(g)
end
Expand Down Expand Up @@ -258,7 +258,7 @@ julia> unique(Real[1, 1.0, 2])
```
"""
function unique(itr)
T = _default_eltype(typeof(itr))
T = @default_eltype(typeof(itr))
out = Vector{T}()
seen = Set{T}()
i = start(itr)
Expand Down
43 changes: 20 additions & 23 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,37 +926,34 @@ end
# vectors/matrices in mixedargs in their orginal order, and such that the result of
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
@inline function capturescalars(f, mixedargs)
let makeargs = _capturescalars(mixedargs...),
parevalf = (passed...) -> f(makeargs(passed...)...),
passedsrcargstup = _capturenonscalars(mixedargs...)
let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...)
parevalf = (passed...) -> f(makeargs(passed...)...)
return (parevalf, passedsrcargstup)
end
end

@inline _capturenonscalars(nonscalararg::SparseVecOrMat, mixedargs...) =
(nonscalararg, _capturenonscalars(mixedargs...)...)
@inline _capturenonscalars(scalararg, mixedargs...) =
_capturenonscalars(mixedargs...)
@inline _capturenonscalars() = ()
nonscalararg(::SparseVecOrMat) = true
nonscalararg(::Any) = false

@inline _capturescalars(nonscalararg::SparseVecOrMat, mixedargs...) =
let f = _capturescalars(mixedargs...)
(head, tail...) -> (head, f(tail...)...) # pass-through
@inline function _capturescalars()
return (), () -> ()
end
@inline function _capturescalars(arg, mixedargs...)
let (rest, f) = _capturescalars(mixedargs...)
if nonscalararg(arg)
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
else
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
end
end
@inline _capturescalars(scalararg, mixedargs...) =
let f = _capturescalars(mixedargs...)
(tail...) -> (scalararg, f(tail...)...) # add scalararg
end
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
if nonscalararg(arg)
return (arg,), (head,) -> (head,) # pass-through
else
return (), () -> (arg,) # add scalararg
end
# TODO: use the implicit version once inference can handle it
# handle too-many-arguments explicitly
@inline function _capturescalars()
too_many_arguments() = ()
too_many_arguments(tail...) = throw(ArgumentError("too many"))
end
#@inline _capturescalars(nonscalararg::SparseVecOrMat) =
# (head,) -> (head,) # pass-through
#@inline _capturescalars(scalararg) =
# () -> (scalararg,) # add scalararg

# NOTE: The following two method definitions work around #19096.
broadcast(f::Tf, ::Type{T}, A::SparseMatrixCSC) where {Tf,T} = broadcast(y -> f(T, y), A)
Expand Down
25 changes: 18 additions & 7 deletions src/rtutils.c
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,12 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
else if (vt == jl_method_instance_type) {
jl_method_instance_t *li = (jl_method_instance_t*)v;
if (jl_is_method(li->def.method)) {
jl_method_t *m = li->def.method;
n += jl_static_show_x(out, (jl_value_t*)m->module, depth);
if (li->specTypes) {
n += jl_printf(out, ".");
n += jl_show_svec(out, ((jl_datatype_t*)jl_unwrap_unionall(li->specTypes))->parameters,
jl_symbol_name(m->name), "(", ")");
n += jl_static_show_func_sig(out, li->specTypes);
}
else {
jl_method_t *m = li->def.method;
n += jl_static_show_x(out, (jl_value_t*)m->module, depth);
n += jl_printf(out, ".%s(?)", jl_symbol_name(m->name));
}
}
Expand Down Expand Up @@ -949,15 +947,15 @@ JL_DLLEXPORT size_t jl_static_show_func_sig(JL_STREAM *s, jl_value_t *type)
if (ftype == NULL)
return jl_static_show(s, type);
size_t n = 0;
if (jl_nparams(ftype)==0 || ftype == ((jl_datatype_t*)ftype)->name->wrapper) {
if (jl_nparams(ftype) == 0 || ftype == ((jl_datatype_t*)ftype)->name->wrapper) {
n += jl_printf(s, "%s", jl_symbol_name(((jl_datatype_t*)ftype)->name->mt->name));
}
else {
n += jl_printf(s, "(::");
n += jl_static_show(s, ftype);
n += jl_printf(s, ")");
}
// TODO: better way to show method parameters
jl_unionall_t *tvars = (jl_unionall_t*)type;
type = jl_unwrap_unionall(type);
if (!jl_is_datatype(type)) {
n += jl_printf(s, " ");
Expand All @@ -984,6 +982,19 @@ JL_DLLEXPORT size_t jl_static_show_func_sig(JL_STREAM *s, jl_value_t *type)
}
}
n += jl_printf(s, ")");
if (jl_is_unionall(tvars)) {
int first = 1;
n += jl_printf(s, " where {");
while (jl_is_unionall(tvars)) {
if (first)
first = 0;
else
n += jl_printf(s, ", ");
n += jl_static_show(s, (jl_value_t*)tvars->var);
tvars = (jl_unionall_t*)tvars->body;
}
n += jl_printf(s, "}");
}
return n;
}

Expand Down
8 changes: 3 additions & 5 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3353,10 +3353,6 @@ end
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) == EmptyIIOtherField13175(EmptyImmutable13175(), 1.0)
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) != EmptyIIOtherField13175(EmptyImmutable13175(), 2.0)

# issue #13183
gg13183(x::X...) where {X} = 1==0 ? gg13183(x, x) : 0
@test gg13183(5) == 0

# issue 8932 (llvm return type legalizer error)
struct Vec3_8932
x::Float32
Expand Down Expand Up @@ -5331,7 +5327,8 @@ module UnionOptimizations
using Test

const boxedunions = [Union{}, Union{String, Void}]
const unboxedunions = [Union{Int8, Void}, Union{Int8, Float16, Void},
const unboxedunions = [Union{Int8, Void},
Union{Int8, Float16, Void},
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128},
Union{Char, Date, Int}]

Expand Down Expand Up @@ -5457,6 +5454,7 @@ t4 = vcat(A23567, t2, t3)
@test t4[11:15] == A23567

for U in unboxedunions
Base.unionlen(U) > 5 && continue # larger values cause subtyping to crash
local U
for N in (1, 2, 3, 4)
A = Array{U}(ntuple(x->0, N)...)
Expand Down
28 changes: 23 additions & 5 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@

# tests for Core.Inference correctness and precision
import Core.Inference: Const, Conditional,
const isleaftype = Core.Inference._isleaftype

# demonstrate some of the type-size limits
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref, Ref, 0) == Ref
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref{Complex{T} where T}, Ref, 0) == Ref{Complex{T} where T}
let comparison = Tuple{X, X} where X<:Tuple
sig = Tuple{X, X} where X<:comparison
ref = Tuple{X, X} where X
@test Core.Inference.limit_type_size(sig, comparison, comparison, 10) == comparison
@test Core.Inference.limit_type_size(sig, ref, comparison, 10) == comparison
@test Core.Inference.limit_type_size(Tuple{sig}, Tuple{ref}, comparison, 10) == Tuple{comparison}
@test Core.Inference.limit_type_size(sig, ref, Tuple{comparison}, 10) == sig
end


# issue 9770
@noinline x9770() = false
Expand Down Expand Up @@ -186,7 +200,6 @@ function find_tvar10930(arg)
end
@test find_tvar10930(Vararg{Int}) === 1

const isleaftype = Base._isleaftype

# issue #12474
@generated function f12474(::Any)
Expand Down Expand Up @@ -980,13 +993,13 @@ copy_dims_out(out) = ()
copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...)
copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...)
@test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 2 < count_specializations(m) < 15, methods(copy_dims_out))
@test all(m -> 10 < count_specializations(m) < 25, methods(copy_dims_out))

copy_dims_pair(out) = ()
copy_dims_pair(out, dim::Int, tail...) = copy_dims_out(out => dim, tail...)
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_out(out => dim, tail...)
copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...)
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...)
@test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 5 < count_specializations(m) < 25, methods(copy_dims_out))
@test all(m -> 5 < count_specializations(m) < 25, methods(copy_dims_pair))

# splatting an ::Any should still allow inference to use types of parameters preceding it
f22364(::Int, ::Any...) = 0
Expand Down Expand Up @@ -1225,3 +1238,8 @@ end
let t = Tuple{Type{T23786{D, N} where N where D<:Tuple{Vararg{Array{T, 1} where T, N} where N}}}
@test Core.Inference.limit_type_depth(t, 4) >: t
end

# issue #13183
_false13183 = false
gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0)
@test gg13183(5) == 0

0 comments on commit 10d470d

Please sign in to comment.