diff --git a/src/subtype.c b/src/subtype.c index 8a1ea03fdd6fd..c1e0f47acae77 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -39,20 +39,24 @@ extern "C" { // Union type decision points are discovered while the algorithm works. // If a new Union decision is encountered, the `more` flag is set to tell // the forall/exists loop to grow the stack. -// TODO: the stack probably needs to be artificially large because of some -// deeper problem (see #21191) and could be shrunk once that is fixed + +typedef struct jl_bits_stack_t { + uint32_t data[16]; + struct jl_bits_stack_t *next; +} jl_bits_stack_t; + typedef struct { int16_t depth; int16_t more; int16_t used; - uint32_t stack[100]; // stack of bits represented as a bit vector + jl_bits_stack_t stack; } jl_unionstate_t; typedef struct { int16_t depth; int16_t more; int16_t used; - void *stack; + uint8_t *stack; } jl_saved_unionstate_t; // Linked list storing the type variable environment. A new jl_varbinding_t @@ -121,37 +125,111 @@ static jl_varbinding_t *lookup(jl_stenv_t *e, jl_tvar_t *v) JL_GLOBALLY_ROOTED J } #endif +// union-stack tools + static int statestack_get(jl_unionstate_t *st, int i) JL_NOTSAFEPOINT { - assert(i >= 0 && i < sizeof(st->stack) * 8); + assert(i >= 0 && i <= 32767); // limited by the depth bit. // get the `i`th bit in an array of 32-bit words - return (st->stack[i>>5] & (1u<<(i&31))) != 0; + jl_bits_stack_t *stack = &st->stack; + while (i >= sizeof(stack->data) * 8) { + // We should have set this bit. + assert(stack->next); + stack = stack->next; + i -= sizeof(stack->data) * 8; + } + return (stack->data[i>>5] & (1u<<(i&31))) != 0; } static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT { - assert(i >= 0 && i < sizeof(st->stack) * 8); + assert(i >= 0 && i <= 32767); // limited by the depth bit. + jl_bits_stack_t *stack = &st->stack; + while (i >= sizeof(stack->data) * 8) { + if (__unlikely(stack->next == NULL)) { + stack->next = (jl_bits_stack_t *)malloc(sizeof(jl_bits_stack_t)); + stack->next->next = NULL; + } + stack = stack->next; + i -= sizeof(stack->data) * 8; + } if (val) - st->stack[i>>5] |= (1u<<(i&31)); + stack->data[i>>5] |= (1u<<(i&31)); else - st->stack[i>>5] &= ~(1u<<(i&31)); + stack->data[i>>5] &= ~(1u<<(i&31)); +} + +#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0) + +static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT +{ + jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; + if (state->more == 0) + return 0; + // reset `used` and let `pick_union_decision` clean the stack. + state->used = state->more; + statestack_set(state, state->used - 1, 1); + return 1; +} + +static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT +{ + jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; + if (state->depth >= state->used) { + statestack_set(state, state->used, 0); + state->used++; + } + int ui = statestack_get(state, state->depth); + state->depth++; + if (ui == 0) + state->more = state->depth; // memorize that this was the deepest available choice + return ui; } -#define push_unionstate(saved, src) \ - do { \ - (saved)->depth = (src)->depth; \ - (saved)->more = (src)->more; \ - (saved)->used = (src)->used; \ - (saved)->stack = alloca(((src)->used+7)/8); \ - memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \ +static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT +{ + do { + if (pick_union_decision(e, R)) + u = ((jl_uniontype_t*)u)->b; + else + u = ((jl_uniontype_t*)u)->a; + } while (jl_is_uniontype(u)); + return u; +} + +#define push_unionstate(saved, src) \ + do { \ + (saved)->depth = (src)->depth; \ + (saved)->more = (src)->more; \ + (saved)->used = (src)->used; \ + jl_bits_stack_t *srcstack = &(src)->stack; \ + int pushbits = ((saved)->used+7)/8; \ + (saved)->stack = (uint8_t *)alloca(pushbits); \ + for (int n = 0; n < pushbits; n += sizeof(srcstack->data)) { \ + assert(srcstack != NULL); \ + int rest = pushbits - n; \ + if (rest > sizeof(srcstack->data)) \ + rest = sizeof(srcstack->data); \ + memcpy(&(saved)->stack[n], &srcstack->data, rest); \ + srcstack = srcstack->next; \ + } \ } while (0); -#define pop_unionstate(dst, saved) \ - do { \ - (dst)->depth = (saved)->depth; \ - (dst)->more = (saved)->more; \ - (dst)->used = (saved)->used; \ - memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \ +#define pop_unionstate(dst, saved) \ + do { \ + (dst)->depth = (saved)->depth; \ + (dst)->more = (saved)->more; \ + (dst)->used = (saved)->used; \ + jl_bits_stack_t *dststack = &(dst)->stack; \ + int popbits = ((saved)->used+7)/8; \ + for (int n = 0; n < popbits; n += sizeof(dststack->data)) { \ + assert(dststack != NULL); \ + int rest = popbits - n; \ + if (rest > sizeof(dststack->data)) \ + rest = sizeof(dststack->data); \ + memcpy(&dststack->data, &(saved)->stack[n], rest); \ + dststack = dststack->next; \ + } \ } while (0); static int current_env_length(jl_stenv_t *e) @@ -254,6 +332,18 @@ static void free_env(jl_savedenv_t *se) JL_NOTSAFEPOINT se->buf = NULL; } +static void free_stenv(jl_stenv_t *e) JL_NOTSAFEPOINT +{ + for (int R = 0; R < 2; R++) { + jl_bits_stack_t *temp = R ? e->Runions.stack.next : e->Lunions.stack.next; + while (temp != NULL) { + jl_bits_stack_t *next = temp->next; + free(temp); + temp = next; + } + } +} + static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPOINT { jl_value_t **roots = NULL; @@ -586,42 +676,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi) static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param); -static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT -{ - jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; - if (state->more == 0) - return 0; - // reset `used` and let `pick_union_decision` clean the stack. - state->used = state->more; - statestack_set(state, state->used - 1, 1); - return 1; -} - -static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT -{ - jl_unionstate_t *state = R ? &e->Runions : &e->Lunions; - if (state->depth >= state->used) { - statestack_set(state, state->used, 0); - state->used++; - } - int ui = statestack_get(state, state->depth); - state->depth++; - if (ui == 0) - state->more = state->depth; // memorize that this was the deepest available choice - return ui; -} - -static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT -{ - do { - if (pick_union_decision(e, R)) - u = ((jl_uniontype_t*)u)->b; - else - u = ((jl_uniontype_t*)u)->a; - } while (jl_is_uniontype(u)); - return u; -} - static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow); // subtype for variable bounds consistency check. needs its own forall/exists environment. @@ -1486,37 +1540,12 @@ static int is_definite_length_tuple_type(jl_value_t *x) return k == JL_VARARG_NONE || k == JL_VARARG_INT; } -static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore); - -static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT +static int is_exists_typevar(jl_value_t *x, jl_stenv_t *e) { - if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type) - return 0; - if (jl_is_unionall(x)) - return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log); - if (jl_is_datatype(x)) { - jl_datatype_t *xd = (jl_datatype_t *)x; - for (int i = 0; i < jl_nparams(xd); i++) { - jl_value_t *param = jl_tparam(xd, i); - if (jl_is_vararg(param)) - param = jl_unwrap_vararg(param); - if (may_contain_union_decision(param, e, log)) - return 1; - } - return 0; - } if (!jl_is_typevar(x)) - return jl_is_type(x); - jl_typeenv_t *t = log; - while (t != NULL) { - if (x == (jl_value_t *)t->var) - return 1; - t = t->prev; - } - jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log }; - jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x); - return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) || - may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog); + return 0; + jl_varbinding_t *vb = lookup(e, (jl_tvar_t *)x); + return vb && vb->right; } static int has_exists_typevar(jl_value_t *x, jl_stenv_t *e) JL_NOTSAFEPOINT @@ -1547,31 +1576,9 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t int kindy = !jl_has_free_typevars(y); if (kindx && kindy) return jl_subtype(x, y); - if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) { - jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions); - e->Lunions.used = e->Runions.used = 0; - e->Lunions.depth = e->Runions.depth = 0; - e->Lunions.more = e->Runions.more = 0; - int count = 0, noRmore = 0; - sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore); - pop_unionstate(&e->Runions, &oldRunions); - // We could skip the slow path safely if - // 1) `_∀_∃_subtype` has tested all cases - // 2) `_∀_∃_subtype` returns 1 && `x` and `y` contain no ∃ typevar - // Once `limit_slow == 1`, also skip it if - // 1) `_∀_∃_subtype` returns 0 - // 2) the left `Union` looks big - // TODO: `limit_slow` ignores complexity from inner `local_∀_exists_subtype`. - if (limit_slow == -1) - limit_slow = kindx || kindy; - int skip = noRmore || (limit_slow && (count > 3 || !sub)) || - (sub && (kindx || !has_exists_typevar(x, e)) && - (kindy || !has_exists_typevar(y, e))); - if (skip) - e->Runions.more = oldRmore; - } - else { - // slow path + int has_exists = (!kindx && has_exists_typevar(x, e)) || + (!kindy && has_exists_typevar(y, e)); + if (has_exists && (is_exists_typevar(x, e) != is_exists_typevar(y, e))) { e->Lunions.used = 0; while (1) { e->Lunions.more = 0; @@ -1580,7 +1587,51 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t if (!sub || !next_union_state(e, 0)) break; } + return sub; } + if (limit_slow == -1) + limit_slow = kindx || kindy; + jl_savedenv_t se; + save_env(e, &se, has_exists); + int count, limited = 0, ini_count = 0; + jl_saved_unionstate_t latestLunions = {0, 0, 0, NULL}; + while (1) { + count = ini_count; + if (ini_count == 0) + e->Lunions.used = 0; + else + pop_unionstate(&e->Lunions, &latestLunions); + while (1) { + e->Lunions.more = 0; + e->Lunions.depth = 0; + if (count < 4) count++; + sub = subtype(x, y, e, param); + if (limit_slow && count == 4) + limited = 1; + if (!sub || !next_union_state(e, 0)) + break; + if (limited || !has_exists || e->Runions.more == oldRmore) { + // re-save env and freeze the ∃decision for previous ∀Union + // Note: We could ignore the rest `∃Union` decisions if `x` and `y` + // contain no ∃ typevar, as they have no effect on env. + ini_count = count; + push_unionstate(&latestLunions, &e->Lunions); + re_save_env(e, &se, has_exists); + e->Runions.more = oldRmore; + } + } + if (sub || e->Runions.more == oldRmore) + break; + assert(e->Runions.more > oldRmore); + next_union_state(e, 1); + restore_env(e, &se, has_exists); // also restore Rdepth here + e->Runions.more = oldRmore; + } + if (!sub) + assert(e->Runions.more == oldRmore); + else if (limited || !has_exists) + e->Runions.more = oldRmore; + free_env(&se); return sub; } @@ -1650,7 +1701,7 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_savede } } -static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore) +static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param) { // The depth recursion has the following shape, after simplification: // ∀₁ @@ -1662,12 +1713,8 @@ static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, i e->Lunions.used = 0; int sub; - if (count) *count = 0; - if (noRmore) *noRmore = 1; while (1) { sub = exists_subtype(x, y, e, &se, param); - if (count) *count = (*count < 4) ? *count + 1 : 4; - if (noRmore) *noRmore = *noRmore && e->Runions.more == 0; if (!sub || !next_union_state(e, 0)) break; re_save_env(e, &se, 1); @@ -1677,11 +1724,6 @@ static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, i return sub; } -static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param) -{ - return _forall_exists_subtype(x, y, e, param, NULL, NULL); -} - static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz) { e->vars = NULL; @@ -1701,6 +1743,8 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz) e->Lunions.depth = 0; e->Runions.depth = 0; e->Lunions.more = 0; e->Runions.more = 0; e->Lunions.used = 0; e->Runions.used = 0; + e->Lunions.stack.next = NULL; + e->Runions.stack.next = NULL; } // subtyping entry points @@ -2130,6 +2174,7 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env, } init_stenv(&e, env, envsz); int subtype = forall_exists_subtype(x, y, &e, 0); + free_stenv(&e); assert(obvious_subtype == 3 || obvious_subtype == subtype || jl_has_free_typevars(x) || jl_has_free_typevars(y)); #ifndef NDEBUG if (obvious_subtype == 0 || (obvious_subtype == 1 && envsz == 0)) @@ -2222,6 +2267,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b) { init_stenv(&e, NULL, 0); int subtype = forall_exists_subtype(a, b, &e, 0); + free_stenv(&e); assert(subtype_ab == 3 || subtype_ab == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b)); #ifndef NDEBUG if (subtype_ab != 0 && subtype_ab != 1) // ensures that running in a debugger doesn't change the result @@ -2238,6 +2284,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b) { init_stenv(&e, NULL, 0); int subtype = forall_exists_subtype(b, a, &e, 0); + free_stenv(&e); assert(subtype_ba == 3 || subtype_ba == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b)); #ifndef NDEBUG if (subtype_ba != 0 && subtype_ba != 1) // ensures that running in a debugger doesn't change the result @@ -4089,7 +4136,9 @@ static jl_value_t *intersect_types(jl_value_t *x, jl_value_t *y, int emptiness_o init_stenv(&e, NULL, 0); e.intersection = e.ignore_free = 1; e.emptiness_only = emptiness_only; - return intersect_all(x, y, &e); + jl_value_t *ans = intersect_all(x, y, &e); + free_stenv(&e); + return ans; } JL_DLLEXPORT jl_value_t *jl_intersect_types(jl_value_t *x, jl_value_t *y) @@ -4266,6 +4315,7 @@ jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t * memset(env, 0, szb*sizeof(void*)); e.envsz = szb; *ans = intersect_all(a, b, &e); + free_stenv(&e); if (*ans == jl_bottom_type) goto bot; // TODO: code dealing with method signatures is not able to handle unions, so if // `a` and `b` are both tuples, we need to be careful and may not return a union,