diff --git a/build.sh b/build.sh index 1c54276aa5..4e230108d4 100755 --- a/build.sh +++ b/build.sh @@ -120,6 +120,9 @@ if hasArg clean; then CLEAN=1 fi +if hasArg cpp-mgtests; then + BUILD_CUML_MG_TESTS=ON +fi # Long arguments LONG_ARGUMENT_LIST=( diff --git a/cpp/bench/sg/kmeans.cu b/cpp/bench/sg/kmeans.cu index b2603aade8..04728805bd 100644 --- a/cpp/bench/sg/kmeans.cu +++ b/cpp/bench/sg/kmeans.cu @@ -17,6 +17,9 @@ #include "benchmark.cuh" #include #include +#include +#include +#include #include namespace ML { @@ -86,9 +89,9 @@ std::vector getInputs() p.kmeans.init = ML::kmeans::KMeansParams::InitMethod(0); p.kmeans.max_iter = 300; p.kmeans.tol = 1e-4; - p.kmeans.verbosity = CUML_LEVEL_INFO; - p.kmeans.seed = int(p.blobs.seed); - p.kmeans.metric = 0; // L2 + p.kmeans.verbosity = RAFT_LEVEL_INFO; + p.kmeans.metric = raft::distance::DistanceType::L2Expanded; + p.kmeans.rng_state = raft::random::RngState(p.blobs.seed); p.kmeans.inertia_check = true; std::vector> rowcols = { {160000, 64}, diff --git a/cpp/examples/kmeans/kmeans_example.cpp b/cpp/examples/kmeans/kmeans_example.cpp index 5149ffcd2d..31d4d29cf3 100644 --- a/cpp/examples/kmeans/kmeans_example.cpp +++ b/cpp/examples/kmeans/kmeans_example.cpp @@ -23,9 +23,8 @@ #include -#include - #include +#include #ifndef CUDA_RT_CALL #define CUDA_RT_CALL(call) \ @@ -112,7 +111,7 @@ int main(int argc, char* argv[]) params.max_iter = 300; params.tol = 0.05; } - params.metric = 1; + params.metric = raft::distance::DistanceType::L2SqrtExpanded; params.init = ML::kmeans::KMeansParams::InitMethod::Random; // Inputs copied from kmeans_test.cu diff --git a/cpp/include/cuml/cluster/kmeans.hpp b/cpp/include/cuml/cluster/kmeans.hpp index 94bd9eebe8..e9c37791f4 100644 --- a/cpp/include/cuml/cluster/kmeans.hpp +++ b/cpp/include/cuml/cluster/kmeans.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #pragma once #include +#include namespace raft { class handle_t; @@ -26,54 +27,7 @@ namespace ML { namespace kmeans { -struct KMeansParams { - enum InitMethod { KMeansPlusPlus, Random, Array }; - - // The number of clusters to form as well as the number of centroids to - // generate (default:8). - int n_clusters = 8; - - /* - * Method for initialization, defaults to k-means++: - * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm - * to select the initial cluster centers. - * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at - * random from the input data for the initial centroids. - * - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. - */ - InitMethod init = KMeansPlusPlus; - - // Maximum number of iterations of the k-means algorithm for a single run. - int max_iter = 300; - - // Relative tolerance with regards to inertia to declare convergence. - double tol = 1e-4; - - // verbosity level. - int verbosity = CUML_LEVEL_INFO; - - // Seed to the random number generator. - int seed = 0; - - // Metric to use for distance computation. Any metric from - // raft::distance::DistanceType can be used - int metric = 0; - - // Number of instance k-means algorithm will be run with different seeds. - int n_init = 1; - - // Oversampling factor for use in the k-means|| algorithm. - double oversampling_factor = 2.0; - - // batch_samples and batch_centroids are used to tile 1NN computation which is - // useful to optimize/control the memory footprint - // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 - // then don't tile the centroids - int batch_samples = 1 << 15; - int batch_centroids = 0; // if 0 then batch_centroids = n_clusters - - bool inertia_check = false; -}; +using KMeansParams = raft::cluster::KMeansParams; /** * @brief Compute k-means clustering and predicts cluster index for each sample @@ -222,8 +176,6 @@ void predict(const raft::handle_t& handle, * @param[in] n_features Number of features or the dimensions of each * sample in 'X' (it should be same as the dimension for each cluster centers in * 'centroids'). - * @param[in] metric Metric to use for distance computation. Any - * metric from raft::distance::DistanceType can be used * @param[out] X_new X transformed in the new space.. */ void transform(const raft::handle_t& handle, @@ -232,7 +184,6 @@ void transform(const raft::handle_t& handle, const float* X, int n_samples, int n_features, - int metric, float* X_new); void transform(const raft::handle_t& handle, @@ -241,7 +192,6 @@ void transform(const raft::handle_t& handle, const double* X, int n_samples, int n_features, - int metric, double* X_new); }; // end namespace kmeans diff --git a/cpp/include/cuml/cluster/kmeans_mg.hpp b/cpp/include/cuml/cluster/kmeans_mg.hpp index 9ca3450cab..618f4bbd3a 100644 --- a/cpp/include/cuml/cluster/kmeans_mg.hpp +++ b/cpp/include/cuml/cluster/kmeans_mg.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,14 @@ #pragma once +#include + namespace raft { class handle_t; } namespace ML { namespace kmeans { -struct KMeansParams; namespace opg { /** diff --git a/cpp/src/common/tensor.hpp b/cpp/src/common/tensor.hpp deleted file mode 100644 index 61bb8b84e9..0000000000 --- a/cpp/src/common/tensor.hpp +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include - -namespace ML { - -template -class Tensor { - public: - enum { NumDim = Dim }; - typedef DataT* DataPtrT; - - __host__ ~Tensor() - { - if (_state == AllocState::Owner) { - if (memory_type(_data) == cudaMemoryTypeHost) { delete _data; } - - if (memory_type(_data) == cudaMemoryTypeDevice) { - rmm_alloc->deallocate(_data, this->getSizeInBytes(), _stream); - } else if (memory_type(_data) == cudaMemoryTypeHost) { - delete _data; - } - } - } - - __host__ Tensor(DataPtrT data, const std::vector& sizes) - : _data(data), _state(AllocState::NotOwner) - { - static_assert(Dim > 0, "must have > 0 dimensions"); - - ASSERT(sizes.size() == Dim, - "invalid argument: # of entries in the input argument 'sizes' must " - "match the tensor dimension"); - - for (int i = 0; i < Dim; ++i) { - _size[i] = sizes[i]; - } - - _stride[Dim - 1] = (IndexT)1; - for (int j = Dim - 2; j >= 0; --j) { - _stride[j] = _stride[j + 1] * _size[j + 1]; - } - } - - // allocate the data using the allocator and release when the object goes out of scope - // allocating tensor is the owner of the data - __host__ Tensor(const std::vector& sizes, cudaStream_t stream) - : _stream(stream), _state(AllocState::Owner) - { - static_assert(Dim > 0, "must have > 0 dimensions"); - - ASSERT(sizes.size() == Dim, "dimension mismatch"); - - for (int i = 0; i < Dim; ++i) { - _size[i] = sizes[i]; - } - - _stride[Dim - 1] = (IndexT)1; - for (int j = Dim - 2; j >= 0; --j) { - _stride[j] = _stride[j + 1] * _size[j + 1]; - } - - rmm_alloc = rmm::mr::get_current_device_resource(); - _data = (DataT*)rmm_alloc->allocate(this->getSizeInBytes(), _stream); - - ASSERT(this->data() || (this->getSizeInBytes() == 0), "device allocation failed"); - } - - /// returns the total number of elements contained within our data - __host__ size_t numElements() const - { - size_t num = (size_t)getSize(0); - - for (int i = 1; i < Dim; ++i) { - num *= (size_t)getSize(i); - } - - return num; - } - - /// returns the size of a given dimension, `[0, Dim - 1]` - __host__ inline IndexT getSize(int i) const { return _size[i]; } - - /// returns the stride array - __host__ inline const IndexT* strides() const { return _stride; } - - /// returns the stride array. - __host__ inline const IndexT getStride(int i) const { return _stride[i]; } - - /// returns the total size in bytes of our data - __host__ size_t getSizeInBytes() const { return numElements() * sizeof(DataT); } - - /// returns a raw pointer to the start of our data - __host__ inline DataPtrT data() { return _data; } - - /// returns a raw pointer to the start of our data. - __host__ inline DataPtrT begin() { return _data; } - - /// returns a raw pointer to the end of our data - __host__ inline DataPtrT end() { return data() + numElements(); } - - /// returns a raw pointer to the start of our data (const) - __host__ inline DataPtrT data() const { return _data; } - - /// returns a raw pointer to the end of our data (const) - __host__ inline DataPtrT end() const { return data() + numElements(); } - - /// returns the size array. - __host__ inline const IndexT* sizes() const { return _size; } - - template - __host__ Tensor view(const std::vector& sizes, - const std::vector& start_pos) - { - ASSERT(sizes.size() == NewDim, "invalid view requested"); - ASSERT(start_pos.size() == Dim, "dimensionality of the position if incorrect"); - - // calc offset at start_pos - uint32_t offset = 0; - for (uint32_t dim = 0; dim < Dim; ++dim) { - offset += start_pos[dim] * getStride(dim); - } - DataPtrT newData = this->data() + offset; - - // The total size of the new view must be the <= total size of the old view - size_t curSize = numElements(); - size_t newSize = 1; - - for (auto s : sizes) { - newSize *= s; - } - - ASSERT(newSize <= curSize, "invalid view requested"); - - return Tensor(newData, sizes); - } - - private: - enum AllocState { - /// This tensor itself owns the memory, which must be freed via - /// cudaFree - Owner, - - /// This tensor itself is not an owner of the memory; there is - /// nothing to free - NotOwner - }; - - protected: - /// Raw pointer to where the tensor data begins - DataPtrT _data{}; - - /// Array of strides (in sizeof(T) terms) per each dimension - IndexT _stride[Dim]; - - /// Size per each dimension - IndexT _size[Dim]; - - AllocState _state{}; - - cudaStream_t _stream{}; - - rmm::mr::device_memory_resource* rmm_alloc; -}; - -}; // end namespace ML diff --git a/cpp/src/kmeans/common.cuh b/cpp/src/kmeans/common.cuh deleted file mode 100644 index f70325b7dd..0000000000 --- a/cpp/src/kmeans/common.cuh +++ /dev/null @@ -1,898 +0,0 @@ -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace ML { - -#define LOG(handle, fmt, ...) \ - do { \ - bool isRoot = true; \ - if (handle.comms_initialized()) { \ - const auto& comm = handle.get_comms(); \ - const int my_rank = comm.get_rank(); \ - isRoot = my_rank == 0; \ - } \ - if (isRoot) { CUML_LOG_DEBUG(fmt, ##__VA_ARGS__); } \ - } while (0) - -namespace kmeans { -namespace detail { - -template -struct FusedL2NNReduceOp { - LabelT offset; - - FusedL2NNReduceOp(LabelT _offset) : offset(_offset){}; - - typedef typename cub::KeyValuePair KVP; - DI void operator()(LabelT rit, KVP* out, const KVP& other) - { - if (other.value < out->value) { - out->key = offset + other.key; - out->value = other.value; - } - } - - DI void operator()(LabelT rit, DataT* out, const KVP& other) - { - if (other.value < *out) { *out = other.value; } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) - { - out->key = -1; - out->value = maxVal; - } -}; - -template -struct SamplingOp { - DataT* rnd; - int* flag; - DataT cluster_cost; - double oversampling_factor; - int n_clusters; - - CUB_RUNTIME_FUNCTION __forceinline__ SamplingOp(DataT c, double l, int k, DataT* rand, int* ptr) - : cluster_cost(c), oversampling_factor(l), n_clusters(k), rnd(rand), flag(ptr) - { - } - - __host__ __device__ __forceinline__ bool operator()( - const cub::KeyValuePair& a) const - { - DataT prob_threshold = (DataT)rnd[a.key]; - - DataT prob_x = ((oversampling_factor * n_clusters * a.value) / cluster_cost); - - return !flag[a.key] && (prob_x > prob_threshold); - } -}; - -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const cub::KeyValuePair& a) const - { - return a.key; - } -}; - -template -CountT getDataBatchSize(const KMeansParams& params, CountT n_samples) -{ - auto minVal = std::min(params.batch_samples, n_samples); - return (minVal == 0) ? n_samples : minVal; -} - -template -CountT getCentroidsBatchSize(const KMeansParams& params, CountT n_local_clusters) -{ - auto minVal = std::min(params.batch_centroids, n_local_clusters); - return (minVal == 0) ? n_local_clusters : minVal; -} - -// Computes the intensity histogram from a sequence of labels -template -void countLabels(const raft::handle_t& handle, - SampleIteratorT labels, - CounterT* count, - int n_samples, - int n_clusters, - rmm::device_uvector& workspace, - cudaStream_t stream) -{ - int num_levels = n_clusters + 1; - int lower_level = 0; - int upper_level = n_clusters; - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - n_samples, - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(), - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - n_samples, - stream)); -} - -template -Tensor sampleCentroids(const raft::handle_t& handle, - Tensor& X, - Tensor& minClusterDistance, - Tensor& isSampleCentroid, - typename kmeans::detail::SamplingOp& select_op, - rmm::device_uvector& workspace, - cudaStream_t stream) -{ - int n_local_samples = X.getSize(0); - int n_features = X.getSize(1); - - Tensor nSelected({1}, stream); - - cub::ArgIndexInputIterator ip_itr(minClusterDistance.data()); - Tensor, 1> sampledMinClusterDistance({n_local_samples}, - stream); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr, - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data(), - nSelected.data(), - n_local_samples, - select_op, - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceSelect::If(workspace.data(), - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data(), - nSelected.data(), - n_local_samples, - select_op, - stream)); - - int nPtsSampledInRank = 0; - raft::copy(&nPtsSampledInRank, nSelected.data(), nSelected.numElements(), stream); - handle.sync_stream(stream); - - int* rawPtr_isSampleCentroid = isSampleCentroid.data(); - thrust::for_each_n(handle.get_thrust_policy(), - sampledMinClusterDistance.begin(), - nPtsSampledInRank, - [=] __device__(cub::KeyValuePair val) { - rawPtr_isSampleCentroid[val.key] = 1; - }); - - Tensor inRankCp({nPtsSampledInRank, n_features}, stream); - - raft::matrix::gather( - X.data(), - X.getSize(1), - X.getSize(0), - sampledMinClusterDistance.data(), - nPtsSampledInRank, - inRankCp.data(), - [=] __device__(cub::KeyValuePair val) { // MapTransformOp - return val.key; - }, - stream); - - return inRankCp; -} - -template -void computeClusterCost(const raft::handle_t& handle, - Tensor& minClusterDistance, - rmm::device_uvector& workspace, - DataT* clusterCost, - ReductionOpT reduction_op, - cudaStream_t stream) -{ - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, - temp_storage_bytes, - minClusterDistance.data(), - clusterCost, - minClusterDistance.numElements(), - reduction_op, - DataT(), - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), - temp_storage_bytes, - minClusterDistance.data(), - clusterCost, - minClusterDistance.numElements(), - reduction_op, - DataT(), - stream)); -} - -// calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', -// result will be stored in 'pairwiseDistance[n x k]' -template -void pairwise_distance(const raft::handle_t& handle, - Tensor& X, - Tensor& centroids, - Tensor& pairwiseDistance, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - cudaStream_t stream) -{ - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = centroids.getSize(0); - - ASSERT(X.getSize(1) == centroids.getSize(1), - "# features in dataset and centroids are different (must be same)"); - - ML::Metrics::pairwise_distance(handle, - X.data(), - centroids.data(), - pairwiseDistance.data(), - n_samples, - n_clusters, - n_features, - metric); -} - -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' -template -void minClusterAndDistance( - const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - Tensor& centroids, - Tensor, 1, IndexT>& minClusterAndDistance, - Tensor& L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - cudaStream_t stream) -{ - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = centroids.getSize(0); - auto dataBatchSize = kmeans::detail::getDataBatchSize(params, n_samples); - auto centroidsBatchSize = kmeans::detail::getCentroidsBatchSize(params, n_clusters); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data(), - centroids.getSize(1), - centroids.getSize(0), - raft::linalg::L2Norm, - true, - stream); - } else { - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - Tensor centroidsNorm(L2NormBuf_OR_DistBuf.data(), {n_clusters}); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - Tensor pairwiseDistance(L2NormBuf_OR_DistBuf.data(), - {dataBatchSize, centroidsBatchSize}); - - cub::KeyValuePair initial_value(0, std::numeric_limits::max()); - - thrust::fill(handle.get_thrust_policy(), - minClusterAndDistance.begin(), - minClusterAndDistance.end(), - initial_value); - - // tile over the input dataset - for (auto dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min(dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = X.template view<2>({ns, n_features}, {dIdx, 0}); - - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = minClusterAndDistance.template view<1>({ns}, {dIdx}); - - auto L2NormXView = L2NormX.template view<1>({ns}, {dIdx}); - - // tile over the centroids - for (auto cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min(centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = centroids.template view<2>({nc, n_features}, {cIdx, 0}); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto centroidsNormView = centroidsNorm.template view<1>({nc}, {cIdx}); - workspace.resize((sizeof(int)) * ns, stream); - - FusedL2NNReduceOp redOp(cIdx); - raft::distance::KVPMinReduce pairRedOp; - - raft::distance::fusedL2NN, IndexT>( - minClusterAndDistanceView.data(), - datasetView.data(), - centroidsView.data(), - L2NormXView.data(), - centroidsNormView.data(), - ns, - nc, - n_features, - (void*)workspace.data(), - redOp, - pairRedOp, - (metric == raft::distance::DistanceType::L2Expanded) ? false : true, - false, - stream); - } else { - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = pairwiseDistance.template view<2>({ns, nc}, {0, 0}); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - kmeans::detail::pairwise_distance( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric, stream); - - // argmin reduction returning pair - // calculates the closest centroid and the distance to the closest - // centroid - raft::linalg::coalescedReduction( - minClusterAndDistanceView.data(), - pairwiseDistanceView.data(), - pairwiseDistanceView.getSize(1), - pairwiseDistanceView.getSize(0), - initial_value, - stream, - true, - [=] __device__(const DataT val, const IndexT i) { - cub::KeyValuePair pair; - pair.key = cIdx + i; - pair.value = val; - return pair; - }, - [=] __device__(cub::KeyValuePair a, cub::KeyValuePair b) { - return (b.value < a.value) ? b : a; - }, - [=] __device__(cub::KeyValuePair pair) { return pair; }); - } - } - } -} - -template -void minClusterDistance(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - Tensor& centroids, - Tensor& minClusterDistance, - Tensor& L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - cudaStream_t stream) -{ - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = centroids.getSize(0); - - auto dataBatchSize = kmeans::detail::getDataBatchSize(params, n_samples); - auto centroidsBatchSize = kmeans::detail::getCentroidsBatchSize(params, n_clusters); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data(), - centroids.getSize(1), - centroids.getSize(0), - raft::linalg::L2Norm, - true, - stream); - } else { - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - Tensor centroidsNorm(L2NormBuf_OR_DistBuf.data(), {n_clusters}); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - Tensor pairwiseDistance(L2NormBuf_OR_DistBuf.data(), - {dataBatchSize, centroidsBatchSize}); - - thrust::fill(handle.get_thrust_policy(), - minClusterDistance.begin(), - minClusterDistance.end(), - std::numeric_limits::max()); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (int dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min(dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = X.template view<2>({ns, n_features}, {dIdx, 0}); - - // minClusterDistanceView [ns x n_clusters] - auto minClusterDistanceView = minClusterDistance.template view<1>({ns}, {dIdx}); - - auto L2NormXView = L2NormX.template view<1>({ns}, {dIdx}); - - // tile over the centroids - for (auto cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min(centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = centroids.template view<2>({nc, n_features}, {cIdx, 0}); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto centroidsNormView = centroidsNorm.template view<1>({nc}, {cIdx}); - workspace.resize((sizeof(int)) * ns, stream); - - FusedL2NNReduceOp redOp(cIdx); - raft::distance::KVPMinReduce pairRedOp; - raft::distance::fusedL2NN( - minClusterDistanceView.data(), - datasetView.data(), - centroidsView.data(), - L2NormXView.data(), - centroidsNormView.data(), - ns, - nc, - n_features, - (void*)workspace.data(), - redOp, - pairRedOp, - (metric == raft::distance::DistanceType::L2Expanded) ? false : true, - false, - stream); - } else { - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = pairwiseDistance.template view<2>({ns, nc}, {0, 0}); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - kmeans::detail::pairwise_distance( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric, stream); - - raft::linalg::coalescedReduction( - minClusterDistanceView.data(), - pairwiseDistanceView.data(), - pairwiseDistanceView.getSize(1), - pairwiseDistanceView.getSize(0), - std::numeric_limits::max(), - stream, - true, - [=] __device__(DataT val, int i) { // MainLambda - return val; - }, - [=] __device__(DataT a, DataT b) { // ReduceLambda - return (b < a) ? b : a; - }, - [=] __device__(DataT val) { // FinalLambda - return val; - }); - } - } - } -} - -// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores -// in 'out' does not modify the input -template -void shuffleAndGather(const raft::handle_t& handle, - const Tensor& in, - Tensor& out, - size_t n_samples_to_gather, - int seed, - cudaStream_t stream, - rmm::device_uvector* workspace = nullptr) -{ - auto n_samples = in.getSize(0); - auto n_features = in.getSize(1); - - Tensor indices({n_samples}, stream); - - if (workspace) { - // shuffle indices on device using ml-prims - raft::random::permute( - indices.data(), nullptr, nullptr, in.getSize(1), in.getSize(0), true, stream); - } else { - // shuffle indices on host and copy to device... - std::vector ht_indices(n_samples); - - std::iota(ht_indices.begin(), ht_indices.end(), 0); - - std::mt19937 gen(seed); - std::shuffle(ht_indices.begin(), ht_indices.end(), gen); - - raft::copy(indices.data(), ht_indices.data(), indices.numElements(), stream); - } - - raft::matrix::gather(in.data(), - in.getSize(1), - in.getSize(0), - indices.data(), - n_samples_to_gather, - out.data(), - stream); -} - -template -void countSamplesInCluster(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - Tensor& L2NormX, - Tensor& centroids, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - Tensor& sampleCountInCluster, - cudaStream_t stream) -{ - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = centroids.getSize(0); - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - Tensor, 1, IndexT> minClusterAndDistance({n_samples}, stream); - - // temporary buffer to store distance matrix, destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - kmeans::detail::minClusterAndDistance(handle, - params, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - // Using TransformInputIteratorT to dereference an array of cub::KeyValuePair - // and converting them to just return the Key to be used in reduce_rows_by_key - // prims - kmeans::detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - cub::KeyValuePair*> - itr(minClusterAndDistance.data(), conversion_op); - - // count # of samples in each cluster - kmeans::detail::countLabels( - handle, itr, sampleCountInCluster.data(), n_samples, n_clusters, workspace, stream); -} - -/* - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: while |C| < k - * 3: Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - * 4: C = C U {x} - * 5: end for - */ -template -void kmeansPlusPlus(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - raft::distance::DistanceType metric, - rmm::device_uvector& workspace, - rmm::device_uvector& centroidsRawData, - cudaStream_t stream) -{ - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = params.n_clusters; - - // number of seeding trials for each center (except the first) - auto n_trials = 2 + static_cast(std::ceil(log(n_clusters))); - - LOG(handle, - "Run sequential k-means++ to select %d centroids from %d input samples " - "(%d seeding trials per iterations)", - n_clusters, - n_samples, - n_trials); - - auto dataBatchSize = kmeans::detail::getDataBatchSize(params, n_samples); - - // temporary buffers - std::vector h_wt(n_samples); - - rmm::device_uvector distBuffer(n_trials * n_samples, stream); - - Tensor centroidCandidates({n_trials, n_features}, stream); - - Tensor costPerCandidate({n_trials}, stream); - - Tensor minClusterDistance({n_samples}, stream); - - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - rmm::device_scalar clusterCost(stream); - - rmm::device_scalar> minClusterIndexAndDistance(stream); - - // L2 norm of X: ||c||^2 - Tensor L2NormX({n_samples}, stream); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data(), X.data(), X.getSize(1), X.getSize(0), raft::linalg::L2Norm, true, stream); - } - - std::mt19937 gen(params.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - // <<< Step-1 >>>: C <-- sample a point uniformly at random from X - auto initialCentroid = X.template view<2>({1, n_features}, {dis(gen), 0}); - int n_clusters_picked = 1; - - // reset buffer to store the chosen centroid - centroidsRawData.resize(initialCentroid.numElements(), stream); - raft::copy( - centroidsRawData.begin(), initialCentroid.data(), initialCentroid.numElements(), stream); - - // C = initial set of centroids - Tensor centroids(centroidsRawData.data(), - {initialCentroid.getSize(0), initialCentroid.getSize(1)}); - // <<< End of Step-1 >>> - - // Calculate cluster distance, d^2(x, C), for all the points x in X to the nearest centroid - kmeans::detail::minClusterDistance(handle, - params, - X, - centroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - LOG(handle, " k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - - // <<<< Step-2 >>> : while |C| < k - while (n_clusters_picked < n_clusters) { - // <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - // Choose 'n_trials' centroid candidates from X with probability proportional to the squared - // distance to the nearest existing cluster - raft::copy(h_wt.data(), minClusterDistance.data(), minClusterDistance.numElements(), stream); - handle.sync_stream(stream); - - // Note - n_trials is relative small here, we don't need MLCommon::gather call - std::discrete_distribution<> d(h_wt.begin(), h_wt.end()); - for (int cIdx = 0; cIdx < n_trials; ++cIdx) { - auto rand_idx = d(gen); - auto randCentroid = X.template view<2>({1, n_features}, {rand_idx, 0}); - raft::copy(centroidCandidates.data() + cIdx * n_features, - randCentroid.data(), - randCentroid.numElements(), - stream); - } - - // Calculate pairwise distance between X and the centroid candidates - // Output - pwd [n_trails x n_samples] - Tensor pwd(distBuffer.data(), {n_trials, n_samples}); - kmeans::detail::pairwise_distance( - handle, centroidCandidates, X, pwd, workspace, metric, stream); - - // Update nearest cluster distance for each centroid candidate - // Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values. - // Outputs minDistanceBuf[m_trails x n_samples] where minDistance[i, :] contains updated - // minClusterDistance that includes candidate-i - Tensor minDistBuf(distBuffer.data(), {n_trials, n_samples}); - raft::linalg::matrixVectorOp( - minDistBuf.data(), - pwd.data(), - minClusterDistance.data(), - pwd.getSize(1), - pwd.getSize(0), - true, - true, - [=] __device__(DataT mat, DataT vec) { return vec <= mat ? vec : mat; }, - stream); - - // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using - // centroid candidate-i - raft::linalg::reduce(costPerCandidate.data(), - minDistBuf.data(), - minDistBuf.getSize(1), - minDistBuf.getSize(0), - static_cast(0), - true, - true, - stream); - - // Greedy Choice - Choose the candidate that has minimum cluster cost - // ArgMin operation below identifies the index of minimum cost in costPerCandidate - { - // Determine temporary device storage requirements - size_t temp_storage_bytes = 0; - cub::DeviceReduce::ArgMin(nullptr, - temp_storage_bytes, - costPerCandidate.data(), - minClusterIndexAndDistance.data(), - costPerCandidate.getSize(0)); - - // Allocate temporary storage - workspace.resize(temp_storage_bytes, stream); - - // Run argmin-reduction - cub::DeviceReduce::ArgMin(workspace.data(), - temp_storage_bytes, - costPerCandidate.data(), - minClusterIndexAndDistance.data(), - costPerCandidate.getSize(0)); - - int bestCandidateIdx = -1; - raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); - /// <<< End of Step-3 >>> - - /// <<< Step-4 >>>: C = C U {x} - // Update minimum cluster distance corresponding to the chosen centroid candidate - raft::copy(minClusterDistance.data(), - minDistBuf.data() + bestCandidateIdx * n_samples, - n_samples, - stream); - - raft::copy(centroidsRawData.data() + n_clusters_picked * n_features, - centroidCandidates.data() + bestCandidateIdx * n_features, - n_features, - stream); - - ++n_clusters_picked; - /// <<< End of Step-4 >>> - } - - LOG(handle, " k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - } /// <<<< Step-5 >>> -} - -template -void checkWeights(const raft::handle_t& handle, - rmm::device_uvector& workspace, - Tensor& weight, - cudaStream_t stream) -{ - rmm::device_scalar wt_aggr(stream); - - int n_samples = weight.getSize(0); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data(), wt_aggr.data(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - workspace.data(), temp_storage_bytes, weight.data(), wt_aggr.data(), n_samples, stream)); - - DataT wt_sum = 0; - raft::copy(&wt_sum, wt_aggr.data(), 1, stream); - handle.sync_stream(stream); - - if (wt_sum != n_samples) { - LOG(handle, - "[Warning!] KMeans: normalizing the user provided sample weights to " - "sum up to %d samples", - n_samples); - - DataT scale = n_samples / wt_sum; - raft::linalg::unaryOp( - weight.data(), - weight.data(), - weight.numElements(), - [=] __device__(const DataT& wt) { return wt * scale; }, - stream); - } -} -}; // namespace detail -}; // namespace kmeans -}; // namespace ML diff --git a/cpp/src/kmeans/kmeans.cu b/cpp/src/kmeans/kmeans.cu index 52abed22ff..ab926d960b 100644 --- a/cpp/src/kmeans/kmeans.cu +++ b/cpp/src/kmeans/kmeans.cu @@ -14,15 +14,17 @@ * limitations under the License. */ -#include "sg_impl.cuh" -#include +#include + +#include +#include namespace ML { namespace kmeans { // -------------------------- fit_predict --------------------------------// void fit_predict(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const float* X, int n_samples, int n_features, @@ -32,13 +34,22 @@ void fit_predict(const raft::handle_t& handle, float& inertia, int& n_iter) { - impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); - impl::predict( - handle, params, centroids, X, n_samples, n_features, sample_weight, true, labels, inertia); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + std::optional> sw = std::nullopt; + if (sample_weight != nullptr) + sw = std::make_optional(raft::make_device_vector_view((sample_weight), n_samples)); + auto centroids_opt = + std::make_optional(raft::make_device_matrix_view(centroids, params.n_clusters, n_features)); + auto rLabels = raft::make_device_vector_view(labels, n_samples); + auto inertia_view = raft::make_host_scalar_view(&inertia); + auto n_iter_view = raft::make_host_scalar_view(&n_iter); + + raft::cluster::kmeans_fit_predict( + handle, params, X_view, sw, centroids_opt, rLabels, inertia_view, n_iter_view); } void fit_predict(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const double* X, int n_samples, int n_features, @@ -48,15 +59,24 @@ void fit_predict(const raft::handle_t& handle, double& inertia, int& n_iter) { - impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); - impl::predict( - handle, params, centroids, X, n_samples, n_features, sample_weight, true, labels, inertia); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + std::optional> sw = std::nullopt; + if (sample_weight != nullptr) + sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples)); + auto centroids_opt = + std::make_optional(raft::make_device_matrix_view(centroids, params.n_clusters, n_features)); + auto rLabels = raft::make_device_vector_view(labels, n_samples); + auto inertia_view = raft::make_host_scalar_view(&inertia); + auto n_iter_view = raft::make_host_scalar_view(&n_iter); + + raft::cluster::kmeans_fit_predict( + handle, params, X_view, sw, centroids_opt, rLabels, inertia_view, n_iter_view); } // ----------------------------- fit ---------------------------------// void fit(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const float* X, int n_samples, int n_features, @@ -65,11 +85,20 @@ void fit(const raft::handle_t& handle, float& inertia, int& n_iter) { - impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + std::optional> sw = std::nullopt; + if (sample_weight != nullptr) + sw = std::make_optional(raft::make_device_vector_view((sample_weight), n_samples)); + auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto inertia_view = raft::make_host_scalar_view(&inertia); + auto n_iter_view = raft::make_host_scalar_view(&n_iter); + + raft::cluster::kmeans_fit( + handle, params, X_view, sw, centroids_view, inertia_view, n_iter_view); } void fit(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const double* X, int n_samples, int n_features, @@ -78,13 +107,22 @@ void fit(const raft::handle_t& handle, double& inertia, int& n_iter) { - impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + std::optional> sw = std::nullopt; + if (sample_weight != nullptr) + sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples)); + auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto inertia_view = raft::make_host_scalar_view(&inertia); + auto n_iter_view = raft::make_host_scalar_view(&n_iter); + + raft::cluster::kmeans_fit( + handle, params, X_view, sw, centroids_view, inertia_view, n_iter_view); } // ----------------------------- predict ---------------------------------// void predict(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const float* centroids, const float* X, int n_samples, @@ -94,20 +132,20 @@ void predict(const raft::handle_t& handle, int* labels, float& inertia) { - impl::predict(handle, - params, - centroids, - X, - n_samples, - n_features, - sample_weight, - normalize_weights, - labels, - inertia); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + std::optional> sw = std::nullopt; + if (sample_weight != nullptr) + sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples)); + auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto rLabels = raft::make_device_vector_view(labels, n_samples); + auto inertia_view = raft::make_host_scalar_view(&inertia); + + raft::cluster::kmeans_predict( + handle, params, X_view, sw, centroids_view, rLabels, normalize_weights, inertia_view); } void predict(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const double* centroids, const double* X, int n_samples, @@ -117,41 +155,47 @@ void predict(const raft::handle_t& handle, int* labels, double& inertia) { - impl::predict(handle, - params, - centroids, - X, - n_samples, - n_features, - sample_weight, - normalize_weights, - labels, - inertia); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + std::optional> sw = std::nullopt; + if (sample_weight != nullptr) + sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples)); + auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto rLabels = raft::make_device_vector_view(labels, n_samples); + auto inertia_view = raft::make_host_scalar_view(&inertia); + + raft::cluster::kmeans_predict( + handle, params, X_view, sw, centroids_view, rLabels, normalize_weights, inertia_view); } // ----------------------------- transform ---------------------------------// void transform(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const float* centroids, const float* X, int n_samples, int n_features, - int metric, float* X_new) { - impl::transform(handle, params, centroids, X, n_samples, n_features, metric, X_new); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto rX_new = raft::make_device_matrix_view(X_new, n_samples, n_features); + + raft::cluster::kmeans_transform(handle, params, X_view, centroids_view, rX_new); } void transform(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const double* centroids, const double* X, int n_samples, int n_features, - int metric, double* X_new) { - impl::transform(handle, params, centroids, X, n_samples, n_features, metric, X_new); + auto X_view = raft::make_device_matrix_view(X, n_samples, n_features); + auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto rX_new = raft::make_device_matrix_view(X_new, n_samples, n_features); + + raft::cluster::kmeans_transform(handle, params, X_view, centroids_view, rX_new); } }; // end namespace kmeans diff --git a/cpp/src/kmeans/kmeans_mg.cu b/cpp/src/kmeans/kmeans_mg.cu index ae9e94ec97..89daac8fc5 100644 --- a/cpp/src/kmeans/kmeans_mg.cu +++ b/cpp/src/kmeans/kmeans_mg.cu @@ -14,8 +14,11 @@ * limitations under the License. */ +#include + #include "kmeans_mg_impl.cuh" #include +#include namespace ML { namespace kmeans { @@ -24,7 +27,7 @@ namespace opg { // ----------------------------- fit ---------------------------------// void fit(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const float* X, int n_samples, int n_features, @@ -40,7 +43,7 @@ void fit(const raft::handle_t& handle, } void fit(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::KMeansParams& params, const double* X, int n_samples, int n_features, diff --git a/cpp/src/kmeans/kmeans_mg_impl.cuh b/cpp/src/kmeans/kmeans_mg_impl.cuh index 59069403dc..b803387800 100644 --- a/cpp/src/kmeans/kmeans_mg_impl.cuh +++ b/cpp/src/kmeans/kmeans_mg_impl.cuh @@ -15,21 +15,38 @@ */ #pragma once -#include +#include +#include +#include +#include +#include +#include #include #include #include +#include #include #include #include #include #include -#include "common.cuh" -#include "sg_impl.cuh" +#include namespace ML { + +#define CUML_LOG_KMEANS(handle, fmt, ...) \ + do { \ + bool isRoot = true; \ + if (handle.comms_initialized()) { \ + const auto& comm = handle.get_comms(); \ + const int my_rank = comm.get_rank(); \ + isRoot = my_rank == 0; \ + } \ + if (isRoot) { CUML_LOG_DEBUG(fmt, ##__VA_ARGS__); } \ + } while (0) + namespace kmeans { namespace opg { namespace impl { @@ -39,24 +56,19 @@ namespace impl { // Selects 'n_clusters' samples randomly from X template void initRandom(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - rmm::device_uvector& centroidsRawData) + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids) { const auto& comm = handle.get_comms(); cudaStream_t stream = handle.get_stream(); - auto n_local_samples = X.getSize(0); - auto n_features = X.getSize(1); + auto n_local_samples = X.extent(0); + auto n_features = X.extent(1); auto n_clusters = params.n_clusters; const int my_rank = comm.get_rank(); const int n_ranks = comm.get_size(); - // allocate centroids buffer - centroidsRawData.resize(n_clusters * n_features, stream); - auto centroids = - std::move(Tensor(centroidsRawData.data(), {n_clusters, n_features})); - std::vector nCentroidsSampledByRank(n_ranks, 0); std::vector nCentroidsElementsToReceiveFromRank(n_ranks, 0); @@ -72,18 +84,19 @@ void initRandom(const raft::handle_t& handle, nCentroidsElementsToReceiveFromRank[rank] = nCentroidsSampledInRank * n_features; } - int nCentroidsSampledInRank = nCentroidsSampledByRank[my_rank]; - ASSERT(nCentroidsSampledInRank <= n_local_samples, + auto nCentroidsSampledInRank = nCentroidsSampledByRank[my_rank]; + ASSERT((IndexT)nCentroidsSampledInRank <= (IndexT)n_local_samples, "# random samples requested from rank-%d is larger than the available " - "samples at the rank (requested is %d, available is %d)", + "samples at the rank (requested is %lu, available is %lu)", my_rank, - nCentroidsSampledInRank, - n_local_samples); + (size_t)nCentroidsSampledInRank, + (size_t)n_local_samples); - Tensor centroidsSampledInRank({nCentroidsSampledInRank, n_features}, stream); + auto centroidsSampledInRank = + raft::make_device_matrix(handle, nCentroidsSampledInRank, n_features); - kmeans::detail::shuffleAndGather( - handle, X, centroidsSampledInRank, nCentroidsSampledInRank, params.seed, stream); + raft::cluster::kmeans::shuffle_and_gather( + handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, params.rng_state.seed); std::vector displs(n_ranks); thrust::exclusive_scan(thrust::host, @@ -92,8 +105,8 @@ void initRandom(const raft::handle_t& handle, displs.begin()); // gather centroids from all ranks - comm.allgatherv(centroidsSampledInRank.data(), // sendbuff - centroids.data(), // recvbuff + comm.allgatherv(centroidsSampledInRank.data_handle(), // sendbuff + centroids.data_handle(), // recvbuff nCentroidsElementsToReceiveFromRank.data(), // recvcount displs.data(), stream); @@ -115,9 +128,9 @@ void initRandom(const raft::handle_t& handle, */ template void initKMeansPlusPlus(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - rmm::device_uvector& centroidsRawData, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroidsRawData, rmm::device_uvector& workspace) { const auto& comm = handle.get_comms(); @@ -125,12 +138,12 @@ void initKMeansPlusPlus(const raft::handle_t& handle, const int my_rank = comm.get_rank(); const int n_rank = comm.get_size(); - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = params.n_clusters; - raft::distance::DistanceType metric = static_cast(params.metric); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; - raft::random::Rng rng(params.seed, raft::random::GeneratorType::GenPhilox); + raft::random::RngState rng(params.rng_state.seed, raft::random::GeneratorType::GenPhilox); // <<<< Step-1 >>> : C <- sample a point uniformly at random from X // 1.1 - Select a rank r' at random from the available n_rank ranks with a @@ -140,92 +153,100 @@ void initKMeansPlusPlus(const raft::handle_t& handle, // X which will be used as the initial centroid for kmeans++ // 1.3 - Communicate the initial centroid chosen by rank-r' to all other // ranks - std::mt19937 gen(params.seed); + std::mt19937 gen(params.rng_state.seed); std::uniform_int_distribution<> dis(0, n_rank - 1); int rp = dis(gen); // buffer to flag the sample that is chosen as initial centroids - std::vector h_isSampleCentroid(n_samples); + std::vector h_isSampleCentroid(n_samples); std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); - Tensor initialCentroid({1, n_features}, stream); - LOG(handle, "@Rank-%d : KMeans|| : initial centroid is sampled at rank-%d\n", my_rank, rp); + auto initialCentroid = raft::make_device_matrix(handle, 1, n_features); + CUML_LOG_KMEANS( + handle, "@Rank-%d : KMeans|| : initial centroid is sampled at rank-%d\n", my_rank, rp); // 1.2 - Rank r' samples a point uniformly at random from the local dataset // X which will be used as the initial centroid for kmeans++ if (my_rank == rp) { - std::mt19937 gen(params.seed); + std::mt19937 gen(params.rng_state.seed); std::uniform_int_distribution<> dis(0, n_samples - 1); int cIdx = dis(gen); - auto centroidsView = X.template view<2>({1, n_features}, {cIdx, 0}); + auto centroidsView = raft::make_device_matrix_view( + X.data_handle() + cIdx * n_features, 1, n_features); - raft::copy(initialCentroid.data(), centroidsView.data(), centroidsView.numElements(), stream); + raft::copy( + initialCentroid.data_handle(), centroidsView.data_handle(), centroidsView.size(), stream); h_isSampleCentroid[cIdx] = 1; } // 1.3 - Communicate the initial centroid chosen by rank-r' to all other ranks - comm.bcast(initialCentroid.data(), initialCentroid.numElements(), rp, stream); + comm.bcast(initialCentroid.data_handle(), initialCentroid.size(), rp, stream); // device buffer to flag the sample that is chosen as initial centroid - Tensor isSampleCentroid({n_samples}, stream); + auto isSampleCentroid = raft::make_device_vector(handle, n_samples); raft::copy( - isSampleCentroid.data(), h_isSampleCentroid.data(), isSampleCentroid.numElements(), stream); + isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); rmm::device_uvector centroidsBuf(0, stream); // reset buffer to store the chosen centroid - centroidsBuf.resize(initialCentroid.numElements(), stream); - raft::copy(centroidsBuf.begin(), initialCentroid.data(), initialCentroid.numElements(), stream); + centroidsBuf.resize(initialCentroid.size(), stream); + raft::copy(centroidsBuf.begin(), initialCentroid.data_handle(), initialCentroid.size(), stream); - auto potentialCentroids = std::move(Tensor( - centroidsBuf.data(), {initialCentroid.getSize(0), initialCentroid.getSize(1)})); + auto potentialCentroids = raft::make_device_matrix_view( + centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); // <<< End of Step-1 >>> rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // L2 norm of X: ||x||^2 - Tensor L2NormX({n_samples}, stream); + auto L2NormX = raft::make_device_vector(handle, n_samples); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data(), X.data(), X.getSize(1), X.getSize(0), raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(L2NormX.data_handle(), + X.data_handle(), + X.extent(1), + X.extent(0), + raft::linalg::L2Norm, + true, + stream); } - Tensor minClusterDistance({n_samples}, stream); - Tensor uniformRands({n_samples}, stream); + auto minClusterDistance = raft::make_device_vector(handle, n_samples); + auto uniformRands = raft::make_device_vector(handle, n_samples); // <<< Step-2 >>>: psi <- phi_X (C) - rmm::device_scalar clusterCost(stream); - - kmeans::detail::minClusterDistance(handle, - params, - X, - potentialCentroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); + auto clusterCost = raft::make_device_scalar(handle, 0); + + raft::cluster::kmeans::min_cluster_distance(handle, + X, + potentialCentroids, + minClusterDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); // compute partial cluster cost from the samples in rank - kmeans::detail::computeClusterCost( + raft::cluster::kmeans::cluster_cost( handle, - minClusterDistance, + minClusterDistance.view(), workspace, - clusterCost.data(), - [] __device__(const DataT& a, const DataT& b) { return a + b; }, - stream); + clusterCost.view(), + [] __device__(const DataT& a, const DataT& b) { return a + b; }); // compute total cluster cost by accumulating the partial cost from all the // ranks - comm.allreduce(clusterCost.data(), clusterCost.data(), 1, raft::comms::op_t::SUM, stream); + comm.allreduce( + clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); DataT psi = 0; - psi = clusterCost.value(stream); + raft::copy(&psi, clusterCost.data_handle(), 1, stream); // <<< End of Step-2 >>> @@ -235,54 +256,64 @@ void initKMeansPlusPlus(const raft::handle_t& handle, // Scalable kmeans++ paper claims 8 rounds is sufficient int niter = std::min(8, (int)ceil(log(psi))); - LOG(handle, - "@Rank-%d:KMeans|| :phi - %f, max # of iterations for kmeans++ loop - " - "%d\n", - my_rank, - psi, - niter); + CUML_LOG_KMEANS(handle, + "@Rank-%d:KMeans|| :phi - %f, max # of iterations for kmeans++ loop - " + "%d\n", + my_rank, + psi, + niter); // <<<< Step-3 >>> : for O( log(psi) ) times do for (int iter = 0; iter < niter; ++iter) { - LOG(handle, - "@Rank-%d:KMeans|| - Iteration %d: # potential centroids sampled - " - "%d\n", - my_rank, - iter, - potentialCentroids.getSize(0)); - - kmeans::detail::minClusterDistance(handle, - params, - X, - potentialCentroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - kmeans::detail::computeClusterCost( + CUML_LOG_KMEANS(handle, + "@Rank-%d:KMeans|| - Iteration %d: # potential centroids sampled - " + "%d\n", + my_rank, + iter, + potentialCentroids.extent(0)); + + raft::cluster::kmeans::min_cluster_distance(handle, + X, + potentialCentroids, + minClusterDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); + + raft::cluster::kmeans::cluster_cost( handle, - minClusterDistance, + minClusterDistance.view(), workspace, - clusterCost.data(), - [] __device__(const DataT& a, const DataT& b) { return a + b; }, - stream); - comm.allreduce(clusterCost.data(), clusterCost.data(), 1, raft::comms::op_t::SUM, stream); - psi = clusterCost.value(stream); + clusterCost.view(), + [] __device__(const DataT& a, const DataT& b) { return a + b; }); + comm.allreduce( + clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); + raft::copy(&psi, clusterCost.data_handle(), 1, stream); ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, "An error occurred in the distributed operation. This can result " "from a failed rank"); // <<<< Step-4 >>> : Sample each point x in X independently and identify new // potentialCentroids - rng.uniform(uniformRands.data(), uniformRands.getSize(0), (DataT)0, (DataT)1, stream); - kmeans::detail::SamplingOp select_op( - psi, params.oversampling_factor, n_clusters, uniformRands.data(), isSampleCentroid.data()); - - auto inRankCp = kmeans::detail::sampleCentroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, workspace, stream); + raft::random::uniform( + handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1); + raft::cluster::kmeans::SamplingOp select_op(psi, + params.oversampling_factor, + n_clusters, + uniformRands.data_handle(), + isSampleCentroid.data_handle()); + + rmm::device_uvector inRankCp(0, stream); + raft::cluster::kmeans::sample_centroids(handle, + X, + minClusterDistance.view(), + isSampleCentroid.view(), + select_op, + inRankCp, + workspace); /// <<<< End of Step-4 >>>> int* nPtsSampledByRank; @@ -293,13 +324,13 @@ void initKMeansPlusPlus(const raft::handle_t& handle, // potentialCentroids // RAFT_CUDA_TRY(cudaMemsetAsync(nPtsSampledByRank, 0, n_rank * sizeof(int), stream)); std::fill(nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); - nPtsSampledByRank[my_rank] = inRankCp.getSize(0); + nPtsSampledByRank[my_rank] = inRankCp.size() / n_features; comm.allgather(&(nPtsSampledByRank[my_rank]), nPtsSampledByRank, 1, stream); ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, "An error occurred in the distributed operation. This can result " "from a failed rank"); - int nPtsSampled = + auto nPtsSampled = thrust::reduce(thrust::host, nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); // gather centroids from all ranks @@ -321,31 +352,31 @@ void initKMeansPlusPlus(const raft::handle_t& handle, displs.data(), stream); - int tot_centroids = potentialCentroids.getSize(0) + nPtsSampled; + auto tot_centroids = potentialCentroids.extent(0) + nPtsSampled; potentialCentroids = - std::move(Tensor(centroidsBuf.data(), {tot_centroids, n_features})); + raft::make_device_matrix_view(centroidsBuf.data(), tot_centroids, n_features); /// <<<< End of Step-5 >>> } /// <<<< Step-6 >>> - LOG(handle, - "@Rank-%d:KMeans||: # potential centroids sampled - %d\n", - my_rank, - potentialCentroids.getSize(0)); + CUML_LOG_KMEANS(handle, + "@Rank-%d:KMeans||: # potential centroids sampled - %d\n", + my_rank, + potentialCentroids.extent(0)); - if (potentialCentroids.getSize(0) > n_clusters) { + if ((IndexT)potentialCentroids.extent(0) > (IndexT)n_clusters) { // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X // temporary buffer to store the sample count per cluster, destructor // releases the resource - Tensor weight({potentialCentroids.getSize(0)}, stream); + auto weight = raft::make_device_vector(handle, potentialCentroids.extent(0)); - kmeans::detail::countSamplesInCluster( - handle, params, X, L2NormX, potentialCentroids, workspace, metric, weight, stream); + raft::cluster::kmeans::count_samples_in_cluster( + handle, params, X, L2NormX.view(), potentialCentroids, workspace, weight.view()); // merge the local histogram from all ranks - comm.allreduce(weight.data(), // sendbuff - weight.data(), // recvbuff - weight.numElements(), // count + comm.allreduce(weight.data_handle(), // sendbuff + weight.data_handle(), // recvbuff + weight.size(), // count raft::comms::op_t::SUM, stream); @@ -354,77 +385,78 @@ void initKMeansPlusPlus(const raft::handle_t& handle, // Step-8: Recluster the weighted points in C into k clusters // Note - reclustering step is duplicated across all ranks and with the same // seed they should generate the same potentialCentroids - centroidsRawData.resize(n_clusters * n_features, stream); - kmeans::detail::kmeansPlusPlus( - handle, params, potentialCentroids, metric, workspace, centroidsRawData, stream); - - DataT inertia = 0; - int n_iter = 0; - KMeansParams default_params; + auto const_centroids = raft::make_device_matrix_view( + potentialCentroids.data_handle(), potentialCentroids.extent(0), potentialCentroids.extent(1)); + raft::cluster::kmeans::init_plus_plus( + handle, params, const_centroids, centroidsRawData, workspace); + + auto inertia = raft::make_host_scalar(0); + auto n_iter = raft::make_host_scalar(0); + auto weight_view = + raft::make_device_vector_view(weight.data_handle(), weight.extent(0)); + raft::cluster::kmeans::KMeansParams default_params; default_params.n_clusters = params.n_clusters; - ML::kmeans::impl::fit(handle, - default_params, - potentialCentroids, - weight, - centroidsRawData, - inertia, - n_iter, - workspace); + raft::cluster::kmeans::fit_main(handle, + default_params, + const_centroids, + weight_view, + centroidsRawData, + inertia.view(), + n_iter.view(), + workspace); - } else if (potentialCentroids.getSize(0) < n_clusters) { + } else if ((IndexT)potentialCentroids.extent(0) < (IndexT)n_clusters) { // supplement with random - auto n_random_clusters = n_clusters - potentialCentroids.getSize(0); - LOG(handle, - "[Warning!] KMeans||: found fewer than %d centroids during " - "initialization (found %d centroids, remaining %d centroids will be " - "chosen randomly from input samples)\n", - n_clusters, - potentialCentroids.getSize(0), - n_random_clusters); - - // reset buffer to store the chosen centroid - centroidsRawData.resize(n_clusters * n_features, stream); + auto n_random_clusters = n_clusters - potentialCentroids.extent(0); + CUML_LOG_KMEANS(handle, + "[Warning!] KMeans||: found fewer than %d centroids during " + "initialization (found %d centroids, remaining %d centroids will be " + "chosen randomly from input samples)\n", + n_clusters, + potentialCentroids.extent(0), + n_random_clusters); // generate `n_random_clusters` centroids - KMeansParams rand_params; - rand_params.init = KMeansParams::InitMethod::Random; + raft::cluster::kmeans::KMeansParams rand_params; + rand_params.init = raft::cluster::kmeans::KMeansParams::InitMethod::Random; rand_params.n_clusters = n_random_clusters; initRandom(handle, rand_params, X, centroidsRawData); // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data() + n_random_clusters * n_features, - potentialCentroids.data(), - potentialCentroids.numElements(), + raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, + potentialCentroids.data_handle(), + potentialCentroids.size(), stream); } else { // found the required n_clusters - centroidsRawData.resize(n_clusters * n_features, stream); - raft::copy( - centroidsRawData.data(), potentialCentroids.data(), potentialCentroids.numElements(), stream); + raft::copy(centroidsRawData.data_handle(), + potentialCentroids.data_handle(), + potentialCentroids.size(), + stream); } } template void checkWeights(const raft::handle_t& handle, rmm::device_uvector& workspace, - Tensor& weight, - cudaStream_t stream) + raft::device_vector_view weight) { + cudaStream_t stream = handle.get_stream(); rmm::device_scalar wt_aggr(stream); const auto& comm = handle.get_comms(); - int n_samples = weight.getSize(0); + auto n_samples = weight.extent(0); size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data(), wt_aggr.data(), n_samples, stream)); + nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); workspace.resize(temp_storage_bytes, stream); RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - workspace.data(), temp_storage_bytes, weight.data(), wt_aggr.data(), n_samples, stream)); + workspace.data(), temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); comm.allreduce(wt_aggr.data(), // sendbuff wt_aggr.data(), // recvbuff @@ -435,16 +467,16 @@ void checkWeights(const raft::handle_t& handle, handle.sync_stream(stream); if (wt_sum != n_samples) { - LOG(handle, - "[Warning!] KMeans: normalizing the user provided sample weights to " - "sum up to %d samples", - n_samples); + CUML_LOG_KMEANS(handle, + "[Warning!] KMeans: normalizing the user provided sample weights to " + "sum up to %d samples", + n_samples); DataT scale = n_samples / wt_sum; raft::linalg::unaryOp( - weight.data(), - weight.data(), - weight.numElements(), + weight.data_handle(), + weight.data_handle(), + weight.size(), [=] __device__(const DataT& wt) { return wt * scale; }, stream); } @@ -452,26 +484,26 @@ void checkWeights(const raft::handle_t& handle, template void fit(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - Tensor& weight, - rmm::device_uvector& centroidsRawData, - DataT& inertia, - int& n_iter, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_vector_view weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, rmm::device_uvector& workspace) { const auto& comm = handle.get_comms(); cudaStream_t stream = handle.get_stream(); - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); auto n_clusters = params.n_clusters; - - raft::distance::DistanceType metric = static_cast(params.metric); + auto metric = params.metric; // stores (key, value) pair corresponding to each sample where // - key is the index of nearest cluster // - value is the distance to the nearest cluster - Tensor, 1, IndexT> minClusterAndDistance({n_samples}, stream); + auto minClusterAndDistance = + raft::make_device_vector, IndexT>(handle, n_samples); // temporary buffer to store L2 norm of centroids or distance matrix, // destructor releases the resource @@ -479,85 +511,94 @@ void fit(const raft::handle_t& handle, // temporary buffer to store intermediate centroids, destructor releases the // resource - Tensor newCentroids({n_clusters, n_features}, stream); + auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); // temporary buffer to store the weights per cluster, destructor releases // the resource - Tensor wtInCluster({n_clusters}, stream); + auto wtInCluster = raft::make_device_vector(handle, n_clusters); // L2 norm of X: ||x||^2 - Tensor L2NormX({n_samples}, stream); + auto L2NormX = raft::make_device_vector(handle, n_samples); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data(), X.data(), X.getSize(1), X.getSize(0), raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(L2NormX.data_handle(), + X.data_handle(), + X.extent(1), + X.extent(0), + raft::linalg::L2Norm, + true, + stream); } DataT priorClusteringCost = 0; - for (n_iter = 0; n_iter < params.max_iter; ++n_iter) { - LOG(handle, - "KMeans.fit: Iteration-%d: fitting the model using the initialize " - "cluster centers\n", - n_iter); - - auto centroids = - std::move(Tensor(centroidsRawData.data(), {n_clusters, n_features})); - + for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { + CUML_LOG_KMEANS(handle, + "KMeans.fit: Iteration-%d: fitting the model using the initialize " + "cluster centers\n", + n_iter[0]); + + auto const_centroids = raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)); // computes minClusterAndDistance[0:n_samples) where // minClusterAndDistance[i] is a pair where // 'key' is index to an sample in 'centroids' (index of the nearest // centroid) and 'value' is the distance between the sample 'X[i]' and the // 'centroid[key]' - kmeans::detail::minClusterAndDistance(handle, - params, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); + raft::cluster::kmeans::min_cluster_and_distance(handle, + X, + const_centroids, + minClusterAndDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); // Using TransformInputIteratorT to dereference an array of // cub::KeyValuePair and converting them to just return the Key to be used // in reduce_rows_by_key prims - kmeans::detail::KeyValueIndexOp conversion_op; + raft::cluster::kmeans::KeyValueIndexOp conversion_op; cub::TransformInputIterator, - cub::KeyValuePair*> - itr(minClusterAndDistance.data(), conversion_op); + raft::cluster::kmeans::KeyValueIndexOp, + raft::KeyValuePair*> + itr(minClusterAndDistance.data_handle(), conversion_op); workspace.resize(n_samples, stream); // Calculates weighted sum of all the samples assigned to cluster-i and // store the result in newCentroids[i] - raft::linalg::reduce_rows_by_key(X.data(), - X.getSize(1), + raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), + X.extent(1), itr, - weight.data(), + weight.data_handle(), workspace.data(), - X.getSize(0), - X.getSize(1), + X.extent(0), + X.extent(1), n_clusters, - newCentroids.data(), + newCentroids.data_handle(), stream); // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key( - weight.data(), itr, wtInCluster.data(), 1, weight.getSize(0), n_clusters, stream); + raft::linalg::reduce_cols_by_key(weight.data_handle(), + itr, + wtInCluster.data_handle(), + (IndexT)1, + (IndexT)weight.extent(0), + (IndexT)n_clusters, + stream); // merge the local histogram from all ranks - comm.allreduce(wtInCluster.data(), // sendbuff - wtInCluster.data(), // recvbuff - wtInCluster.numElements(), // count + comm.allreduce(wtInCluster.data_handle(), // sendbuff + wtInCluster.data_handle(), // recvbuff + wtInCluster.size(), // count raft::comms::op_t::SUM, stream); // reduces newCentroids from all ranks - comm.allreduce(newCentroids.data(), // sendbuff - newCentroids.data(), // recvbuff - newCentroids.numElements(), // count + comm.allreduce(newCentroids.data_handle(), // sendbuff + newCentroids.data_handle(), // recvbuff + newCentroids.size(), // count raft::comms::op_t::SUM, stream); @@ -569,11 +610,11 @@ void fit(const raft::handle_t& handle, // Note - when wtInCluster[i] is 0, newCentroid[i] is reset to 0 raft::linalg::matrixVectorOp( - newCentroids.data(), - newCentroids.data(), - wtInCluster.data(), - newCentroids.getSize(1), - newCentroids.getSize(0), + newCentroids.data_handle(), + newCentroids.data_handle(), + wtInCluster.data_handle(), + newCentroids.extent(1), + newCentroids.extent(0), true, false, [=] __device__(DataT mat, DataT vec) { @@ -585,64 +626,63 @@ void fit(const raft::handle_t& handle, stream); // copy the centroids[i] to newCentroids[i] when wtInCluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(wtInCluster.data()); + cub::ArgIndexInputIterator itr_wt(wtInCluster.data_handle()); raft::matrix::gather_if( - centroids.data(), - centroids.getSize(1), - centroids.getSize(0), + centroids.data_handle(), + centroids.extent(1), + centroids.extent(0), itr_wt, itr_wt, - wtInCluster.numElements(), - newCentroids.data(), - [=] __device__(cub::KeyValuePair map) { // predicate + wtInCluster.size(), + newCentroids.data_handle(), + [=] __device__(raft::KeyValuePair map) { // predicate // copy when the # of samples in the cluster is 0 if (map.value == 0) return true; else return false; }, - [=] __device__(cub::KeyValuePair map) { // map + [=] __device__(raft::KeyValuePair map) { // map return map.key; }, stream); // compute the squared norm between the newCentroids and the original // centroids, destructor releases the resource - Tensor sqrdNorm({1}, stream); + auto sqrdNorm = raft::make_device_scalar(handle, 1); raft::linalg::mapThenSumReduce( - sqrdNorm.data(), - newCentroids.numElements(), + sqrdNorm.data_handle(), + newCentroids.size(), [=] __device__(const DataT a, const DataT b) { DataT diff = a - b; return diff * diff; }, stream, - centroids.data(), - newCentroids.data()); + centroids.data_handle(), + newCentroids.data_handle()); DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data(), sqrdNorm.numElements(), stream); + raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); - raft::copy(centroidsRawData.data(), newCentroids.data(), newCentroids.numElements(), stream); + raft::copy(centroids.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); bool done = false; if (params.inertia_check) { - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar> clusterCostD(stream); // calculate cluster cost phi_x(C) - kmeans::detail::computeClusterCost( + raft::cluster::kmeans::cluster_cost( handle, - minClusterAndDistance, + minClusterAndDistance.view(), workspace, - clusterCostD.data(), - [] __device__(const cub::KeyValuePair& a, - const cub::KeyValuePair& b) { - cub::KeyValuePair res; + raft::make_device_scalar_view(clusterCostD.data()), + [] __device__(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) { + raft::KeyValuePair res; res.key = 0; res.value = a.value + b.value; return res; - }, - stream); + }); // Cluster cost phi_x(C) from all ranks comm.allreduce(&(clusterCostD.data()->value), @@ -661,7 +701,7 @@ void fit(const raft::handle_t& handle, "Too few points and centriods being found is getting 0 cost from " "centers\n"); - if (n_iter > 0) { + if (n_iter[0] > 0) { DataT delta = curClusteringCost / priorClusteringCost; if (delta > 1 - params.tol) done = true; } @@ -672,7 +712,8 @@ void fit(const raft::handle_t& handle, if (sqrdNormError < params.tol) done = true; if (done) { - LOG(handle, "Threshold triggered after %d iterations. Terminating early.\n", n_iter); + CUML_LOG_KMEANS( + handle, "Threshold triggered after %d iterations. Terminating early.\n", n_iter[0]); break; } } @@ -680,76 +721,82 @@ void fit(const raft::handle_t& handle, template void fit(const raft::handle_t& handle, - const KMeansParams& params, + const raft::cluster::kmeans::KMeansParams& params, const DataT* X, - const int n_local_samples, - const int n_features, + const IndexT n_local_samples, + const IndexT n_features, const DataT* sample_weight, DataT* centroids, DataT& inertia, - int& n_iter) + IndexT& n_iter) { cudaStream_t stream = handle.get_stream(); ASSERT(n_local_samples > 0, "# of samples must be > 0"); - ASSERT(params.oversampling_factor > 0, "oversampling factor must be > 0 (requested %d)", (int)params.oversampling_factor); - ASSERT(is_device_or_managed_type(X), "input data must be device accessible"); - Tensor data((DataT*)X, {n_local_samples, n_features}); - - Tensor weight({n_local_samples}, stream); + auto n_clusters = params.n_clusters; + auto data = raft::make_device_matrix_view(X, n_local_samples, n_features); + auto weight = raft::make_device_vector(handle, n_local_samples); if (sample_weight != nullptr) { - raft::copy(weight.data(), sample_weight, n_local_samples, stream); + raft::copy(weight.data_handle(), sample_weight, n_local_samples, stream); } else { - thrust::fill(handle.get_thrust_policy(), weight.begin(), weight.end(), 1); + thrust::fill( + handle.get_thrust_policy(), weight.data_handle(), weight.data_handle() + weight.size(), 1); } // underlying expandable storage that holds centroids data - rmm::device_uvector centroidsRawData(0, stream); + auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); // Device-accessible allocation of expandable storage used as temorary buffers rmm::device_uvector workspace(0, stream); // check if weights sum up to n_samples - checkWeights(handle, workspace, weight, stream); + checkWeights(handle, workspace, weight.view()); - if (params.init == KMeansParams::InitMethod::Random) { + if (params.init == raft::cluster::kmeans::KMeansParams::InitMethod::Random) { // initializing with random samples from input dataset - LOG(handle, - "KMeans.fit: initialize cluster centers by randomly choosing from the " - "input data.\n"); - initRandom(handle, params, data, centroidsRawData); - } else if (params.init == KMeansParams::InitMethod::KMeansPlusPlus) { + CUML_LOG_KMEANS(handle, + "KMeans.fit: initialize cluster centers by randomly choosing from the " + "input data.\n"); + initRandom(handle, params, data, centroidsRawData.view()); + } else if (params.init == raft::cluster::kmeans::KMeansParams::InitMethod::KMeansPlusPlus) { // default method to initialize is kmeans++ - LOG(handle, "KMeans.fit: initialize cluster centers using k-means++ algorithm.\n"); - initKMeansPlusPlus(handle, params, data, centroidsRawData, workspace); - } else if (params.init == KMeansParams::InitMethod::Array) { - LOG(handle, - "KMeans.fit: initialize cluster centers from the ndarray array input " - "passed to init arguement.\n"); + CUML_LOG_KMEANS(handle, "KMeans.fit: initialize cluster centers using k-means++ algorithm.\n"); + initKMeansPlusPlus(handle, params, data, centroidsRawData.view(), workspace); + } else if (params.init == raft::cluster::kmeans::KMeansParams::InitMethod::Array) { + CUML_LOG_KMEANS(handle, + "KMeans.fit: initialize cluster centers from the ndarray array input " + "passed to init arguement.\n"); ASSERT(centroids != nullptr, "centroids array is null (require a valid array of centroids for " "the requested initialization method)"); - centroidsRawData.resize(params.n_clusters * n_features, stream); - raft::copy(centroidsRawData.begin(), centroids, params.n_clusters * n_features, stream); - + raft::copy(centroidsRawData.data_handle(), centroids, params.n_clusters * n_features, stream); } else { THROW("unknown initialization method to select initial centers"); } - - fit(handle, params, data, weight, centroidsRawData, inertia, n_iter, workspace); - - raft::copy(centroids, centroidsRawData.data(), params.n_clusters * n_features, stream); - - LOG(handle, - "KMeans.fit: async call returned (fit could still be running on the " - "device)\n"); + auto inertiaView = raft::make_host_scalar_view(&inertia); + auto n_iterView = raft::make_host_scalar_view(&n_iter); + + fit(handle, + params, + data, + weight.view(), + centroidsRawData.view(), + inertiaView, + n_iterView, + workspace); + + raft::copy(centroids, centroidsRawData.data_handle(), params.n_clusters * n_features, stream); + + CUML_LOG_KMEANS(handle, + "KMeans.fit: async call returned (fit could still be running on the " + "device)\n"); } }; // end namespace impl diff --git a/cpp/src/kmeans/sg_impl.cuh b/cpp/src/kmeans/sg_impl.cuh deleted file mode 100644 index 5084f449fa..0000000000 --- a/cpp/src/kmeans/sg_impl.cuh +++ /dev/null @@ -1,844 +0,0 @@ -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "common.cuh" -#include -#include -#include -#include - -#include -#include - -namespace ML { - -namespace kmeans { - -namespace impl { - -// Selects 'n_clusters' samples randomly from X -template -void initRandom(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - rmm::device_uvector& centroidsRawData) -{ - cudaStream_t stream = handle.get_stream(); - auto n_features = X.getSize(1); - auto n_clusters = params.n_clusters; - // allocate centroids buffer - centroidsRawData.resize(n_clusters * n_features, stream); - auto centroids = - std::move(Tensor(centroidsRawData.data(), {n_clusters, n_features})); - - kmeans::detail::shuffleAndGather(handle, X, centroids, n_clusters, params.seed, stream); -} - -template -void fit(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - Tensor& weight, - rmm::device_uvector& centroidsRawData, - DataT& inertia, - int& n_iter, - rmm::device_uvector& workspace) -{ - ML::Logger::get().setLevel(params.verbosity); - cudaStream_t stream = handle.get_stream(); - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = params.n_clusters; - - raft::distance::DistanceType metric = static_cast(params.metric); - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - Tensor, 1, IndexT> minClusterAndDistance({n_samples}, stream); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - Tensor newCentroids({n_clusters, n_features}, stream); - - // temporary buffer to store weights per cluster, destructor releases the - // resource - Tensor wtInCluster({n_clusters}, stream); - - rmm::device_scalar> clusterCostD(stream); - - // L2 norm of X: ||x||^2 - Tensor L2NormX({n_samples}, stream); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data(), X.data(), X.getSize(1), X.getSize(0), raft::linalg::L2Norm, true, stream); - } - - LOG(handle, - "Calling KMeans.fit with %d samples of input data and the initialized " - "cluster centers", - n_samples); - - DataT priorClusteringCost = 0; - for (n_iter = 1; n_iter <= params.max_iter; ++n_iter) { - LOG(handle, - "KMeans.fit: Iteration-%d: fitting the model using the initialized " - "cluster centers", - n_iter); - - auto centroids = - std::move(Tensor(centroidsRawData.data(), {n_clusters, n_features})); - - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - kmeans::detail::minClusterAndDistance(handle, - params, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - // Using TransformInputIteratorT to dereference an array of - // cub::KeyValuePair and converting them to just return the Key to be used - // in reduce_rows_by_key prims - kmeans::detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - cub::KeyValuePair*> - itr(minClusterAndDistance.data(), conversion_op); - - workspace.resize(n_samples, stream); - - // Calculates weighted sum of all the samples assigned to cluster-i and store the - // result in newCentroids[i] - raft::linalg::reduce_rows_by_key(X.data(), - X.getSize(1), - itr, - weight.data(), - workspace.data(), - X.getSize(0), - X.getSize(1), - n_clusters, - newCentroids.data(), - stream); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key( - weight.data(), itr, wtInCluster.data(), 1, weight.getSize(0), n_clusters, stream); - - // Computes newCentroids[i] = newCentroids[i]/wtInCluster[i] where - // newCentroids[n_clusters x n_features] - 2D array, newCentroids[i] has sum of all the - // samples assigned to cluster-i wtInCluster[n_clusters] - 1D array, wtInCluster[i] contains # - // of samples in cluster-i. - // Note - when wtInCluster[i] is 0, newCentroid[i] is reset to 0 - raft::linalg::matrixVectorOp( - newCentroids.data(), - newCentroids.data(), - wtInCluster.data(), - newCentroids.getSize(1), - newCentroids.getSize(0), - true, - false, - [=] __device__(DataT mat, DataT vec) { - if (vec == 0) - return DataT(0); - else - return mat / vec; - }, - stream); - - // copy centroids[i] to newCentroids[i] when wtInCluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(wtInCluster.data()); - raft::matrix::gather_if( - centroids.data(), - centroids.getSize(1), - centroids.getSize(0), - itr_wt, - itr_wt, - wtInCluster.numElements(), - newCentroids.data(), - [=] __device__(cub::KeyValuePair map) { // predicate - // copy when the # of samples in the cluster is 0 - if (map.value == 0) - return true; - else - return false; - }, - [=] __device__(cub::KeyValuePair map) { // map - return map.key; - }, - stream); - - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - Tensor sqrdNorm({1}, stream); - raft::linalg::mapThenSumReduce( - sqrdNorm.data(), - newCentroids.numElements(), - [=] __device__(const DataT a, const DataT b) { - DataT diff = a - b; - return diff * diff; - }, - stream, - centroids.data(), - newCentroids.data()); - - DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data(), sqrdNorm.numElements(), stream); - - raft::copy(centroidsRawData.data(), newCentroids.data(), newCentroids.numElements(), stream); - - bool done = false; - if (params.inertia_check) { - // calculate cluster cost phi_x(C) - kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance, - workspace, - clusterCostD.data(), - [] __device__(const cub::KeyValuePair& a, - const cub::KeyValuePair& b) { - cub::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }, - stream); - - DataT curClusteringCost = 0; - raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); - - handle.sync_stream(stream); - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centriods being found is getting 0 cost from " - "centers"); - - if (n_iter > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - handle.sync_stream(stream); - if (sqrdNormError < params.tol) done = true; - - if (done) { - LOG(handle, "Threshold triggered after %d iterations. Terminating early.", n_iter); - break; - } - } - - auto centroids = - std::move(Tensor(centroidsRawData.data(), {n_clusters, n_features})); - - kmeans::detail::minClusterAndDistance(handle, - params, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - thrust::transform(handle.get_thrust_policy(), - minClusterAndDistance.begin(), - minClusterAndDistance.end(), - weight.data(), - minClusterAndDistance.begin(), - [=] __device__(const cub::KeyValuePair kvp, DataT wt) { - cub::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - // calculate cluster cost phi_x(C) - kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance, - workspace, - clusterCostD.data(), - [] __device__(const cub::KeyValuePair& a, - const cub::KeyValuePair& b) { - cub::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }, - stream); - - raft::copy(&inertia, &(clusterCostD.data()->value), 1, stream); - - LOG(handle, - "KMeans.fit: completed after %d iterations with %f inertia ", - n_iter > params.max_iter ? n_iter - 1 : n_iter, - inertia); -} - -template -void initKMeansPlusPlus(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - rmm::device_uvector& centroidsRawData, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = handle.get_stream(); - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = params.n_clusters; - raft::distance::DistanceType metric = static_cast(params.metric); - centroidsRawData.resize(n_clusters * n_features, stream); - kmeans::detail::kmeansPlusPlus(handle, params, X, metric, workspace, centroidsRawData, stream); -} - -/* - * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. - - * @note This is the algorithm described in - * "Scalable K-Means++", 2012, Bahman Bahmani, Benjamin Moseley, - * Andrea Vattani, Ravi Kumar, Sergei Vassilvitskii, - * https://arxiv.org/abs/1203.6402 - - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: psi = phi_X (C) - * 3: for O( log(psi) ) times do - * 4: C' = sample each point x in X independently with probability - * p_x = l * (d^2(x, C) / phi_X (C) ) - * 5: C = C U C' - * 6: end for - * 7: For x in C, set w_x to be the number of points in X closer to x than any - * other point in C - * 8: Recluster the weighted points in C into k clusters - - */ -template -void initScalableKMeansPlusPlus(const raft::handle_t& handle, - const KMeansParams& params, - Tensor& X, - rmm::device_uvector& centroidsRawData, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = handle.get_stream(); - auto n_samples = X.getSize(0); - auto n_features = X.getSize(1); - auto n_clusters = params.n_clusters; - raft::distance::DistanceType metric = static_cast(params.metric); - - raft::random::Rng rng(params.seed, raft::random::GeneratorType::GenPhilox); - - // <<<< Step-1 >>> : C <- sample a point uniformly at random from X - std::mt19937 gen(params.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - int cIdx = dis(gen); - auto initialCentroid = X.template view<2>({1, n_features}, {cIdx, 0}); - - // flag the sample that is chosen as initial centroid - std::vector h_isSampleCentroid(n_samples); - std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); - h_isSampleCentroid[cIdx] = 1; - - // device buffer to flag the sample that is chosen as initial centroid - Tensor isSampleCentroid({n_samples}, stream); - - raft::copy( - isSampleCentroid.data(), h_isSampleCentroid.data(), isSampleCentroid.numElements(), stream); - - rmm::device_uvector centroidsBuf(0, stream); - - // reset buffer to store the chosen centroid - centroidsBuf.resize(initialCentroid.numElements(), stream); - raft::copy(centroidsBuf.begin(), initialCentroid.data(), initialCentroid.numElements(), stream); - - auto potentialCentroids = std::move(Tensor( - centroidsBuf.data(), {initialCentroid.getSize(0), initialCentroid.getSize(1)})); - // <<< End of Step-1 >>> - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - Tensor L2NormX({n_samples}, stream); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data(), X.data(), X.getSize(1), X.getSize(0), raft::linalg::L2Norm, true, stream); - } - - Tensor minClusterDistance({n_samples}, stream); - Tensor uniformRands({n_samples}, stream); - rmm::device_scalar clusterCost(stream); - - // <<< Step-2 >>>: psi <- phi_X (C) - kmeans::detail::minClusterDistance(handle, - params, - X, - potentialCentroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - // compute partial cluster cost from the samples in rank - kmeans::detail::computeClusterCost( - handle, - minClusterDistance, - workspace, - clusterCost.data(), - [] __device__(const DataT& a, const DataT& b) { return a + b; }, - stream); - - DataT psi = 0; - psi = clusterCost.value(stream); - - // <<< End of Step-2 >>> - - // Scalable kmeans++ paper claims 8 rounds is sufficient - handle.sync_stream(stream); - int niter = std::min(8, (int)ceil(log(psi))); - LOG(handle, "KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter); - - // <<<< Step-3 >>> : for O( log(psi) ) times do - for (int iter = 0; iter < niter; ++iter) { - LOG(handle, - "KMeans|| - Iteration %d: # potential centroids sampled - %d", - iter, - potentialCentroids.getSize(0)); - - kmeans::detail::minClusterDistance(handle, - params, - X, - potentialCentroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - kmeans::detail::computeClusterCost( - handle, - minClusterDistance, - workspace, - clusterCost.data(), - [] __device__(const DataT& a, const DataT& b) { return a + b; }, - stream); - - psi = clusterCost.value(stream); - - // <<<< Step-4 >>> : Sample each point x in X independently and identify new - // potentialCentroids - rng.uniform(uniformRands.data(), uniformRands.getSize(0), (DataT)0, (DataT)1, stream); - - kmeans::detail::SamplingOp select_op( - psi, params.oversampling_factor, n_clusters, uniformRands.data(), isSampleCentroid.data()); - - auto Cp = kmeans::detail::sampleCentroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, workspace, stream); - /// <<<< End of Step-4 >>>> - - /// <<<< Step-5 >>> : C = C U C' - // append the data in Cp to the buffer holding the potentialCentroids - centroidsBuf.resize(centroidsBuf.size() + Cp.numElements(), stream); - raft::copy(centroidsBuf.end() - Cp.numElements(), Cp.data(), Cp.numElements(), stream); - - int tot_centroids = potentialCentroids.getSize(0) + Cp.getSize(0); - potentialCentroids = - std::move(Tensor(centroidsBuf.data(), {tot_centroids, n_features})); - /// <<<< End of Step-5 >>> - } /// <<<< Step-6 >>> - - LOG(handle, "KMeans||: total # potential centroids sampled - %d", potentialCentroids.getSize(0)); - - if (potentialCentroids.getSize(0) > n_clusters) { - // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X - // temporary buffer to store the sample count per cluster, destructor - // releases the resource - Tensor weight({potentialCentroids.getSize(0)}, stream); - - kmeans::detail::countSamplesInCluster( - handle, params, X, L2NormX, potentialCentroids, workspace, metric, weight, stream); - - // <<< end of Step-7 >>> - - // Step-8: Recluster the weighted points in C into k clusters - centroidsRawData.resize(n_clusters * n_features, stream); - kmeans::detail::kmeansPlusPlus( - handle, params, potentialCentroids, metric, workspace, centroidsRawData, stream); - - DataT inertia = 0; - int n_iter = 0; - KMeansParams default_params; - default_params.n_clusters = params.n_clusters; - - ML::kmeans::impl::fit(handle, - default_params, - potentialCentroids, - weight, - centroidsRawData, - inertia, - n_iter, - workspace); - - } else if (potentialCentroids.getSize(0) < n_clusters) { - // supplement with random - auto n_random_clusters = n_clusters - potentialCentroids.getSize(0); - - LOG(handle, - "[Warning!] KMeans||: found fewer than %d centroids during " - "initialization (found %d centroids, remaining %d centroids will be " - "chosen randomly from input samples)", - n_clusters, - potentialCentroids.getSize(0), - n_random_clusters); - - // reset buffer to store the chosen centroid - centroidsRawData.resize(n_clusters * n_features, stream); - - // generate `n_random_clusters` centroids - KMeansParams rand_params; - rand_params.init = KMeansParams::InitMethod::Random; - rand_params.n_clusters = n_random_clusters; - initRandom(handle, rand_params, X, centroidsRawData); - - // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data() + n_random_clusters * n_features, - potentialCentroids.data(), - potentialCentroids.numElements(), - stream); - } else { - // found the required n_clusters - centroidsRawData.resize(n_clusters * n_features, stream); - raft::copy( - centroidsRawData.data(), potentialCentroids.data(), potentialCentroids.numElements(), stream); - } -} - -template -void fit(const raft::handle_t& handle, - const KMeansParams& km_params, - const DataT* X, - const int n_samples, - const int n_features, - const DataT* sample_weight, - DataT* centroids, - DataT& inertia, - int& n_iter) -{ - ML::Logger::get().setLevel(km_params.verbosity); - cudaStream_t stream = handle.get_stream(); - - ASSERT(n_samples > 0, "# of samples must be > 0"); - - ASSERT(km_params.oversampling_factor >= 0, - "oversampling factor must be >= 0 (requested %f)", - km_params.oversampling_factor); - - ASSERT(is_device_or_managed_type(X), "input data must be device accessible"); - - Tensor data((DataT*)X, {n_samples, n_features}); - - Tensor weight({n_samples}, stream); - if (sample_weight != nullptr) { - raft::copy(weight.data(), sample_weight, n_samples, stream); - } else { - thrust::fill(handle.get_thrust_policy(), weight.begin(), weight.end(), 1); - } - - // underlying expandable storage that holds centroids data - rmm::device_uvector centroidsRawData(0, stream); - - // Device-accessible allocation of expandable storage used as temorary buffers - rmm::device_uvector workspace(0, stream); - - // check if weights sum up to n_samples - kmeans::detail::checkWeights(handle, workspace, weight, stream); - - auto n_init = km_params.n_init; - if (km_params.init == KMeansParams::InitMethod::Array && n_init != 1) { - LOG(handle, - "Explicit initial center position passed: performing only one init in " - "k-means instead of n_init=%d", - n_init); - n_init = 1; - } - - std::mt19937 gen(km_params.seed); - inertia = std::numeric_limits::max(); - - // run k-means algorithm with different seeds - for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { - // generate KMeansParams with different seed - KMeansParams params = km_params; - params.seed = gen(); - - DataT _inertia = std::numeric_limits::max(); - int _n_iter = 0; - - if (params.init == KMeansParams::InitMethod::Random) { - // initializing with random samples from input dataset - LOG(handle, - "\n\nKMeans.fit (Iteration-%d/%d): initialize cluster centers by " - "randomly choosing from the " - "input data.", - seed_iter + 1, - n_init); - initRandom(handle, params, data, centroidsRawData); - } else if (params.init == KMeansParams::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - LOG(handle, - "\n\nKMeans.fit (Iteration-%d/%d): initialize cluster centers using " - "k-means++ algorithm.", - seed_iter + 1, - n_init); - if (params.oversampling_factor == 0) - initKMeansPlusPlus(handle, params, data, centroidsRawData, workspace); - else - initScalableKMeansPlusPlus(handle, params, data, centroidsRawData, workspace); - } else if (params.init == KMeansParams::InitMethod::Array) { - LOG(handle, - "\n\nKMeans.fit (Iteration-%d/%d): initialize cluster centers from " - "the ndarray array input " - "passed to init arguement.", - seed_iter + 1, - n_init); - - ASSERT(centroids != nullptr, - "centroids array is null (require a valid array of centroids for " - "the requested initialization method)"); - - centroidsRawData.resize(params.n_clusters * n_features, stream); - raft::copy(centroidsRawData.begin(), centroids, params.n_clusters * n_features, stream); - - } else { - THROW("unknown initialization method to select initial centers"); - } - - fit(handle, params, data, weight, centroidsRawData, _inertia, _n_iter, workspace); - - if (_inertia < inertia) { - inertia = _inertia; - n_iter = _n_iter; - raft::copy(centroids, centroidsRawData.data(), params.n_clusters * n_features, stream); - } - - LOG(handle, - "KMeans.fit after iteration-%d/%d: inertia - %f, n_iter - %d", - seed_iter + 1, - n_init, - inertia, - n_iter); - - // auto centroidsT = std::move(Tensor( - // centroids, {params.n_clusters, n_features})); - } - - LOG(handle, - "KMeans.fit: async call returned (fit could still be running on the " - "device)"); -} - -template -void predict(const raft::handle_t& handle, - const KMeansParams& params, - const DataT* cptr, - const DataT* Xptr, - const int n_samples, - const int n_features, - const DataT* sample_weight, - bool normalize_weights, - IndexT* labelsRawPtr, - DataT& inertia) -{ - ML::Logger::get().setLevel(params.verbosity); - cudaStream_t stream = handle.get_stream(); - auto n_clusters = params.n_clusters; - - ASSERT(n_clusters > 0 && cptr != nullptr, "no clusters exist"); - - ASSERT(is_device_or_managed_type(Xptr), "input data must be device accessible"); - - ASSERT(is_device_or_managed_type(cptr), "centroid data must be device accessible"); - - raft::distance::DistanceType metric = static_cast(params.metric); - - Tensor X((DataT*)Xptr, {n_samples, n_features}); - Tensor centroids((DataT*)cptr, {n_clusters, n_features}); - - Tensor weight({n_samples}, stream); - if (sample_weight != nullptr) { - raft::copy(weight.data(), sample_weight, n_samples, stream); - } else { - thrust::fill(handle.get_thrust_policy(), weight.begin(), weight.end(), 1); - } - - // underlying expandable storage that holds labels - rmm::device_uvector labelsRawData(0, stream); - - // Device-accessible allocation of expandable storage used as temorary buffers - rmm::device_uvector workspace(0, stream); - - // check if weights sum up to n_samples - if (normalize_weights) kmeans::detail::checkWeights(handle, workspace, weight, stream); - - Tensor, 1> minClusterAndDistance({n_samples}, stream); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - Tensor L2NormX({n_samples}, stream); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data(), X.data(), X.getSize(1), X.getSize(0), raft::linalg::L2Norm, true, stream); - } - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - kmeans::detail::minClusterAndDistance(handle, - params, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - workspace, - metric, - stream); - - // calculate cluster cost phi_x(C) - rmm::device_scalar> clusterCostD(stream); - - thrust::transform(handle.get_thrust_policy(), - minClusterAndDistance.begin(), - minClusterAndDistance.end(), - weight.data(), - minClusterAndDistance.begin(), - [=] __device__(const cub::KeyValuePair kvp, DataT wt) { - cub::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance, - workspace, - clusterCostD.data(), - [] __device__(const cub::KeyValuePair& a, - const cub::KeyValuePair& b) { - cub::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }, - stream); - - raft::copy(&inertia, &(clusterCostD.data()->value), 1, stream); - - labelsRawData.resize(n_samples, stream); - - thrust::transform(handle.get_thrust_policy(), - minClusterAndDistance.begin(), - minClusterAndDistance.end(), - labelsRawData.data(), - [=] __device__(cub::KeyValuePair pair) { return pair.key; }); - - raft::copy(labelsRawPtr, labelsRawData.data(), n_samples, stream); -} - -template -void transform(const raft::handle_t& handle, - const KMeansParams& params, - const DataT* cptr, - const DataT* Xptr, - int n_samples, - int n_features, - int transform_metric, - DataT* X_new) -{ - ML::Logger::get().setLevel(params.verbosity); - cudaStream_t stream = handle.get_stream(); - auto n_clusters = params.n_clusters; - raft::distance::DistanceType metric = static_cast(transform_metric); - - ASSERT(n_clusters > 0 && cptr != nullptr, "no clusters exist"); - - ASSERT(is_device_or_managed_type(Xptr), "input data must be device accessible"); - - ASSERT(is_device_or_managed_type(cptr), "centroid data must be device accessible"); - - ASSERT(is_device_or_managed_type(X_new), "output data storage must be device accessible"); - - Tensor dataset((DataT*)Xptr, {n_samples, n_features}); - Tensor centroids((DataT*)cptr, {n_clusters, n_features}); - Tensor pairwiseDistance((DataT*)X_new, {n_samples, n_clusters}); - - // Device-accessible allocation of expandable storage used as temorary buffers - rmm::device_uvector workspace(0, stream); - - auto dataBatchSize = kmeans::detail::getDataBatchSize(params, n_samples); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (int dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - int ns = std::min(dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = dataset.template view<2>({ns, n_features}, {dIdx, 0}); - - // pairwiseDistanceView [ns x n_clusters] - auto pairwiseDistanceView = pairwiseDistance.template view<2>({ns, n_clusters}, {dIdx, 0}); - - // calculate pairwise distance between cluster centroids and current batch - // of input dataset - kmeans::detail::pairwise_distance( - handle, datasetView, centroids, pairwiseDistanceView, workspace, metric, stream); - } -} - -}; // namespace impl -}; // namespace kmeans -}; // end namespace ML diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 455df0eabd..d99163629a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -197,26 +197,14 @@ endif() if(BUILD_CUML_MG_TESTS) + ConfigureTest(PREFIX MG NAME KMEANS_TEST PATH mg/kmeans_test.cu OPTIONAL NCCL CUMLPRIMS ML_INCLUDE) if(MPI_CXX_FOUND) # (please keep the filenames in alphabetical order) - ConfigureTest(PREFIX MG NAME KNN_TEST PATH knn.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) - ConfigureTest(PREFIX MG NAME KNN_CLASSIFY_TEST PATH knn_classify.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) - ConfigureTest(PREFIX MG NAME KNN_REGRESS_TEST PATH knn_regress.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) - ConfigureTest(PREFIX MG NAME MAIN_TEST PATH main.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) - ConfigureTest(PREFIX MG NAME PCA_TEST PATH pca.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) - - set_target_properties( - ${CUML_MG_TEST_TARGET} - PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib" - ) - - install( - TARGETS ${CUML_MG_TEST_TARGET} - COMPONENT testing - DESTINATION bin/gtests/libcuml_mg - EXCLUDE_FROM_ALL - ) - + ConfigureTest(PREFIX MG NAME KNN_TEST PATH mg/knn.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) + ConfigureTest(PREFIX MG NAME KNN_CLASSIFY_TEST PATH mg/knn_classify.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) + ConfigureTest(PREFIX MG NAME KNN_REGRESS_TEST PATH mg/knn_regress.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) + ConfigureTest(PREFIX MG NAME MAIN_TEST PATH mg/main.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) + ConfigureTest(PREFIX MG NAME PCA_TEST PATH mg/pca.cu OPTIONAL NCCL CUMLPRIMS MPI ML_INCLUDE) else(MPI_CXX_FOUND) message("OpenMPI not found. Skipping test '${CUML_MG_TEST_TARGET}'") endif() diff --git a/cpp/test/mg/kmeans_test.cu b/cpp/test/mg/kmeans_test.cu new file mode 100644 index 0000000000..6e83d9076f --- /dev/null +++ b/cpp/test/mg/kmeans_test.cu @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t res = cmd; \ + if (res != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(res)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace ML { + +using namespace Datasets; +using namespace Metrics; + +template +struct KmeansInputs { + int n_row; + int n_col; + int n_clusters; + T tol; + bool weighted; +}; + +template +class KmeansTest : public ::testing::TestWithParam> { + protected: + KmeansTest() + : stream(handle.get_stream()), + d_labels(0, stream), + d_labels_ref(0, stream), + d_centroids(0, stream), + d_sample_weight(0, stream) + { + } + + void basicTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + ncclComm_t nccl_comm; + NCCLCHECK(ncclCommInitAll(&nccl_comm, 1, {0})); + raft::comms::build_comms_nccl_only(&handle, nccl_comm, 1, 0); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 5; + params.rng_state.seed = 1; + params.oversampling_factor = 1; + + auto stream = handle.get_stream(); + rmm::device_uvector X(n_samples * n_features, stream); + rmm::device_uvector labels(n_samples, stream); + + make_blobs(handle, + X.data(), + labels.data(), + n_samples, + n_features, + params.n_clusters, + true, + nullptr, + nullptr, + 1.0, + false, + -10.0f, + 10.0f, + 1234ULL); + + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + + T* d_sample_weight_ptr = nullptr; + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + d_sample_weight_ptr = d_sample_weight.data(); + thrust::fill( + thrust::cuda::par.on(stream), d_sample_weight_ptr, d_sample_weight_ptr + n_samples, 1); + } + + raft::copy(d_labels_ref.data(), labels.data(), n_samples, stream); + + handle.sync_stream(stream); + + T inertia = 0; + int n_iter = 0; + + ML::kmeans::opg::fit(handle, + params, + X.data(), + n_samples, + n_features, + d_sample_weight_ptr, + d_centroids.data(), + inertia, + n_iter); + + kmeans::predict(handle, + params, + d_centroids.data(), + X.data(), + n_samples, + n_features, + d_sample_weight_ptr, + true, + d_labels.data(), + inertia); + + score = adjusted_rand_index(handle, d_labels_ref.data(), d_labels.data(), n_samples); + handle.sync_stream(stream); + + if (score < 0.99) { + std::stringstream ss; + ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); + CUML_LOG_WARN(ss.str().c_str()); + ss.str(std::string()); + ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); + CUML_LOG_WARN(ss.str().c_str()); + CUML_LOG_WARN("Score = %lf", score); + } + + ncclCommDestroy(nccl_comm); + } + + void SetUp() override { basicTest(); } + + protected: + raft::handle_t handle; + cudaStream_t stream; + KmeansInputs testparams; + rmm::device_uvector d_labels; + rmm::device_uvector d_labels_ref; + rmm::device_uvector d_centroids; + rmm::device_uvector d_sample_weight; + double score; + ML::kmeans::KMeansParams params; +}; + +const std::vector> inputsf2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}}; + +const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}}; + +typedef KmeansTest KmeansTestF; +TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score >= 0.99); } + +typedef KmeansTest KmeansTestD; +TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score >= 0.99); } + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); + +} // end namespace ML \ No newline at end of file diff --git a/cpp/test/sg/kmeans_test.cu b/cpp/test/sg/kmeans_test.cu index ecd3cf1c50..dd7f10761a 100644 --- a/cpp/test/sg/kmeans_test.cu +++ b/cpp/test/sg/kmeans_test.cu @@ -15,9 +15,10 @@ */ #include -#include +#include #include -#include +#include +#include #include #include #include @@ -64,7 +65,7 @@ class KmeansTest : public ::testing::TestWithParam> { params.n_clusters = testparams.n_clusters; params.tol = testparams.tol; params.n_init = 5; - params.seed = 1; + params.rng_state.seed = 1; params.oversampling_factor = 0; auto stream = handle.get_stream(); diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx index c3c457e63c..a95d004f61 100644 --- a/python/cuml/cluster/kmeans.pyx +++ b/python/cuml/cluster/kmeans.pyx @@ -22,6 +22,7 @@ import rmm import warnings import typing +from cython.operator cimport dereference as deref from libcpp cimport bool from libc.stdint cimport uintptr_t, int64_t from libc.stdlib cimport calloc, malloc, free @@ -34,6 +35,7 @@ from cuml.common.mixins import ClusterMixin from cuml.common.mixins import CMajorInputTagMixin from cuml.common import input_to_cuml_array from cuml.cluster.kmeans_utils cimport * +from cuml.metrics.distance_type cimport DistanceType from pylibraft.common.handle cimport handle_t cdef extern from "cuml/cluster/kmeans.hpp" namespace "ML::kmeans": @@ -88,7 +90,6 @@ cdef extern from "cuml/cluster/kmeans.hpp" namespace "ML::kmeans": const float *X, int n_samples, int n_features, - int metric, float *X_new) except + cdef void transform(handle_t& handle, @@ -97,7 +98,6 @@ cdef extern from "cuml/cluster/kmeans.hpp" namespace "ML::kmeans": const double *X, int n_samples, int n_features, - int metric, double *X_new) except + @@ -243,6 +243,21 @@ class KMeans(Base, labels_ = CumlArrayDescriptor() cluster_centers_ = CumlArrayDescriptor() + def _get_kmeans_params(self): + cdef KMeansParams* params = \ + calloc(1, sizeof(KMeansParams)) + params.n_clusters = self.n_clusters + params.init = self._params_init + params.max_iter = self.max_iter + params.tol = self.tol + params.verbosity = self.verbose + params.rng_state.seed = self.random_state + params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501 + params.batch_samples = self.max_samples_per_batch + params.oversampling_factor = self.oversampling_factor + params.n_init = self.n_init + return params + def __init__(self, *, handle=None, n_clusters=8, max_iter=300, tol=1e-4, verbose=False, random_state=1, init='scalable-k-means++', n_init=1, oversampling_factor=2.0, @@ -264,9 +279,6 @@ class KMeans(Base, self.labels_ = None self.cluster_centers_ = None - cdef KMeansParams params - params.n_clusters = self.n_clusters - # cuPy does not allow comparing with string. See issue #2372 init_str = init if isinstance(init, str) else None @@ -278,29 +290,19 @@ class KMeans(Base, if (init_str in ['scalable-k-means++', 'k-means||']): self.init = init_str - params.init = KMeansPlusPlus + self._params_init = KMeansPlusPlus elif (init_str == 'random'): self.init = init - params.init = Random + self._params_init = Random else: self.init = 'preset' - params.init = Array + self._params_init = Array self.cluster_centers_, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(init, order='C', check_dtype=[np.float32, np.float64]) - params.max_iter = self.max_iter - params.tol = self.tol - params.verbosity = self.verbose - params.seed = self.random_state - params.metric = 0 # distance metric as squared L2: @todo - support other metrics # noqa: E501 - params.batch_samples=self.max_samples_per_batch - params.oversampling_factor=self.oversampling_factor - params.n_init = self.n_init - self._params = params - @generate_docstring() def fit(self, X, sample_weight=None) -> "KMeans": """ @@ -346,13 +348,14 @@ class KMeans(Base, cdef float inertiaf = 0 cdef double inertiad = 0 - cdef KMeansParams params = self._params + cdef KMeansParams* params = \ + self._get_kmeans_params() cdef int n_iter = 0 if self.dtype == np.float32: fit_predict( handle_[0], - params, + deref(params), input_ptr, n_rows, self.n_cols, @@ -367,7 +370,7 @@ class KMeans(Base, elif self.dtype == np.float64: fit_predict( handle_[0], - params, + deref(params), input_ptr, n_rows, self.n_cols, @@ -387,6 +390,7 @@ class KMeans(Base, self.handle.sync() del(X_m) del(sample_weight_m) + free(params) return self @generate_docstring(return_values={'name': 'preds', @@ -461,11 +465,13 @@ class KMeans(Base, # Sum of squared distances of samples to their closest cluster center. cdef float inertiaf = 0 cdef double inertiad = 0 + cdef KMeansParams* params = \ + self._get_kmeans_params() if self.dtype == np.float32: predict( handle_[0], - self._params, + deref(params), cluster_centers_ptr, input_ptr, n_rows, @@ -479,7 +485,7 @@ class KMeans(Base, elif self.dtype == np.float64: predict( handle_[0], - self._params, + deref(params), cluster_centers_ptr, input_ptr, n_rows, @@ -498,6 +504,7 @@ class KMeans(Base, self.handle.sync() del(X_m) del(sample_weight_m) + free(params) return self.labels_, inertia @generate_docstring(return_values={'name': 'preds', @@ -547,27 +554,27 @@ class KMeans(Base, cdef uintptr_t preds_ptr = preds.ptr # distance metric as L2-norm/euclidean distance: @todo - support other metrics # noqa: E501 - distance_metric = 1 + cdef KMeansParams* params = \ + self._get_kmeans_params() + params.metric = DistanceType.L2SqrtExpanded if self.dtype == np.float32: transform( handle_[0], - self._params, + deref(params), cluster_centers_ptr, input_ptr, n_rows, self.n_cols, - distance_metric, preds_ptr) elif self.dtype == np.float64: transform( handle_[0], - self._params, + deref(params), cluster_centers_ptr, input_ptr, n_rows, self.n_cols, - distance_metric, preds_ptr) else: raise TypeError('KMeans supports only float32 and float64 input,' @@ -577,6 +584,7 @@ class KMeans(Base, self.handle.sync() del(X_m) + free(params) return preds @generate_docstring(return_values={'name': 'score', diff --git a/python/cuml/cluster/kmeans_mg.pyx b/python/cuml/cluster/kmeans_mg.pyx index 650b0aa337..cfb7c640b0 100644 --- a/python/cuml/cluster/kmeans_mg.pyx +++ b/python/cuml/cluster/kmeans_mg.pyx @@ -22,6 +22,7 @@ import warnings import rmm +from cython.operator cimport dereference as deref from libcpp cimport bool from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free @@ -122,14 +123,15 @@ class KMeansMG(KMeans): cdef float inertiaf = 0 cdef double inertiad = 0 - cdef KMeansParams params = self._params + cdef KMeansParams* params = \ + self._get_kmeans_params() cdef int n_iter = 0 if self.dtype == np.float32: with nogil: fit( handle_[0], - params, + deref(params), input_ptr, n_rows, n_cols, @@ -144,7 +146,7 @@ class KMeansMG(KMeans): with nogil: fit( handle_[0], - params, + deref(params), input_ptr, n_rows, n_cols, @@ -163,5 +165,6 @@ class KMeansMG(KMeans): self.handle.sync() del(X_m) + free(params) return self diff --git a/python/cuml/cluster/kmeans_utils.pxd b/python/cuml/cluster/kmeans_utils.pxd index 5ae36f1ca6..efbe27dcd7 100644 --- a/python/cuml/cluster/kmeans_utils.pxd +++ b/python/cuml/cluster/kmeans_utils.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2020, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ import ctypes from libcpp cimport bool +from cuml.metrics.distance_type cimport DistanceType +from cuml.common.rng_state cimport RngState cdef extern from "cuml/cluster/kmeans.hpp" namespace \ "ML::kmeans::KMeansParams": @@ -29,8 +31,8 @@ cdef extern from "cuml/cluster/kmeans.hpp" namespace \ int max_iter, double tol, int verbosity, - int seed, - int metric, + RngState rng_state, + DistanceType metric, int n_init, double oversampling_factor, int batch_samples, diff --git a/python/cuml/common/rng_state.pxd b/python/cuml/common/rng_state.pxd new file mode 100644 index 0000000000..43f7905542 --- /dev/null +++ b/python/cuml/common/rng_state.pxd @@ -0,0 +1,30 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import ctypes +from libcpp cimport bool +from libc.stdint cimport uint64_t + +cdef extern from "raft/random/rng_state.hpp" namespace \ + "raft::random": + enum GeneratorType: + GenPhilox, GenPC + + cdef struct RngState: + RngState(uint64_t seed) except + + uint64_t seed, + uint64_t base_subsequence, + GeneratorType type diff --git a/python/cuml/tests/dask/test_dask_kmeans.py b/python/cuml/tests/dask/test_dask_kmeans.py index 0fb2cbb927..521da77321 100644 --- a/python/cuml/tests/dask/test_dask_kmeans.py +++ b/python/cuml/tests/dask/test_dask_kmeans.py @@ -93,6 +93,7 @@ def test_end_to_end(nrows, ncols, nclusters, n_parts, assert 1.0 == score +@pytest.mark.mg @pytest.mark.parametrize('nrows', [500]) @pytest.mark.parametrize('ncols', [5]) @pytest.mark.parametrize('nclusters', [3, 10])