diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c8356b199f6b..fba3360f0942 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4964,6 +4964,22 @@ def _multi_slice(arr: DeviceArray, return results setattr(DeviceArray, "_multi_slice", _multi_slice) +# The next two functions are related to iter(device_array), implemented here to +# avoid circular imports. +@jit +def _unstack(x): + return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] +setattr(DeviceArray, "_unstack", _unstack) +def _chunk_iter(x, size): + if size > x.shape[0]: + yield x + else: + num_chunks, tail = divmod(x.shape[0], size) + for i in range(num_chunks): + yield lax.dynamic_slice_in_dim(x, i * size, size) + if tail: + yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail) +setattr(DeviceArray, "_chunk_iter", _chunk_iter) # Syntactic sugar for scatter operations. class _IndexUpdateHelper: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index c6849fbaf74c..794babdf412e 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -452,14 +452,24 @@ def __getitem__(self, idx): buf_idx = None if buf_idx is not None: buf = self.device_buffers[buf_idx] - # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 - # is the default. - aval = ShapedArray( - getattr(buf, "xla_shape", buf.shape)().dimensions(), - self.aval.dtype) + # TODO(jblespiau): use buf.xla_shape() after jaxlib==0.1.58 is default + aval = ShapedArray(getattr(buf, "xla_shape", buf.shape)().dimensions(), + self.aval.dtype) return xla.make_device_array(aval, None, lazy.array(aval.shape), buf) return super(ShardedDeviceArray, self).__getitem__(idx) + def __iter__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") # same as numpy error + else: + return (self[i] for i in range(self.shape[0])) + + def __reversed__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") # same as numpy error + else: + return (self[i] for i in range(self.shape[0] - 1, -1, -1)) + def _hashable_index(idx): return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index ab2bdfdb827d..6a47f416c270 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -1135,13 +1135,21 @@ def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else: - return self._value.__iter__() + device = self.device_buffer.device() + if device is None or device.platform == 'cpu': + # do the slicing in NumPy for better performance + return iter(self._value) + device = device or jax.devices('cpu') + aval = ShapedArray(self.aval.shape[1:], self.aval.dtype, + self.aval.weak_type) + lexpr = lazy.array(aval.shape) + return (make_device_array(aval, device, lexpr, *device_put(x, device)) + for x in self._value) + else: + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) def __reversed__(self): - if self.ndim == 0: - raise TypeError("iteration over a 0-d array") - else: - return reversed(self._value) + return iter(self[::-1]) def __format__(self, format_spec): # Simulates behavior of https://github.com/numpy/numpy/pull/9883