Skip to content

Commit

Permalink
Add final API of raft KMeans
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Apr 26, 2022
1 parent 696439d commit f1589f1
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 156 deletions.
54 changes: 2 additions & 52 deletions cpp/include/cuml/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cuml/common/log_levels.hpp>
#include <raft/cluster/kmeans_params.hpp>

namespace raft {
class handle_t;
Expand All @@ -26,54 +27,7 @@ namespace ML {

namespace kmeans {

struct KMeansParams {
enum InitMethod { KMeansPlusPlus, Random, Array };

// The number of clusters to form as well as the number of centroids to
// generate (default:8).
int n_clusters = 8;

/*
* Method for initialization, defaults to k-means++:
* - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm
* to select the initial cluster centers.
* - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at
* random from the input data for the initial centroids.
* - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers.
*/
InitMethod init = KMeansPlusPlus;

// Maximum number of iterations of the k-means algorithm for a single run.
int max_iter = 300;

// Relative tolerance with regards to inertia to declare convergence.
double tol = 1e-4;

// verbosity level.
int verbosity = CUML_LEVEL_INFO;

// Seed to the random number generator.
int seed = 0;

// Metric to use for distance computation. Any metric from
// raft::distance::DistanceType can be used
int metric = 0;

// Number of instance k-means algorithm will be run with different seeds.
int n_init = 1;

// Oversampling factor for use in the k-means|| algorithm.
double oversampling_factor = 2.0;

// batch_samples and batch_centroids are used to tile 1NN computation which is
// useful to optimize/control the memory footprint
// Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0
// then don't tile the centroids
int batch_samples = 1 << 15;
int batch_centroids = 0; // if 0 then batch_centroids = n_clusters

bool inertia_check = false;
};
using KMeansParams = raft::cluster::KMeansParams;

/**
* @brief Compute k-means clustering and predicts cluster index for each sample
Expand Down Expand Up @@ -222,8 +176,6 @@ void predict(const raft::handle_t& handle,
* @param[in] n_features Number of features or the dimensions of each
* sample in 'X' (it should be same as the dimension for each cluster centers in
* 'centroids').
* @param[in] metric Metric to use for distance computation. Any
* metric from raft::distance::DistanceType can be used
* @param[out] X_new X transformed in the new space..
*/
void transform(const raft::handle_t& handle,
Expand All @@ -232,7 +184,6 @@ void transform(const raft::handle_t& handle,
const float* X,
int n_samples,
int n_features,
int metric,
float* X_new);

void transform(const raft::handle_t& handle,
Expand All @@ -241,7 +192,6 @@ void transform(const raft::handle_t& handle,
const double* X,
int n_samples,
int n_features,
int metric,
double* X_new);

}; // end namespace kmeans
Expand Down
1 change: 0 additions & 1 deletion cpp/include/cuml/cluster/kmeans_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class handle_t;

namespace ML {
namespace kmeans {
struct KMeansParams;
namespace opg {

/**
Expand Down
155 changes: 61 additions & 94 deletions cpp/src/kmeans/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#include "sg_impl.cuh"
#include <raft/cluster/kmeans.cuh>
#include <cuml/cluster/kmeans.hpp>
#include <raft/cluster/kmeans.cuh>

namespace ML {
namespace kmeans {
Expand All @@ -33,40 +33,15 @@ void fit_predict(const raft::handle_t& handle,
float& inertia,
int& n_iter)
{
/*impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter);
impl::predict(
handle, params, centroids, X, n_samples, n_features, sample_weight, true, labels, inertia);*/
raft::cluster::KMeansParams rParams;
rParams.n_clusters = params.n_clusters;
rParams.init = raft::cluster::KMeansParams::InitMethod::KMeansPlusPlus;
rParams.max_iter = params.max_iter;
rParams.tol = params.tol;
rParams.verbosity = params.verbosity;
rParams.seed = params.seed;
rParams.metric = params.metric;
rParams.n_init = params.n_init;
rParams.oversampling_factor = params.oversampling_factor;
rParams.batch_samples = params.batch_samples;
rParams.batch_centroids = params.batch_centroids;
rParams.inertia_check = params.inertia_check;
auto X_view = raft::make_device_matrix_view(const_cast<float*>(X), n_samples, n_features);
std::optional<raft::device_vector_view<float>> sw = std::nullopt;
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
std::optional<raft::device_vector_view<const float>> sw = std::nullopt;
if (sample_weight != nullptr)
sw = std::make_optional(raft::make_device_vector_view(const_cast<float*>(sample_weight), n_samples));
std::optional<raft::device_matrix_view<float>> rCentroids = std::nullopt;
if (centroids != nullptr)
rCentroids = std::make_optional(raft::make_device_matrix_view(centroids, params.n_clusters, n_features));
auto rLabels = raft::make_device_vector_view(labels, n_samples);

raft::cluster::kmeans_fit_predict<float, int>(handle,
rParams,
X_view,
sw,
rCentroids,
rLabels,
inertia,
n_iter);
handle.sync_stream(handle.get_stream());
sw = std::make_optional(raft::make_device_vector_view((sample_weight), n_samples));
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);
auto rLabels = raft::make_device_vector_view(labels, n_samples);

raft::cluster::kmeans_fit_predict<float, int>(
handle, params, X_view, sw, centroids_view, rLabels, inertia, n_iter);
}

void fit_predict(const raft::handle_t& handle,
Expand All @@ -80,41 +55,15 @@ void fit_predict(const raft::handle_t& handle,
double& inertia,
int& n_iter)
{
/*impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter);
impl::predict(
handle, params, centroids, X, n_samples, n_features, sample_weight, true, labels, inertia);*/

raft::cluster::KMeansParams rParams;
rParams.n_clusters = params.n_clusters;
rParams.init = raft::cluster::KMeansParams::InitMethod::KMeansPlusPlus;
rParams.max_iter = params.max_iter;
rParams.tol = params.tol;
rParams.verbosity = params.verbosity;
rParams.seed = params.seed;
rParams.metric = params.metric;
rParams.n_init = params.n_init;
rParams.oversampling_factor = params.oversampling_factor;
rParams.batch_samples = params.batch_samples;
rParams.batch_centroids = params.batch_centroids;
rParams.inertia_check = params.inertia_check;
auto X_view = raft::make_device_matrix_view((double*)X, n_samples, n_features);
std::optional<raft::device_vector_view<double>> sw = std::nullopt;
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
std::optional<raft::device_vector_view<const double>> sw = std::nullopt;
if (sample_weight != nullptr)
sw = std::make_optional(raft::make_device_vector_view(const_cast<double*>(sample_weight), n_samples));
std::optional<raft::device_matrix_view<double>> rCentroids = std::nullopt;
if (centroids != nullptr)
rCentroids = std::make_optional(raft::make_device_matrix_view(centroids, params.n_clusters, n_features));
auto rLabels = raft::make_device_vector_view(labels, n_samples);

raft::cluster::kmeans_fit_predict<double, int>(handle,
rParams,
X_view,
sw,
rCentroids,
rLabels,
inertia,
n_iter);
handle.sync_stream(handle.get_stream());
sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples));
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);
auto rLabels = raft::make_device_vector_view(labels, n_samples);

raft::cluster::kmeans_fit_predict<double, int>(
handle, params, X_view, sw, centroids_view, rLabels, inertia, n_iter);
}

// ----------------------------- fit ---------------------------------//
Expand All @@ -129,7 +78,14 @@ void fit(const raft::handle_t& handle,
float& inertia,
int& n_iter)
{
impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter);
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
std::optional<raft::device_vector_view<const float>> sw = std::nullopt;
if (sample_weight != nullptr)
sw = std::make_optional(raft::make_device_vector_view((sample_weight), n_samples));
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);

raft::cluster::kmeans_fit<float, int>(
handle, params, X_view, sw, centroids_view, inertia, n_iter);
}

void fit(const raft::handle_t& handle,
Expand All @@ -142,7 +98,14 @@ void fit(const raft::handle_t& handle,
double& inertia,
int& n_iter)
{
impl::fit(handle, params, X, n_samples, n_features, sample_weight, centroids, inertia, n_iter);
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
std::optional<raft::device_vector_view<const double>> sw = std::nullopt;
if (sample_weight != nullptr)
sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples));
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);

raft::cluster::kmeans_fit<double, int>(
handle, params, X_view, sw, centroids_view, inertia, n_iter);
}

// ----------------------------- predict ---------------------------------//
Expand All @@ -158,16 +121,15 @@ void predict(const raft::handle_t& handle,
int* labels,
float& inertia)
{
impl::predict(handle,
params,
centroids,
X,
n_samples,
n_features,
sample_weight,
normalize_weights,
labels,
inertia);
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
std::optional<raft::device_vector_view<const float>> sw = std::nullopt;
if (sample_weight != nullptr)
sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples));
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);
auto rLabels = raft::make_device_vector_view(labels, n_samples);

raft::cluster::kmeans_predict<float, int>(
handle, params, X_view, sw, centroids_view, rLabels, normalize_weights, inertia);
}

void predict(const raft::handle_t& handle,
Expand All @@ -181,16 +143,15 @@ void predict(const raft::handle_t& handle,
int* labels,
double& inertia)
{
impl::predict(handle,
params,
centroids,
X,
n_samples,
n_features,
sample_weight,
normalize_weights,
labels,
inertia);
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
std::optional<raft::device_vector_view<const double>> sw = std::nullopt;
if (sample_weight != nullptr)
sw = std::make_optional(raft::make_device_vector_view(sample_weight, n_samples));
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);
auto rLabels = raft::make_device_vector_view(labels, n_samples);

raft::cluster::kmeans_predict<double, int>(
handle, params, X_view, sw, centroids_view, rLabels, normalize_weights, inertia);
}

// ----------------------------- transform ---------------------------------//
Expand All @@ -200,10 +161,13 @@ void transform(const raft::handle_t& handle,
const float* X,
int n_samples,
int n_features,
int metric,
float* X_new)
{
impl::transform(handle, params, centroids, X, n_samples, n_features, metric, X_new);
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);
auto rX_new = raft::make_device_matrix_view(X_new, n_samples, n_features);

raft::cluster::kmeans_transform<float, int>(handle, params, X_view, centroids_view, rX_new);
}

void transform(const raft::handle_t& handle,
Expand All @@ -212,10 +176,13 @@ void transform(const raft::handle_t& handle,
const double* X,
int n_samples,
int n_features,
int metric,
double* X_new)
{
impl::transform(handle, params, centroids, X, n_samples, n_features, metric, X_new);
auto X_view = raft::make_device_matrix_view(X, n_samples, n_features);
auto centroids_view = raft::make_device_matrix_view(centroids, params.n_clusters, n_features);
auto rX_new = raft::make_device_matrix_view(X_new, n_samples, n_features);

raft::cluster::kmeans_transform<double, int>(handle, params, X_view, centroids_view, rX_new);
}

}; // end namespace kmeans
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/spectral/spectral.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <raft/sparse/coo.hpp>

//#include <raft/sparse/linalg/spectral.hpp>
#include <raft/sparse/linalg/spectral.hpp>

namespace raft {
class handle_t;
Expand Down Expand Up @@ -49,7 +49,7 @@ void fit_embedding(const raft::handle_t& handle,
float* out,
unsigned long long seed)
{
//raft::sparse::spectral::fit_embedding(handle, rows, cols, vals, nnz, n, n_components, out, seed);
raft::sparse::spectral::fit_embedding(handle, rows, cols, vals, nnz, n, n_components, out, seed);
}
} // namespace Spectral
} // namespace ML
11 changes: 4 additions & 7 deletions python/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ cdef extern from "cuml/cluster/kmeans.hpp" namespace "ML::kmeans":
const float *X,
int n_samples,
int n_features,
int metric,
float *X_new) except +

cdef void transform(handle_t& handle,
Expand All @@ -98,7 +97,6 @@ cdef extern from "cuml/cluster/kmeans.hpp" namespace "ML::kmeans":
const double *X,
int n_samples,
int n_features,
int metric,
double *X_new) except +


Expand Down Expand Up @@ -548,27 +546,26 @@ class KMeans(Base,
cdef uintptr_t preds_ptr = preds.ptr

# distance metric as L2-norm/euclidean distance: @todo - support other metrics # noqa: E501
distance_metric = 1
cdef KMeansParams params = self._params
params.metric = 1

if self.dtype == np.float32:
transform(
handle_[0],
<KMeansParams> self._params,
<KMeansParams> params,
<float*> cluster_centers_ptr,
<float*> input_ptr,
<size_t> n_rows,
<size_t> self.n_cols,
<int> distance_metric,
<float*> preds_ptr)
elif self.dtype == np.float64:
transform(
handle_[0],
<KMeansParams> self._params,
<KMeansParams> params,
<double*> cluster_centers_ptr,
<double*> input_ptr,
<size_t> n_rows,
<size_t> self.n_cols,
<int> distance_metric,
<double*> preds_ptr)
else:
raise TypeError('KMeans supports only float32 and float64 input,'
Expand Down

0 comments on commit f1589f1

Please sign in to comment.