Skip to content

Commit

Permalink
Merge pull request #15479 from jakevdp:einsum-element-type
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 523161165
  • Loading branch information
jax authors committed Apr 10, 2023
2 parents 4a5bf29 + 3eb61e1 commit 646339e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
13 changes: 9 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3144,6 +3144,7 @@ def einsum(
out=None,
optimize="optimal",
precision=None,
preferred_element_type=None,
_use_xeinsum=False,
_dot_general=lax.dot_general,
):
Expand Down Expand Up @@ -3176,7 +3177,8 @@ def einsum(

_einsum_computation = jax.named_call(
_einsum, name=spec) if spec is not None else _einsum
return _einsum_computation(operands, contractions, precision, _dot_general)
return _einsum_computation(operands, contractions, precision,
preferred_element_type, _dot_general)


# Enable other modules to override einsum_contact_path.
Expand All @@ -3201,11 +3203,12 @@ def _removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))


@partial(jit, static_argnums=(1, 2, 3))
@partial(jit, static_argnums=(1, 2, 3, 4), inline=True)
def _einsum(
operands: Sequence,
contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
precision,
preferred_element_type,
_dot_general=lax.dot_general,
):
operands = list(util.promote_dtypes(*operands))
Expand Down Expand Up @@ -3320,11 +3323,13 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
if names == result_names:
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
operand = _dot_general(rhs, lhs, dimension_numbers, precision)
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
preferred_element_type=preferred_element_type)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
operand = _dot_general(lhs, rhs, dimension_numbers, precision)
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
preferred_element_type=preferred_element_type)
else:
raise NotImplementedError # if this is actually reachable, open an issue!

Expand Down
15 changes: 15 additions & 0 deletions tests/lax_numpy_einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,21 @@ def test_no_unnecessary_transpose(self):
jaxpr = jax.make_jaxpr(partial(jnp.einsum, "ijk,kl->ijl"))(x, y)
self.assertNotIn('transpose', str(jaxpr))

def test_preferred_element_type(self):
r = self.rng()
x = r.randn(2, 2).astype('bfloat16')
y = r.randn(2).astype('bfloat16')
pattern = "ij,j->i"
f1 = partial(jnp.einsum, pattern)
jaxpr = jax.make_jaxpr(f1)(x, y)
self.assertLen(jaxpr.eqns, 1)
self.assertIsNone(jaxpr.eqns[0].params['preferred_element_type'])

f2 = partial(jnp.einsum, pattern, preferred_element_type='float32')
jaxpr = jax.make_jaxpr(f2)(x, y)
self.assertLen(jaxpr.eqns, 1)
self.assertEqual(jaxpr.eqns[0].params['preferred_element_type'], 'float32')


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 646339e

Please sign in to comment.