From 64256a2210e41be746e1e4675421f0c8d6f80bd7 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 27 Jul 2022 14:28:58 +0200 Subject: [PATCH 1/4] Fix raw pointer API --- cpp/include/raft/cluster/detail/kmeans.cuh | 30 +++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index f34005fc29..27993e8108 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -901,14 +901,14 @@ void kmeans_fit(handle_t const& handle, DataT& inertia, IndexT& n_iter) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); std::optional> sample_weightView = std::nullopt; if (sample_weight) - sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); + sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); + auto inertiaView = raft::make_host_scalar_view(&inertia); + auto n_iterView = raft::make_host_scalar_view(&n_iter); detail::kmeans_fit( handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView); @@ -1034,14 +1034,14 @@ void kmeans_predict(handle_t const& handle, bool normalize_weight, DataT& inertia) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + raft::make_device_matrix_view(centroids, params.n_clusters, n_features); std::optional> sample_weightView{std::nullopt}; if (sample_weight) sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); + raft::make_device_vector_view(sample_weight, n_samples)); + auto labelsView = raft::make_device_vector_view(labels, n_samples); auto inertiaView = raft::make_host_scalar_view(&inertia); detail::kmeans_predict(handle, @@ -1092,18 +1092,18 @@ void kmeans_fit_predict(handle_t const& handle, DataT& inertia, IndexT& n_iter) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); std::optional> sample_weightView{std::nullopt}; if (sample_weight) sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); + raft::make_device_vector_view(sample_weight, n_samples)); std::optional> centroidsView{std::nullopt}; if (centroids) centroidsView.emplace( raft::make_device_matrix_view(centroids, params.n_clusters, n_features)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); + auto labelsView = raft::make_device_vector_view(labels, n_samples); + auto inertiaView = raft::make_host_scalar_view(&inertia); + auto n_iterView = raft::make_host_scalar_view(&n_iter); detail::kmeans_fit_predict( handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView); @@ -1169,9 +1169,9 @@ void kmeans_transform(const raft::handle_t& handle, IndexT n_features, DataT* X_new) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + raft::make_device_matrix_view(centroids, params.n_clusters, n_features); auto X_newView = raft::make_device_matrix_view(X_new, n_samples, n_features); detail::kmeans_transform(handle, params, XView, centroidsView, X_newView); From 41392bb34b750746095ee655b3623bd36f9e862c Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 27 Jul 2022 15:50:25 +0200 Subject: [PATCH 2/4] Fix style --- cpp/include/raft/cluster/detail/kmeans.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 27993e8108..88aafe970e 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -906,7 +906,8 @@ void kmeans_fit(handle_t const& handle, raft::make_device_matrix_view(centroids, params.n_clusters, n_features); std::optional> sample_weightView = std::nullopt; if (sample_weight) - sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); + sample_weightView = + raft::make_device_vector_view(sample_weight, n_samples); auto inertiaView = raft::make_host_scalar_view(&inertia); auto n_iterView = raft::make_host_scalar_view(&n_iter); From 4710b1f792d988c4218d01586b86106fc50dc5b8 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 27 Jul 2022 17:38:53 +0200 Subject: [PATCH 3/4] Fix kmeans transform --- cpp/include/raft/cluster/detail/kmeans.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 88aafe970e..6768aa9b4d 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -1148,11 +1148,11 @@ void kmeans_transform(const raft::handle_t& handle, // datasetView [ns x n_features] - view representing the current batch of // input dataset auto datasetView = - raft::make_device_matrix_view(X.data() + n_features * dIdx, ns, n_features); + raft::make_device_matrix_view(X.data_handle() + n_features * dIdx, ns, n_features); // pairwiseDistanceView [ns x n_clusters] auto pairwiseDistanceView = raft::make_device_matrix_view( - X_new.data() + n_clusters * dIdx, ns, n_clusters); + X_new.data_handle() + n_clusters * dIdx, ns, n_clusters); // calculate pairwise distance between cluster centroids and current batch // of input dataset From 153e4e61a46ed8ba754ce13fd0d1b09a2c7f5dec Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 27 Jul 2022 17:40:54 +0200 Subject: [PATCH 4/4] Fix style --- cpp/include/raft/cluster/detail/kmeans.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 6768aa9b4d..303de77078 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -1147,8 +1147,8 @@ void kmeans_transform(const raft::handle_t& handle, // datasetView [ns x n_features] - view representing the current batch of // input dataset - auto datasetView = - raft::make_device_matrix_view(X.data_handle() + n_features * dIdx, ns, n_features); + auto datasetView = raft::make_device_matrix_view( + X.data_handle() + n_features * dIdx, ns, n_features); // pairwiseDistanceView [ns x n_clusters] auto pairwiseDistanceView = raft::make_device_matrix_view(