From 41f7cd034f41ce7ab0365237e7bf3d4938d1ca96 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 22 Jul 2020 12:10:43 -0700 Subject: [PATCH] make iter(DeviceArray) return DeviceArrays w/o sync --- jax/_src/numpy/lax_numpy.py | 16 ++++++++++++++++ jax/interpreters/pxla.py | 36 ++++++++++++++++++++++++++---------- jax/interpreters/xla.py | 20 +++++++++++++++----- tests/api_test.py | 7 ++++++- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 33b3a8776097..51b0a0c2999c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6681,6 +6681,22 @@ def _multi_slice(arr, results.append(sliced) return results +# The next two functions are related to iter(device_array), implemented here to +# avoid circular imports. +@jit +def _unstack(x): + return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] +setattr(DeviceArray, "_unstack", _unstack) +def _chunk_iter(x, size): + if size > x.shape[0]: + yield x + else: + num_chunks, tail = divmod(x.shape[0], size) + for i in range(num_chunks): + yield lax.dynamic_slice_in_dim(x, i * size, size) + if tail: + yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail) +setattr(DeviceArray, "_chunk_iter", _chunk_iter) # Syntactic sugar for scatter operations. class _IndexUpdateHelper: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 980cc80b68c5..6634529a63f3 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -618,21 +618,34 @@ def _sda_value(self): def _sda__getitem__(self, idx): + self._check_if_deleted() if not isinstance(idx, tuple): cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1) else: cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx)) - try: - buf_idx = self.indices.index(cidx) - except ValueError: - # NOTE: Slow path, this will materialize the sharded array on a single - # device and use XLA's Gather to index into the resulting array. - return xla.DeviceArray.__getitem__(self, idx) + if self._npy_value is None: + try: + buf_idx = self.indices.index(cidx) + except ValueError: + buf_idx = None + if buf_idx is not None: + buf = self.device_buffers[buf_idx] + aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype) + return xla.make_device_array(aval, None, buf) + return super(self.__class__, self).__getitem__(idx) + + +def _sda__iter__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") # same as numpy error else: - self._check_if_deleted() - buf = self.device_buffers[buf_idx] - aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype) - return xla.make_device_array(aval, None, buf) + return (self[i] for i in range(self.shape[0])) + +def _sda__reversed__(self): + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") # same as numpy error + else: + return (self[i] for i in range(self.shape[0] - 1, -1, -1)) for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]: @@ -643,6 +656,8 @@ def _sda__getitem__(self, idx): setattr(sda, "block_until_ready", _sda_block_until_ready) setattr(sda, "_value", property(_sda_value)) setattr(sda, "__getitem__", _sda__getitem__) + setattr(sda, "__iter__", _sda__iter__) + setattr(sda, "__reversed__", _sda__reversed__) del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async, _sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__) @@ -655,6 +670,7 @@ def _sda__getitem__(self, idx): ShardedDeviceArray = _ShardedDeviceArray + def _hashable_index(idx): return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 1345705401d5..f15dac25f188 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -1347,15 +1347,12 @@ def __iter__(self): if self.ndim == 0: raise TypeError("iteration over a 0-d array") # same as numpy error else: - return self._value.__iter__() + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) setattr(device_array, "__iter__", __iter__) def __reversed__(self): - if self.ndim == 0: - raise TypeError("iteration over a 0-d array") - else: - return reversed(self._value) + return iter(self[::-1]) setattr(device_array, "__reversed__", __reversed__) @@ -1397,6 +1394,19 @@ def __array__(self, dtype=None, context=None): # clobbered when jax.numpy is imported, but useful in tests setattr(device_array, "__eq__", lambda self, other: self._value == other) + # def __hash__(self): + # if self.ndim == 0: + # # We allow 0D DeviceArrays to be hashable, mainly so that when we unpack a + # # 1D DeviceArray with __iter__ we get an iterable of hashable values. That + # # is loosely analogous to how NumPy unpacks 1D arrays into hashable NumPy + # # scalars (but JAX doesn't have special scalar values distinct from 0D + # # arrays). + # return hash(self.item()) + # else: + # raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.") + + # setattr(device_array, "__hash__", __hash__) + # The following methods are dynamically overridden in lax_numpy.py. def raise_not_implemented(): raise NotImplementedError diff --git a/tests/api_test.py b/tests/api_test.py index 255bfac1fb7c..e2f5f1de35f3 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2146,7 +2146,7 @@ def test_device_array_repr(self): self.assertStartsWith(repr(rep), "DeviceArray") def test_device_array_hash(self): - rep = jnp.ones(()) + 1. + rep = jnp.ones((1,)) + 1. self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray) self.assertNotIsInstance(rep, collections.Hashable) with self.assertRaisesRegex(TypeError, 'unhashable type'): @@ -3034,6 +3034,11 @@ def test_jnp_array_doesnt_device_put(self): api.make_jaxpr(lambda: jnp.array(3))() self.assertEqual(count[0], 0) + # def test_device_array_unpacking_1D_hashable(self): + # xs = device_put(np.array([1, 1])) + # x, _ = xs + # hash(x) # doesn't crash + class RematTest(jtu.JaxTestCase):