diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 03aacec07751..3ec6359fe43a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5222,18 +5222,71 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) -@util.implements(np.fill_diagonal, lax_description=""" -The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which -JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal` -adds the ``inplace`` parameter, which must be set to ``False`` by the user as a -reminder of this API difference. -""", extra_params=""" -inplace : bool, default=True - If left to its default value of True, JAX will raise an error. This is because - the semantics of :func:`numpy.fill_diagonal` are to modify the array in-place, - which is not possible in JAX due to the immutability of JAX arrays. -""") -def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array: +def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, + inplace: bool = True) -> Array: + """Return a copy of the array with the diagonal overwritten. + + JAX implementation of :func:`numpy.fill_diagonal`. + + The semantics of :func:`numpy.fill_diagonal` are to modify arrays in-place, which + is not possible for JAX's immutable arrays. The JAX version returns a modified + copy of the input, and adds the ``inplace`` parameter which must be set to + `False`` by the user as a reminder of this API difference. + + Args: + a: input array. Must have ``a.ndim >= 2``. If ``a.ndim >= 3``, then all + dimensions must be the same size. + val: scalar or array with which to fill the diagonal. If an array, it will + be flattened and repeated to fill the diagonal entries. + inplace: must be set to False to indicate that the input is not modified + in-place, but rather a modified copy is returned. + + Returns: + A copy of ``a`` with the diagonal set to ``val``. + + Examples: + >>> x = jnp.zeros((3, 3), dtype=int) + >>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False) + Array([[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], dtype=int32) + + Unlike :func:`numpy.fill_diagonal`, the input ``x`` is not modified. + + If the diagonal value has too many entries, it will be truncated + + >>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False) + Array([[100, 0, 0], + [ 0, 101, 0], + [ 0, 0, 102]], dtype=int32) + + If the diagonal has too few entries, it will be repeated: + + >>> x = jnp.zeros((4, 4), dtype=int) + >>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False) + Array([[3, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]], dtype=int32) + + For non-square arrays, the diagonal of the leading square slice is filled: + + >>> x = jnp.zeros((3, 5), dtype=int) + >>> jnp.fill_diagonal(x, 1, inplace=False) + Array([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0]], dtype=int32) + + And for square N-dimensional arrays, the N-dimensional diagonal is filled: + + >>> y = jnp.zeros((2, 2, 2)) + >>> jnp.fill_diagonal(y, 1, inplace=False) + Array([[[1., 0.], + [0., 0.]], + + [[0., 0.], + [0., 1.]]], dtype=float32) + """ if inplace: raise NotImplementedError("JAX arrays are immutable, must use inplace=False") if wrap: @@ -8830,19 +8883,64 @@ def _tile_to_size(arr: Array, size: int) -> Array: return arr[:size] if arr.size > size else arr -@util.implements(np.place, lax_description=""" -The semantics of :func:`numpy.place` is to modify arrays in-place, which JAX -cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.place` adds -the ``inplace`` parameter, which must be set to ``False`` by the user as a -reminder of this API difference. -""", extra_params=""" -inplace : bool, default=True - If left to its default value of True, JAX will raise an error. This is because - the semantics of :func:`numpy.put` are to modify the array in-place, which is - not possible in JAX due to the immutability of JAX arrays. -""") def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: bool = True) -> Array: + """Update array elements based on a mask. + + JAX implementation of :func:`numpy.place`. + + The semantics of :func:`numpy.place` are to modify arrays in-place, which + is not possible for JAX's immutable arrays. The JAX version returns a modified + copy of the input, and adds the ``inplace`` parameter which must be set to + `False`` by the user as a reminder of this API difference. + + Args: + arr: array into which values will be placed. + mask: boolean mask with the same size as ``arr``. + vals: values to be inserted into ``arr`` at the locations indicated + by mask. If too many values are supplied, they will be truncated. + If not enough values are supplied, they will be repeated. + inplace: must be set to False to indicate that the input is not modified + in-place, but rather a modified copy is returned. + + Returns: + A copy of ``arr`` with masked values set to entries from `vals`. + + See Also: + - :func:`jax.numpy.put`: put elements into an array at numerical indices. + - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing + + Examples: + >>> x = jnp.zeros((3, 5), dtype=int) + >>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape) + >>> mask + Array([[ True, False, False, True, False], + [False, True, False, False, True], + [False, False, True, False, False]], dtype=bool) + + Placing a scalar value: + + >>> jnp.place(x, mask, 1, inplace=False) + Array([[1, 0, 0, 1, 0], + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 0]], dtype=int32) + + In this case, ``jnp.place`` is similar to the masked array update syntax: + + >>> x.at[mask].set(1) + Array([[1, 0, 0, 1, 0], + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 0]], dtype=int32) + + ``place`` differs when placing values from an array. The array is repeated + to fill the masked entries: + + >>> vals = jnp.array([1, 3, 5]) + >>> jnp.place(x, mask, vals, inplace=False) + Array([[1, 0, 0, 3, 0], + [0, 5, 0, 0, 1], + [0, 0, 3, 0, 0]], dtype=int32) + """ util.check_arraylike("place", arr, mask, vals) data, mask_arr, vals_arr = asarray(arr), asarray(mask), ravel(vals) if inplace: @@ -8860,19 +8958,70 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape) -@util.implements(np.put, lax_description=""" -The semantics of :func:`numpy.put` is to modify arrays in-place, which JAX -cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.put` adds -the ``inplace`` parameter, which must be set to ``False`` by the user as a -reminder of this API difference. -""", extra_params=""" -inplace : bool, default=True - If left to its default value of True, JAX will raise an error. This is because - the semantics of :func:`numpy.put` are to modify the array in-place, which is - not possible in JAX due to the immutability of JAX arrays. -""") def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = None, *, inplace: bool = True) -> Array: + """Put elements into an array at given indices. + + JAX implementation of :func:`numpy.put`. + + The semantics of :func:`numpy.put` are to modify arrays in-place, which + is not possible for JAX's immutable arrays. The JAX version returns a modified + copy of the input, and adds the ``inplace`` parameter which must be set to + `False`` by the user as a reminder of this API difference. + + Args: + a: array into which values will be placed. + ind: array of indices over the flattened array at which to put values. + v: array of values to put into the array. + mode: string specifying how to handle out-of-bound indices. Supported values: + + - ``"clip"`` (default): clip out-of-bound indices to the final index. + - ``"wrap"``: wrap out-of-bound indices to the beginning of the array. + + inplace: must be set to False to indicate that the input is not modified + in-place, but rather a modified copy is returned. + + Returns: + A copy of ``a`` with specified entries updated. + + See Also: + - :func:`jax.numpy.place`: place elements into an array via boolean mask. + - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing. + - :func:`jax.numpy.take`: extract values from an array at given indices. + + Examples: + >>> x = jnp.zeros(5, dtype=int) + >>> indices = jnp.array([0, 2, 4]) + >>> values = jnp.array([10, 20, 30]) + >>> jnp.put(x, indices, values, inplace=False) + Array([10, 0, 20, 0, 30], dtype=int32) + + This is equivalent to the following :attr:`jax.numpy.ndarray.at` indexing syntax: + + >>> x.at[indices].set(values) + Array([10, 0, 20, 0, 30], dtype=int32) + + There are two modes for handling out-of-bound indices. By default they are + clipped: + + >>> indices = jnp.array([0, 2, 6]) + >>> jnp.put(x, indices, values, inplace=False, mode='clip') + Array([10, 0, 20, 0, 30], dtype=int32) + + Alternatively, they can be wrapped to the beginning of the array: + + >>> jnp.put(x, indices, values, inplace=False, mode='wrap') + Array([10, 30, 20, 0, 0], dtype=int32) + + For N-dimensional inputs, the indices refer to the flattened array: + + >>> x = jnp.zeros((3, 5), dtype=int) + >>> indices = jnp.array([0, 7, 14]) + >>> jnp.put(x, indices, values, inplace=False) + Array([[10, 0, 0, 0, 0], + [ 0, 0, 20, 0, 0], + [ 0, 0, 0, 0, 30]], dtype=int32) + """ util.check_arraylike("put", a, ind, v) arr, ind_arr, v_arr = asarray(a), ravel(ind), ravel(v) if not arr.size or not ind_arr.size or not v_arr.size: