-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make random seeding consistent with PRNG keys everywhere instead of seeds #290
Conversation
Codecov Report
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## main #290 +/- ##
==========================================
- Coverage 85.93% 85.92% -0.01%
==========================================
Files 52 52
Lines 5587 5579 -8
Branches 852 576 -276
==========================================
- Hits 4801 4794 -7
Misses 663 663
+ Partials 123 122 -1
|
thanks Othmane! you need to run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please also change the type hints in all cases tot jax.random.PRNGKeyArray
to be consistent?
Ideally the hint should be changed in tests as well whenever the rng: jnp.ndarray
fixture is used (see, e.g., here), this is not critical however.
….PRNGKeyArray as default type
@@ -66,7 +66,7 @@ def __init__(self, rank: int, **kwargs: Any): | |||
def init_q( | |||
self, | |||
ot_prob: Problem_t, | |||
key: jnp.ndarray, | |||
rng: jnp.ndarray, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the type hint should rather be jax.random.PRNGKeyArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same everywhere, it should always be rng: jax.random.PRNGKeyArray
instead of rng: jnp.ndarray
Yes, modified them and I’ll push the changes this afternoon
…On Thu 16 Feb 2023 at 11:59, Pierre Ablin ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In src/ott/initializers/linear/initializers_lr.py
<#290 (comment)>:
> @@ -66,7 +66,7 @@ def __init__(self, rank: int, **kwargs: Any):
def init_q(
self,
ot_prob: Problem_t,
- key: jnp.ndarray,
+ rng: jnp.ndarray,
Same everywhere, it should always be rng: jax.random.PRNGKeyArray instead
of rng: jnp.ndarray
—
Reply to this email directly, view it on GitHub
<#290 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AKUEHABK54IBF23URBAEVADWXYCBBANCNFSM6AAAAAAU33OSJY>
.
You are receiving this because you were assigned.Message ID:
***@***.***>
|
@@ -383,19 +383,19 @@ def convert( | |||
return self | |||
|
|||
(geom_xx, geom_yy, geom_xy, *children), aux_data = self.tree_flatten() | |||
(s1, s2, s3) = jax.random.split(jax.random.PRNGKey(seed), 3)[:, 0] | |||
(k1, k2, k3) = jax.random.split(rng, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(k1, k2, k3) = jax.random.split(rng, 3) | |
rng1, rng2, rng3 = jax.random.split(rng, 3) |
(No parenthesis to be consistent, and k -> rng
@othmanesebbouh do you think you could run the pre-commit to get rid of linting issues? Thanks! |
@othmanesebbouh any progress on this? Re the test that doesn't converge here, I'd say either you can keep it as it is (high epsilon), or revert + use another seed (if it works). In any case, will look into it in #302 |
@michalk8 just ran the pre-commit, I hope it works now. I actually thought I had already done that, but wasn't using it correctly, sorry. For the test that doesn't converge, I feel that having epsilon very large is a less cryptic way to signaling that there's a problem than having a seed that miraculously works. |
Works for me, thanks! Could you please merge merge git remote add upstream ssh://[email protected]/ott-jax/ott
git fetch upstream
git merge upstream/main |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @othmanesebbouh, LGTM!
…eeds (#290) * add ENS * removed ENS * Test * Test * Change gw_barycenter and neural dual * fix random number generators in initializers * k-means and gromov * fix rng on low rank geometry and datasets * fix remaining seed changes to rng * fixed rng reference and increased epsilon to make algo converge, should open issue * fixed keys to rngs * extended rng instead of keys to inside functions and added jax.random.PRNGKeyArray as default type * fixed more keys to rng with correct default type * ran pre-commit * Fix typo --------- Co-authored-by: Anastasiia <[email protected]> Co-authored-by: michalk8 <[email protected]>
This fixes #172, replacing all the
seed
withrng
, settingjax.random.PRNGKey(0)
as default where applicable.closes #172