Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adjust lax.convert_element_type bind to avoid H2D transfers during tracing #6014

Merged
merged 1 commit into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2)

if not isinstance(operand, (core.Tracer, xla.DeviceArray)):
return _device_put_raw(np.asarray(operand, dtype=new_dtype),
weak_type=new_weak_type)
elif (old_dtype, old_weak_type) == (new_dtype, new_weak_type):
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
Expand Down Expand Up @@ -2687,10 +2685,13 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
else:
return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type)

convert_element_type_p = standard_primitive(
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
'convert_element_type', _convert_element_type_translation_rule,
weak_type_rule=_convert_element_type_weak_type_rule)
convert_element_type_p = core.convert_element_type_p
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
_convert_element_type_weak_type_rule, standard_named_shape_rule))
xla.translations[convert_element_type_p] = _convert_element_type_translation_rule
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
batching.defvectorized(convert_element_type_p)
Expand Down
9 changes: 5 additions & 4 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ def entr(x):
@_wraps(osp_special.multigammaln, update_doc=False)
def multigammaln(a, d):
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d = _promote_args_inexact("multigammaln", a, d)
a, d_ = _promote_args_inexact("multigammaln", a, d)

constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
lax.sub(d, _constant_like(a, 1))),
constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d_),
lax.sub(d_, _constant_like(a, 1))),
lax.log(_constant_like(a, np.pi)))
res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
lax.div(jnp.arange(d), _constant_like(a, 2))),
lax.div(jnp.arange(d, dtype=d_.dtype),
_constant_like(a, 2))),
axis=-1)
return res + constant

Expand Down
2 changes: 2 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,8 @@ def concrete_or_error(force: Any, val: Any, context=""):
else:
return force(val)

convert_element_type_p = Primitive('convert_element_type')

class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 2
Expand Down
15 changes: 11 additions & 4 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,10 +1000,17 @@ def lit(var: core.Var) -> Optional[Any]:
new_constvars = [var(v) for v in jaxpr.constvars if not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = [new_jaxpr_eqn([lit(v) or var(v) for v in eqn.invars],
[var(v) if v in used else dropvar for v in eqn.outvars],
eqn.primitive, eqn.params, eqn.source_info)
for eqn in jaxpr.eqns]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var(v) for v in eqn.invars]
if (eqn.primitive is core.convert_element_type_p and type(invars[0]) is Literal):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(invars[0].val, eqn.params['new_dtype'])
else:
# might do DCE here, but we won't until we're more careful about effects
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals
Expand Down
59 changes: 25 additions & 34 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4671,49 +4671,40 @@ def t_(_, t): return id_(7.)

class InvertibleADTest(jtu.JaxTestCase):

@jtu.ignore_warning(message="Values that an @invertible function closes")
def test_invertible_basic(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("Test requires omnistaging")

def f(x):
return (jnp.exp(x) * 4) * x

finv = jax.invertible(f)

x = jnp.ones((5,))

if config.omnistaging_enabled:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = mul b a
g = div e a
h = mul b g
i = div g 4.0
j = mul f 4.0
_ = log i
k = mul j i
l = add_any h k
in (l,) }
"""
else:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = div e a
g = mul b f
h = mul b a
i = mul h 4.0
j = div f 4.0
k = mul i j
l = add_any g k
in (l,) }
"""

jaxpr = jax.make_jaxpr(lambda p, ct: jax.vjp(finv, p)[1](ct))(x, x)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))

# expected = """
# { lambda ; a b.
# let c = exp a
# d = mul c 4.0
# e = mul d a
# f = mul b a
# g = div e a
# h = mul b g
# i = mul f 4.0
# j = div g 4.0
# k = mul f j
# _ = reduce_sum[ axes=(0,) ] k
# _ = log j
# l = mul i j
# m = add_any h l
# in (m,) }
# """
# self.assertMultiLineStrippedEqual(expected, str(jaxpr)) # no jaxpr test

self.assertIn('div', str(jaxpr))
self.assertIn('log', str(jaxpr)) # assumes no DCE
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x)))(x),
jax.value_and_grad(lambda x: np.sum(finv(x)))(x),
check_dtypes=True)
Expand Down
12 changes: 3 additions & 9 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,18 +969,12 @@ def test_prng_errors(self):
api.jit(random.PRNGKey)(seed)

def test_random_split_doesnt_device_put_during_tracing(self):
raise SkipTest("broken test") # TODO(mattjj): fix

if not config.omnistaging_enabled:
raise SkipTest("test is omnistaging-specific")

key = random.PRNGKey(1)
raise SkipTest("test requires omnistaging")
key = random.PRNGKey(1).block_until_ready()
with jtu.count_device_put() as count:
api.jit(random.split)(key)
key, _ = random.split(key, 2)
self.assertEqual(count[0], 1) # 1 for the argument device_put call


self.assertEqual(count[0], 1) # 1 for the argument device_put


if __name__ == "__main__":
Expand Down