diff --git a/jax/test_util.py b/jax/test_util.py index a5e9eb38b4cd..2e9652a26c70 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -314,6 +314,22 @@ def f_vjp(*args): _check_grads(f, args, order) +@contextmanager +def count_device_put(): + device_put = xla.device_put + count = [0] + + def device_put_and_count(*args, **kwargs): + count[0] += 1 + return device_put(*args, **kwargs) + + xla.device_put = device_put_and_count + try: + yield count + finally: + xla.device_put = device_put + + @contextmanager def count_primitive_compiles(): xla.xla_primitive_callable.cache_clear() diff --git a/tests/api_test.py b/tests/api_test.py index f3f562e5c4f7..0c96d75a03b4 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2134,23 +2134,15 @@ def f(): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f() - # TODO(jakevdp): re-enable this if possible. - @unittest.skipIf(True, "broken by convert_element_type change.") def test_xla_computation_zeros_doesnt_device_put(self): + raise unittest.SkipTest("broken test") # TODO(mattjj): fix + if not config.omnistaging_enabled: raise unittest.SkipTest("test is omnistaging-specific") - count = 0 - def device_put_and_count(*args, **kwargs): - nonlocal count - count += 1 - return orig_device_put(*args, **kwargs) - orig_device_put, xla.device_put = xla.device_put, device_put_and_count - try: + with jtu.count_device_put() as count: api.xla_computation(lambda: jnp.zeros(3))() - finally: - xla.device_put = orig_device_put - self.assertEqual(count, 0) + self.assertEqual(count[0], 0) def test_join_concrete_arrays_with_omnistaging(self): # https://github.com/google/jax/issues/4622 diff --git a/tests/random_test.py b/tests/random_test.py index 3b7dc7e126cb..cd74c8dd66ff 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -967,6 +967,20 @@ def test_prng_errors(self): with self.assertRaises(OverflowError): api.jit(random.PRNGKey)(seed) + def test_random_split_doesnt_device_put_during_tracing(self): + raise SkipTest("broken test") # TODO(mattjj): fix + + if not config.omnistaging_enabled: + raise SkipTest("test is omnistaging-specific") + + key = random.PRNGKey(1) + with jtu.count_device_put() as count: + api.jit(random.split)(key) + key, _ = random.split(key, 2) + self.assertEqual(count[0], 1) # 1 for the argument device_put call + + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())