Skip to content

Commit

Permalink
Expose cluster_cost to python
Browse files Browse the repository at this point in the history
Add cython bindings for the cluster_cost function, to allow
computing inertia from python.

Closes rapidsai#972
  • Loading branch information
benfred committed Nov 17, 2022
1 parent 611abc7 commit 403e955
Show file tree
Hide file tree
Showing 10 changed files with 343 additions and 31 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/fused_l2_min_arg.cu
src/distance/update_centroids_float.cu
src/distance/update_centroids_double.cu
src/distance/cluster_cost_float.cu
src/distance/cluster_cost_double.cu
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
Expand Down
17 changes: 16 additions & 1 deletion cpp/include/raft_distance/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,19 @@ void update_centroids(raft::handle_t const& handle,
double* new_centroids,
double* weight_per_cluster);

} // namespace raft::cluster::kmeans::runtime
void cluster_cost(raft::handle_t const& handle,
const float* X,
int n_samples,
int n_features,
int n_clusters,
const float* centroids,
float* cost);

void cluster_cost(raft::handle_t const& handle,
const double* X,
int n_samples,
int n_features,
int n_clusters,
const double* centroids,
double* cost);
} // namespace raft::cluster::kmeans::runtime
79 changes: 79 additions & 0 deletions cpp/src/distance/cluster_cost.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/handle.hpp>

namespace raft::cluster::kmeans::runtime {
template <typename ElementType, typename IndexType>
void cluster_cost(const raft::handle_t& handle,
const ElementType* X,
IndexType n_samples,
IndexType n_features,
IndexType n_clusters,
const ElementType* centroids,
ElementType* cost)
{
rmm::device_uvector<char> workspace(n_samples * sizeof(IndexType), handle.get_stream());

rmm::device_uvector<ElementType> x_norms(n_samples, handle.get_stream());
rmm::device_uvector<ElementType> centroid_norms(n_clusters, handle.get_stream());
raft::linalg::rowNorm(
x_norms.data(), X, n_samples, n_features, raft::linalg::L2Norm, true, handle.get_stream());
raft::linalg::rowNorm(centroid_norms.data(),
centroids,
n_clusters,
n_features,
raft::linalg::L2Norm,
true,
handle.get_stream());

auto min_cluster_distance =
raft::make_device_vector<raft::KeyValuePair<IndexType, ElementType>>(handle, n_samples);
raft::distance::fusedL2NNMinReduce(min_cluster_distance.data_handle(),
X,
centroids,
x_norms.data(),
centroid_norms.data(),
n_samples,
n_features,
n_clusters,
(void*)workspace.data(),
true,
true,
handle.get_stream());

auto distances = raft::make_device_vector<ElementType, IndexType>(handle, n_samples);
thrust::transform(
handle.get_thrust_policy(),
min_cluster_distance.data_handle(),
min_cluster_distance.data_handle() + n_samples,
distances.data_handle(),
[] __device__(const raft::KeyValuePair<IndexType, ElementType>& a) { return a.value; });

rmm::device_scalar<ElementType> device_cost(0, handle.get_stream());
raft::cluster::kmeans::cluster_cost(
handle,
distances.view(),
workspace,
make_device_scalar_view<ElementType>(device_cost.data()),
[] __device__(const ElementType& a, const ElementType& b) { return a + b; });

raft::update_host(cost, device_cost.data(), 1, handle.get_stream());
}
} // namespace raft::cluster::kmeans::runtime
34 changes: 34 additions & 0 deletions cpp/src/distance/cluster_cost_double.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cluster_cost.cuh"
#include <raft/distance/distance_types.hpp>
#include <raft/distance/specializations.cuh>
#include <raft/handle.hpp>

namespace raft::cluster::kmeans::runtime {

void cluster_cost(const raft::handle_t& handle,
const double* X,
int n_samples,
int n_features,
int n_clusters,
const double* centroids,
double* cost)
{
cluster_cost<double, int>(handle, X, n_samples, n_features, n_clusters, centroids, cost);
}
} // namespace raft::cluster::kmeans::runtime
34 changes: 34 additions & 0 deletions cpp/src/distance/cluster_cost_float.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cluster_cost.cuh"
#include <raft/distance/distance_types.hpp>
#include <raft/distance/specializations.cuh>
#include <raft/handle.hpp>

namespace raft::cluster::kmeans::runtime {

void cluster_cost(const raft::handle_t& handle,
const float* X,
int n_samples,
int n_features,
int n_clusters,
const float* centroids,
float* cost)
{
cluster_cost<float, int>(handle, X, n_samples, n_features, n_clusters, centroids, cost);
}
} // namespace raft::cluster::kmeans::runtime
123 changes: 94 additions & 29 deletions python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,24 @@ from libcpp cimport bool, nullptr

from pylibraft.common import Handle
from pylibraft.common.handle import auto_sync_handle

from pylibraft.common.handle cimport handle_t

from pylibraft.common.input_validation import *
from pylibraft.distance import DISTANCE_TYPES

from pylibraft.cpp.kmeans cimport (
cluster_cost as cpp_cluster_cost,
update_centroids,
)


def is_c_cont(cai, dt):
return "strides" not in cai or \
cai["strides"] is None or \
cai["strides"][1] == dt.itemsize


cdef extern from "raft_distance/kmeans.hpp" \
namespace "raft::cluster::kmeans::runtime":

cdef void update_centroids(
const handle_t& handle,
const double *X,
int n_samples,
int n_features,
int n_clusters,
const double *sample_weights,
const double *centroids,
const int* labels,
double *new_centroids,
double *weight_per_cluster) except +

cdef void update_centroids(
const handle_t& handle,
const float *X,
int n_samples,
int n_features,
int n_clusters,
const float *sample_weights,
const float *centroids,
const int* labels,
float *new_centroids,
float *weight_per_cluster) except +


@auto_sync_handle
def compute_new_centroids(X,
centroids,
Expand Down Expand Up @@ -109,7 +87,6 @@ def compute_new_centroids(X,
from pylibraft.common import Handle
from pylibraft.cluster.kmeans import compute_new_centroids
from pylibraft.distance import fused_l2_nn_argmin
# A single RAFT handle can optionally be reused across
# pylibraft functions.
Expand Down Expand Up @@ -220,3 +197,91 @@ def compute_new_centroids(X,
<double*> weight_per_cluster_ptr)
else:
raise ValueError("dtype %s not supported" % x_dt)


@auto_sync_handle
def cluster_cost(X, centroids, handle=None):
"""
Compute cluster cost given an input matrix and existing centroids
Parameters
----------
X : Input CUDA array interface compliant matrix shape (m, k)
centroids : Input CUDA array interface compliant matrix shape
(n_clusters, k)
{handle_docstring}
Examples
--------
.. code-block:: python
import cupy as cp
from pylibraft.cluster.kmeans import cluster_cost
n_samples = 5000
n_features = 50
n_clusters = 3
X = cp.random.random_sample((n_samples, n_features),
dtype=cp.float32)
centroids = cp.random.random_sample((n_clusters, n_features),
dtype=cp.float32)
inertia = cluster_cost(X, centroids)
"""
x_cai = X.__cuda_array_interface__
centroids_cai = centroids.__cuda_array_interface__

m = x_cai["shape"][0]
x_k = x_cai["shape"][1]
n_clusters = centroids_cai["shape"][0]

centroids_k = centroids_cai["shape"][1]

x_dt = np.dtype(x_cai["typestr"])
centroids_dt = np.dtype(centroids_cai["typestr"])

if not do_cols_match(X, centroids):
raise ValueError("X and centroids must have same number of columns.")

x_ptr = <uintptr_t>x_cai["data"][0]
centroids_ptr = <uintptr_t>centroids_cai["data"][0]

handle = handle if handle is not None else Handle()
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

x_c_contiguous = is_c_cont(x_cai, x_dt)
centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt)

if not x_c_contiguous or not centroids_c_contiguous:
raise ValueError("Inputs must all be c contiguous")

if not do_dtypes_match(X, centroids):
raise ValueError("Inputs must all have the same dtypes "
"(float32 or float64)")

cdef float f_cost = 0
cdef double d_cost = 0

if x_dt == np.float32:
cpp_cluster_cost(deref(h),
<float*> x_ptr,
<int> m,
<int> x_k,
<int> n_clusters,
<float*> centroids_ptr,
<float*> &f_cost)
return f_cost
elif x_dt == np.float64:
cpp_cluster_cost(deref(h),
<double*> x_ptr,
<int> m,
<int> x_k,
<int> n_clusters,
<double*> centroids_ptr,
<double*> &d_cost)
return d_cost
else:
raise ValueError("dtype %s not supported" % x_dt)
Empty file.
Empty file.
Loading

0 comments on commit 403e955

Please sign in to comment.