Skip to content

Commit

Permalink
fix remap csr; add test
Browse files Browse the repository at this point in the history
  • Loading branch information
JelmerBot committed Jan 10, 2025
1 parent a0914b6 commit d80a67f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
22 changes: 16 additions & 6 deletions fast_hbcc/hbcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from scipy.sparse import csr_array
from sklearn.utils import check_array, check_X_y
from sklearn.utils.validation import check_is_fitted, _check_sample_weight
from fast_hdbscan.core_graph import core_graph_clusters, core_graph_to_edge_list
from fast_hdbscan.core_graph import (
core_graph_clusters,
core_graph_to_edge_list,
CoreGraph,
)
from fast_hdbscan.hdbscan import (
HDBSCAN,
to_numpy_rec_array,
Expand Down Expand Up @@ -43,12 +47,18 @@ def check_literals(**kwargs):

def remap_csr_graph(graph, finite_index, internal_to_raw, num_points):
new_indptr = np.empty(num_points + 1, dtype=graph.indptr.dtype)
new_indptr[0] = 0
new_indptr[: finite_index[0] + 1] = 0
for idx, (start, end) in enumerate(zip(finite_index, finite_index[1:])):
new_indptr[start + 1 : end + 1] = graph.indptr[idx + 1]

start = end
new_indptr[finite_index[-1] + 1 :] = graph.indptr[-1]
graph.indices[:] = np.vectorize(internal_to_raw.get)(graph.indices)
return graph
return new_indptr


def remap_core_graph(graph, finite_index, internal_to_raw, num_points):
new_indptr = remap_csr_graph(graph, finite_index, internal_to_raw, num_points)
return CoreGraph(graph.weights, graph.distances, graph.indices, new_indptr)


def boundary_coefficient_from_csr(g):
Expand Down Expand Up @@ -403,10 +413,10 @@ def fit(self, X, y=None, sample_weight=None, **fit_params):
self._single_linkage_tree = remap_single_linkage_tree(
self._single_linkage_tree, internal_to_raw, outliers
)
self._core_graph = remap_csr_graph(
self._core_graph = remap_core_graph(
self._core_graph, finite_index, internal_to_raw, X.shape[0]
)

new_labels = np.full(X.shape[0], -1)
new_labels[finite_index] = self.labels_
self.labels_ = new_labels
Expand Down
6 changes: 5 additions & 1 deletion fast_hbcc/tests/test_hbcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def test_missing_data():
assert model.probabilities_[5] == 0
clean_indices = list(range(1, 5)) + list(range(6, 200))
clean_model = HBCC().fit(X_missing_data[clean_indices])
assert np.all(
clean_model._core_graph.indptr[1:]
== model._core_graph.indptr[np.array(clean_indices) + 1]
)
assert np.allclose(clean_model.labels_, model.labels_[clean_indices])


Expand Down Expand Up @@ -174,7 +178,7 @@ def test_hbcc_persistence_threshold():
assert np.all(model.labels_ == -1)


@pytest.mark.skipif(not HAVE_HDBSCAN, reason='requires HDBSCAN')
@pytest.mark.skipif(not HAVE_HDBSCAN, reason="requires HDBSCAN")
def test_attributes():
c = HBCC().fit(X)
assert isinstance(c.condensed_tree_, CondensedTree)
Expand Down

0 comments on commit d80a67f

Please sign in to comment.