Skip to content

Commit

Permalink
give wider/safer intersection result for vars used in both invariant …
Browse files Browse the repository at this point in the history
…and covariant position

fixes #41738
  • Loading branch information
JeffBezanson committed Aug 16, 2021
1 parent b06f813 commit bc00546
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 32 deletions.
62 changes: 33 additions & 29 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ typedef struct jl_stenv_t {
int ignore_free; // treat free vars as black boxes; used during intersection
int intersection; // true iff subtype is being called from intersection
int emptiness_only; // true iff intersection only needs to test for emptiness
int triangular; // when intersecting Ref{X} with Ref{<:Y}
} jl_stenv_t;

// state manipulation utilities
Expand Down Expand Up @@ -1411,6 +1412,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->ignore_free = 0;
e->intersection = 0;
e->emptiness_only = 0;
e->triangular = 0;
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
}
Expand Down Expand Up @@ -2169,7 +2171,7 @@ static void set_bound(jl_value_t **bound, jl_value_t *val, jl_tvar_t *v, jl_sten
return;
jl_varbinding_t *btemp = e->vars;
while (btemp != NULL) {
if (btemp->lb == (jl_value_t*)v && btemp->ub == (jl_value_t*)v &&
if ((btemp->lb == (jl_value_t*)v || btemp->ub == (jl_value_t*)v) &&
in_union(val, (jl_value_t*)btemp->var))
return;
btemp = btemp->prev;
Expand Down Expand Up @@ -2221,6 +2223,21 @@ static int reachable_var(jl_value_t *x, jl_tvar_t *y, jl_stenv_t *e)
return reachable_var(xv->ub, y, e) || reachable_var(xv->lb, y, e);
}

// check whether setting v == t implies v == SomeType{v}, which is unsatisfiable.
static int check_unsat_bound(jl_value_t *t, jl_tvar_t *v, jl_stenv_t *e) JL_NOTSAFEPOINT
{
if (var_occurs_inside(t, v, 0, 0))
return 1;
jl_varbinding_t *btemp = e->vars;
while (btemp != NULL) {
if (btemp->lb == (jl_value_t*)v && btemp->ub == (jl_value_t*)v &&
var_occurs_inside(t, btemp->var, 0, 0))
return 1;
btemp = btemp->prev;
}
return 0;
}

static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int8_t R, int param)
{
jl_varbinding_t *bb = lookup(e, b);
Expand Down Expand Up @@ -2250,7 +2267,9 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
ub = a;
}
else {
e->triangular++;
ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
e->triangular--;
save_env(e, &root, &se);
int issub = subtype_in_env_existential(bb->lb, ub, e, 0, d);
restore_env(e, root, &se);
Expand All @@ -2262,20 +2281,10 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
}
if (ub != (jl_value_t*)b) {
if (jl_has_free_typevars(ub)) {
// constraint X == Ref{X} is unsatisfiable. also check variables set equal to X.
if (var_occurs_inside(ub, b, 0, 0)) {
if (check_unsat_bound(ub, b, e)) {
JL_GC_POP();
return jl_bottom_type;
}
jl_varbinding_t *btemp = e->vars;
while (btemp != NULL) {
if (btemp->lb == (jl_value_t*)b && btemp->ub == (jl_value_t*)b &&
var_occurs_inside(ub, btemp->var, 0, 0)) {
JL_GC_POP();
return jl_bottom_type;
}
btemp = btemp->prev;
}
}
bb->ub = ub;
bb->lb = ub;
Expand All @@ -2286,7 +2295,13 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
if (ub == jl_bottom_type)
return jl_bottom_type;
if (bb->constraintkind == 0) {
if (bb->constraintkind == 1 || e->triangular) {
if (e->triangular && check_unsat_bound(ub, b, e))
return jl_bottom_type;
set_bound(&bb->ub, ub, b, e);
return (jl_value_t*)b;
}
else if (bb->constraintkind == 0) {
JL_GC_PUSH1(&ub);
if (!jl_is_typevar(a) && try_subtype_in_env(bb->ub, a, e, 0, d)) {
JL_GC_POP();
Expand All @@ -2295,10 +2310,6 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
JL_GC_POP();
return ub;
}
else if (bb->constraintkind == 1) {
set_bound(&bb->ub, ub, b, e);
return (jl_value_t*)b;
}
assert(bb->constraintkind == 2);
if (!jl_is_typevar(a)) {
if (ub == a && bb->lb != jl_bottom_type)
Expand Down Expand Up @@ -2563,11 +2574,11 @@ static jl_value_t *intersect_unionall_(jl_value_t *t, jl_unionall_t *u, jl_stenv

static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
jl_savedenv_t se, se2;
jl_value_t *res=NULL, *save=NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
JL_GC_PUSH5(&res, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
if (res != jl_bottom_type) {
Expand All @@ -2577,18 +2588,11 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
vb.constraintkind = vb.concrete ? 1 : 2;
res = intersect_unionall_(t, u, e, R, param, &vb);
}
else if (vb.occurs_cov) {
save_env(e, &save2, &se2);
else if (vb.occurs_cov && !var_occurs_invariant(u->body, u->var, 0)) {
restore_env(e, save, &se);
vb.occurs_cov = vb.occurs_inv = 0;
vb.lb = u->var->lb; vb.ub = u->var->ub;
vb.constraintkind = 1;
res2 = intersect_unionall_(t, u, e, R, param, &vb);
if (res2 != jl_bottom_type)
res = res2;
else
restore_env(e, save2, &se2);
free_env(&se2);
res = intersect_unionall_(t, u, e, R, param, &vb);
}
}
free_env(&se);
Expand Down
18 changes: 15 additions & 3 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ f31082(::Pair{B, C}, ::C, ::C) where {B, C} = 1
Tuple{Type{Val{T}},Int,T} where T)
@testintersect(Tuple{Type{Val{T}},Integer,T} where T,
Tuple{Type,Int,Integer},
Tuple{Type{Val{T}},Int,T} where T<:Integer)
Tuple{Type{Val{T}},Int,Integer} where T)
@testintersect(Tuple{Type{Val{T}},Integer,T} where T>:Integer,
Tuple{Type,Int,Integer},
Tuple{Type{Val{T}},Int,Integer} where T>:Integer)
Expand Down Expand Up @@ -1866,7 +1866,7 @@ let A = Tuple{Type{T} where T<:Ref, Ref, Union{T, Union{Ref{T}, T}} where T<:Ref
I = typeintersect(A,B)
# this was a case where <: disagreed with === (due to a badly-normalized type)
@test I == typeintersect(A,B)
@test I == Tuple{Type{T}, Ref{T}, Union{Ref{T}, T}} where T<:Ref
@test I == Tuple{Type{T}, Ref{T}, Ref} where T<:Ref
end

# issue #39218
Expand Down Expand Up @@ -1946,9 +1946,21 @@ let A = Tuple{UnionAll, Vector{Any}},
B = Tuple{Type{T}, T} where T<:AbstractArray,
I = typeintersect(A, B)
@test !isconcretetype(I)
@test_broken I == Tuple{Type{T}, Vector{Any}} where T<:AbstractArray
@test I == Tuple{Type{T}, Vector{Any}} where T<:AbstractArray
end

@testintersect(Tuple{Type{Vector{<:T}}, T} where {T<:Integer},
Tuple{Type{T}, AbstractArray} where T<:Array,
Bottom)

let A = Tuple{Any, Type{Ref{_A}} where _A},
B = Tuple{Type{T}, Type{<:Union{Ref{T}, T}}} where T,
I = typeintersect(A, B)
@test I != Union{}
# TODO: this intersection result is still too narrow
@test_broken Tuple{Type{Ref{Integer}}, Type{Ref{Integer}}} <: I
end

@testintersect(Tuple{Type{T}, T} where T<:(Tuple{Vararg{_A, _B}} where _B where _A),
Tuple{Type{Tuple{Vararg{_A, N}} where _A<:F}, Pair{N, F}} where F where N,
Bottom)

0 comments on commit bc00546

Please sign in to comment.