From ba12f6a0f947bea48c0e91a405e0c8928c6ac365 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 20 Oct 2022 17:01:45 -0400 Subject: [PATCH] Updates to kmeans public API to fix cuml (#932) Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/932 --- cpp/include/raft/cluster/kmeans.cuh | 56 ++++++++++++++--------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 0ce35da4a5..40435b9580 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -251,11 +251,11 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * @param[in] workspace Temporary workspace buffer which can get resized * */ -template +template void sampleCentroids(const raft::handle_t& handle, - const raft::device_matrix_view& X, - const raft::device_vector_view& minClusterDistance, - const raft::device_vector_view& isSampleCentroid, + raft::device_matrix_view X, + raft::device_vector_view minClusterDistance, + raft::device_vector_view isSampleCentroid, SamplingOp& select_op, rmm::device_uvector& inRankCp, rmm::device_uvector& workspace) @@ -278,11 +278,11 @@ void sampleCentroids(const raft::handle_t& handle, * @param[in] reduction_op The reduction operation used for the cost * */ -template +template void computeClusterCost(const raft::handle_t& handle, - const raft::device_vector_view& minClusterDistance, + raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, - const raft::device_scalar_view& clusterCost, + raft::device_scalar_view clusterCost, ReductionOpT reduction_op) { detail::computeClusterCost( @@ -313,10 +313,10 @@ void computeClusterCost(const raft::handle_t& handle, template void minClusterDistanceCompute(const raft::handle_t& handle, const KMeansParams& params, - const raft::device_matrix_view& X, - const raft::device_matrix_view& centroids, - const raft::device_vector_view& minClusterDistance, - const raft::device_vector_view& L2NormX, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view minClusterDistance, + raft::device_vector_view L2NormX, rmm::device_uvector& L2NormBuf_OR_DistBuf, rmm::device_uvector& workspace) { @@ -352,10 +352,10 @@ template void minClusterAndDistanceCompute( const raft::handle_t& handle, const KMeansParams& params, - const raft::device_matrix_view X, - const raft::device_matrix_view centroids, - const raft::device_vector_view, IndexT>& minClusterAndDistance, - const raft::device_vector_view& L2NormX, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view, IndexT> minClusterAndDistance, + raft::device_vector_view L2NormX, rmm::device_uvector& L2NormBuf_OR_DistBuf, rmm::device_uvector& workspace) { @@ -382,8 +382,8 @@ void minClusterAndDistanceCompute( */ template void shuffleAndGather(const raft::handle_t& handle, - const raft::device_matrix_view& in, - const raft::device_matrix_view& out, + raft::device_matrix_view in, + raft::device_matrix_view out, uint32_t n_samples_to_gather, uint64_t seed, rmm::device_uvector* workspace = nullptr) @@ -413,11 +413,11 @@ void shuffleAndGather(const raft::handle_t& handle, template void countSamplesInCluster(const raft::handle_t& handle, const KMeansParams& params, - const raft::device_matrix_view& X, - const raft::device_vector_view& L2NormX, - const raft::device_matrix_view& centroids, + raft::device_matrix_view X, + raft::device_vector_view L2NormX, + raft::device_matrix_view centroids, rmm::device_uvector& workspace, - const raft::device_vector_view& sampleCountInCluster) + raft::device_vector_view sampleCountInCluster) { detail::countSamplesInCluster( handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); @@ -444,8 +444,8 @@ void countSamplesInCluster(const raft::handle_t& handle, template void kmeansPlusPlus(const raft::handle_t& handle, const KMeansParams& params, - const raft::device_matrix_view& X, - const raft::device_matrix_view& centroidsRawData, + raft::device_matrix_view X, + raft::device_matrix_view centroidsRawData, rmm::device_uvector& workspace) { detail::kmeansPlusPlus(handle, params, X, centroidsRawData, workspace); @@ -477,11 +477,11 @@ void kmeansPlusPlus(const raft::handle_t& handle, template void kmeans_fit_main(const raft::handle_t& handle, const KMeansParams& params, - const raft::device_matrix_view& X, - const raft::device_vector_view& weight, - const raft::device_matrix_view& centroidsRawData, - const raft::host_scalar_view& inertia, - const raft::host_scalar_view& n_iter, + raft::device_matrix_view X, + raft::device_vector_view weight, + raft::device_matrix_view centroidsRawData, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, rmm::device_uvector& workspace) { detail::kmeans_fit_main(