diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index cf938a5d33..a7eb759d89 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -307,6 +307,8 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu + src/distance/cluster/kmeans_init_plus_plus_double.cu + src/distance/cluster/kmeans_init_plus_plus_float.cu src/distance/distance/specializations/detail/canberra_double_double_double_int.cu src/distance/distance/specializations/detail/canberra_float_float_float_int.cu src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu diff --git a/cpp/include/raft_runtime/cluster/kmeans.hpp b/cpp/include/raft_runtime/cluster/kmeans.hpp index 3386774414..aab8c14eab 100644 --- a/cpp/include/raft_runtime/cluster/kmeans.hpp +++ b/cpp/include/raft_runtime/cluster/kmeans.hpp @@ -66,6 +66,16 @@ void fit(raft::device_resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids); + +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids); + void cluster_cost(raft::device_resources const& handle, const float* X, int n_samples, diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu new file mode 100644 index 0000000000..53132e13e7 --- /dev/null +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft::runtime::cluster::kmeans { + +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids) +{ + rmm::device_uvector workspace(0, handle.get_stream()); + raft::cluster::kmeans::init_plus_plus(handle, params, X, centroids, workspace); +} +} // namespace raft::runtime::cluster::kmeans diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu new file mode 100644 index 0000000000..814b0b41b8 --- /dev/null +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft::runtime::cluster::kmeans { + +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids) +{ + rmm::device_uvector workspace(0, handle.get_stream()); + raft::cluster::kmeans::init_plus_plus(handle, params, X, centroids, workspace); +} +} // namespace raft::runtime::cluster::kmeans diff --git a/python/pylibraft/pylibraft/cluster/__init__.py b/python/pylibraft/pylibraft/cluster/__init__.py index 4facc3dae2..00b12eab9b 100644 --- a/python/pylibraft/pylibraft/cluster/__init__.py +++ b/python/pylibraft/pylibraft/cluster/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,18 @@ # limitations under the License. # -from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit +from .kmeans import ( + KMeansParams, + cluster_cost, + compute_new_centroids, + fit, + init_plus_plus, +) -__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit"] +__all__ = [ + "KMeansParams", + "cluster_cost", + "compute_new_centroids", + "fit", + "init_plus_plus", +] diff --git a/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd b/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd index c43f18ac3f..4a5a47de68 100644 --- a/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd +++ b/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd @@ -75,6 +75,18 @@ cdef extern from "raft_runtime/cluster/kmeans.hpp" \ const double * centroids, double * cost) except + + cdef void init_plus_plus( + const device_resources & handle, + const KMeansParams& params, + device_matrix_view[float, int, row_major] X, + device_matrix_view[float, int, row_major] centroids) except + + + cdef void init_plus_plus( + const device_resources & handle, + const KMeansParams& params, + device_matrix_view[double, int, row_major] X, + device_matrix_view[double, int, row_major] centroids) except + + cdef void fit( const device_resources & handle, const KMeansParams& params, diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 1d0b9ad241..b61fb4ab02 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -39,6 +39,7 @@ from pylibraft.distance import DISTANCE_TYPES from pylibraft.cluster.cpp cimport kmeans as cpp_kmeans, kmeans_types from pylibraft.cluster.cpp.kmeans cimport ( cluster_cost as cpp_cluster_cost, + init_plus_plus as cpp_init_plus_plus, update_centroids, ) from pylibraft.common.cpp.mdspan cimport * @@ -199,6 +200,90 @@ def compute_new_centroids(X, raise ValueError("dtype %s not supported" % x_dt) +@auto_sync_handle +@auto_convert_output +def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None): + """ + Compute initial centroids using the "kmeans++" algorithm. + + Parameters + ---------- + + X : Input CUDA array interface compliant matrix shape (m, k) + n_clusters : Number of clusters to select + seed : Controls the random sampling of centroids + centroids : Optional writable CUDA array interface compliant matrix shape + (n_clusters, k). Use instead of passing `n_clusters`. + {handle_docstring} + + Examples + -------- + + >>> import cupy as cp + >>> from pylibraft.cluster.kmeans import init_plus_plus + + >>> n_samples = 5000 + >>> n_features = 50 + >>> n_clusters = 3 + + >>> X = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> centroids = init_plus_plus(X, n_clusters) + """ + if (n_clusters is not None and + centroids is not None and n_clusters != centroids.shape[0]): + msg = ("Parameters 'n_clusters' and 'centroids' " + "are exclusive. Only pass one at a time.") + raise RuntimeError(msg) + + cdef device_resources *h = handle.getHandle() + + X_cai = cai_wrapper(X) + X_cai.validate_shape_dtype(expected_dims=2) + dtype = X_cai.dtype + + if centroids is not None: + n_clusters = centroids.shape[0] + else: + centroids_shape = (n_clusters, X_cai.shape[1]) + centroids = device_ndarray.empty(centroids_shape, dtype=dtype) + + centroids_cai = cai_wrapper(centroids) + + # Can't set attributes of KMeansParameters after creating it, so taking + # a detour via a dict to collect the possible constructor arguments + params_ = dict(n_clusters=n_clusters) + if seed is not None: + params_["seed"] = seed + params = KMeansParams(**params_) + + if dtype == np.float64: + cpp_init_plus_plus( + deref(h), params.c_obj, + make_device_matrix_view[double, int, row_major]( + X_cai.data, + X_cai.shape[0], X_cai.shape[1]), + make_device_matrix_view[double, int, row_major]( + centroids_cai.data, + centroids_cai.shape[0], centroids_cai.shape[1]), + ) + elif dtype == np.float32: + cpp_init_plus_plus( + deref(h), params.c_obj, + make_device_matrix_view[float, int, row_major]( + X_cai.data, + X_cai.shape[0], X_cai.shape[1]), + make_device_matrix_view[float, int, row_major]( + centroids_cai.data, + centroids_cai.shape[0], centroids_cai.shape[1]), + ) + else: + raise ValueError(f"Unhandled dtype ({dtype}) for X.") + + return centroids + + @auto_sync_handle @auto_convert_output def cluster_cost(X, centroids, handle=None): diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index 4c2388de62..8736c6ee7a 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -21,6 +21,7 @@ cluster_cost, compute_new_centroids, fit, + init_plus_plus, ) from pylibraft.common import DeviceResources, device_ndarray from pylibraft.distance import pairwise_distance @@ -147,3 +148,59 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) + + +@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_cols", [5, 25]) +@pytest.mark.parametrize("n_clusters", [4, 15]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_init_plus_plus(n_rows, n_cols, n_clusters, dtype): + X = np.random.random_sample((n_rows, n_cols)).astype(dtype) + X_device = device_ndarray(X) + + centroids = init_plus_plus(X_device, n_clusters, seed=1) + centroids_ = centroids.copy_to_host() + + assert centroids_.shape == (n_clusters, X.shape[1]) + + # Centroids are selected from the existing points + for centroid in centroids_: + assert (centroid == X).all(axis=1).any() + + +@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_cols", [5, 25]) +@pytest.mark.parametrize("n_clusters", [4, 15]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_init_plus_plus_preallocated_output(n_rows, n_cols, n_clusters, dtype): + X = np.random.random_sample((n_rows, n_cols)).astype(dtype) + X_device = device_ndarray(X) + + centroids = device_ndarray.empty((n_clusters, n_cols), dtype=dtype) + + new_centroids = init_plus_plus(X_device, centroids=centroids, seed=1) + new_centroids_ = new_centroids.copy_to_host() + + # The shape should not have changed + assert new_centroids_.shape == centroids.shape + + # Centroids are selected from the existing points + for centroid in new_centroids_: + assert (centroid == X).all(axis=1).any() + + +def test_init_plus_plus_exclusive_arguments(): + # Check an exception is raised when n_clusters and centroids shape + # are inconsistent. + X = np.random.random_sample((10, 5)).astype(np.float64) + X = device_ndarray(X) + + n_clusters = 3 + + centroids = np.random.random_sample((n_clusters + 1, 5)).astype(np.float64) + centroids = device_ndarray(centroids) + + with pytest.raises( + RuntimeError, match="Parameters 'n_clusters' and 'centroids'" + ): + init_plus_plus(X, n_clusters, centroids=centroids)