Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Subtype: some performance tuning. (JuliaLang#56007)
Browse files Browse the repository at this point in the history
The main motivation of this PR is to fix JuliaLang#55807.
dc689fe tries to remove the slow
`may_contain_union_decision` check by re-organizing the code path. Now
the fast path has been removed and most of its optimization has been
integrated into the preserved slow path.
Since the slow path stores all inner ∃ decisions on the outer most R
stack, there might be overflow risk.
aee69a4 should fix that concern.

The reported MWE now becomes
```julia
  0.000002 seconds
  0.000040 seconds (105 allocations: 4.828 KiB, 52.00% compilation time)
  0.000023 seconds (105 allocations: 4.828 KiB, 49.36% compilation time)
  0.000026 seconds (105 allocations: 4.828 KiB, 50.38% compilation time)
  0.000027 seconds (105 allocations: 4.828 KiB, 54.95% compilation time)
  0.000019 seconds (106 allocations: 4.922 KiB, 49.73% compilation time)
  0.000024 seconds (105 allocations: 4.828 KiB, 52.24% compilation time)
```

Local bench also shows that 72855cd slightly accelerates
`OmniPackage.jl`'s loading
```julia
julia> @time using OmniPackage
# v1.11rc4
 20.525278 seconds (25.36 M allocations: 1.606 GiB, 8.48% gc time, 12.89% compilation time: 77% of which was recompilation)
# v1.11rc4+aee69a4+72855cd 
 19.527871 seconds (24.92 M allocations: 1.593 GiB, 8.88% gc time, 15.13% compilation time: 82% of which was recompilation)
```
N5N3 authored and Zentrik committed Oct 12, 2024
1 parent a9627fc commit 6b3921f
Showing 1 changed file with 173 additions and 125 deletions.
298 changes: 173 additions & 125 deletions src/subtype.c
Original file line number Diff line number Diff line change
@@ -39,20 +39,24 @@ extern "C" {
// Union type decision points are discovered while the algorithm works.
// If a new Union decision is encountered, the `more` flag is set to tell
// the forall/exists loop to grow the stack.
// TODO: the stack probably needs to be artificially large because of some
// deeper problem (see #21191) and could be shrunk once that is fixed

typedef struct jl_bits_stack_t {
uint32_t data[16];
struct jl_bits_stack_t *next;
} jl_bits_stack_t;

typedef struct {
int16_t depth;
int16_t more;
int16_t used;
uint32_t stack[100]; // stack of bits represented as a bit vector
jl_bits_stack_t stack;
} jl_unionstate_t;

typedef struct {
int16_t depth;
int16_t more;
int16_t used;
void *stack;
uint8_t *stack;
} jl_saved_unionstate_t;

// Linked list storing the type variable environment. A new jl_varbinding_t
@@ -131,37 +135,111 @@ static jl_varbinding_t *lookup(jl_stenv_t *e, jl_tvar_t *v) JL_GLOBALLY_ROOTED J
}
#endif

// union-stack tools

static int statestack_get(jl_unionstate_t *st, int i) JL_NOTSAFEPOINT
{
assert(i >= 0 && i < sizeof(st->stack) * 8);
assert(i >= 0 && i <= 32767); // limited by the depth bit.
// get the `i`th bit in an array of 32-bit words
return (st->stack[i>>5] & (1u<<(i&31))) != 0;
jl_bits_stack_t *stack = &st->stack;
while (i >= sizeof(stack->data) * 8) {
// We should have set this bit.
assert(stack->next);
stack = stack->next;
i -= sizeof(stack->data) * 8;
}
return (stack->data[i>>5] & (1u<<(i&31))) != 0;
}

static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
{
assert(i >= 0 && i < sizeof(st->stack) * 8);
assert(i >= 0 && i <= 32767); // limited by the depth bit.
jl_bits_stack_t *stack = &st->stack;
while (i >= sizeof(stack->data) * 8) {
if (__unlikely(stack->next == NULL)) {
stack->next = (jl_bits_stack_t *)malloc(sizeof(jl_bits_stack_t));
stack->next->next = NULL;
}
stack = stack->next;
i -= sizeof(stack->data) * 8;
}
if (val)
st->stack[i>>5] |= (1u<<(i&31));
stack->data[i>>5] |= (1u<<(i&31));
else
st->stack[i>>5] &= ~(1u<<(i&31));
stack->data[i>>5] &= ~(1u<<(i&31));
}

#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)

static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->more == 0)
return 0;
// reset `used` and let `pick_union_decision` clean the stack.
state->used = state->more;
statestack_set(state, state->used - 1, 1);
return 1;
}

