Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New style Jax RNG keys cannot be converted directly to Numpy random state seeds. #176

Open
miltonllera opened this issue Feb 27, 2024 · 2 comments

Comments

@miltonllera
Copy link

Jax recently started moving towards new style RNG-keys which entails changing calls to jax.random.PRNGKey for calls to jax.key. Additionally, seeds to numpy.random.RandomState from new-style keys (as in compute_cvt_centroids) must be first processed with jax.random.key_data otherwise we get a type error during conversion.

I am happy to write a pull request for this.

@Aneoshun
Copy link
Member

Aneoshun commented Mar 5, 2024

Hi Milton,

Sounds great. Please, write de PR for this, we will be happy to review and accept it.

@miltonllera
Copy link
Author

Sorry I've been slow on this.

I've submitted a PR, but I ran into some issues (as can be seen by the failed tests). The problem is the version of Haiku in requirements.txt (0.0.10) doesn't support new style rng keys. It is supported in the latest version (0.0.13) as I have tested this in my own machine. Updating the requirements should thus allow the examples/tests to remain the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants