Skip to content

Commit

Permalink
doc: improve docs for jax.lax trig functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 6, 2025
1 parent 5d647cc commit 039ab8d
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Operators
erfc
erf_inv
exp
exp2
expand_dims
expm1
fft
Expand Down
159 changes: 145 additions & 14 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def round(x: ArrayLike,
rounding_method = RoundingMethod(rounding_method)
return round_p.bind(x, rounding_method=rounding_method)

@export
def is_finite(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{isfinite}`.
Expand All @@ -478,6 +479,7 @@ def is_finite(x: ArrayLike) -> Array:
"""
return is_finite_p.bind(x)

@export
def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`.
Expand All @@ -488,7 +490,7 @@ def exp(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential.
exponential.
See also:
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
Expand All @@ -509,7 +511,7 @@ def exp2(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
base-2 exponential.
base-2 exponential.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
Expand All @@ -520,6 +522,7 @@ def exp2(x: ArrayLike) -> Array:
"""
return exp2_p.bind(x)

@export
def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`.
Expand All @@ -532,7 +535,7 @@ def expm1(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential minus 1.
exponential minus 1.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
Expand All @@ -542,6 +545,7 @@ def expm1(x: ArrayLike) -> Array:
"""
return expm1_p.bind(x)

@export
def log(x: ArrayLike) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
Expand All @@ -552,7 +556,7 @@ def log(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm.
natural logarithm.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
Expand All @@ -561,8 +565,9 @@ def log(x: ArrayLike) -> Array:
"""
return log_p.bind(x)

@export
def log1p(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`..
r"""Elementwise :math:`\mathrm{log}(1 + x)`.
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
Expand All @@ -573,7 +578,7 @@ def log1p(x: ArrayLike) -> Array:
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm of ``x + 1``.
natural logarithm of ``x + 1``.
See also:
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
Expand All @@ -591,17 +596,75 @@ def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
return logistic_p.bind(x)

@export
def sin(x: ArrayLike) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.
For floating-point inputs, this function lowers directly to the
`stablehlo.sine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex sine.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
sine.
See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.asin`: elementwise arcsine.
.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
"""
return sin_p.bind(x)

@export
def cos(x: ArrayLike) -> Array:
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.
For floating-point inputs, this function lowers directly to the
`stablehlo.cosine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex cosine.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
cosine.
See also:
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.acos`: elementwise arccosine.
.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
"""
return cos_p.bind(x)

@export
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise arc tangent of two variables:
:math:`\mathrm{atan}({x \over y})`."""
r"""Elementwise two-term arc tangent: :math:`\mathrm{atan}({x \over y})`.
This function lowers directly to the `stablehlo.atan2`_ operation.
Args:
x, y: input arrays. Must have a matching floating-point dtypes. If neither is
a scalar, must have the same number of dimensions and be broadcast-compatible.
Returns:
Array of the same shape and dtype as ``x`` and ``y`` containing the element-wise
arctangent of :math:`x \over y`, respecting the quadrant indicated by the sign
of each input.
See also:
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.atan`: elementwise one-term arctangent.
.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
"""
return atan2_p.bind(x, y)

def real(x: ArrayLike) -> Array:
Expand Down Expand Up @@ -2473,20 +2536,88 @@ def reciprocal(x: ArrayLike) -> Array:
r"""Elementwise reciprocal: :math:`1 \over x`."""
return integer_pow(x, -1)

@export
def tan(x: ArrayLike) -> Array:
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.
This function lowers directly to the `stablehlo.tangent`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
tangent.
See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.atan`: elementwise arctangent.
- :func:`jax.lax.atan2`: elementwise 2-term arctangent.
.. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent
"""
return tan_p.bind(x)

@export
def asin(x: ArrayLike) -> Array:
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.
This function lowers directly to the ``chlo.asin`` operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arcsine.
See also:
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.acos`: elementwise arccosine.
- :func:`jax.lax.atan`: elementwise arctangent.
"""
return asin_p.bind(x)

@export
def acos(x: ArrayLike) -> Array:
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`.
This function lowers directly to the ``chlo.acos`` operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arccosine.
See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.asin`: elementwise arcsine.
- :func:`jax.lax.atan`: elementwise arctangent.
"""
return acos_p.bind(x)

@export
def atan(x: ArrayLike) -> Array:
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`.
This function lowers directly to the ``chlo.atan`` operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the
element-wise arctangent.
See also:
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.acos`: elementwise arccosine.
- :func:`jax.lax.asin`: elementwise arcsine.
- :func:`jax.lax.atan2`: elementwise 2-term arctangent.
"""
return atan_p.bind(x)

def sinh(x: ArrayLike) -> Array:
Expand Down

0 comments on commit 039ab8d

Please sign in to comment.