#define push_unionstate(saved, src) \
do { \
(saved)->depth = (src)->depth; \
(saved)->more = (src)->more; \
(saved)->used = (src)->used; \
(saved)->stack = alloca(((src)->used+7)/8); \
memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \
static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
return ui;
}

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
do {
if (pick_union_decision(e, R))
u = ((jl_uniontype_t*)u)->b;
else
u = ((jl_uniontype_t*)u)->a;
} while (jl_is_uniontype(u));
return u;
}

#define push_unionstate(saved, src) \
do { \
(saved)->depth = (src)->depth; \
(saved)->more = (src)->more; \
(saved)->used = (src)->used; \
jl_bits_stack_t *srcstack = &(src)->stack; \
int pushbits = ((saved)->used+7)/8; \
(saved)->stack = (uint8_t *)alloca(pushbits); \
for (int n = 0; n < pushbits; n += sizeof(srcstack->data)) { \
assert(srcstack != NULL); \
int rest = pushbits - n; \
if (rest > sizeof(srcstack->data)) \
rest = sizeof(srcstack->data); \
memcpy(&(saved)->stack[n], &srcstack->data, rest); \
srcstack = srcstack->next; \
} \
} while (0);

#define pop_unionstate(dst, saved) \
do { \
(dst)->depth = (saved)->depth; \
(dst)->more = (saved)->more; \
(dst)->used = (saved)->used; \
memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \
#define pop_unionstate(dst, saved) \
do { \
(dst)->depth = (saved)->depth; \
(dst)->more = (saved)->more; \
(dst)->used = (saved)->used; \
jl_bits_stack_t *dststack = &(dst)->stack; \
int popbits = ((saved)->used+7)/8; \
for (int n = 0; n < popbits; n += sizeof(dststack->data)) { \
assert(dststack != NULL); \
int rest = popbits - n; \
if (rest > sizeof(dststack->data)) \
rest = sizeof(dststack->data); \
memcpy(&dststack->data, &(saved)->stack[n], rest); \
dststack = dststack->next; \
} \
} while (0);

static int current_env_length(jl_stenv_t *e)
@@ -264,6 +342,18 @@ static void free_env(jl_savedenv_t *se) JL_NOTSAFEPOINT
se->buf = NULL;
}

static void free_stenv(jl_stenv_t *e) JL_NOTSAFEPOINT
{
for (int R = 0; R < 2; R++) {
jl_bits_stack_t *temp = R ? e->Runions.stack.next : e->Lunions.stack.next;
while (temp != NULL) {
jl_bits_stack_t *next = temp->next;
free(temp);
temp = next;
}
}
}

static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPOINT
{
jl_value_t **roots = NULL;
@@ -587,44 +677,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)

static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)

static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->more == 0)
return 0;
// reset `used` and let `pick_union_decision` clean the stack.
state->used = state->more;
statestack_set(state, state->used - 1, 1);
return 1;
}

static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
return ui;
}

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
do {
if (pick_union_decision(e, R))
u = ((jl_uniontype_t*)u)->b;
else
u = ((jl_uniontype_t*)u)->a;
} while (jl_is_uniontype(u));
return u;
}

static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);

// subtype for variable bounds consistency check. needs its own forall/exists environment.
@@ -1513,37 +1565,12 @@ static int is_definite_length_tuple_type(jl_value_t *x)
return k == JL_VARARG_NONE || k == JL_VARARG_INT;
}

static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore);

static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT
static int is_exists_typevar(jl_value_t *x, jl_stenv_t *e)
{
if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type)
return 0;
if (jl_is_unionall(x))
return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log);
if (jl_is_datatype(x)) {
jl_datatype_t *xd = (jl_datatype_t *)x;
for (int i = 0; i < jl_nparams(xd); i++) {
jl_value_t *param = jl_tparam(xd, i);
if (jl_is_vararg(param))
param = jl_unwrap_vararg(param);
if (may_contain_union_decision(param, e, log))
return 1;
}
return 0;
}
if (!jl_is_typevar(x))
return jl_is_type(x);
jl_typeenv_t *t = log;
while (t != NULL) {
if (x == (jl_value_t *)t->var)
return 1;
t = t->prev;
}
jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log };
jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x);
return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) ||
may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog);
return 0;
jl_varbinding_t *vb = lookup(e, (jl_tvar_t *)x);
return vb && vb->right;
}

static int has_exists_typevar(jl_value_t *x, jl_stenv_t *e) JL_NOTSAFEPOINT
@@ -1574,31 +1601,9 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t
int kindy = !jl_has_free_typevars(y);
if (kindx && kindy)
return jl_subtype(x, y);
if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) {
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Lunions.used = e->Runions.used = 0;
e->Lunions.depth = e->Runions.depth = 0;
e->Lunions.more = e->Runions.more = 0;
int count = 0, noRmore = 0;
sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore);
pop_unionstate(&e->Runions, &oldRunions);
// We could skip the slow path safely if
// 1) `_∀_∃_subtype` has tested all cases
// 2) `_∀_∃_subtype` returns 1 && `x` and `y` contain no ∃ typevar
// Once `limit_slow == 1`, also skip it if
// 1) `_∀_∃_subtype` returns 0
// 2) the left `Union` looks big
// TODO: `limit_slow` ignores complexity from inner `local_∀_exists_subtype`.
if (limit_slow == -1)
limit_slow = kindx || kindy;
int skip = noRmore || (limit_slow && (count > 3 || !sub)) ||
(sub && (kindx || !has_exists_typevar(x, e)) &&
(kindy || !has_exists_typevar(y, e)));
if (skip)
e->Runions.more = oldRmore;
}
else {
// slow path
int has_exists = (!kindx && has_exists_typevar(x, e)) ||
(!kindy && has_exists_typevar(y, e));
if (has_exists && (is_exists_typevar(x, e) != is_exists_typevar(y, e))) {
e->Lunions.used = 0;
while (1) {
e->Lunions.more = 0;
@@ -1607,7 +1612,51 @@ static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t
if (!sub || !next_union_state(e, 0))
break;
}
return sub;
}
if (limit_slow == -1)
limit_slow = kindx || kindy;
jl_savedenv_t se;
save_env(e, &se, has_exists);
int count, limited = 0, ini_count = 0;
jl_saved_unionstate_t latestLunions = {0, 0, 0, NULL};
while (1) {
count = ini_count;
if (ini_count == 0)
e->Lunions.used = 0;
else
pop_unionstate(&e->Lunions, &latestLunions);
while (1) {
e->Lunions.more = 0;
e->Lunions.depth = 0;
if (count < 4) count++;
sub = subtype(x, y, e, param);
if (limit_slow && count == 4)
limited = 1;
if (!sub || !next_union_state(e, 0))
break;
if (limited || !has_exists || e->Runions.more == oldRmore) {
// re-save env and freeze the ∃decision for previous ∀Union
// Note: We could ignore the rest `∃Union` decisions if `x` and `y`
// contain no ∃ typevar, as they have no effect on env.
ini_count = count;
push_unionstate(&latestLunions, &e->Lunions);
re_save_env(e, &se, has_exists);
e->Runions.more = oldRmore;
}
}
if (sub || e->Runions.more == oldRmore)
break;
assert(e->Runions.more > oldRmore);
next_union_state(e, 1);
restore_env(e, &se, has_exists); // also restore Rdepth here
e->Runions.more = oldRmore;
}
if (!sub)
assert(e->Runions.more == oldRmore);
else if (limited || !has_exists)
e->Runions.more = oldRmore;
free_env(&se);
return sub;
}

@@ -1677,7 +1726,7 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_savede
}
}

static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore)
static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
{
// The depth recursion has the following shape, after simplification:
// ∀₁
@@ -1689,12 +1738,8 @@ static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, i

e->Lunions.used = 0;
int sub;
if (count) *count = 0;
if (noRmore) *noRmore = 1;
while (1) {
sub = exists_subtype(x, y, e, &se, param);
if (count) *count = (*count < 4) ? *count + 1 : 4;
if (noRmore) *noRmore = *noRmore && e->Runions.more == 0;
if (!sub || !next_union_state(e, 0))
break;
re_save_env(e, &se, 1);
@@ -1704,11 +1749,6 @@ static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, i
return sub;
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
{
return _forall_exists_subtype(x, y, e, param, NULL, NULL);
}

static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
{
e->vars = NULL;
@@ -1728,6 +1768,8 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
e->Lunions.used = 0; e->Runions.used = 0;
e->Lunions.stack.next = NULL;
e->Runions.stack.next = NULL;
}

// subtyping entry points
@@ -2157,6 +2199,7 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env,
}
init_stenv(&e, env, envsz);
int subtype = forall_exists_subtype(x, y, &e, 0);
free_stenv(&e);
assert(obvious_subtype == 3 || obvious_subtype == subtype || jl_has_free_typevars(x) || jl_has_free_typevars(y));
#ifndef NDEBUG
if (obvious_subtype == 0 || (obvious_subtype == 1 && envsz == 0))
@@ -2249,6 +2292,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
{
init_stenv(&e, NULL, 0);
int subtype = forall_exists_subtype(a, b, &e, 0);
free_stenv(&e);
assert(subtype_ab == 3 || subtype_ab == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b));
#ifndef NDEBUG
if (subtype_ab != 0 && subtype_ab != 1) // ensures that running in a debugger doesn't change the result
@@ -2265,6 +2309,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
{
init_stenv(&e, NULL, 0);
int subtype = forall_exists_subtype(b, a, &e, 0);
free_stenv(&e);
assert(subtype_ba == 3 || subtype_ba == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b));
#ifndef NDEBUG
if (subtype_ba != 0 && subtype_ba != 1) // ensures that running in a debugger doesn't change the result
@@ -4230,7 +4275,9 @@ static jl_value_t *intersect_types(jl_value_t *x, jl_value_t *y, int emptiness_o
init_stenv(&e, NULL, 0);
e.intersection = e.ignore_free = 1;
e.emptiness_only = emptiness_only;
return intersect_all(x, y, &e);
jl_value_t *ans = intersect_all(x, y, &e);
free_stenv(&e);
return ans;
}

JL_DLLEXPORT jl_value_t *jl_intersect_types(jl_value_t *x, jl_value_t *y)
@@ -4407,6 +4454,7 @@ jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t *
memset(env, 0, szb*sizeof(void*));
e.envsz = szb;
*ans = intersect_all(a, b, &e);
free_stenv(&e);
if (*ans == jl_bottom_type) goto bot;
// TODO: code dealing with method signatures is not able to handle unions, so if
// `a` and `b` are both tuples, we need to be careful and may not return a union,

0 comments on commit 6b3921f

Please sign in to comment.