From 29edfd89251c98a647ca6af33b813f50fb80c373 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 5 Mar 2024 20:09:14 -0800 Subject: [PATCH] define a loop-free untrue batching rule for `rng_bit_generator` --- jax/_src/lax/control_flow/loops.py | 16 ++++---- jax/_src/random.py | 2 +- tests/BUILD | 3 ++ tests/lax_test.py | 18 +++++++++ tests/random_lax_test.py | 59 ++++++++++++++++++++++++++---- 5 files changed, 82 insertions(+), 16 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index e939f2170f8a..99fbb010130a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2046,16 +2046,18 @@ def map(f, xs): return ys def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): - """Calls RBG in a loop and stacks the results.""" - key, = batched_args + keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype, + return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, algorithm=algorithm), (None, None) - key = batching.moveaxis(key, bd, 0) - map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm) - stacked_keys, stacked_bits = map(map_body, key) - return (stacked_keys, stacked_bits), (0, 0) + keys = batching.moveaxis(keys, bd, 0) + batch_size = keys.shape[0] + key = keys[0] + new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), + dtype=dtype, algorithm=algorithm) + new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) + return (new_keys, bits), (0, 0) batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore diff --git a/jax/_src/random.py b/jax/_src/random.py index f9045ebf3135..31d10db01c05 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1233,7 +1233,7 @@ def _gamma_impl(key, a, *, log_space, use_vmap=False): keys = keys.flatten() alphas = a.flatten() - if use_vmap: + if use_vmap and _key_impl(key) is prng.threefry_prng_impl: samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas) else: samples = lax.map( diff --git a/tests/BUILD b/tests/BUILD index 9c8ca93103b7..6c72e9da98f0 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -784,6 +784,9 @@ jax_test( "notsan", # Times out ], }, + backend_variant_args = { + "gpu": ["--jax_num_generated_cases=40"], + }, shard_count = { "cpu": 40, "gpu": 30, diff --git a/tests/lax_test.py b/tests/lax_test.py index aadac1d64566..613164650bd8 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2652,6 +2652,24 @@ def testRngBitGeneratorReturnedKey(self): new_key, _ = lax.rng_bit_generator(key, (0,)) self.assertAllClose(key, new_key) + def test_rng_bit_generator_vmap(self): + def f(key): + return lax.rng_bit_generator(key, shape=(5, 7)) + + keys = np.arange(3 * 4).reshape((3, 4)).astype(np.uint32) + out_keys, bits = jax.vmap(f)(keys) + self.assertEqual(out_keys.shape, (3, 4)) + self.assertEqual(bits.shape, (3, 5, 7)) + + def test_rng_bit_generator_vmap_vmap(self): + def f(key): + return lax.rng_bit_generator(key, shape=(5, 7)) + + keys = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.uint32) + out_keys, bits = jax.vmap(jax.vmap(f))(keys) + self.assertEqual(out_keys.shape, (2, 3, 4)) + self.assertEqual(bits.shape, (2, 3, 5, 7)) + @jtu.sample_product( dtype=lax_test_util.all_dtypes + lax_test_util.python_scalar_types, weak_type=[True, False], diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 52e8cbc7b262..f6280a5e02ec 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1348,6 +1348,7 @@ def test_vmap_fold_in_shape(self): out = vmap(vmap(random.fold_in), in_axes=(1, 0))(keys(), msgs.T) self.assertEqual(out.shape, (3, 2)) + @jax.enable_key_reuse_checks(False) def test_vmap_split_mapped_key(self): key = self.make_key(73) mapped_keys = random.split(key, num=3) @@ -1408,24 +1409,57 @@ def test_vmap_split_not_mapped_key(self): self.assertArraysEqual(random.key_data(vk), random.key_data(single_split_key)) - def test_vmap_split_mapped_key(self): + @jax.enable_key_reuse_checks(False) + def test_vmap_split_mapped_key_shape(self): key = self.make_key(73) mapped_keys = random.split(key, num=3) - forloop_keys = [random.split(k) for k in mapped_keys] vmapped_keys = vmap(random.split)(mapped_keys) self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape)) - for fk, vk in zip(forloop_keys, vmapped_keys): - self.assertArraysEqual(random.key_data(fk), + + @jax.enable_key_reuse_checks(False) + def test_vmap_split_mapped_key_values(self): + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + vmapped_keys = vmap(random.split)(mapped_keys) + ref_keys = [random.split(k) for k in mapped_keys] + for rk, vk in zip(ref_keys, vmapped_keys): + self.assertArraysEqual(random.key_data(rk), random.key_data(vk)) - def test_vmap_random_bits(self): - rand_fun = lambda key: random.randint(key, (), 0, 100) + @jax.enable_key_reuse_checks(False) + def test_vmap_random_bits_shape(self): + rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100) key = self.make_key(73) mapped_keys = random.split(key, num=3) - forloop_rand_nums = [rand_fun(k) for k in mapped_keys] rand_nums = vmap(rand_fun)(mapped_keys) self.assertEqual(rand_nums.shape, (3,)) - self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums)) + + @jtu.skip_on_devices("tpu") + @jax.enable_key_reuse_checks(False) + def test_vmap_random_bits_value(self): + rand_fun = lambda key, shape=(): random.randint(key, shape, 0, 100) + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + rand_nums = vmap(rand_fun)(mapped_keys) + ref_nums = rand_fun(mapped_keys[0], shape=(3,)) + self.assertArraysEqual(rand_nums, ref_nums) + + def test_vmap_random_bits_distribution(self): + dtype = jnp.float32 + keys = lambda: jax.random.split(self.make_key(0), 10) + + def rand(key): + nums = jax.vmap(lambda key: random.uniform(key, (1000,), dtype))(key) + return nums.flatten() + + crand = jax.jit(rand) + + uncompiled_samples = rand(keys()) + compiled_samples = crand(keys()) + + for samples in [uncompiled_samples, compiled_samples]: + self._CheckCollisions(samples, jnp.finfo(dtype).nmant) + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf) def test_cannot_add(self): key = self.make_key(73) @@ -1455,6 +1489,15 @@ class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): def make_key(self, seed): return random.PRNGKey(seed, impl="unsafe_rbg") + @jtu.skip_on_devices("tpu") + @jax.enable_key_reuse_checks(False) + def test_vmap_split_mapped_key_values(self): + key = self.make_key(73) + mapped_keys = random.split(key, num=3) + vmapped_keys = vmap(random.split)(mapped_keys) + ref_keys = random.split(mapped_keys[0], (3, 2)) + self.assertArraysEqual(random.key_data(vmapped_keys), + random.key_data(ref_keys)) def _sampler_unimplemented_with_custom_prng(*args, **kwargs): raise SkipTest('sampler only implemented for default RNG')