-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JEP 9263: Typed keys & pluggable RNGs #17297
Conversation
0c5e6d0
to
960009f
Compare
c91db95
to
2f6bfe7
Compare
@froystig the JEP mentions |
Good catch. Sent you #17307 for review. |
2523485
to
49aaa5a
Compare
This follows the recommendations of [JEP 9263](jax-ml/jax#17297)
aa0fb7c
to
7e7f8d9
Compare
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see jax-ml/jax#17297) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict. PiperOrigin-RevId: 565133147
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see jax-ml/jax#17297) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict. PiperOrigin-RevId: 565133147
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see jax-ml/jax#17297) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict. PiperOrigin-RevId: 565133147
Functionally, this involves changing uses of jax.random.PRNGKey to jax.random.key. For details on this change and the motivation behind it, see the draft JEP at jax-ml/jax#17297, and please feel free to offer comments and feedback! PiperOrigin-RevId: 563549594
@jakevdp Good points.
Yeah, that's true. The
I see your point. I think it's likely that many type errors in such a situation would be true positives though, but maybe I'm too optimistic. I think the most common usage of
My mistake! Thanks for taking the time to answer. I guess I'll just have a |
This is true! But that doesn't make it any easier to land the change if you're working in a context where you need tests for all dependent projects to stay green. |
For more details, see jax-ml/jax#17297. Previously, we had imagined a world where the jax_enable_custom_prng flag globally determined the presence of typed keys. This proved untenable for a number of reasons. Going forward, old-style and new-style keys are expected to exist side-by-side regardless of the value of `jax_enable_custom_prng`, which will soon be deprecated. Eventually old-style keys will also be deprecated and removed. PiperOrigin-RevId: 565694585
For more details, see jax-ml/jax#17297. Previously, we had imagined a world where the jax_enable_custom_prng flag globally determined the presence of typed keys. This proved untenable for a number of reasons. Going forward, old-style and new-style keys are expected to exist side-by-side regardless of the value of `jax_enable_custom_prng`, which will soon be deprecated. Eventually old-style keys will also be deprecated and removed. PiperOrigin-RevId: 565770648
@NeilGirdhar I just pushed an update to the doc adding explicit sections on instance checking and type annotations. Thanks for the feedback! |
dcd57c4
to
5e6dd06
Compare
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously. PiperOrigin-RevId: 565133147
The context is described more fully in [JEP 9263](jax-ml/jax#17297). If you have comments on the JEP, we'd love to hear them! PiperOrigin-RevId: 566664782
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously. PiperOrigin-RevId: 565133147
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously. PiperOrigin-RevId: 565133147
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously. PiperOrigin-RevId: 565133147
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](jax-ml/jax#17297) for details) Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously. PiperOrigin-RevId: 566933252
We've gotten feedback from a number of stakeholders; I think this is ready for final review & merge! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the edit suggestion, then we're all set!
Co-authored-by: Roy Frostig <[email protected]>
This follows the recommendations of [JEP 9263](jax-ml/jax#17297)
Functionally, this involves changing uses of jax.random.PRNGKey to jax.random.key. For details on this change and the motivation behind it, see the draft JEP at jax-ml/jax#17297, and please feel free to offer comments and feedback! PiperOrigin-RevId: 565475405
For details, see jax-ml/jax#17297 PiperOrigin-RevId: 565705637 Change-Id: I88c4a3003166580e28111c4e410c8412c23ccb9d
Draft JAX Enhancement Proposal (JEP) document for the work tracked in #9263
Rendered preview: https://jax--17297.org.readthedocs.build/en/17297/jep/9263-typed-keys.html