Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jtu.count_device_put for tests to count device_put. #5998

Merged
merged 3 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,22 @@ def f_vjp(*args):
_check_grads(f, args, order)


@contextmanager
def count_device_put():
device_put = xla.device_put
count = [0]

def device_put_and_count(*args, **kwargs):
count[0] += 1
return device_put(*args, **kwargs)

xla.device_put = device_put_and_count
try:
yield count
finally:
xla.device_put = device_put


@contextmanager
def count_primitive_compiles():
xla.xla_primitive_callable.cache_clear()
Expand Down
16 changes: 4 additions & 12 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,23 +2134,15 @@ def f():
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

# TODO(jakevdp): re-enable this if possible.
@unittest.skipIf(True, "broken by convert_element_type change.")
def test_xla_computation_zeros_doesnt_device_put(self):
raise unittest.SkipTest("broken test") # TODO(mattjj): fix

if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")

count = 0
def device_put_and_count(*args, **kwargs):
nonlocal count
count += 1
return orig_device_put(*args, **kwargs)
orig_device_put, xla.device_put = xla.device_put, device_put_and_count
try:
with jtu.count_device_put() as count:
api.xla_computation(lambda: jnp.zeros(3))()
finally:
xla.device_put = orig_device_put
self.assertEqual(count, 0)
self.assertEqual(count[0], 0)

def test_join_concrete_arrays_with_omnistaging(self):
# https://github.com/google/jax/issues/4622
Expand Down
14 changes: 14 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,20 @@ def test_prng_errors(self):
with self.assertRaises(OverflowError):
api.jit(random.PRNGKey)(seed)

def test_random_split_doesnt_device_put_during_tracing(self):
raise SkipTest("broken test") # TODO(mattjj): fix

if not config.omnistaging_enabled:
raise SkipTest("test is omnistaging-specific")

key = random.PRNGKey(1)
with jtu.count_device_put() as count:
api.jit(random.split)(key)
key, _ = random.split(key, 2)
self.assertEqual(count[0], 1) # 1 for the argument device_put call




if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())