diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index f34005fc29..303de77078 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -901,14 +901,15 @@ void kmeans_fit(handle_t const& handle, DataT& inertia, IndexT& n_iter) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); std::optional> sample_weightView = std::nullopt; if (sample_weight) - sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); + sample_weightView = + raft::make_device_vector_view(sample_weight, n_samples); + auto inertiaView = raft::make_host_scalar_view(&inertia); + auto n_iterView = raft::make_host_scalar_view(&n_iter); detail::kmeans_fit( handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView); @@ -1034,14 +1035,14 @@ void kmeans_predict(handle_t const& handle, bool normalize_weight, DataT& inertia) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + raft::make_device_matrix_view(centroids, params.n_clusters, n_features); std::optional> sample_weightView{std::nullopt}; if (sample_weight) sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); + raft::make_device_vector_view(sample_weight, n_samples)); + auto labelsView = raft::make_device_vector_view(labels, n_samples); auto inertiaView = raft::make_host_scalar_view(&inertia); detail::kmeans_predict(handle, @@ -1092,18 +1093,18 @@ void kmeans_fit_predict(handle_t const& handle, DataT& inertia, IndexT& n_iter) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); std::optional> sample_weightView{std::nullopt}; if (sample_weight) sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); + raft::make_device_vector_view(sample_weight, n_samples)); std::optional> centroidsView{std::nullopt}; if (centroids) centroidsView.emplace( raft::make_device_matrix_view(centroids, params.n_clusters, n_features)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); + auto labelsView = raft::make_device_vector_view(labels, n_samples); + auto inertiaView = raft::make_host_scalar_view(&inertia); + auto n_iterView = raft::make_host_scalar_view(&n_iter); detail::kmeans_fit_predict( handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView); @@ -1146,12 +1147,12 @@ void kmeans_transform(const raft::handle_t& handle, // datasetView [ns x n_features] - view representing the current batch of // input dataset - auto datasetView = - raft::make_device_matrix_view(X.data() + n_features * dIdx, ns, n_features); + auto datasetView = raft::make_device_matrix_view( + X.data_handle() + n_features * dIdx, ns, n_features); // pairwiseDistanceView [ns x n_clusters] auto pairwiseDistanceView = raft::make_device_matrix_view( - X_new.data() + n_clusters * dIdx, ns, n_clusters); + X_new.data_handle() + n_clusters * dIdx, ns, n_clusters); // calculate pairwise distance between cluster centroids and current batch // of input dataset @@ -1169,9 +1170,9 @@ void kmeans_transform(const raft::handle_t& handle, IndexT n_features, DataT* X_new) { - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + raft::make_device_matrix_view(centroids, params.n_clusters, n_features); auto X_newView = raft::make_device_matrix_view(X_new, n_samples, n_features); detail::kmeans_transform(handle, params, XView, centroidsView, X_newView);