Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kmeans++ #120

Merged
merged 51 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c0ba906
Add initial impl. from `CR.Sparse`
michalk8 Aug 3, 2022
5d29f4f
Rename file, add to __init__
michalk8 Aug 3, 2022
a155dbe
Initial fixed point iteration
michalk8 Aug 3, 2022
700005b
Add random initialization
michalk8 Aug 3, 2022
f3aa3a5
Better KMeansState
michalk8 Aug 3, 2022
6001994
Fix `cond_fn`
michalk8 Aug 3, 2022
d6790e5
First working version
michalk8 Aug 3, 2022
023ccbf
Clean output, use `tree_map`
michalk8 Aug 3, 2022
58ecb3d
Remove reference impl. and dead code
michalk8 Aug 3, 2022
6233ef9
Add TODO
michalk8 Aug 3, 2022
a3c1109
Expose `cost_rank` in `PointCloud`
michalk8 Aug 3, 2022
25705ff
Add initial kmeans++ implementation
michalk8 Aug 3, 2022
7ff26ad
Fix indexing bug
michalk8 Aug 3, 2022
2d54ed9
Remove `set` methods
michalk8 Aug 3, 2022
28fb1fa
Add tolerance to convergence check
michalk8 Aug 3, 2022
400cb23
Rename function
michalk8 Aug 3, 2022
99d8442
Store inner errors
michalk8 Aug 4, 2022
dbf070f
Unify kmeans initializer interface, allow custom
michalk8 Aug 4, 2022
68b2116
Reorder arguments
michalk8 Aug 4, 2022
fa02a19
Add strict convergence criterion
michalk8 Aug 4, 2022
88d6bc4
Add convergence iteration to output
michalk8 Aug 4, 2022
b222f68
Simplify `cond_fn`, use `max_iter - 1`
michalk8 Aug 4, 2022
9974b75
Clip cosine distance to `[0, 2]`
michalk8 Aug 4, 2022
c1ae696
Require sqEucl geometry, allow arrays
michalk8 Aug 4, 2022
de998af
Fix random/kmeans++ init
michalk8 Aug 4, 2022
69b071e
Remove normalization comment
michalk8 Aug 4, 2022
9b1ae9a
Fix dividing by 0 when using weights
michalk8 Aug 4, 2022
e28dc84
Add TODOs
michalk8 Aug 4, 2022
15d5c93
Fix k-means++ init centroid
michalk8 Aug 4, 2022
e9652e4
Use `jax.tree_util.tree_map`
michalk8 Aug 4, 2022
f690333
Rename arguments, use sum instead of mean
michalk8 Aug 4, 2022
6578051
Switch order
michalk8 Aug 4, 2022
1be8bc1
Fix `unique_indices=True` in segment sum
michalk8 Aug 5, 2022
1856cde
Don't compute assignment in `init_fn`
michalk8 Aug 5, 2022
d509766
Fix weighting
michalk8 Aug 5, 2022
2723899
Use center shift as convergence criterion
michalk8 Aug 8, 2022
c717876
Remove old TODOs
michalk8 Aug 8, 2022
0ad4b46
Improve final assignment
michalk8 Aug 8, 2022
f1323b7
Fix centroid/weight adjustment
michalk8 Aug 8, 2022
777e5a2
[ci skip] Allow geometry with cosine cost
michalk8 Aug 8, 2022
b316a20
Fix cosine conversion, add test
michalk8 Aug 8, 2022
e7f8c33
Add more cosine -> sqeucl conversion tests
michalk8 Aug 8, 2022
68ea41e
Add documentation
michalk8 Aug 8, 2022
e99b6f0
Add skeleton tests
michalk8 Aug 9, 2022
d8359cd
Add kmeans++ tests
michalk8 Aug 9, 2022
595d9a7
Add k-means initialization test
michalk8 Aug 9, 2022
e33b1c5
Fix bug when removing empty centroids
michalk8 Aug 9, 2022
a0afc5e
Finish tests
michalk8 Aug 9, 2022
4458e8f
Increase tolerance
michalk8 Aug 9, 2022
20ae11e
Address comments
michalk8 Aug 10, 2022
d10ebbf
Use smaller eps
michalk8 Aug 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,28 @@ @Article{higham:1997
doi = "10.1023/A:1019150005407",
url = "https://doi.org/10.1023/A:1019150005407"
}

@Article{lloyd:82,
author={Lloyd, S.},
journal={IEEE Transactions on Information Theory},
title={Least squares quantization in PCM},
year={1982},
volume={28},
number={2},
pages={129-137},
doi={10.1109/TIT.1982.1056489}
}

@inproceedings{arthur:07,
author = {Arthur, David and Vassilvitskii, Sergei},
title = {K-Means++: The Advantages of Careful Seeding},
year = {2007},
isbn = {9780898716245},
publisher = {Society for Industrial and Applied Mathematics},
address = {USA},
booktitle = {Proceedings of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms},
pages = {1027–1035},
numpages = {9},
location = {New Orleans, Louisiana},
series = {SODA '07}
}
8 changes: 8 additions & 0 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ Soft Sorting Algorithms
soft_sort.ranks
soft_sort.sort
soft_sort.sort_with

Clustering
----------
.. autosummary::
:toctree: _autosummary

k_means.k_means
k_means.KMeansOutput
8 changes: 5 additions & 3 deletions ott/core/fixed_point_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def fixpoint_iter_fwd(
"""
force_scan = (min_iterations == max_iterations)
compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1
states = jax.tree_map(
states = jax.tree_util.tree_map(
lambda x: jnp.zeros((max_iterations // inner_iterations + 1,) + x.shape,
dtype=x.dtype), state
)
Expand Down Expand Up @@ -176,7 +176,7 @@ def fixpoint_iter_bwd(
force_scan = (min_iterations == max_iterations)
constants, iteration, states = res
# The tree may contain some python floats
g_constants = jax.tree_map(
g_constants = jax.tree_util.tree_map(
lambda x: jnp.zeros_like(x, dtype=x.dtype)
if isinstance(x, (np.ndarray, jnp.ndarray)) else 0, constants
)
Expand All @@ -202,7 +202,9 @@ def one_iteration(iteration_state, compute_error):

def unrolled_body_fn(iteration_g_gconst):
iteration, g, g_constants = iteration_g_gconst
state = jax.tree_map(lambda x: x[iteration // inner_iterations], states)
state = jax.tree_util.tree_map(
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
lambda x: x[iteration // inner_iterations], states
)
_, pullback = jax.vjp(
unrolled_body_fn_no_errors, iteration, constants, state
)
Expand Down
11 changes: 6 additions & 5 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def padder(cls, dim: int) -> jnp.ndarray:
return jnp.zeros((1, dim))

def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
return self.pairwise(
x, y
) + (0 if self.norm is None else self.norm(x) + self.norm(y))
cost = self.pairwise(x, y)
if self.norm is None:
return cost
return cost + self.norm(x) + self.norm(y)

def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Compute matrix of all costs (including norms) for vectors in x / y.
Expand Down Expand Up @@ -99,7 +100,7 @@ def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute minus twice the dot-product between vectors."""
return -2 * jnp.vdot(x, y)
return -2. * jnp.vdot(x, y)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
"""Output barycenter of vectors when using squared-Euclidean distance."""
Expand All @@ -121,7 +122,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
y_norm = jnp.linalg.norm(y, axis=-1)
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + ridge)
cosine_distance = 1.0 - cosine_similarity
return cosine_distance
return jnp.clip(cosine_distance, 0., 2.)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float:
raise NotImplementedError("Barycenter for cosine cost not yet implemented.")
Expand Down
4 changes: 2 additions & 2 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
self._kwargs = {**{'init': None, 'decay': None}, **kwargs}

@property
def cost_rank(self) -> None:
def cost_rank(self) -> Optional[int]:
"""Output rank of cost matrix, if any was provided."""
return None

Expand Down Expand Up @@ -198,7 +198,7 @@ def is_symmetric(self) -> bool:
@property
def inv_scale_cost(self) -> float:
"""Compute and return inverse of scaling factor for cost matrix."""
if isinstance(self._scale_cost, (int, float)):
if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
return 1.0 / self._scale_cost
self = self._masked_geom(mask_value=jnp.nan)
if self._scale_cost == 'max_cost':
Expand Down
8 changes: 4 additions & 4 deletions ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ def transport_from_scalings(
) -> NoReturn:
"""Not implemented, use :meth:`apply_transport_from_scalings` instead."""
raise ValueError(
'Grid geometry cannot instantiate a transport matrix, use',
' apply_transport_from_scalings(...) if you wish to ',
' apply the transport matrix to a vector, or use a point '
' cloud geometry instead'
'Grid geometry cannot instantiate a transport matrix, use ',
'apply_transport_from_scalings(...) if you wish to ',
'apply the transport matrix to a vector, or use a point '
'cloud geometry instead.'
)

def subset(
Expand Down
2 changes: 1 addition & 1 deletion ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def is_symmetric(self) -> bool:

@property
def inv_scale_cost(self) -> float:
if isinstance(self._scale_cost, (int, float)):
if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == 'max_bound':
Expand Down
18 changes: 17 additions & 1 deletion ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,14 @@ def is_online(self) -> bool:
is computed on-the-fly."""
return self.batch_size is not None

# TODO(michalk8): when refactoring, consider PC as a subclass of LR?
@property
def cost_rank(self) -> int:
return self.x.shape[1]

@property
def inv_scale_cost(self) -> float:
if isinstance(self._scale_cost, (int, float)):
if isinstance(self._scale_cost, (int, float, jnp.DeviceArray)):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == 'max_cost':
Expand Down Expand Up @@ -567,6 +572,17 @@ def tree_unflatten(cls, aux_data, children):
x, y, cost_fn=cost_fn, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data
)

def _cosine_to_sqeucl(self) -> 'PointCloud':
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(self._cost_fn, costs.Cosine), type(self._cost_fn)
assert self.power == 2, self.power
(x, y, *args, _), aux_data = self.tree_flatten()
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
y = y / jnp.linalg.norm(y, axis=-1, keepdims=True)
# TODO(michalk8): find a better way
aux_data["scale_cost"] = 2. / self.inv_scale_cost
cost_fn = costs.Euclidean()
return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn])

def to_LRCGeometry(
self,
scale: float = 1.0,
Expand Down
2 changes: 1 addition & 1 deletion ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.
"""OTT tools: A set of tools to use OT in differentiable ML pipelines."""

from . import gaussian_mixture, plot, sinkhorn_divergence, soft_sort, transport
from . import gaussian_mixture, k_means, plot, sinkhorn_divergence, soft_sort, transport
Loading