Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HELP-REQ] Expose KMeans init_plus_plus in pylibraft #1198

Merged
27 changes: 18 additions & 9 deletions python/pylibraft/pylibraft/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,33 @@ def compute_new_centroids(X,

@auto_sync_handle
@auto_convert_output
def init_plus_plus(X, n_clusters, seed=None, handle=None):
def init_plus_plus(X, n_clusters=None, seed=None, handle=None, centroids=None):
if n_clusters is not None and centroids is not None:
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
msg = ("Parameters 'n_clusters' and 'centroids' are exclusive. Only " +
"pass one at a time.")
raise RuntimeError(msg)

cdef device_resources *h = <device_resources*><size_t>handle.getHandle()

X_cai = cai_wrapper(X)
X_cai.validate_shape_dtype(expected_dims=2)
dtype = X_cai.dtype

if centroids is not None:
n_clusters = centroids.shape[0]
else:
centroids_shape = (n_clusters, X_cai.shape[1])
centroids = device_ndarray.empty(centroids_shape, dtype=dtype)

centroids_cai = cai_wrapper(centroids)

# Can't set attributes of KMeansParameters after creating it, so taking
# a detour via a dict to collect the possible constructor arguments
params_ = dict(n_clusters=n_clusters)
if seed is not None:
params_["seed"] = seed
params = KMeansParams(**params_)

X_cai = cai_wrapper(X)
X_cai.validate_shape_dtype(expected_dims=2)
dtype = X_cai.dtype

centroids_shape = (n_clusters, X_cai.shape[1])
centroids = device_ndarray.empty(centroids_shape, dtype=dtype)
centroids_cai = cai_wrapper(centroids)

if dtype == np.float64:
cpp_init_plus_plus(deref(h),
params.c_obj,
Expand Down
44 changes: 37 additions & 7 deletions python/pylibraft/pylibraft/test/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def test_compute_new_centroids(
new_centroids_device = device_ndarray(new_centroids)

sample_weights = np.ones((n_rows,)).astype(dtype) / n_rows
sample_weights_device = (
device_ndarray(sample_weights) if additional_args else None
)
sample_weights_device = device_ndarray(sample_weights) if additional_args else None

# Compute new centroids naively
dists = np.zeros((n_rows, n_clusters), dtype=dtype)
Expand Down Expand Up @@ -141,9 +139,7 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype):
).copy_to_host()
cluster_ids = np.argmin(distances, axis=1)

cluster_distances = np.take_along_axis(
distances, cluster_ids[:, None], axis=1
)
cluster_distances = np.take_along_axis(distances, cluster_ids[:, None], axis=1)

# need reduced tolerance for float32
tol = 1e-3 if dtype == np.float32 else 1e-6
Expand All @@ -165,4 +161,38 @@ def test_init_plus_plus(n_rows, n_cols, n_clusters, dtype):

# Centroids are selected from the existing points
for centroid in centroids_:
assert (centroid == X).all(axis=1).any()
assert (centroid == X).all(axis=1).any()


@pytest.mark.parametrize("n_rows", [100])
@pytest.mark.parametrize("n_cols", [5, 25])
@pytest.mark.parametrize("n_clusters", [4, 15])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_init_plus_plus_preallocated_output(n_rows, n_cols, n_clusters, dtype):
X = np.random.random_sample((n_rows, n_cols)).astype(dtype)
X_device = device_ndarray(X)

centroids = device_ndarray.empty((n_clusters, n_cols), dtype=dtype)

new_centroids = init_plus_plus(X_device, centroids=centroids, seed=1)
new_centroids_ = new_centroids.copy_to_host()

# The shape should not have changed
assert new_centroids_.shape == centroids.shape

# Centroids are selected from the existing points
for centroid in new_centroids_:
assert (centroid == X).all(axis=1).any()


def test_init_plus_plus_exclusive_arguments():
X = np.random.random_sample((10, 5)).astype(np.float64)
X = device_ndarray(X)

n_clusters = 3

centroids = np.random.random_sample((n_clusters, 5)).astype(np.float64)
centroids = device_ndarray(centroids)

with pytest.raises(RuntimeError, match="Parameters 'n_clusters' and 'centroids'"):
init_plus_plus(X, n_clusters, centroids=centroids)