Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace k-means++ CPU bottleneck with a random::discrete prim #1039

Merged
merged 6 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <raft/linalg/norm.cuh>
#include <raft/linalg/reduce_cols_by_key.cuh>
#include <raft/linalg/reduce_rows_by_key.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_scalar.hpp>
Expand Down Expand Up @@ -109,7 +110,7 @@ void kmeansPlusPlus(const raft::handle_t& handle,
auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples);

// temporary buffers
std::vector<DataT> h_wt(n_samples);
auto indices = raft::make_device_vector<IndexT, IndexT>(handle, n_trials);
auto centroidCandidates = raft::make_device_matrix<DataT, IndexT>(handle, n_trials, n_features);
auto costPerCandidate = raft::make_device_vector<DataT, IndexT>(handle, n_trials);
auto minClusterDistance = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
Expand All @@ -119,6 +120,17 @@ void kmeansPlusPlus(const raft::handle_t& handle,
rmm::device_scalar<DataT> clusterCost(stream);
rmm::device_scalar<cub::KeyValuePair<int, DataT>> minClusterIndexAndDistance(stream);

// Device and matrix views
raft::device_vector_view<IndexT, IndexT> indices_view(indices.data_handle(), n_trials);
auto const_weights_view =
raft::make_device_vector_view<const DataT, IndexT>(minClusterDistance.data_handle(), n_samples);
auto const_indices_view =
raft::make_device_vector_view<const IndexT, IndexT>(indices.data_handle(), n_trials);
auto const_X_view =
raft::make_device_matrix_view<const DataT, IndexT>(X.data_handle(), n_samples, n_features);
raft::device_matrix_view<DataT, IndexT> candidates_view(
centroidCandidates.data_handle(), n_trials, n_features);

// L2 norm of X: ||c||^2
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);

Expand All @@ -133,6 +145,7 @@ void kmeansPlusPlus(const raft::handle_t& handle,
stream);
}

raft::random::RngState rng(params.rng_state.seed, params.rng_state.type);
std::mt19937 gen(params.rng_state.seed);
std::uniform_int_distribution<> dis(0, n_samples - 1);

Expand Down Expand Up @@ -169,20 +182,9 @@ void kmeansPlusPlus(const raft::handle_t& handle,
// <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C)
// Choose 'n_trials' centroid candidates from X with probability proportional to the squared
// distance to the nearest existing cluster
raft::copy(h_wt.data(), minClusterDistance.data_handle(), minClusterDistance.size(), stream);
handle.sync_stream(stream);

// Note - n_trials is relative small here, we don't need raft::gather call
std::discrete_distribution<> d(h_wt.begin(), h_wt.end());
for (int cIdx = 0; cIdx < n_trials; ++cIdx) {
auto rand_idx = d(gen);
auto randCentroid = raft::make_device_matrix_view<const DataT, IndexT>(
X.data_handle() + n_features * rand_idx, 1, n_features);
raft::copy(centroidCandidates.data_handle() + cIdx * n_features,
randCentroid.data_handle(),
randCentroid.size(),
stream);
}
raft::random::discrete(handle, rng, indices_view, const_weights_view);
raft::matrix::gather(handle, const_X_view, const_indices_view, candidates_view);

// Calculate pairwise distance between X and the centroid candidates
// Output - pwd [n_trials x n_samples]
Expand Down
34 changes: 34 additions & 0 deletions cpp/include/raft/random/detail/rng_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,40 @@ __global__ void rngKernel(DeviceState<GenType> rng_state,
return;
}

template <typename GenType, typename OutType, typename WeightType, typename IdxType>
__global__ void sample_with_replacement_kernel(DeviceState<GenType> rng_state,
OutType* out,
const WeightType* weights_csum,
IdxType sampledLen,
IdxType len)
{
// todo(lsugy): warp-collaborative binary search
Nyrio marked this conversation as resolved.
Show resolved Hide resolved

IdxType tid = threadIdx.x + static_cast<IdxType>(blockIdx.x) * blockDim.x;
GenType gen(rng_state, static_cast<uint64_t>(tid));

if (tid < sampledLen) {
WeightType val_01;
gen.next(val_01);
WeightType val_search = val_01 * weights_csum[len - 1];

// Binary search of the first index for which the cumulative sum of weights is larger than the
// generated value
IdxType idx_start = 0;
IdxType idx_end = len;
while (idx_end > idx_start) {
IdxType idx_middle = (idx_start + idx_end) / 2;
WeightType val_middle = weights_csum[idx_middle];
if (val_search <= val_middle) {
idx_end = idx_middle;
} else {
idx_start = idx_middle + 1;
}
}
out[tid] = static_cast<OutType>(min(idx_start, len - 1));
}
}

/**
* This kernel is deprecated and should be removed in a future release
*/
Expand Down
43 changes: 43 additions & 0 deletions cpp/include/raft/random/detail/rng_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,49 @@ void laplace(
RAFT_CALL_RNG_FUNC(rng_state, call_rng_kernel<1>, rng_state, stream, ptr, len, params);
}

template <typename GenType, typename OutType, typename WeightType, typename IdxType>
void call_sample_with_replacement_kernel(DeviceState<GenType> const& dev_state,
RngState& rng_state,
cudaStream_t stream,
OutType* out,
const WeightType* weights_csum,
IdxType sampledLen,
IdxType len)
{
IdxType n_threads = 256;
IdxType n_blocks = raft::ceildiv(sampledLen, n_threads);
sample_with_replacement_kernel<<<n_blocks, n_threads, 0, stream>>>(
dev_state, out, weights_csum, sampledLen, len);
rng_state.advance(uint64_t(n_blocks) * n_threads, 1);
}

template <typename OutType, typename WeightType, typename IndexType = OutType>
std::enable_if_t<std::is_integral_v<OutType>> discrete(RngState& rng_state,
OutType* ptr,
const WeightType* weights,
IndexType sampledLen,
IndexType len,
cudaStream_t stream)
{
// Compute the cumulative sums of the weights
size_t temp_storage_bytes = 0;
rmm::device_uvector<WeightType> weights_csum(len, stream);
cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, weights, weights_csum.data(), len);
rmm::device_uvector<uint8_t> temp_storage(temp_storage_bytes, stream);
cub::DeviceScan::InclusiveSum(
temp_storage.data(), temp_storage_bytes, weights, weights_csum.data(), len);

// Sample indices with replacement
RAFT_CALL_RNG_FUNC(rng_state,
call_sample_with_replacement_kernel,
rng_state,
stream,
ptr,
weights_csum.data(),
sampledLen,
len);
}

template <typename DataT, typename WeightsT, typename IdxT = int>
void sampleWithoutReplacement(RngState& rng_state,
DataT* out,
Expand Down
27 changes: 27 additions & 0 deletions cpp/include/raft/random/rng.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,33 @@ void laplace(const raft::handle_t& handle,
detail::laplace(rng_state, ptr, len, mu, scale, handle.get_stream());
}

/**
* @brief Generate random integers, where the probability of i is weights[i]/sum(weights)
*
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
* @tparam OutType integer output type
* @tparam WeightType weight type
* @tparam IndexType data type used to represent length of the arrays
*
* @param[in] handle raft handle for resource management
* @param[in] rng_state random number generator state
* @param[out] out output array
* @param[in] weights weight array
*/
template <typename OutType, typename WeightType, typename IndexType>
std::enable_if_t<std::is_integral_v<OutType>> discrete(
const raft::handle_t& handle,
RngState& rng_state,
raft::device_vector_view<OutType, IndexType> out,
raft::device_vector_view<const WeightType, IndexType> weights)
{
detail::discrete(rng_state,
out.data_handle(),
weights.data_handle(),
out.extent(0),
weights.extent(0),
handle.get_stream());
}

namespace sample_without_replacement_impl {
template <typename T>
struct weight_alias {
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ if(BUILD_TESTS)
test/random/multi_variable_gaussian.cu
test/random/permute.cu
test/random/rng.cu
test/random/rng_discrete.cu
test/random/rng_int.cu
test/random/rmat_rectangular_generator.cu
test/random/sample_without_replacement.cu
Expand Down
160 changes: 160 additions & 0 deletions cpp/test/random/rng_discrete.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.h"
#include <gtest/gtest.h>
#include <raft/linalg/unary_op.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <vector>

namespace raft {
namespace random {

/* In this test we generate pseudo-random values following a probability distribution defined by the
* given weights. If the probability of i is is p=w_i/sum(w), the expected value for the normalized
* histogram is E=p, the standard deviation sigma=sqrt(p*(1-p)/n).
* We use as the test tolerance eps=4*sigma(p,n) where p=min(w_i/sum(w)) and n=sampledLen.
*/

template <typename WeightT, typename IdxT>
struct RngDiscreteInputs {
float tolerance;
IdxT sampledLen;
std::vector<WeightT> weights;
GeneratorType gtype;
unsigned long long int seed;
};

template <typename WeightT, typename IdxT>
::std::ostream& operator<<(::std::ostream& os, const RngDiscreteInputs<WeightT, IdxT>& d)
{
using raft::operator<<;
return os << "{" << d.sampledLen << ", " << d.weights << "}";
}

// Computes the intensity histogram from a sequence of labels
template <typename LabelT, typename IdxT>
void compute_normalized_histogram(
const LabelT* labels, float* histogram, IdxT sampledLen, IdxT len, const cudaStream_t& stream)
{
IdxT num_levels = len + 1;
IdxT lower_level = 0;
IdxT upper_level = len;

rmm::device_uvector<IdxT> count(len, stream);

size_t temp_storage_bytes = 0;
RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr,
temp_storage_bytes,
labels,
count.data(),
num_levels,
lower_level,
upper_level,
sampledLen,
stream));

rmm::device_uvector<char> workspace(temp_storage_bytes, stream);

RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(),
temp_storage_bytes,
labels,
count.data(),
num_levels,
lower_level,
upper_level,
sampledLen,
stream));

float scale = static_cast<float>(sampledLen);
raft::linalg::unaryOp(
histogram,
count.data(),
len,
[scale] __device__(const IdxT& cnt) { return static_cast<float>(cnt) / scale; },
stream);
}

template <typename OutT, typename WeightT, typename IdxT>
class RngDiscreteTest : public ::testing::TestWithParam<RngDiscreteInputs<WeightT, IdxT>> {
public:
RngDiscreteTest()
: params(::testing::TestWithParam<RngDiscreteInputs<WeightT, IdxT>>::GetParam()),
stream(handle.get_stream()),
out(params.sampledLen, stream),
weights(params.weights.size(), stream),
histogram(params.weights.size(), stream),
exp_histogram(params.weights.size())
{
}

protected:
void SetUp() override
{
IdxT len = params.weights.size();

raft::copy(weights.data(), params.weights.data(), len, stream);

RngState r(params.seed, params.gtype);
raft::device_vector_view<OutT, IdxT> out_view(out.data(), out.size());
auto weights_view =
raft::make_device_vector_view<const WeightT, IdxT>(weights.data(), weights.size());

discrete(handle, r, out_view, weights_view);

// Compute the actual and expected normalized histogram of the values
float total_weight = 0.0f;
for (IdxT i = 0; i < len; i++) {
total_weight += params.weights[i];
}
for (IdxT i = 0; i < len; i++) {
exp_histogram[i] = params.weights[i] / total_weight;
}
compute_normalized_histogram(out.data(), histogram.data(), params.sampledLen, len, stream);
}

protected:
raft::handle_t handle;
cudaStream_t stream;

RngDiscreteInputs<WeightT, IdxT> params;
rmm::device_uvector<OutT> out;
rmm::device_uvector<WeightT> weights;
rmm::device_uvector<float> histogram;
std::vector<float> exp_histogram;
};

const std::vector<RngDiscreteInputs<float, int>> inputs_u32 = {
{0.016f, 10000, {1.f, 2.f, 3.f, 4.f}, GenPhilox, 1234ULL},
{0.01f, 10000, {0.5f, 0.3f, 0.3f, 0.f, 0.f, 0.f, 1.5f, 2.0f}, GenPhilox, 1234ULL},

{0.016f, 10000, {1.f, 2.f, 3.f, 4.f}, GenPC, 1234ULL},
};

using RngDiscreteTestU32F = RngDiscreteTest<uint32_t, float, int>;
TEST_P(RngDiscreteTestU32F, Result)
{
ASSERT_TRUE(devArrMatchHost(exp_histogram.data(),
histogram.data(),
exp_histogram.size(),
CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_SUITE_P(RngTests, RngDiscreteTestU32F, ::testing::ValuesIn(inputs_u32));

} // namespace random
} // namespace raft
12 changes: 12 additions & 0 deletions cpp/test/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@

namespace raft {

template <typename T>
::std::ostream& operator<<(::std::ostream& os, const std::vector<T>& v)
{
os << "{";
for (size_t i = 0; i < v.size(); ++i) {
if (i != 0) os << ",";
os << v[i];
}
os << "}";
return os;
}

template <typename T>
struct Compare {
bool operator()(const T& a, const T& b) const { return a == b; }
Expand Down