Skip to content

Commit

Permalink
Some cleanup of k-means internals (#953)
Browse files Browse the repository at this point in the history
This also refactors the `update centroids` logic into its own function in the detail API. Also exposing `update_centroids` function for scikit-learn pluggability.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #953
  • Loading branch information
cjnolet authored Oct 28, 2022
1 parent 7e95567 commit d199a9f
Show file tree
Hide file tree
Showing 19 changed files with 959 additions and 147 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ build/
build_prims/
dist/
python/**/**/*.cpp
python/raft/record.txt
python/raft-dask/record.txt
python/pylibraft/record.txt
log
.ipynb_checkpoints
Expand Down
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ if(RAFT_COMPILE_DIST_LIBRARY)
add_library(raft_distance_lib
src/distance/pairwise_distance.cu
src/distance/fused_l2_min_arg.cu
src/distance/update_centroids_float.cu
src/distance/update_centroids_double.cu
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
Expand Down
231 changes: 143 additions & 88 deletions cpp/include/raft/cluster/detail/kmeans.cuh

Large diffs are not rendered by default.

81 changes: 44 additions & 37 deletions cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void countLabels(const raft::handle_t& handle,

template <typename DataT, typename IndexT>
void checkWeight(const raft::handle_t& handle,
const raft::device_vector_view<DataT, IndexT>& weight,
raft::device_vector_view<DataT, IndexT> weight,
rmm::device_uvector<char>& workspace)
{
cudaStream_t stream = handle.get_stream();
Expand Down Expand Up @@ -166,24 +166,24 @@ void checkWeight(const raft::handle_t& handle,
}

template <typename IndexT>
IndexT getDataBatchSize(const KMeansParams& params, IndexT n_samples)
IndexT getDataBatchSize(int batch_samples, IndexT n_samples)
{
auto minVal = std::min(static_cast<IndexT>(params.batch_samples), n_samples);
auto minVal = std::min(static_cast<IndexT>(batch_samples), n_samples);
return (minVal == 0) ? n_samples : minVal;
}

template <typename IndexT>
IndexT getCentroidsBatchSize(const KMeansParams& params, IndexT n_local_clusters)
IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters)
{
auto minVal = std::min(static_cast<IndexT>(params.batch_centroids), n_local_clusters);
auto minVal = std::min(static_cast<IndexT>(batch_centroids), n_local_clusters);
return (minVal == 0) ? n_local_clusters : minVal;
}

template <typename DataT, typename ReductionOpT, typename IndexT = int>
void computeClusterCost(const raft::handle_t& handle,
const raft::device_vector_view<DataT, IndexT>& minClusterDistance,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
rmm::device_uvector<char>& workspace,
const raft::device_scalar_view<DataT>& clusterCost,
raft::device_scalar_view<DataT> clusterCost,
ReductionOpT reduction_op)
{
cudaStream_t stream = handle.get_stream();
Expand Down Expand Up @@ -211,9 +211,9 @@ void computeClusterCost(const raft::handle_t& handle,

template <typename DataT, typename IndexT>
void sampleCentroids(const raft::handle_t& handle,
const raft::device_matrix_view<const DataT, IndexT>& X,
const raft::device_vector_view<DataT, IndexT>& minClusterDistance,
const raft::device_vector_view<uint8_t, IndexT>& isSampleCentroid,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
raft::device_vector_view<uint8_t, IndexT> isSampleCentroid,
SamplingOp<DataT, IndexT>& select_op,
rmm::device_uvector<DataT>& inRankCp,
rmm::device_uvector<char>& workspace)
Expand Down Expand Up @@ -277,9 +277,9 @@ void sampleCentroids(const raft::handle_t& handle,
// result will be stored in 'pairwiseDistance[n x k]'
template <typename DataT, typename IndexT>
void pairwise_distance_kmeans(const raft::handle_t& handle,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_matrix_view<DataT, IndexT> pairwiseDistance,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_matrix_view<DataT, IndexT> pairwiseDistance,
rmm::device_uvector<char>& workspace,
raft::distance::DistanceType metric)
{
Expand All @@ -305,8 +305,8 @@ void pairwise_distance_kmeans(const raft::handle_t& handle,
// in 'out' does not modify the input
template <typename DataT, typename IndexT>
void shuffleAndGather(const raft::handle_t& handle,
const raft::device_matrix_view<const DataT, IndexT>& in,
const raft::device_matrix_view<DataT, IndexT>& out,
raft::device_matrix_view<const DataT, IndexT> in,
raft::device_matrix_view<DataT, IndexT> out,
uint32_t n_samples_to_gather,
uint64_t seed)
{
Expand Down Expand Up @@ -340,24 +340,25 @@ void shuffleAndGather(const raft::handle_t& handle,
template <typename DataT, typename IndexT>
void minClusterAndDistanceCompute(
const raft::handle_t& handle,
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT> X,
const raft::device_matrix_view<const DataT, IndexT> centroids,
const raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
const raft::device_vector_view<DataT, IndexT> L2NormX,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
raft::device_vector_view<const DataT, IndexT> L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
raft::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace)
{
cudaStream_t stream = handle.get_stream();
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
auto n_clusters = centroids.extent(0);
auto metric = params.metric;
// todo(lsugy): change batch size computation when using fusedL2NN!
bool is_fused = metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded;
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(params, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(params, n_clusters);
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters);

if (is_fused) {
L2NormBuf_OR_DistBuf.resize(n_clusters, stream);
Expand All @@ -369,6 +370,9 @@ void minClusterAndDistanceCompute(
true,
stream);
} else {
// TODO: Unless pool allocator is used, passing in a workspace for this
// isn't really increasing performance because this needs to do a re-allocation
// anyways. ref https://github.com/rapidsai/raft/issues/930
L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream);
}

Expand Down Expand Up @@ -403,7 +407,7 @@ void minClusterAndDistanceCompute(
minClusterAndDistance.data_handle() + dIdx, ns);

auto L2NormXView =
raft::make_device_vector_view<DataT, IndexT>(L2NormX.data_handle() + dIdx, ns);
raft::make_device_vector_view<const DataT, IndexT>(L2NormX.data_handle() + dIdx, ns);

if (is_fused) {
workspace.resize((sizeof(int)) * ns, stream);
Expand Down Expand Up @@ -471,24 +475,25 @@ void minClusterAndDistanceCompute(

template <typename DataT, typename IndexT>
void minClusterDistanceCompute(const raft::handle_t& handle,
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT>& X,
const raft::device_matrix_view<DataT, IndexT>& centroids,
const raft::device_vector_view<DataT, IndexT>& minClusterDistance,
const raft::device_vector_view<DataT, IndexT>& L2NormX,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<DataT, IndexT> centroids,
raft::device_vector_view<DataT, IndexT> minClusterDistance,
raft::device_vector_view<DataT, IndexT> L2NormX,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
raft::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace)
{
cudaStream_t stream = handle.get_stream();
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
auto n_clusters = centroids.extent(0);
auto metric = params.metric;

bool is_fused = metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded;
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(params, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(params, n_clusters);
auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples);
auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters);

if (is_fused) {
L2NormBuf_OR_DistBuf.resize(n_clusters, stream);
Expand Down Expand Up @@ -597,11 +602,11 @@ void minClusterDistanceCompute(const raft::handle_t& handle,
template <typename DataT, typename IndexT>
void countSamplesInCluster(const raft::handle_t& handle,
const KMeansParams& params,
const raft::device_matrix_view<const DataT, IndexT>& X,
const raft::device_vector_view<DataT, IndexT> L2NormX,
const raft::device_matrix_view<DataT, IndexT> centroids,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_vector_view<const DataT, IndexT> L2NormX,
raft::device_matrix_view<DataT, IndexT> centroids,
rmm::device_uvector<char>& workspace,
const raft::device_vector_view<DataT, IndexT> sampleCountInCluster)
raft::device_vector_view<DataT, IndexT> sampleCountInCluster)
{
cudaStream_t stream = handle.get_stream();
auto n_samples = X.extent(0);
Expand All @@ -623,12 +628,14 @@ void countSamplesInCluster(const raft::handle_t& handle,
// centroid) and 'value' is the distance between the sample 'X[i]' and the
// 'centroid[key]'
detail::minClusterAndDistanceCompute(handle,
params,
X,
(raft::device_matrix_view<const DataT, IndexT>)centroids,
minClusterAndDistance.view(),
L2NormX,
L2NormBuf_OR_DistBuf,
params.metric,
params.batch_samples,
params.batch_centroids,
workspace);

// Using TransformInputIteratorT to dereference an array of raft::KeyValuePair
Expand Down
Loading

0 comments on commit d199a9f

Please sign in to comment.