Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make random seeding consistent with PRNG keys everywhere instead of seeds #290

Merged
merged 20 commits into from
Feb 24, 2023

Conversation

othmanesebbouh
Copy link
Contributor

@othmanesebbouh othmanesebbouh commented Feb 14, 2023

This fixes #172, replacing all the seed with rng, setting jax.random.PRNGKey(0) as default where applicable.

closes #172

tests/solvers/quadratic/fgw_test.py Show resolved Hide resolved
tests/tools/k_means_test.py Outdated Show resolved Hide resolved
src/ott/problems/nn/dataset.py Outdated Show resolved Hide resolved
src/ott/problems/nn/dataset.py Outdated Show resolved Hide resolved
src/ott/tools/k_means.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Feb 15, 2023

Codecov Report

Merging #290 (74e0bbd) into main (2995ef4) will decrease coverage by 0.01%.
The diff coverage is 99.01%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #290      +/-   ##
==========================================
- Coverage   85.93%   85.92%   -0.01%     
==========================================
  Files          52       52              
  Lines        5587     5579       -8     
  Branches      852      576     -276     
==========================================
- Hits         4801     4794       -7     
  Misses        663      663              
+ Partials      123      122       -1     
Impacted Files Coverage Δ
src/ott/geometry/segment.py 100.00% <ø> (ø)
src/ott/math/fixed_point_loop.py 100.00% <ø> (ø)
src/ott/solvers/linear/implicit_differentiation.py 97.33% <ø> (ø)
src/ott/solvers/nn/layers.py 88.46% <ø> (ø)
src/ott/tools/gaussian_mixture/fit_gmm.py 75.60% <ø> (ø)
src/ott/tools/gaussian_mixture/fit_gmm_pair.py 79.51% <ø> (ø)
src/ott/tools/gaussian_mixture/linalg.py 100.00% <ø> (ø)
src/ott/tools/plot.py 19.81% <ø> (ø)
src/ott/tools/segment_sinkhorn.py 100.00% <ø> (ø)
src/ott/tools/soft_sort.py 95.41% <ø> (ø)
... and 42 more

@marcocuturi
Copy link
Contributor

thanks Othmane! you need to run pre-commit to get rid of linting issues, this is in the CONTRIBUTING.md

Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also change the type hints in all cases tot jax.random.PRNGKeyArray to be consistent?
Ideally the hint should be changed in tests as well whenever the rng: jnp.ndarray fixture is used (see, e.g., here), this is not critical however.

tests/solvers/quadratic/fgw_test.py Show resolved Hide resolved
@@ -66,7 +66,7 @@ def __init__(self, rank: int, **kwargs: Any):
def init_q(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type hint should rather be jax.random.PRNGKeyArray

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same everywhere, it should always be rng: jax.random.PRNGKeyArray instead of rng: jnp.ndarray

@othmanesebbouh
Copy link
Contributor Author

othmanesebbouh commented Feb 16, 2023 via email

@@ -383,19 +383,19 @@ def convert(
return self

(geom_xx, geom_yy, geom_xy, *children), aux_data = self.tree_flatten()
(s1, s2, s3) = jax.random.split(jax.random.PRNGKey(seed), 3)[:, 0]
(k1, k2, k3) = jax.random.split(rng, 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(k1, k2, k3) = jax.random.split(rng, 3)
rng1, rng2, rng3 = jax.random.split(rng, 3)

(No parenthesis to be consistent, and k -> rng

@marcocuturi
Copy link
Contributor

@othmanesebbouh do you think you could run the pre-commit to get rid of linting issues? Thanks!

@michalk8
Copy link
Collaborator

@othmanesebbouh any progress on this? Re the test that doesn't converge here, I'd say either you can keep it as it is (high epsilon), or revert + use another seed (if it works). In any case, will look into it in #302

@othmanesebbouh
Copy link
Contributor Author

@michalk8 just ran the pre-commit, I hope it works now. I actually thought I had already done that, but wasn't using it correctly, sorry.

For the test that doesn't converge, I feel that having epsilon very large is a less cryptic way to signaling that there's a problem than having a seed that miraculously works.

@michalk8
Copy link
Collaborator

For the test that doesn't converge, I feel that having epsilon very large is a less cryptic way to signaling that there's a problem than having a seed that miraculously works.

Works for me, thanks! Could you please merge merge main and resolve the conflicts?
This can be done by, e.g.:

git remote add upstream ssh://[email protected]/ott-jax/ott
git fetch upstream
git merge upstream/main

Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @othmanesebbouh, LGTM!

@michalk8 michalk8 merged commit 05ff6dd into ott-jax:main Feb 24, 2023
michalk8 added a commit that referenced this pull request Jun 27, 2024
…eeds (#290)

* add ENS

* removed ENS

* Test

* Test

* Change gw_barycenter and neural dual

* fix random number generators in initializers

* k-means and gromov

* fix rng on low rank geometry and datasets

* fix remaining seed changes to rng

* fixed rng reference and increased epsilon to make algo converge, should open issue

* fixed keys to rngs

* extended rng instead of keys to inside functions and added jax.random.PRNGKeyArray as default type

* fixed more keys to rng with correct default type

* ran pre-commit

* Fix typo

---------

Co-authored-by: Anastasiia <[email protected]>
Co-authored-by: michalk8 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unifying passing PRNG seeds
5 participants