You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Jax recently started moving towards new style RNG-keys which entails changing calls to jax.random.PRNGKey for calls to jax.key. Additionally, seeds to numpy.random.RandomState from new-style keys (as in compute_cvt_centroids) must be first processed with jax.random.key_data otherwise we get a type error during conversion.
I am happy to write a pull request for this.
The text was updated successfully, but these errors were encountered:
I've submitted a PR, but I ran into some issues (as can be seen by the failed tests). The problem is the version of Haiku in requirements.txt (0.0.10) doesn't support new style rng keys. It is supported in the latest version (0.0.13) as I have tested this in my own machine. Updating the requirements should thus allow the examples/tests to remain the same.
Jax recently started moving towards new style RNG-keys which entails changing calls to
jax.random.PRNGKey
for calls tojax.key
. Additionally, seeds tonumpy.random.RandomState
from new-style keys (as incompute_cvt_centroids
) must be first processed withjax.random.key_data
otherwise we get a type error during conversion.I am happy to write a pull request for this.
The text was updated successfully, but these errors were encountered: