-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/kmeans++ #120
Feature/kmeans++ #120
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot Michal, this looks very nice
tests/tools/k_means_test.py
Outdated
x, k=k, weights=w, min_iterations=10, max_iterations=10, key=key1 | ||
).error | ||
|
||
k, eps = 4, 1e-12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a bit too small? if this works with 1e-12 why not, but this sounds like a scale where numerical errors could easily pop up. Maybe more reasonable to use something like 1e-3 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed eps=1e-3
and rtol=atol=1e-3
. If using 64 bits, it passes with rtol=atol=1e-8
, but haven't added a test for this, as it already takes some time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Michal. Apart from the scale of the eps perturbation, everything sounds good!
* Add initial impl. from `CR.Sparse` * Rename file, add to __init__ * Initial fixed point iteration * Add random initialization * Better KMeansState * Fix `cond_fn` * First working version * Clean output, use `tree_map` * Remove reference impl. and dead code * Add TODO * Expose `cost_rank` in `PointCloud` * Add initial kmeans++ implementation * Fix indexing bug * Remove `set` methods * Add tolerance to convergence check * Rename function * Store inner errors * Unify kmeans initializer interface, allow custom * Reorder arguments * Add strict convergence criterion * Add convergence iteration to output * Simplify `cond_fn`, use `max_iter - 1` * Clip cosine distance to `[0, 2]` * Require sqEucl geometry, allow arrays * Fix random/kmeans++ init * Remove normalization comment * Fix dividing by 0 when using weights * Add TODOs * Fix k-means++ init centroid * Use `jax.tree_util.tree_map` * Rename arguments, use sum instead of mean * Switch order * Fix `unique_indices=True` in segment sum * Don't compute assignment in `init_fn` * Fix weighting * Use center shift as convergence criterion * Remove old TODOs * Improve final assignment * Fix centroid/weight adjustment * [ci skip] Allow geometry with cosine cost * Fix cosine conversion, add test * Add more cosine -> sqeucl conversion tests * Add documentation * Add skeleton tests * Add kmeans++ tests * Add k-means initialization test * Fix bug when removing empty centroids * Finish tests * Increase tolerance * Address comments * Use smaller eps
In this PR:
k-means
, includingk-means++
initialization and weightingtree_util
warnings infixed_point_loop
DeviceArray
for scaling of cost matricesrelated issue #111