Skip to content

Commit

Permalink
Feature/kmeans++ (#120)
Browse files Browse the repository at this point in the history
* Add initial impl. from `CR.Sparse`

* Rename file, add to __init__

* Initial fixed point iteration

* Add random initialization

* Better KMeansState

* Fix `cond_fn`

* First working version

* Clean output, use `tree_map`

* Remove reference impl. and dead code

* Add TODO

* Expose `cost_rank` in `PointCloud`

* Add initial kmeans++ implementation

* Fix indexing bug

* Remove `set` methods

* Add tolerance to convergence check

* Rename function

* Store inner errors

* Unify kmeans initializer interface, allow custom

* Reorder arguments

* Add strict convergence criterion

* Add convergence iteration to output

* Simplify `cond_fn`, use `max_iter - 1`

* Clip cosine distance to `[0, 2]`

* Require sqEucl geometry, allow arrays

* Fix random/kmeans++ init

* Remove normalization comment

* Fix dividing by 0 when using weights

* Add TODOs

* Fix k-means++ init centroid

* Use `jax.tree_util.tree_map`

* Rename arguments, use sum instead of mean

* Switch order

* Fix `unique_indices=True` in segment sum

* Don't compute assignment in `init_fn`

* Fix weighting

* Use center shift as convergence criterion

* Remove old TODOs

* Improve final assignment

* Fix centroid/weight adjustment

* [ci skip] Allow geometry with cosine cost

* Fix cosine conversion, add test

* Add more cosine -> sqeucl conversion tests

* Add documentation

* Add skeleton tests

* Add kmeans++ tests

* Add k-means initialization test

* Fix bug when removing empty centroids

* Finish tests

* Increase tolerance

* Address comments

* Use smaller eps
  • Loading branch information
michalk8 authored Aug 10, 2022
1 parent 22d3929 commit d7521fd
Show file tree
Hide file tree
Showing 13 changed files with 880 additions and 17 deletions.
25 changes: 25 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,28 @@ @Article{higham:1997
doi = "10.1023/A:1019150005407",
url = "https://doi.org/10.1023/A:1019150005407"
}

@Article{lloyd:82,
author={Lloyd, S.},
journal={IEEE Transactions on Information Theory},
title={Least squares quantization in PCM},
year={1982},
volume={28},
number={2},
pages={129-137},
doi={10.1109/TIT.1982.1056489}
}

@inproceedings{arthur:07,
author = {Arthur, David and Vassilvitskii, Sergei},
title = {K-Means++: The Advantages of Careful Seeding},
year = {2007},
isbn = {9780898716245},
publisher = {Society for Industrial and Applied Mathematics},
address = {USA},
booktitle = {Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms},
pages = {1027–1035},
numpages = {9},
location = {New Orleans, Louisiana},
series = {SODA '07}
}
8 changes: 8 additions & 0 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ Soft Sorting Algorithms
soft_sort.ranks
soft_sort.sort
soft_sort.sort_with

Clustering
----------
.. autosummary::
:toctree: _autosummary

k_means.k_means
k_means.KMeansOutput
8 changes: 5 additions & 3 deletions ott/core/fixed_point_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def fixpoint_iter_fwd(
"""
force_scan = (min_iterations == max_iterations)
compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1
states = jax.tree_map(
states = jax.tree_util.tree_map(
lambda x: jnp.zeros((max_iterations // inner_iterations + 1,) + x.shape,
dtype=x.dtype), state
)
Expand Down Expand Up @@ -176,7 +176,7 @@ def fixpoint_iter_bwd(
force_scan = (min_iterations == max_iterations)
constants, iteration, states = res
# The tree may contain some python floats
g_constants = jax.tree_map(
g_constants = jax.tree_util.tree_map(
lambda x: jnp.zeros_like(x, dtype=x.dtype)
if isinstance(x, (np.ndarray, jnp.ndarray)) else 0, constants
)
Expand All @@ -202,7 +202,9 @@ def one_iteration(iteration_state, compute_error):

def unrolled_body_fn(iteration_g_gconst):
iteration, g, g_constants = iteration_g_gconst
state = jax.tree_map(lambda x: x[iteration // inner_iterations], states)
state = jax.tree_util.tree_map(
lambda x: x[iteration // inner_iterations], states
)
_, pullback = jax.vjp(
unrolled_body_fn_no_errors, iteration, constants, state
)
Expand Down
12 changes: 7 additions & 5 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def padder(cls, dim: int) -> jnp.ndarray:
return jnp.zeros((1, dim))

def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
return self.pairwise(
x, y
) + (0 if self.norm is None else self.norm(x) + self.norm(y))
cost = self.pairwise(x, y)
if self.norm is None:
return cost
return cost + self.norm(x) + self.norm(y)

def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute matrix of all costs (including norms) for vectors in x / y.
Expand Down Expand Up @@ -99,7 +100,7 @@ def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute minus twice the dot-product between vectors."""
return -2 * jnp.vdot(x, y)
return -2. * jnp.vdot(x, y)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
"""Output barycenter of vectors when using squared-Euclidean distance."""
Expand All @@ -121,7 +122,8 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge)
cosine_distance = 1.0 - cosine_similarity
return cosine_distance
# similarity is in [-1, 1], clip because of numerical imprecisions
return jnp.clip(cosine_distance, 0., 2.)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float:
raise NotImplementedError("Barycenter for cosine cost not yet implemented.")
Expand Down
4 changes: 2 additions & 2 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
self._kwargs = {**{'init': None, 'decay': None}, **kwargs}

@property
def cost_rank(self) -> None:
def cost_rank(self) -> Optional[int]:
"""Output rank of cost matrix, if any was provided."""
return None

Expand Down Expand Up @@ -198,7 +198,7 @@ def is_symmetric(self) -> bool:
@property
def inv_scale_cost(self) -> float:
"""Compute and return inverse of scaling factor for cost matrix."""
if isinstance(self._scale_cost, (int, float)):
if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
return 1.0 / self._scale_cost
self = self._masked_geom(mask_value=jnp.nan)
if self._scale_cost == 'max_cost':
Expand Down
8 changes: 4 additions & 4 deletions ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ def transport_from_scalings(
) -> NoReturn:
"""Not implemented, use :meth:`apply_transport_from_scalings` instead."""
raise ValueError(
'Grid geometry cannot instantiate a transport matrix, use',
' apply_transport_from_scalings(...) if you wish to ',
' apply the transport matrix to a vector, or use a point '
' cloud geometry instead'
'Grid geometry cannot instantiate a transport matrix, use ',
'apply_transport_from_scalings(...) if you wish to ',
'apply the transport matrix to a vector, or use a point '
'cloud geometry instead.'
)

def subset(
Expand Down
2 changes: 1 addition & 1 deletion ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def is_symmetric(self) -> bool:

@property
def inv_scale_cost(self) -> float:
if isinstance(self._scale_cost, (int, float)):
if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == 'max_bound':
Expand Down
18 changes: 17 additions & 1 deletion ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,14 @@ def is_online(self) -> bool:
is computed on-the-fly."""
return self.batch_size is not None

# TODO(michalk8): when refactoring, consider PC as a subclass of LR?
@property
def cost_rank(self) -> int:
return self.x.shape[1]

@property
def inv_scale_cost(self) -> float:
if isinstance(self._scale_cost, (int, float)):
if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == 'max_cost':
Expand Down Expand Up @@ -567,6 +572,17 @@ def tree_unflatten(cls, aux_data, children):
x, y, cost_fn=cost_fn, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data
)

def _cosine_to_sqeucl(self) -> 'PointCloud':
assert isinstance(self._cost_fn, costs.Cosine), type(self._cost_fn)
assert self.power == 2, self.power
(x, y, *args, _), aux_data = self.tree_flatten()
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
y = y / jnp.linalg.norm(y, axis=-1, keepdims=True)
# TODO(michalk8): find a better way
aux_data["scale_cost"] = 2. / self.inv_scale_cost
cost_fn = costs.Euclidean()
return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn])

def to_LRCGeometry(
self,
scale: float = 1.0,
Expand Down
2 changes: 1 addition & 1 deletion ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.
"""OTT tools: A set of tools to use OT in differentiable ML pipelines."""

from . import gaussian_mixture, plot, sinkhorn_divergence, soft_sort, transport
from . import gaussian_mixture, k_means, plot, sinkhorn_divergence, soft_sort, transport
Loading

0 comments on commit d7521fd

Please sign in to comment.