Skip to content

Commit

Permalink
Reverts 525b646
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707146329
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Dec 17, 2024
1 parent 4911a39 commit 7de9eb2
Show file tree
Hide file tree
Showing 13 changed files with 237 additions and 31 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.export.export` can be used for device-polymorphic export with
shardings constructed with {func}`jax.sharding.AbstractMesh`.
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
* Added {func}`jax.lax.split`. This is a primitive version of
{func}`jax.numpy.split`, added because it yields a more compact
transpose during automatic differentiation.

## jax 0.4.37 (Dec 9, 2024)

Expand Down
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ Operators
slice_in_dim
sort
sort_key_val
split
sqrt
square
squeeze
Expand Down
104 changes: 92 additions & 12 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
return concatenate_p.bind(*operands, dimension=dimension)


def split(operand: ArrayLike, sizes: Sequence[int],
axis: int = 0) -> Sequence[Array]:
"""Splits an array along ``axis``.
Args:
operand: an array to split
sizes: the sizes of the split arrays. The sum of the sizes must be equal
to the size of the ``axis`` dimension of ``operand``.
axis: the axis along which to split the array.
Returns:
A sequence of ``len(sizes)`` arrays. If ``sizes`` is
``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``,
taken along ``axis``.
"""
operand = asarray(operand)
return split_p.bind(operand, sizes=tuple(sizes),
axis=canonicalize_axis(axis, operand.ndim))


_precision_strings: dict[Any, Precision] = {}

class Precision(enum.Enum):
Expand Down Expand Up @@ -4454,18 +4474,8 @@ def _concatenate_transpose_rule(t, *operands, dimension):
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
for o in operands]
else:
limit_points = np.cumsum(
[shape[dimension] for shape in operand_shapes]).tolist()
starts = np.zeros((len(operands), t.ndim), dtype=int).tolist()
limits = np.tile(t.shape, (len(operands), 1)).tolist()

for i, s in enumerate(starts[1:]):
s[dimension] = limit_points[:-1][i]
for i, l in enumerate(limits):
l[dimension] = limit_points[i]

return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o)
else None for o, start, limit in zip(operands, starts, limits)]
return split(t, tuple(shape[dimension] for shape in operand_shapes),
axis=dimension)

def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
Expand Down Expand Up @@ -4499,6 +4509,76 @@ def _concatenate_lower(ctx, *xs, dimension):
mlir.register_lowering(concatenate_p, _concatenate_lower)


def _split_shape_rule(operand, *, sizes, axis):
shapes = []
shape = list(operand.shape)
if any(s < 0 for s in sizes):
raise ValueError(
f"Sizes passed to split must be nonnegative, got {list(sizes)}")
if operand.shape[axis] != np.sum(sizes):
raise ValueError(
f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the "
f"operand shape {list(operand.shape)}")
for size in sizes:
shape[axis] = size
shapes.append(tuple(shape))
return shapes

def _split_dtype_rule(operand, *, sizes, axis):
return (operand.dtype,) * len(sizes)

def _split_weak_type_rule(operand, *, sizes, axis):
return (operand.weak_type,) * len(sizes)

def _split_transpose_rule(cotangents, operand, *, sizes, axis):
assert ad.is_undefined_primal(operand)
if all(type(t) is ad_util.Zero for t in cotangents):
return ad_util.Zero(operand.aval),
cotangents = [
_zeros(t.aval) if type(t) is ad_util.Zero else t
for t in cotangents
]
return concatenate(cotangents, dimension=axis),

def _split_batch_rule(batched_args, batch_dims, *, sizes, axis):
operand, = batched_args
bdim, = batch_dims
new_bdims = (bdim,) * len(sizes)
out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis)
return out, new_bdims

def _split_lower(ctx, x, *, sizes, axis):
x_aval, = ctx.avals_in
start_indices = [0] * x_aval.ndim
limit_indices = list(x_aval.shape)
strides = (1,) * x_aval.ndim
outs = []
for aval_out in ctx.avals_out:
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
limit_indices=limit_indices, strides=strides)
outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out)
if config.sharding_in_types.value else out)
start_indices[axis] = limit_indices[axis]
return outs

def _split_sharding_rule(operand, *, sizes, axis):
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
# change this logic to `return operand.sharding` directly.
out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis)
return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split')
for out_sh in out_shapes]

split_p = core.Primitive('split')
split_p.multiple_results = True
split_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule,
_split_dtype_rule, _split_weak_type_rule, _split_sharding_rule))
split_p.def_impl(partial(dispatch.apply_primitive, split_p))
ad.deflinear2(split_p, _split_transpose_rule)
batching.primitive_batchers[split_p] = _split_batch_rule
mlir.register_lowering(split_p, _split_lower)

def _pad_dtype_rule(operand, padding_value, *, padding_config):
if operand.dtype != padding_value.dtype:
msg = "pad operand and padding_value must be same dtype: got {} and {}."
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ def _multi_slice(self: Array,
# avoid circular imports.
@jax.jit
def _unstack(x: Array) -> list[Array]:
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
dims = (0,)
return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])]

def _chunk_iter(x, size):
if size > x.shape[0]:
Expand Down
31 changes: 14 additions & 17 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)
from jax._src.util import (
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2,
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
tuple_replace)
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
PartitionSpec as P)
Expand Down Expand Up @@ -3273,10 +3273,10 @@ def _split(op: str, ary: ArrayLike,
if (isinstance(indices_or_sections, (tuple, list)) or
isinstance(indices_or_sections, (np.ndarray, Array)) and
indices_or_sections.ndim > 0):
indices_or_sections = [
split_indices = np.asarray([0] + [
core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1")
for i_s in indices_or_sections]
split_indices = [0] + list(indices_or_sections) + [size]
for i_s in indices_or_sections] + [size])
sizes = list(np.diff(split_indices))
else:
if core.is_symbolic_dim(indices_or_sections):
raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is "
Expand All @@ -3285,21 +3285,14 @@ def _split(op: str, ary: ArrayLike,
f"in jax.numpy.{op} argument 1")
part_size, r = divmod(size, num_sections)
if r == 0:
split_indices = [i * part_size
for i in range(num_sections + 1)]
sizes = [part_size] * num_sections
elif op == "array_split":
split_indices = (
[i * (part_size + 1) for i in range(r + 1)] +
[i * part_size + ((r + 1) * (part_size + 1) - 1)
for i in range(num_sections - r)])
sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r)
else:
raise ValueError(f"array split does not result in an equal division: rest is {r}")
split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
for i in split_indices]
starts, ends = [0] * ndim(ary), shape(ary)
_subval = lambda x, i, v: subvals(x, [(i, v)])
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
for start, end in zip(split_indices[:-1], split_indices[1:])]
sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
for i in sizes]
return list(lax.split(ary, sizes, axis=axis))


@export
Expand Down Expand Up @@ -4662,7 +4655,11 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
"Unstack requires arrays with rank > 0, however a scalar array was "
"passed."
)
return tuple(moveaxis(x, axis, 0))
dimensions = (axis,)
return tuple(
lax.squeeze(t, dimensions)
for t in lax.split(x, (1,) * x.shape[axis], axis=axis)
)


@export
Expand Down
21 changes: 21 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,6 +1901,27 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule


def _split_lowering_rule(
ctx: LoweringRuleContext, x, *, sizes, axis
):
(x_aval,) = ctx.avals_in
slice_size = np.array(x_aval.shape, dtype=np.int64)
starts = np.zeros_like(slice_size)
strides = np.ones_like(slice_size)
outs = []
for size, aval_out in zip(sizes, ctx.avals_out):
slice_size[axis] = size
outs.append(
vector.extract_strided_slice(
aval_to_ir_type(aval_out), x, starts, slice_size, strides
)
)
starts[axis] += size
return outs

lowering_rules[lax.split_p] = _split_lowering_rule


def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
sharding):
out_type = aval_to_ir_type(ctx.avals_out[0])
Expand Down
6 changes: 6 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,6 +2087,12 @@ def _concatenate(*operands, dimension):
tf_impl[lax.concatenate_p] = _concatenate


def _split(operand, *, sizes, axis):
return tf.split(operand, _eval_shape(sizes), axis=axis)

tf_impl[lax.split_p] = _split


def _conv_general_dimension_numbers_proto(dimension_numbers):
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from jax._src.api_util import shaped_abstractify
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache
from jax._src.util import unzip2, weakref_lru_cache, safe_zip


def jet(fun, primals, series):
Expand Down Expand Up @@ -310,6 +310,8 @@ def deflinear(prim):
def linear_prop(prim, primals_in, series_in, **params):
primal_out = prim.bind(*primals_in, **params)
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
if prim.multiple_results:
series_out = safe_zip(*series_out)
return primal_out, series_out

deflinear(lax.neg_p)
Expand All @@ -323,6 +325,7 @@ def linear_prop(prim, primals_in, series_in, **params):
deflinear(lax.convert_element_type_p)
deflinear(lax.broadcast_in_dim_p)
deflinear(lax.concatenate_p)
deflinear(lax.split_p)
deflinear(lax.pad_p)
deflinear(lax.reshape_p)
deflinear(lax.squeeze_p)
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@
sort as sort,
sort_key_val as sort_key_val,
sort_p as sort_p,
split as split,
split_p as split_p,
sqrt as sqrt,
sqrt_p as sqrt_p,
square as square,
Expand Down
18 changes: 18 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,24 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs):
concatenate = lambda *args: lax.concatenate(args, dim)
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(len(base_shape))
],
num_pieces=range(3),
dtype=float_dtypes,
)
def testSplitGrad(self, axis, base_shape, dtype, num_pieces):
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
shape = list(base_shape)
shape[axis] = np.sum(sizes)
rng = jtu.rand_default(self.rng())
operands = (rng(shape, dtype),)
split = lambda x: lax.split(x, sizes, axis)
check_grads(split, operands, 2, ["fwd", "rev"], eps=1.)


@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
for lhs_shape, rhs_shape, all_strides in itertools.chain(
Expand Down
27 changes: 27 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,33 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs):
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
self._CheckAgainstNumpy(numpy_op, op, args_maker)

@jtu.sample_product(
[dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(len(shape))],
num_pieces=range(3),
dtype=lax_test_util.default_dtypes,
)
def testSplit(self, axis, base_shape, dtype, num_pieces):
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
shape = list(base_shape)
shape[axis] = np.sum(sizes)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.split(x, sizes, axis=axis)
def numpy_op(x):
return np.split(x, np.cumsum(sizes[:-1]), axis=axis)
self._CompileAndCheck(op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)

def testSplitErrors(self):
with self.assertRaisesRegex(ValueError,
"Sizes passed to split must be nonnegative"):
lax.split(np.arange(5), [-1])
with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"):
lax.split(np.arange(5), [6])
with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"):
lax.split(np.arange(5), sizes=(), axis=1)

@jtu.sample_product(
[
dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
Expand Down
18 changes: 18 additions & 0 deletions tests/lax_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,24 @@ def testSlice(self, shape, dtype, starts, limits, strides, bdims):
op = lambda x: lax.slice(x, starts, limits, strides)
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis, bdims=bdims)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(len(base_shape))
for bdims in lax_test_util.all_bdims(base_shape)
],
num_pieces=range(3),
dtype=lax_test_util.default_dtypes,
)
def testSplit(self, base_shape, dtype, num_pieces, axis, bdims):
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
shape = list(base_shape)
shape[axis] = np.sum(sizes)
rng = jtu.rand_default(self.rng())
op = lambda x: lax.split(x, sizes, axis)
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng,
multiple_results=True)

@jtu.sample_product(
[dict(shape=shape, perm=perm, bdims=bdims)
for shape, perm in [
Expand Down
Loading

0 comments on commit 7de9eb2

Please sign in to comment.