diff --git a/docs/references.bib b/docs/references.bib index aa72b91df..f5a81da4c 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -479,3 +479,28 @@ @Article{higham:1997 doi = "10.1023/A:1019150005407", url = "https://doi.org/10.1023/A:1019150005407" } + +@Article{lloyd:82, + author={Lloyd, S.}, + journal={IEEE Transactions on Information Theory}, + title={Least squares quantization in PCM}, + year={1982}, + volume={28}, + number={2}, + pages={129-137}, + doi={10.1109/TIT.1982.1056489} +} + +@inproceedings{arthur:07, + author = {Arthur, David and Vassilvitskii, Sergei}, + title = {K-Means++: The Advantages of Careful Seeding}, + year = {2007}, + isbn = {9780898716245}, + publisher = {Society for Industrial and Applied Mathematics}, + address = {USA}, + booktitle = {Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms}, + pages = {1027–1035}, + numpages = {9}, + location = {New Orleans, Louisiana}, + series = {SODA '07} +} diff --git a/docs/tools.rst b/docs/tools.rst index d6128d613..2a890d52a 100644 --- a/docs/tools.rst +++ b/docs/tools.rst @@ -43,3 +43,11 @@ Soft Sorting Algorithms soft_sort.ranks soft_sort.sort soft_sort.sort_with + +Clustering +---------- +.. autosummary:: + :toctree: _autosummary + + k_means.k_means + k_means.KMeansOutput diff --git a/ott/core/fixed_point_loop.py b/ott/core/fixed_point_loop.py index 7e2587fc4..e800ea824 100644 --- a/ott/core/fixed_point_loop.py +++ b/ott/core/fixed_point_loop.py @@ -119,7 +119,7 @@ def fixpoint_iter_fwd( """ force_scan = (min_iterations == max_iterations) compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1 - states = jax.tree_map( + states = jax.tree_util.tree_map( lambda x: jnp.zeros((max_iterations // inner_iterations + 1,) + x.shape, dtype=x.dtype), state ) @@ -176,7 +176,7 @@ def fixpoint_iter_bwd( force_scan = (min_iterations == max_iterations) constants, iteration, states = res # The tree may contain some python floats - g_constants = jax.tree_map( + g_constants = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x, dtype=x.dtype) if isinstance(x, (np.ndarray, jnp.ndarray)) else 0, constants ) @@ -202,7 +202,9 @@ def one_iteration(iteration_state, compute_error): def unrolled_body_fn(iteration_g_gconst): iteration, g, g_constants = iteration_g_gconst - state = jax.tree_map(lambda x: x[iteration // inner_iterations], states) + state = jax.tree_util.tree_map( + lambda x: x[iteration // inner_iterations], states + ) _, pullback = jax.vjp( unrolled_body_fn_no_errors, iteration, constants, state ) diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index e369c2c18..f98c573ec 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -54,9 +54,10 @@ def padder(cls, dim: int) -> jnp.ndarray: return jnp.zeros((1, dim)) def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: - return self.pairwise( - x, y - ) + (0 if self.norm is None else self.norm(x) + self.norm(y)) + cost = self.pairwise(x, y) + if self.norm is None: + return cost + return cost + self.norm(x) + self.norm(y) def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Compute matrix of all costs (including norms) for vectors in x / y. @@ -99,7 +100,7 @@ def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: """Compute minus twice the dot-product between vectors.""" - return -2 * jnp.vdot(x, y) + return -2. * jnp.vdot(x, y) def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: """Output barycenter of vectors when using squared-Euclidean distance.""" @@ -121,7 +122,8 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: y_norm = jnp.linalg.norm(y, axis=-1) cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge) cosine_distance = 1.0 - cosine_similarity - return cosine_distance + # similarity is in [-1, 1], clip because of numerical imprecisions + return jnp.clip(cosine_distance, 0., 2.) def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float: raise NotImplementedError("Barycenter for cosine cost not yet implemented.") diff --git a/ott/geometry/geometry.py b/ott/geometry/geometry.py index 590a1b1ed..37c09156c 100644 --- a/ott/geometry/geometry.py +++ b/ott/geometry/geometry.py @@ -100,7 +100,7 @@ def __init__( self._kwargs = {**{'init': None, 'decay': None}, **kwargs} @property - def cost_rank(self) -> None: + def cost_rank(self) -> Optional[int]: """Output rank of cost matrix, if any was provided.""" return None @@ -198,7 +198,7 @@ def is_symmetric(self) -> bool: @property def inv_scale_cost(self) -> float: """Compute and return inverse of scaling factor for cost matrix.""" - if isinstance(self._scale_cost, (int, float)): + if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)): return 1.0 / self._scale_cost self = self._masked_geom(mask_value=jnp.nan) if self._scale_cost == 'max_cost': diff --git a/ott/geometry/grid.py b/ott/geometry/grid.py index 8bbb9e742..ffdc449e7 100644 --- a/ott/geometry/grid.py +++ b/ott/geometry/grid.py @@ -304,10 +304,10 @@ def transport_from_scalings( ) -> NoReturn: """Not implemented, use :meth:`apply_transport_from_scalings` instead.""" raise ValueError( - 'Grid geometry cannot instantiate a transport matrix, use', - ' apply_transport_from_scalings(...) if you wish to ', - ' apply the transport matrix to a vector, or use a point ' - ' cloud geometry instead' + 'Grid geometry cannot instantiate a transport matrix, use ', + 'apply_transport_from_scalings(...) if you wish to ', + 'apply the transport matrix to a vector, or use a point ' + 'cloud geometry instead.' ) def subset( diff --git a/ott/geometry/low_rank.py b/ott/geometry/low_rank.py index 426e792a0..4992519d0 100644 --- a/ott/geometry/low_rank.py +++ b/ott/geometry/low_rank.py @@ -100,7 +100,7 @@ def is_symmetric(self) -> bool: @property def inv_scale_cost(self) -> float: - if isinstance(self._scale_cost, (int, float)): + if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == 'max_bound': diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 06b8cc32a..e09dad5bb 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -131,9 +131,14 @@ def is_online(self) -> bool: is computed on-the-fly.""" return self.batch_size is not None + # TODO(michalk8): when refactoring, consider PC as a subclass of LR? + @property + def cost_rank(self) -> int: + return self.x.shape[1] + @property def inv_scale_cost(self) -> float: - if isinstance(self._scale_cost, (int, float)): + if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == 'max_cost': @@ -567,6 +572,17 @@ def tree_unflatten(cls, aux_data, children): x, y, cost_fn=cost_fn, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data ) + def _cosine_to_sqeucl(self) -> 'PointCloud': + assert isinstance(self._cost_fn, costs.Cosine), type(self._cost_fn) + assert self.power == 2, self.power + (x, y, *args, _), aux_data = self.tree_flatten() + x = x / jnp.linalg.norm(x, axis=-1, keepdims=True) + y = y / jnp.linalg.norm(y, axis=-1, keepdims=True) + # TODO(michalk8): find a better way + aux_data["scale_cost"] = 2. / self.inv_scale_cost + cost_fn = costs.Euclidean() + return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn]) + def to_LRCGeometry( self, scale: float = 1.0, diff --git a/ott/tools/__init__.py b/ott/tools/__init__.py index 8d86f7272..1dd1945e5 100644 --- a/ott/tools/__init__.py +++ b/ott/tools/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. """OTT tools: A set of tools to use OT in differentiable ML pipelines.""" -from . import gaussian_mixture, plot, sinkhorn_divergence, soft_sort, transport +from . import gaussian_mixture, k_means, plot, sinkhorn_divergence, soft_sort, transport diff --git a/ott/tools/k_means.py b/ott/tools/k_means.py new file mode 100644 index 000000000..d9dd667ed --- /dev/null +++ b/ott/tools/k_means.py @@ -0,0 +1,367 @@ +# Copyright 2022 The OTT Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import math +from typing import Any, Callable, NamedTuple, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from typing_extensions import Literal + +from ott.core import fixed_point_loop +from ott.geometry import costs, pointcloud + +__all__ = ["k_means", "KMeansOutput"] + +Init_t = Union[Literal["k-means++", "random"], + Callable[[pointcloud.PointCloud, int, jnp.ndarray], jnp.ndarray]] + + +class KPPState(NamedTuple): + key: jnp.ndarray + centroids: jnp.ndarray + centroid_dists: jnp.ndarray + + +class KMeansState(NamedTuple): + centroids: jnp.ndarray + prev_assignment: jnp.ndarray + assignment: jnp.ndarray + errors: jnp.ndarray + center_shift: float + + +class KMeansOutput(NamedTuple): + """Output of the :func:`~ott.tools.k_means.k_means` algorithm. + + Args: + centroids: Array of shape ``[k, ndim]`` containing the centroids. + assignment: Array of shape ``[n,]`` containing the labels. + converged: Whether the algorithm has converged. + iteration: The number of iterations run. + error: (Weighted) sum of squared distances from each point to its closest + center. + inner_errors: Array of shape ``[max_iterations,]`` containing the ``error`` + at every iteration. + """ + centroids: jnp.ndarray + assignment: jnp.ndarray + converged: bool + iteration: int + error: float + inner_errors: Optional[jnp.ndarray] + + @classmethod + def _from_state( + cls, + state: KMeansState, + *, + tol: float, + store_inner_errors: bool = False + ) -> "KMeansOutput": + errs = state.errors + mask = errs == -1 + error = jnp.nanmin(jnp.where(mask, jnp.nan, errs)) + + assignment_same = jnp.all(state.prev_assignment == state.assignment) + tol_satisfied = jnp.logical_or(jnp.any(mask), (errs[-2] - errs[-1]) <= tol) + converged = jnp.logical_or(assignment_same, tol_satisfied) + + return cls( + centroids=state.centroids, + assignment=state.assignment, + converged=converged, + iteration=jnp.sum(~mask), + error=error, + inner_errors=errs if store_inner_errors else None, + ) + + +def _random_init( + geom: pointcloud.PointCloud, k: int, key: jnp.ndarray +) -> jnp.ndarray: + ixs = jnp.arange(geom.shape[0]) + ixs = jax.random.choice(key, ixs, shape=(k,), replace=False) + return geom.subset(ixs, None).x + + +def _k_means_plus_plus( + geom: pointcloud.PointCloud, + k: int, + key: jnp.ndarray, + n_local_trials: Optional[int] = None, +) -> jnp.ndarray: + + def init_fn(geom: pointcloud.PointCloud, key: jnp.ndarray) -> KPPState: + key, next_key = jax.random.split(key, 2) + ix = jax.random.choice(key, jnp.arange(geom.shape[0]), shape=()) + centroids = jnp.full((k, geom.cost_rank), jnp.inf).at[0].set(geom.x[ix]) + dists = geom.subset(ix, None).cost_matrix[0] + return KPPState(key=next_key, centroids=centroids, centroid_dists=dists) + + def body_fn( + iteration: int, const: Tuple[pointcloud.PointCloud, jnp.ndarray], + state: KPPState, compute_error: bool + ) -> KPPState: + del compute_error + key, next_key = jax.random.split(state.key, 2) + geom, ixs = const + + # no need to normalize when `replace=True` + probs = state.centroid_dists + ixs = jax.random.choice( + key, ixs, shape=(n_local_trials,), p=probs, replace=True + ) + geom = geom.subset(ixs, None) + + candidate_dists = jnp.minimum(geom.cost_matrix, state.centroid_dists) + best_ix = jnp.argmin(candidate_dists.sum(1)) + + centroids = state.centroids.at[iteration + 1].set(geom.x[best_ix]) + centroid_dists = candidate_dists[best_ix] + + return KPPState( + key=next_key, centroids=centroids, centroid_dists=centroid_dists + ) + + if n_local_trials is None: + n_local_trials = 2 + int(math.log(k)) + assert n_local_trials > 0, n_local_trials + + state = init_fn(geom, key) + constants = (geom, jnp.arange(geom.shape[0])) + state = fixed_point_loop.fixpoint_iter( + lambda *_, **__: True, + body_fn, + min_iterations=k - 1, + max_iterations=k - 1, + inner_iterations=1, + constants=constants, + state=state + ) + + return state.centroids + + +@functools.partial(jax.vmap, in_axes=[0] + [None] * 9) +def _k_means( + key: jnp.ndarray, + geom: pointcloud.PointCloud, + k: int, + weights: Optional[jnp.ndarray] = None, + init: Init_t = "k-means++", + n_local_trials: Optional[int] = None, + tol: float = 1e-4, + min_iterations: int = 0, + max_iterations: int = 300, + store_inner_errors: bool = False, +) -> KMeansOutput: + + def center_shift( + old_centroids: jnp.ndarray, new_centroids: jnp.ndarray + ) -> float: + return jnp.linalg.norm(old_centroids - new_centroids, ord="fro") ** 2 + + @functools.partial(jax.vmap, in_axes=[0, 0, 0], out_axes=0) + def reallocate_centroids( + ix: jnp.ndarray, + centroid: jnp.ndarray, + weight: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + is_empty = weight <= 0. + new_centroid = (1 - is_empty) * centroid + is_empty * geom.x[ix] + centroid_to_remove = is_empty * weighted_x[ix, :-1] + weight_to_remove = is_empty * weights[ix] + return new_centroid, jnp.concatenate([centroid_to_remove, weight_to_remove]) + + def update_assignment( + centroids: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + (x, _, *args), aux_data = geom.tree_flatten() + cost_matrix = pointcloud.PointCloud.tree_unflatten( + aux_data, [x, centroids] + args + ).cost_matrix + + assignment = jnp.argmin(cost_matrix, axis=1) + dist_to_centers = cost_matrix[jnp.arange(len(assignment)), assignment] + return assignment, dist_to_centers + + def update_centroids( + assignment: jnp.ndarray, dist_to_centers: jnp.ndarray + ) -> jnp.ndarray: + data = jax.ops.segment_sum(weighted_x, assignment, num_segments=k) + centroids, ws = data[:, :-1], data[:, -1:] + + far_ixs = jnp.argsort(dist_to_centers)[-k:][::-1] + centroids, to_remove = reallocate_centroids(far_ixs, centroids, ws) + to_remove = jax.ops.segment_sum( + to_remove, assignment[far_ixs], num_segments=k + ) + centroids -= to_remove[:, :-1] + ws -= to_remove[:, -1:] + + return centroids * jnp.where(ws > 0., 1. / ws, 1.) + + def init_fn(init: Init_t) -> KMeansState: + if init == "k-means++": + init = functools.partial( + _k_means_plus_plus, n_local_trials=n_local_trials + ) + elif init == "random": + init = _random_init + if not callable(init): + raise TypeError( + f"Expected `init` to be 'k-means++', 'random' " + f"or a callable, found `{init_fn!r}`." + ) + + centroids = init(geom, k, key) + if centroids.shape != (k, geom.cost_rank): + raise ValueError( + f"Expected initial centroids to have shape " + f"`{k, geom.cost_rank}`, found `{centroids.shape}`." + ) + n = geom.shape[0] + prev_assignment = jnp.full((n,), -2) + assignment = jnp.full((n,), -1) + errors = jnp.full((max_iterations,), -1.) + + return KMeansState( + centroids=centroids, + prev_assignment=prev_assignment, + assignment=assignment, + center_shift=jnp.inf, + errors=errors, + ) + + def cond_fn(iteration: int, const: Any, state: KMeansState) -> bool: + del const + assignment_not_same = jnp.any(state.prev_assignment != state.assignment) + tol_not_satisfied = state.center_shift > tol + return jnp.logical_and(assignment_not_same, tol_not_satisfied) + + def body_fn( + iteration: int, const: Any, state: KMeansState, compute_error: bool + ) -> KMeansState: + del compute_error, const + + assignment, dist_to_centers = update_assignment(state.centroids) + centroids = update_centroids(assignment, dist_to_centers) + err = jnp.sum(weights * dist_to_centers) + errors = state.errors.at[iteration].set(err) + + return KMeansState( + centroids=centroids, + prev_assignment=state.assignment, + assignment=assignment, + center_shift=center_shift(state.centroids, centroids), + errors=errors, + ) + + def finalize_assignment(state: KMeansState) -> KMeansState: + last_iter = jnp.sum(state.errors != -1) - 1 + assignment, dist_to_centers = update_assignment(state.centroids) + err = jnp.sum(weights * dist_to_centers) + return state._replace( + assignment=assignment, errors=state.errors.at[last_iter].set(err) + ) + + weighted_x = jnp.hstack([weights[:, None] * geom.x, weights[:, None]]) + state = fixed_point_loop.fixpoint_iter( + cond_fn, + body_fn, + min_iterations=min_iterations, + max_iterations=max_iterations, + inner_iterations=1, + constants=None, + state=init_fn(init) + ) + state = jax.lax.cond( + jnp.all(state.prev_assignment == state.assignment), lambda _: _, + finalize_assignment, state + ) + + return KMeansOutput._from_state( + state, tol=tol, store_inner_errors=store_inner_errors + ) + + +def k_means( + geom: Union[jnp.ndarray, pointcloud.PointCloud], + k: int, + weights: Optional[jnp.ndarray] = None, + init: Init_t = "k-means++", + n_init: int = 10, + n_local_trials: Optional[int] = None, + tol: float = 1e-4, + min_iterations: int = 0, + max_iterations: int = 300, + store_inner_errors: bool = False, + key: Optional[jnp.ndarray] = None, +) -> KMeansOutput: + r"""K-means clustering using Lloyd's algorithm :cite:`lloyd:82`. + + Args: + geom: Point cloud of shape ``[n, ndim]`` to cluster. If passed as an array, + :class:`~ott.geometry.costs.Euclidean` cost is assumed. + k: The number of clusters. + weights: The weights of input points. These weights are considered when + computing the centroids and inertia. If ``None``, use uniform weights. + init: Initialization method. Can be one of the following: + + - **'k-means++'** - select initial centroids that are + :math:`\mathcal{O}(\log k)`-optimal :cite:`arthur:07`. + - **'random'** - randomly select ``k`` points from the ``geom``. + - :func:`callable` - a function which takes the point cloud, the number of + clusters and a random key and returns the centroids as an array of shape + ``[k, ndim]``. + + n_init: Number of times k-means will run with different initial seeds. + n_local_trials: Number of local trials when ``init = 'k-means++'``. + If ``None``, :math:`2 + \lfloor log(k) \rfloor` is used. + tol: Relative tolerance with respect to the Frobenius norm of the centroids' + shift between two consecutive iterations. + min_iterations: Minimum number of iterations. + max_iterations: Maximum number of iterations. + store_inner_errors: Whether to store the errors (inertia) at each iteration. + key: Random key to seed the initializations. + + Returns: + The k-means clustering result. + """ + if isinstance(geom, jnp.ndarray): + geom = pointcloud.PointCloud(geom) + if isinstance(geom._cost_fn, costs.Cosine): + geom = geom._cosine_to_sqeucl() + assert geom.is_squared_euclidean + + if geom.is_online: + # to allow materializing the cost matrix + children, aux_data = geom.tree_flatten() + aux_data["batch_size"] = None + geom = type(geom).tree_unflatten(aux_data, children) + + if weights is None: + weights = jnp.ones(geom.shape[0]) + assert weights.shape == (geom.shape[0],) + + if key is None: + key = jax.random.PRNGKey(0) + keys = jax.random.split(key, n_init) + out = _k_means( + keys, geom, k, weights, init, n_local_trials, tol, min_iterations, + max_iterations, store_inner_errors + ) + best_ix = jnp.argmin(out.error) + return jax.tree_util.tree_map(lambda arr: arr[best_ix], out) diff --git a/setup.cfg b/setup.cfg index 78041ab90..f250cbbbf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,6 +50,7 @@ install_requires = [options.extras_require] test = + sklearn pytest pytest-xdist pytest-cov diff --git a/tests/geometry/geometry_pointcloud_apply_test.py b/tests/geometry/geometry_pointcloud_apply_test.py index bb48b7daf..bc31770a6 100644 --- a/tests/geometry/geometry_pointcloud_apply_test.py +++ b/tests/geometry/geometry_pointcloud_apply_test.py @@ -14,6 +14,7 @@ # Lint as: python3 """Tests for apply_cost and apply_kernel.""" +from typing import Union import jax import jax.numpy as jnp @@ -114,3 +115,59 @@ def test_apply_cost_without_norm(self, rng: jnp.ndarray, axis: 1): actual = pc.apply_cost(arr, axis=axis).squeeze() np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) + + +class TestPointCloudCosineConversion: + + @pytest.mark.parametrize( + "scale_cost", ["mean", "median", "max_cost", "max_norm", 41] + ) + def test_cosine_to_sqeucl_conversion( + self, rng: jnp.ndarray, scale_cost: Union[str, float] + ): + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, shape=(101, 4)) + y = jax.random.normal(key2, shape=(123, 4)) + cosine = pointcloud.PointCloud( + x, y, cost_fn=costs.Cosine(), scale_cost=scale_cost + ) + + eucl = cosine._cosine_to_sqeucl() + assert eucl.is_squared_euclidean + + np.testing.assert_allclose( + 2. * eucl.inv_scale_cost, cosine.inv_scale_cost, rtol=1e-6, atol=1e-6 + ) + np.testing.assert_allclose( + eucl.mean_cost_matrix, cosine.mean_cost_matrix, rtol=1e-6, atol=1e-6 + ) + np.testing.assert_allclose( + eucl.median_cost_matrix, + cosine.median_cost_matrix, + rtol=1e-6, + atol=1e-6 + ) + np.testing.assert_allclose( + eucl.cost_matrix, cosine.cost_matrix, rtol=1e-6, atol=1e-6 + ) + + @pytest.mark.parametrize( + "scale_cost", ["mean", "median", "max_cost", "max_norm", 2.0] + ) + @pytest.mark.parametrize("axis", [0, 1]) + def test_apply_cost_cosine_to_sqeucl( + self, rng: jnp.ndarray, axis: int, scale_cost: Union[str, float] + ): + key1, key2 = jax.random.split(rng, 2) + x = jax.random.normal(key1, shape=(17, 5)) + y = jax.random.normal(key2, shape=(12, 5)) + cosine = pointcloud.PointCloud( + x, y, cost_fn=costs.Cosine(), scale_cost=scale_cost + ) + eucl = cosine._cosine_to_sqeucl() + arr = jnp.ones((x.shape[0],)) if axis == 0 else jnp.ones((y.shape[0],)) + + expected = cosine.apply_cost(arr, axis=axis) + actual = eucl.apply_cost(arr, axis=axis) + + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py new file mode 100644 index 000000000..81b6b1c30 --- /dev/null +++ b/tests/tools/k_means_test.py @@ -0,0 +1,385 @@ +from typing import Any, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from sklearn import datasets +from sklearn.cluster import KMeans +from sklearn.cluster._k_means_common import _is_same_clustering +from sklearn.cluster._kmeans import kmeans_plusplus +from typing_extensions import Literal + +from ott.geometry import costs, pointcloud +from ott.tools import k_means + + +def make_blobs( + *args: Any, + cost_fn: Optional[Literal['sqeucl', 'cosine']] = None, + **kwargs: Any +) -> Tuple[Union[jnp.ndarray, pointcloud.PointCloud], jnp.ndarray, jnp.ndarray]: + X, y, c = datasets.make_blobs(*args, return_centers=True, **kwargs) + X, y, c = jnp.asarray(X), jnp.asarray(y), jnp.asarray(c) + if cost_fn is None: + pass + elif cost_fn == 'sqeucl': + X = pointcloud.PointCloud(X, cost_fn=costs.Euclidean()) + elif cost_fn == 'cosine': + X = pointcloud.PointCloud(X, cost_fn=costs.Cosine()) + else: + raise NotImplementedError(cost_fn) + + return X, y, c + + +def compute_assignment( + x: jnp.ndarray, + centers: jnp.ndarray, + weights: Optional[jnp.ndarray] = None +) -> Tuple[jnp.ndarray, float]: + if weights is None: + weights = jnp.ones(x.shape[0]) + cost_matrix = pointcloud.PointCloud(x, centers).cost_matrix + assignment = jnp.argmin(cost_matrix, axis=1) + dist_to_centers = cost_matrix[jnp.arange(len(assignment)), assignment] + + return assignment, jnp.sum(weights * dist_to_centers) + + +class TestKmeansPlusPlus: + + @pytest.mark.fast.with_args("n_local_trials", [None, 1, 5], only_fast=-1) + def test_n_local_trials(self, rng: jnp.ndarray, n_local_trials): + n, k = 150, 4 + key1, key2 = jax.random.split(rng) + geom, _, c = make_blobs( + n_samples=n, centers=k, cost_fn='sqeucl', random_state=0 + ) + centers1 = k_means._k_means_plus_plus(geom, k, key1, n_local_trials) + centers2 = k_means._k_means_plus_plus(geom, k, key2, 20) + + shift1 = jnp.linalg.norm(centers1 - c, ord="fro") ** 2 + shift2 = jnp.linalg.norm(centers2 - c, ord="fro") ** 2 + + assert shift1 > shift2 + + @pytest.mark.fast.with_args("k", [4, 5, 10], only_fast=0) + def test_matches_sklearn(self, rng: jnp.ndarray, k: int): + ndim = 2 + geom, _, _ = make_blobs( + n_samples=200, + centers=k, + n_features=ndim, + cost_fn='sqeucl', + random_state=0 + ) + gt_centers, _ = kmeans_plusplus(np.asarray(geom.x), k, random_state=1) + pred_centers = k_means._k_means_plus_plus(geom, k, rng) + + _, gt_inertia = compute_assignment(geom.x, gt_centers) + _, pred_inertia = compute_assignment(geom.x, pred_centers) + + assert pred_centers.shape == (k, ndim) + np.testing.assert_array_equal( + pred_centers.max(axis=0) <= geom.x.max(axis=0), True + ) + np.testing.assert_array_equal( + pred_centers.min(axis=0) >= geom.x.min(axis=0), True + ) + # the largest was 70.56378 + assert jnp.abs(pred_inertia - gt_inertia) <= 100 + + def test_initialization_differentiable(self, rng: jnp.ndarray): + + def callback(x: jnp.ndarray) -> float: + geom = pointcloud.PointCloud(x) + centers = k_means._k_means_plus_plus(geom, k=3, key=rng) + _, inertia = compute_assignment(x, centers) + return inertia + + X, _, _ = make_blobs(n_samples=34, random_state=0) + fun = jax.value_and_grad(callback) + ineria1, grad = fun(X) + ineria2, _ = fun(X - 0.1 * grad) + + assert ineria2 < ineria1 + + +class TestKmeans: + + @pytest.mark.fast + @pytest.mark.parametrize("k", [1, 6]) + def test_k_means_output(self, rng: jnp.ndarray, k: int): + max_iter, ndim = 10, 4 + geom, gt_assignment, _ = make_blobs( + n_samples=50, n_features=ndim, centers=k, random_state=42 + ) + gt_assignment = np.array(gt_assignment) + + res = k_means.k_means( + geom, k, max_iterations=max_iter, store_inner_errors=False, key=rng + ) + pred_assignment = np.array(res.assignment) + + assert res.centroids.shape == (k, ndim) + assert res.converged + assert res.error >= 0. + assert res.inner_errors is None + assert _is_same_clustering(pred_assignment, gt_assignment, k) + + @pytest.mark.fast + def test_k_means_simple_example(self): + expected_labels = np.asarray([1, 1, 0, 0], dtype=np.int32) + expected_centers = np.asarray([[0.75, 1], [0.25, 0]]) + + x = jnp.asarray([[0, 0], [0.5, 0], [0.5, 1], [1, 1]]) + init = lambda *_: jnp.array([[0.5, 0.5], [3, 3]]) + + res = k_means.k_means(x, k=2, init=init) + + np.testing.assert_array_equal(res.assignment, expected_labels) + np.testing.assert_allclose(res.centroids, expected_centers) + np.testing.assert_allclose(res.error, 0.25) + assert res.iteration == 3 + + @pytest.mark.fast.with_args( + "init", + ["k-means++", "random", "callable", "wrong-callable"], + only_fast=1, + ) + def test_init_method(self, rng: jnp.ndarray, init: str): + if init == "callable": + init_fn = lambda geom, k, _: geom.x[:k] + elif init == "wrong-callable": + init_fn = lambda geom, k, _: geom.x[:k + 1] + else: + init_fn = init + + k = 3 + geom, _, _ = make_blobs(n_samples=50, centers=k + 1) + if init == "wrong-callable": + with pytest.raises(ValueError, match=r"Expected initial centroids"): + _ = k_means.k_means(geom, k, init=init_fn) + else: + _ = k_means.k_means(geom, k, init=init_fn) + + def test_k_means_plus_plus_better_than_random(self, rng: jnp.ndarray): + k = 5 + key1, key2 = jax.random.split(rng, 2) + geom, _, _ = make_blobs(n_samples=50, centers=k, random_state=10) + + res_random = k_means.k_means(geom, k, init="random", key=key1) + res_kpp = k_means.k_means(geom, k, init="k-means++", key=key2) + + assert res_random.converged + assert res_kpp.converged + assert res_kpp.iteration < res_random.iteration + assert res_kpp.error <= res_random.error + + def test_larger_n_init_helps(self, rng: jnp.ndarray): + k = 10 + geom, _, _ = make_blobs(n_samples=150, centers=k, random_state=0) + + res = k_means.k_means(geom, k, n_init=3, key=rng) + res_larger_n_init = k_means.k_means(geom, k, n_init=20, key=rng) + + assert res_larger_n_init.error < res.error + + @pytest.mark.parametrize("max_iter", [8, 16]) + def test_store_inner_errors(self, rng: jnp.ndarray, max_iter: int): + ndim, k = 10, 4 + geom, _, _ = make_blobs( + n_samples=40, n_features=ndim, centers=k, random_state=43 + ) + + res = k_means.k_means( + geom, k, max_iterations=max_iter, store_inner_errors=True, key=rng + ) + + errors = res.inner_errors + assert errors.shape == (max_iter,) + assert res.iteration == jnp.sum(errors > 0.) + # check if error is decreasing + np.testing.assert_array_equal(jnp.diff(errors[::-1]) >= 0., True) + + def test_strict_tolerance(self, rng: jnp.ndarray): + k = 11 + geom, _, _ = make_blobs(n_samples=200, centers=k, random_state=39) + + res = k_means.k_means(geom, k=k, tol=1., key=rng) + res_strict = k_means.k_means(geom, k=k, tol=0., key=rng) + + assert res.converged + assert res_strict.converged + assert res.iteration < res_strict.iteration + + @pytest.mark.parametrize( + "tol", [1e-3, 0.], ids=["weak-convergence", "strict-convergence"] + ) + def test_convergence_force_scan(self, rng: jnp.ndarray, tol: float): + k, n_iter = 9, 20 + geom, _, _ = make_blobs(n_samples=100, centers=k, random_state=37) + + res = k_means.k_means( + geom, + k=k, + tol=tol, + min_iterations=n_iter, + max_iterations=n_iter, + store_inner_errors=True, + key=rng + ) + + assert res.converged + assert res.iteration == n_iter + np.testing.assert_array_equal(res.inner_errors == -1, False) + + def test_k_means_min_iterations(self, rng: jnp.ndarray): + k, min_iter = 8, 12 + geom, _, _ = make_blobs(n_samples=160, centers=k, random_state=38) + + res = k_means.k_means( + geom, + k - 2, + store_inner_errors=True, + min_iterations=min_iter, + max_iterations=20, + tol=0., + key=rng + ) + + assert res.converged + assert jnp.sum(res.inner_errors != -1) >= min_iter + + def test_weight_scaling_effects_only_inertia(self, rng: jnp.ndarray): + k = 10 + key1, key2 = jax.random.split(rng) + geom, _, _ = make_blobs(n_samples=130, centers=k, random_state=3) + weights = jnp.abs(jax.random.normal(key1, shape=(geom.shape[0],))) + weights_scaled = weights / jnp.sum(weights) + + res = k_means.k_means(geom, k=k - 1, weights=weights) + res_scaled = k_means.k_means(geom, k=k - 1, weights=weights_scaled) + + np.testing.assert_allclose( + res.centroids, res_scaled.centroids, rtol=1e-5, atol=1e-5 + ) + assert _is_same_clustering( + np.array(res.assignment), np.array(res_scaled.assignment), k + ) + np.testing.assert_allclose( + res.error, res_scaled.error * jnp.sum(weights), rtol=1e-3, atol=1e-3 + ) + + @pytest.mark.fast + def test_empty_weights(self, rng: jnp.ndarray): + n, ndim, k, d = 20, 2, 3, 5. + x = np.random.normal(size=(n, ndim)) + x[:, 0] += d + x[:, 1] += d + y = np.random.normal(size=(n, ndim)) + y[:, 0] -= d + y[:, 1] -= d + z = np.random.normal(size=(n, ndim)) + z[:, 0] += d + z[:, 1] -= d + w = np.random.normal(size=(n, ndim)) + w[:, 0] -= d + w[:, 1] += d + x = jnp.concatenate((x, y, z, w)) + # ignore `x` by setting its weights to 0 + weights = jnp.ones((x.shape[0],)).at[:n].set(0.) + + expected_centroids = jnp.stack([w.mean(0), z.mean(0), y.mean(0)]) + res = k_means.k_means(x, k=k, weights=weights, key=rng) + + cost = pointcloud.PointCloud(res.centroids, expected_centroids).cost_matrix + ixs = jnp.argmin(cost, axis=1) + + np.testing.assert_array_equal(jnp.sort(ixs), jnp.arange(k)) + + total_shift = jnp.sum(cost[jnp.arange(k), ixs]) + np.testing.assert_allclose(total_shift, 0., rtol=1e-3, atol=1e-3) + + def test_cosine_cost_fn(self): + k = 4 + geom, _, _ = make_blobs(n_samples=75) + geom_scaled = pointcloud.PointCloud(geom * 10., cost_fn=costs.Cosine()) + geom = pointcloud.PointCloud(geom, cost_fn=costs.Cosine()) + + res_scaled = k_means.k_means(geom_scaled, k=k) + res = k_means.k_means(geom, k=k) + + np.testing.assert_allclose( + res_scaled.error, res.error, rtol=1e-5, atol=1e-5 + ) + assert _is_same_clustering( + np.array(res_scaled.assignment), np.array(res.assignment), k + ) + + @pytest.mark.fast.with_args("init", ["k-means++", "random"], only_fast=0) + def test_k_means_jitting( + self, rng: jnp.ndarray, init: Literal["k-means++", "random"] + ): + + def callback(x: jnp.ndarray) -> k_means.KMeansOutput: + return k_means.k_means( + x, k=k, init=init, store_inner_errors=True, key=rng + ) + + k = 7 + x, _, _ = make_blobs(n_samples=150, centers=k, random_state=0) + res = jax.jit(callback)(x) + res_jit: k_means.KMeansOutput = jax.jit(callback)(x) + + np.testing.assert_allclose(res.centroids, res_jit.centroids) + np.testing.assert_array_equal(res.assignment, res_jit.assignment) + np.testing.assert_allclose(res.error, res_jit.error) + np.testing.assert_allclose(res.inner_errors, res_jit.inner_errors) + assert res.iteration == res_jit.iteration + assert res.converged == res_jit.converged + + def test_k_means_differentiability(self, rng: jnp.ndarray): + + def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: + return k_means.k_means( + x, k=k, weights=w, min_iterations=10, max_iterations=10, key=key1 + ).error + + k, eps, tol = 4, 1e-3, 1e-3 + x, _, _ = make_blobs(n_samples=150, centers=k, random_state=41) + key1, key2, key3, key4 = jax.random.split(rng, 4) + w = jnp.abs(jax.random.normal(key2, (x.shape[0],))) + + _, (grad_x, grad_w) = jax.value_and_grad(inertia, (0, 1))(x, w) + + v_x = jax.random.normal(key3, shape=x.shape) + v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * eps + v_w = jax.random.normal(key4, shape=w.shape) * eps + v_w = (v_w / jnp.linalg.norm(v_w, axis=-1, keepdims=True)) * eps + + expected = inertia(x + v_x, w) - inertia(x - v_x, w) + actual = 2 * jnp.vdot(v_x, grad_x) + np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) + + expected = inertia(x, w + v_w) - inertia(x, w - v_w) + actual = 2 * jnp.vdot(v_w, grad_w) + np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) + + @pytest.mark.parametrize("tol", [1e-3, 0.]) + @pytest.mark.parametrize("n,k", [(37, 4), (128, 6)]) + def test_clustering_matches_sklearn( + self, rng: jnp.ndarray, n: int, k: int, tol: float + ): + x, _, _ = make_blobs(n_samples=n, centers=k, random_state=41) + + res_kmeans = KMeans(n_clusters=k, n_init=20, tol=tol, random_state=0).fit(x) + res_ours = k_means.k_means(x, k, n_init=20, tol=tol, key=rng) + gt_labels = res_kmeans.labels_ + pred_labels = np.array(res_ours.assignment) + + assert _is_same_clustering(pred_labels, gt_labels, k) + np.testing.assert_allclose( + res_ours.error, res_kmeans.inertia_, rtol=1e-3, atol=1e-3 + )