Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial port of auto-find-k #1070

Merged
merged 26 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3c4e43c
Initial port of auto-find-k from nvgraph
cjnolet Dec 6, 2022
c56415a
Removing unneeded arguments
cjnolet Dec 6, 2022
46cc28d
Fixing kmin default
cjnolet Dec 6, 2022
08cd31d
Rename function
cjnolet Dec 6, 2022
3bdafdd
Adding countLabels to auto-find k
cjnolet Dec 6, 2022
2456df2
Cleanup
cjnolet Dec 6, 2022
f9666bb
Adding to public api
cjnolet Dec 6, 2022
565a68d
auto find k builds. Need to add test
cjnolet Dec 8, 2022
bef9b42
Adding docs
cjnolet Dec 8, 2022
efb71dd
Adding googletest
cjnolet Dec 8, 2022
45b9541
Getting auto-find-k tests to build
cjnolet Dec 8, 2022
1d6ba6c
Troubleshooting segmentation faults
cjnolet Dec 8, 2022
3d2b432
Merge remote-tracking branch 'rapidsai/branch-23.02' into fea-2302-au…
cjnolet Dec 9, 2022
4c136ea
The tests pass!
cjnolet Dec 9, 2022
5c78528
Making within-cluster variance super small
cjnolet Dec 12, 2022
606b68f
Merge branch 'branch-23.02' into fea-2302-auto-find-k
cjnolet Feb 16, 2023
812e652
Merge branch 'branch-23.04' into fea-2302-auto-find-k
cjnolet Feb 16, 2023
77ec3ee
Some fixes and updates
cjnolet Feb 16, 2023
abe5a1d
linking in kmeans_auto_find_k googletest
cjnolet Feb 17, 2023
15ee9fe
Build fixes
cjnolet Feb 17, 2023
058b6e4
Hopefully tests pass this timey
cjnolet Feb 17, 2023
605ec44
Merge branch 'branch-23.04' into fea-2302-auto-find-k
cjnolet Feb 18, 2023
c77d09e
Implementing review feedback
cjnolet Feb 18, 2023
787c14a
Merge branch 'fea-2302-auto-find-k' of github.com:cjnolet/raft into f…
cjnolet Feb 18, 2023
e335462
Merge branch 'branch-23.04' into fea-2302-auto-find-k
cjnolet Feb 18, 2023
e3be870
Merge branch 'branch-23.04' into fea-2302-auto-find-k
cjnolet Feb 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -818,9 +818,9 @@ void initScalableKMeansPlusPlus(raft::device_resources const& handle,
template <typename DataT, typename IndexT>
void kmeans_fit(raft::device_resources const& handle,
const KMeansParams& params,
raft::device_matrix_view<const DataT> X,
std::optional<raft::device_vector_view<const DataT>> sample_weight,
raft::device_matrix_view<DataT> centroids,
raft::device_matrix_view<const DataT, IndexT> X,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
raft::device_matrix_view<DataT, IndexT> centroids,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter)
{
Expand Down Expand Up @@ -982,10 +982,10 @@ void kmeans_fit(raft::device_resources const& handle,
template <typename DataT, typename IndexT>
void kmeans_predict(raft::device_resources const& handle,
const KMeansParams& params,
raft::device_matrix_view<const DataT> X,
std::optional<raft::device_vector_view<const DataT>> sample_weight,
raft::device_matrix_view<const DataT> centroids,
raft::device_vector_view<IndexT> labels,
raft::device_matrix_view<const DataT, IndexT> X,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_vector_view<IndexT, IndexT> labels,
bool normalize_weight,
raft::host_scalar_view<DataT> inertia)
{
Expand Down Expand Up @@ -1122,10 +1122,10 @@ void kmeans_predict(raft::device_resources const& handle,
template <typename DataT, typename IndexT = int>
void kmeans_fit_predict(raft::device_resources const& handle,
const KMeansParams& params,
raft::device_matrix_view<const DataT> X,
std::optional<raft::device_vector_view<const DataT>> sample_weight,
std::optional<raft::device_matrix_view<DataT>> centroids,
raft::device_vector_view<IndexT> labels,
raft::device_matrix_view<const DataT, IndexT> X,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
std::optional<raft::device_matrix_view<DataT, IndexT>> centroids,
raft::device_vector_view<IndexT, IndexT> labels,
raft::host_scalar_view<DataT> inertia,
raft::host_scalar_view<IndexT> n_iter)
{
Expand Down
219 changes: 219 additions & 0 deletions cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <thrust/host_vector.h>

#include <raft/core/logger.hpp>

#include <raft/cluster/detail/kmeans.cuh>

#include <raft/core/error.hpp>

#include <raft/core/device_resources.hpp>
#include <raft/stats/dispersion.cuh>

namespace raft::cluster::detail {

template <typename idx_t, typename value_t>
void find_k(raft::device_resources const& handle,
raft::device_matrix_view<const value_t, idx_t> X,
raft::host_scalar_view<idx_t> best_k,
raft::host_scalar_view<value_t> residual,
raft::host_scalar_view<idx_t> n_iter,
idx_t kmax,
idx_t kmin = 1,
idx_t maxiter = 100,
value_t tol = 1e-3)
{
idx_t n = X.extent(0);
idx_t d = X.extent(1);

RAFT_EXPECTS(n >= 1, "n must be >= 1");
RAFT_EXPECTS(d >= 1, "d must be >= 1");
RAFT_EXPECTS(kmin >= 1, "kmin must be >= 1");
RAFT_EXPECTS(kmax <= n, "kmax must be <= number of data samples in X");
RAFT_EXPECTS(tol >= 0, "tolerance must be >= 0");
RAFT_EXPECTS(maxiter >= 0, "maxiter must be >= 0");
// Allocate memory
// Device memory

auto centroids = raft::make_device_matrix<value_t, idx_t>(handle, kmax, X.extent(1));
auto clusterSizes = raft::make_device_vector<idx_t>(handle, kmax);
auto labels = raft::make_device_vector<idx_t>(handle, n);

rmm::device_uvector<char> workspace(0, handle.get_stream());

idx_t* clusterSizes_ptr = clusterSizes.data_handle();

// Host memory
auto results = raft::make_host_vector<value_t>(kmax + 1);
auto clusterDispersion = raft::make_host_vector<value_t>(kmax + 1);

auto clusterDispertionView = clusterDispersion.view();
auto resultsView = results.view();

// Loop to find *best* k
// Perform k-means in binary search
int left = kmin; // must be at least 2
int right = kmax; // int(floor(len(data)/2)) #assumption of clusters of size 2 at least
int mid = int(floor((right + left) / 2));
int oldmid = mid;
int tests = 0;
value_t objective[3]; // 0= left of mid, 1= right of mid
if (left == 1) left = 2; // at least do 2 clusters

KMeansParams params;
params.max_iter = maxiter;
params.tol = tol;

auto centroids_const_view =
raft::make_device_matrix_view<const value_t, idx_t>(centroids.data_handle(), left, d);

auto centroids_view =
raft::make_device_matrix_view<value_t, idx_t>(centroids.data_handle(), left, d);

auto cluster_sizes_view =
raft::make_device_vector_view<const idx_t, idx_t>(clusterSizes_ptr, left);

params.n_clusters = left;
kmeans_fit_predict<value_t, idx_t>(handle,
params,
X,
std::nullopt,
std::make_optional(centroids_view),
labels.view(),
residual,
n_iter);

detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, left, workspace);
resultsView[left] = residual[0];

clusterDispertionView[left] = raft::stats::cluster_dispersion(
handle, centroids_const_view, cluster_sizes_view, std::nullopt, n);
// eval right edge0
resultsView[right] = 1e20;
while (resultsView[right] > resultsView[left] && tests < 3) {
centroids_const_view =
raft::make_device_matrix_view<const value_t, idx_t>(centroids.data_handle(), right, d);
centroids_view =
raft::make_device_matrix_view<value_t, idx_t>(centroids.data_handle(), right, d);

cluster_sizes_view = raft::make_device_vector_view<const idx_t, idx_t>(clusterSizes_ptr, right);

params.n_clusters = right;
kmeans_fit_predict<value_t, idx_t>(handle,
params,
X,
std::nullopt,
std::make_optional(centroids_view),
labels.view(),
residual,
n_iter);

detail::countLabels(
handle, labels.data_handle(), clusterSizes.data_handle(), n, right, workspace);

resultsView[right] = residual[0];
clusterDispertionView[right] = raft::stats::cluster_dispersion(
handle, centroids_const_view, cluster_sizes_view, std::nullopt, n);

tests += 1;
}

objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left];
objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right];
while (left < right - 1) {
resultsView[mid] = 1e20;
tests = 0;
while (resultsView[mid] > resultsView[left] && tests < 3) {
centroids_const_view =
raft::make_device_matrix_view<const value_t, idx_t>(centroids.data_handle(), mid, d);
centroids_view =
raft::make_device_matrix_view<value_t, idx_t>(centroids.data_handle(), mid, d);

cluster_sizes_view = raft::make_device_vector_view<const idx_t, idx_t>(clusterSizes_ptr, mid);

params.n_clusters = mid;

kmeans_fit_predict<value_t, idx_t>(handle,
params,
X,
std::nullopt,
std::make_optional(centroids_view),
labels.view(),
residual,
n_iter);

detail::countLabels(
handle, labels.data_handle(), clusterSizes.data_handle(), n, mid, workspace);

resultsView[mid] = residual[0];
clusterDispertionView[mid] = raft::stats::cluster_dispersion(
handle, centroids_const_view, cluster_sizes_view, std::nullopt, n);

if (resultsView[mid] > resultsView[left] && (mid + 1) < right) {
mid += 1;
resultsView[mid] = 1e20;
} else if (resultsView[mid] > resultsView[left] && (mid - 1) > left) {
mid -= 1;
resultsView[mid] = 1e20;
}
tests += 1;
}

// maximize Calinski-Harabasz Index, minimize resid/ cluster
objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left];
objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right];
objective[2] = (n - mid) / (mid - 1) * clusterDispertionView[mid] / resultsView[mid];
objective[0] = (objective[2] - objective[0]) / (mid - left);
objective[1] = (objective[1] - objective[2]) / (right - mid);

if (objective[0] > 0 && objective[1] < 0) {
// our point is in the left-of-mid side
right = mid;
} else {
left = mid;
}
oldmid = mid;
mid = int(floor((right + left) / 2));
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
}

best_k[0] = right;
objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left];
objective[1] = (n - oldmid) / (oldmid - 1) * clusterDispertionView[oldmid] / resultsView[oldmid];
if (objective[1] < objective[0]) { best_k[0] = left; }

// if best_k isn't what we just ran, re-run to get correct centroids and dist data on return->
// this saves memory
if (best_k[0] != oldmid) {
centroids_view =
raft::make_device_matrix_view<value_t, idx_t>(centroids.data_handle(), best_k[0], d);

params.n_clusters = best_k[0];
kmeans_fit_predict<value_t, idx_t>(handle,
params,
X,
std::nullopt,
std::make_optional(centroids_view),
labels.view(),
residual,
n_iter);
}
}
} // namespace raft::cluster::detail
55 changes: 55 additions & 0 deletions cpp/include/raft/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <optional>
#include <raft/cluster/detail/kmeans.cuh>
#include <raft/cluster/detail/kmeans_auto_find_k.cuh>
#include <raft/cluster/kmeans_types.hpp>
#include <raft/core/kvp.hpp>
#include <raft/core/mdarray.hpp>
Expand Down Expand Up @@ -261,6 +262,60 @@ void transform(raft::device_resources const& handle,
handle, params, X, centroids, n_samples, n_features, X_new);
}

/**
* Automatically find the optimal value of k using a binary search.
* This method maximizes the Calinski-Harabasz Index while minimizing the per-cluster inertia.
*
* @code{.cpp}
* #include <raft/core/handle.hpp>
* #include <raft/cluster/kmeans.cuh>
* #include <raft/cluster/kmeans_types.hpp>
*
* #include <raft/random/make_blobs.cuh>
*
* using namespace raft::cluster;
*
* raft::handle_t handle;
* int n_samples = 100, n_features = 15, n_clusters = 10;
* auto X = raft::make_device_matrix<float, int>(handle, n_samples, n_features);
* auto labels = raft::make_device_vector<float, int>(handle, n_samples);
*
* raft::random::make_blobs(handle, X, labels, n_clusters);
*
* auto best_k = raft::make_host_scalar<int>(0);
* auto n_iter = raft::make_host_scalar<int>(0);
* auto inertia = raft::make_host_scalar<int>(0);
*
* kmeans::find_k(handle, X, best_k.view(), inertia.view(), n_iter.view(), n_clusters+1);
*
* @endcode
*
* @tparam idx_t indexing type (should be integral)
* @tparam value_t value type (should be floating point)
* @param handle raft handle
* @param X input observations (shape n_samples, n_dims)
* @param best_k best k found from binary search
* @param inertia inertia of best k found
* @param n_iter number of iterations used to find best k
* @param kmax maximum k to try in search
* @param kmin minimum k to try in search (should be >= 1)
* @param maxiter maximum number of iterations to run
* @param tol tolerance for early stopping convergence
*/
template <typename idx_t, typename value_t>
void find_k(raft::device_resources const& handle,
raft::device_matrix_view<const value_t, idx_t> X,
raft::host_scalar_view<idx_t> best_k,
raft::host_scalar_view<value_t> inertia,
raft::host_scalar_view<idx_t> n_iter,
idx_t kmax,
idx_t kmin = 1,
idx_t maxiter = 100,
value_t tol = 1e-3)
{
detail::find_k(handle, X, best_k, inertia, n_iter, kmax, kmin, maxiter, tol);
}

/**
* @brief Select centroids according to a sampling operation
*
Expand Down
Loading