Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some missing jax.numpy documentation #24537

Merged
merged 1 commit into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,44 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
"""Returns the type to which a binary operation should cast its arguments.

For details of JAX's type promotion semantics, see :ref:`type-promotion`.
JAX implementation of :func:`numpy.promote_types`. For details of JAX's
type promotion semantics, see :ref:`type-promotion`.

Args:
a: a :class:`numpy.dtype` or a dtype specifier.
b: a :class:`numpy.dtype` or a dtype specifier.

Returns:
A :class:`numpy.dtype` object.

Examples:
Type specifiers may be strings, dtypes, or scalar types, and the return
value is always a dtype:

>>> jnp.promote_types('int32', 'float32') # strings
dtype('float32')
>>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes
dtype('float32')
>>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types
dtype('float32')

Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are
treated as weakly-typed and will not change the bit width of a strongly-typed
counterpart (see discussion in :ref:`type-promotion`):

>>> jnp.promote_types('uint8', int)
dtype('uint8')
>>> jnp.promote_types('float16', float)
dtype('float16')

This differs from the NumPy version of this function, which treats built-in scalar
types as equivalent to 64-bit types:

>>> import numpy
>>> numpy.promote_types('uint8', int)
dtype('int64')
>>> numpy.promote_types('float16', float)
dtype('float64')
"""
# Note: we deliberately avoid `if a in _weak_types` here because we want to check
# object identity, not object equality, due to the behavior of np.dtype.__eq__
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
{"dtype": np.dtype(np_scalar_type)})
meta.__module__ = _PUBLIC_MODULE_NAME
meta.__doc__ =\
f"""A JAX scalar constructor of type {np_scalar_type.__name__}.

While NumPy defines scalar types for each data type, JAX represents
scalars as zero-dimensional arrays.
"""
return meta

bool_ = _make_scalar_type(np.bool_)
Expand Down
23 changes: 23 additions & 0 deletions jax/_src/numpy/ufunc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,5 +598,28 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,

Returns:
wrapped : jax.numpy.ufunc wrapper of func.

Examples:
Here is an example of creating a ufunc similar to :obj:`jax.numpy.add`:

>>> import operator
>>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)

Now all the standard :class:`jax.numpy.ufunc` methods are available:

>>> x = jnp.arange(4)
>>> add(x, 10)
Array([10, 11, 12, 13], dtype=int32)
>>> add.outer(x, x)
Array([[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]], dtype=int32)
>>> add.reduce(x)
Array(6, dtype=int32)
>>> add.accumulate(x)
Array([0, 1, 3, 6], dtype=int32)
>>> add.at(x, 1, 10, inplace=False)
Array([ 0, 11, 2, 3], dtype=int32)
"""
return ufunc(func, nin, nout, identity=identity)
Loading