Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bures.barycenter returns covariance matrix that has 0 values on the diagonal #199

Closed
marcocuturi opened this issue Dec 7, 2022 Discussed in #198 · 2 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@marcocuturi
Copy link
Contributor

Discussed in #198

Originally posted by conorhassan December 7, 2022
Hi there,

I love the OTT package - it's fantastic!

I am trying to use Bures.barycenter to find the weighted average (Gaussian distribution) between multiple Gaussians, but it returns a covariance matrix with 0's on the diagonal. Here is a minimal example:

from ott.geometry.costs import Bures, mean_and_cov_to_x, x_to_means_and_covs
import jax.numpy as jnp

# first Gaussian 
mu1 = jnp.array([-0.8909, -0.3568, 0.2758, 0.0352, -0.1457])
Sigma1 = jnp.array([0.3206, 0.8825, 0.1113, 0.9454, 0.0052]) * jnp.eye(5)
# second Gaussian 
mu2 = jnp.array([-0.8862, -0.3652, 0.2751, 0.0349, -0.1486])
Sigma2 = jnp.array([0.3075, 0.8545, 0.1110, 0.9206, 0.0054]) * jnp.eye(5)

# initializing Bures instance 
weights = jnp.array([300./537., 237./537.])
bures = Bures(5)

# stacking parameter values
xs = jnp.vstack(
    (mean_and_cov_to_x(mu1, Sigma1, 5), 
    mean_and_cov_to_x(mu2, Sigma2, 5))
)

# print output
output = bures.barycenter(weights, xs)
mu, Sigma = x_to_means_and_covs(output, 5)
print(Sigma)

image

I am trying to use this operation as part of an iterative procedure and thus need to do operations such as Cholesky decompositions on this matrix, etc.

I dug into the covariance_fixpoint_iter function and set the rtol to 1e-6 as opposed to the default value of 1e-2, but this didn't change much.

Any thoughts would be appreciated!

@marcocuturi marcocuturi self-assigned this Dec 7, 2022
@marcocuturi marcocuturi added the bug Something isn't working label Dec 7, 2022
@marcocuturi
Copy link
Contributor Author

thanks a lot for your kind words @conorhassan, as far as i can see, the issue does not come from the fixed point loop to compute barycenter of covariance matrices, but rather on the precision of the sqrtm routine, at a lower level. When I set the threshold there to a lower value this works.

BTW, if you are considering diagonal cov matrices, then the barycenters have a much simpler form (see e.g. Remark 2.31 Distance between Gaussians in our book). Essentially in that case you should take diagonals you want to average, compute their element-wise square roots, average in the usual sense with that representation, and then square back.

@marcocuturi
Copy link
Contributor Author

solved in #205

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant