From 104538f4a74cef086d958a72653206b130e6235e Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Thu, 27 Sep 2018 20:34:26 -0400 Subject: [PATCH] fix #29269, type intersection bug in union parameters with typevars also fixes #25752 --- base/missing.jl | 4 ++ base/some.jl | 3 +- src/subtype.c | 71 +++++++++++++-------------- stdlib/LinearAlgebra/src/symmetric.jl | 9 +++- test/ambiguous.jl | 7 +++ test/subtype.jl | 19 ++++++- 6 files changed, 73 insertions(+), 40 deletions(-) diff --git a/base/missing.jl b/base/missing.jl index 672ae602d9fff..384771ca6ed40 100644 --- a/base/missing.jl +++ b/base/missing.jl @@ -27,6 +27,7 @@ nonmissingtype(::Type{Any}) = Any for U in (:Nothing, :Missing) @eval begin promote_rule(::Type{$U}, ::Type{T}) where {T} = Union{T, $U} + promote_rule(::Type{Union{S,$U}}, ::Type{Any}) where {S} = Any promote_rule(::Type{Union{S,$U}}, ::Type{T}) where {T,S} = Union{promote_type(T, S), $U} promote_rule(::Type{Any}, ::Type{$U}) = Any promote_rule(::Type{$U}, ::Type{Any}) = Any @@ -37,13 +38,16 @@ end promote_rule(::Type{Union{Nothing, Missing}}, ::Type{Any}) = Any promote_rule(::Type{Union{Nothing, Missing}}, ::Type{T}) where {T} = Union{Nothing, Missing, T} +promote_rule(::Type{Union{Nothing, Missing, S}}, ::Type{Any}) where {S} = Any promote_rule(::Type{Union{Nothing, Missing, S}}, ::Type{T}) where {T,S} = Union{Nothing, Missing, promote_type(T, S)} +convert(::Type{Union{T, Missing}}, x::Union{T, Missing}) where {T} = x convert(::Type{Union{T, Missing}}, x) where {T} = convert(T, x) # To fix ambiguities convert(::Type{Missing}, ::Missing) = missing convert(::Type{Union{Nothing, Missing}}, x::Union{Nothing, Missing}) = x +convert(::Type{Union{Nothing, Missing, T}}, x::Union{Nothing, Missing, T}) where {T} = x convert(::Type{Union{Nothing, Missing}}, x) = throw(MethodError(convert, (Union{Nothing, Missing}, x))) # To print more appropriate message than "T not defined" diff --git a/base/some.jl b/base/some.jl index fe4ced04e0bb4..c2f51bdf6a45b 100644 --- a/base/some.jl +++ b/base/some.jl @@ -18,9 +18,10 @@ promote_rule(::Type{Some{T}}, ::Type{Nothing}) where {T} = Union{Some{T}, Nothin convert(::Type{Some{T}}, x::Some) where {T} = Some{T}(convert(T, x.value)) convert(::Type{Union{Some{T}, Nothing}}, x::Some) where {T} = convert(Some{T}, x) +convert(::Type{Union{T, Nothing}}, x::Union{T, Nothing}) where {T} = x convert(::Type{Union{T, Nothing}}, x::Any) where {T} = convert(T, x) -convert(::Type{Nothing}, x::Any) = throw(MethodError(convert, (Nothing, x))) convert(::Type{Nothing}, x::Nothing) = nothing +convert(::Type{Nothing}, x::Any) = throw(MethodError(convert, (Nothing, x))) function show(io::IO, x::Some) if get(io, :typeinfo, Any) == typeof(x) diff --git a/src/subtype.c b/src/subtype.c index 2a5666951870d..be2a22b6ab7e8 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -521,7 +521,7 @@ static int var_outside(jl_stenv_t *e, jl_tvar_t *x, jl_tvar_t *y) return 0; } -static jl_value_t *intersect_ufirst(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth); +static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth); // check that type var `b` is <: `a`, and update b's upper bound. static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param) @@ -539,7 +539,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param) // for this to work we need to compute issub(left,right) before issub(right,left), // since otherwise the issub(a, bb.ub) check in var_gt becomes vacuous. if (e->intersection) { - jl_value_t *ub = intersect_ufirst(bb->ub, a, e, bb->depth0); + jl_value_t *ub = intersect_aside(bb->ub, a, e, bb->depth0); if (ub != (jl_value_t*)b) bb->ub = ub; } @@ -1328,16 +1328,32 @@ JL_DLLEXPORT int jl_isa(jl_value_t *x, jl_value_t *t) static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param); +static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e); + +// intersect in nested union environment, similar to subtype_ccheck +static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth) +{ + jl_value_t *res; + int savedepth = e->invdepth; + jl_unionstate_t oldRunions = e->Runions; + e->invdepth = depth; + + res = intersect_all(x, y, e); + + e->Runions = oldRunions; + e->invdepth = savedepth; + return res; +} + static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t *e, int8_t R, int param) { if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) { - jl_value_t *a=NULL, *b=NULL, *save=NULL; jl_savedenv_t se; - JL_GC_PUSH3(&a, &b, &save); - save_env(e, &save, &se); - a = R ? intersect(x, u->a, e, param) : intersect(u->a, x, e, param); - restore_env(e, NULL, &se); - b = R ? intersect(x, u->b, e, param) : intersect(u->b, x, e, param); - free(se.buf); + jl_value_t *a=NULL, *b=NULL; + JL_GC_PUSH2(&a, &b); + jl_unionstate_t oldRunions = e->Runions; + a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e); + b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e); + e->Runions = oldRunions; jl_value_t *i = simple_join(a,b); JL_GC_POP(); return i; @@ -1347,21 +1363,6 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t return R ? intersect(x, choice, e, param) : intersect(choice, x, e, param); } -static jl_value_t *intersect_ufirst(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth) -{ - jl_value_t *res; - int savedepth = e->invdepth; - e->invdepth = depth; - if (jl_is_uniontype(x) && jl_is_typevar(y)) - res = intersect_union(y, (jl_uniontype_t*)x, e, 0, 0); - else if (jl_is_typevar(x) && jl_is_uniontype(y)) - res = intersect_union(x, (jl_uniontype_t*)y, e, 1, 0); - else - res = intersect(x, y, e, 0); - e->invdepth = savedepth; - return res; -} - // set a variable to a non-type constant static jl_value_t *set_var_to_const(jl_varbinding_t *bb, jl_value_t *v JL_MAYBE_UNROOTED, jl_varbinding_t *othervar) { @@ -1386,13 +1387,11 @@ static jl_value_t *set_var_to_const(jl_varbinding_t *bb, jl_value_t *v JL_MAYBE_ static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e) { - jl_value_t *root=NULL; jl_savedenv_t se; int ret=0; + jl_value_t *root=NULL; jl_savedenv_t se; JL_GC_PUSH1(&root); save_env(e, &root, &se); - if (subtype_in_env(a, b, e)) - ret = 1; - else - restore_env(e, root, &se); + int ret = subtype_in_env(a, b, e); + restore_env(e, root, &se); free(se.buf); JL_GC_POP(); return ret; @@ -1402,7 +1401,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int { jl_varbinding_t *bb = lookup(e, b); if (bb == NULL) - return R ? intersect_ufirst(a, b->ub, e, 0) : intersect_ufirst(b->ub, a, e, 0); + return R ? intersect_aside(a, b->ub, e, 0) : intersect_aside(b->ub, a, e, 0); if (bb->lb == bb->ub && jl_is_typevar(bb->lb)) return intersect(a, bb->lb, e, param); if (!jl_is_type(a) && !jl_is_typevar(a)) @@ -1410,7 +1409,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int int d = bb->depth0; jl_value_t *root=NULL; jl_savedenv_t se; if (param == 2) { - jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d); + jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d); JL_GC_PUSH2(&ub, &root); save_env(e, &root, &se); int issub = subtype_in_env(bb->lb, ub, e); @@ -1448,10 +1447,10 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int if (try_subtype_in_env(bb->ub, a, e)) return (jl_value_t*)b; } - return R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d); + return R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d); } else if (bb->concrete || bb->constraintkind == 1) { - jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d); + jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d); JL_GC_PUSH1(&ub); if (ub == jl_bottom_type || !subtype_in_env(bb->lb, a, e)) { JL_GC_POP(); @@ -1471,7 +1470,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int return a; } assert(bb->constraintkind == 3); - jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d); + jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d); if (ub == jl_bottom_type) return jl_bottom_type; if (jl_is_typevar(a)) @@ -1492,7 +1491,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int root = NULL; JL_GC_PUSH2(&root, &ub); save_env(e, &root, &se); - jl_value_t *ii = R ? intersect_ufirst(a, bb->lb, e, d) : intersect_ufirst(bb->lb, a, e, d); + jl_value_t *ii = R ? intersect_aside(a, bb->lb, e, d) : intersect_aside(bb->lb, a, e, d); if (ii == jl_bottom_type) { restore_env(e, root, &se); ii = (jl_value_t*)b; @@ -2045,7 +2044,7 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa return jl_bottom_type; jl_value_t *ub=NULL, *lb=NULL; JL_GC_PUSH2(&lb, &ub); - ub = intersect_ufirst(xub, yub, e, xx ? xx->depth0 : 0); + ub = intersect_aside(xub, yub, e, xx ? xx->depth0 : 0); lb = simple_join(xlb, ylb); if (yy) { if (lb != y) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index b5f6204361230..f35831758a9de 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -176,7 +176,8 @@ end convert(T::Type{<:Symmetric}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m) convert(T::Type{<:Hermitian}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m) -const HermOrSym{T,S} = Union{Hermitian{T,S}, Symmetric{T,S}} +const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}} +const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}} const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}} const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}} @@ -427,11 +428,17 @@ mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,<:StridedMatrix}) *(A::AbstractMatrix, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = A * adjB.parent # ambiguities with transposed AbstractMatrix methods in linalg/matmul.jl +*(transA::Transpose{<:Any,<:RealHermSym}, transB::Transpose{<:Any,<:RealHermSym}) = transA * transB.parent +*(transA::Transpose{<:Any,<:RealHermSym}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA * transB.parent *(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA.parent * transB.parent +*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSym}) = transA.parent * transB *(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSymComplexHerm}) = transA.parent * transB *(transA::Transpose{<:Any,<:RealHermSymComplexHerm}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA * transB.parent +*(adjA::Adjoint{<:Any,<:RealHermSym}, adjB::Adjoint{<:Any,<:RealHermSym}) = adjA * adjB.parent *(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA.parent * adjB.parent +*(adjA::Adjoint{<:Any,<:RealHermSym}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA * adjB.parent *(adjA::Adjoint{<:Any,<:RealHermSymComplexSym}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA * adjB.parent +*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSym}) = adjA.parent * adjB *(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSymComplexSym}) = adjA.parent * adjB # ambiguities with AbstractTriangular diff --git a/test/ambiguous.jl b/test/ambiguous.jl index 02d076d5eb8ba..fafd937bcb221 100644 --- a/test/ambiguous.jl +++ b/test/ambiguous.jl @@ -275,6 +275,7 @@ end pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, (Type{Union{T, Nothing}} where T, Core.Compiler.Some))) pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}})) pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}})) + pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Union{Nothing,T}},Union{Nothing,T}} where T)) @test need_to_handle_undef_sparam == Set() end let need_to_handle_undef_sparam = @@ -299,6 +300,12 @@ end pop!(need_to_handle_undef_sparam, which(Base.convert, (Type{Union{T, Nothing}} where T, Some))) pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}})) pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}})) + pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Nothing,T}},Union{Nothing,T}} where T)) + pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,T}},Union{Missing,T}} where T)) + pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,Nothing,T}},Union{Missing,Nothing,T}} where T)) + pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Nothing,T}},Type{Any}} where T)) + pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Missing,T}},Type{Any}} where T)) + pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Missing,Nothing,T}},Type{Any}} where T)) @test need_to_handle_undef_sparam == Set() end end diff --git a/test/subtype.jl b/test/subtype.jl index b2e710ed99ce9..c68a5966864b1 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -847,10 +847,14 @@ function test_intersection() @testintersect(Ref{@UnionAll T @UnionAll S Tuple{T,S}}, Ref{@UnionAll T Tuple{T,T}}, Bottom) + # both of these answers seem acceptable + #@testintersect(Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular}, + # Tuple{AbstractArray{T,N}, AbstractArray{T,N}} where N where T, + # Union{Tuple{T,T} where T<:UpperTriangular, + # Tuple{T,T} where T<:UnitUpperTriangular}) @testintersect(Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular}, Tuple{AbstractArray{T,N}, AbstractArray{T,N}} where N where T, - Union{Tuple{T,T} where T<:UpperTriangular, - Tuple{T,T} where T<:UnitUpperTriangular}) + Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular}) @testintersect(DataType, Type, DataType) @testintersect(DataType, Type{T} where T<:Integer, Type{T} where T<:Integer) @@ -1333,3 +1337,14 @@ struct A28256{names, T<:NamedTuple{names, <:Tuple}} x::T end @test A28256{(:a,), NamedTuple{(:a,),Tuple{Int}}}((a=1,)) isa A28256 + +# issue #25752 +@testintersect(Base.RefValue, Ref{Union{Int,T}} where T, + Base.RefValue{Union{Int,T}} where T) +# issue #29269 +@testintersect((Tuple{Int, Array{T}} where T), + (Tuple{Any, Vector{Union{Missing,T}}} where T), + (Tuple{Int, Vector{Union{Missing,T}}} where T)) +@testintersect((Tuple{Int, Array{T}} where T), + (Tuple{Any, Vector{Union{Missing,Nothing,T}}} where T), + (Tuple{Int, Vector{Union{Missing,Nothing,T}}} where T))