From 49672a4a825e8c3b3e931f7f55c288a1b47f3b5d Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sun, 19 Nov 2023 02:00:46 +0800 Subject: [PATCH] Widen diagonal var during `Type` unwrapping in `instanceof_tfunc` --- base/compiler/tfuncs.jl | 12 ++- base/essentials.jl | 5 + src/subtype.c | 205 +++++++++++++++++++++++++++++++++++++ test/compiler/inference.jl | 13 +++ test/core.jl | 8 ++ 5 files changed, 239 insertions(+), 4 deletions(-) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index eb43e77885d64a..4088ddf58368cc 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -95,7 +95,7 @@ add_tfunc(throw, 1, 1, @nospecs((𝕃::AbstractLattice, x)->Bottom), 0) # if isexact is false, the actual runtime type may (will) be a subtype of t # if isconcrete is true, the actual runtime type is definitely concrete (unreachable if not valid as a typeof) # if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int}) -function instanceof_tfunc(@nospecialize(t), astag::Bool=false) +function instanceof_tfunc(@nospecialize(t), astag::Bool=false, @nospecialize(troot) = t) if isa(t, Const) if isa(t.val, Type) && valid_as_lattice(t.val, astag) return t.val, true, isconcretetype(t.val), true @@ -103,6 +103,7 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false) return Bottom, true, false, false # runtime throws on non-Type end t = widenconst(t) + troot = widenconst(troot) if t === Bottom return Bottom, true, true, false # runtime unreachable elseif t === typeof(Bottom) || !hasintersect(t, Type) @@ -110,10 +111,13 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false) elseif isType(t) tp = t.parameters[1] valid_as_lattice(tp, astag) || return Bottom, true, false, false # runtime unreachable / throws on non-Type + if troot isa UnionAll + tp = widen_diagonal(tp, troot) + end return tp, !has_free_typevars(tp), isconcretetype(tp), true elseif isa(t, UnionAll) t′ = unwrap_unionall(t) - t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag) + t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag, rewrap_unionall(t, troot)) tr = rewrap_unionall(t′′, t) if t′′ isa DataType && t′′.name !== Tuple.name && !has_free_typevars(tr) # a real instance must be within the declared bounds of the type, @@ -128,8 +132,8 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false) end return tr, isexact, isconcrete, istype elseif isa(t, Union) - ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag) - tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag) + ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag, troot) + tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag, troot) isconcrete = isconcrete_a && isconcrete_b istype = istype_a && istype_b # most users already handle the Union case, so here we assume that diff --git a/base/essentials.jl b/base/essentials.jl index 106826d140c574..a9f3bfc40f6228 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -459,6 +459,11 @@ function rename_unionall(@nospecialize(u)) return UnionAll(nv, body{nv}) end +# remove concrete constraint on diagonal TypeVar if it comes from troot +function widen_diagonal(@nospecialize(t), troot::UnionAll) + body = ccall(:jl_widen_diagonal, Any, (Any, Any), t, troot) +end + function isvarargtype(@nospecialize(t)) return isa(t, Core.TypeofVararg) end diff --git a/src/subtype.c b/src/subtype.c index f80e31c58c46b4..edd752278a2852 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -4304,6 +4304,211 @@ int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv) return sub; } +// type utils +static void check_diagonal(jl_value_t *t, jl_varbinding_t *troot, int param) +{ + if (jl_is_uniontype(t)) { + int i, len = 0; + jl_varbinding_t *v; + for (v = troot; v != NULL; v = v->prev) + len++; + int8_t *occurs = (int8_t *)alloca(len); + for (v = troot, i = 0; v != NULL; v = v->prev, i++) + occurs[i] = v->occurs_inv | (v->occurs_cov << 2); + check_diagonal(((jl_uniontype_t *)t)->a, troot, param); + for (v = troot, i = 0; v != NULL; v = v->prev, i++) { + int8_t occurs_inv = occurs[i] & 3; + int8_t occurs_cov = occurs[i] >> 2; + occurs[i] = v->occurs_inv | (v->occurs_cov << 2); + v->occurs_inv = occurs_inv; + v->occurs_cov = occurs_cov; + } + check_diagonal(((jl_uniontype_t *)t)->b, troot, param); + for (v = troot, i = 0; v != NULL; v = v->prev, i++) { + if (v->occurs_inv < (occurs[i] & 3)) + v->occurs_inv = occurs[i] & 3; + if (v->occurs_cov < (occurs[i] >> 2)) + v->occurs_cov = occurs[i] >> 2; + } + } + else if (jl_is_unionall(t)) { + assert(troot != NULL); + jl_varbinding_t *v1 = troot, *v2 = troot->prev; + while (v2 != NULL) { + if (v2->var == ((jl_unionall_t *)t)->var) { + v1->prev = v2->prev; + break; + } + v1 = v2; + v2 = v2->prev; + } + check_diagonal(((jl_unionall_t *)t)->body, troot, param); + v1->prev = v2; + } + else if (jl_is_datatype(t)) { + int nparam = jl_is_tuple_type(t) ? 1 : 2; + if (nparam < param) nparam = param; + for (size_t i = 0; i < jl_nparams(t); i++) { + check_diagonal(jl_tparam(t, i), troot, nparam); + } + } + else if (jl_is_vararg(t)) { + jl_value_t *T = jl_unwrap_vararg(t); + jl_value_t *N = jl_unwrap_vararg_num(t); + int n = (N && jl_is_long(N)) ? jl_unbox_long(N) : 2; + if (T && n > 0) check_diagonal(T, troot, param); + if (T && n > 1) check_diagonal(T, troot, param); + if (N) check_diagonal(N, troot, 2); + } + else if (jl_is_typevar(t)) { + jl_varbinding_t *v = troot; + for (; v != NULL; v = v->prev) { + if (v->var == (jl_tvar_t *)t) { + if (param == 1 && v->occurs_cov < 2) v->occurs_cov++; + if (param == 2 && v->occurs_inv < 2) v->occurs_inv++; + break; + } + } + if (v == NULL) + check_diagonal(((jl_tvar_t *)t)->ub, troot, 0); + } +} + +static jl_value_t *insert_nondiagonal(jl_value_t *type, jl_varbinding_t *troot, int widen2ub) +{ + // we must replace each covariant occurrence of newvar with a different newvar2<:newvar (diagonal rule) + if (jl_is_typevar(type)) { + jl_varbinding_t *v = troot; + for (; v != NULL; v = v->prev) { + if (v->concrete && v->var == (jl_tvar_t *)type) + break; + } + if (v != NULL) { + if (widen2ub) { + type = ((jl_tvar_t *)type)->ub; + } + else { + if (v->innervars == NULL) + v->innervars = jl_alloc_array_1d(jl_array_any_type, 0); + jl_value_t *newvar = NULL, *lb = v->var->lb, *ub = (jl_value_t *)v->var; + jl_array_t *innervars = v->innervars; + JL_GC_PUSH4(&newvar, &lb, &ub, &innervars); + newvar = (jl_value_t *)jl_new_typevar(v->var->name, lb, ub); + jl_array_ptr_1d_push(innervars, newvar); + JL_GC_POP(); + type = newvar; + } + } + } + else if (jl_is_unionall(type)) { + jl_value_t *body = ((jl_unionall_t*)type)->body; + jl_tvar_t *var = ((jl_unionall_t*)type)->var; + jl_varbinding_t *v = troot; + for (; v != NULL; v = v->prev) { + if (v->var == (jl_tvar_t *)var) + break; + } + if (v == NULL) { + jl_value_t *newbody = insert_nondiagonal(body, troot, widen2ub); + jl_value_t *newvar = NULL; + JL_GC_PUSH2(&newbody, &newvar); + if (body == newbody || jl_has_typevar(newbody, var)) { + if (body != newbody) + newbody = jl_new_struct(jl_unionall_type, var, newbody); + // n.b. we do not widen lb, since that would be the wrong direction + newvar = insert_nondiagonal(var->ub, troot, widen2ub); + if (newvar != var->ub) { + newvar = (jl_value_t*)jl_new_typevar(var->name, var->lb, newvar); + newbody = jl_apply_type1(newbody, newvar); + newbody = jl_type_unionall((jl_tvar_t*)newvar, newbody); + } + } + type = newbody; + JL_GC_POP(); + } + } + else if (jl_is_uniontype(type)) { + jl_value_t *a = ((jl_uniontype_t*)type)->a; + jl_value_t *b = ((jl_uniontype_t*)type)->b; + jl_value_t *newa = NULL; + jl_value_t *newb = NULL; + JL_GC_PUSH2(&newa, &newb); + newa = insert_nondiagonal(a, troot, widen2ub); + newb = insert_nondiagonal(b, troot, widen2ub); + if (newa != a || newb != b) + type = jl_new_struct(jl_uniontype_type, newa, newb); + JL_GC_POP(); + } + else if (jl_is_vararg(type)) { + // As for Vararg we'd better widen it's var to ub as otherwise they are still diagonal + jl_value_t *t = jl_unwrap_vararg(type); + jl_value_t *n = jl_unwrap_vararg_num(type); + widen2ub = !(n && jl_is_long(n)) || jl_unbox_long(n) > 1; + jl_value_t *newt; + JL_GC_PUSH1(&newt); + newt = insert_nondiagonal(t, troot, widen2ub); + if (t != newt) + type = jl_new_struct(jl_vararg_type, newt, n); + JL_GC_POP(); + } + else if (jl_is_datatype(type)) { + if (jl_is_tuple_type(type)) { + jl_svec_t *newparams = NULL; + JL_GC_PUSH1(&newparams); + for (size_t i = 0; i < jl_nparams(type); i++) { + jl_value_t *elt = jl_tparam(type, i); + jl_value_t *newelt = insert_nondiagonal(elt, troot, widen2ub); + if (elt != newelt) { + if (!newparams) { + newparams = (jl_svec_t*)newelt; // temporary root + newparams = jl_svec_copy(((jl_datatype_t*)type)->parameters); + } + jl_svecset(newparams, i, newelt); + } + } + if (newparams) + type = (jl_value_t*)jl_apply_tuple_type(newparams, 1); + JL_GC_POP(); + } + } + return type; +} + +static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) { + check_diagonal(t, troot, 0); + int any_concrete = 0; + for (jl_varbinding_t *v = troot; v != NULL; v = v->prev) { + v->concrete = v->occurs_cov > 1 && v->occurs_inv == 0; + any_concrete |= v->concrete; + } + if (!any_concrete) + return t; // no diagonal + return insert_nondiagonal(t, troot, 0); +} + +static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot) +{ + jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot }; + jl_value_t *nt; + JL_GC_PUSH2(&vb.innervars, &nt); + if (jl_is_unionall(u->body)) + nt = widen_diagonal(t, (jl_unionall_t *)u->body, &vb); + else + nt = _widen_diagonal(t, &vb); + if (vb.innervars != NULL) { + for (size_t i = 0; i < jl_array_nrows(vb.innervars); i++) { + jl_tvar_t *var = (jl_tvar_t*)jl_array_ptr_ref(vb.innervars, i); + nt = jl_type_unionall(var, nt); + } + } + JL_GC_POP(); + return nt; +} + +JL_DLLEXPORT jl_value_t *jl_widen_diagonal(jl_value_t *t, jl_unionall_t *ua) +{ + return widen_diagonal(t, ua, NULL); +} // specificity comparison diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index ace17baeb5859b..3512bd5a53ab08 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5533,3 +5533,16 @@ function test_exit_bottom(s) n end @test only(Base.return_types(test_exit_bottom, Tuple{String})) == Int + +# Issue #52168 +f52168(x, t::Type) = x::NTuple{2, Base.inferencebarrier(t)::Type} +@test f52168((1, 2.), Any) === (1, 2.) + +# Issue #27031 +let x = 1, _Any = Any + @noinline bar27031(tt::Tuple{T,T}, ::Type{Val{T}}) where {T} = notsame27031(tt) + @noinline notsame27031(tt::Tuple{T, T}) where {T} = error() + @noinline notsame27031(tt::Tuple{T, S}) where {T, S} = "OK" + foo27031() = bar27031((x, 1.0), Val{_Any}) + @test foo27031() == "OK" +end diff --git a/test/core.jl b/test/core.jl index 00ab41e4ecd487..46c0ba7a32105d 100644 --- a/test/core.jl +++ b/test/core.jl @@ -8059,3 +8059,11 @@ check_globalref_lowering() = @insert_global let src = code_lowered(check_globalref_lowering)[1] @test length(src.code) == 2 end + +# Test correctness of widen_diagonal +let widen_diagonal(x::UnionAll) = Base.rewrap_unionall(Base.widen_diagonal(Base.unwrap_unionall(x), x), x), + check_widen_diagonal(x, y) = !<:(x, y) && x <: widen_diagonal(y) + @test Tuple{Int,Float64} <: widen_diagonal(NTuple) + @test Tuple{Int,Float64} <: widen_diagonal(Tuple{T,T} where {T}) + @test Union{Tuple{T}, Tuple{T,Int}} where {T} == widen_diagonal(Union{Tuple{T}, Tuple{T,Int}} where {T}) +end