Skip to content

Commit

Permalink
Merge pull request jax-ml#9186 from froystig:get-default-rng
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 421454645
  • Loading branch information
jax authors committed Jan 13, 2022
2 parents 436ce79 + 026b91b commit f0e4f04
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
14 changes: 10 additions & 4 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _check_prng_key(key):
'Raw arrays as random keys to jax.random functions are deprecated. '
'Assuming valid threefry2x32 key for now.',
FutureWarning)
return prng.PRNGKeyArray(_get_default_prng_impl(), key), True
return prng.PRNGKeyArray(default_prng_impl(), key), True
else:
raise TypeError(f'unexpected PRNG key type {type(key)}')

Expand All @@ -94,7 +94,13 @@ def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
'unsafe_rbg': prng.unsafe_rbg_prng_impl,
}

def _get_default_prng_impl():
def default_prng_impl():
"""Get the default PRNG implementation.
The default implementation is determined by ``config.jax_default_prng_impl``,
which specifies it by name. This function returns the corresponding
``jax.prng.PRNGImpl`` instance.
"""
impl_name = config.jax_default_prng_impl
assert impl_name in PRNG_IMPLS, impl_name
return PRNG_IMPLS[impl_name]
Expand All @@ -117,13 +123,13 @@ def PRNGKey(seed: int) -> KeyArray:
and ``fold_in``.
"""
impl = _get_default_prng_impl()
impl = default_prng_impl()
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)

# TODO(frostig): remove once we always enable_custom_prng
def _check_default_impl_with_no_custom_prng(impl, name):
default_impl = _get_default_prng_impl()
default_impl = default_prng_impl()
default_name = config.jax_default_prng_impl
if not config.jax_enable_custom_prng and default_impl is not impl:
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
categorical as categorical,
cauchy as cauchy,
choice as choice,
default_prng_impl as default_prng_impl,
dirichlet as dirichlet,
double_sided_maxwell as double_sided_maxwell,
exponential as exponential,
Expand Down
16 changes: 16 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,28 @@ def test_default_prng_selection(self):
('rbg', prng.rbg_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertIs(key.impl, impl)
k1, k2 = random.split(key, 2)
self.assertIs(k1.impl, impl)
self.assertIs(k2.impl, impl)

def test_default_prng_selection_without_custom_prng_mode(self):
if config.jax_enable_custom_prng:
self.skipTest("test requires that config.jax_enable_custom_prng is False")
for name, impl in [('threefry2x32', prng.threefry_prng_impl),
('rbg', prng.rbg_prng_impl),
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]:
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
key = random.PRNGKey(42)
self.assertEqual(key.shape, impl.key_shape)
k1, k2 = random.split(key, 2)
self.assertEqual(k1.shape, impl.key_shape)
self.assertEqual(k2.shape, impl.key_shape)


def test_explicit_threefry2x32_key(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
Expand Down

0 comments on commit f0e4f04

Please sign in to comment.