-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make random seeding consistent with PRNG keys everywhere instead of seeds #290
Merged
Merged
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
b1b3457
add ENS
othmanesebbouh 290849c
removed ENS
othmanesebbouh 2cb8a9f
Test
4d497ab
Test
bc53ef5
Change gw_barycenter and neural dual
0326926
fix random number generators in initializers
othmanesebbouh bfeec00
Merge branch 'fix_random' of github.com:othmanesebbouh/ott into fix_r…
othmanesebbouh 1914639
k-means and gromov
c0db6df
fix rng on low rank geometry and datasets
othmanesebbouh 8182b12
Merge branch 'fix_random' of github.com:othmanesebbouh/ott into fix_r…
othmanesebbouh 8b305ef
fix remaining seed changes to rng
othmanesebbouh 55acb30
Merge branch 'main' into fix_random
othmanesebbouh 8e764d9
fixed rng reference and increased epsilon to make algo converge, shou…
othmanesebbouh 74e0bbd
fixed keys to rngs
othmanesebbouh d36ae8f
extended rng instead of keys to inside functions and added jax.random…
othmanesebbouh 7c38fc3
fixed more keys to rng with correct default type
othmanesebbouh fa8aedf
ran pre-commit
othmanesebbouh 6844e19
delete modified examples files and fix small conflicts
othmanesebbouh 8d464d8
Merge branch 'main' into fix_random
michalk8 01f4bf2
Fix typo
michalk8 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -360,11 +360,11 @@ def convertible(geom: geometry.Geometry) -> bool: | |||||
(geom_xy is None or convertible(geom_xy)) | ||||||
) | ||||||
|
||||||
def to_low_rank(self, seed: int = 0) -> "QuadraticProblem": | ||||||
def to_low_rank(self, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0)) -> "QuadraticProblem": | ||||||
"""Convert geometries to low-rank. | ||||||
|
||||||
Args: | ||||||
seed: Random seed. | ||||||
rng: Random key for seeding. | ||||||
|
||||||
Returns: | ||||||
Quadratic problem with low-rank geometries. | ||||||
|
@@ -383,19 +383,19 @@ def convert( | |||||
return self | ||||||
|
||||||
(geom_xx, geom_yy, geom_xy, *children), aux_data = self.tree_flatten() | ||||||
(s1, s2, s3) = jax.random.split(jax.random.PRNGKey(seed), 3)[:, 0] | ||||||
(k1, k2, k3) = jax.random.split(rng, 3) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(No parenthesis to be consistent, and k -> rng |
||||||
(r1, r2, r3), (t1, t2, t3) = convert(self.ranks), convert(self.tolerances) | ||||||
|
||||||
geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, seed=s1) | ||||||
geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, seed=s2) | ||||||
geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, rng=k1) | ||||||
geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, rng=k2) | ||||||
if self.is_fused: | ||||||
if isinstance( | ||||||
geom_xy, pointcloud.PointCloud | ||||||
) and geom_xy.is_squared_euclidean: | ||||||
geom_xy = geom_xy.to_LRCGeometry(scale=self.fused_penalty) | ||||||
else: | ||||||
geom_xy = geom_xy.to_LRCGeometry( | ||||||
rank=r3, tol=t3, seed=s3, scale=self.fused_penalty | ||||||
rank=r3, tol=t3, rng=k3, scale=self.fused_penalty | ||||||
) | ||||||
|
||||||
return type(self).tree_unflatten( | ||||||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the type hint should rather be
jax.random.PRNGKeyArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same everywhere, it should always be
rng: jax.random.PRNGKeyArray
instead ofrng: jnp.ndarray