Skip to content

Commit

Permalink
Finish bringing move_median back
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-riazanov committed Sep 21, 2022
1 parent 02a0ce1 commit cd49b4f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 25 deletions.
54 changes: 33 additions & 21 deletions bottleneck/src/move_median/move_median.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mm_update_init(mm_handle *mm, ai_t ai) {
node->idx = n_l;
++mm->n_l;
mm->l_first_leaf = FIRST_LEAF(mm->n_l);
heapify_large_node(mm, n_l);
heapify_large_node_mm(mm, n_l);
} else {
/* add new node to small heap */
mm->s_heap[n_s] = node;
Expand Down Expand Up @@ -179,7 +179,7 @@ mq_update_init(mq_handle *mq, ai_t ai) {
node->idx = n_l;
++mq->n_l;
mq->l_first_leaf = FIRST_LEAF(mq->n_l);
heapify_large_node(mq, n_l);
heapify_large_node_mq(mq, n_l);
} else {
/* add new node to small heap */
mq->s_heap[n_s] = node;
Expand Down Expand Up @@ -216,7 +216,7 @@ mm_update(mm_handle *mm, ai_t ai) {
if (node->region == SH) {
heapify_small_node_mm(mm, node->idx);
} else {
heapify_large_node(mm, node->idx);
heapify_large_node_mm(mm, node->idx);
}

/* return the median */
Expand All @@ -242,7 +242,7 @@ mq_update(mq_handle *mq, ai_t ai) {
if (node->region == SH) {
heapify_small_node_mq(mq, node->idx);
} else {
heapify_large_node(mq, node->idx);
heapify_large_node_mq(mq, node->idx);
}

/* return the median */
Expand Down Expand Up @@ -350,7 +350,7 @@ mm_update_init_nan(mm_handle *mm, ai_t ai) {
node->idx = n_l;
++mm->n_l;
mm->l_first_leaf = FIRST_LEAF(mm->n_l);
heapify_large_node(mm, n_l);
heapify_large_node_mm(mm, n_l);
} else {
/* add new node to small heap */
mm->s_heap[n_s] = node;
Expand Down Expand Up @@ -416,7 +416,7 @@ mq_update_init_nan(mq_handle *mq, ai_t ai) {
node->idx = n_l;
++mq->n_l;
mq->l_first_leaf = FIRST_LEAF(mq->n_l);
heapify_large_node(mq, n_l);
heapify_large_node_mq(mq, n_l);
} else {
/* add new node to small heap */
mq->s_heap[n_s] = node;
Expand Down Expand Up @@ -499,7 +499,7 @@ mm_update_nan(mm_handle *mm, ai_t ai) {
} else {
mm->l_first_leaf = FIRST_LEAF(mm->n_l);
}
heapify_large_node(mm, 0);
heapify_large_node_mm(mm, 0);
}
} else {
if (idx != n_s - 1) {
Expand Down Expand Up @@ -527,7 +527,7 @@ mm_update_nan(mm_handle *mm, ai_t ai) {
} else {
mm->l_first_leaf = FIRST_LEAF(mm->n_l);
}
heapify_large_node(mm, 0);
heapify_large_node_mm(mm, 0);

} else {
mm->s_first_leaf = FIRST_LEAF(mm->n_s);
Expand All @@ -550,7 +550,7 @@ mm_update_nan(mm_handle *mm, ai_t ai) {
if (idx != n_l - 1) {
l_heap[idx] = l_heap[n_l - 1];
l_heap[idx]->idx = idx;
heapify_large_node(mm, idx);
heapify_large_node_mm(mm, idx);
}
--mm->n_l;
if (mm->n_l == 0) {
Expand All @@ -566,7 +566,7 @@ mm_update_nan(mm_handle *mm, ai_t ai) {
l_heap[mm->n_l] = node2;
++mm->n_l;
mm->l_first_leaf = FIRST_LEAF(mm->n_l);
heapify_large_node(mm, node2->idx);
heapify_large_node_mm(mm, node2->idx);

/* plug hole in small heap */
if (n_s != 1) {
Expand All @@ -583,7 +583,7 @@ mm_update_nan(mm_handle *mm, ai_t ai) {
heapify_small_node_mm(mm, 0);
}
/* reorder large heap if needed */
heapify_large_node(mm, idx);
heapify_large_node_mm(mm, idx);
} else if (node->region == NA) {
/* insert node into nan heap */
n_array[idx] = node;
Expand All @@ -592,7 +592,7 @@ mm_update_nan(mm_handle *mm, ai_t ai) {
if (node->region == SH) {
heapify_small_node_mm(mm, idx);
} else if (node->region == LH) {
heapify_large_node(mm, idx);
heapify_large_node_mm(mm, idx);
} else {
/* ai is not NaN but oldest node is in nan array */
if (n_s > n_l) {
Expand All @@ -602,7 +602,7 @@ mm_update_nan(mm_handle *mm, ai_t ai) {
l_heap[n_l] = node;
++mm->n_l;
mm->l_first_leaf = FIRST_LEAF(mm->n_l);
heapify_large_node(mm, n_l);
heapify_large_node_mm(mm, n_l);
} else {
/* insert into small heap */
node->region = SH;
Expand Down Expand Up @@ -689,7 +689,7 @@ mq_update_nan(mq_handle *mq, ai_t ai) {
} else {
mq->l_first_leaf = FIRST_LEAF(mq->n_l);
}
heapify_large_node(mq, 0);
heapify_large_node_mq(mq, 0);
}
} else {
if (idx != n_s - 1) {
Expand Down Expand Up @@ -717,7 +717,7 @@ mq_update_nan(mq_handle *mq, ai_t ai) {
} else {
mq->l_first_leaf = FIRST_LEAF(mq->n_l);
}
heapify_large_node(mq, 0);
heapify_large_node_mq(mq, 0);

} else {
mq->s_first_leaf = FIRST_LEAF(mq->n_s);
Expand All @@ -741,7 +741,7 @@ mq_update_nan(mq_handle *mq, ai_t ai) {
if (idx != n_l - 1) {
l_heap[idx] = l_heap[n_l - 1];
l_heap[idx]->idx = idx;
heapify_large_node(mq, idx);
heapify_large_node_mq(mq, idx);
}
--mq->n_l;
if (mq->n_l == 0) {
Expand All @@ -757,7 +757,7 @@ mq_update_nan(mq_handle *mq, ai_t ai) {
l_heap[mq->n_l] = node2;
++mq->n_l;
mq->l_first_leaf = FIRST_LEAF(mq->n_l);
heapify_large_node(mq, node2->idx);
heapify_large_node_mq(mq, node2->idx);

/* plug hole in small heap */
if (n_s != 1) {
Expand All @@ -774,7 +774,7 @@ mq_update_nan(mq_handle *mq, ai_t ai) {
heapify_small_node_mq(mq, 0);
}
/* reorder large heap if needed */
heapify_large_node(mq, idx);
heapify_large_node_mq(mq, idx);
} else if (node->region == NA) {
/* insert node into nan heap */
n_array[idx] = node;
Expand All @@ -783,7 +783,7 @@ mq_update_nan(mq_handle *mq, ai_t ai) {
if (node->region == SH) {
heapify_small_node_mq(mq, idx);
} else if (node->region == LH) {
heapify_large_node(mq, idx);
heapify_large_node_mq(mq, idx);
} else {
/* ai is not NaN but oldest node is in nan array */
k_stat = mq_k_stat(mq, n_s + n_l + 1);
Expand All @@ -795,7 +795,7 @@ mq_update_nan(mq_handle *mq, ai_t ai) {
l_heap[n_l] = node;
++mq->n_l;
mq->l_first_leaf = FIRST_LEAF(mq->n_l);
heapify_large_node(mq, n_l);
heapify_large_node_mq(mq, n_l);
} else {
/* insert into small heap */
node->region = SH;
Expand Down Expand Up @@ -866,6 +866,18 @@ mq_free(mq_handle *mq) {
-----------------------------------------------------------------------------
*/

/* Return the current median */
static inline ai_t
mm_get_median(mm_handle *mm) {
idx_t n_total = mm->n_l + mm->n_s;
if (n_total < mm->min_count)
return MM_NAN();
if (min(mm->window, n_total) % 2 == 1)
return mm->s_heap[0]->ai;
return (mm->s_heap[0]->ai + mm->l_heap[0]->ai) / 2.0;
}


/* function to find the current index of element correspodning to the quantile */
static inline idx_t mq_k_stat(mq_handle *mq, idx_t idx) {
return (idx_t) floor((idx - 1) * mq->quantile) + 1;
Expand Down Expand Up @@ -943,7 +955,7 @@ HEAPIFY_SMALL_NODE(mq)
/* mtype is mm (move_median) or mq (move_quantile) */
#define HEAPIFY_LARGE_NODE(mtype) \
static inline void \
heapify_large_node##mtype(mtype##_handle *mtype, idx_t idx) { \
heapify_large_node_##mtype(mtype##_handle *mtype, idx_t idx) { \
idx_t idx2; \
mm_node *node; \
mm_node *node2; \
Expand Down
2 changes: 1 addition & 1 deletion bottleneck/src/move_median/move_median_debug.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ai_t *mm_move_median(ai_t *a, idx_t length, idx_t window, idx_t min_count, doubl
idx_t i;

out = malloc(length * sizeof(ai_t));
mm = mm_new_nan(window, min_count, quantile);
mm = mm_new_nan(window, min_count);
for (i=0; i < length; i++) {
if (i < window) {
out[i] = mm_update_init_nan(mm, a[i]);
Expand Down
9 changes: 6 additions & 3 deletions bottleneck/tests/move_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def test_move_quantile_with_nans():
func = bn.move_quantile
func0 = bn.slow.move_quantile
rs = np.random.RandomState([1, 2, 3])
for size in [1, 2, 3, 5, 9, 10, 20, 21]:
for size in [1, 2, 3, 5, 9]:
# for size in [1, 2, 3, 5, 9, 10, 20, 21]:
for _ in range(REPEAT_QUANTILE):
# 0 and 1 are important edge cases
for q in [0., 1., rs.rand()]:
Expand All @@ -234,7 +235,8 @@ def test_move_quantile_without_nans():
func = bn.move_quantile
func0 = bn.slow.move_quantile
rs = np.random.RandomState([1, 2, 3])
for size in [1, 2, 3, 5, 9, 10, 20, 21]:
for size in [1, 2, 3, 5, 9]:
# for size in [1, 2, 3, 5, 9, 10, 20, 21]:
for _ in range(REPEAT_QUANTILE):
for q in [0., 1., rs.rand()]:
a = np.arange(size, dtype=np.float64)
Expand Down Expand Up @@ -296,7 +298,8 @@ def test_move_quantile_with_infs_and_nans():
inf_minf_nan_fracs = [triple for triple in itertools.product(fracs, fracs, fracs) if np.sum(triple) <= 1]
total = 0
# for size in [1, 2, 3, 5, 9, 10, 20, 21, 47, 48]:
for size in [1, 2, 3, 5, 9, 10, 20, 21]:
# for size in [1, 2, 3, 5, 9, 10, 20, 21]:
for size in [1, 2, 3, 5, 9]:
print(size)
for min_count in [1, 2, 3, size//2, size - 1, size]:
if min_count < 1 or min_count > size:
Expand Down

0 comments on commit cd49b4f

Please sign in to comment.