Skip to content

Commit

Permalink
Merge pull request #20094 from froystig:vmap-rbg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614034982
  • Loading branch information
jax authors committed Mar 8, 2024
2 parents 59e9ee3 + 29edfd8 commit c4cf265
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 16 deletions.
16 changes: 9 additions & 7 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,16 +2044,18 @@ def map(f, xs):
return ys

def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
"""Calls RBG in a loop and stacks the results."""
key, = batched_args
keys, = batched_args
bd, = batch_dims
if bd is batching.not_mapped:
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype,
algorithm=algorithm), (None, None)
key = batching.moveaxis(key, bd, 0)
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
stacked_keys, stacked_bits = map(map_body, key)
return (stacked_keys, stacked_bits), (0, 0)
keys = batching.moveaxis(keys, bd, 0)
batch_size = keys.shape[0]
key = keys[0]
new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape),
dtype=dtype, algorithm=algorithm)
new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0)
return (new_keys, bits), (0, 0)

batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False):
keys = keys.flatten()
alphas = a.flatten()

if use_vmap:
if use_vmap and _key_impl(key) is prng.threefry_prng_impl:
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
else:
samples = lax.map(
Expand Down
3 changes: 3 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,9 @@ jax_test(
"notsan", # Times out
],
},
backend_variant_args = {
"gpu": ["--jax_num_generated_cases=40"],
},
shard_count = {
"cpu": 40,
"gpu": 30,
Expand Down
18 changes: 18 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,6 +2652,24 @@ def testRngBitGeneratorReturnedKey(self):
new_key, _ = lax.rng_bit_generator(key, (0,))
self.assertAllClose(key, new_key)

def test_rng_bit_generator_vmap(self):
def f(key):
return lax.rng_bit_generator(key, shape=(5, 7))

keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32)
out_keys, bits = jax.vmap(f)(keys)
self.assertEqual(out_keys.shape, (3, 4))
self.assertEqual(bits.shape, (3, 5, 7))

def test_rng_bit_generator_vmap_vmap(self):
def f(key):
return lax.rng_bit_generator(key, shape=(5, 7))

keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32)
out_keys, bits = jax.vmap(jax.vmap(f))(keys)
self.assertEqual(out_keys.shape, (2, 3, 4))
self.assertEqual(bits.shape, (2, 3, 5, 7))

@jtu.sample_product(
dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types,
weak_type=[True, False],
Expand Down
59 changes: 51 additions & 8 deletions tests/random_lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,7 @@ def test_vmap_fold_in_shape(self):
out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T)
self.assertEqual(out.shape, (3, 2))

@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
Expand Down Expand Up @@ -1408,24 +1409,57 @@ def test_vmap_split_not_mapped_key(self):
self.assertArraysEqual(random.key_data(vk),
random.key_data(single_split_key))

def test_vmap_split_mapped_key(self):
@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key_shape(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_keys = [random.split(k) for k in mapped_keys]
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape))
for fk, vk in zip(forloop_keys, vmapped_keys):
self.assertArraysEqual(random.key_data(fk),

@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
ref_keys = [random.split(k) for k in mapped_keys]
for rk, vk in zip(ref_keys, vmapped_keys):
self.assertArraysEqual(random.key_data(rk),
random.key_data(vk))

def test_vmap_random_bits(self):
rand_fun = lambda key: random.randint(key, (), 0, 100)
@jax.enable_key_reuse_checks(False)
def test_vmap_random_bits_shape(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
forloop_rand_nums = [rand_fun(k) for k in mapped_keys]
rand_nums = vmap(rand_fun)(mapped_keys)
self.assertEqual(rand_nums.shape, (3,))
self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums))

@jtu.skip_on_devices("tpu")
@jax.enable_key_reuse_checks(False)
def test_vmap_random_bits_value(self):
rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100)
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
rand_nums = vmap(rand_fun)(mapped_keys)
ref_nums = rand_fun(mapped_keys[0], shape=(3,))
self.assertArraysEqual(rand_nums, ref_nums)

def test_vmap_random_bits_distribution(self):
dtype = jnp.float32
keys = lambda: jax.random.split(self.make_key(0), 10)

def rand(key):
nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key)
return nums.flatten()

crand = jax.jit(rand)

uncompiled_samples = rand(keys())
compiled_samples = crand(keys())

for samples in [uncompiled_samples, compiled_samples]:
self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)

def test_cannot_add(self):
key = self.make_key(73)
Expand Down Expand Up @@ -1455,6 +1489,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def make_key(self, seed):
return random.PRNGKey(seed, impl="unsafe_rbg")

@jtu.skip_on_devices("tpu")
@jax.enable_key_reuse_checks(False)
def test_vmap_split_mapped_key_values(self):
key = self.make_key(73)
mapped_keys = random.split(key, num=3)
vmapped_keys = vmap(random.split)(mapped_keys)
ref_keys = random.split(mapped_keys[0], (3, 2))
self.assertArraysEqual(random.key_data(vmapped_keys),
random.key_data(ref_keys))

def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')
Expand Down

0 comments on commit c4cf265

Please sign in to comment.