diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index c992ef016e..c3ca60973a 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -19,974 +19,1116 @@ #include #include #include +#include +#include #include -#include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace raft { namespace cluster { namespace detail { + // ========================================================= -// Useful grid settings +// Init functions // ========================================================= -constexpr unsigned int BLOCK_SIZE = 1024; -constexpr unsigned int WARP_SIZE = 32; -constexpr unsigned int BSIZE_DIV_WSIZE = (BLOCK_SIZE / WARP_SIZE); +// Selects 'n_clusters' samples randomly from X +template +void initRandom(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroids) +{ + cudaStream_t stream = handle.get_stream(); + auto n_clusters = params.n_clusters; + detail::shuffleAndGather(handle, X, centroids, n_clusters, params.rng_state.seed); +} -// ========================================================= -// CUDA kernels -// ========================================================= +/* + * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. -/** - * @brief Compute distances between observation vectors and centroids - * Block dimensions should be (warpSize, 1, - * blockSize/warpSize). Ideally, the grid is large enough so there - * are d threads in the x-direction, k threads in the y-direction, - * and n threads in the z-direction. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, d*n entries) Observation matrix. Matrix is - * stored column-major and each column is an observation - * vector. Matrix dimensions are d x n. - * @param centroids (Input, d*k entries) Centroid matrix. Matrix is - * stored column-major and each column is a centroid. Matrix - * dimensions are d x k. - * @param dists (Output, n*k entries) Distance matrix. Matrix is - * stored column-major and the (i,j)-entry is the square of the - * Euclidean distance between the ith observation vector and jth - * centroid. Matrix dimensions are n x k. Entries must be - * initialized to zero. + * @note This is the algorithm described in + * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. + * ACM-SIAM symposium on Discrete algorithms. + * + * Scalable kmeans++ pseudocode + * 1: C = sample a point uniformly at random from X + * 2: while |C| < k + * 3: Sample x in X with probability p_x = d^2(x, C) / phi_X (C) + * 4: C = C U {x} + * 5: end for */ -template -static __global__ void computeDistances(index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const value_type_t* __restrict__ centroids, - value_type_t* __restrict__ dists) +template +void kmeansPlusPlus(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroidsRawData, + rmm::device_uvector& workspace) { - // Loop index - index_type_t i; - - // Block indices - index_type_t bidx; - // Global indices - index_type_t gidx, gidy, gidz; - - // Private memory - value_type_t centroid_private, dist_private; - - // Global x-index indicates index of vector entry - bidx = blockIdx.x; - while (bidx * blockDim.x < d) { - gidx = threadIdx.x + bidx * blockDim.x; - - // Global y-index indicates centroid - gidy = threadIdx.y + blockIdx.y * blockDim.y; - while (gidy < k) { - // Load centroid coordinate from global memory - centroid_private = (gidx < d) ? centroids[IDX(gidx, gidy, d)] : 0; - - // Global z-index indicates observation vector - gidz = threadIdx.z + blockIdx.z * blockDim.z; - while (gidz < n) { - // Load observation vector coordinate from global memory - dist_private = (gidx < d) ? obs[IDX(gidx, gidz, d)] : 0; - - // Compute contribution of current entry to distance - dist_private = centroid_private - dist_private; - dist_private = dist_private * dist_private; - - // Perform reduction on warp - for (i = WARP_SIZE / 2; i > 0; i /= 2) - dist_private += __shfl_down_sync(warp_full_mask(), dist_private, i, 2 * i); - - // Write result to global memory - if (threadIdx.x == 0) atomicAdd(dists + IDX(gidz, gidy, n), dist_private); - - // Move to another observation vector - gidz += blockDim.z * gridDim.z; - } + cudaStream_t stream = handle.get_stream(); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + // number of seeding trials for each center (except the first) + auto n_trials = 2 + static_cast(std::ceil(log(n_clusters))); + + RAFT_LOG_DEBUG( + "Run sequential k-means++ to select %d centroids from %d input samples " + "(%d seeding trials per iterations)", + n_clusters, + n_samples, + n_trials); + + auto dataBatchSize = getDataBatchSize(params, n_samples); + + // temporary buffers + std::vector h_wt(n_samples); + auto centroidCandidates = raft::make_device_matrix(n_trials, n_features, stream); + auto costPerCandidate = raft::make_device_vector(n_trials, stream); + auto minClusterDistance = raft::make_device_vector(n_samples, stream); + auto distBuffer = raft::make_device_matrix(n_trials, n_samples, stream); + + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + rmm::device_scalar clusterCost(stream); + rmm::device_scalar> minClusterIndexAndDistance(stream); + + // L2 norm of X: ||c||^2 + auto L2NormX = raft::make_device_vector(n_samples, stream); + + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormX.data(), X.data(), X.extent(1), X.extent(0), raft::linalg::L2Norm, true, stream); + } - // Move to another centroid - gidy += blockDim.y * gridDim.y; + std::mt19937 gen(params.rng_state.seed); + std::uniform_int_distribution<> dis(0, n_samples - 1); + + // <<< Step-1 >>>: C <-- sample a point uniformly at random from X + auto initialCentroid = + raft::make_device_matrix_view(X.data() + dis(gen) * n_features, 1, n_features); + int n_clusters_picked = 1; + + // store the chosen centroid in the buffer + raft::copy(centroidsRawData.data(), initialCentroid.data(), initialCentroid.size(), stream); + + // C = initial set of centroids + auto centroids = raft::make_device_matrix_view( + centroidsRawData.data(), initialCentroid.extent(0), initialCentroid.extent(1)); + // <<< End of Step-1 >>> + + // Calculate cluster distance, d^2(x, C), for all the points x in X to the nearest centroid + detail::minClusterDistanceCompute(handle, + params, + X, + centroids, + minClusterDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + workspace); + + RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); + + // <<<< Step-2 >>> : while |C| < k + while (n_clusters_picked < n_clusters) { + // <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C) + // Choose 'n_trials' centroid candidates from X with probability proportional to the squared + // distance to the nearest existing cluster + raft::copy(h_wt.data(), minClusterDistance.data(), minClusterDistance.size(), stream); + handle.sync_stream(stream); + + // Note - n_trials is relative small here, we don't need raft::gather call + std::discrete_distribution<> d(h_wt.begin(), h_wt.end()); + for (int cIdx = 0; cIdx < n_trials; ++cIdx) { + auto rand_idx = d(gen); + auto randCentroid = + raft::make_device_matrix_view(X.data() + n_features * rand_idx, 1, n_features); + raft::copy(centroidCandidates.data() + cIdx * n_features, + randCentroid.data(), + randCentroid.size(), + stream); } - // Move to another vector entry - bidx += gridDim.x; - } + // Calculate pairwise distance between X and the centroid candidates + // Output - pwd [n_trials x n_samples] + auto pwd = distBuffer.view(); + detail::pairwise_distance_kmeans( + handle, centroidCandidates.view(), X, pwd, workspace, metric); + + // Update nearest cluster distance for each centroid candidate + // Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values. + // Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated + // minClusterDistance that includes candidate-i + auto minDistBuf = distBuffer.view(); + raft::linalg::matrixVectorOp( + minDistBuf.data(), + pwd.data(), + minClusterDistance.data(), + pwd.extent(1), + pwd.extent(0), + true, + true, + [=] __device__(DataT mat, DataT vec) { return vec <= mat ? vec : mat; }, + stream); + + // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using + // centroid candidate-i + raft::linalg::reduce(costPerCandidate.data(), + minDistBuf.data(), + minDistBuf.extent(1), + minDistBuf.extent(0), + static_cast(0), + true, + true, + stream); + + // Greedy Choice - Choose the candidate that has minimum cluster cost + // ArgMin operation below identifies the index of minimum cost in costPerCandidate + { + // Determine temporary device storage requirements + size_t temp_storage_bytes = 0; + cub::DeviceReduce::ArgMin(nullptr, + temp_storage_bytes, + costPerCandidate.data(), + minClusterIndexAndDistance.data(), + costPerCandidate.extent(0)); + + // Allocate temporary storage + workspace.resize(temp_storage_bytes, stream); + + // Run argmin-reduction + cub::DeviceReduce::ArgMin(workspace.data(), + temp_storage_bytes, + costPerCandidate.data(), + minClusterIndexAndDistance.data(), + costPerCandidate.extent(0)); + + int bestCandidateIdx = -1; + raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); + /// <<< End of Step-3 >>> + + /// <<< Step-4 >>>: C = C U {x} + // Update minimum cluster distance corresponding to the chosen centroid candidate + raft::copy(minClusterDistance.data(), + minDistBuf.data() + bestCandidateIdx * n_samples, + n_samples, + stream); + + raft::copy(centroidsRawData.data() + n_clusters_picked * n_features, + centroidCandidates.data() + bestCandidateIdx * n_features, + n_features, + stream); + + ++n_clusters_picked; + /// <<< End of Step-4 >>> + } + + RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); + } /// <<<< Step-5 >>> } -/** - * @brief Find closest centroid to observation vectors. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param k Number of clusters. - * @param centroids (Input, d*k entries) Centroid matrix. Matrix is - * stored column-major and each column is a centroid. Matrix - * dimensions are d x k. - * @param dists (Input/output, n*k entries) Distance matrix. Matrix - * is stored column-major and the (i,j)-entry is the square of - * the Euclidean distance between the ith observation vector and - * jth centroid. Matrix dimensions are n x k. On exit, the first - * n entries give the square of the Euclidean distance between - * observation vectors and closest centroids. - * @param codes (Output, n entries) Cluster assignments. - * @param clusterSizes (Output, k entries) Number of points in each - * cluster. Entries must be initialized to zero. - */ -template -static __global__ void minDistances(index_type_t n, - index_type_t k, - value_type_t* __restrict__ dists, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) +// TODO: Resizing is needed to use mdarray instead of rmm::device_uvector +template +void kmeans_fit_main(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_vector_view& weight, + const raft::device_matrix_view& centroidsRawData, + const raft::host_scalar_view& inertia, + const raft::host_scalar_view& n_iter, + rmm::device_uvector& workspace) { - // Loop index - index_type_t i, j; - - // Current matrix entry - value_type_t dist_curr; - - // Smallest entry in row - value_type_t dist_min; - index_type_t code_min; - - // Each row in observation matrix is processed by a thread - i = threadIdx.x + blockIdx.x * blockDim.x; - while (i < n) { - // Find minimum entry in row - code_min = 0; - dist_min = dists[IDX(i, 0, n)]; - for (j = 1; j < k; ++j) { - dist_curr = dists[IDX(i, j, n)]; - code_min = (dist_curr < dist_min) ? j : code_min; - dist_min = (dist_curr < dist_min) ? dist_curr : dist_min; - } + logger::get(RAFT_NAME).set_level(params.verbosity); + cudaStream_t stream = handle.get_stream(); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + // stores (key, value) pair corresponding to each sample where + // - key is the index of nearest cluster + // - value is the distance to the nearest cluster + auto minClusterAndDistance = + raft::make_device_vector>(n_samples, stream); + + // temporary buffer to store L2 norm of centroids or distance matrix, + // destructor releases the resource + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // temporary buffer to store intermediate centroids, destructor releases the + // resource + auto newCentroids = raft::make_device_matrix(n_clusters, n_features, stream); + + // temporary buffer to store weights per cluster, destructor releases the + // resource + auto wtInCluster = raft::make_device_vector(n_clusters, stream); + + rmm::device_scalar> clusterCostD(stream); + + // L2 norm of X: ||x||^2 + auto L2NormX = raft::make_device_vector(n_samples, stream); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormX.data(), X.data(), X.extent(1), X.extent(0), raft::linalg::L2Norm, true, stream); + } - // Transfer result to global memory - dists[i] = dist_min; - codes[i] = code_min; + RAFT_LOG_DEBUG( + "Calling KMeans.fit with %d samples of input data and the initialized " + "cluster centers", + n_samples); + + DataT priorClusteringCost = 0; + for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { + RAFT_LOG_DEBUG( + "KMeans.fit: Iteration-%d: fitting the model using the initialized " + "cluster centers", + n_iter[0]); + + auto centroids = raft::make_device_matrix_view(centroidsRawData.data(), n_clusters, n_features); + + // computes minClusterAndDistance[0:n_samples) where + // minClusterAndDistance[i] is a pair where + // 'key' is index to a sample in 'centroids' (index of the nearest + // centroid) and 'value' is the distance between the sample 'X[i]' and the + // 'centroid[key]' + detail::minClusterAndDistanceCompute(handle, + params, + X, + centroids, + minClusterAndDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + workspace); + + // Using TransformInputIteratorT to dereference an array of + // cub::KeyValuePair and converting them to just return the Key to be used + // in reduce_rows_by_key prims + detail::KeyValueIndexOp conversion_op; + cub::TransformInputIterator, + cub::KeyValuePair*> + itr(minClusterAndDistance.data(), conversion_op); + + workspace.resize(n_samples, stream); + + // Calculates weighted sum of all the samples assigned to cluster-i and store the + // result in newCentroids[i] + raft::linalg::reduce_rows_by_key((DataT*)X.data(), + X.extent(1), + itr, + weight.data(), + workspace.data(), + X.extent(0), + X.extent(1), + n_clusters, + newCentroids.data(), + stream); + + // Reduce weights by key to compute weight in each cluster + raft::linalg::reduce_cols_by_key(weight.data(), + itr, + wtInCluster.data(), + (IndexT)1, + (IndexT)weight.extent(0), + (IndexT)n_clusters, + stream); + + // Computes newCentroids[i] = newCentroids[i]/wtInCluster[i] where + // newCentroids[n_clusters x n_features] - 2D array, newCentroids[i] has sum of all the + // samples assigned to cluster-i wtInCluster[n_clusters] - 1D array, wtInCluster[i] contains # + // of samples in cluster-i. + // Note - when wtInCluster[i] is 0, newCentroid[i] is reset to 0 + raft::linalg::matrixVectorOp( + newCentroids.data(), + newCentroids.data(), + wtInCluster.data(), + newCentroids.extent(1), + newCentroids.extent(0), + true, + false, + [=] __device__(DataT mat, DataT vec) { + if (vec == 0) + return DataT(0); + else + return mat / vec; + }, + stream); + + // copy centroids[i] to newCentroids[i] when wtInCluster[i] is 0 + cub::ArgIndexInputIterator itr_wt(wtInCluster.data()); + raft::matrix::gather_if( + centroids.data(), + centroids.extent(1), + centroids.extent(0), + itr_wt, + itr_wt, + wtInCluster.size(), + newCentroids.data(), + [=] __device__(cub::KeyValuePair map) { // predicate + // copy when the # of samples in the cluster is 0 + if (map.value == 0) + return true; + else + return false; + }, + [=] __device__(cub::KeyValuePair map) { // map + return map.key; + }, + stream); + + // compute the squared norm between the newCentroids and the original + // centroids, destructor releases the resource + auto sqrdNorm = raft::make_device_scalar(DataT(0), stream); + raft::linalg::mapThenSumReduce( + sqrdNorm.data(), + newCentroids.size(), + [=] __device__(const DataT a, const DataT b) { + DataT diff = a - b; + return diff * diff; + }, + stream, + centroids.data(), + newCentroids.data()); + + DataT sqrdNormError = 0; + raft::copy(&sqrdNormError, sqrdNorm.data(), sqrdNorm.size(), stream); + + raft::copy(centroidsRawData.data(), newCentroids.data(), newCentroids.size(), stream); + + bool done = false; + if (params.inertia_check) { + // calculate cluster cost phi_x(C) + detail::computeClusterCost(handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + [] __device__(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + cub::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + }); + + DataT curClusteringCost = 0; + raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); + + handle.sync_stream(stream); + ASSERT(curClusteringCost != (DataT)0.0, + "Too few points and centroids being found is getting 0 cost from " + "centers"); + + if (n_iter[0] > 1) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + } + priorClusteringCost = curClusteringCost; + } - // Increment cluster sizes - atomicAdd(clusterSizes + code_min, 1); + handle.sync_stream(stream); + if (sqrdNormError < params.tol) done = true; - // Move to another row - i += blockDim.x * gridDim.x; + if (done) { + RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); + break; + } } + + auto centroids = raft::make_device_matrix_view(centroidsRawData.data(), n_clusters, n_features); + + detail::minClusterAndDistanceCompute(handle, + params, + X, + centroids, + minClusterAndDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + workspace); + + // TODO: add different templates for InType of binaryOp to avoid thrust transform + thrust::transform(handle.get_thrust_policy(), + minClusterAndDistance.data(), + minClusterAndDistance.data() + minClusterAndDistance.size(), + weight.data(), + minClusterAndDistance.data(), + [=] __device__(const cub::KeyValuePair kvp, DataT wt) { + cub::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }); + + // calculate cluster cost phi_x(C) + detail::computeClusterCost(handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + [] __device__(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + cub::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + }); + + raft::copy(inertia.data(), &(clusterCostD.data()->value), 1, stream); + + RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", + n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], + inertia[0]); } -/** - * @brief Check if newly computed distances are smaller than old distances. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param dists_old (Input/output, n entries) Distances between - * observation vectors and closest centroids. On exit, entries - * are replaced by entries in 'dists_new' if the corresponding - * observation vectors are closest to the new centroid. - * @param dists_new (Input, n entries) Distance between observation - * vectors and new centroid. - * @param codes_old (Input/output, n entries) Cluster - * assignments. On exit, entries are replaced with 'code_new' if - * the corresponding observation vectors are closest to the new - * centroid. - * @param code_new Index associated with new centroid. +/* + * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. + + * @note This is the algorithm described in + * "Scalable K-Means++", 2012, Bahman Bahmani, Benjamin Moseley, + * Andrea Vattani, Ravi Kumar, Sergei Vassilvitskii, + * https://arxiv.org/abs/1203.6402 + + * Scalable kmeans++ pseudocode + * 1: C = sample a point uniformly at random from X + * 2: psi = phi_X (C) + * 3: for O( log(psi) ) times do + * 4: C' = sample each point x in X independently with probability + * p_x = l * (d^2(x, C) / phi_X (C) ) + * 5: C = C U C' + * 6: end for + * 7: For x in C, set w_x to be the number of points in X closer to x than any + * other point in C + * 8: Recluster the weighted points in C into k clusters + + * TODO: Resizing is needed to use mdarray instead of rmm::device_uvector + */ -template -static __global__ void minDistances2(index_type_t n, - value_type_t* __restrict__ dists_old, - const value_type_t* __restrict__ dists_new, - index_type_t* __restrict__ codes_old, - index_type_t code_new) +template +void initScalableKMeansPlusPlus(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroidsRawData, + rmm::device_uvector& workspace) { - // Loop index - index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; - - // Distances - value_type_t dist_old_private; - value_type_t dist_new_private; - - // Each row is processed by a thread - while (i < n) { - // Get old and new distances - dist_old_private = dists_old[i]; - dist_new_private = dists_new[i]; - - // Update if new distance is smaller than old distance - if (dist_new_private < dist_old_private) { - dists_old[i] = dist_new_private; - codes_old[i] = code_new; - } + cudaStream_t stream = handle.get_stream(); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); + + // <<<< Step-1 >>> : C <- sample a point uniformly at random from X + std::mt19937 gen(params.rng_state.seed); + std::uniform_int_distribution<> dis(0, n_samples - 1); + + auto cIdx = dis(gen); + auto initialCentroid = raft::make_device_matrix_view(X.data() + cIdx * n_features, 1, n_features); + + // flag the sample that is chosen as initial centroid + std::vector h_isSampleCentroid(n_samples); + std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); + h_isSampleCentroid[cIdx] = 1; + + // device buffer to flag the sample that is chosen as initial centroid + auto isSampleCentroid = raft::make_device_vector(n_samples, stream); + + raft::copy(isSampleCentroid.data(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); + + rmm::device_uvector centroidsBuf(initialCentroid.size(), stream); - // Move to another row - i += blockDim.x * gridDim.x; + // reset buffer to store the chosen centroid + raft::copy(centroidsBuf.data(), initialCentroid.data(), initialCentroid.size(), stream); + + auto potentialCentroids = raft::make_device_matrix_view( + centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); + // <<< End of Step-1 >>> + + // temporary buffer to store L2 norm of centroids or distance matrix, + // destructor releases the resource + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // L2 norm of X: ||x||^2 + auto L2NormX = raft::make_device_vector(n_samples, stream); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormX.data(), X.data(), X.extent(1), X.extent(0), raft::linalg::L2Norm, true, stream); } -} -/** - * @brief Compute size of k-means clusters. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @param n Number of observation vectors. - * @param k Number of clusters. - * @param codes (Input, n entries) Cluster assignments. - * @param clusterSizes (Output, k entries) Number of points in each - * cluster. Entries must be initialized to zero. - */ -template -static __global__ void computeClusterSizes(index_type_t n, - const index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) -{ - index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; - while (i < n) { - atomicAdd(clusterSizes + codes[i], 1); - i += blockDim.x * gridDim.x; + auto minClusterDistanceVec = raft::make_device_vector(n_samples, stream); + auto uniformRands = raft::make_device_vector(n_samples, stream); + rmm::device_scalar clusterCost(stream); + + // <<< Step-2 >>>: psi <- phi_X (C) + detail::minClusterDistanceCompute(handle, + params, + X, + potentialCentroids, + minClusterDistanceVec.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + workspace); + + // compute partial cluster cost from the samples in rank + detail::computeClusterCost(handle, + minClusterDistanceVec.view(), + workspace, + raft::make_device_scalar_view(clusterCost.data()), + [] __device__(const DataT& a, const DataT& b) { return a + b; }); + + auto psi = clusterCost.value(stream); + + // <<< End of Step-2 >>> + + // Scalable kmeans++ paper claims 8 rounds is sufficient + handle.sync_stream(stream); + int niter = std::min(8, (int)ceil(log(psi))); + RAFT_LOG_DEBUG("KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter); + + // <<<< Step-3 >>> : for O( log(psi) ) times do + for (int iter = 0; iter < niter; ++iter) { + RAFT_LOG_DEBUG("KMeans|| - Iteration %d: # potential centroids sampled - %d", + iter, + potentialCentroids.extent(0)); + + detail::minClusterDistanceCompute(handle, + params, + X, + potentialCentroids, + minClusterDistanceVec.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + workspace); + + detail::computeClusterCost(handle, + minClusterDistanceVec.view(), + workspace, + raft::make_device_scalar_view(clusterCost.data()), + [] __device__(const DataT& a, const DataT& b) { return a + b; }); + + psi = clusterCost.value(stream); + + // <<<< Step-4 >>> : Sample each point x in X independently and identify new + // potentialCentroids + raft::random::uniform( + handle, rng, uniformRands.data(), uniformRands.extent(0), (DataT)0, (DataT)1); + + detail::SamplingOp select_op( + psi, params.oversampling_factor, n_clusters, uniformRands.data(), isSampleCentroid.data()); + + rmm::device_uvector CpRaw(0, stream); + detail::sampleCentroids(handle, + X, + minClusterDistanceVec.view(), + isSampleCentroid.view(), + select_op, + CpRaw, + workspace); + auto Cp = raft::make_device_matrix_view(CpRaw.data(), CpRaw.size() / n_features, n_features); + /// <<<< End of Step-4 >>>> + + /// <<<< Step-5 >>> : C = C U C' + // append the data in Cp to the buffer holding the potentialCentroids + centroidsBuf.resize(centroidsBuf.size() + Cp.size(), stream); + raft::copy(centroidsBuf.data() + centroidsBuf.size() - Cp.size(), Cp.data(), Cp.size(), stream); + + IndexT tot_centroids = potentialCentroids.extent(0) + Cp.extent(0); + potentialCentroids = + raft::make_device_matrix_view(centroidsBuf.data(), tot_centroids, n_features); + /// <<<< End of Step-5 >>> + } /// <<<< Step-6 >>> + + RAFT_LOG_DEBUG("KMeans||: total # potential centroids sampled - %d", + potentialCentroids.extent(0)); + + if ((int)potentialCentroids.extent(0) > n_clusters) { + // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X + // temporary buffer to store the sample count per cluster, destructor + // releases the resource + auto weight = raft::make_device_vector(potentialCentroids.extent(0), stream); + + detail::countSamplesInCluster( + handle, params, X, L2NormX.view(), potentialCentroids, workspace, weight.view()); + + // <<< end of Step-7 >>> + + // Step-8: Recluster the weighted points in C into k clusters + detail::kmeansPlusPlus( + handle, params, potentialCentroids, centroidsRawData, workspace); + + auto inertia = make_host_scalar(0); + auto n_iter = make_host_scalar(0); + KMeansParams default_params; + default_params.n_clusters = params.n_clusters; + + detail::kmeans_fit_main(handle, + default_params, + potentialCentroids, + weight.view(), + centroidsRawData, + inertia.view(), + n_iter.view(), + workspace); + + } else if ((int)potentialCentroids.extent(0) < n_clusters) { + // supplement with random + auto n_random_clusters = n_clusters - potentialCentroids.extent(0); + + RAFT_LOG_DEBUG( + "[Warning!] KMeans||: found fewer than %d centroids during " + "initialization (found %d centroids, remaining %d centroids will be " + "chosen randomly from input samples)", + n_clusters, + potentialCentroids.extent(0), + n_random_clusters); + + // generate `n_random_clusters` centroids + KMeansParams rand_params; + rand_params.init = KMeansParams::InitMethod::Random; + rand_params.n_clusters = n_random_clusters; + initRandom(handle, rand_params, X, centroidsRawData); + + // copy centroids generated during kmeans|| iteration to the buffer + raft::copy(centroidsRawData.data() + n_random_clusters * n_features, + potentialCentroids.data(), + potentialCentroids.size(), + stream); + } else { + // found the required n_clusters + raft::copy( + centroidsRawData.data(), potentialCentroids.data(), potentialCentroids.size(), stream); } } /** - * @brief Divide rows of centroid matrix by cluster sizes. - * Divides the ith column of the sum matrix by the size of the ith - * cluster. If the sum matrix has been initialized so that the ith - * row is the sum of all observation vectors in the ith cluster, - * this kernel produces cluster centroids. The grid and block - * dimensions should be 2-dimensional. Ideally the grid is large - * enough so there are d threads in the x-direction and k threads - * in the y-direction. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param clusterSizes (Input, k entries) Number of points in each - * cluster. - * @param centroids (Input/output, d*k entries) Sum matrix. Matrix - * is stored column-major and matrix dimensions are d x k. The - * ith column is the sum of all observation vectors in the ith - * cluster. On exit, the matrix is the centroid matrix (each - * column is the mean position of a cluster). + * @brief Find clusters with k-means algorithm. + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. It must be noted + * that the data must be in row-major format and stored in device accessible + * location. + * @param[in] n_samples Number of samples in the input X. + * @param[in] n_features Number of features or the dimensions of each + * sample. + * @param[in] sample_weight Optional weights for each observation in X. + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers + * [out] Otherwise, generated centroids from the + * kmeans algorithm is stored at the address pointed by 'centroids'. + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. */ -template -static __global__ void divideCentroids(index_type_t d, - index_type_t k, - const index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids) +template +void kmeans_fit(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { - // Global indices - index_type_t gidx, gidy; - - // Current cluster size - index_type_t clusterSize_private; - - // Observation vector is determined by global y-index - gidy = threadIdx.y + blockIdx.y * blockDim.y; - while (gidy < k) { - // Get cluster size from global memory - clusterSize_private = clusterSizes[gidy]; - - // Add vector entries to centroid matrix - // vector entris are determined by global x-index - gidx = threadIdx.x + blockIdx.x * blockDim.x; - while (gidx < d) { - centroids[IDX(gidx, gidy, d)] /= clusterSize_private; - gidx += blockDim.x * gridDim.x; - } + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + cudaStream_t stream = handle.get_stream(); + // Check that parameters are valid + if (sample_weight.has_value()) + RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, + "invalid parameter (sample_weight!=n_samples)"); + RAFT_EXPECTS(n_clusters > 0, "invalid parameter (n_clusters<=0)"); + RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); + RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); + RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, + "invalid parameter (centroids.extent(0) != n_clusters)"); + RAFT_EXPECTS(centroids.extent(1) == n_features, + "invalid parameter (centroids.extent(1) != n_features)"); + + logger::get(RAFT_NAME).set_level(params.verbosity); - // Move to another centroid - gidy += blockDim.y * gridDim.y; + // Allocate memory + rmm::device_uvector workspace(0, stream); + auto weight = raft::make_device_vector(handle, n_samples); + if (sample_weight.has_value()) + raft::copy(weight.data(), sample_weight.value().data(), n_samples, stream); + else + thrust::fill(handle.get_thrust_policy(), weight.data(), weight.data() + weight.size(), 1); + + // check if weights sum up to n_samples + checkWeight(handle, weight.view(), workspace); + + auto centroidsRawData = raft::make_device_matrix(n_clusters, n_features, stream); + + auto n_init = params.n_init; + if (params.init == KMeansParams::InitMethod::Array && n_init != 1) { + RAFT_LOG_DEBUG( + "Explicit initial center position passed: performing only one init in " + "k-means instead of n_init=%d", + n_init); + n_init = 1; } -} -// ========================================================= -// Helper functions -// ========================================================= + std::mt19937 gen(params.rng_state.seed); + inertia[0] = std::numeric_limits::max(); + + for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { + KMeansParams iter_params = params; + iter_params.rng_state.seed = gen(); + + DataT iter_inertia = std::numeric_limits::max(); + IndexT n_current_iter = 0; + if (iter_params.init == KMeansParams::InitMethod::Random) { + // initializing with random samples from input dataset + RAFT_LOG_DEBUG( + "KMeans.fit (Iteration-%d/%d): initialize cluster centers by " + "randomly choosing from the " + "input data.", + seed_iter + 1, + n_init); + initRandom(handle, iter_params, X, centroidsRawData.view()); + } else if (iter_params.init == KMeansParams::InitMethod::KMeansPlusPlus) { + // default method to initialize is kmeans++ + RAFT_LOG_DEBUG( + "KMeans.fit (Iteration-%d/%d): initialize cluster centers using " + "k-means++ algorithm.", + seed_iter + 1, + n_init); + if (iter_params.oversampling_factor == 0) + detail::kmeansPlusPlus( + handle, iter_params, X, centroidsRawData.view(), workspace); + else + detail::initScalableKMeansPlusPlus( + handle, iter_params, X, centroidsRawData.view(), workspace); + } else if (iter_params.init == KMeansParams::InitMethod::Array) { + RAFT_LOG_DEBUG( + "KMeans.fit (Iteration-%d/%d): initialize cluster centers from " + "the ndarray array input " + "passed to init arguement.", + seed_iter + 1, + n_init); + raft::copy(centroidsRawData.data(), centroids.data(), n_clusters * n_features, stream); + } else { + THROW("unknown initialization method to select initial centers"); + } -/** - * @brief Randomly choose new centroids. - * Centroid is randomly chosen with k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param rand Random number drawn uniformly from [0,1). - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are n x d. - * @param dists (Input, device memory, 2*n entries) Workspace. The - * first n entries should be the distance between observation - * vectors and the closest centroid. - * @param centroid (Output, device memory, d entries) Centroid - * coordinates. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int chooseNewCentroid(handle_t const& handle, - index_type_t n, - index_type_t d, - value_type_t rand, - const value_type_t* __restrict__ obs, - value_type_t* __restrict__ dists, - value_type_t* __restrict__ centroid) -{ - // Cumulative sum of distances - value_type_t* distsCumSum = dists + n; - // Residual sum of squares - value_type_t distsSum{0}; - // Observation vector that is chosen as new centroid - index_type_t obsIndex; - - auto stream = handle.get_stream(); - auto thrust_exec_policy = handle.get_thrust_policy(); - - // Compute cumulative sum of distances - thrust::inclusive_scan(thrust_exec_policy, - thrust::device_pointer_cast(dists), - thrust::device_pointer_cast(dists + n), - thrust::device_pointer_cast(distsCumSum)); - CHECK_CUDA(stream); - CUDA_TRY(cudaMemcpyAsync( - &distsSum, distsCumSum + n - 1, sizeof(value_type_t), cudaMemcpyDeviceToHost, stream)); - - // Randomly choose observation vector - // Probabilities are proportional to square of distance to closest - // centroid (see k-means++ algorithm) - // - // seg-faults due to Thrust bug - // on binary-search-like algorithms - // when run with stream dependent - // execution policies; fixed on Thrust GitHub - // hence replace w/ linear interpolation, - // until the Thrust issue gets resolved: - // - // obsIndex = (thrust::lower_bound( - // thrust_exec_policy, thrust::device_pointer_cast(distsCumSum), - // thrust::device_pointer_cast(distsCumSum + n), distsSum * rand) - - // thrust::device_pointer_cast(distsCumSum)); - // - // linear interpolation logic: - //{ - value_type_t minSum{0}; - RAFT_CUDA_TRY( - cudaMemcpyAsync(&minSum, distsCumSum, sizeof(value_type_t), cudaMemcpyDeviceToHost, stream)); - RAFT_CHECK_CUDA(stream); - - if (distsSum > minSum) { - value_type_t vIndex = static_cast(n - 1); - obsIndex = static_cast(vIndex * (distsSum * rand - minSum) / (distsSum - minSum)); - } else { - obsIndex = 0; + detail::kmeans_fit_main(handle, + iter_params, + X, + weight.view(), + centroidsRawData.view(), + raft::make_host_scalar_view(&iter_inertia), + raft::make_host_scalar_view(&n_current_iter), + workspace); + if (iter_inertia < inertia[0]) { + inertia[0] = iter_inertia; + n_iter[0] = n_current_iter; + raft::copy(centroids.data(), centroidsRawData.data(), n_clusters * n_features, stream); + } + RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", + seed_iter + 1, + n_init, + inertia[0], + n_iter[0]); } - //} - - RAFT_CHECK_CUDA(stream); - obsIndex = std::max(obsIndex, static_cast(0)); - obsIndex = std::min(obsIndex, n - 1); - - // Record new centroid position - RAFT_CUDA_TRY(cudaMemcpyAsync(centroid, - obs + IDX(0, obsIndex, d), - d * sizeof(value_type_t), - cudaMemcpyDeviceToDevice, - stream)); + RAFT_LOG_DEBUG("KMeans.fit: async call returned (fit could still be running on the device)"); +} - return 0; +template +void kmeans_fit(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + DataT* centroids, + IndexT n_samples, + IndexT n_features, + DataT& inertia, + IndexT& n_iter) +{ + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto centroidsView = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + std::optional> sample_weightView = std::nullopt; + if (sample_weight) sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); + auto inertiaView = raft::make_host_scalar_view(&inertia); + auto n_iterView = raft::make_host_scalar_view(&n_iter); + + detail::kmeans_fit( + handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView); } -/** - * @brief Choose initial cluster centroids for k-means algorithm. - * Centroids are randomly chosen with k-means++ algorithm - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param dists (Output, device memory, 2*n entries) Workspace. On - * exit, the first n entries give the square of the Euclidean - * distance between observation vectors and the closest centroid. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int initializeCentroids(handle_t const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - value_type_t* __restrict__ centroids, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ dists, - unsigned long long seed) +template +void kmeans_predict(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) { - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Loop index - index_type_t i; - - // Random number generator - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution uniformDist(0, 1); - - auto stream = handle.get_stream(); - auto thrust_exec_policy = handle.get_thrust_policy(); - - constexpr unsigned grid_lower_bound{65535}; - - // ------------------------------------------------------- - // Implementation - // ------------------------------------------------------- - - // Initialize grid dimensions - dim3 blockDim_warp{WARP_SIZE, 1, BSIZE_DIV_WSIZE}; - - // CUDA grid dimensions - dim3 gridDim_warp{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - 1, - std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound)}; - - // CUDA grid dimensions - dim3 gridDim_block{std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound), 1, 1}; - - // Assign observation vectors to code 0 - RAFT_CUDA_TRY(cudaMemsetAsync(codes, 0, n * sizeof(index_type_t), stream)); - - // Choose first centroid - thrust::fill(thrust_exec_policy, - thrust::device_pointer_cast(dists), - thrust::device_pointer_cast(dists + n), - 1); - RAFT_CHECK_CUDA(stream); - if (chooseNewCentroid(handle, n, d, uniformDist(rng), obs, dists, centroids)) - WARNING("error in k-means++ (could not pick centroid)"); - - // Compute distances from first centroid - RAFT_CUDA_TRY(cudaMemsetAsync(dists, 0, n * sizeof(value_type_t), stream)); - computeDistances<<>>(n, d, 1, obs, centroids, dists); - RAFT_CHECK_CUDA(stream); - - // Choose remaining centroids - for (i = 1; i < k; ++i) { - // Choose ith centroid - if (chooseNewCentroid(handle, n, d, uniformDist(rng), obs, dists, centroids + IDX(0, i, d))) - WARNING("error in k-means++ (could not pick centroid)"); - - // Compute distances from ith centroid - CUDA_TRY(cudaMemsetAsync(dists + n, 0, n * sizeof(value_type_t), stream)); - computeDistances<<>>( - n, d, 1, obs, centroids + IDX(0, i, d), dists + n); - RAFT_CHECK_CUDA(stream); - - // Recompute minimum distances - minDistances2<<>>(n, dists, dists + n, codes, i); - RAFT_CHECK_CUDA(stream); - } + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + cudaStream_t stream = handle.get_stream(); + // Check that parameters are valid + if (sample_weight.has_value()) + RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, + "invalid parameter (sample_weight!=n_samples)"); + RAFT_EXPECTS(params.n_clusters > 0, "invalid parameter (n_clusters<=0)"); + RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); + RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); + RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, + "invalid parameter (centroids.extent(0) != n_clusters)"); + RAFT_EXPECTS(centroids.extent(1) == n_features, + "invalid parameter (centroids.extent(1) != n_features)"); + + logger::get(RAFT_NAME).set_level(params.verbosity); + auto metric = params.metric; - // Compute cluster sizes - CUDA_TRY(cudaMemsetAsync(clusterSizes, 0, k * sizeof(index_type_t), stream)); - computeClusterSizes<<>>(n, codes, clusterSizes); - RAFT_CHECK_CUDA(stream); + // Allocate memory + // Device-accessible allocation of expandable storage used as temorary buffers + rmm::device_uvector workspace(0, stream); + auto weight = raft::make_device_vector(handle, n_samples); + if (sample_weight.has_value()) + raft::copy(weight.data(), sample_weight.value().data(), n_samples, stream); + else + thrust::fill(handle.get_thrust_policy(), weight.data(), weight.data() + weight.size(), 1); + + // check if weights sum up to n_samples + if (normalize_weight) checkWeight(handle, weight.view(), workspace); + + auto minClusterAndDistance = + raft::make_device_vector>(n_samples, stream); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // L2 norm of X: ||x||^2 + auto L2NormX = raft::make_device_vector(n_samples, stream); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormX.data(), X.data(), X.extent(1), X.extent(0), raft::linalg::L2Norm, true, stream); + } - return 0; + // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] + // is a pair where + // 'key' is index to a sample in 'centroids' (index of the nearest + // centroid) and 'value' is the distance between the sample 'X[i]' and the + // 'centroid[key]' + detail::minClusterAndDistanceCompute(handle, + params, + X, + centroids, + minClusterAndDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + workspace); + + // calculate cluster cost phi_x(C) + rmm::device_scalar> clusterCostD(stream); + // TODO: add different templates for InType of binaryOp to avoid thrust transform + thrust::transform(handle.get_thrust_policy(), + minClusterAndDistance.data(), + minClusterAndDistance.data() + minClusterAndDistance.size(), + weight.data(), + minClusterAndDistance.data(), + [=] __device__(const cub::KeyValuePair kvp, DataT wt) { + cub::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }); + + detail::computeClusterCost(handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + [] __device__(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + cub::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + }); + + raft::copy(inertia.data(), &(clusterCostD.data()->value), 1, stream); + + thrust::transform(handle.get_thrust_policy(), + minClusterAndDistance.data(), + minClusterAndDistance.data() + minClusterAndDistance.size(), + labels.data(), + [=] __device__(cub::KeyValuePair pair) { return pair.key; }); } -/** - * @brief Find cluster centroids closest to observation vectors. - * Distance is measured with Euclidean norm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param centroids (Input, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param dists (Output, device memory, n*k entries) Workspace. On - * exit, the first n entries give the square of the Euclidean - * distance between observation vectors and the closest centroid. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param residual_host (Output, host memory, 1 entry) Residual sum - * of squares of assignment. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int assignCentroids(handle_t const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const value_type_t* __restrict__ centroids, - value_type_t* __restrict__ dists, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* residual_host) +template +void kmeans_predict(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + const DataT* centroids, + IndexT n_samples, + IndexT n_features, + IndexT* labels, + bool normalize_weight, + DataT& inertia) { - auto stream = handle.get_stream(); - auto thrust_exec_policy = handle.get_thrust_policy(); - - // Compute distance between centroids and observation vectors - RAFT_CUDA_TRY(cudaMemsetAsync(dists, 0, n * k * sizeof(value_type_t), stream)); - - // CUDA grid dimensions - dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; - - dim3 gridDim; - constexpr unsigned grid_lower_bound{65535}; - gridDim.x = std::min(ceildiv(d, WARP_SIZE), grid_lower_bound); - gridDim.y = std::min(static_cast(k), grid_lower_bound); - gridDim.z = std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound); - - computeDistances<<>>(n, d, k, obs, centroids, dists); - RAFT_CHECK_CUDA(stream); - - // Find centroid closest to each observation vector - CUDA_TRY(cudaMemsetAsync(clusterSizes, 0, k * sizeof(index_type_t), stream)); - blockDim.x = BLOCK_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound); - gridDim.y = 1; - gridDim.z = 1; - minDistances<<>>(n, k, dists, codes, clusterSizes); - CHECK_CUDA(stream); - - // Compute residual sum of squares - *residual_host = thrust::reduce( - thrust_exec_policy, thrust::device_pointer_cast(dists), thrust::device_pointer_cast(dists + n)); - - return 0; + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto centroidsView = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + std::optional> sample_weightView = std::nullopt; + if (sample_weight) sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); + auto labelsView = raft::make_device_vector_view(labels, n_samples); + auto inertiaView = raft::make_host_scalar_view(&inertia); + + detail::kmeans_predict(handle, + params, + XView, + sample_weightView, + centroidsView, + labelsView, + normalize_weight, + inertiaView); } -/** - * @brief Update cluster centroids for k-means algorithm. - * All clusters are assumed to be non-empty. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Input, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Input, device memory, k entries) Number of - * points in each cluster. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param work (Output, device memory, n*d entries) Workspace. - * @param work_int (Output, device memory, 2*d*n entries) - * Workspace. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int updateCentroids(handle_t const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const index_type_t* __restrict__ codes, - const index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids, - value_type_t* __restrict__ work, - index_type_t* __restrict__ work_int) +template +void kmeans_fit_predict(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) { - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Useful constants - const value_type_t one = 1; - const value_type_t zero = 0; - - constexpr unsigned grid_lower_bound{65535}; - - auto stream = handle.get_stream(); - auto cublas_h = handle.get_cublas_handle(); - auto thrust_exec_policy = handle.get_thrust_policy(); - - // Device memory - thrust::device_ptr obs_copy(work); - thrust::device_ptr codes_copy(work_int); - thrust::device_ptr rows(work_int + d * n); - - // Take transpose of observation matrix - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgeam(cublas_h, - CUBLAS_OP_T, - CUBLAS_OP_N, - n, - d, - &one, - obs, - d, - &zero, - (value_type_t*)NULL, - n, - thrust::raw_pointer_cast(obs_copy), - n, - stream)); - - // Cluster assigned to each observation matrix entry - thrust::sequence(thrust_exec_policy, rows, rows + d * n); - RAFT_CHECK_CUDA(stream); - thrust::transform(thrust_exec_policy, - rows, - rows + d * n, - thrust::make_constant_iterator(n), - rows, - thrust::modulus()); - RAFT_CHECK_CUDA(stream); - thrust::gather( - thrust_exec_policy, rows, rows + d * n, thrust::device_pointer_cast(codes), codes_copy); - RAFT_CHECK_CUDA(stream); - - // Row associated with each observation matrix entry - thrust::sequence(thrust_exec_policy, rows, rows + d * n); - RAFT_CHECK_CUDA(stream); - thrust::transform(thrust_exec_policy, - rows, - rows + d * n, - thrust::make_constant_iterator(n), - rows, - thrust::divides()); - RAFT_CHECK_CUDA(stream); - - // Sort and reduce to add observation vectors in same cluster - thrust::stable_sort_by_key(thrust_exec_policy, - codes_copy, - codes_copy + d * n, - make_zip_iterator(make_tuple(obs_copy, rows))); - RAFT_CHECK_CUDA(stream); - thrust::reduce_by_key(thrust_exec_policy, - rows, - rows + d * n, - obs_copy, - codes_copy, // Output to codes_copy is ignored - thrust::device_pointer_cast(centroids)); - RAFT_CHECK_CUDA(stream); - - // Divide sums by cluster size to get centroid matrix - // - // CUDA grid dimensions - dim3 blockDim{WARP_SIZE, BLOCK_SIZE / WARP_SIZE, 1}; - - // CUDA grid dimensions - dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - std::min(ceildiv(k, BSIZE_DIV_WSIZE), grid_lower_bound), - 1}; - - divideCentroids<<>>(d, k, clusterSizes, centroids); - RAFT_CHECK_CUDA(stream); - - return 0; + if (!centroids.has_value()) { + auto n_features = X.extent(1); + auto centroids_matrix = + raft::make_device_matrix(params.n_clusters, n_features, handle.get_stream()); + detail::kmeans_fit( + handle, params, X, sample_weight, centroids_matrix.view(), inertia, n_iter); + detail::kmeans_predict( + handle, params, X, sample_weight, centroids_matrix.view(), labels, true, inertia); + } else { + detail::kmeans_fit( + handle, params, X, sample_weight, centroids.value(), inertia, n_iter); + detail::kmeans_predict( + handle, params, X, sample_weight, centroids.value(), labels, true, inertia); + } } -// ========================================================= -// k-means algorithm -// ========================================================= +template +void kmeans_fit_predict(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + DataT* centroids, + IndexT n_samples, + IndexT n_features, + IndexT* labels, + DataT& inertia, + IndexT& n_iter) +{ + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + std::optional> sample_weightView = std::nullopt; + if (sample_weight) sample_weightView = raft::make_device_vector_view(sample_weight, n_samples); + std::optional> centroidsView = std::nullopt; + if (centroids) + centroidsView = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto labelsView = raft::make_device_vector_view(labels, n_samples); + auto inertiaView = raft::make_host_scalar_view(&inertia); + auto n_iterView = raft::make_host_scalar_view(&n_iter); + + detail::kmeans_fit_predict( + handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView); +} /** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param work (Output, device memory, n*max(k,d) entries) - * Workspace. - * @param work_int (Output, device memory, 2*d*n entries) - * Workspace. - * @param residual_host (Output, host memory, 1 entry) Residual sum - * of squares (sum of squares of distances between observation - * vectors and centroids). - * @param iters_host (Output, host memory, 1 entry) Number of - * k-means iterations. - * @param seed random seed to be used. - * @return error flag. + * @brief Transform X to a cluster-distance space. + * + * @param[in] handle The handle to the cuML library context that + * manages the CUDA resources. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format + * @param[in] centroids Cluster centroids. The data must be in row-major format. + * @param[out] X_new X transformed in the new space.. */ -template -int kmeans(handle_t const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids, - value_type_t* __restrict__ work, - index_type_t* __restrict__ work_int, - value_type_t* residual_host, - index_type_t* iters_host, - unsigned long long seed) +template +void kmeans_transform(const raft::handle_t& handle, + const KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_matrix_view X_new) { - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Current iteration - index_type_t iter; - - constexpr unsigned grid_lower_bound{65535}; - - // Residual sum of squares at previous iteration - value_type_t residualPrev = 0; - - // Random number generator - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution uniformDist(0, 1); - - // ------------------------------------------------------- - // Initialization - // ------------------------------------------------------- - - auto stream = handle.get_stream(); - auto cublas_h = handle.get_cublas_handle(); - auto thrust_exec_policy = handle.get_thrust_policy(); - - // Trivial cases - if (k == 1) { - CUDA_TRY(cudaMemsetAsync(codes, 0, n * sizeof(index_type_t), stream)); - CUDA_TRY( - cudaMemcpyAsync(clusterSizes, &n, sizeof(index_type_t), cudaMemcpyHostToDevice, stream)); - if (updateCentroids(handle, n, d, k, obs, codes, clusterSizes, centroids, work, work_int)) - WARNING("could not compute k-means centroids"); - - dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; - - dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - 1, - std::min(ceildiv(n, BLOCK_SIZE / WARP_SIZE), grid_lower_bound)}; - - CUDA_TRY(cudaMemsetAsync(work, 0, n * k * sizeof(value_type_t), stream)); - computeDistances<<>>(n, d, 1, obs, centroids, work); - RAFT_CHECK_CUDA(stream); - *residual_host = thrust::reduce( - thrust_exec_policy, thrust::device_pointer_cast(work), thrust::device_pointer_cast(work + n)); - RAFT_CHECK_CUDA(stream); - return 0; - } - if (n <= k) { - thrust::sequence(thrust_exec_policy, - thrust::device_pointer_cast(codes), - thrust::device_pointer_cast(codes + n)); - RAFT_CHECK_CUDA(stream); - thrust::fill_n(thrust_exec_policy, thrust::device_pointer_cast(clusterSizes), n, 1); - RAFT_CHECK_CUDA(stream); - - if (n < k) - RAFT_CUDA_TRY(cudaMemsetAsync(clusterSizes + n, 0, (k - n) * sizeof(index_type_t), stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync( - centroids, obs, d * n * sizeof(value_type_t), cudaMemcpyDeviceToDevice, stream)); - *residual_host = 0; - return 0; - } - - // Initialize cuBLAS - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY( - raft::linalg::detail::cublassetpointermode(cublas_h, CUBLAS_POINTER_MODE_HOST, stream)); - - // ------------------------------------------------------- - // k-means++ algorithm - // ------------------------------------------------------- - - // Choose initial cluster centroids - if (initializeCentroids(handle, n, d, k, obs, centroids, codes, clusterSizes, work, seed)) - WARNING("could not initialize k-means centroids"); - - // Apply k-means iteration until convergence - for (iter = 0; iter < maxiter; ++iter) { - // Update cluster centroids - if (updateCentroids(handle, n, d, k, obs, codes, clusterSizes, centroids, work, work_int)) - WARNING("could not update k-means centroids"); - - // Determine centroid closest to each observation - residualPrev = *residual_host; - if (assignCentroids(handle, n, d, k, obs, centroids, work, codes, clusterSizes, residual_host)) - WARNING("could not assign observation vectors to k-means clusters"); - - // Reinitialize empty clusters with new centroids - index_type_t emptyCentroid = (thrust::find(thrust_exec_policy, - thrust::device_pointer_cast(clusterSizes), - thrust::device_pointer_cast(clusterSizes + k), - 0) - - thrust::device_pointer_cast(clusterSizes)); - - // FIXME: emptyCentroid never reaches k (infinite loop) under certain - // conditions, such as if obs is corrupt (as seen as a result of a - // DataFrame column of NULL edge vals used to create the Graph) - while (emptyCentroid < k) { - if (chooseNewCentroid( - handle, n, d, uniformDist(rng), obs, work, centroids + IDX(0, emptyCentroid, d))) - WARNING("could not replace empty centroid"); - if (assignCentroids( - handle, n, d, k, obs, centroids, work, codes, clusterSizes, residual_host)) - WARNING("could not assign observation vectors to k-means clusters"); - emptyCentroid = (thrust::find(thrust_exec_policy, - thrust::device_pointer_cast(clusterSizes), - thrust::device_pointer_cast(clusterSizes + k), - 0) - - thrust::device_pointer_cast(clusterSizes)); - RAFT_CHECK_CUDA(stream); - } - - // Check for convergence - if (std::fabs(residualPrev - (*residual_host)) / n < tol) { - ++iter; - break; - } + logger::get(RAFT_NAME).set_level(params.verbosity); + cudaStream_t stream = handle.get_stream(); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + // Device-accessible allocation of expandable storage used as temorary buffers + rmm::device_uvector workspace(0, stream); + auto dataBatchSize = getDataBatchSize(params, n_samples); + + // tile over the input data and calculate distance matrix [n_samples x + // n_clusters] + for (IndexT dIdx = 0; dIdx < (IndexT)n_samples; dIdx += dataBatchSize) { + // # of samples for the current batch + auto ns = std::min(dataBatchSize, n_samples - dIdx); + + // datasetView [ns x n_features] - view representing the current batch of + // input dataset + auto datasetView = raft::make_device_matrix_view(X.data() + n_features * dIdx, ns, n_features); + + // pairwiseDistanceView [ns x n_clusters] + auto pairwiseDistanceView = + raft::make_device_matrix_view(X_new.data() + n_clusters * dIdx, ns, n_clusters); + + // calculate pairwise distance between cluster centroids and current batch + // of input dataset + pairwise_distance_kmeans( + handle, datasetView, centroids, pairwiseDistanceView, workspace, metric); } - - // Warning if k-means has failed to converge - if (std::fabs(residualPrev - (*residual_host)) / n >= tol) WARNING("k-means failed to converge"); - - *iters_host = iter; - return 0; } -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param residual On exit, residual sum of squares (sum of squares - * of distances between observation vectors and centroids). - * @param iters on exit, number of k-means iterations. - * @param seed random seed to be used. - * @return error flag - */ -template -int kmeans(handle_t const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - value_type_t& residual, - index_type_t& iters, - unsigned long long seed = 123456) +template +void kmeans_transform(const raft::handle_t& handle, + const KMeansParams& params, + const DataT* X, + const DataT* centroids, + IndexT n_samples, + IndexT n_features, + DataT* X_new) { - // Check that parameters are valid - RAFT_EXPECTS(n > 0, "invalid parameter (n<1)"); - RAFT_EXPECTS(d > 0, "invalid parameter (d<1)"); - RAFT_EXPECTS(k > 0, "invalid parameter (k<1)"); - RAFT_EXPECTS(tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(maxiter >= 0, "invalid parameter (maxiter<0)"); + auto XView = raft::make_device_matrix_view(X, n_samples, n_features); + auto centroidsView = raft::make_device_matrix_view(centroids, params.n_clusters, n_features); + auto X_newView = raft::make_device_matrix_view(X_new, n_samples, n_features); - // Allocate memory - raft::spectral::matrix::vector_t clusterSizes(handle, k); - raft::spectral::matrix::vector_t centroids(handle, d * k); - raft::spectral::matrix::vector_t work(handle, n * std::max(k, d)); - raft::spectral::matrix::vector_t work_int(handle, 2 * d * n); - - // Perform k-means - return kmeans(handle, - n, - d, - k, - tol, - maxiter, - obs, - codes, - clusterSizes.raw(), - centroids.raw(), - work.raw(), - work_int.raw(), - &residual, - &iters, - seed); + detail::kmeans_transform(handle, params, XView, centroidsView, X_newView); } - } // namespace detail } // namespace cluster } // namespace raft diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh new file mode 100644 index 0000000000..0d46b532c4 --- /dev/null +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -0,0 +1,683 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace cluster { +namespace detail { + +template +struct FusedL2NNReduceOp { + IndexT offset; + + FusedL2NNReduceOp(IndexT _offset) : offset(_offset){}; + + typedef typename cub::KeyValuePair KVP; + DI void operator()(IndexT rit, KVP* out, const KVP& other) + { + if (other.value < out->value) { + out->key = offset + other.key; + out->value = other.value; + } + } + + DI void operator()(IndexT rit, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) + { + out->key = -1; + out->value = maxVal; + } +}; + +template +struct SamplingOp { + DataT* rnd; + int* flag; + DataT cluster_cost; + double oversampling_factor; + IndexT n_clusters; + + CUB_RUNTIME_FUNCTION __forceinline__ + SamplingOp(DataT c, double l, IndexT k, DataT* rand, int* ptr) + : cluster_cost(c), oversampling_factor(l), n_clusters(k), rnd(rand), flag(ptr) + { + } + + __host__ __device__ __forceinline__ bool operator()( + const cub::KeyValuePair& a) const + { + DataT prob_threshold = (DataT)rnd[a.key]; + + DataT prob_x = ((oversampling_factor * n_clusters * a.value) / cluster_cost); + + return !flag[a.key] && (prob_x > prob_threshold); + } +}; + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const cub::KeyValuePair& a) const + { + return a.key; + } +}; + +// Computes the intensity histogram from a sequence of labels +template +void countLabels(const raft::handle_t& handle, + SampleIteratorT labels, + CounterT* count, + IndexT n_samples, + IndexT n_clusters, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = handle.get_stream(); + IndexT num_levels = n_clusters + 1; + IndexT lower_level = 0; + IndexT upper_level = n_clusters; + + size_t temp_storage_bytes = 0; + RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, + temp_storage_bytes, + labels, + count, + num_levels, + lower_level, + upper_level, + n_samples, + stream)); + + workspace.resize(temp_storage_bytes, stream); + + RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(), + temp_storage_bytes, + labels, + count, + num_levels, + lower_level, + upper_level, + n_samples, + stream)); +} + +template +void checkWeight(const raft::handle_t& handle, + const raft::device_vector_view& weight, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = handle.get_stream(); + auto wt_aggr = raft::make_device_scalar(0, stream); + auto n_samples = weight.extent(0); + + size_t temp_storage_bytes = 0; + RAFT_CUDA_TRY(cub::DeviceReduce::Sum( + nullptr, temp_storage_bytes, weight.data(), wt_aggr.data(), n_samples, stream)); + + workspace.resize(temp_storage_bytes, stream); + + RAFT_CUDA_TRY(cub::DeviceReduce::Sum( + workspace.data(), temp_storage_bytes, weight.data(), wt_aggr.data(), n_samples, stream)); + DataT wt_sum = 0; + raft::copy(&wt_sum, wt_aggr.data(), 1, stream); + handle.sync_stream(stream); + + if (wt_sum != n_samples) { + RAFT_LOG_DEBUG( + "[Warning!] KMeans: normalizing the user provided sample weight to " + "sum up to %d samples", + n_samples); + + auto scale = static_cast(n_samples) / wt_sum; + raft::linalg::unaryOp( + weight.data(), + weight.data(), + n_samples, + [=] __device__(const DataT& wt) { return wt * scale; }, + stream); + } +} + +template +IndexT getDataBatchSize(const KMeansParams& params, IndexT n_samples) +{ + auto minVal = std::min(static_cast(params.batch_samples), n_samples); + return (minVal == 0) ? n_samples : minVal; +} + +template +IndexT getCentroidsBatchSize(const KMeansParams& params, IndexT n_local_clusters) +{ + auto minVal = std::min(static_cast(params.batch_centroids), n_local_clusters); + return (minVal == 0) ? n_local_clusters : minVal; +} + +template +void computeClusterCost(const raft::handle_t& handle, + const raft::device_vector_view& minClusterDistance, + rmm::device_uvector& workspace, + const raft::device_scalar_view& clusterCost, + ReductionOpT reduction_op) +{ + cudaStream_t stream = handle.get_stream(); + size_t temp_storage_bytes = 0; + RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, + temp_storage_bytes, + minClusterDistance.data(), + clusterCost.data(), + minClusterDistance.size(), + reduction_op, + DataT(), + stream)); + + workspace.resize(temp_storage_bytes, stream); + + RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), + temp_storage_bytes, + minClusterDistance.data(), + clusterCost.data(), + minClusterDistance.size(), + reduction_op, + DataT(), + stream)); +} + +template +void sampleCentroids(const raft::handle_t& handle, + const raft::device_matrix_view& X, + const raft::device_vector_view& minClusterDistance, + const raft::device_vector_view& isSampleCentroid, + SamplingOp& select_op, + rmm::device_uvector& inRankCp, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = handle.get_stream(); + auto n_local_samples = X.extent(0); + auto n_features = X.extent(1); + + auto nSelected = raft::make_device_scalar(0, stream); + cub::ArgIndexInputIterator ip_itr(minClusterDistance.data()); + auto sampledMinClusterDistance = + raft::make_device_vector>(n_local_samples, stream); + size_t temp_storage_bytes = 0; + RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr, + temp_storage_bytes, + ip_itr, + sampledMinClusterDistance.data(), + nSelected.data(), + n_local_samples, + select_op, + stream)); + + workspace.resize(temp_storage_bytes, stream); + + RAFT_CUDA_TRY(cub::DeviceSelect::If(workspace.data(), + temp_storage_bytes, + ip_itr, + sampledMinClusterDistance.data(), + nSelected.data(), + n_local_samples, + select_op, + stream)); + + IndexT nPtsSampledInRank = 0; + raft::copy(&nPtsSampledInRank, nSelected.data(), 1, stream); + handle.sync_stream(stream); + + IndexT* rawPtr_isSampleCentroid = isSampleCentroid.data(); + thrust::for_each_n(handle.get_thrust_policy(), + sampledMinClusterDistance.data(), + nPtsSampledInRank, + [=] __device__(cub::KeyValuePair val) { + rawPtr_isSampleCentroid[val.key] = 1; + }); + + inRankCp.resize(nPtsSampledInRank * n_features, stream); + + raft::matrix::gather((DataT*)X.data(), + X.extent(1), + X.extent(0), + sampledMinClusterDistance.data(), + nPtsSampledInRank, + inRankCp.data(), + [=] __device__(cub::KeyValuePair val) { // MapTransformOp + return val.key; + }, + stream); +} + +// calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', +// result will be stored in 'pairwiseDistance[n x k]' +template +void pairwise_distance_kmeans(const raft::handle_t& handle, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroids, + const raft::device_matrix_view& pairwiseDistance, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric) +{ + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = centroids.extent(0); + + ASSERT(X.extent(1) == centroids.extent(1), + "# features in dataset and centroids are different (must be same)"); + + raft::distance::pairwise_distance(handle, + X.data(), + centroids.data(), + pairwiseDistance.data(), + n_samples, + n_clusters, + n_features, + workspace, + metric); +} + +// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores +// in 'out' does not modify the input +template +void shuffleAndGather(const raft::handle_t& handle, + const raft::device_matrix_view& in, + const raft::device_matrix_view& out, + uint32_t n_samples_to_gather, + uint64_t seed, + rmm::device_uvector* workspace = nullptr) +{ + cudaStream_t stream = handle.get_stream(); + auto n_samples = in.extent(0); + auto n_features = in.extent(1); + + auto indices = raft::make_device_vector(n_samples, stream); + + if (workspace) { + // shuffle indices on device + raft::random::permute( + indices.data(), nullptr, nullptr, (IndexT)in.extent(1), (IndexT)in.extent(0), true, stream); + } else { + // shuffle indices on host and copy to device... + std::vector ht_indices(n_samples); + + std::iota(ht_indices.begin(), ht_indices.end(), 0); + + std::mt19937 gen(seed); + std::shuffle(ht_indices.begin(), ht_indices.end(), gen); + + raft::copy(indices.data(), ht_indices.data(), indices.size(), stream); + } + + raft::matrix::gather((DataT*)in.data(), + in.extent(1), + in.extent(0), + indices.data(), + n_samples_to_gather, + out.data(), + stream); +} + +// Calculates a pair for every sample in input 'X' where key is an +// index to an sample in 'centroids' (index of the nearest centroid) and 'value' +// is the distance between the sample and the 'centroid[key]' +template +void minClusterAndDistanceCompute( + const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view X, + const raft::device_matrix_view centroids, + const raft::device_vector_view>& minClusterAndDistance, + const raft::device_vector_view& L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = handle.get_stream(); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = centroids.extent(0); + auto metric = params.metric; + auto dataBatchSize = getDataBatchSize(params, n_samples); + auto centroidsBatchSize = getCentroidsBatchSize(params, n_clusters); + + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), + centroids.data(), + centroids.extent(1), + centroids.extent(0), + raft::linalg::L2Norm, + true, + stream); + } else { + L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); + } + + // Note - pairwiseDistance and centroidsNorm share the same buffer + // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm + auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto pairwiseDistance = + raft::make_device_matrix_view(L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + + cub::KeyValuePair initial_value(0, std::numeric_limits::max()); + + thrust::fill(handle.get_thrust_policy(), + minClusterAndDistance.data(), + minClusterAndDistance.data() + minClusterAndDistance.size(), + initial_value); + + // tile over the input dataset + for (std::size_t dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + // # of samples for the current batch + auto ns = std::min(dataBatchSize, n_samples - dIdx); + + // datasetView [ns x n_features] - view representing the current batch of + // input dataset + auto datasetView = + raft::make_device_matrix_view(X.data() + (dIdx * n_features), ns, n_features); + + // minClusterAndDistanceView [ns x n_clusters] + auto minClusterAndDistanceView = + raft::make_device_vector_view(minClusterAndDistance.data() + dIdx, ns); + + auto L2NormXView = raft::make_device_vector_view(L2NormX.data() + dIdx, ns); + + // tile over the centroids + for (std::size_t cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { + // # of centroids for the current batch + auto nc = std::min(centroidsBatchSize, n_clusters - cIdx); + + // centroidsView [nc x n_features] - view representing the current batch + // of centroids + auto centroidsView = + raft::make_device_matrix_view(centroids.data() + (cIdx * n_features), nc, n_features); + + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + auto centroidsNormView = raft::make_device_vector_view(centroidsNorm.data() + cIdx, nc); + workspace.resize((sizeof(int)) * ns, stream); + + FusedL2NNReduceOp redOp(cIdx); + raft::distance::KVPMinReduce pairRedOp; + + raft::distance::fusedL2NN, IndexT>( + minClusterAndDistanceView.data(), + datasetView.data(), + centroidsView.data(), + L2NormXView.data(), + centroidsNormView.data(), + ns, + nc, + n_features, + (void*)workspace.data(), + redOp, + pairRedOp, + (metric == raft::distance::DistanceType::L2Expanded) ? false : true, + false, + stream); + } else { + // pairwiseDistanceView [ns x nc] - view representing the pairwise + // distance for current batch + auto pairwiseDistanceView = raft::make_device_matrix_view(pairwiseDistance.data(), ns, nc); + + // calculate pairwise distance between current tile of cluster centroids + // and input dataset + pairwise_distance_kmeans( + handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); + + // argmin reduction returning pair + // calculates the closest centroid and the distance to the closest + // centroid + raft::linalg::coalescedReduction( + minClusterAndDistanceView.data(), + pairwiseDistanceView.data(), + pairwiseDistanceView.extent(1), + pairwiseDistanceView.extent(0), + initial_value, + stream, + true, + [=] __device__(const DataT val, const IndexT i) { + cub::KeyValuePair pair; + pair.key = cIdx + i; + pair.value = val; + return pair; + }, + [=] __device__(cub::KeyValuePair a, cub::KeyValuePair b) { + return (b.value < a.value) ? b : a; + }, + [=] __device__(cub::KeyValuePair pair) { return pair; }); + } + } + } +} + +template +void minClusterDistanceCompute(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroids, + const raft::device_vector_view& minClusterDistance, + const raft::device_vector_view& L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = handle.get_stream(); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = centroids.extent(0); + auto metric = params.metric; + + auto dataBatchSize = getDataBatchSize(params, n_samples); + auto centroidsBatchSize = getCentroidsBatchSize(params, n_clusters); + + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), + centroids.data(), + centroids.extent(1), + centroids.extent(0), + raft::linalg::L2Norm, + true, + stream); + } else { + L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); + } + + // Note - pairwiseDistance and centroidsNorm share the same buffer + // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto pairwiseDistance = raft::make_device_matrix_view( + L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + + thrust::fill(handle.get_thrust_policy(), + minClusterDistance.data(), + minClusterDistance.data() + minClusterDistance.size(), + std::numeric_limits::max()); + + // tile over the input data and calculate distance matrix [n_samples x + // n_clusters] + for (std::size_t dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + // # of samples for the current batch + auto ns = std::min(dataBatchSize, n_samples - dIdx); + + // datasetView [ns x n_features] - view representing the current batch of + // input dataset + auto datasetView = + raft::make_device_matrix_view(X.data() + dIdx * n_features, ns, n_features); + + // minClusterDistanceView [ns x n_clusters] + auto minClusterDistanceView = + raft::make_device_vector_view(minClusterDistance.data() + dIdx, ns); + + auto L2NormXView = raft::make_device_vector_view(L2NormX.data() + dIdx, ns); + + // tile over the centroids + for (std::size_t cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { + // # of centroids for the current batch + auto nc = std::min(centroidsBatchSize, n_clusters - cIdx); + + // centroidsView [nc x n_features] - view representing the current batch + // of centroids + auto centroidsView = + raft::make_device_matrix_view(centroids.data() + cIdx * n_features, nc, n_features); + + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + auto centroidsNormView = + raft::make_device_vector_view(centroidsNorm.data() + cIdx, nc); + workspace.resize((sizeof(IndexT)) * ns, stream); + + FusedL2NNReduceOp redOp(cIdx); + raft::distance::KVPMinReduce pairRedOp; + raft::distance::fusedL2NN( + minClusterDistanceView.data(), + datasetView.data(), + centroidsView.data(), + L2NormXView.data(), + centroidsNormView.data(), + ns, + nc, + n_features, + (void*)workspace.data(), + redOp, + pairRedOp, + (metric != raft::distance::DistanceType::L2Expanded), + false, + stream); + } else { + // pairwiseDistanceView [ns x nc] - view representing the pairwise + // distance for current batch + auto pairwiseDistanceView = + raft::make_device_matrix_view(pairwiseDistance.data(), ns, nc); + + // calculate pairwise distance between current tile of cluster centroids + // and input dataset + pairwise_distance_kmeans( + handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); + + raft::linalg::coalescedReduction( + minClusterDistanceView.data(), + pairwiseDistanceView.data(), + pairwiseDistanceView.extent(1), + pairwiseDistanceView.extent(0), + std::numeric_limits::max(), + stream, + true, + [=] __device__(DataT val, IndexT i) { // MainLambda + return val; + }, + [=] __device__(DataT a, DataT b) { // ReduceLambda + return (b < a) ? b : a; + }, + [=] __device__(DataT val) { // FinalLambda + return val; + }); + } + } + } +} + +template +void countSamplesInCluster(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_vector_view& L2NormX, + const raft::device_matrix_view& centroids, + rmm::device_uvector& workspace, + const raft::device_vector_view& sampleCountInCluster) +{ + cudaStream_t stream = handle.get_stream(); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = centroids.extent(0); + + // stores (key, value) pair corresponding to each sample where + // - key is the index of nearest cluster + // - value is the distance to the nearest cluster + auto minClusterAndDistance = + raft::make_device_vector>(n_samples, stream); + + // temporary buffer to store distance matrix, destructor releases the resource + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] + // is a pair where + // 'key' is index to an sample in 'centroids' (index of the nearest + // centroid) and 'value' is the distance between the sample 'X[i]' and the + // 'centroid[key]' + detail::minClusterAndDistanceCompute(handle, + params, + X, + (raft::device_matrix_view)centroids, + minClusterAndDistance.view(), + L2NormX, + L2NormBuf_OR_DistBuf, + workspace); + + // Using TransformInputIteratorT to dereference an array of cub::KeyValuePair + // and converting them to just return the Key to be used in reduce_rows_by_key + // prims + detail::KeyValueIndexOp conversion_op; + cub::TransformInputIterator, + cub::KeyValuePair*> + itr(minClusterAndDistance.data(), conversion_op); + + // count # of samples in each cluster + countLabels( + handle, itr, sampleCountInCluster.data(), (IndexT)n_samples, (IndexT)n_clusters, workspace); +} +} // namespace detail +} // namespace cluster +} // namespace raft diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 28d4ae0719..3285a98083 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -15,51 +15,478 @@ */ #pragma once +#include #include +#include +#include namespace raft { namespace cluster { /** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param residual On exit, residual sum of squares (sum of squares - * of distances between observation vectors and centroids). - * @param iters on exit, number of k-means iterations. - * @param seed random seed to be used. - * @return error flag + * @brief Find clusters with k-means algorithm. + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. */ -template -int kmeans(handle_t const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - value_type_t& residual, - index_type_t& iters, - unsigned long long seed = 123456) -{ - return detail::kmeans( - handle, n, d, k, tol, maxiter, obs, codes, residual, iters, seed); +template +void kmeans_fit(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); +} + +template +void kmeans_fit(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + DataT* centroids, + IndexT n_samples, + IndexT n_features, + DataT& inertia, + IndexT& n_iter) +{ + detail::kmeans_fit( + handle, params, X, sample_weight, centroids, n_samples, n_features, inertia, n_iter); +} + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +template +void kmeans_predict(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + detail::kmeans_predict( + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); +} + +template +void kmeans_predict(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + const DataT* centroids, + IndexT n_samples, + IndexT n_features, + IndexT* labels, + bool normalize_weight, + DataT& inertia) +{ + detail::kmeans_predict(handle, + params, + X, + sample_weight, + centroids, + n_samples, + n_features, + labels, + normalize_weight, + inertia); +} + +/** + * @brief Compute k-means clustering and predicts cluster index for each sample + * in the input. + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must be + * in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids Optional + * [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] labels Index of the cluster each sample in X belongs + * to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +template +void kmeans_fit_predict(handle_t const& handle, + const KMeansParams& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + detail::kmeans_fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} + +template +void kmeans_fit_predict(handle_t const& handle, + const KMeansParams& params, + const DataT* X, + const DataT* sample_weight, + DataT* centroids, + IndexT n_samples, + IndexT n_features, + IndexT* labels, + DataT& inertia, + IndexT& n_iter) +{ + detail::kmeans_fit_predict( + handle, params, X, sample_weight, centroids, n_samples, n_features, labels, inertia, n_iter); +} + +/** + * @brief Transform X to a cluster-distance space. + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in row-major format. + * [dim = n_clusters x n_features] + * @param[out] X_new X transformed in the new space. + * [dim = n_samples x n_features] + */ +template +void kmeans_transform(const raft::handle_t& handle, + const KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_matrix_view X_new) +{ + detail::kmeans_transform(handle, params, X, centroids, X_new); +} + +template +void kmeans_transform(const raft::handle_t& handle, + const KMeansParams& params, + const DataT* X, + const DataT* centroids, + IndexT n_samples, + IndexT n_features, + DataT* X_new) +{ + detail::kmeans_transform( + handle, params, X, centroids, n_samples, n_features, X_new); +} + +template +using SamplingOp = detail::SamplingOp; + +template +using KeyValueIndexOp = detail::KeyValueIndexOp; + +/** + * @brief Select centroids according to a sampling operation + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] minClusterDistance Distance for every sample to it's nearest centroid + * [dim = n_samples] + * @param[in] isSampleCentroid Flag the sample choosen as initial centroid + * [dim = n_samples] + * @param[in] select_op The sampling operation used to select the centroids + * @param[out] inRankCp The sampled centroids + * [dim = n_selected_centroids x n_features] + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void sampleCentroids(const raft::handle_t& handle, + const raft::device_matrix_view& X, + const raft::device_vector_view& minClusterDistance, + const raft::device_vector_view& isSampleCentroid, + SamplingOp& select_op, + rmm::device_uvector& inRankCp, + rmm::device_uvector& workspace) +{ + detail::sampleCentroids( + handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); +} + +/** + * @brief Compute cluster cost + * + * @tparam DataT the type of data used for weights, distances. + * @tparam ReductionOpT the type of data used for the reduction operation. + * + * @param[in] handle The raft handle + * @param[in] minClusterDistance Distance for every sample to it's nearest centroid + * [dim = n_samples] + * @param[in] workspace Temporary workspace buffer which can get resized + * @param[out] clusterCost Resulting cluster cost + * @param[in] reduction_op The reduction operation used for the cost + * + */ +template +void computeClusterCost(const raft::handle_t& handle, + const raft::device_vector_view& minClusterDistance, + rmm::device_uvector& workspace, + const raft::device_scalar_view& clusterCost, + ReductionOpT reduction_op) +{ + detail::computeClusterCost( + handle, minClusterDistance, workspace, clusterCost, reduction_op); +} + +/** + * @brief Compute distance for every sample to it's nearest centroid + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[out] minClusterDistance Distance for every sample to it's nearest centroid + * [dim = n_samples] + * @param[in] L2NormX L2 norm of X : ||x||^2 + * [dim = n_samples] + * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance + * matrix + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void minClusterDistanceCompute(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroids, + const raft::device_vector_view& minClusterDistance, + const raft::device_vector_view& L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace) +{ + detail::minClusterDistanceCompute( + handle, params, X, centroids, minClusterDistance, L2NormX, L2NormBuf_OR_DistBuf, workspace); +} + +/** + * @brief Calculates a pair for every sample in input 'X' where key is an + * index of one of the 'centroids' (index of the nearest centroid) and 'value' + * is the distance between the sample and the 'centroid[key]' + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest + * centroid and it's distance + * [dim = n_samples] + * @param[in] L2NormX L2 norm of X : ||x||^2 + * [dim = n_samples] + * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance + * matrix + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void minClusterAndDistanceCompute( + const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view X, + const raft::device_matrix_view centroids, + const raft::device_vector_view>& minClusterAndDistance, + const raft::device_vector_view& L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace) +{ + detail::minClusterAndDistanceCompute( + handle, params, X, centroids, minClusterAndDistance, L2NormX, L2NormBuf_OR_DistBuf, workspace); +} + +/** + * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores + * in 'out' does not modify the input + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] in The data to shuffle and gather + * [dim = n_samples x n_features] + * @param[out] out The sampled data + * [dim = n_samples_to_gather x n_features] + * @param[in] n_samples_to_gather Number of sample to gather + * @param[in] seed Seed for the shuffle + * @param[in] workspace Temporary workspace buffer which can get resized + * + */ +template +void shuffleAndGather(const raft::handle_t& handle, + const raft::device_matrix_view& in, + const raft::device_matrix_view& out, + uint32_t n_samples_to_gather, + uint64_t seed, + rmm::device_uvector* workspace = nullptr) +{ + detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed, workspace); +} + +/** + * @brief Count the number of samples in each cluster + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[in] L2NormX L2 norm of X : ||x||^2 + * [dim = n_samples] + * @param[in] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[in] workspace Temporary workspace buffer which can get resized + * @param[out] sampleCountInCluster The count for each centroid + * [dim = n_cluster] + * + */ +template +void countSamplesInCluster(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_vector_view& L2NormX, + const raft::device_matrix_view& centroids, + rmm::device_uvector& workspace, + const raft::device_vector_view& sampleCountInCluster) +{ + detail::countSamplesInCluster( + handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); +} + +/* + * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. + + * @note This is the algorithm described in + * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. + * ACM-SIAM symposium on Discrete algorithms. + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle + * @param[in] params The parameters for KMeans + * @param[in] X The data in row-major format + * [dim = n_samples x n_features] + * @param[out] centroids Centroids data + * [dim = n_cluster x n_features] + * @param[in] workspace Temporary workspace buffer which can get resized + */ +template +void kmeansPlusPlus(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_matrix_view& centroidsRawData, + rmm::device_uvector& workspace) +{ + detail::kmeansPlusPlus(handle, params, X, centroidsRawData, workspace); +} + +/* + * @brief Main function used to fit KMeans (after cluster initialization) + * + * @tparam DataT the type of data used for weights, distances. + * @tparam IndexT the type of data used for indexing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] Initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + * @param[in] workspace Temporary workspace buffer which can get resized + */ +template +void kmeans_fit_main(const raft::handle_t& handle, + const KMeansParams& params, + const raft::device_matrix_view& X, + const raft::device_vector_view& weight, + const raft::device_matrix_view& centroidsRawData, + const raft::host_scalar_view& inertia, + const raft::host_scalar_view& n_iter, + rmm::device_uvector& workspace) +{ + detail::kmeans_fit_main( + handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); } } // namespace cluster } // namespace raft diff --git a/cpp/include/raft/cluster/kmeans_params.hpp b/cpp/include/raft/cluster/kmeans_params.hpp new file mode 100644 index 0000000000..70ea49d36d --- /dev/null +++ b/cpp/include/raft/cluster/kmeans_params.hpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include + +namespace raft { +namespace cluster { + +struct KMeansParams { + enum InitMethod { KMeansPlusPlus, Random, Array }; + + // The number of clusters to form as well as the number of centroids to + // generate (default:8). + int n_clusters = 8; + + /* + * Method for initialization, defaults to k-means++: + * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm + * to select the initial cluster centers. + * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at + * random from the input data for the initial centroids. + * - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. + */ + InitMethod init = KMeansPlusPlus; + + // Maximum number of iterations of the k-means algorithm for a single run. + int max_iter = 300; + + // Relative tolerance with regards to inertia to declare convergence. + double tol = 1e-4; + + // verbosity level. + int verbosity = RAFT_LEVEL_INFO; + + // Seed to the random number generator. + raft::random::RngState rng_state = + raft::random::RngState(0, raft::random::GeneratorType::GenPhilox); + + // Metric to use for distance computation. + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; + + // Number of instance k-means algorithm will be run with different seeds. + int n_init = 1; + + // Oversampling factor for use in the k-means|| algorithm. + double oversampling_factor = 2.0; + + // batch_samples and batch_centroids are used to tile 1NN computation which is + // useful to optimize/control the memory footprint + // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 + // then don't tile the centroids + int batch_samples = 1 << 15; + int batch_centroids = 0; // if 0 then batch_centroids = n_clusters + + bool inertia_check = false; +}; +} // namespace cluster +} // namespace raft diff --git a/cpp/include/raft/comms/detail/ucp_helper.hpp b/cpp/include/raft/comms/detail/ucp_helper.hpp index ef93ae90c5..79976811ed 100644 --- a/cpp/include/raft/comms/detail/ucp_helper.hpp +++ b/cpp/include/raft/comms/detail/ucp_helper.hpp @@ -69,7 +69,7 @@ class ucp_request { }; // by default, match the whole tag -static const ucp_tag_t default_tag_mask = -1; +static const ucp_tag_t default_tag_mask = (ucp_tag_t)-1; /** * @brief Asynchronous send callback sets request to completed diff --git a/cpp/include/raft/core/logger.hpp b/cpp/include/raft/core/logger.hpp index 927eb8943e..22e4dd7a90 100644 --- a/cpp/include/raft/core/logger.hpp +++ b/cpp/include/raft/core/logger.hpp @@ -15,6 +15,9 @@ */ #pragma once +#ifndef __RAFT_RT_LOGGER +#define __RAFT_RT_LOGGER + #include #include @@ -315,3 +318,5 @@ class logger { #define RAFT_LOG_CRITICAL(fmt, ...) void(0) #endif /** @} */ + +#endif \ No newline at end of file diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh new file mode 100644 index 0000000000..dd1da1e498 --- /dev/null +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft { +namespace matrix { +namespace detail { + +// gatherKernel conditionally copies rows from the source matrix 'in' into the destination matrix +// 'out' according to a map (or a transformed map) +template +__global__ void gatherKernel(const MatrixIteratorT in, + IndexT D, + IndexT N, + MapIteratorT map, + StencilIteratorT stencil, + MatrixIteratorT out, + PredicateOp pred_op, + MapTransformOp transform_op) +{ + typedef typename std::iterator_traits::value_type MapValueT; + typedef typename std::iterator_traits::value_type StencilValueT; + + IndexT outRowStart = blockIdx.x * D; + MapValueT map_val = map[blockIdx.x]; + StencilValueT stencil_val = stencil[blockIdx.x]; + + bool predicate = pred_op(stencil_val); + if (predicate) { + IndexT inRowStart = transform_op(map_val) * D; + for (int i = threadIdx.x; i < D; i += TPB) { + out[outRowStart + i] = in[inRowStart + i]; + } + } +} + +/** + * @brief gather conditionally copies rows from a source matrix into a destination matrix according + * to a transformed map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a + * simple pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * type must be convertible to bool type. + * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to IndexT (= int) type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param stencil Pointer to the input sequence of stencil or predicate values + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param pred_op Predicate to apply to the stencil values + * @param transform_op The transformation operation, transforms the map values to IndexT + * @param stream CUDA stream to launch kernels within + */ +template +void gatherImpl(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + StencilIteratorT stencil, + int map_length, + MatrixIteratorT out, + UnaryPredicateOp pred_op, + MapTransformOp transform_op, + cudaStream_t stream) +{ + // skip in case of 0 length input + if (map_length <= 0 || N <= 0 || D <= 0) return; + + // signed integer type for indexing or global offsets + typedef int IndexT; + + // map value type + typedef typename std::iterator_traits::value_type MapValueT; + + // stencil value type + typedef typename std::iterator_traits::value_type StencilValueT; + + // return type of MapTransformOp, must be convertable to IndexT + typedef typename std::result_of::type MapTransformOpReturnT; + static_assert((std::is_convertible::value), + "MapTransformOp's result type must be convertible to signed integer"); + + // return type of UnaryPredicateOp, must be convertible to bool + typedef typename std::result_of::type PredicateOpReturnT; + static_assert((std::is_convertible::value), + "UnaryPredicateOp's result type must be convertible to bool type"); + + if (D <= 32) { + gatherKernel + <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + } else if (D <= 64) { + gatherKernel + <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + } else if (D <= 128) { + gatherKernel + <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + } else { + gatherKernel + <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param stream CUDA stream to launch kernels within + */ +template +void gather(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + int map_length, + MatrixIteratorT out, + cudaStream_t stream) +{ + typedef typename std::iterator_traits::value_type MapValueT; + gatherImpl( + in, + D, + N, + map, + map, + map_length, + out, + [] __device__(MapValueT val) { return true; }, + [] __device__(MapValueT val) { return val; }, + stream); +} + +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a + * transformed map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to IndexT (= int) type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param transform_op The transformation operation, transforms the map values to IndexT + * @param stream CUDA stream to launch kernels within + */ +template +void gather(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + int map_length, + MatrixIteratorT out, + MapTransformOp transform_op, + cudaStream_t stream) +{ + typedef typename std::iterator_traits::value_type MapValueT; + gatherImpl( + in, + D, + N, + map, + map, + map_length, + out, + [] __device__(MapValueT val) { return true; }, + transform_op, + stream); +} + +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a + * simple pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * type must be convertible to bool type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param stencil Pointer to the input sequence of stencil or predicate values + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param pred_op Predicate to apply to the stencil values + * @param stream CUDA stream to launch kernels within + */ +template +void gather_if(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + StencilIteratorT stencil, + int map_length, + MatrixIteratorT out, + UnaryPredicateOp pred_op, + cudaStream_t stream) +{ + typedef typename std::iterator_traits::value_type MapValueT; + gatherImpl( + in, + D, + N, + map, + stencil, + map_length, + out, + pred_op, + [] __device__(MapValueT val) { return val; }, + stream); +} + +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a transformed map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a + * simple pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * type must be convertible to bool type. + * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to IndexT (= int) type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param stencil Pointer to the input sequence of stencil or predicate values + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param pred_op Predicate to apply to the stencil values + * @param transform_op The transformation operation, transforms the map values to IndexT + * @param stream CUDA stream to launch kernels within + */ +template +void gather_if(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + StencilIteratorT stencil, + int map_length, + MatrixIteratorT out, + UnaryPredicateOp pred_op, + MapTransformOp transform_op, + cudaStream_t stream) +{ + typedef typename std::iterator_traits::value_type MapValueT; + gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); +} +} // namespace detail +} // namespace matrix +} // namespace raft diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh new file mode 100644 index 0000000000..31164b2041 --- /dev/null +++ b/cpp/include/raft/matrix/gather.cuh @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace matrix { + +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param stream CUDA stream to launch kernels within + */ +template +void gather(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + int map_length, + MatrixIteratorT out, + cudaStream_t stream) +{ + detail::gather(in, D, N, map, map_length, out, stream); +} + +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a + * transformed map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to IndexT (= int) type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param transform_op The transformation operation, transforms the map values to IndexT + * @param stream CUDA stream to launch kernels within + */ +template +void gather(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + int map_length, + MatrixIteratorT out, + MapTransformOp transform_op, + cudaStream_t stream) +{ + detail::gather(in, D, N, map, map_length, out, transform_op, stream); +} + +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a + * simple pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * type must be convertible to bool type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param stencil Pointer to the input sequence of stencil or predicate values + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param pred_op Predicate to apply to the stencil values + * @param stream CUDA stream to launch kernels within + */ +template +void gather_if(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + StencilIteratorT stencil, + int map_length, + MatrixIteratorT out, + UnaryPredicateOp pred_op, + cudaStream_t stream) +{ + detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, stream); +} + +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a transformed map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a + * simple pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * type must be convertible to bool type. + * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to IndexT (= int) type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param D Leading dimension of the input matrix 'in', which in-case of row-major + * storage is the number of columns + * @param N Second dimension + * @param map Pointer to the input sequence of gather locations + * @param stencil Pointer to the input sequence of stencil or predicate values + * @param map_length The length of 'map' and 'stencil' + * @param out Pointer to the output matrix (assumed to be row-major) + * @param pred_op Predicate to apply to the stencil values + * @param transform_op The transformation operation, transforms the map values to IndexT + * @param stream CUDA stream to launch kernels within + */ +template +void gather_if(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + StencilIteratorT stencil, + int map_length, + MatrixIteratorT out, + UnaryPredicateOp pred_op, + MapTransformOp transform_op, + cudaStream_t stream) +{ + detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); +} +} // namespace matrix +} // namespace raft diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index e47b43da97..eead64942f 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -67,7 +67,7 @@ namespace detail { RAFT_DEPAREN(func)(r_pc, ##__VA_ARGS__); \ break; \ } \ - default: RAFT_FAIL("Unepxected generator type '%d'", int((rng_state).type)); \ + default: RAFT_FAIL("Unexpected generator type '%d'", int((rng_state).type)); \ } template diff --git a/cpp/include/raft/spectral/cluster_solvers.cuh b/cpp/include/raft/spectral/cluster_solvers.cuh index 27599c9464..8d56f3b172 100644 --- a/cpp/include/raft/spectral/cluster_solvers.cuh +++ b/cpp/include/raft/spectral/cluster_solvers.cuh @@ -57,18 +57,29 @@ struct kmeans_solver_t { RAFT_EXPECTS(codes != nullptr, "Null codes buffer."); value_type_t residual{}; index_type_t iters{}; + raft::cluster::KMeansParams km_params; + km_params.n_clusters = config_.n_clusters; + km_params.tol = config_.tol; + km_params.max_iter = config_.maxIter; + km_params.rng_state.seed = config_.seed; - raft::cluster::kmeans(handle, - n_obs_vecs, - dim, - config_.n_clusters, - config_.tol, - config_.maxIter, - obs, - codes, - residual, - iters, - config_.seed); + auto X = raft::make_device_matrix_view(obs, n_obs_vecs, dim); + auto labels = raft::make_device_vector_view(codes, n_obs_vecs); + auto centroids = + raft::make_device_matrix(config_.n_clusters, dim, handle.get_stream()); + auto weight = raft::make_device_vector(n_obs_vecs, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), weight.data(), weight.data() + n_obs_vecs, 1); + + auto sw = std::make_optional((raft::device_vector_view)weight.view()); + raft::cluster::kmeans_fit_predict( + handle, + km_params, + X, + sw, + centroids.view(), + labels, + raft::make_host_scalar_view(&residual), + raft::make_host_scalar_view(&iters)); return std::make_pair(residual, iters); } diff --git a/cpp/include/raft/spectral/cluster_solvers.hpp b/cpp/include/raft/spectral/cluster_solvers.hpp index 9cb773cce2..8709d10026 100644 --- a/cpp/include/raft/spectral/cluster_solvers.hpp +++ b/cpp/include/raft/spectral/cluster_solvers.hpp @@ -19,71 +19,6 @@ * Please use the cuh version instead. */ -#ifndef __CLUSTER_SOLVERS_H -#define __CLUSTER_SOLVERS_H - #pragma once -#include -#include // for std::pair - -namespace raft { -namespace spectral { - -using namespace matrix; - -// aggregate of control params for Eigen Solver: -// -template -struct cluster_solver_config_t { - size_type_t n_clusters; - size_type_t maxIter; - - value_type_t tol; - - unsigned long long seed{123456}; -}; - -template -struct kmeans_solver_t { - explicit kmeans_solver_t( - cluster_solver_config_t const& config) - : config_(config) - { - } - - std::pair solve(handle_t const& handle, - size_type_t n_obs_vecs, - size_type_t dim, - value_type_t const* __restrict__ obs, - index_type_t* __restrict__ codes) const - { - RAFT_EXPECTS(obs != nullptr, "Null obs buffer."); - RAFT_EXPECTS(codes != nullptr, "Null codes buffer."); - value_type_t residual{}; - index_type_t iters{}; - - raft::cluster::kmeans(handle, - n_obs_vecs, - dim, - config_.n_clusters, - config_.tol, - config_.maxIter, - obs, - codes, - residual, - iters, - config_.seed); - return std::make_pair(residual, iters); - } - - auto const& get_config(void) const { return config_; } - - private: - cluster_solver_config_t config_; -}; - -} // namespace spectral -} // namespace raft - -#endif \ No newline at end of file +#include \ No newline at end of file diff --git a/cpp/include/raft/spectral/detail/partition.hpp b/cpp/include/raft/spectral/detail/partition.hpp index 97e10963dc..1e0cc78826 100644 --- a/cpp/include/raft/spectral/detail/partition.hpp +++ b/cpp/include/raft/spectral/detail/partition.hpp @@ -29,6 +29,7 @@ #include #include #include +#include namespace raft { namespace spectral { @@ -91,7 +92,7 @@ std::tuple partition( // Initialize Laplacian /// sparse_matrix_t A{handle, graph}; - laplacian_matrix_t L{handle, csr_m}; + spectral::matrix::laplacian_matrix_t L{handle, csr_m}; auto eigen_config = eigen_solver.get_config(); auto nEigVecs = eigen_config.n_eigVecs; @@ -148,8 +149,8 @@ void analyzePartition(handle_t const& handle, weight_t partEdgesCut, clustersize; // Device memory - vector_t part_i(handle, n); - vector_t Lx(handle, n); + spectral::matrix::vector_t part_i(handle, n); + spectral::matrix::vector_t Lx(handle, n); // Initialize cuBLAS RAFT_CUBLAS_TRY( @@ -157,7 +158,7 @@ void analyzePartition(handle_t const& handle, // Initialize Laplacian /// sparse_matrix_t A{handle, graph}; - laplacian_matrix_t L{handle, csr_m}; + spectral::matrix::laplacian_matrix_t L{handle, csr_m}; // Initialize output cost = 0; diff --git a/cpp/include/raft/spectral/detail/spectral_util.cuh b/cpp/include/raft/spectral/detail/spectral_util.cuh index 8fa096b26b..bb8e94b764 100644 --- a/cpp/include/raft/spectral/detail/spectral_util.cuh +++ b/cpp/include/raft/spectral/detail/spectral_util.cuh @@ -160,7 +160,7 @@ void transform_eigen_matrix(handle_t const& handle, edge_t n, vertex_t nEigVecs, // Transpose eigenvector matrix // TODO: in-place transpose { - vector_t work(handle, nEigVecs * n); + raft::spectral::matrix::vector_t work(handle, nEigVecs * n); // TODO: Call from public API when ready RAFT_CUBLAS_TRY( raft::linalg::detail::cublassetpointermode(cublas_h, CUBLAS_POINTER_MODE_HOST, stream)); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3eead4d494..3326a0691c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -16,6 +16,7 @@ # keep the files in alphabetical order! add_executable(test_raft + test/cluster/kmeans.cu test/common/logger.cpp test/common/seive.cu test/cudart_utils.cpp @@ -71,6 +72,7 @@ add_executable(test_raft test/linalg/ternary_op.cu test/linalg/transpose.cu test/linalg/unary_op.cu + test/matrix/gather.cu test/matrix/math.cu test/matrix/matrix.cu test/matrix/columnSort.cu diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu new file mode 100644 index 0000000000..f54484c9ba --- /dev/null +++ b/cpp/test/cluster/kmeans.cu @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +template +struct KmeansInputs { + int n_row; + int n_col; + int n_clusters; + T tol; + bool weighted; +}; + +template +class KmeansTest : public ::testing::TestWithParam> { + protected: + KmeansTest() + : stream(handle.get_stream()), + d_labels(0, stream), + d_labels_ref(0, stream), + d_centroids(0, stream), + d_sample_weight(0, stream) + { + } + + void basicTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 5; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + auto X = raft::make_device_matrix(n_samples, n_features, stream); + auto labels = raft::make_device_vector(n_samples, stream); + + raft::random::make_blobs(X.data(), + labels.data(), + n_samples, + n_features, + params.n_clusters, + stream, + true, + nullptr, + nullptr, + T(1.0), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + + std::optional> d_sw = std::nullopt; + auto d_centroids_view = + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + d_sw = std::make_optional( + raft::make_device_vector_view(d_sample_weight.data(), n_samples)); + thrust::fill(thrust::cuda::par.on(stream), + d_sample_weight.data(), + d_sample_weight.data() + n_samples, + 1); + } + + raft::copy(d_labels_ref.data(), labels.data(), n_samples, stream); + handle.sync_stream(stream); + + T inertia = 0; + int n_iter = 0; + auto X_view = (raft::device_matrix_view)X.view(); + + raft::cluster::kmeans_fit_predict( + handle, + params, + X_view, + d_sw, + d_centroids_view, + raft::make_device_vector_view(d_labels.data(), n_samples), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + + handle.sync_stream(stream); + + score = raft::stats::adjusted_rand_index( + d_labels_ref.data(), d_labels.data(), n_samples, handle.get_stream()); + + if (score < 1.0) { + std::stringstream ss; + ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); + std::cout << (ss.str().c_str()) << '\n'; + ss.str(std::string()); + ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); + std::cout << (ss.str().c_str()) << '\n'; + std::cout << "Score = " << score << '\n'; + } + } + + void SetUp() override { basicTest(); } + + protected: + raft::handle_t handle; + cudaStream_t stream; + KmeansInputs testparams; + rmm::device_uvector d_labels; + rmm::device_uvector d_labels_ref; + rmm::device_uvector d_centroids; + rmm::device_uvector d_sample_weight; + double score; + raft::cluster::KMeansParams params; +}; + +const std::vector> inputsf2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}, + {10000, 1000, 200, 0.0001, true}, + {10000, 1000, 200, 0.0001, false}}; + +const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}, + {10000, 1000, 200, 0.0001, true}, + {10000, 1000, 200, 0.0001, false}}; + +typedef KmeansTest KmeansTestF; +TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); } + +typedef KmeansTest KmeansTestD; +TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score == 1.0); } + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); + +} // namespace raft diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu new file mode 100644 index 0000000000..2baeb81881 --- /dev/null +++ b/cpp/test/matrix/gather.cu @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +template +void naiveGatherImpl( + MatrixIteratorT in, int D, int N, MapIteratorT map, int map_length, MatrixIteratorT out) +{ + for (int outRow = 0; outRow < map_length; ++outRow) { + typename std::iterator_traits::value_type map_val = map[outRow]; + int inRowStart = map_val * D; + int outRowStart = outRow * D; + for (int i = 0; i < D; ++i) { + out[outRowStart + i] = in[inRowStart + i]; + } + } +} + +template +void naiveGather( + MatrixIteratorT in, int D, int N, MapIteratorT map, int map_length, MatrixIteratorT out) +{ + naiveGatherImpl(in, D, N, map, map_length, out); +} + +template +void gatherLaunch(MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + int map_length, + MatrixIteratorT out, + cudaStream_t stream) +{ + typedef typename std::iterator_traits::value_type MapValueT; + matrix::gather(in, D, N, map, map_length, out, stream); +} + +struct GatherInputs { + uint32_t nrows; + uint32_t ncols; + uint32_t map_length; + unsigned long long int seed; +}; + +template +class GatherTest : public ::testing::TestWithParam { + protected: + GatherTest() + : stream(handle.get_stream()), + params(::testing::TestWithParam::GetParam()), + d_in(0, stream), + d_out_exp(0, stream), + d_out_act(0, stream), + d_map(0, stream) + { + } + + void SetUp() override + { + raft::random::RngState r(params.seed); + raft::random::RngState r_int(params.seed); + + uint32_t nrows = params.nrows; + uint32_t ncols = params.ncols; + uint32_t map_length = params.map_length; + uint32_t len = nrows * ncols; + + // input matrix setup + d_in.resize(nrows * ncols, stream); + h_in.resize(nrows * ncols); + raft::random::uniform(handle, r, d_in.data(), len, MatrixT(-1.0), MatrixT(1.0)); + raft::update_host(h_in.data(), d_in.data(), len, stream); + + // map setup + d_map.resize(map_length, stream); + h_map.resize(map_length); + raft::random::uniformInt(handle, r_int, d_map.data(), map_length, (MapT)0, nrows); + raft::update_host(h_map.data(), d_map.data(), map_length, stream); + + // expected and actual output matrix setup + h_out.resize(map_length * ncols); + d_out_exp.resize(map_length * ncols, stream); + d_out_act.resize(map_length * ncols, stream); + + // launch gather on the host and copy the results to device + naiveGather(h_in.data(), ncols, nrows, h_map.data(), map_length, h_out.data()); + raft::update_device(d_out_exp.data(), h_out.data(), map_length * ncols, stream); + + // launch device version of the kernel + gatherLaunch(d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); + + handle.sync_stream(stream); + } + + protected: + raft::handle_t handle; + cudaStream_t stream = 0; + GatherInputs params; + std::vector h_in, h_out; + std::vector h_map; + rmm::device_uvector d_in, d_out_exp, d_out_act; + rmm::device_uvector d_map; +}; + +const std::vector inputs = {{1024, 32, 128, 1234ULL}, + {1024, 32, 256, 1234ULL}, + {1024, 32, 512, 1234ULL}, + {1024, 32, 1024, 1234ULL}, + {1024, 64, 128, 1234ULL}, + {1024, 64, 256, 1234ULL}, + {1024, 64, 512, 1234ULL}, + {1024, 64, 1024, 1234ULL}, + {1024, 128, 128, 1234ULL}, + {1024, 128, 256, 1234ULL}, + {1024, 128, 512, 1234ULL}, + {1024, 128, 1024, 1234ULL}}; + +typedef GatherTest GatherTestF; +TEST_P(GatherTestF, Result) +{ + ASSERT_TRUE(devArrMatch( + d_out_exp.data(), d_out_act.data(), params.map_length * params.ncols, raft::Compare())); +} + +typedef GatherTest GatherTestD; +TEST_P(GatherTestD, Result) +{ + ASSERT_TRUE(devArrMatch( + d_out_exp.data(), d_out_act.data(), params.map_length * params.ncols, raft::Compare())); +} + +INSTANTIATE_TEST_CASE_P(GatherTests, GatherTestF, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(GatherTests, GatherTestD, ::testing::ValuesIn(inputs)); + +} // end namespace raft \ No newline at end of file