diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a88c20bb3625..c59ac5994e08 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -926,7 +926,7 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array], """ start_indices = _dynamic_slice_indices(operand, start_indices) return dynamic_slice_p.bind(operand, *start_indices, - slice_sizes=tuple(slice_sizes)) + slice_sizes=core.canonicalize_shape(slice_sizes)) def dynamic_update_slice(operand: Array, update: Array, start_indices: Array) -> Array: @@ -1362,7 +1362,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array: `_ operator. """ - permutation = tuple(permutation) + permutation = tuple(operator.index(d) for d in permutation) if (permutation == tuple(range(np.ndim(operand))) and isinstance(operand, (core.Tracer, xla.DeviceArray))): return operand diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4e7952a71c37..8cc96e6f2922 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6699,6 +6699,22 @@ def _multi_slice(arr, results.append(sliced) return results +# The next two functions are related to iter(device_array), implemented here to +# avoid circular imports. +@jit +def _unstack(x): + return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] +setattr(DeviceArray, "_unstack", _unstack) +def _chunk_iter(x, size): + if size > x.shape[0]: + yield x + else: + num_chunks, tail = divmod(x.shape[0], size) + for i in range(num_chunks): + yield lax.dynamic_slice_in_dim(x, i * size, size) + if tail: + yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail) +setattr(DeviceArray, "_chunk_iter", _chunk_iter) # Syntactic sugar for scatter operations. class _IndexUpdateHelper: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 3fbd16a77afa..23090c27e2f7 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -622,21 +622,34 @@ def _sda_value(self): def _sda__getitem__(self, idx): + self._check_if_deleted() if not isinstance(idx, tuple): cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1) else: cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx)) - try: - buf_idx = self.indices.index(cidx) - except ValueError: - # NOTE: Slow path, this will materialize the sharded array on a single - # device and use XLA's Gather to index into the resulting array. - return xla.DeviceArray.__getitem__(self, idx) + if self._npy_value is None: + try: + buf_idx = self.indices.index(cidx) + except ValueError: + buf_idx = None + if buf_idx is not None: + buf = self.device_buffers[buf_idx] + aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype) + return xla.make_device_array(aval, None, buf) + return super(self.__class__, self).__getitem__(idx) + + +def _sda__iter__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") # same as numpy error else: - self._check_if_deleted() - buf = self.device_buffers[buf_idx] - aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype) - return xla.make_device_array(aval, None, buf) + return (self[i] for i in range(self.shape[0])) + +def _sda__reversed__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") # same as numpy error + else: + return (self[i] for i in range(self.shape[0] - 1, -1, -1)) for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]: @@ -647,6 +660,8 @@ def _sda__getitem__(self, idx): setattr(sda, "block_until_ready", _sda_block_until_ready) setattr(sda, "_value", property(_sda_value)) setattr(sda, "__getitem__", _sda__getitem__) + setattr(sda, "__iter__", _sda__iter__) + setattr(sda, "__reversed__", _sda__reversed__) del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async, _sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__) @@ -659,6 +674,7 @@ def _sda__getitem__(self, idx): ShardedDeviceArray = _ShardedDeviceArray + def _hashable_index(idx): return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index db50f0ec2ab8..5339ea7ef88a 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -1596,15 +1596,12 @@ def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else: - return self._value.__iter__() + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) setattr(device_array, "__iter__", __iter__) def __reversed__(self): - if self.ndim == 0: - raise TypeError("iteration over a 0-d array") - else: - return reversed(self._value) + return iter(self[::-1]) setattr(device_array, "__reversed__", __reversed__) diff --git a/tests/api_test.py b/tests/api_test.py index e1c227984f4b..6bd10dc1bbb8 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2281,7 +2281,7 @@ def test_device_array_repr(self): self.assertStartsWith(repr(rep), "DeviceArray") def test_device_array_hash(self): - rep = jnp.ones(()) + 1. + rep = jnp.ones((1,)) + 1. self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray) self.assertNotIsInstance(rep, collections.abc.Hashable) with self.assertRaisesRegex(TypeError, 'unhashable type'): diff --git a/tests/lax_test.py b/tests/lax_test.py index a5422a46509f..0f70bb00924b 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -26,6 +26,7 @@ import numpy as np import jax +import jax.numpy as jnp from jax import core from jax._src import dtypes from jax import lax @@ -1497,6 +1498,12 @@ def testDynamicSliceInDim(self): x = rng((6, 7), np.int32) np.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5]) + def testDynamicSliceArraySliceSizes(self): + rng = jtu.rand_default(self.rng()) + x = rng((6, 7), np.int32) + np.testing.assert_equal(lax.dynamic_slice(x, [2, 3], jnp.array([2, 2])), + x[2:4, 3:5]) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_indices={}_update_shape={}".format( jtu.format_shape_dtype_string(shape, dtype), @@ -1556,6 +1563,10 @@ def testTranspose(self, shape, dtype, perm): op = lambda x: lax.transpose(x, perm) self._CompileAndCheck(op, args_maker) + def testTransposeWithArrayPermutation(self): + x = lax.transpose(np.ones((2, 3)), jnp.array([1, 0])) + self.assertEqual((3, 2), x.shape) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( jtu.format_shape_dtype_string(shape, dtype), perm),