Skip to content

Commit

Permalink
make iter(DeviceArray) return DeviceArrays w/o sync
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 17, 2020
1 parent e56c3f4 commit d8510c2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
16 changes: 16 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4964,6 +4964,22 @@ def _multi_slice(arr: DeviceArray,
return results
setattr(DeviceArray, "_multi_slice", _multi_slice)

# The next two functions are related to iter(device_array), implemented here to
# avoid circular imports.
@jit
def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(DeviceArray, "_unstack", _unstack)
def _chunk_iter(x, size):
if size > x.shape[0]:
yield x
else:
num_chunks, tail = divmod(x.shape[0], size)
for i in range(num_chunks):
yield lax.dynamic_slice_in_dim(x, i * size, size)
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(DeviceArray, "_chunk_iter", _chunk_iter)

# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
Expand Down
20 changes: 15 additions & 5 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,24 @@ def __getitem__(self, idx):
buf_idx = None
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
# TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58
# is the default.
aval = ShapedArray(
getattr(buf, "xla_shape", buf.shape)().dimensions(),
self.aval.dtype)
# TODO(jblespiau): use buf.xla_shape() after jaxlib==0.1.58 is default
aval = ShapedArray(getattr(buf, "xla_shape", buf.shape)().dimensions(),
self.aval.dtype)
return xla.make_device_array(aval, None, lazy.array(aval.shape), buf)
return super(ShardedDeviceArray, self).__getitem__(idx)

def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0]))

def __reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0] - 1, -1, -1))


def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
Expand Down
18 changes: 13 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,13 +1135,21 @@ def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return self._value.__iter__()
device = self.device_buffer.device()
if device is None or device.platform == 'cpu':
# do the slicing in NumPy for better performance
return iter(self._value)
device = device or jax.devices('cpu')
aval = ShapedArray(self.aval.shape[1:], self.aval.dtype,
self.aval.weak_type)
lexpr = lazy.array(aval.shape)
return (make_device_array(aval, device, lexpr, *device_put(x, device))
for x in self._value)
else:
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())

def __reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array")
else:
return reversed(self._value)
return iter(self[::-1])

def __format__(self, format_spec):
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
Expand Down

0 comments on commit d8510c2

Please sign in to comment.