Skip to content

Commit

Permalink
Merge pull request #8043 from hawkinsp:iter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 406822933
  • Loading branch information
jax authors committed Nov 1, 2021
2 parents 5ae0795 + 96623c3 commit 335857b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 18 deletions.
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_slice_p.bind(operand, *start_indices,
slice_sizes=tuple(slice_sizes))
slice_sizes=core.canonicalize_shape(slice_sizes))

def dynamic_update_slice(operand: Array, update: Array,
start_indices: Array) -> Array:
Expand Down Expand Up @@ -1362,7 +1362,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
operator.
"""
permutation = tuple(permutation)
permutation = tuple(operator.index(d) for d in permutation)
if (permutation == tuple(range(np.ndim(operand)))
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
Expand Down
16 changes: 16 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6699,6 +6699,22 @@ def _multi_slice(arr,
results.append(sliced)
return results

# The next two functions are related to iter(device_array), implemented here to
# avoid circular imports.
@jit
def _unstack(x):
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(DeviceArray, "_unstack", _unstack)
def _chunk_iter(x, size):
if size > x.shape[0]:
yield x
else:
num_chunks, tail = divmod(x.shape[0], size)
for i in range(num_chunks):
yield lax.dynamic_slice_in_dim(x, i * size, size)
if tail:
yield lax.dynamic_slice_in_dim(x, num_chunks * size, tail)
setattr(DeviceArray, "_chunk_iter", _chunk_iter)

# Syntactic sugar for scatter operations.
class _IndexUpdateHelper:
Expand Down
36 changes: 26 additions & 10 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,21 +622,34 @@ def _sda_value(self):


def _sda__getitem__(self, idx):
self._check_if_deleted()
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
try:
buf_idx = self.indices.index(cidx)
except ValueError:
# NOTE: Slow path, this will materialize the sharded array on a single
# device and use XLA's Gather to index into the resulting array.
return xla.DeviceArray.__getitem__(self, idx)
if self._npy_value is None:
try:
buf_idx = self.indices.index(cidx)
except ValueError:
buf_idx = None
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return super(self.__class__, self).__getitem__(idx)


def _sda__iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
self._check_if_deleted()
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return xla.make_device_array(aval, None, buf)
return (self[i] for i in range(self.shape[0]))

def _sda__reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0] - 1, -1, -1))


for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
Expand All @@ -647,6 +660,8 @@ def _sda__getitem__(self, idx):
setattr(sda, "block_until_ready", _sda_block_until_ready)
setattr(sda, "_value", property(_sda_value))
setattr(sda, "__getitem__", _sda__getitem__)
setattr(sda, "__iter__", _sda__iter__)
setattr(sda, "__reversed__", _sda__reversed__)

del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__)
Expand All @@ -659,6 +674,7 @@ def _sda__getitem__(self, idx):
ShardedDeviceArray = _ShardedDeviceArray



def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
idx)
Expand Down
7 changes: 2 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,15 +1596,12 @@ def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return self._value.__iter__()
return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())

setattr(device_array, "__iter__", __iter__)

def __reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array")
else:
return reversed(self._value)
return iter(self[::-1])

setattr(device_array, "__reversed__", __reversed__)

Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,7 +2281,7 @@ def test_device_array_repr(self):
self.assertStartsWith(repr(rep), "DeviceArray")

def test_device_array_hash(self):
rep = jnp.ones(()) + 1.
rep = jnp.ones((1,)) + 1.
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
self.assertNotIsInstance(rep, collections.abc.Hashable)
with self.assertRaisesRegex(TypeError, 'unhashable type'):
Expand Down
11 changes: 11 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np

import jax
import jax.numpy as jnp
from jax import core
from jax._src import dtypes
from jax import lax
Expand Down Expand Up @@ -1497,6 +1498,12 @@ def testDynamicSliceInDim(self):
x = rng((6, 7), np.int32)
np.testing.assert_equal(lax.dynamic_slice_in_dim(x, 2, 3), x[2:5])

def testDynamicSliceArraySliceSizes(self):
rng = jtu.rand_default(self.rng())
x = rng((6, 7), np.int32)
np.testing.assert_equal(lax.dynamic_slice(x, [2, 3], jnp.array([2, 2])),
x[2:4, 3:5])

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_indices={}_update_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype),
Expand Down Expand Up @@ -1556,6 +1563,10 @@ def testTranspose(self, shape, dtype, perm):
op = lambda x: lax.transpose(x, perm)
self._CompileAndCheck(op, args_maker)

def testTransposeWithArrayPermutation(self):
x = lax.transpose(np.ones((2, 3)), jnp.array([1, 0]))
self.assertEqual((3, 2), x.shape)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_perm={}".format(
jtu.format_shape_dtype_string(shape, dtype), perm),
Expand Down

0 comments on commit 335857b

Please sign in to comment.