diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 40e6728dbe..e93368fa3c 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -818,9 +818,9 @@ void initScalableKMeansPlusPlus(raft::device_resources const& handle, template void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { @@ -982,10 +982,10 @@ void kmeans_fit(raft::device_resources const& handle, template void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, bool normalize_weight, raft::host_scalar_view inertia) { @@ -1122,10 +1122,10 @@ void kmeans_predict(raft::device_resources const& handle, template void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { diff --git a/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh b/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh new file mode 100644 index 0000000000..edc74a085f --- /dev/null +++ b/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include + +#include + +#include +#include + +namespace raft::cluster::detail { + +template +void compute_dispersion(raft::device_resources const& handle, + raft::device_matrix_view X, + KMeansParams& params, + raft::device_matrix_view centroids_view, + raft::device_vector_view labels, + raft::device_vector_view clusterSizes, + rmm::device_uvector& workspace, + raft::host_vector_view clusterDispertionView, + raft::host_vector_view resultsView, + raft::host_scalar_view residual, + raft::host_scalar_view n_iter, + int val, + idx_t n, + idx_t d) +{ + auto centroids_const_view = + raft::make_device_matrix_view(centroids_view.data_handle(), val, d); + + idx_t* clusterSizes_ptr = clusterSizes.data_handle(); + auto cluster_sizes_view = + raft::make_device_vector_view(clusterSizes_ptr, val); + + params.n_clusters = val; + + raft::cluster::detail::kmeans_fit_predict( + handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter); + + detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace); + + resultsView[val] = residual[0]; + clusterDispertionView[val] = raft::stats::cluster_dispersion( + handle, centroids_const_view, cluster_sizes_view, std::nullopt, n); +} + +template +void find_k(raft::device_resources const& handle, + raft::device_matrix_view X, + raft::host_scalar_view best_k, + raft::host_scalar_view residual, + raft::host_scalar_view n_iter, + idx_t kmax, + idx_t kmin = 1, + idx_t maxiter = 100, + value_t tol = 1e-2) +{ + idx_t n = X.extent(0); + idx_t d = X.extent(1); + + RAFT_EXPECTS(n >= 1, "n must be >= 1"); + RAFT_EXPECTS(d >= 1, "d must be >= 1"); + RAFT_EXPECTS(kmin >= 1, "kmin must be >= 1"); + RAFT_EXPECTS(kmax <= n, "kmax must be <= number of data samples in X"); + RAFT_EXPECTS(tol >= 0, "tolerance must be >= 0"); + RAFT_EXPECTS(maxiter >= 0, "maxiter must be >= 0"); + // Allocate memory + // Device memory + + auto centroids = raft::make_device_matrix(handle, kmax, X.extent(1)); + auto clusterSizes = raft::make_device_vector(handle, kmax); + auto labels = raft::make_device_vector(handle, n); + + rmm::device_uvector workspace(0, handle.get_stream()); + + idx_t* clusterSizes_ptr = clusterSizes.data_handle(); + + // Host memory + auto results = raft::make_host_vector(kmax + 1); + auto clusterDispersion = raft::make_host_vector(kmax + 1); + + auto clusterDispertionView = clusterDispersion.view(); + auto resultsView = results.view(); + + // Loop to find *best* k + // Perform k-means in binary search + int left = kmin; // must be at least 2 + int right = kmax; // int(floor(len(data)/2)) #assumption of clusters of size 2 at least + int mid = ((unsigned int)left + (unsigned int)right) >> 1; + int oldmid = mid; + int tests = 0; + double objective[3]; // 0= left of mid, 1= right of mid + if (left == 1) left = 2; // at least do 2 clusters + + KMeansParams params; + params.max_iter = maxiter; + params.tol = tol; + + auto centroids_view = + raft::make_device_matrix_view(centroids.data_handle(), left, d); + compute_dispersion(handle, + X, + params, + centroids_view, + labels.view(), + clusterSizes.view(), + workspace, + clusterDispertionView, + resultsView, + residual, + n_iter, + left, + n, + d); + + // eval right edge0 + resultsView[right] = 1e20; + while (resultsView[right] > resultsView[left] && tests < 3) { + centroids_view = + raft::make_device_matrix_view(centroids.data_handle(), right, d); + compute_dispersion(handle, + X, + params, + centroids_view, + labels.view(), + clusterSizes.view(), + workspace, + clusterDispertionView, + resultsView, + residual, + n_iter, + right, + n, + d); + + tests += 1; + } + + objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; + objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; + while (left < right - 1) { + resultsView[mid] = 1e20; + tests = 0; + while (resultsView[mid] > resultsView[left] && tests < 3) { + centroids_view = + raft::make_device_matrix_view(centroids.data_handle(), mid, d); + compute_dispersion(handle, + X, + params, + centroids_view, + labels.view(), + clusterSizes.view(), + workspace, + clusterDispertionView, + resultsView, + residual, + n_iter, + mid, + n, + d); + + if (resultsView[mid] > resultsView[left] && (mid + 1) < right) { + mid += 1; + resultsView[mid] = 1e20; + } else if (resultsView[mid] > resultsView[left] && (mid - 1) > left) { + mid -= 1; + resultsView[mid] = 1e20; + } + tests += 1; + } + + // maximize Calinski-Harabasz Index, minimize resid/ cluster + objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; + objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; + objective[2] = (n - mid) / (mid - 1) * clusterDispertionView[mid] / resultsView[mid]; + objective[0] = (objective[2] - objective[0]) / (mid - left); + objective[1] = (objective[1] - objective[2]) / (right - mid); + + if (objective[0] > 0 && objective[1] < 0) { + // our point is in the left-of-mid side + right = mid; + } else { + left = mid; + } + oldmid = mid; + mid = ((unsigned int)right + (unsigned int)left) >> 1; + } + + best_k[0] = right; + objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; + objective[1] = (n - oldmid) / (oldmid - 1) * clusterDispertionView[oldmid] / resultsView[oldmid]; + if (objective[1] < objective[0]) { best_k[0] = left; } + + // if best_k isn't what we just ran, re-run to get correct centroids and dist data on return-> + // this saves memory + if (best_k[0] != oldmid) { + auto centroids_view = + raft::make_device_matrix_view(centroids.data_handle(), best_k[0], d); + + params.n_clusters = best_k[0]; + raft::cluster::detail::kmeans_fit_predict(handle, + params, + X, + std::nullopt, + std::make_optional(centroids_view), + labels.view(), + residual, + n_iter); + } +} +} // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index ac9e66d5da..da5f0458ad 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -261,6 +262,60 @@ void transform(raft::device_resources const& handle, handle, params, X, centroids, n_samples, n_features, X_new); } +/** + * Automatically find the optimal value of k using a binary search. + * This method maximizes the Calinski-Harabasz Index while minimizing the per-cluster inertia. + * + * @code{.cpp} + * #include + * #include + * #include + * + * #include + * + * using namespace raft::cluster; + * + * raft::handle_t handle; + * int n_samples = 100, n_features = 15, n_clusters = 10; + * auto X = raft::make_device_matrix(handle, n_samples, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * + * raft::random::make_blobs(handle, X, labels, n_clusters); + * + * auto best_k = raft::make_host_scalar(0); + * auto n_iter = raft::make_host_scalar(0); + * auto inertia = raft::make_host_scalar(0); + * + * kmeans::find_k(handle, X, best_k.view(), inertia.view(), n_iter.view(), n_clusters+1); + * + * @endcode + * + * @tparam idx_t indexing type (should be integral) + * @tparam value_t value type (should be floating point) + * @param handle raft handle + * @param X input observations (shape n_samples, n_dims) + * @param best_k best k found from binary search + * @param inertia inertia of best k found + * @param n_iter number of iterations used to find best k + * @param kmax maximum k to try in search + * @param kmin minimum k to try in search (should be >= 1) + * @param maxiter maximum number of iterations to run + * @param tol tolerance for early stopping convergence + */ +template +void find_k(raft::device_resources const& handle, + raft::device_matrix_view X, + raft::host_scalar_view best_k, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, + idx_t kmax, + idx_t kmin = 1, + idx_t maxiter = 100, + value_t tol = 1e-3) +{ + detail::find_k(handle, X, best_k, inertia, n_iter, kmax, kmin, maxiter, tol); +} + /** * @brief Select centroids according to a sampling operation * diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 575e8cf84b..4b633864a3 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -78,8 +78,17 @@ endfunction() if(BUILD_TESTS) ConfigureTest( - NAME CLUSTER_TEST PATH test/cluster/kmeans.cu test/cluster/kmeans_balanced.cu - test/cluster/cluster_solvers.cu test/cluster/linkage.cu OPTIONAL DIST NN + NAME + CLUSTER_TEST + PATH + test/cluster/kmeans.cu + test/cluster/kmeans_balanced.cu + test/cluster/cluster_solvers.cu + test/cluster/linkage.cu + test/cluster/kmeans_find_k.cu + OPTIONAL + DIST + NN ) ConfigureTest( diff --git a/cpp/test/cluster/kmeans_find_k.cu b/cpp/test/cluster/kmeans_find_k.cu new file mode 100644 index 0000000000..5ac4ebd293 --- /dev/null +++ b/cpp/test/cluster/kmeans_find_k.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include + +#include +#include +#include +#include +#include + +#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED +#include +#endif + +namespace raft { + +template +struct KmeansFindKInputs { + int n_row; + int n_col; + int n_clusters; + T tol; + bool weighted; +}; + +template +class KmeansFindKTest : public ::testing::TestWithParam> { + protected: + KmeansFindKTest() : stream(handle.get_stream()), best_k(raft::make_host_scalar(0)) {} + + void basicTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + int n_clusters = testparams.n_clusters; + + auto X = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + + raft::random::make_blobs(X.data_handle(), + labels.data_handle(), + n_samples, + n_features, + n_clusters, + stream, + true, + nullptr, + nullptr, + T(.001), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + + auto inertia = raft::make_host_scalar(0); + auto n_iter = raft::make_host_scalar(0); + + auto X_view = + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); + + raft::cluster::kmeans::find_k( + handle, X_view, best_k.view(), inertia.view(), n_iter.view(), n_clusters); + + handle.sync_stream(stream); + } + + void SetUp() override { basicTest(); } + + protected: + raft::device_resources handle; + cudaStream_t stream; + KmeansFindKInputs testparams; + raft::host_scalar best_k; +}; + +const std::vector> inputsf2 = {{1000, 32, 8, 0.001f, true}, + {1000, 32, 8, 0.001f, false}, + {1000, 100, 20, 0.001f, true}, + {1000, 100, 20, 0.001f, false}, + {10000, 32, 10, 0.001f, true}, + {10000, 32, 10, 0.001f, false}, + {10000, 100, 50, 0.001f, true}, + {10000, 100, 50, 0.001f, false}, + {10000, 500, 100, 0.001f, true}, + {10000, 500, 100, 0.001f, false}}; + +const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}, + {10000, 500, 100, 0.0001, true}, + {10000, 500, 100, 0.0001, false}}; + +typedef KmeansFindKTest KmeansFindKTestF; +TEST_P(KmeansFindKTestF, Result) +{ + if (best_k.view()[0] != testparams.n_clusters) { + std::cout << best_k.view()[0] << " " << testparams.n_clusters << std::endl; + } + ASSERT_TRUE(best_k.view()[0] == testparams.n_clusters); +} + +typedef KmeansFindKTest KmeansFindKTestD; +TEST_P(KmeansFindKTestD, Result) +{ + if (best_k.view()[0] != testparams.n_clusters) { + std::cout << best_k.view()[0] << " " << testparams.n_clusters << std::endl; + } + + ASSERT_TRUE(best_k.view()[0] == testparams.n_clusters); +} + +INSTANTIATE_TEST_CASE_P(KmeansFindKTests, KmeansFindKTestF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(KmeansFindKTests, KmeansFindKTestD, ::testing::ValuesIn(inputsd2)); + +} // namespace raft