Skip to content

Commit

Permalink
make iter(DeviceArray) return DeviceArrays w/o sync
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj authored and hawkinsp committed Sep 29, 2021
1 parent 22dce0f commit 41f7cd0
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 16 deletions.
16 changes: 16 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6681,6 +6681,22 @@ def _multi_slice(arr,
results.append(sliced)
return results

# The next two functions are related to iter(device_array), implemented here to
# avoid circular imports.
@jit
def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(DeviceArray, "_unstack", _unstack)
def _chunk_iter(x, size):
if size > x.shape[0]:
yield x
else:
num_chunks, tail = divmod(x.shape[0], size)
for i in range(num_chunks):
yield lax.dynamic_slice_in_dim(x, i * size, size)
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(DeviceArray, "_chunk_iter", _chunk_iter)

# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
Expand Down
36 changes: 26 additions & 10 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,21 +618,34 @@ def _sda_value(self):


def _sda__getitem__(self, idx):
self._check_if_deleted()
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
try:
buf_idx = self.indices.index(cidx)
except ValueError:
# NOTE: Slow path, this will materialize the sharded array on a single
# device and use XLA's Gather to index into the resulting array.
return xla.DeviceArray.__getitem__(self, idx)
if self._npy_value is None:
try:
buf_idx = self.indices.index(cidx)
except ValueError:
buf_idx = None
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return super(self.__class__, self).__getitem__(idx)


def _sda__iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
self._check_if_deleted()
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return (self[i] for i in range(self.shape[0]))

def _sda__reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0] - 1, -1, -1))


for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
Expand All @@ -643,6 +656,8 @@ def _sda__getitem__(self, idx):
setattr(sda, "block_until_ready", _sda_block_until_ready)
setattr(sda, "_value", property(_sda_value))
setattr(sda, "__getitem__", _sda__getitem__)
setattr(sda, "__iter__", _sda__iter__)
setattr(sda, "__reversed__", _sda__reversed__)

del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__)
Expand All @@ -655,6 +670,7 @@ def _sda__getitem__(self, idx):
ShardedDeviceArray = _ShardedDeviceArray



def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
idx)
Expand Down
20 changes: 15 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,15 +1347,12 @@ def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return self._value.__iter__()
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())

setattr(device_array, "__iter__", __iter__)

def __reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array")
else:
return reversed(self._value)
return iter(self[::-1])

setattr(device_array, "__reversed__", __reversed__)

Expand Down Expand Up @@ -1397,6 +1394,19 @@ def __array__(self, dtype=None, context=None):
# clobbered when jax.numpy is imported, but useful in tests
setattr(device_array, "__eq__", lambda self, other: self._value == other)

# def __hash__(self):
# if self.ndim == 0:
# # We allow 0D DeviceArrays to be hashable, mainly so that when we unpack a
# # 1D DeviceArray with __iter__ we get an iterable of hashable values. That
# # is loosely analogous to how NumPy unpacks 1D arrays into hashable NumPy
# # scalars (but JAX doesn't have special scalar values distinct from 0D
# # arrays).
# return hash(self.item())
# else:
# raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")

# setattr(device_array, "__hash__", __hash__)

# The following methods are dynamically overridden in lax_numpy.py.
def raise_not_implemented():
raise NotImplementedError
Expand Down
7 changes: 6 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,7 @@ def test_device_array_repr(self):
self.assertStartsWith(repr(rep), "DeviceArray")

def test_device_array_hash(self):
rep = jnp.ones(()) + 1.
rep = jnp.ones((1,)) + 1.
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
self.assertNotIsInstance(rep, collections.Hashable)
with self.assertRaisesRegex(TypeError, 'unhashable type'):
Expand Down Expand Up @@ -3034,6 +3034,11 @@ def test_jnp_array_doesnt_device_put(self):
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)

# def test_device_array_unpacking_1D_hashable(self):
# xs = device_put(np.array([1, 1]))
# x, _ = xs
# hash(x) # doesn't crash


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 41f7cd0

Please sign in to comment.