Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace k-means++ CPU bottleneck with a random::discrete prim #1039

Merged
merged 6 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <raft/linalg/norm.cuh>
#include <raft/linalg/reduce_cols_by_key.cuh>
#include <raft/linalg/reduce_rows_by_key.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_scalar.hpp>
Expand Down Expand Up @@ -109,7 +110,7 @@ void kmeansPlusPlus(const raft::handle_t& handle,
auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples);

// temporary buffers
std::vector<DataT> h_wt(n_samples);
auto indices = raft::make_device_vector<IndexT, IndexT>(handle, n_trials);
auto centroidCandidates = raft::make_device_matrix<DataT, IndexT>(handle, n_trials, n_features);
auto costPerCandidate = raft::make_device_vector<DataT, IndexT>(handle, n_trials);
auto minClusterDistance = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
Expand All @@ -119,6 +120,17 @@ void kmeansPlusPlus(const raft::handle_t& handle,
rmm::device_scalar<DataT> clusterCost(stream);
rmm::device_scalar<cub::KeyValuePair<int, DataT>> minClusterIndexAndDistance(stream);

// Device and matrix views
raft::device_vector_view<IndexT, IndexT> indices_view(indices.data_handle(), n_trials);
auto const_weights_view =
raft::make_device_vector_view<const DataT, IndexT>(minClusterDistance.data_handle(), n_samples);
auto const_indices_view =
raft::make_device_vector_view<const IndexT, IndexT>(indices.data_handle(), n_trials);
auto const_X_view =
raft::make_device_matrix_view<const DataT, IndexT>(X.data_handle(), n_samples, n_features);
raft::device_matrix_view<DataT, IndexT> candidates_view(
centroidCandidates.data_handle(), n_trials, n_features);

// L2 norm of X: ||c||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);

Expand All @@ -133,6 +145,7 @@ void kmeansPlusPlus(const raft::handle_t& handle,
stream);
}

raft::random::RngState rng(params.rng_state.seed, params.rng_state.type);
std::mt19937 gen(params.rng_state.seed);
std::uniform_int_distribution<> dis(0, n_samples - 1);

Expand Down Expand Up @@ -169,20 +182,9 @@ void kmeansPlusPlus(const raft::handle_t& handle,
// <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C)
// Choose 'n_trials' centroid candidates from X with probability proportional to the squared
// distance to the nearest existing cluster
raft::copy(h_wt.data(), minClusterDistance.data_handle(), minClusterDistance.size(), stream);
handle.sync_stream(stream);

// Note - n_trials is relative small here, we don't need raft::gather call
std::discrete_distribution<> d(h_wt.begin(), h_wt.end());
for (int cIdx = 0; cIdx < n_trials; ++cIdx) {
auto rand_idx = d(gen);
auto randCentroid = raft::make_device_matrix_view<const DataT, IndexT>(
X.data_handle() + n_features * rand_idx, 1, n_features);
raft::copy(centroidCandidates.data_handle() + cIdx * n_features,
randCentroid.data_handle(),
randCentroid.size(),
stream);
}
raft::random::discrete(handle, rng, indices_view, const_weights_view);
raft::matrix::gather(handle, const_X_view, const_indices_view, candidates_view);

// Calculate pairwise distance between X and the centroid candidates
// Output - pwd [n_trials x n_samples]
Expand Down
34 changes: 34 additions & 0 deletions cpp/include/raft/random/detail/rng_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,40 @@ __global__ void rngKernel(DeviceState<GenType> rng_state,
return;
}

template <typename GenType, typename OutType, typename WeightType, typename IdxType>
__global__ void sample_with_replacement_kernel(DeviceState<GenType> rng_state,
OutType* out,
const WeightType* weights_csum,
IdxType sampledLen,
IdxType len)
{
// todo(lsugy): warp-collaborative binary search
Nyrio marked this conversation as resolved.
Show resolved Hide resolved

IdxType tid = threadIdx.x + static_cast<IdxType>(blockIdx.x) * blockDim.x;
GenType gen(rng_state, static_cast<uint64_t>(tid));

if (tid < sampledLen) {
WeightType val_01;
gen.next(val_01);
WeightType val_search = val_01 * weights_csum[len - 1];

// Binary search of the first index for which the cumulative sum of weights is larger than the
// generated value
IdxType idx_start = 0;
IdxType idx_end = len;
while (idx_end > idx_start) {
IdxType idx_middle = (idx_start + idx_end) / 2;
WeightType val_middle = weights_csum[idx_middle];
if (val_search <= val_middle) {
idx_end = idx_middle;
} else {
idx_start = idx_middle + 1;
}
}
out[tid] = static_cast<OutType>(min(idx_start, len - 1));
}
}

/**
* This kernel is deprecated and should be removed in a future release
*/
Expand Down
43 changes: 43 additions & 0 deletions cpp/include/raft/random/detail/rng_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,49 @@ void laplace(
RAFT_CALL_RNG_FUNC(rng_state, call_rng_kernel<1>, rng_state, stream, ptr, len, params);
}

template <typename GenType, typename OutType, typename WeightType, typename IdxType>
void call_sample_with_replacement_kernel(DeviceState<GenType> const& dev_state,
RngState& rng_state,
cudaStream_t stream,
OutType* out,
const WeightType* weights_csum,
IdxType sampledLen,
IdxType len)
{
IdxType n_threads = 256;
IdxType n_blocks = raft::ceildiv(sampledLen, n_threads);
sample_with_replacement_kernel<<<n_blocks, n_threads, 0, stream>>>(
dev_state, out, weights_csum, sampledLen, len);
rng_state.advance(uint64_t(n_blocks) * n_threads, 1);
}

template <typename OutType, typename WeightType, typename IndexType = OutType>
std::enable_if_t<std::is_integral_v<OutType>> discrete(RngState& rng_state,
OutType* ptr,
const WeightType* weights,
IndexType sampledLen,
IndexType len,
cudaStream_t stream)
{
// Compute the cumulative sums of the weights
size_t temp_storage_bytes = 0;
rmm::device_uvector<WeightType> weights_csum(len, stream);
cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, weights, weights_csum.data(), len);
rmm::device_uvector<uint8_t> temp_storage(temp_storage_bytes, stream);
cub::DeviceScan::InclusiveSum(
temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len);

// Sample indices with replacement
RAFT_CALL_RNG_FUNC(rng_state,
call_sample_with_replacement_kernel,
rng_state,
stream,
ptr,
weights_csum.data(),
sampledLen,
len);
}

template <typename DataT, typename WeightsT, typename IdxT = int>
void sampleWithoutReplacement(RngState& rng_state,
DataT* out,
Expand Down
40 changes: 40 additions & 0 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,46 @@ void laplace(const raft::handle_t& handle,
detail::laplace(rng_state, ptr, len, mu, scale, handle.get_stream());
}

/**
* @brief Generate random integers, where the probability of i is weights[i]/sum(weights)
*
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
* Usage example:
* @code{.cpp}
* #include <raft/core/device_mdarray.hpp>
* #include <raft/core/handle.hpp>
* #include <raft/random/rng.cuh>
*
* raft::handle_t handle;
* ...
* raft::random::RngState rng(seed);
* auto indices = raft::make_device_vector<int>(handle, n_samples);
* raft::random::discrete(handle, rng, indices.view(), weights);
* @endcode
*
* @tparam OutType integer output type
* @tparam WeightType weight type
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output array
* @param[in] weights weight array
*/
template <typename OutType, typename WeightType, typename IndexType>
std::enable_if_t<std::is_integral_v<OutType>> discrete(
const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutType, IndexType> out,
raft::device_vector_view<const WeightType, IndexType> weights)
{
detail::discrete(rng_state,
out.data_handle(),
weights.data_handle(),
out.extent(0),
weights.extent(0),
handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ if(BUILD_TESTS)
test/random/multi_variable_gaussian.cu
test/random/permute.cu
test/random/rng.cu
test/random/rng_discrete.cu
test/random/rng_int.cu
test/random/rmat_rectangular_generator.cu
test/random/sample_without_replacement.cu
Expand Down
Loading