Skip to content

Commit

Permalink
Fix kmeans pytest to correctly compute fp output error (#5426)
Browse files Browse the repository at this point in the history
-- This PR is on top build fix - #5425
-- fixes the fp error checking logic and set the tolerance to account for 3xTF32 based fusedL2NN kernel

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5426
  • Loading branch information
mdoijade authored May 18, 2023
1 parent b910401 commit dd432c6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
16 changes: 12 additions & 4 deletions cpp/src/hdbscan/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ struct FixConnectivitiesRedOp {
value_t* core_dists;
value_idx m;

DI FixConnectivitiesRedOp() : colors(0), m(0) {}

FixConnectivitiesRedOp(value_idx* colors_, value_t* core_dists_, value_idx m_)
: colors(colors_), core_dists(core_dists_), m(m_){};

typedef typename raft::KeyValuePair<value_idx, value_t> KVP;
DI void operator()(value_idx rit, KVP* out, const KVP& other)
DI void operator()(value_idx rit, KVP* out, const KVP& other) const
{
if (rit < m && other.value < std::numeric_limits<value_t>::max() &&
colors[rit] != colors[other.key]) {
Expand All @@ -78,7 +80,7 @@ struct FixConnectivitiesRedOp {
}
}

DI KVP operator()(value_idx rit, const KVP& a, const KVP& b)
DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const
{
if (rit < m && a.key > -1 && colors[rit] != colors[a.key]) {
value_t core_dist_rit = core_dists[rit];
Expand All @@ -97,12 +99,18 @@ struct FixConnectivitiesRedOp {
return b;
}

DI void init(value_t* out, value_t maxVal) { *out = maxVal; }
DI void init(KVP* out, value_t maxVal)
DI void init(value_t* out, value_t maxVal) const { *out = maxVal; }
DI void init(KVP* out, value_t maxVal) const
{
out->key = -1;
out->value = maxVal;
}

DI void init_key(value_t& out, value_idx idx) const { return; }
DI void init_key(KVP& out, value_idx idx) const { out.key = idx; }

DI value_t get_value(KVP& out) const { return out.value; }
DI value_t get_value(value_t& out) const { return out; }
};

/**
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_traditional_kmeans_plus_plus_init(
kmeans.fit(cp.asnumpy(X))
sk_score = kmeans.score(cp.asnumpy(X))

assert abs(cu_score - sk_score) <= cluster_std * 1.5
cp.testing.assert_allclose(cu_score, sk_score, atol=0.1, rtol=1e-4)


@pytest.mark.parametrize("nrows", [100, 500])
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_score(nrows, ncols, nclusters, random_state):
expected_score *= -1

cp.testing.assert_allclose(
actual_score, expected_score, atol=0.1, rtol=1e-5
actual_score, expected_score, atol=0.1, rtol=1e-4
)


Expand Down

0 comments on commit dd432c6

Please sign in to comment.