Skip to content

Commit

Permalink
[JAX] Update JAX users in preparation for a change that makes iterati…
Browse files Browse the repository at this point in the history
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406150403
  • Loading branch information
hawkinsp authored and jax authors committed Oct 28, 2021
1 parent 934bfc0 commit 55dc485
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
19 changes: 13 additions & 6 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,13 +722,15 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
Returns:
An array containing the result.
"""
contract_dims_seq, batch_dims_seq = dimension_numbers
contract_dims = tuple(map(tuple, contract_dims_seq)) # type: ignore
batch_dims = tuple(map(tuple, batch_dims_seq)) # type: ignore
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
api_util._ensure_index_tuple(rhs_contract))
bdims = (api_util._ensure_index_tuple(lhs_batch),
api_util._ensure_index_tuple(rhs_batch))
preferred_element_type = (None if preferred_element_type is None else
np.dtype(preferred_element_type))
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(contract_dims, batch_dims),
dimension_numbers=(cdims, bdims),
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type)

Expand Down Expand Up @@ -822,14 +824,19 @@ def reshape(operand: Array, new_sizes: Shape,
new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
same_dims = dimensions is None or tuple(dimensions) == tuple(range(np.ndim(operand)))
if dimensions is None:
same_dims = True
dims = None
else:
dims = api_util._ensure_index_tuple(dimensions)
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
if (np.shape(operand) and same_shape and same_dims
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
else:
return reshape_p.bind(
operand, new_sizes=new_sizes,
dimensions=None if dimensions is None or same_dims else tuple(dimensions))
dimensions=None if dims is None or same_dims else dims)

def pad(operand: Array, padding_value: Array,
padding_config: Sequence[Tuple[int, int, int]]) -> Array:
Expand Down
4 changes: 2 additions & 2 deletions tests/ann_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall):
hits = sum(
len(list(x
for x in ann_args_per_q
if x in gt_args_sets[q]))
if x.item() in gt_args_sets[q]))
for q, ann_args_per_q in enumerate(ann_args))
self.assertGreater(hits / (qy_shape[0] * k), recall)

Expand Down Expand Up @@ -81,7 +81,7 @@ def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall):
hits = sum(
len(list(x
for x in ann_args_per_q
if x in gt_args_sets[q]))
if x.item() in gt_args_sets[q]))
for q, ann_args_per_q in enumerate(ann_args))
self.assertGreater(hits / (qy_shape[0] * k), recall)

Expand Down

0 comments on commit 55dc485

Please sign in to comment.