Skip to content

Commit

Permalink
Allow kmeans with no iterations (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
thempel authored Jan 12, 2022
1 parent 6c0d297 commit 86eb11b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
44 changes: 23 additions & 21 deletions deeptime/src/include/deeptime/clustering/bits/kmeans_bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,31 @@ inline std::tuple<np_array_nfc<T>, int, int, np_array<T>> cluster_loop(
auto currentCenters = np_centers;

std::vector<T> inertias;
inertias.reserve(max_iter);

do {
auto clusterResult = cluster<Metric>(np_chunk, currentCenters, n_threads);
currentCenters = std::get<0>(clusterResult);
const auto &assignments = std::get<1>(clusterResult);
auto cost = costFunction<Metric>(np_chunk, currentCenters, assignments, n_threads);
inertias.push_back(cost);
rel_change = (cost != 0.0) ? std::abs(cost - prev_cost) / cost : 0;
prev_cost = cost;
if (rel_change <= tolerance) {
converged = true;
} else {
if (!callback.is_none()) {
/* Acquire GIL before calling Python code */
py::gil_scoped_acquire acquire;
callback();
if (max_iter > 0) {
inertias.reserve(max_iter);

do {
auto clusterResult = cluster<Metric>(np_chunk, currentCenters, n_threads);
currentCenters = std::get<0>(clusterResult);
const auto &assignments = std::get<1>(clusterResult);
auto cost = costFunction<Metric>(np_chunk, currentCenters, assignments, n_threads);
inertias.push_back(cost);
rel_change = (cost != 0.0) ? std::abs(cost - prev_cost) / cost : 0;
prev_cost = cost;
if (rel_change <= tolerance) {
converged = true;
} else {
if (!callback.is_none()) {
/* Acquire GIL before calling Python code */
py::gil_scoped_acquire acquire;
callback();
}
}
}

it += 1;
} while (it < max_iter && !converged);
int res = converged ? 0 : 1;
it += 1;
} while (it < max_iter && !converged);
}
int res = max_iter <= 0 || converged ? 0 : 1;
np_array<T> npInertias({static_cast<py::ssize_t>(inertias.size())});
std::copy(inertias.begin(), inertias.end(), npInertias.mutable_data());
return std::make_tuple(currentCenters, res, it, npInertias);
Expand Down
8 changes: 8 additions & 0 deletions tests/clustering/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ def callback_loop():
assert init == 3
assert iter <= 2

def test_noiter(self):
initial_centers = np.array([[0, 0, 0], [1, 1, 1]]).astype(np.float32)
X = np.random.rand(100, 3)
kmeans, model = cluster_kmeans(X, k=2, cluster_centers=initial_centers, max_iter=0, n_jobs=1)

np.testing.assert_(model.converged)
np.testing.assert_array_equal(initial_centers, model.cluster_centers)


class TestKmeansResume(unittest.TestCase):

Expand Down

0 comments on commit 86eb11b

Please sign in to comment.