-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes numerical errors in Bures barycenter, and sqrtm
, due to low default precision.
#205
Conversation
sqrtm
, due to low default precision.
src/ott/geometry/costs.py
Outdated
rtol: float = 1e-2 | ||
weights: jnp.ndarray, | ||
tolerance: float = 1e-4, | ||
kwargs_sqrtm: Optional[Mapping[str, Any]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use **kwargs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at that moment, this was used to avoid mixing threshold
(which is used for sqrtm
fixed point) and tolerance
(for barycenter. Both can be defined though. However, we'll be running into issues if we want to open other parameters (such as min_iterations
) since we have 2 imbricated fixed point loops (sqrtm
and cov. barycenter).
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #205 +/- ##
==========================================
- Coverage 89.26% 81.94% -7.33%
==========================================
Files 52 52
Lines 5300 5322 +22
Branches 543 546 +3
==========================================
- Hits 4731 4361 -370
- Misses 437 809 +372
- Partials 132 152 +20
|
src/ott/geometry/costs.py
Outdated
rtol: float = 1e-2 | ||
weights: jnp.ndarray, | ||
tolerance: float = 1e-4, | ||
**kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still missing a type hint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh sorry, probably added somewhere else!
…efault precision. (#205) * fixes * lint * fix tolerance in test * add test, minor fixes. * lint * Any in front of kwargs * fix test * lint * improve docs. * lint * lint
This allows user to pass on
kwargs
parameters to Bures barycenters, solving #199.This also lowers the default tolerance of
sqrtm
.example in #199 that was failing now works, with
outputting