Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kmeans++ #120

Merged
merged 51 commits into from
Aug 10, 2022
Merged

Feature/kmeans++ #120

merged 51 commits into from
Aug 10, 2022

Conversation

michalk8
Copy link
Collaborator

@michalk8 michalk8 commented Aug 4, 2022

In this PR:

  • implement k-means, including k-means++ initialization and weighting
  • fix tree_util warnings in fixed_point_loop
  • allow DeviceArray for scaling of cost matrices
  • allow converting cosine point cloud to sq. Euclidean one

related issue #111

@michalk8 michalk8 added the enhancement New feature or request label Aug 4, 2022
@michalk8 michalk8 self-assigned this Aug 4, 2022
@michalk8 michalk8 requested a review from marcocuturi August 9, 2022 18:58
@michalk8 michalk8 marked this pull request as ready for review August 9, 2022 21:57
Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Michal, this looks very nice

ott/core/fixed_point_loop.py Show resolved Hide resolved
ott/core/fixed_point_loop.py Show resolved Hide resolved
ott/geometry/costs.py Show resolved Hide resolved
ott/geometry/pointcloud.py Show resolved Hide resolved
ott/tools/k_means.py Outdated Show resolved Hide resolved
ott/tools/k_means.py Show resolved Hide resolved
ott/tools/k_means.py Outdated Show resolved Hide resolved
tests/tools/k_means_test.py Outdated Show resolved Hide resolved
x, k=k, weights=w, min_iterations=10, max_iterations=10, key=key1
).error

k, eps = 4, 1e-12
Copy link
Contributor

@marcocuturi marcocuturi Aug 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a bit too small? if this works with 1e-12 why not, but this sounds like a scale where numerical errors could easily pop up. Maybe more reasonable to use something like 1e-3 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed eps=1e-3 and rtol=atol=1e-3. If using 64 bits, it passes with rtol=atol=1e-8, but haven't added a test for this, as it already takes some time.

Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Michal. Apart from the scale of the eps perturbation, everything sounds good!

@michalk8 michalk8 merged commit d7521fd into ott-jax:main Aug 10, 2022
@michalk8 michalk8 deleted the feature/kmeans++ branch August 10, 2022 13:40
michalk8 added a commit that referenced this pull request Jun 27, 2024
* Add initial impl. from `CR.Sparse`

* Rename file, add to __init__

* Initial fixed point iteration

* Add random initialization

* Better KMeansState

* Fix `cond_fn`

* First working version

* Clean output, use `tree_map`

* Remove reference impl. and dead code

* Add TODO

* Expose `cost_rank` in `PointCloud`

* Add initial kmeans++ implementation

* Fix indexing bug

* Remove `set` methods

* Add tolerance to convergence check

* Rename function

* Store inner errors

* Unify kmeans initializer interface, allow custom

* Reorder arguments

* Add strict convergence criterion

* Add convergence iteration to output

* Simplify `cond_fn`, use `max_iter - 1`

* Clip cosine distance to `[0, 2]`

* Require sqEucl geometry, allow arrays

* Fix random/kmeans++ init

* Remove normalization comment

* Fix dividing by 0 when using weights

* Add TODOs

* Fix k-means++ init centroid

* Use `jax.tree_util.tree_map`

* Rename arguments, use sum instead of mean

* Switch order

* Fix `unique_indices=True` in segment sum

* Don't compute assignment in `init_fn`

* Fix weighting

* Use center shift as convergence criterion

* Remove old TODOs

* Improve final assignment

* Fix centroid/weight adjustment

* [ci skip] Allow geometry with cosine cost

* Fix cosine conversion, add test

* Add more cosine -> sqeucl conversion tests

* Add documentation

* Add skeleton tests

* Add kmeans++ tests

* Add k-means initialization test

* Fix bug when removing empty centroids

* Finish tests

* Increase tolerance

* Address comments

* Use smaller eps
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants