Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JEP 9263: Typed keys & pluggable RNGs #17297

Merged
merged 1 commit into from
Sep 20, 2023
Merged

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Aug 25, 2023

Draft JAX Enhancement Proposal (JEP) document for the work tracked in #9263

Rendered preview: https://jax--17297.org.readthedocs.build/en/17297/jep/9263-typed-keys.html

@jakevdp jakevdp added the JEP JAX enhancement proposal label Aug 25, 2023
@jakevdp jakevdp requested a review from froystig August 25, 2023 18:06
@jakevdp jakevdp self-assigned this Aug 25, 2023
@jakevdp jakevdp marked this pull request as draft August 25, 2023 18:06
@jakevdp jakevdp changed the title JAP 9263: Typed keys & pluggable RNGs JEP 9263: Typed keys & pluggable RNGs Aug 25, 2023
@jakevdp jakevdp force-pushed the jep-9263 branch 2 times, most recently from 0c5e6d0 to 960009f Compare August 25, 2023 18:12
@jakevdp jakevdp marked this pull request as ready for review August 25, 2023 19:55
@jakevdp jakevdp marked this pull request as draft August 25, 2023 20:05
@jakevdp jakevdp force-pushed the jep-9263 branch 3 times, most recently from c91db95 to 2f6bfe7 Compare August 25, 2023 21:26
@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 25, 2023

@froystig the JEP mentions jax.extend.random.wrap_key, which we haven't created yet. Is jax._src.prng.random_wrap the right function for this?

@froystig
Copy link
Member

Good catch. Sent you #17307 for review.

