Skip to content

Commit

Permalink
Add Tests for kmeans API (#982)
Browse files Browse the repository at this point in the history
This is also fixing a bug in kmeans `shuffle_and_gather` function and is needed by rapidsai/cuml#4713

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #982
  • Loading branch information
lowener authored Nov 8, 2022
1 parent 9b80321 commit 929be7a
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 26 deletions.
14 changes: 5 additions & 9 deletions cpp/include/raft/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ void min_cluster_distance(const raft::handle_t& handle,
* @param[in] handle The raft handle
* @param[in] X The data in row-major format
* [dim = n_samples x n_features]
-c * @param[in] centroids Centroids data
* @param[in] centroids Centroids data
* [dim = n_cluster x n_features]
* @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest
* centroid and it's distance
Expand Down Expand Up @@ -461,18 +461,16 @@ void min_cluster_and_distance(
* [dim = n_samples_to_gather x n_features]
* @param[in] n_samples_to_gather Number of sample to gather
* @param[in] seed Seed for the shuffle
* @param[in] workspace Temporary workspace buffer which can get resized
*
*/
template <typename DataT, typename IndexT>
void shuffle_and_gather(const raft::handle_t& handle,
raft::device_matrix_view<const DataT, IndexT> in,
raft::device_matrix_view<DataT, IndexT> out,
uint32_t n_samples_to_gather,
uint64_t seed,
rmm::device_uvector<char>* workspace = nullptr)
uint64_t seed)
{
detail::shuffleAndGather<DataT, IndexT>(handle, in, out, n_samples_to_gather, seed, workspace);
detail::shuffleAndGather<DataT, IndexT>(handle, in, out, n_samples_to_gather, seed);
}

/**
Expand Down Expand Up @@ -949,18 +947,16 @@ void minClusterAndDistanceCompute(
* [dim = n_samples_to_gather x n_features]
* @param[in] n_samples_to_gather Number of sample to gather
* @param[in] seed Seed for the shuffle
* @param[in] workspace Temporary workspace buffer which can get resized
*
*/
template <typename DataT, typename IndexT>
void shuffleAndGather(const raft::handle_t& handle,
raft::device_matrix_view<const DataT, IndexT> in,
raft::device_matrix_view<DataT, IndexT> out,
uint32_t n_samples_to_gather,
uint64_t seed,
rmm::device_uvector<char>* workspace = nullptr)
uint64_t seed)
{
kmeans::shuffle_and_gather<DataT, IndexT>(handle, in, out, n_samples_to_gather, seed, workspace);
kmeans::shuffle_and_gather<DataT, IndexT>(handle, in, out, n_samples_to_gather, seed);
}

/**
Expand Down
221 changes: 204 additions & 17 deletions cpp/test/cluster/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ struct KmeansInputs {
bool weighted;
};

template <typename DataT, typename IndexT>
void run_cluster_cost(const raft::handle_t& handle,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
rmm::device_uvector<char>& workspace,
raft::device_scalar_view<DataT> clusterCost)
{
raft::cluster::kmeans::cluster_cost(
handle,
minClusterDistance,
workspace,
clusterCost,
[] __device__(const DataT& a, const DataT& b) { return a + b; });
}

template <typename T>
class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
protected:
Expand All @@ -55,6 +69,175 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
{
}

void apiTest()
{
testparams = ::testing::TestWithParam<KmeansInputs<T>>::GetParam();

int n_samples = testparams.n_row;
int n_features = testparams.n_col;
params.n_clusters = testparams.n_clusters;
params.tol = testparams.tol;
params.n_init = 1;
params.rng_state.seed = 1;
params.oversampling_factor = 0;

raft::random::RngState rng(params.rng_state.seed, params.rng_state.type);

auto X = raft::make_device_matrix<T, int>(handle, n_samples, n_features);
auto labels = raft::make_device_vector<int, int>(handle, n_samples);

raft::random::make_blobs<T, int>(X.data_handle(),
labels.data_handle(),
n_samples,
n_features,
params.n_clusters,
stream,
true,
nullptr,
nullptr,
T(1.0),
false,
(T)-10.0f,
(T)10.0f,
(uint64_t)1234);
d_labels.resize(n_samples, stream);
d_labels_ref.resize(n_samples, stream);
d_centroids.resize(params.n_clusters * n_features, stream);
raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream);
rmm::device_uvector<T> d_sample_weight(n_samples, stream);
thrust::fill(
thrust::cuda::par.on(stream), d_sample_weight.data(), d_sample_weight.data() + n_samples, 1);
auto weight_view =
raft::make_device_vector_view<const T, int>(d_sample_weight.data(), n_samples);

T inertia = 0;
int n_iter = 0;
rmm::device_uvector<char> workspace(0, stream);
rmm::device_uvector<T> L2NormBuf_OR_DistBuf(0, stream);
rmm::device_uvector<T> inRankCp(0, stream);
auto X_view =
raft::make_device_matrix_view<const T, int>(X.data_handle(), X.extent(0), X.extent(1));
auto centroids_view =
raft::make_device_matrix_view<T, int>(d_centroids.data(), params.n_clusters, n_features);
auto miniX = raft::make_device_matrix<T, int>(handle, n_samples / 4, n_features);

// Initialize kmeans on a portion of X
raft::cluster::kmeans::shuffle_and_gather(
handle,
X_view,
raft::make_device_matrix_view<T, int>(miniX.data_handle(), miniX.extent(0), miniX.extent(1)),
miniX.extent(0),
params.rng_state.seed);

raft::cluster::kmeans::init_plus_plus(handle,
params,
raft::make_device_matrix_view<const T, int>(
miniX.data_handle(), miniX.extent(0), miniX.extent(1)),
centroids_view,
workspace);

auto minClusterDistance = raft::make_device_vector<T, int>(handle, n_samples);
auto minClusterAndDistance =
raft::make_device_vector<raft::KeyValuePair<int, T>, int>(handle, n_samples);
auto L2NormX = raft::make_device_vector<T, int>(handle, n_samples);
auto clusterCostBefore = raft::make_device_scalar<T>(handle, 0);
auto clusterCostAfter = raft::make_device_scalar<T>(handle, 0);

raft::linalg::rowNorm(L2NormX.data_handle(),
X.data_handle(),
X.extent(1),
X.extent(0),
raft::linalg::L2Norm,
true,
stream);

raft::cluster::kmeans::min_cluster_distance(handle,
X_view,
centroids_view,
minClusterDistance.view(),
L2NormX.view(),
L2NormBuf_OR_DistBuf,
params.metric,
params.batch_samples,
params.batch_centroids,
workspace);

run_cluster_cost(handle, minClusterDistance.view(), workspace, clusterCostBefore.view());

// Run a fit of kmeans
raft::cluster::kmeans::fit_main(handle,
params,
X_view,
weight_view,
centroids_view,
raft::make_host_scalar_view(&inertia),
raft::make_host_scalar_view(&n_iter),
workspace);

// Check that the cluster cost decreased
raft::cluster::kmeans::min_cluster_distance(handle,
X_view,
centroids_view,
minClusterDistance.view(),
L2NormX.view(),
L2NormBuf_OR_DistBuf,
params.metric,
params.batch_samples,
params.batch_centroids,
workspace);

run_cluster_cost(handle, minClusterDistance.view(), workspace, clusterCostAfter.view());
T h_clusterCostBefore = T(0);
T h_clusterCostAfter = T(0);
raft::update_host(&h_clusterCostBefore, clusterCostBefore.data_handle(), 1, stream);
raft::update_host(&h_clusterCostAfter, clusterCostAfter.data_handle(), 1, stream);
ASSERT_TRUE(h_clusterCostAfter < h_clusterCostBefore);

// Count samples in clusters using 2 methods and compare them
// Fill minClusterAndDistance
raft::cluster::kmeans::min_cluster_and_distance(
handle,
X_view,
raft::make_device_matrix_view<const T, int>(
d_centroids.data(), params.n_clusters, n_features),
minClusterAndDistance.view(),
L2NormX.view(),
L2NormBuf_OR_DistBuf,
params.metric,
params.batch_samples,
params.batch_centroids,
workspace);
raft::cluster::kmeans::KeyValueIndexOp<int, T> conversion_op;
cub::TransformInputIterator<int,
raft::cluster::kmeans::KeyValueIndexOp<int, T>,
raft::KeyValuePair<int, T>*>
itr(minClusterAndDistance.data_handle(), conversion_op);

auto sampleCountInCluster = raft::make_device_vector<T, int>(handle, params.n_clusters);
auto weigthInCluster = raft::make_device_vector<T, int>(handle, params.n_clusters);
auto newCentroids = raft::make_device_matrix<T, int>(handle, params.n_clusters, n_features);
raft::cluster::kmeans::update_centroids(handle,
X_view,
weight_view,
raft::make_device_matrix_view<const T, int>(
d_centroids.data(), params.n_clusters, n_features),
itr,
weigthInCluster.view(),
newCentroids.view());
raft::cluster::kmeans::count_samples_in_cluster(handle,
params,
X_view,
L2NormX.view(),
newCentroids.view(),
workspace,
sampleCountInCluster.view());

ASSERT_TRUE(devArrMatch(sampleCountInCluster.data_handle(),
weigthInCluster.data_handle(),
params.n_clusters,
CompareApprox<T>(params.tol)));
}

void basicTest()
{
testparams = ::testing::TestWithParam<KmeansInputs<T>>::GetParam();
Expand Down Expand Up @@ -103,11 +286,11 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
}

raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream);
handle.sync_stream(stream);

T inertia = 0;
int n_iter = 0;
auto X_view = (raft::device_matrix_view<const T, int>)X.view();
T inertia = 0;
int n_iter = 0;
auto X_view =
raft::make_device_matrix_view<const T, int>(X.data_handle(), X.extent(0), X.extent(1));

raft::cluster::kmeans_fit_predict<T, int>(
handle,
Expand Down Expand Up @@ -135,7 +318,11 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
}
}

void SetUp() override { basicTest(); }
void SetUp() override
{
basicTest();
apiTest();
}

protected:
raft::handle_t handle;
Expand All @@ -149,16 +336,16 @@ class KmeansTest : public ::testing::TestWithParam<KmeansInputs<T>> {
raft::cluster::KMeansParams params;
};

const std::vector<KmeansInputs<float>> inputsf2 = {{1000, 32, 5, 0.0001, true},
{1000, 32, 5, 0.0001, false},
{1000, 100, 20, 0.0001, true},
{1000, 100, 20, 0.0001, false},
{10000, 32, 10, 0.0001, true},
{10000, 32, 10, 0.0001, false},
{10000, 100, 50, 0.0001, true},
{10000, 100, 50, 0.0001, false},
{10000, 1000, 200, 0.0001, true},
{10000, 1000, 200, 0.0001, false}};
const std::vector<KmeansInputs<float>> inputsf2 = {{1000, 32, 5, 0.0001f, true},
{1000, 32, 5, 0.0001f, false},
{1000, 100, 20, 0.0001f, true},
{1000, 100, 20, 0.0001f, false},
{10000, 32, 10, 0.0001f, true},
{10000, 32, 10, 0.0001f, false},
{10000, 100, 50, 0.0001f, true},
{10000, 100, 50, 0.0001f, false},
{10000, 500, 100, 0.0001f, true},
{10000, 500, 100, 0.0001f, false}};

const std::vector<KmeansInputs<double>> inputsd2 = {{1000, 32, 5, 0.0001, true},
{1000, 32, 5, 0.0001, false},
Expand All @@ -168,8 +355,8 @@ const std::vector<KmeansInputs<double>> inputsd2 = {{1000, 32, 5, 0.0001, true},
{10000, 32, 10, 0.0001, false},
{10000, 100, 50, 0.0001, true},
{10000, 100, 50, 0.0001, false},
{10000, 1000, 200, 0.0001, true},
{10000, 1000, 200, 0.0001, false}};
{10000, 500, 100, 0.0001, true},
{10000, 500, 100, 0.0001, false}};

typedef KmeansTest<float> KmeansTestF;
TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); }
Expand Down

0 comments on commit 929be7a

Please sign in to comment.