Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes numerical errors in Bures barycenter, and sqrtm, due to low default precision. #205

Merged
merged 11 commits into from
Dec 9, 2022

Conversation

marcocuturi
Copy link
Contributor

This allows user to pass on kwargs parameters to Bures barycenters, solving #199.

This also lowers the default tolerance of sqrtm.

example in #199 that was failing now works, with

import ott
from ott.geometry.costs import Bures, mean_and_cov_to_x, x_to_means_and_covs
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
# first Gaussian 
mu1 = jnp.array([-0.8909, -0.3568, 0.2758, 0.0352, -0.1457])
r = jnp.array([0.3206, 0.8825, 0.1113, 0.0052, 0.9454])
Sigma1 = r * jnp.eye(5)
# second Gaussian 
mu2 = jnp.array([-0.8862, -0.3652, 0.2751, 0.0349, -0.1486])
s = jnp.array([0.3075, 0.8545, 0.1110, 0.0054, 0.9206])
Sigma2 = s * jnp.eye(5)

# initializing Bures instance 
weights = jnp.array([300./537., 237./537.])
bures = Bures(5)

# stacking parameter values
xs = jnp.vstack(
    (mean_and_cov_to_x(mu1, Sigma1, 5), 
    mean_and_cov_to_x(mu2, Sigma2, 5))
)

# print output

output = bures.barycenter(weights, xs, tolerance=1e-4)
mu, Sigma = x_to_means_and_covs(output, 5)
print('new default threshold of 1e-6 (not passed)')
print(Sigma)

kwargs_sqrtm={'threshold' : 1e-4}
output = bures.barycenter(weights, xs, kwargs_sqrtm=kwargs_sqrtm)
mu, Sigma = x_to_means_and_covs(output, 5)
print('with former threshold in sqrtm (1e-4), convergence issues')
print(Sigma)

print('groundtruth')
print(jnp.diag(
  (weights[0]*jnp.sqrt(r) + weights[1]*jnp.sqrt(s))**2)
  )

outputting

new default threshold of 1e-6 (not passed)
[[0.31478475 0.         0.         0.         0.        ]
 [0.         0.87008681 0.         0.         0.        ]
 [0.         0.         0.11116755 0.         0.        ]
 [0.         0.         0.         0.00528754 0.        ]
 [0.         0.         0.         0.         0.93441411]]
with former threshold in sqrtm (1e-4), convergence issues
[[0.31478475 0.         0.         0.         0.        ]
 [0.         0.87008681 0.         0.         0.        ]
 [0.         0.         0.11116755 0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.93441411]]
groundtruth
[[0.31478475 0.         0.         0.         0.        ]
 [0.         0.87008681 0.         0.         0.        ]
 [0.         0.         0.11116755 0.         0.        ]
 [0.         0.         0.         0.0052878  0.        ]
 [0.         0.         0.         0.         0.93441411]]

@marcocuturi marcocuturi changed the title Fixes numerical errors Fixes numerical errors in Bures barycenter, and sqrtm, due to low default precision. Dec 9, 2022
@marcocuturi marcocuturi requested a review from michalk8 December 9, 2022 16:22
src/ott/math/matrix_square_root.py Show resolved Hide resolved
rtol: float = 1e-2
weights: jnp.ndarray,
tolerance: float = 1e-4,
kwargs_sqrtm: Optional[Mapping[str, Any]] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use **kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at that moment, this was used to avoid mixing threshold (which is used for sqrtm fixed point) and tolerance (for barycenter. Both can be defined though. However, we'll be running into issues if we want to open other parameters (such as min_iterations) since we have 2 imbricated fixed point loops (sqrtm and cov. barycenter).

src/ott/geometry/costs.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Dec 9, 2022

Codecov Report

Merging #205 (79b63aa) into main (0bbcca8) will decrease coverage by 7.32%.
The diff coverage is 84.21%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #205      +/-   ##
==========================================
- Coverage   89.26%   81.94%   -7.33%     
==========================================
  Files          52       52              
  Lines        5300     5322      +22     
  Branches      543      546       +3     
==========================================
- Hits         4731     4361     -370     
- Misses        437      809     +372     
- Partials      132      152      +20     
Impacted Files Coverage Δ
src/ott/math/matrix_square_root.py 93.61% <78.57%> (-6.39%) ⬇️
src/ott/geometry/costs.py 98.61% <100.00%> (ø)
src/ott/solvers/linear/discrete_barycenter.py 20.73% <0.00%> (-65.86%) ⬇️
src/ott/tools/segment_sinkhorn.py 43.75% <0.00%> (-56.25%) ⬇️
src/ott/solvers/quadratic/gw_barycenter.py 30.47% <0.00%> (-53.34%) ⬇️
src/ott/problems/quadratic/gw_barycenter.py 28.82% <0.00%> (-44.15%) ⬇️
src/ott/geometry/segment.py 62.22% <0.00%> (-37.78%) ⬇️
src/ott/tools/sinkhorn_divergence.py 63.15% <0.00%> (-35.09%) ⬇️
src/ott/tools/soft_sort.py 69.90% <0.00%> (-25.25%) ⬇️
src/ott/geometry/pointcloud.py 71.09% <0.00%> (-14.46%) ⬇️
... and 25 more

rtol: float = 1e-2
weights: jnp.ndarray,
tolerance: float = 1e-4,
**kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still missing a type hint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh sorry, probably added somewhere else!

@marcocuturi marcocuturi merged commit ad20eba into main Dec 9, 2022
@michalk8 michalk8 deleted the marco-issue199 branch February 24, 2023 13:51
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
…efault precision. (#205)

* fixes

* lint

* fix tolerance in test

* add test, minor fixes.

* lint

* Any in front of kwargs

* fix test

* lint

* improve docs.

* lint

* lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants