-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LR Sinkhorn improvements #111
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a look the the notebook that causes the conflict, there are no new changes.
Could you please update the branch as git merge master -X theirs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got circular import error because of the kmeans
, because it's placed in tools. Making it a relative import in LRSinkhorn.__call__
solved this.
@meyerscetbon @marcocuturi you can both review it now, there some TODOs left for me, but not critical |
Codecov Report
@@ Coverage Diff @@
## main #111 +/- ##
==========================================
- Coverage 89.45% 83.17% -6.28%
==========================================
Files 47 48 +1
Lines 4580 4713 +133
Branches 503 512 +9
==========================================
- Hits 4097 3920 -177
- Misses 364 650 +286
- Partials 119 143 +24
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR, it looks like we have a really great LR implementation now.
d7a2a16
to
fb1a1df
Compare
* test * update lr-sinkhorn * restored_branch * check * review * circular fixed * update review * Fix bugs in `LRSinkhorn` * Use new `k-means` implementation * Fix linter * Refactor `LRSinkhorn` initializers * Use `if` for `is_entropic`, remove dead variables * Slightly improve types * Do not use stateful `gamma` * Fix typo in tests * Fix using `state.gamma` instead of `self.gamma` * Fix point cloud size in notebook * Add assertion to k-means * Use `jax.lax.cond` instead of `jax.numpy.where` * Change convergence criterion * Use safe log * Fix more tests * Fix tests * Fix `tree_flatten` in `KMeansInitializer` * Fix defaults, change `rank_2` -> `rank2` * Simplify `apply` * Update TODOs * Update docs, make `lr_costs` private * Increate tolerance in failing test * Update LR notebook * Address comments * Remove LR Sinkhorn notebook from testing, to slow Co-authored-by: Michal Klein
Here are the implementations of the update for the lr-sinkhorn algorithm: