Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HELP-REQ] Expose KMeans init_plus_plus in pylibraft #1198

Merged
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu
src/distance/cluster/kmeans_fit_float.cu
src/distance/cluster/kmeans_fit_double.cu
src/distance/cluster/kmeans_init_plus_plus_double.cu
src/distance/cluster/kmeans_init_plus_plus_float.cu
src/distance/distance/specializations/detail/canberra_double_double_double_int.cu
src/distance/distance/specializations/detail/canberra_float_float_float_int.cu
src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu
Expand Down
10 changes: 10 additions & 0 deletions cpp/include/raft_runtime/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ void fit(raft::device_resources const& handle,
raft::host_scalar_view<double, int> inertia,
raft::host_scalar_view<int, int> n_iter);

void init_plus_plus(raft::device_resources const& handle,
const raft::cluster::kmeans::KMeansParams& params,
raft::device_matrix_view<const float, int, row_major> X,
raft::device_matrix_view<float, int, row_major> centroids);

void init_plus_plus(raft::device_resources const& handle,
const raft::cluster::kmeans::KMeansParams& params,
raft::device_matrix_view<const double, int, row_major> X,
raft::device_matrix_view<double, int, row_major> centroids);

void cluster_cost(raft::device_resources const& handle,
const float* X,
int n_samples,
Expand Down
31 changes: 31 additions & 0 deletions cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

void init_plus_plus(raft::device_resources const& handle,
const raft::cluster::kmeans::KMeansParams& params,
raft::device_matrix_view<const double, int> X,
raft::device_matrix_view<double, int> centroids)
{
rmm::device_uvector<char> workspace(0, handle.get_stream());
raft::cluster::kmeans::init_plus_plus<double, int>(handle, params, X, centroids, workspace);
}
} // namespace raft::runtime::cluster::kmeans
31 changes: 31 additions & 0 deletions cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

void init_plus_plus(raft::device_resources const& handle,
const raft::cluster::kmeans::KMeansParams& params,
raft::device_matrix_view<const float, int> X,
raft::device_matrix_view<float, int> centroids)
{
rmm::device_uvector<char> workspace(0, handle.get_stream());
raft::cluster::kmeans::init_plus_plus<float, int>(handle, params, X, centroids, workspace);
}
} // namespace raft::runtime::cluster::kmeans
18 changes: 15 additions & 3 deletions python/pylibraft/pylibraft/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,18 @@
# limitations under the License.
#

from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit
from .kmeans import (
KMeansParams,
cluster_cost,
compute_new_centroids,
fit,
init_plus_plus,
)

__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit"]
__all__ = [
"KMeansParams",
"cluster_cost",
"compute_new_centroids",
"fit",
"init_plus_plus",
]
12 changes: 12 additions & 0 deletions python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ cdef extern from "raft_runtime/cluster/kmeans.hpp" \
const double * centroids,
double * cost) except +

cdef void init_plus_plus(
const device_resources & handle,
const KMeansParams& params,
device_matrix_view[float, int, row_major] X,
device_matrix_view[float, int, row_major] centroids) except +

cdef void init_plus_plus(
const device_resources & handle,
const KMeansParams& params,
device_matrix_view[double, int, row_major] X,
device_matrix_view[double, int, row_major] centroids) except +

cdef void fit(
const device_resources & handle,
const KMeansParams& params,
Expand Down
85 changes: 85 additions & 0 deletions python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ from pylibraft.distance import DISTANCE_TYPES
from pylibraft.cluster.cpp cimport kmeans as cpp_kmeans, kmeans_types
from pylibraft.cluster.cpp.kmeans cimport (
cluster_cost as cpp_cluster_cost,
init_plus_plus as cpp_init_plus_plus,
update_centroids,
)
from pylibraft.common.cpp.mdspan cimport *
Expand Down Expand Up @@ -199,6 +200,90 @@ def compute_new_centroids(X,
raise ValueError("dtype %s not supported" % x_dt)


@auto_sync_handle
@auto_convert_output
def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None):
"""
Compute initial centroids using the "kmeans++" algorithm.

Parameters
----------

X : Input CUDA array interface compliant matrix shape (m, k)
n_clusters : Number of clusters to select
seed : Controls the random sampling of centroids
centroids : Optional writable CUDA array interface compliant matrix shape
(n_clusters, k). Use instead of passing `n_clusters`.
{handle_docstring}

Examples
--------

>>> import cupy as cp
>>> from pylibraft.cluster.kmeans import init_plus_plus

>>> n_samples = 5000
>>> n_features = 50
>>> n_clusters = 3

>>> X = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)

>>> centroids = init_plus_plus(X, n_clusters)
"""
if (n_clusters is not None and
centroids is not None and n_clusters != centroids.shape[0]):
msg = ("Parameters 'n_clusters' and 'centroids' "
"are exclusive. Only pass one at a time.")
raise RuntimeError(msg)

cdef device_resources *h = <device_resources*><size_t>handle.getHandle()

X_cai = cai_wrapper(X)
X_cai.validate_shape_dtype(expected_dims=2)
dtype = X_cai.dtype

if centroids is not None:
n_clusters = centroids.shape[0]
else:
centroids_shape = (n_clusters, X_cai.shape[1])
centroids = device_ndarray.empty(centroids_shape, dtype=dtype)

centroids_cai = cai_wrapper(centroids)

# Can't set attributes of KMeansParameters after creating it, so taking
# a detour via a dict to collect the possible constructor arguments
params_ = dict(n_clusters=n_clusters)
if seed is not None:
params_["seed"] = seed
params = KMeansParams(**params_)

if dtype == np.float64:
cpp_init_plus_plus(
deref(h), params.c_obj,
make_device_matrix_view[double, int, row_major](
<double *><uintptr_t>X_cai.data,
<int>X_cai.shape[0], <int>X_cai.shape[1]),
make_device_matrix_view[double, int, row_major](
<double *><uintptr_t>centroids_cai.data,
<int>centroids_cai.shape[0], <int>centroids_cai.shape[1]),
)
elif dtype == np.float32:
cpp_init_plus_plus(
deref(h), params.c_obj,
make_device_matrix_view[float, int, row_major](
<float *><uintptr_t>X_cai.data,
<int>X_cai.shape[0], <int>X_cai.shape[1]),
make_device_matrix_view[float, int, row_major](
<float *><uintptr_t>centroids_cai.data,
<int>centroids_cai.shape[0], <int>centroids_cai.shape[1]),
)
else:
raise ValueError(f"Unhandled dtype ({dtype}) for X.")

return centroids


@auto_sync_handle
@auto_convert_output
def cluster_cost(X, centroids, handle=None):
Expand Down
57 changes: 57 additions & 0 deletions python/pylibraft/pylibraft/test/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
cluster_cost,
compute_new_centroids,
fit,
init_plus_plus,
)
from pylibraft.common import DeviceResources, device_ndarray
from pylibraft.distance import pairwise_distance
Expand Down Expand Up @@ -147,3 +148,59 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype):
# need reduced tolerance for float32
tol = 1e-3 if dtype == np.float32 else 1e-6
assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol)


@pytest.mark.parametrize("n_rows", [100])
@pytest.mark.parametrize("n_cols", [5, 25])
@pytest.mark.parametrize("n_clusters", [4, 15])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_init_plus_plus(n_rows, n_cols, n_clusters, dtype):
X = np.random.random_sample((n_rows, n_cols)).astype(dtype)
X_device = device_ndarray(X)

centroids = init_plus_plus(X_device, n_clusters, seed=1)
centroids_ = centroids.copy_to_host()

assert centroids_.shape == (n_clusters, X.shape[1])

# Centroids are selected from the existing points
for centroid in centroids_:
assert (centroid == X).all(axis=1).any()


@pytest.mark.parametrize("n_rows", [100])
@pytest.mark.parametrize("n_cols", [5, 25])
@pytest.mark.parametrize("n_clusters", [4, 15])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_init_plus_plus_preallocated_output(n_rows, n_cols, n_clusters, dtype):
X = np.random.random_sample((n_rows, n_cols)).astype(dtype)
X_device = device_ndarray(X)

centroids = device_ndarray.empty((n_clusters, n_cols), dtype=dtype)

new_centroids = init_plus_plus(X_device, centroids=centroids, seed=1)
new_centroids_ = new_centroids.copy_to_host()

# The shape should not have changed
assert new_centroids_.shape == centroids.shape

# Centroids are selected from the existing points
for centroid in new_centroids_:
assert (centroid == X).all(axis=1).any()


def test_init_plus_plus_exclusive_arguments():
# Check an exception is raised when n_clusters and centroids shape
# are inconsistent.
X = np.random.random_sample((10, 5)).astype(np.float64)
X = device_ndarray(X)

n_clusters = 3

centroids = np.random.random_sample((n_clusters + 1, 5)).astype(np.float64)
centroids = device_ndarray(centroids)

with pytest.raises(
RuntimeError, match="Parameters 'n_clusters' and 'centroids'"
):
init_plus_plus(X, n_clusters, centroids=centroids)