diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 33bb6fd1ef..2a35c1efa0 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -407,7 +407,7 @@ void min_cluster_distance(const raft::handle_t& handle, * @param[in] handle The raft handle * @param[in] X The data in row-major format * [dim = n_samples x n_features] --c * @param[in] centroids Centroids data + * @param[in] centroids Centroids data * [dim = n_cluster x n_features] * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest * centroid and it's distance @@ -461,7 +461,6 @@ void min_cluster_and_distance( * [dim = n_samples_to_gather x n_features] * @param[in] n_samples_to_gather Number of sample to gather * @param[in] seed Seed for the shuffle - * @param[in] workspace Temporary workspace buffer which can get resized * */ template @@ -469,10 +468,9 @@ void shuffle_and_gather(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, - uint64_t seed, - rmm::device_uvector* workspace = nullptr) + uint64_t seed) { - detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed, workspace); + detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed); } /** @@ -949,7 +947,6 @@ void minClusterAndDistanceCompute( * [dim = n_samples_to_gather x n_features] * @param[in] n_samples_to_gather Number of sample to gather * @param[in] seed Seed for the shuffle - * @param[in] workspace Temporary workspace buffer which can get resized * */ template @@ -957,10 +954,9 @@ void shuffleAndGather(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, - uint64_t seed, - rmm::device_uvector* workspace = nullptr) + uint64_t seed) { - kmeans::shuffle_and_gather(handle, in, out, n_samples_to_gather, seed, workspace); + kmeans::shuffle_and_gather(handle, in, out, n_samples_to_gather, seed); } /** diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index 453e966394..698d23ac27 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -43,6 +43,20 @@ struct KmeansInputs { bool weighted; }; +template +void run_cluster_cost(const raft::handle_t& handle, + raft::device_vector_view minClusterDistance, + rmm::device_uvector& workspace, + raft::device_scalar_view clusterCost) +{ + raft::cluster::kmeans::cluster_cost( + handle, + minClusterDistance, + workspace, + clusterCost, + [] __device__(const DataT& a, const DataT& b) { return a + b; }); +} + template class KmeansTest : public ::testing::TestWithParam> { protected: @@ -55,6 +69,175 @@ class KmeansTest : public ::testing::TestWithParam> { { } + void apiTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 1; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); + + auto X = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + + raft::random::make_blobs(X.data_handle(), + labels.data_handle(), + n_samples, + n_features, + params.n_clusters, + stream, + true, + nullptr, + nullptr, + T(1.0), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); + rmm::device_uvector d_sample_weight(n_samples, stream); + thrust::fill( + thrust::cuda::par.on(stream), d_sample_weight.data(), d_sample_weight.data() + n_samples, 1); + auto weight_view = + raft::make_device_vector_view(d_sample_weight.data(), n_samples); + + T inertia = 0; + int n_iter = 0; + rmm::device_uvector workspace(0, stream); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + rmm::device_uvector inRankCp(0, stream); + auto X_view = + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); + auto centroids_view = + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + auto miniX = raft::make_device_matrix(handle, n_samples / 4, n_features); + + // Initialize kmeans on a portion of X + raft::cluster::kmeans::shuffle_and_gather( + handle, + X_view, + raft::make_device_matrix_view(miniX.data_handle(), miniX.extent(0), miniX.extent(1)), + miniX.extent(0), + params.rng_state.seed); + + raft::cluster::kmeans::init_plus_plus(handle, + params, + raft::make_device_matrix_view( + miniX.data_handle(), miniX.extent(0), miniX.extent(1)), + centroids_view, + workspace); + + auto minClusterDistance = raft::make_device_vector(handle, n_samples); + auto minClusterAndDistance = + raft::make_device_vector, int>(handle, n_samples); + auto L2NormX = raft::make_device_vector(handle, n_samples); + auto clusterCostBefore = raft::make_device_scalar(handle, 0); + auto clusterCostAfter = raft::make_device_scalar(handle, 0); + + raft::linalg::rowNorm(L2NormX.data_handle(), + X.data_handle(), + X.extent(1), + X.extent(0), + raft::linalg::L2Norm, + true, + stream); + + raft::cluster::kmeans::min_cluster_distance(handle, + X_view, + centroids_view, + minClusterDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); + + run_cluster_cost(handle, minClusterDistance.view(), workspace, clusterCostBefore.view()); + + // Run a fit of kmeans + raft::cluster::kmeans::fit_main(handle, + params, + X_view, + weight_view, + centroids_view, + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter), + workspace); + + // Check that the cluster cost decreased + raft::cluster::kmeans::min_cluster_distance(handle, + X_view, + centroids_view, + minClusterDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); + + run_cluster_cost(handle, minClusterDistance.view(), workspace, clusterCostAfter.view()); + T h_clusterCostBefore = T(0); + T h_clusterCostAfter = T(0); + raft::update_host(&h_clusterCostBefore, clusterCostBefore.data_handle(), 1, stream); + raft::update_host(&h_clusterCostAfter, clusterCostAfter.data_handle(), 1, stream); + ASSERT_TRUE(h_clusterCostAfter < h_clusterCostBefore); + + // Count samples in clusters using 2 methods and compare them + // Fill minClusterAndDistance + raft::cluster::kmeans::min_cluster_and_distance( + handle, + X_view, + raft::make_device_matrix_view( + d_centroids.data(), params.n_clusters, n_features), + minClusterAndDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); + raft::cluster::kmeans::KeyValueIndexOp conversion_op; + cub::TransformInputIterator, + raft::KeyValuePair*> + itr(minClusterAndDistance.data_handle(), conversion_op); + + auto sampleCountInCluster = raft::make_device_vector(handle, params.n_clusters); + auto weigthInCluster = raft::make_device_vector(handle, params.n_clusters); + auto newCentroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + raft::cluster::kmeans::update_centroids(handle, + X_view, + weight_view, + raft::make_device_matrix_view( + d_centroids.data(), params.n_clusters, n_features), + itr, + weigthInCluster.view(), + newCentroids.view()); + raft::cluster::kmeans::count_samples_in_cluster(handle, + params, + X_view, + L2NormX.view(), + newCentroids.view(), + workspace, + sampleCountInCluster.view()); + + ASSERT_TRUE(devArrMatch(sampleCountInCluster.data_handle(), + weigthInCluster.data_handle(), + params.n_clusters, + CompareApprox(params.tol))); + } + void basicTest() { testparams = ::testing::TestWithParam>::GetParam(); @@ -103,11 +286,11 @@ class KmeansTest : public ::testing::TestWithParam> { } raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); - handle.sync_stream(stream); - T inertia = 0; - int n_iter = 0; - auto X_view = (raft::device_matrix_view)X.view(); + T inertia = 0; + int n_iter = 0; + auto X_view = + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)); raft::cluster::kmeans_fit_predict( handle, @@ -135,7 +318,11 @@ class KmeansTest : public ::testing::TestWithParam> { } } - void SetUp() override { basicTest(); } + void SetUp() override + { + basicTest(); + apiTest(); + } protected: raft::handle_t handle; @@ -149,16 +336,16 @@ class KmeansTest : public ::testing::TestWithParam> { raft::cluster::KMeansParams params; }; -const std::vector> inputsf2 = {{1000, 32, 5, 0.0001, true}, - {1000, 32, 5, 0.0001, false}, - {1000, 100, 20, 0.0001, true}, - {1000, 100, 20, 0.0001, false}, - {10000, 32, 10, 0.0001, true}, - {10000, 32, 10, 0.0001, false}, - {10000, 100, 50, 0.0001, true}, - {10000, 100, 50, 0.0001, false}, - {10000, 1000, 200, 0.0001, true}, - {10000, 1000, 200, 0.0001, false}}; +const std::vector> inputsf2 = {{1000, 32, 5, 0.0001f, true}, + {1000, 32, 5, 0.0001f, false}, + {1000, 100, 20, 0.0001f, true}, + {1000, 100, 20, 0.0001f, false}, + {10000, 32, 10, 0.0001f, true}, + {10000, 32, 10, 0.0001f, false}, + {10000, 100, 50, 0.0001f, true}, + {10000, 100, 50, 0.0001f, false}, + {10000, 500, 100, 0.0001f, true}, + {10000, 500, 100, 0.0001f, false}}; const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, {1000, 32, 5, 0.0001, false}, @@ -168,8 +355,8 @@ const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, {10000, 32, 10, 0.0001, false}, {10000, 100, 50, 0.0001, true}, {10000, 100, 50, 0.0001, false}, - {10000, 1000, 200, 0.0001, true}, - {10000, 1000, 200, 0.0001, false}}; + {10000, 500, 100, 0.0001, true}, + {10000, 500, 100, 0.0001, false}}; typedef KmeansTest KmeansTestF; TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); }