From 615899ee7b4c898c6876f169c6b0bb72128d6104 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 27 Jan 2023 14:59:58 +0000 Subject: [PATCH 1/8] Expose KMeans init_plus_plus in pylibraft --- cpp/CMakeLists.txt | 2 + cpp/include/raft_runtime/cluster/kmeans.hpp | 10 ++++ .../cluster/kmeans_init_plus_plus_double.cu | 33 +++++++++++++ .../cluster/kmeans_init_plus_plus_float.cu | 33 +++++++++++++ .../pylibraft/pylibraft/cluster/__init__.py | 4 +- .../pylibraft/cluster/cpp/kmeans.pxd | 12 +++++ python/pylibraft/pylibraft/cluster/kmeans.pyx | 47 +++++++++++++++++++ .../pylibraft/pylibraft/test/test_kmeans.py | 16 +++++++ 8 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu create mode 100644 cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c6850b290f..4a445274b3 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -288,6 +288,8 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu + src/distance/cluster/kmeans_init_plus_plus_double.cu + src/distance/cluster/kmeans_init_plus_plus_float.cu src/distance/distance/specializations/detail/canberra.cu src/distance/distance/specializations/detail/chebyshev.cu src/distance/distance/specializations/detail/correlation.cu diff --git a/cpp/include/raft_runtime/cluster/kmeans.hpp b/cpp/include/raft_runtime/cluster/kmeans.hpp index 3386774414..aab8c14eab 100644 --- a/cpp/include/raft_runtime/cluster/kmeans.hpp +++ b/cpp/include/raft_runtime/cluster/kmeans.hpp @@ -66,6 +66,16 @@ void fit(raft::device_resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids); + +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids); + void cluster_cost(raft::device_resources const& handle, const float* X, int n_samples, diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu new file mode 100644 index 0000000000..d7ce35668a --- /dev/null +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft::runtime::cluster::kmeans { + +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids) +{ + // XXX what should the size of this be? + rmm::device_uvector workspace(10, handle.get_stream()); + raft::cluster::kmeans::init_plus_plus( + handle, params, X, centroids, workspace); +} +} // namespace raft::runtime::cluster::kmeans diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu new file mode 100644 index 0000000000..ae3442f124 --- /dev/null +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft::runtime::cluster::kmeans { + +void init_plus_plus(raft::device_resources const& handle, + const raft::cluster::kmeans::KMeansParams& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids) +{ + // XXX what should the size of this be? + rmm::device_uvector workspace(10, handle.get_stream()); + raft::cluster::kmeans::init_plus_plus( + handle, params, X, centroids, workspace); +} +} // namespace raft::runtime::cluster::kmeans diff --git a/python/pylibraft/pylibraft/cluster/__init__.py b/python/pylibraft/pylibraft/cluster/__init__.py index 4facc3dae2..730a3f287f 100644 --- a/python/pylibraft/pylibraft/cluster/__init__.py +++ b/python/pylibraft/pylibraft/cluster/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # -from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit +from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit, init_plus_plus -__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit"] +__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit", "init_plus_plus"] diff --git a/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd b/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd index c43f18ac3f..4a5a47de68 100644 --- a/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd +++ b/python/pylibraft/pylibraft/cluster/cpp/kmeans.pxd @@ -75,6 +75,18 @@ cdef extern from "raft_runtime/cluster/kmeans.hpp" \ const double * centroids, double * cost) except + + cdef void init_plus_plus( + const device_resources & handle, + const KMeansParams& params, + device_matrix_view[float, int, row_major] X, + device_matrix_view[float, int, row_major] centroids) except + + + cdef void init_plus_plus( + const device_resources & handle, + const KMeansParams& params, + device_matrix_view[double, int, row_major] X, + device_matrix_view[double, int, row_major] centroids) except + + cdef void fit( const device_resources & handle, const KMeansParams& params, diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 1d0b9ad241..cd31b34eaf 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -40,6 +40,7 @@ from pylibraft.cluster.cpp cimport kmeans as cpp_kmeans, kmeans_types from pylibraft.cluster.cpp.kmeans cimport ( cluster_cost as cpp_cluster_cost, update_centroids, + init_plus_plus as cpp_init_plus_plus, ) from pylibraft.common.cpp.mdspan cimport * from pylibraft.common.cpp.optional cimport optional @@ -199,6 +200,52 @@ def compute_new_centroids(X, raise ValueError("dtype %s not supported" % x_dt) +@auto_sync_handle +@auto_convert_output +def init_plus_plus(X, n_clusters, seed=None, handle=None): + cdef device_resources *h = handle.getHandle() + + # Can't set attributes of KMeansParameters after creating it, so taking + # a detour via a dict to collect the possible constructor arguments + params_ = dict(n_clusters=n_clusters) + if seed is not None: + params_["seed"] = seed + params = KMeansParams(**params_) + + X_cai = cai_wrapper(X) + X_cai.validate_shape_dtype(expected_dims=2) + dtype = X_cai.dtype + + centroids_shape = (n_clusters, X_cai.shape[1]) + centroids = device_ndarray.empty(centroids_shape, dtype=dtype) + centroids_cai = cai_wrapper(centroids) + + if dtype == np.float64: + cpp_init_plus_plus(deref(h), + params.c_obj, + make_device_matrix_view[double, int, row_major]( + X_cai.data, + X_cai.shape[0], X_cai.shape[1]), + make_device_matrix_view[double, int, row_major]( + centroids_cai.data, + centroids_cai.shape[0], centroids_cai.shape[1]), + ) + elif dtype == np.float32: + cpp_init_plus_plus(deref(h), + params.c_obj, + make_device_matrix_view[float, int, row_major]( + X_cai.data, + X_cai.shape[0], X_cai.shape[1]), + make_device_matrix_view[float, int, row_major]( + centroids_cai.data, + centroids_cai.shape[0], centroids_cai.shape[1]), + ) + else: + raise ValueError(f"Unhandled dtype ({dtype}) for X.") + + return centroids + + @auto_sync_handle @auto_convert_output def cluster_cost(X, centroids, handle=None): diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index 4c2388de62..dbe5d6030a 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -21,6 +21,7 @@ cluster_cost, compute_new_centroids, fit, + init_plus_plus, ) from pylibraft.common import DeviceResources, device_ndarray from pylibraft.distance import pairwise_distance @@ -147,3 +148,18 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) + + +@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_cols", [5, 25]) +@pytest.mark.parametrize("n_clusters", [4, 15]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_init_plus_plus(n_rows, n_cols, n_clusters, dtype): + X = np.random.random_sample((n_rows, n_cols)).astype(dtype) + X_device = device_ndarray(X) + + centroids = init_plus_plus(X_device, n_clusters, seed=1) + + # Centroids are selected from the existing points + for centroid in centroids.copy_to_host(): + assert (centroid == X).all(axis=1).any() \ No newline at end of file From 2140e6e83960d9b97441f9e98b1130e82e871807 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jan 2023 08:01:21 +0000 Subject: [PATCH 2/8] Check centroids shape --- python/pylibraft/pylibraft/test/test_kmeans.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index dbe5d6030a..d2d1d12ac7 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -159,7 +159,10 @@ def test_init_plus_plus(n_rows, n_cols, n_clusters, dtype): X_device = device_ndarray(X) centroids = init_plus_plus(X_device, n_clusters, seed=1) + centroids_ = centroids.copy_to_host() + + assert centroids_.shape == (n_clusters, X.shape[1]) # Centroids are selected from the existing points - for centroid in centroids.copy_to_host(): + for centroid in centroids_: assert (centroid == X).all(axis=1).any() \ No newline at end of file From 3a54c2ed4d700cec334f3461fb9c1dc509cf2e5d Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 30 Jan 2023 09:03:47 +0000 Subject: [PATCH 3/8] Add optional pre-allocated centroids argument --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 27 ++++++++---- .../pylibraft/pylibraft/test/test_kmeans.py | 44 ++++++++++++++++--- 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index cd31b34eaf..04a76a96b6 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -202,9 +202,26 @@ def compute_new_centroids(X, @auto_sync_handle @auto_convert_output -def init_plus_plus(X, n_clusters, seed=None, handle=None): +def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None): + if n_clusters is not None and centroids is not None: + msg = ("Parameters 'n_clusters' and 'centroids' are exclusive. Only " + + "pass one at a time.") + raise RuntimeError(msg) + cdef device_resources *h = handle.getHandle() + X_cai = cai_wrapper(X) + X_cai.validate_shape_dtype(expected_dims=2) + dtype = X_cai.dtype + + if centroids is not None: + n_clusters = centroids.shape[0] + else: + centroids_shape = (n_clusters, X_cai.shape[1]) + centroids = device_ndarray.empty(centroids_shape, dtype=dtype) + + centroids_cai = cai_wrapper(centroids) + # Can't set attributes of KMeansParameters after creating it, so taking # a detour via a dict to collect the possible constructor arguments params_ = dict(n_clusters=n_clusters) @@ -212,14 +229,6 @@ def init_plus_plus(X, n_clusters, seed=None, handle=None): params_["seed"] = seed params = KMeansParams(**params_) - X_cai = cai_wrapper(X) - X_cai.validate_shape_dtype(expected_dims=2) - dtype = X_cai.dtype - - centroids_shape = (n_clusters, X_cai.shape[1]) - centroids = device_ndarray.empty(centroids_shape, dtype=dtype) - centroids_cai = cai_wrapper(centroids) - if dtype == np.float64: cpp_init_plus_plus(deref(h), params.c_obj, diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index d2d1d12ac7..399fb66064 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -82,9 +82,7 @@ def test_compute_new_centroids( new_centroids_device = device_ndarray(new_centroids) sample_weights = np.ones((n_rows,)).astype(dtype) / n_rows - sample_weights_device = ( - device_ndarray(sample_weights) if additional_args else None - ) + sample_weights_device = device_ndarray(sample_weights) if additional_args else None # Compute new centroids naively dists = np.zeros((n_rows, n_clusters), dtype=dtype) @@ -141,9 +139,7 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): ).copy_to_host() cluster_ids = np.argmin(distances, axis=1) - cluster_distances = np.take_along_axis( - distances, cluster_ids[:, None], axis=1 - ) + cluster_distances = np.take_along_axis(distances, cluster_ids[:, None], axis=1) # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 @@ -165,4 +161,38 @@ def test_init_plus_plus(n_rows, n_cols, n_clusters, dtype): # Centroids are selected from the existing points for centroid in centroids_: - assert (centroid == X).all(axis=1).any() \ No newline at end of file + assert (centroid == X).all(axis=1).any() + + +@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_cols", [5, 25]) +@pytest.mark.parametrize("n_clusters", [4, 15]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_init_plus_plus_preallocated_output(n_rows, n_cols, n_clusters, dtype): + X = np.random.random_sample((n_rows, n_cols)).astype(dtype) + X_device = device_ndarray(X) + + centroids = device_ndarray.empty((n_clusters, n_cols), dtype=dtype) + + new_centroids = init_plus_plus(X_device, centroids=centroids, seed=1) + new_centroids_ = new_centroids.copy_to_host() + + # The shape should not have changed + assert new_centroids_.shape == centroids.shape + + # Centroids are selected from the existing points + for centroid in new_centroids_: + assert (centroid == X).all(axis=1).any() + + +def test_init_plus_plus_exclusive_arguments(): + X = np.random.random_sample((10, 5)).astype(np.float64) + X = device_ndarray(X) + + n_clusters = 3 + + centroids = np.random.random_sample((n_clusters, 5)).astype(np.float64) + centroids = device_ndarray(centroids) + + with pytest.raises(RuntimeError, match="Parameters 'n_clusters' and 'centroids'"): + init_plus_plus(X, n_clusters, centroids=centroids) From 2a8de873db5241949929282cd9798fafaf236cf9 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Tue, 31 Jan 2023 10:19:33 +0000 Subject: [PATCH 4/8] Add docstring --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 04a76a96b6..6b72df6d35 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -203,6 +203,34 @@ def compute_new_centroids(X, @auto_sync_handle @auto_convert_output def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None): + """ + Compute initial centroids using the "kmeans++" algorithm. + + Parameters + ---------- + + X : Input CUDA array interface compliant matrix shape (m, k) + n_clusters : Number of clusters to select + seed : Controls the random sampling of centroids + centroids : Optional writable CUDA array interface compliant matrix shape + (n_clusters, k). Use instead of passing `n_clusters`. + {handle_docstring} + + Examples + -------- + + >>> import cupy as cp + >>> from pylibraft.cluster.kmeans import init_plus_plus + + >>> n_samples = 5000 + >>> n_features = 50 + >>> n_clusters = 3 + + >>> X = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> centroids = init_plus_plus(X, n_clusters) + """ if n_clusters is not None and centroids is not None: msg = ("Parameters 'n_clusters' and 'centroids' are exclusive. Only " + "pass one at a time.") From 318991cfd747aaf5cb31b5b8ac8be213b686c098 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 1 Feb 2023 08:37:48 +0000 Subject: [PATCH 5/8] Do not set a size for the temporary workspace --- cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu | 3 +-- cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu index d7ce35668a..b3dc60eaea 100644 --- a/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu @@ -25,8 +25,7 @@ void init_plus_plus(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids) { - // XXX what should the size of this be? - rmm::device_uvector workspace(10, handle.get_stream()); + rmm::device_uvector workspace(0, handle.get_stream()); raft::cluster::kmeans::init_plus_plus( handle, params, X, centroids, workspace); } diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu index ae3442f124..29eb74e9fb 100644 --- a/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu @@ -25,8 +25,7 @@ void init_plus_plus(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids) { - // XXX what should the size of this be? - rmm::device_uvector workspace(10, handle.get_stream()); + rmm::device_uvector workspace(0, handle.get_stream()); raft::cluster::kmeans::init_plus_plus( handle, params, X, centroids, workspace); } From 3a0853d2e9184644c98477ebeba96b91485622f0 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 1 Feb 2023 08:44:34 +0000 Subject: [PATCH 6/8] Only raise an exception when n_clusters and centroids shape are inconsistent --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 4 +++- python/pylibraft/pylibraft/test/test_kmeans.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 6b72df6d35..d5fa07069e 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -231,7 +231,9 @@ def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None): >>> centroids = init_plus_plus(X, n_clusters) """ - if n_clusters is not None and centroids is not None: + if (n_clusters is not None and + centroids is not None and + n_clusters != centroids.shape[0]): msg = ("Parameters 'n_clusters' and 'centroids' are exclusive. Only " + "pass one at a time.") raise RuntimeError(msg) diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index 399fb66064..f990ace8b4 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -186,12 +186,14 @@ def test_init_plus_plus_preallocated_output(n_rows, n_cols, n_clusters, dtype): def test_init_plus_plus_exclusive_arguments(): + # Check an exception is raised when n_clusters and centroids shape + # are inconsistent. X = np.random.random_sample((10, 5)).astype(np.float64) X = device_ndarray(X) n_clusters = 3 - centroids = np.random.random_sample((n_clusters, 5)).astype(np.float64) + centroids = np.random.random_sample((n_clusters + 1, 5)).astype(np.float64) centroids = device_ndarray(centroids) with pytest.raises(RuntimeError, match="Parameters 'n_clusters' and 'centroids'"): From c4442ee80c9dd2272080fba3b68f543f91521e8b Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 1 Feb 2023 12:41:24 +0000 Subject: [PATCH 7/8] Fix linting --- .../cluster/kmeans_init_plus_plus_double.cu | 3 +-- .../cluster/kmeans_init_plus_plus_float.cu | 3 +-- python/pylibraft/pylibraft/cluster/__init__.py | 18 +++++++++++++++--- python/pylibraft/pylibraft/cluster/kmeans.pyx | 2 +- python/pylibraft/pylibraft/test/test_kmeans.py | 12 +++++++++--- 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu index b3dc60eaea..53132e13e7 100644 --- a/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu @@ -26,7 +26,6 @@ void init_plus_plus(raft::device_resources const& handle, raft::device_matrix_view centroids) { rmm::device_uvector workspace(0, handle.get_stream()); - raft::cluster::kmeans::init_plus_plus( - handle, params, X, centroids, workspace); + raft::cluster::kmeans::init_plus_plus(handle, params, X, centroids, workspace); } } // namespace raft::runtime::cluster::kmeans diff --git a/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu index 29eb74e9fb..814b0b41b8 100644 --- a/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu +++ b/cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu @@ -26,7 +26,6 @@ void init_plus_plus(raft::device_resources const& handle, raft::device_matrix_view centroids) { rmm::device_uvector workspace(0, handle.get_stream()); - raft::cluster::kmeans::init_plus_plus( - handle, params, X, centroids, workspace); + raft::cluster::kmeans::init_plus_plus(handle, params, X, centroids, workspace); } } // namespace raft::runtime::cluster::kmeans diff --git a/python/pylibraft/pylibraft/cluster/__init__.py b/python/pylibraft/pylibraft/cluster/__init__.py index 730a3f287f..00b12eab9b 100644 --- a/python/pylibraft/pylibraft/cluster/__init__.py +++ b/python/pylibraft/pylibraft/cluster/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,18 @@ # limitations under the License. # -from .kmeans import KMeansParams, cluster_cost, compute_new_centroids, fit, init_plus_plus +from .kmeans import ( + KMeansParams, + cluster_cost, + compute_new_centroids, + fit, + init_plus_plus, +) -__all__ = ["KMeansParams", "cluster_cost", "compute_new_centroids", "fit", "init_plus_plus"] +__all__ = [ + "KMeansParams", + "cluster_cost", + "compute_new_centroids", + "fit", + "init_plus_plus", +] diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index d5fa07069e..ac5bb65a35 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -39,8 +39,8 @@ from pylibraft.distance import DISTANCE_TYPES from pylibraft.cluster.cpp cimport kmeans as cpp_kmeans, kmeans_types from pylibraft.cluster.cpp.kmeans cimport ( cluster_cost as cpp_cluster_cost, - update_centroids, init_plus_plus as cpp_init_plus_plus, + update_centroids, ) from pylibraft.common.cpp.mdspan cimport * from pylibraft.common.cpp.optional cimport optional diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index f990ace8b4..8736c6ee7a 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -82,7 +82,9 @@ def test_compute_new_centroids( new_centroids_device = device_ndarray(new_centroids) sample_weights = np.ones((n_rows,)).astype(dtype) / n_rows - sample_weights_device = device_ndarray(sample_weights) if additional_args else None + sample_weights_device = ( + device_ndarray(sample_weights) if additional_args else None + ) # Compute new centroids naively dists = np.zeros((n_rows, n_clusters), dtype=dtype) @@ -139,7 +141,9 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): ).copy_to_host() cluster_ids = np.argmin(distances, axis=1) - cluster_distances = np.take_along_axis(distances, cluster_ids[:, None], axis=1) + cluster_distances = np.take_along_axis( + distances, cluster_ids[:, None], axis=1 + ) # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 @@ -196,5 +200,7 @@ def test_init_plus_plus_exclusive_arguments(): centroids = np.random.random_sample((n_clusters + 1, 5)).astype(np.float64) centroids = device_ndarray(centroids) - with pytest.raises(RuntimeError, match="Parameters 'n_clusters' and 'centroids'"): + with pytest.raises( + RuntimeError, match="Parameters 'n_clusters' and 'centroids'" + ): init_plus_plus(X, n_clusters, centroids=centroids) From 1a22ffccc73e4e938d0300faf5d5c45fdf637d7c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 3 Feb 2023 15:41:40 -0500 Subject: [PATCH 8/8] Fixing python style --- python/pylibraft/pylibraft/cluster/kmeans.pyx | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index ac5bb65a35..b61fb4ab02 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -232,10 +232,9 @@ def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None): >>> centroids = init_plus_plus(X, n_clusters) """ if (n_clusters is not None and - centroids is not None and - n_clusters != centroids.shape[0]): - msg = ("Parameters 'n_clusters' and 'centroids' are exclusive. Only " + - "pass one at a time.") + centroids is not None and n_clusters != centroids.shape[0]): + msg = ("Parameters 'n_clusters' and 'centroids' " + "are exclusive. Only pass one at a time.") raise RuntimeError(msg) cdef device_resources *h = handle.getHandle() @@ -260,24 +259,24 @@ def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None): params = KMeansParams(**params_) if dtype == np.float64: - cpp_init_plus_plus(deref(h), - params.c_obj, - make_device_matrix_view[double, int, row_major]( - X_cai.data, - X_cai.shape[0], X_cai.shape[1]), - make_device_matrix_view[double, int, row_major]( - centroids_cai.data, - centroids_cai.shape[0], centroids_cai.shape[1]), + cpp_init_plus_plus( + deref(h), params.c_obj, + make_device_matrix_view[double, int, row_major]( + X_cai.data, + X_cai.shape[0], X_cai.shape[1]), + make_device_matrix_view[double, int, row_major]( + centroids_cai.data, + centroids_cai.shape[0], centroids_cai.shape[1]), ) elif dtype == np.float32: - cpp_init_plus_plus(deref(h), - params.c_obj, - make_device_matrix_view[float, int, row_major]( - X_cai.data, - X_cai.shape[0], X_cai.shape[1]), - make_device_matrix_view[float, int, row_major]( - centroids_cai.data, - centroids_cai.shape[0], centroids_cai.shape[1]), + cpp_init_plus_plus( + deref(h), params.c_obj, + make_device_matrix_view[float, int, row_major]( + X_cai.data, + X_cai.shape[0], X_cai.shape[1]), + make_device_matrix_view[float, int, row_major]( + centroids_cai.data, + centroids_cai.shape[0], centroids_cai.shape[1]), ) else: raise ValueError(f"Unhandled dtype ({dtype}) for X.")