Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: Better doc for jnp.sum #22081

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,76 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)

@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)

def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array:
r"""Sum of the elements of the array over a given axis.

JAX implementation of :func:`numpy.sum`.

Args:
a: Input array.
axis: int or array, default=None. Axis along which the sum to be computed.
If None, the sum is computed along all the axes.
dtype: The type of the output array. Default=None.
out: Unused by JAX
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
initial: int or array, Default=None. Initial value for the sum.
where: int or array, default=None. The elements to be used in the sum. Array
should be broadcast compatible to the input.
promote_integers : bool, default=True. If True, then integer inputs will be
promoted to the widest available integer dtype, following numpy's behavior.
If False, the result will have the same dtype as the input.
``promote_integers`` is ignored if ``dtype`` is specified.

Returns:
An array of the sum along the given axis.

See also:
- :func:`jax.numpy.prod`: Compute the product of array elements over a given
axis.
- :func:`jax.numpy.max`: Compute the maximum of array elements over given axis.
- :func:`jax.numpy.min`: Compute the minimum of array elements over given axis.

Examples:

By default, the sum is computed along all the axes.

>>> x = jnp.array([[1, 3, 4, 2],
... [5, 2, 6, 3],
... [8, 1, 3, 9]])
>>> jnp.sum(x)
Array(47, dtype=int32)

If ``axis=1``, the sum is computed along axis 1.

>>> jnp.sum(x, axis=1)
Array([10, 16, 21], dtype=int32)

If ``keepdims=True``, ``ndim`` of the output is equal to that of the input.

>>> jnp.sum(x, axis=1, keepdims=True)
Array([[10],
[16],
[21]], dtype=int32)

To include only specific elements in the sum, you can use a``where``.

>>> where=jnp.array([[0, 0, 1, 0],
... [0, 0, 1, 1],
... [1, 1, 1, 0]], dtype=bool)
>>> jnp.sum(x, axis=1, keepdims=True, where=where)
Array([[ 4],
[ 9],
[12]], dtype=int32)
>>> where=jnp.array([[False],
... [False],
... [False]])
>>> jnp.sum(x, axis=0, keepdims=True, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)
Expand Down