@jakevdp jakevdp force-pushed the jep-9263 branch 4 times, most recently from 2523485 to 49aaa5a Compare August 29, 2023 21:22
jakevdp added a commit to jakevdp/flax that referenced this pull request Sep 8, 2023
This follows the recommendations of [JEP 9263](jax-ml/jax#17297)
@jakevdp jakevdp force-pushed the jep-9263 branch 2 times, most recently from aa0fb7c to 7e7f8d9 Compare September 13, 2023 16:38
copybara-service bot pushed a commit to google-deepmind/chex that referenced this pull request Sep 13, 2023
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see jax-ml/jax#17297)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict.

PiperOrigin-RevId: 565133147
copybara-service bot pushed a commit to google-deepmind/distrax that referenced this pull request Sep 13, 2023
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see jax-ml/jax#17297)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict.

PiperOrigin-RevId: 565133147
copybara-service bot pushed a commit to google-deepmind/chex that referenced this pull request Sep 13, 2023
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see jax-ml/jax#17297)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict.

PiperOrigin-RevId: 565133147
copybara-service bot pushed a commit to google/flax that referenced this pull request Sep 14, 2023
Functionally, this involves changing uses of jax.random.PRNGKey to jax.random.key. For details on this change and the motivation behind it, see the draft JEP at jax-ml/jax#17297, and please feel free to offer comments and feedback!

PiperOrigin-RevId: 563549594
@NeilGirdhar
Copy link
Contributor

@jakevdp Good points.

  1. If we change it so alias to Array, this code will begin to silently return unexpected results.

Yeah, that's true. The NewType approach would be an option to avoid that since it does not support isinstance.

, and so in practice it leads to a hard break where checks begin failing at the new release.

I see your point. I think it's likely that many type errors in such a situation would be true positives though, but maybe I'm too optimistic. I think the most common usage of KeyArray is an unbroken path from PRNGKey to random functions, and that wouldn't break. But I see the principle you're getting at, and that makes sense.

Regarding issubdtype – it's already mentioned prominently in the first paragraph,

My mistake!

Thanks for taking the time to answer. I guess I'll just have a KeyArray type. It's just unfortunate that the NewType approach is too inconvenient since PRNGKey doesn't produce them.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Sep 15, 2023

I think it's likely that many type errors in such a situation would be true positives though

This is true! But that doesn't make it any easier to land the change if you're working in a context where you need tests for all dependent projects to stay green.

copybara-service bot pushed a commit to google-deepmind/dm-haiku that referenced this pull request Sep 15, 2023
For more details, see jax-ml/jax#17297. Previously, we had imagined a world where the jax_enable_custom_prng flag globally determined the presence of typed keys. This proved untenable for a number of reasons. Going forward, old-style and new-style keys are expected to exist side-by-side regardless of the value of `jax_enable_custom_prng`, which will soon be deprecated. Eventually old-style keys will also be deprecated and removed.

PiperOrigin-RevId: 565694585
copybara-service bot pushed a commit to google-deepmind/dm-haiku that referenced this pull request Sep 15, 2023
For more details, see jax-ml/jax#17297. Previously, we had imagined a world where the jax_enable_custom_prng flag globally determined the presence of typed keys. This proved untenable for a number of reasons. Going forward, old-style and new-style keys are expected to exist side-by-side regardless of the value of `jax_enable_custom_prng`, which will soon be deprecated. Eventually old-style keys will also be deprecated and removed.

PiperOrigin-RevId: 565770648
@jakevdp
Copy link
Collaborator Author

jakevdp commented Sep 18, 2023

@NeilGirdhar I just pushed an update to the doc adding explicit sections on instance checking and type annotations. Thanks for the feedback!

@jakevdp jakevdp force-pushed the jep-9263 branch 2 times, most recently from dcd57c4 to 5e6dd06 Compare September 18, 2023 20:54
copybara-service bot pushed a commit to google-deepmind/chex that referenced this pull request Sep 19, 2023
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously.

PiperOrigin-RevId: 565133147
copybara-service bot pushed a commit to tensorflow/probability that referenced this pull request Sep 19, 2023
The context is described more fully in [JEP 9263](jax-ml/jax#17297).
If you have comments on the JEP, we'd love to hear them!

PiperOrigin-RevId: 566664782
copybara-service bot pushed a commit to google-deepmind/chex that referenced this pull request Sep 19, 2023
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously.

PiperOrigin-RevId: 565133147
copybara-service bot pushed a commit to google-deepmind/chex that referenced this pull request Sep 20, 2023
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously.

PiperOrigin-RevId: 565133147
copybara-service bot pushed a commit to google-deepmind/chex that referenced this pull request Sep 20, 2023
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously.

PiperOrigin-RevId: 565133147
copybara-service bot pushed a commit to google-deepmind/chex that referenced this pull request Sep 20, 2023
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously.

PiperOrigin-RevId: 566933252
@jakevdp jakevdp marked this pull request as ready for review September 20, 2023 18:19
@jakevdp
Copy link
Collaborator Author

jakevdp commented Sep 20, 2023

We've gotten feedback from a number of stakeholders; I think this is ready for final review & merge!

Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the edit suggestion, then we're all set!

docs/jep/9263-typed-keys.md Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Sep 20, 2023
@copybara-service copybara-service bot merged commit 41c7cce into jax-ml:main Sep 20, 2023
@jakevdp jakevdp deleted the jep-9263 branch September 20, 2023 18:38
8bitmp3 pushed a commit to 8bitmp3/flax that referenced this pull request Oct 9, 2023
This follows the recommendations of [JEP 9263](jax-ml/jax#17297)
8bitmp3 pushed a commit to 8bitmp3/flax that referenced this pull request Oct 9, 2023
Functionally, this involves changing uses of jax.random.PRNGKey to jax.random.key. For details on this change and the motivation behind it, see the draft JEP at jax-ml/jax#17297, and please feel free to offer comments and feedback!

PiperOrigin-RevId: 565475405
Augustin-Zidek pushed a commit to google-deepmind/alphafold that referenced this pull request Oct 24, 2023
For details, see jax-ml/jax#17297

PiperOrigin-RevId: 565705637
Change-Id: I88c4a3003166580e28111c4e410c8412c23ccb9d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JEP JAX enhancement proposal pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants