Skip to content

Commit

Permalink
don't device transfer in convert_element_type
Browse files Browse the repository at this point in the history
Co-authored-by: Qiao Zhang <[email protected]>
  • Loading branch information
mattjj and zhangqiaorjc committed Mar 9, 2021
1 parent 9577860 commit 1bd34f4
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 59 deletions.
6 changes: 2 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2)

if not isinstance(operand, (core.Tracer, xla.DeviceArray)):
return _device_put_raw(np.asarray(operand, dtype=new_dtype),
weak_type=new_weak_type)
elif (old_dtype, old_weak_type) == (new_dtype, new_weak_type):
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def entr(x):
@_wraps(osp_special.multigammaln, update_doc=False)
def multigammaln(a, d):
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d = _promote_args_inexact("multigammaln", a, d)
a, d_ = _promote_args_inexact("multigammaln", a, d)

constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
lax.sub(d, _constant_like(a, 1))),
constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d_),
lax.sub(d_, _constant_like(a, 1))),
lax.log(_constant_like(a, np.pi)))
res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
lax.div(jnp.arange(d), _constant_like(a, 2))),
lax.div(jnp.arange(d), _constant_like(d, 2))),
axis=-1)
return res + constant

Expand Down
61 changes: 18 additions & 43 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,7 @@ def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
self.assertLen(inner_jaxpr.eqns, 2 if config.omnistaging_enabled else 3)
self.assertLen(inner_jaxpr.eqns, 3)
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')

Expand Down Expand Up @@ -2090,13 +2090,15 @@ def f(x):
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)

def test_omnistaging_flag(self):
x = jnp.array(1)

if FLAGS.jax_omnistaging:
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
jaxpr = api.make_jaxpr(lambda: jnp.add(x, x))()
self.assertLen(jaxpr.jaxpr.eqns, 1)
else:
# omnistaging can be enabled programmatically without setting the flag,
# but that shouldn't happen in tests
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
jaxpr = api.make_jaxpr(lambda: jnp.add(x, x))()
self.assertLen(jaxpr.jaxpr.eqns, 0)

def test_eval_context(self):
Expand Down Expand Up @@ -2134,8 +2136,6 @@ def f():
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()

# TODO(jakevdp): re-enable this if possible.
@unittest.skipIf(True, "broken by convert_element_type change.")
def test_xla_computation_zeros_doesnt_device_put(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
Expand Down Expand Up @@ -2773,11 +2773,17 @@ def fun(x):
self.assertMultiLineStrippedEqual(expected, str(jaxpr))

def test_cond(self):
# we turn these into jax arrays to avoid convert_element_type pollution in
# the jaxpr
zero = jnp.array(0.)
one = jnp.array(1.)
two = jnp.array(2.)

def f(x):
return lax.cond(x >= 0.,
x + 1.,
return lax.cond(x >= zero,
x + one,
lambda xt: xt + x,
x + 2.,
x + two,
lambda xf: xf - x)
if config.omnistaging_enabled:
expected = """
Expand Down Expand Up @@ -4780,48 +4786,17 @@ def t_(_, t): return id_(7.)
class InvertibleADTest(jtu.JaxTestCase):

def test_invertible_basic(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("Test requires omnistaging")

def f(x):
return (jnp.exp(x) * 4) * x

finv = jax.invertible(f)

x = jnp.ones((5,))

if config.omnistaging_enabled:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = mul b a
g = div e a
h = mul b g
i = div g 4.0
j = mul f 4.0
_ = log i
k = mul j i
l = add_any h k
in (l,) }
"""
else:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = div e a
g = mul b f
h = mul b a
i = mul h 4.0
j = div f 4.0
k = mul i j
l = add_any g k
in (l,) }
"""

jaxpr = jax.make_jaxpr(lambda p, ct: jax.vjp(finv, p)[1](ct))(x, x)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))

self.assertIn('log', str(jaxpr))
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x)))(x),
jax.value_and_grad(lambda x: np.sum(finv(x)))(x),
check_dtypes=True)
Expand Down
3 changes: 3 additions & 0 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,9 @@ def test_tap_grad_primal_unused(self):
if not config.omnistaging_enabled:
raise SkipTest("Test requires omnistaging")

# TODO(mattjj,necula): revive this test w/o testing against a jaxpr
raise SkipTest("skipped because testing against jaxpr text is brittle")

# The output of id_print is not needed for backwards pass
def func(x):
return 2. * hcb.id_print(x * 3., what="x * 3",
Expand Down
12 changes: 4 additions & 8 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,15 +967,11 @@ def test_prng_errors(self):
with self.assertRaises(OverflowError):
api.jit(random.PRNGKey)(seed)

def test_random_split_doesnt_device_put(self):
# TODO(mattjj): Enable this after fixing convert_element_type.
raise SkipTest("Broken by convert_element_type.")
key = random.PRNGKey(1)
def test_random_split_doesnt_device_put_during_tracing(self):
key = random.PRNGKey(1).block_until_ready()
with jtu.count_device_put() as count:
key, _ = random.split(key, 2)
self.assertEqual(count[0], 0)


api.jit(random.split)(key)
self.assertEqual(count[0], 1) # 1 for the argument device_put call


if __name__ == "__main__":
Expand Down

0 comments on commit 1bd34f4

Please sign in to comment.