From 43a9faa06a1dcbd993dd224d55d1d3efa3346eda Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 24 Jan 2024 14:14:19 -0800 Subject: [PATCH] Rename _wraps to implements --- jax/_src/numpy/fft.py | 38 +-- jax/_src/numpy/lax_numpy.py | 350 ++++++++++---------- jax/_src/numpy/linalg.py | 54 +-- jax/_src/numpy/polynomial.py | 22 +- jax/_src/numpy/reductions.py | 54 +-- jax/_src/numpy/setops.py | 22 +- jax/_src/numpy/ufunc_api.py | 12 +- jax/_src/numpy/ufuncs.py | 94 +++--- jax/_src/numpy/util.py | 60 ++-- jax/_src/scipy/cluster/vq.py | 4 +- jax/_src/scipy/fft.py | 10 +- jax/_src/scipy/integrate.py | 2 +- jax/_src/scipy/linalg.py | 46 +-- jax/_src/scipy/ndimage.py | 4 +- jax/_src/scipy/signal.py | 22 +- jax/_src/scipy/spatial/transform.py | 6 +- jax/_src/scipy/special.py | 68 ++-- jax/_src/scipy/stats/_core.py | 8 +- jax/_src/scipy/stats/bernoulli.py | 10 +- jax/_src/scipy/stats/beta.py | 14 +- jax/_src/scipy/stats/betabinom.py | 6 +- jax/_src/scipy/stats/binom.py | 6 +- jax/_src/scipy/stats/cauchy.py | 18 +- jax/_src/scipy/stats/chi2.py | 14 +- jax/_src/scipy/stats/dirichlet.py | 6 +- jax/_src/scipy/stats/expon.py | 6 +- jax/_src/scipy/stats/gamma.py | 14 +- jax/_src/scipy/stats/gennorm.py | 8 +- jax/_src/scipy/stats/geom.py | 6 +- jax/_src/scipy/stats/kde.py | 18 +- jax/_src/scipy/stats/laplace.py | 8 +- jax/_src/scipy/stats/logistic.py | 14 +- jax/_src/scipy/stats/multinomial.py | 6 +- jax/_src/scipy/stats/multivariate_normal.py | 6 +- jax/_src/scipy/stats/nbinom.py | 6 +- jax/_src/scipy/stats/norm.py | 18 +- jax/_src/scipy/stats/pareto.py | 6 +- jax/_src/scipy/stats/poisson.py | 8 +- jax/_src/scipy/stats/t.py | 6 +- jax/_src/scipy/stats/truncnorm.py | 14 +- jax/_src/scipy/stats/uniform.py | 10 +- jax/_src/scipy/stats/vonmises.py | 6 +- jax/_src/scipy/stats/wrapcauchy.py | 6 +- jax/_src/third_party/numpy/linalg.py | 10 +- jax/_src/third_party/scipy/interpolate.py | 6 +- jax/_src/third_party/scipy/linalg.py | 4 +- tests/lax_numpy_test.py | 4 +- 47 files changed, 569 insertions(+), 571 deletions(-) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 7574135ded4f..2d4cc319e901 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -22,7 +22,7 @@ from jax import lax from jax._src.lib import xla_client from jax._src.util import safe_zip -from jax._src.numpy.util import check_arraylike, _wraps +from jax._src.numpy.util import check_arraylike, implements from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import ufuncs, reductions from jax._src.typing import Array, ArrayLike @@ -105,28 +105,28 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, return transformed -@_wraps(np.fft.fftn) +@implements(np.fft.fftn) def fftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm) -@_wraps(np.fft.ifftn) +@implements(np.fft.ifftn) def ifftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) -@_wraps(np.fft.rfftn) +@implements(np.fft.rfftn) def rfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) -@_wraps(np.fft.irfftn) +@implements(np.fft.irfftn) def irfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: @@ -150,31 +150,31 @@ def _fft_core_1d(func_name: str, fft_type: xla_client.FftType, return _fft_core(func_name, fft_type, a, s, axes, norm) -@_wraps(np.fft.fft) +@implements(np.fft.fft) def fft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis, norm=norm) -@_wraps(np.fft.ifft) +@implements(np.fft.ifft) def ifft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis, norm=norm) -@_wraps(np.fft.rfft) +@implements(np.fft.rfft) def rfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis, norm=norm) -@_wraps(np.fft.irfft) +@implements(np.fft.irfft) def irfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis, norm=norm) -@_wraps(np.fft.hfft) +@implements(np.fft.hfft) def hfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: conj_a = ufuncs.conj(a) @@ -183,7 +183,7 @@ def hfft(a: ArrayLike, n: int | None = None, return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis, norm=norm) * nn -@_wraps(np.fft.ihfft) +@implements(np.fft.ihfft) def ihfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: _axis_check_1d('ihfft', axis) @@ -206,32 +206,32 @@ def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike, return _fft_core(func_name, fft_type, a, s, axes, norm) -@_wraps(np.fft.fft2) +@implements(np.fft.fft2) def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes, norm=norm) -@_wraps(np.fft.ifft2) +@implements(np.fft.ifft2) def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, norm=norm) -@_wraps(np.fft.rfft2) +@implements(np.fft.rfft2) def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, norm=norm) -@_wraps(np.fft.irfft2) +@implements(np.fft.irfft2) def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, norm=norm) -@_wraps(np.fft.fftfreq, extra_params=""" +@implements(np.fft.fftfreq, extra_params=""" dtype : Optional The dtype of the returned frequencies. If not specified, JAX's default floating point dtype will be used. @@ -266,7 +266,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: return k / jnp.array(d * n, dtype=dtype) -@_wraps(np.fft.rfftfreq, extra_params=""" +@implements(np.fft.rfftfreq, extra_params=""" dtype : Optional The dtype of the returned frequencies. If not specified, JAX's default floating point dtype will be used. @@ -292,7 +292,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array: return k / jnp.array(d * n, dtype=dtype) -@_wraps(np.fft.fftshift) +@implements(np.fft.fftshift) def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: check_arraylike("fftshift", x) x = jnp.asarray(x) @@ -308,7 +308,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: return jnp.roll(x, shift, axes) -@_wraps(np.fft.ifftshift) +@implements(np.fft.ifftshift) def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array: check_arraylike("ifftshift", x) x = jnp.asarray(x) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 39f935422b44..0b0e8b8ead35 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -111,7 +111,7 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: printoptions = np.printoptions set_printoptions = np.set_printoptions -@util._wraps(np.iscomplexobj) +@util.implements(np.iscomplexobj) def iscomplexobj(x: Any) -> bool: if x is None: return False @@ -218,7 +218,7 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: save = np.save savez = np.savez -@util._wraps(np.dtype) +@util.implements(np.dtype) def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False, copy: bool = False) -> DType: """Similar to np.dtype, but respects JAX dtype defaults.""" @@ -283,7 +283,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: return clip(val, min_val, max_val).astype(dtype) -@util._wraps(np.load, update_doc=False) +@util.implements(np.load, update_doc=False) def load(*args: Any, **kwargs: Any) -> Array: # The main purpose of this wrapper is to recover bfloat16 data types. # Note: this will only work for files created via np.save(), not np.savez(). @@ -300,21 +300,21 @@ def load(*args: Any, **kwargs: Any) -> Array: ### implementations of numpy functions in terms of lax -@util._wraps(np.fmin, module='numpy') +@util.implements(np.fmin, module='numpy') @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) -@util._wraps(np.fmax, module='numpy') +@util.implements(np.fmax, module='numpy') @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) -@util._wraps(np.issubdtype) +@util.implements(np.issubdtype) def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) -@util._wraps(np.isscalar) +@util.implements(np.isscalar) def isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): element = element.__jax_array__() @@ -322,12 +322,12 @@ def isscalar(element: Any) -> bool: iterable = np.iterable -@util._wraps(np.result_type) +@util.implements(np.result_type) def result_type(*args: Any) -> DType: return dtypes.result_type(*args) -@util._wraps(np.trunc, module='numpy') +@util.implements(np.trunc, module='numpy') @jit def trunc(x: ArrayLike) -> Array: util.check_arraylike('trunc', x) @@ -381,8 +381,8 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, return result[0, 0, out_order] -@util._wraps(np.convolve, lax_description=_PRECISION_DOC, - extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(np.convolve, lax_description=_PRECISION_DOC, + extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, @@ -392,8 +392,8 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision=precision, preferred_element_type=preferred_element_type) -@util._wraps(np.correlate, lax_description=_PRECISION_DOC, - extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(np.correlate, lax_description=_PRECISION_DOC, + extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, @@ -403,7 +403,7 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision=precision, preferred_element_type=preferred_element_type) -@util._wraps(np.histogram_bin_edges) +@util.implements(np.histogram_bin_edges) def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range: None | Array | Sequence[ArrayLike] = None, weights: ArrayLike | None = None) -> Array: @@ -429,7 +429,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, return linspace(range[0], range[1], bins_int + 1, dtype=dtype) -@util._wraps(np.histogram) +@util.implements(np.histogram) def histogram(a: ArrayLike, bins: ArrayLike = 10, range: Sequence[ArrayLike] | None = None, weights: ArrayLike | None = None, @@ -453,7 +453,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, counts = counts / bin_widths / counts.sum() return counts, bin_edges -@util._wraps(np.histogram2d) +@util.implements(np.histogram2d) def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -472,7 +472,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = hist, edges = histogramdd(sample, bins, range, weights, density) return hist, edges[0], edges[1] -@util._wraps(np.histogramdd) +@util.implements(np.histogramdd) def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -536,7 +536,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, view of the input. """ -@util._wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.transpose, lax_description=_ARRAY_VIEW_DOC) def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: util.check_arraylike("transpose", a) axes_ = list(range(ndim(a))[::-1]) if axes is None else axes @@ -544,13 +544,13 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: return lax.transpose(a, axes_) -@util._wraps(getattr(np, "permute_dims", None)) +@util.implements(getattr(np, "permute_dims", None)) def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: util.check_arraylike("permute_dims", a) return lax.transpose(a, axes) -@util._wraps(getattr(np, 'matrix_transpose', None)) +@util.implements(getattr(np, 'matrix_transpose', None)) def matrix_transpose(x: ArrayLike, /) -> Array: """Transposes the last two dimensions of x. @@ -572,7 +572,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return lax.transpose(x, axes) -@util._wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.rot90, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: util.check_arraylike("rot90", m) @@ -599,7 +599,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) -@util._wraps(np.flip, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.flip, lax_description=_ARRAY_VIEW_DOC) def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: util.check_arraylike("flip", m) return _flip(asarray(m), reductions._ensure_optional_axes(axis)) @@ -612,30 +612,30 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) -@util._wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.fliplr, lax_description=_ARRAY_VIEW_DOC) def fliplr(m: ArrayLike) -> Array: util.check_arraylike("fliplr", m) return _flip(asarray(m), 1) -@util._wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.flipud, lax_description=_ARRAY_VIEW_DOC) def flipud(m: ArrayLike) -> Array: util.check_arraylike("flipud", m) return _flip(asarray(m), 0) -@util._wraps(np.iscomplex) +@util.implements(np.iscomplex) @jit def iscomplex(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) -@util._wraps(np.isreal) +@util.implements(np.isreal) @jit def isreal(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.eq(i, _lax_const(i, 0)) -@util._wraps(np.angle) +@util.implements(np.angle) @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: re = ufuncs.real(z) @@ -650,7 +650,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: return ufuncs.degrees(result) if deg else result -@util._wraps(np.diff) +@util.implements(np.diff) @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, @@ -710,7 +710,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, loses precision. """ -@util._wraps(np.ediff1d, lax_description=_EDIFF1D_DOC) +@util.implements(np.ediff1d, lax_description=_EDIFF1D_DOC) @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: @@ -726,7 +726,7 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, return result -@util._wraps(np.gradient, skip_params=['edge_order']) +@util.implements(np.gradient, skip_params=['edge_order']) @partial(jit, static_argnames=('axis', 'edge_order')) def gradient(f: ArrayLike, *varargs: ArrayLike, axis: int | Sequence[int] | None = None, @@ -771,12 +771,12 @@ def gradient_along_axis(a, h, axis): return a_grad[0] if len(axis_tuple) == 1 else a_grad -@util._wraps(np.isrealobj) +@util.implements(np.isrealobj) def isrealobj(x: Any) -> bool: return not iscomplexobj(x) -@util._wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.reshape, lax_description=_ARRAY_VIEW_DOC) def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array: __tracebackhide__ = True util.check_arraylike("reshape", a) @@ -788,7 +788,7 @@ def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array: return asarray(a).reshape(newshape, order=order) -@util._wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.ravel, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: util.check_arraylike("ravel", a) @@ -797,7 +797,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: return reshape(a, (size(a),), order) -@util._wraps(np.ravel_multi_index) +@util.implements(np.ravel_multi_index) def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" @@ -840,7 +840,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], and out-of-bounds indices are clipped into the valid range. """ -@util._wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) +@util.implements(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: util.check_arraylike("unravel_index", indices) indices_arr = asarray(indices) @@ -860,7 +860,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in safe_zip(shape, out_indices)) -@util._wraps(np.resize) +@util.implements(np.resize) @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: util.check_arraylike("resize", a) @@ -880,7 +880,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: return reshape(arr, new_shape) -@util._wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.squeeze, lax_description=_ARRAY_VIEW_DOC) def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: util.check_arraylike("squeeze", a) return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) @@ -896,14 +896,14 @@ def _squeeze(a: Array, axis: tuple[int]) -> Array: return lax.squeeze(a, axis) -@util._wraps(np.expand_dims) +@util.implements(np.expand_dims) def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: util.check_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis) -@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.swapaxes, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: util.check_arraylike("swapaxes", a) @@ -912,7 +912,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: return lax.transpose(a, list(perm)) -@util._wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.moveaxis, lax_description=_ARRAY_VIEW_DOC) def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: util.check_arraylike("moveaxis", a) @@ -932,7 +932,7 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) - return lax.transpose(a, perm) -@util._wraps(np.isclose) +@util.implements(np.isclose) @partial(jit, static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -1034,7 +1034,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f -@util._wraps(np.interp, +@util.implements(np.interp, lax_description=_dedent(""" In addition to constant interpolation supported by NumPy, jnp.interp also supports left='extrapolate' and right='extrapolate' to indicate linear @@ -1074,7 +1074,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, _DEPRECATED_WHERE_ARG = object() -@util._wraps(np.where, # type: ignore[no-redef] +@util.implements(np.where, # type: ignore[no-redef] lax_description=_dedent(""" At present, JAX does not support JIT-compilation of the single-argument form of :py:func:`jax.numpy.where` because its output shape is data-dependent. The @@ -1137,7 +1137,7 @@ def where( return util._where(acondition, if_true, if_false) -@util._wraps(np.select) +@util.implements(np.select) def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], @@ -1156,7 +1156,7 @@ def select( return output -@util._wraps(np.bincount, lax_description="""\ +@util.implements(np.bincount, lax_description="""\ Jax adds the optional `length` parameter which specifies the output length, and defaults to ``x.max() + 1``. It must be specified for bincount to be compiled with non-static operands. Values larger than the specified length will be discarded. @@ -1196,7 +1196,7 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ... -@util._wraps(getattr(np, "broadcast_shapes", None)) +@util.implements(getattr(np, "broadcast_shapes", None)) def broadcast_shapes(*shapes): if not shapes: return () @@ -1204,14 +1204,14 @@ def broadcast_shapes(*shapes): return lax.broadcast_shapes(*shapes) -@util._wraps(np.broadcast_arrays, lax_description="""\ +@util.implements(np.broadcast_arrays, lax_description="""\ The JAX version does not necessarily return a view of the input. """) def broadcast_arrays(*args: ArrayLike) -> list[Array]: return util._broadcast_arrays(*args) -@util._wraps(np.broadcast_to, lax_description="""\ +@util.implements(np.broadcast_to, lax_description="""\ The JAX version does not necessarily return a view of the input. """) def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: @@ -1257,13 +1257,13 @@ def _split(op: str, ary: ArrayLike, return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) for start, end in zip(split_indices[:-1], split_indices[1:])] -@util._wraps(np.split, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.split, lax_description=_ARRAY_VIEW_DOC) def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: return _split("split", ary, indices_or_sections, axis=axis) def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, int | ArrayLike], list[Array]]: - @util._wraps(getattr(np, op), update_doc=False) + @util.implements(getattr(np, op), update_doc=False) def f(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: # for 1-D array, hsplit becomes vsplit nonlocal axis @@ -1278,12 +1278,12 @@ def f(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> l hsplit = _split_on_axis("hsplit", axis=1) dsplit = _split_on_axis("dsplit", axis=2) -@util._wraps(np.array_split) +@util.implements(np.array_split) def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: return _split("array_split", ary, indices_or_sections, axis=axis) -@util._wraps(np.clip, skip_params=['out']) +@util.implements(np.clip, skip_params=['out']) @jit def clip(a: ArrayLike, a_min: ArrayLike | None = None, a_max: ArrayLike | None = None, out: None = None) -> Array: @@ -1298,7 +1298,7 @@ def clip(a: ArrayLike, a_min: ArrayLike | None = None, a = ufuncs.minimum(a_max, a) return asarray(a) -@util._wraps(np.around, skip_params=['out']) +@util.implements(np.around, skip_params=['out']) @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: util.check_arraylike("round", a) @@ -1334,7 +1334,7 @@ def _round_float(x: ArrayLike) -> Array: round_ = round -@util._wraps(np.fix, skip_params=['out']) +@util.implements(np.fix, skip_params=['out']) @jit def fix(x: ArrayLike, out: None = None) -> Array: util.check_arraylike("fix", x) @@ -1344,7 +1344,7 @@ def fix(x: ArrayLike, out: None = None) -> Array: return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) -@util._wraps(np.nan_to_num) +@util.implements(np.nan_to_num) @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, @@ -1367,7 +1367,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, return out -@util._wraps(np.allclose) +@util.implements(np.allclose) @partial(jit, static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -1390,7 +1390,7 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, remaining elements will be filled with ``fill_value``, which defaults to zero. """ -@util._wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) +@util.implements(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: @@ -1422,13 +1422,13 @@ def nonzero(a: ArrayLike, *, size: int | None = None, out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out)) return out -@util._wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) +@util.implements(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike] = None) -> Array: return nonzero(ravel(a), size=size, fill_value=fill_value)[0] -@util._wraps(np.unwrap) +@util.implements(np.unwrap) @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: @@ -1764,7 +1764,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") -@util._wraps(np.pad, lax_description="""\ +@util.implements(np.pad, lax_description="""\ Unlike numpy, JAX "function" mode's argument (which is another function) should return the modified array. This is because Jax arrays are immutable. (In numpy, "function" mode's argument should modify a rank 1 array in-place.) @@ -1810,7 +1810,7 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], ### Array-creation functions -@util._wraps(np.stack, skip_params=['out']) +@util.implements(np.stack, skip_params=['out']) def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: if not len(arrays): @@ -1831,7 +1831,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype) -@util._wraps(np.tile) +@util.implements(np.tile) def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: util.check_arraylike("tile", A) try: @@ -1863,7 +1863,7 @@ def _concatenate_array(arr: ArrayLike, axis: int | None, dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)] return lax.reshape(arr, shape, dimensions) -@util._wraps(np.concatenate) +@util.implements(np.concatenate) def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: if isinstance(arrays, (np.ndarray, Array)): @@ -1890,13 +1890,13 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] -@util._wraps(getattr(np, "concat", None)) +@util.implements(getattr(np, "concat", None)) def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: util.check_arraylike("concat", *arrays) return jax.numpy.concatenate(arrays, axis=axis) -@util._wraps(np.vstack) +@util.implements(np.vstack) def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: arrs: Array | list[Array] @@ -1909,7 +1909,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0, dtype=dtype) -@util._wraps(np.hstack) +@util.implements(np.hstack) def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: arrs: Array | list[Array] @@ -1924,7 +1924,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) -@util._wraps(np.dstack) +@util.implements(np.dstack) def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: arrs: Array | list[Array] @@ -1937,7 +1937,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=2, dtype=dtype) -@util._wraps(np.column_stack) +@util.implements(np.column_stack) def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: arrs: Array | list[Array] | np.ndarray if isinstance(tup, (np.ndarray, Array)): @@ -1949,7 +1949,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: return concatenate(arrs, 1) -@util._wraps(np.choose, skip_params=['out']) +@util.implements(np.choose, skip_params=['out']) def choose(a: ArrayLike, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: if out is not None: @@ -1996,7 +1996,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: else: return asarray(xs), 1 -@util._wraps(np.block) +@util.implements(np.block) @jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: out, _ = _block(arrays) @@ -2012,7 +2012,7 @@ def atleast_1d(x: ArrayLike, /) -> Array: @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util._wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -2033,7 +2033,7 @@ def atleast_2d(x: ArrayLike, /) -> Array: @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util._wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -2059,7 +2059,7 @@ def atleast_3d(x: ArrayLike, /) -> Array: @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util._wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -2093,7 +2093,7 @@ def _supports_buffer_protocol(obj): https://jax.readthedocs.io/en/latest/faq.html). """ -@util._wraps(np.array, lax_description=_ARRAY_DOC) +@util.implements(np.array, lax_description=_ARRAY_DOC) def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0) -> Array: if order is not None and order != "K": @@ -2185,7 +2185,7 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x -@util._wraps(getattr(np, "astype", None), lax_description=""" +@util.implements(getattr(np, "astype", None), lax_description=""" This is implemented via :func:`jax.lax.convert_element_type`, which may have slightly different behavior than :func:`numpy.astype` in some cases. In particular, the details of float-to-int and int-to-float casts are @@ -2199,7 +2199,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Ar return lax.convert_element_type(x, dtype) -@util._wraps(np.asarray, lax_description=_ARRAY_DOC) +@util.implements(np.asarray, lax_description=_ARRAY_DOC) def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None) -> Array: # For copy=False, the array API specifies that we raise a ValueError if the input supports @@ -2217,13 +2217,13 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order) # type: ignore -@util._wraps(np.copy, lax_description=_ARRAY_DOC) +@util.implements(np.copy, lax_description=_ARRAY_DOC) def copy(a: ArrayLike, order: str | None = None) -> Array: util.check_arraylike("copy", a) return array(a, copy=True, order=order) -@util._wraps(np.zeros_like) +@util.implements(np.zeros_like) def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None) -> Array: @@ -2235,7 +2235,7 @@ def zeros_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 0, dtype, shape) -@util._wraps(np.ones_like) +@util.implements(np.ones_like) def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None) -> Array: @@ -2247,7 +2247,7 @@ def ones_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 1, dtype, shape) -@util._wraps(np.empty_like, lax_description="""\ +@util.implements(np.empty_like, lax_description="""\ Because XLA cannot create uninitialized arrays, the JAX version will return an array initialized with zeros.""") def empty_like(prototype: ArrayLike | DuckTypedArray, @@ -2269,7 +2269,7 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No return device -@util._wraps(np.full) +@util.implements(np.full) def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -2283,7 +2283,7 @@ def full(shape: Any, fill_value: ArrayLike, return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device) -@util._wraps(np.full_like) +@util.implements(np.full_like) def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None) -> Array: @@ -2302,7 +2302,7 @@ def full_like(a: ArrayLike | DuckTypedArray, return broadcast_to(asarray(fill_value, dtype=dtype), shape) -@util._wraps(np.zeros) +@util.implements(np.zeros) def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: if isinstance(shape, types.GeneratorType): @@ -2312,7 +2312,7 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, shape = canonicalize_shape(shape) return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) -@util._wraps(np.ones) +@util.implements(np.ones) def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: if isinstance(shape, types.GeneratorType): @@ -2322,7 +2322,7 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, dtypes.check_user_dtype_supported(dtype, "ones") return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) -@util._wraps(np.empty, lax_description="""\ +@util.implements(np.empty, lax_description="""\ Because XLA cannot create uninitialized arrays, the JAX version will return an array initialized with zeros.""") def empty(shape: Any, dtype: DTypeLike | None = None, *, @@ -2340,7 +2340,7 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore "with a single tuple argument for the shape?") -@util._wraps(np.array_equal) +@util.implements(np.array_equal) def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: try: a1, a2 = asarray(a1), asarray(a2) @@ -2359,7 +2359,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: return reductions.all(eq) -@util._wraps(np.array_equiv) +@util.implements(np.array_equiv) def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: try: a1, a2 = asarray(a1), asarray(a2) @@ -2380,7 +2380,7 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. -@util._wraps(np.frombuffer) +@util.implements(np.frombuffer) def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) @@ -2421,7 +2421,7 @@ def fromiter(*args, **kwargs): "because of its potential side-effect of consuming the iterable object; for more information see " "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") -@util._wraps(getattr(np, "from_dlpack", None), lax_description=""" +@util.implements(getattr(np, "from_dlpack", None), lax_description=""" .. note:: While JAX arrays are always immutable, dlpack buffers cannot be marked as @@ -2434,7 +2434,7 @@ def from_dlpack(x: Any) -> Array: from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top return from_dlpack(x.__dlpack__()) -@util._wraps(np.fromfunction) +@util.implements(np.fromfunction) def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = float, **kwargs) -> Array: shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()") @@ -2444,12 +2444,12 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) -@util._wraps(np.fromstring) +@util.implements(np.fromstring) def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) -@util._wraps(np.eye) +@util.implements(np.eye) def eye(N: DimSize, M: DimSize | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "eye") @@ -2461,13 +2461,13 @@ def eye(N: DimSize, M: DimSize | None = None, k: int = 0, return lax_internal._eye(_jnp_dtype(dtype), (N_int, M_int), k) -@util._wraps(np.identity) +@util.implements(np.identity) def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "identity") return eye(n, dtype=dtype) -@util._wraps(np.arange,lax_description= """ +@util.implements(np.arange,lax_description= """ .. note:: Using ``arange`` with the ``step`` argument can lead to precision errors, @@ -2561,7 +2561,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0) -> Array | tuple[Array, Array]: ... -@util._wraps(np.linspace) +@util.implements(np.linspace) def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, @@ -2628,7 +2628,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(out, dtype) -@util._wraps(np.logspace) +@util.implements(np.logspace) def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: @@ -2654,7 +2654,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(ufuncs.power(base, lin), dtype) -@util._wraps(np.geomspace) +@util.implements(np.geomspace) def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") @@ -2685,7 +2685,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) -@util._wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: util.check_arraylike("meshgrid", *xi) @@ -2708,7 +2708,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, @custom_jvp -@util._wraps(np.i0) +@util.implements(np.i0) @jit def i0(x: ArrayLike) -> Array: x_arr, = util.promote_args_inexact("i0", x) @@ -2723,7 +2723,7 @@ def _i0_jvp(primals, tangents): return primal_out, where(primals[0] == 0, 0.0, tangent_out) -@util._wraps(np.ix_) +@util.implements(np.ix_) def ix_(*args: ArrayLike) -> tuple[Array, ...]: util.check_arraylike("ix", *args) n = len(args) @@ -2755,7 +2755,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, sparse: bool = False) -> Array | tuple[Array, ...]: ... -@util._wraps(np.indices) +@util.implements(np.indices) def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, sparse: bool = False) -> Array | tuple[Array, ...]: dimensions = tuple( @@ -2784,7 +2784,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, """ -@util._wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC) +@util.implements(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC) def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) -> Array: util.check_arraylike("repeat", a) @@ -2865,7 +2865,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) -@util._wraps(np.tri) +@util.implements(np.tri) def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "tri") M = M if M is not None else N @@ -2873,7 +2873,7 @@ def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None return lax_internal._tri(dtype, (N, M), k) -@util._wraps(np.tril) +@util.implements(np.tril) @partial(jit, static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: util.check_arraylike("tril", m) @@ -2885,7 +2885,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) -@util._wraps(np.triu, update_doc=False) +@util.implements(np.triu, update_doc=False) @partial(jit, static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: util.check_arraylike("triu", m) @@ -2897,7 +2897,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) -@util._wraps(np.trace, skip_params=['out']) +@util.implements(np.trace, skip_params=['out']) @partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2923,7 +2923,7 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, def _wrap_indices_function(f): - @util._wraps(f, update_doc=False) + @util.implements(f, update_doc=False) def wrapper(*args, **kwargs): args = [core.concrete_or_error( None, arg, f"argument {i} of jnp.{f.__name__}()") @@ -2947,7 +2947,7 @@ def _triu_size(n, m, k): return mk * (mk + 1) // 2 + mk * (m - k - mk) -@util._wraps(np.triu_indices) +@util.implements(np.triu_indices) def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") @@ -2956,7 +2956,7 @@ def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j -@util._wraps(np.tril_indices) +@util.implements(np.tril_indices) def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: n = core.concrete_or_error(operator.index, n, "n argument of jnp.triu_indices") k = core.concrete_or_error(operator.index, k, "k argument of jnp.triu_indices") @@ -2965,19 +2965,19 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j -@util._wraps(np.triu_indices_from) +@util.implements(np.triu_indices_from) def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: arr_shape = shape(arr) return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1]) -@util._wraps(np.tril_indices_from) +@util.implements(np.tril_indices_from) def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: arr_shape = shape(arr) return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1]) -@util._wraps(np.fill_diagonal, lax_description=""" +@util.implements(np.fill_diagonal, lax_description=""" The semantics of :func:`numpy.fill_diagonal` is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.fill_diagonal` adds the ``inplace`` parameter, which must be set to ``False`` by the user as a @@ -3005,7 +3005,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n)) -@util._wraps(np.diag_indices) +@util.implements(np.diag_indices) def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()") ndim = core.concrete_or_error(operator.index, ndim, "'ndim' argument of jnp.diag_indices()") @@ -3017,7 +3017,7 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: .format(ndim)) return (lax.iota(int_, n),) * ndim -@util._wraps(np.diag_indices_from) +@util.implements(np.diag_indices_from) def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: util.check_arraylike("diag_indices_from", arr) nd = ndim(arr) @@ -3030,7 +3030,7 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: return diag_indices(s[0], ndim=nd) -@util._wraps(np.diagonal, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.diagonal, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: @@ -3049,7 +3049,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, return a[..., i, j] if offset >= 0 else a[..., j, i] -@util._wraps(np.diag, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.diag, lax_description=_ARRAY_VIEW_DOC) def diag(v: ArrayLike, k: int = 0) -> Array: return _diag(v, operator.index(k)) @@ -3073,7 +3073,7 @@ def _diag(v, k): return a scalar depending on the type of v. """ -@util._wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC) +@util.implements(np.diagflat, lax_description=_SCALAR_VALUE_DOC) def diagflat(v: ArrayLike, k: int = 0) -> Array: util.check_arraylike("diagflat", v) v_ravel = ravel(v) @@ -3090,7 +3090,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res -@util._wraps(np.trim_zeros) +@util.implements(np.trim_zeros) def trim_zeros(filt, trim='fb'): filt = core.concrete_or_error(asarray, filt, "Error arose in the `filt` argument of trim_zeros()") @@ -3113,7 +3113,7 @@ def trim_zeros_tol(filt, tol, trim='fb'): return filt[start:len(filt) - end] -@util._wraps(np.append) +@util.implements(np.append) @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None @@ -3124,7 +3124,7 @@ def append( return concatenate([arr, values], axis=axis) -@util._wraps(np.delete, +@util.implements(np.delete, lax_description=_dedent(""" delete() usually requires the index specification to be static. If the index is an integer array that is guaranteed to contain unique entries, you may @@ -3195,7 +3195,7 @@ def delete( return a[tuple(slice(None) for i in range(axis)) + (mask,)] -@util._wraps(np.insert) +@util.implements(np.insert) def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = None) -> Array: util.check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values) @@ -3248,7 +3248,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, return out -@util._wraps(np.apply_along_axis) +@util.implements(np.apply_along_axis) def apply_along_axis( func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs ) -> Array: @@ -3264,7 +3264,7 @@ def apply_along_axis( return func(arr) -@util._wraps(np.apply_over_axes) +@util.implements(np.apply_over_axes) def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, axes: Sequence[int]) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -3291,8 +3291,8 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, rules of the input array dtypes. """ -@util._wraps(np.dot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(np.dot, lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -3324,8 +3324,8 @@ def dot(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) -@util._wraps(np.matmul, module='numpy', lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(np.matmul, module='numpy', lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -3397,8 +3397,8 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) -@util._wraps(np.vdot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(np.vdot, lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, @@ -3412,8 +3412,8 @@ def vdot( preferred_element_type=preferred_element_type) -@util._wraps(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -3427,8 +3427,8 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, signature="(n),(n)->()")(x1_arr, x2_arr) -@util._wraps(np.tensordot, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(np.tensordot, lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, @@ -3505,7 +3505,7 @@ def einsum( _dot_general: Callable[..., Array] = lax.dot_general, ) -> Array: ... -@util._wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out']) +@util.implements(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out']) def einsum( subscripts, /, *operands, @@ -3557,7 +3557,7 @@ def _default_poly_einsum_handler(*operands, **kwargs): contract_operands = [operands[mapping[id(d)]] for d in out_dummies] return contract_operands, contractions -@util._wraps(np.einsum_path) +@util.implements(np.einsum_path) def einsum_path(subscripts, *operands, optimize='greedy'): # using einsum_call=True here is an internal api for opt_einsum return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) @@ -3716,8 +3716,8 @@ def filter_singleton_dims(operand, names, other_shape, other_names): return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type) -@util._wraps(np.inner, lax_description=_PRECISION_DOC, - extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +@util.implements(np.inner, lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -3733,7 +3733,7 @@ def inner( preferred_element_type=preferred_element_type) -@util._wraps(np.outer, skip_params=['out']) +@util.implements(np.outer, skip_params=['out']) @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: if out is not None: @@ -3743,7 +3743,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: a, b = util.promote_dtypes(a, b) return ravel(a)[:, None] * ravel(b)[None, :] -@util._wraps(np.cross) +@util.implements(np.cross) @partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): @@ -3773,7 +3773,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) -@util._wraps(np.kron) +@util.implements(np.kron) @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -3789,7 +3789,7 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) -@util._wraps(np.vander) +@util.implements(np.vander) @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False @@ -3822,7 +3822,7 @@ def vander( """ -@util._wraps(np.argwhere, +@util.implements(np.argwhere, lax_description=_dedent(""" Because the size of the output of ``argwhere`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which @@ -3847,7 +3847,7 @@ def argwhere( return result.reshape(result.shape[0], ndim(a)) -@util._wraps(np.argmax, skip_params=['out']) +@util.implements(np.argmax, skip_params=['out']) def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: util.check_arraylike("argmax", a) @@ -3869,7 +3869,7 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_)) return expand_dims(result, dims) if keepdims else result -@util._wraps(np.argmin, skip_params=['out']) +@util.implements(np.argmin, skip_params=['out']) def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: util.check_arraylike("argmin", a) @@ -3898,7 +3898,7 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: """ -@util._wraps(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out']) +@util.implements(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out']) def nanargmax( a: ArrayLike, axis: int | None = None, @@ -3921,7 +3921,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) -@util._wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out']) +@util.implements(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out']) def nanargmin( a: ArrayLike, axis: int | None = None, @@ -3944,8 +3944,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) -@util._wraps(np.sort, - extra_params=""" +@util.implements(np.sort, extra_params=""" kind : deprecated; specify sort algorithm using stable=True or stable=False order : not supported stable : bool, default=True @@ -3980,14 +3979,14 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result -@util._wraps(np.sort_complex) +@util.implements(np.sort_complex) @jit def sort_complex(a: ArrayLike) -> Array: util.check_arraylike("sort_complex", a) a = lax.sort(asarray(a), dimension=0) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) -@util._wraps(np.lexsort) +@util.implements(np.lexsort) @partial(jit, static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: key_tuple = tuple(keys) @@ -4006,8 +4005,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1] -@util._wraps(np.argsort, - extra_params=""" +@util.implements(np.argsort, extra_params=""" kind : deprecated; specify sort algorithm using stable=True or stable=False order : not supported stable : bool, default=True @@ -4051,7 +4049,7 @@ def argsort( return lax.rev(indices, dimensions=[dimension]) if descending else indices -@util._wraps(np.partition, lax_description=""" +@util.implements(np.partition, lax_description=""" The JAX version requires the ``kth`` argument to be a static integer rather than a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If you're only accessing the top or bottom k values of the output, it may be more @@ -4077,7 +4075,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) -@util._wraps(np.argpartition, lax_description=""" +@util.implements(np.argpartition, lax_description=""" The JAX version requires the ``kth`` argument to be a static integer rather than a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If you're only accessing the top or bottom k values of the output, it may be more @@ -4136,7 +4134,7 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: dimension=ax) return a -@util._wraps(np.roll) +@util.implements(np.roll) def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: util.check_arraylike("roll", a) @@ -4153,7 +4151,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], return _roll_static(arr, shift, axis) -@util._wraps(np.rollaxis, lax_description=_ARRAY_VIEW_DOC) +@util.implements(np.rollaxis, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: util.check_arraylike("rollaxis", a) @@ -4169,7 +4167,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: return moveaxis(a, axis, start) -@util._wraps(np.packbits) +@util.implements(np.packbits) @partial(jit, static_argnames=('axis', 'bitorder')) def packbits( a: ArrayLike, axis: int | None = None, bitorder: str = "big" @@ -4200,7 +4198,7 @@ def packbits( return swapaxes(packed, axis, -1) -@util._wraps(np.unpackbits) +@util.implements(np.unpackbits) @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, @@ -4231,7 +4229,7 @@ def unpackbits( return swapaxes(unpacked, axis, -1) -@util._wraps(np.take, skip_params=['out'], +@util.implements(np.take, skip_params=['out'], lax_description=""" By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the ``mode`` parameter (see below). @@ -4346,7 +4344,7 @@ def _normalize_index(index, axis_size): """ -@util._wraps(np.take_along_axis, update_doc=False, +@util.implements(np.take_along_axis, update_doc=False, lax_description=TAKE_ALONG_AXIS_DOC) @partial(jit, static_argnames=('axis', 'mode')) def take_along_axis( @@ -5064,7 +5062,7 @@ def clamp_index(i: DimSize, which: str): return start, step, slice_size -@util._wraps(np.blackman) +@util.implements(np.blackman) def blackman(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.blackman") dtype = dtypes.canonicalize_dtype(float_) @@ -5074,7 +5072,7 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) -@util._wraps(np.bartlett) +@util.implements(np.bartlett) def bartlett(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") dtype = dtypes.canonicalize_dtype(float_) @@ -5084,7 +5082,7 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) -@util._wraps(np.hamming) +@util.implements(np.hamming) def hamming(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.hamming") dtype = dtypes.canonicalize_dtype(float_) @@ -5094,7 +5092,7 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) -@util._wraps(np.hanning) +@util.implements(np.hanning) def hanning(M: int) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.hanning") dtype = dtypes.canonicalize_dtype(float_) @@ -5104,7 +5102,7 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) -@util._wraps(np.kaiser) +@util.implements(np.kaiser) def kaiser(M: int, beta: ArrayLike) -> Array: M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") dtype = dtypes.canonicalize_dtype(float_) @@ -5125,7 +5123,7 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) -@util._wraps(np.gcd, module='numpy') +@util.implements(np.gcd, module='numpy') @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: util.check_arraylike("gcd", x1, x2) @@ -5137,7 +5135,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd -@util._wraps(np.lcm, module='numpy') +@util.implements(np.lcm, module='numpy') @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: util.check_arraylike("lcm", x1, x2) @@ -5150,12 +5148,12 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) -@util._wraps(np.extract) +@util.implements(np.extract) def extract(condition: ArrayLike, arr: ArrayLike) -> Array: return compress(ravel(condition), ravel(arr)) -@util._wraps(np.compress, skip_params=['out']) +@util.implements(np.compress, skip_params=['out']) def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, out: None = None) -> Array: util.check_arraylike("compress", condition, a) @@ -5176,7 +5174,7 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, return moveaxis(arr[condition_arr], 0, axis) -@util._wraps(np.cov) +@util.implements(np.cov) @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, bias: bool = False, ddof: int | None = None, @@ -5244,7 +5242,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze() -@util._wraps(np.corrcoef) +@util.implements(np.corrcoef) @partial(jit, static_argnames=('rowvar',)) def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array: util.check_arraylike("corrcoef", x) @@ -5299,7 +5297,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) -@util._wraps(np.searchsorted, skip_params=['sorter'], +@util.implements(np.searchsorted, skip_params=['sorter'], extra_params=_dedent(""" method : str One of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. Controls the method used by the @@ -5335,7 +5333,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', }[method] return impl(asarray(a), asarray(v), side, dtype) # type: ignore -@util._wraps(np.digitize) +@util.implements(np.digitize) @partial(jit, static_argnames=('right',)) def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: util.check_arraylike("digitize", x, bins) @@ -5358,7 +5356,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: See the :func:`jax.lax.switch` documentation for more information. """ -@util._wraps(np.piecewise, lax_description=_PIECEWISE_DOC) +@util.implements(np.piecewise, lax_description=_PIECEWISE_DOC) def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], funclist: list[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: @@ -5400,7 +5398,7 @@ def _tile_to_size(arr: Array, size: int) -> Array: return arr[:size] if arr.size > size else arr -@util._wraps(np.place, lax_description=""" +@util.implements(np.place, lax_description=""" The semantics of :func:`numpy.place` is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.place` adds the ``inplace`` parameter, which must be set to ``False`` by the user as a @@ -5430,7 +5428,7 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape) -@util._wraps(np.put, lax_description=""" +@util.implements(np.put, lax_description=""" The semantics of :func:`numpy.put` is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thus :func:`jax.numpy.put` adds the ``inplace`` parameter, which must be set to ``False`` by the user as a diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 8a66d2befa15..0a68ca208fc6 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -30,7 +30,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, ufuncs -from jax._src.numpy.util import _wraps, promote_dtypes_inexact, check_arraylike +from jax._src.numpy.util import implements, promote_dtypes_inexact, check_arraylike from jax._src.util import canonicalize_axis from jax._src.typing import ArrayLike, Array @@ -63,7 +63,7 @@ def _H(x: ArrayLike) -> Array: def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 -@_wraps(np.linalg.cholesky) +@implements(np.linalg.cholesky) @jit def cholesky(a: ArrayLike) -> Array: check_arraylike("jnp.linalg.cholesky", a) @@ -86,7 +86,7 @@ def svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False) -> Array | SVDResult: ... -@_wraps(np.linalg.svd) +@implements(np.linalg.svd) @partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian')) def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False) -> Array | SVDResult: @@ -115,7 +115,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=False) -@_wraps(np.linalg.matrix_power) +@implements(np.linalg.matrix_power) @partial(jit, static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: check_arraylike("jnp.linalg.matrix_power", a) @@ -154,7 +154,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: return result -@_wraps(np.linalg.matrix_rank) +@implements(np.linalg.matrix_rank) @jit def matrix_rank(M: ArrayLike, tol: ArrayLike | None = None) -> Array: check_arraylike("jnp.linalg.matrix_rank", M) @@ -211,7 +211,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det -@_wraps( +@implements( np.linalg.slogdet, extra_params=textwrap.dedent(""" method: string, optional @@ -357,7 +357,7 @@ def _det_3x3(a: Array) -> Array: @custom_jvp -@_wraps(np.linalg.det) +@implements(np.linalg.det) @jit def det(a: ArrayLike) -> Array: check_arraylike("jnp.linalg.det", a) @@ -383,7 +383,7 @@ def _det_jvp(primals, tangents): return y, jnp.trace(z, axis1=-1, axis2=-2) -@_wraps(np.linalg.eig, lax_description=""" +@implements(np.linalg.eig, lax_description=""" This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always ``complex64`` for 32-bit input, and ``complex128`` for 64-bit input. @@ -399,7 +399,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: return w, v -@_wraps(np.linalg.eigvals) +@implements(np.linalg.eigvals) @jit def eigvals(a: ArrayLike) -> Array: check_arraylike("jnp.linalg.eigvals", a) @@ -407,7 +407,7 @@ def eigvals(a: ArrayLike) -> Array: compute_right_eigenvectors=False)[0] -@_wraps(np.linalg.eigh) +@implements(np.linalg.eigh) @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: @@ -425,7 +425,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None, return EighResult(w, v) -@_wraps(np.linalg.eigvalsh) +@implements(np.linalg.eigvalsh) @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: check_arraylike("jnp.linalg.eigvalsh", a) @@ -434,7 +434,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: @partial(custom_jvp, nondiff_argnums=(1, 2)) -@_wraps(np.linalg.pinv, lax_description=textwrap.dedent("""\ +@implements(np.linalg.pinv, lax_description=textwrap.dedent("""\ It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the default `rcond` is `1e-15`. Here the default is `10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`. @@ -494,7 +494,7 @@ def _pinv_jvp(rcond, hermitian, primals, tangents): return p, p_dot -@_wraps(np.linalg.inv) +@implements(np.linalg.inv) @jit def inv(a: ArrayLike) -> Array: check_arraylike("jnp.linalg.inv", a) @@ -506,7 +506,7 @@ def inv(a: ArrayLike) -> Array: arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2])) -@_wraps(np.linalg.norm) +@implements(np.linalg.norm) @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: int | str | None = None, axis: None | tuple[int, ...] | int = None, @@ -608,7 +608,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ... -@_wraps(np.linalg.qr) +@implements(np.linalg.qr) @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: check_arraylike("jnp.linalg.qr", a) @@ -628,7 +628,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: return QRResult(q, r) -@_wraps(np.linalg.solve) +@implements(np.linalg.solve) @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: check_arraylike("jnp.linalg.solve", a, b) @@ -689,7 +689,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) -@_wraps(np.linalg.lstsq, lax_description=textwrap.dedent("""\ +@implements(np.linalg.lstsq, lax_description=textwrap.dedent("""\ It has two important differences: 1. In `numpy.linalg.lstsq`, the default `rcond` is `-1`, and warns that in the future @@ -710,7 +710,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, return _jit_lstsq(a, b, rcond) -@_wraps(getattr(np.linalg, "cross", None)) +@implements(getattr(np.linalg, "cross", None)) def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): check_arraylike("jnp.linalg.outer", x1, x2) x1, x2 = jnp.asarray(x1), jnp.asarray(x2) @@ -722,7 +722,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): return jnp.cross(x1, x2, axis=axis) -@_wraps(getattr(np.linalg, "outer", None)) +@implements(getattr(np.linalg, "outer", None)) def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("jnp.linalg.outer", x1, x2) x1, x2 = jnp.asarray(x1), jnp.asarray(x2) @@ -731,7 +731,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: return x1[:, None] * x2[None, :] -@_wraps(getattr(np.linalg, "matrix_norm", None)) +@implements(getattr(np.linalg, "matrix_norm", None)) def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array: """ Computes the matrix norm of a matrix (or a stack of matrices) x. @@ -740,7 +740,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1)) -@_wraps(getattr(np.linalg, "matrix_transpose", None)) +@implements(getattr(np.linalg, "matrix_transpose", None)) def matrix_transpose(x: ArrayLike, /) -> Array: """Transposes a matrix (or a stack of matrices) x.""" check_arraylike('jnp.linalg.matrix_transpose', x) @@ -751,7 +751,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) -@_wraps(getattr(np.linalg, "vector_norm", None)) +@implements(getattr(np.linalg, "vector_norm", None)) def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: """Computes the vector norm of a vector (or batch of vectors) x.""" @@ -764,31 +764,31 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa return norm(x, axis=axis, keepdims=keepdims, ord=ord) -@_wraps(getattr(np.linalg, "vecdot", None)) +@implements(getattr(np.linalg, "vecdot", None)) def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: return jnp.vecdot(x1, x2, axis=axis) -@_wraps(getattr(np.linalg, "matmul", None)) +@implements(getattr(np.linalg, "matmul", None)) def matmul(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike('jnp.linalg.matmul', x1, x2) return jnp.matmul(x1, x2) -@_wraps(getattr(np.linalg, "tensordot", None)) +@implements(getattr(np.linalg, "tensordot", None)) def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2) -> Array: check_arraylike('jnp.linalg.tensordot', x1, x2) return jnp.tensordot(x1, x2, axes=axes) -@_wraps(getattr(np.linalg, "svdvals", None)) +@implements(getattr(np.linalg, "svdvals", None)) def svdvals(x: ArrayLike, /) -> Array: check_arraylike('jnp.linalg.svdvals', x) return svd(x, compute_uv=False, hermitian=False) -@_wraps(getattr(np.linalg, "diagonal", None)) +@implements(getattr(np.linalg, "diagonal", None)) def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: check_arraylike('jnp.linalg.diagonal', x) return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 41603c9dcf1c..dc2ff57c0d28 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -31,7 +31,7 @@ from jax._src.numpy.reductions import all from jax._src.numpy import linalg from jax._src.numpy.util import ( - check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, _wraps) + check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, implements) from jax._src.typing import Array, ArrayLike @@ -57,7 +57,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array: return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) -@_wraps(np.roots, lax_description="""\ +@implements(np.roots, lax_description="""\ Unlike the numpy version of this function, the JAX version returns the roots in a complex array regardless of the values of the roots. Additionally, the jax version of this function adds the ``strip_zeros`` function which must be set to @@ -106,7 +106,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix Also, it works best on rcond <= 10e-3 values. """ -@_wraps(np.polyfit, lax_description=_POLYFIT_DOC) +@implements(np.polyfit, lax_description=_POLYFIT_DOC) @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, full: bool = False, w: Array | None = None, cov: bool = False @@ -187,7 +187,7 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, jax returns an array with a complex dtype in such cases. """ -@_wraps(np.poly, lax_description=_POLY_DOC) +@implements(np.poly, lax_description=_POLY_DOC) @jit def poly(seq_of_zeros: Array) -> Array: check_arraylike('poly', seq_of_zeros) @@ -214,7 +214,7 @@ def poly(seq_of_zeros: Array) -> Array: return a -@_wraps(np.polyval, lax_description="""\ +@implements(np.polyval, lax_description="""\ The ``unroll`` parameter is JAX specific. It does not effect correctness but can have a major impact on performance for evaluating high-order polynomials. The parameter controls the number of unrolled steps with ``lax.scan`` inside the @@ -231,7 +231,7 @@ def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll) return y -@_wraps(np.polyadd) +@implements(np.polyadd) @jit def polyadd(a1: Array, a2: Array) -> Array: check_arraylike("polyadd", a1, a2) @@ -242,7 +242,7 @@ def polyadd(a1: Array, a2: Array) -> Array: return a2.at[-a1.shape[0]:].add(a1) -@_wraps(np.polyint) +@implements(np.polyint) @partial(jit, static_argnames=('m',)) def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") @@ -265,7 +265,7 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: return true_divide(concatenate((p, k_arr)), coeff) -@_wraps(np.polyder) +@implements(np.polyder) @partial(jit, static_argnames=('m',)) def polyder(p: Array, m: int = 1) -> Array: check_arraylike("polyder", p) @@ -288,7 +288,7 @@ def polyder(p: Array, m: int = 1) -> Array: JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True. """ -@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC) +@implements(np.polymul, lax_description=_LEADING_ZEROS_DOC) def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: check_arraylike("polymul", a1, a2) a1_arr, a2_arr = promote_dtypes_inexact(a1, a2) @@ -300,7 +300,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - a2_arr = asarray([0], dtype=a1_arr.dtype) return convolve(a1_arr, a2_arr, mode='full') -@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC) +@implements(np.polydiv, lax_description=_LEADING_ZEROS_DOC) def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: check_arraylike("polydiv", u, v) u_arr, v_arr = promote_dtypes_inexact(u, v) @@ -317,7 +317,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f') return q, u_arr -@_wraps(np.polysub) +@implements(np.polysub) @jit def polysub(a1: Array, a2: Array) -> Array: check_arraylike("polysub", a1, a2) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 28950fbac575..b1e80b952a96 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -31,7 +31,7 @@ from jax._src.numpy import ufuncs from jax._src.numpy.util import ( _broadcast_to, check_arraylike, _complex_elem_type, - promote_dtypes_inexact, promote_dtypes_numeric, _where, _wraps) + promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.util import ( @@ -219,7 +219,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, initial=initial, where_=where, parallel_reduce=lax.psum, promote_integers=promote_integers) -@_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) +@implements(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: @@ -238,7 +238,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) -@_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) +@implements(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, @@ -256,7 +256,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) -@_wraps(np.max, skip_params=['out']) +@implements(np.max, skip_params=['out']) def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -271,7 +271,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) -@_wraps(np.min, skip_params=['out']) +@implements(np.min, skip_params=['out']) def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -284,7 +284,7 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) -@_wraps(np.all, skip_params=['out']) +@implements(np.all, skip_params=['out']) def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, @@ -296,7 +296,7 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) -@_wraps(np.any, skip_params=['out']) +@implements(np.any, skip_params=['out']) def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, @@ -316,7 +316,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) return size -@_wraps(np.mean, skip_params=['out']) +@implements(np.mean, skip_params=['out']) def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: @@ -365,7 +365,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, * @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ... -@_wraps(np.average) +@implements(np.average) def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims) @@ -425,7 +425,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, return avg -@_wraps(np.var, skip_params=['out']) +@implements(np.var, skip_params=['out']) def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: @@ -486,7 +486,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy return _upcast_f16(computation_dtype), np.dtype(dtype) -@_wraps(np.std, skip_params=['out']) +@implements(np.std, skip_params=['out']) def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: @@ -506,7 +506,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) -@_wraps(np.ptp, skip_params=['out']) +@implements(np.ptp, skip_params=['out']) def ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: return _ptp(a, _ensure_optional_axes(axis), out, keepdims) @@ -522,7 +522,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, return lax.sub(x, y) -@_wraps(np.count_nonzero) +@implements(np.count_nonzero) @partial(api.jit, static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: @@ -546,7 +546,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], else: return out -@_wraps(np.nanmin, skip_params=['out']) +@implements(np.nanmin, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -555,7 +555,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) -@_wraps(np.nanmax, skip_params=['out']) +@implements(np.nanmax, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -564,7 +564,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) -@_wraps(np.nansum, skip_params=['out']) +@implements(np.nansum, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -578,7 +578,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: if nansum.__doc__ is not None: nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n") -@_wraps(np.nanprod, skip_params=['out']) +@implements(np.nanprod, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -588,7 +588,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) -@_wraps(np.nanmean, skip_params=['out']) +@implements(np.nanmean, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, where: ArrayLike | None = None) -> Array: @@ -608,7 +608,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out return td -@_wraps(np.nanvar, skip_params=['out']) +@implements(np.nanvar, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -639,7 +639,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: return lax.convert_element_type(result, dtype) -@_wraps(np.nanstd, skip_params=['out']) +@implements(np.nanstd, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -664,7 +664,7 @@ def __call__(self, a: ArrayLike, axis: Axis = None, def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], fill_nan: bool = False, fill_value: ArrayLike = 0) -> CumulativeReduction: - @_wraps(np_reduction, skip_params=['out'], + @implements(np_reduction, skip_params=['out'], lax_description=CUML_REDUCTION_LAX_DESCRIPTION) def cumulative_reduction(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -709,7 +709,7 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, fill_nan=True, fill_value=1) # Quantiles -@_wraps(np.quantile, skip_params=['out', 'overwrite_input']) +@implements(np.quantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -725,7 +725,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No "Use 'method=' instead.", DeprecationWarning) return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False) -@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) +@implements(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -862,7 +862,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype) -@_wraps(np.percentile, skip_params=['out', 'overwrite_input']) +@implements(np.percentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, @@ -874,7 +874,7 @@ def percentile(a: ArrayLike, q: ArrayLike, return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, interpolation=interpolation, method=method, keepdims=keepdims) -@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input']) +@implements(np.nanpercentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, @@ -887,7 +887,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, interpolation=interpolation, method=method, keepdims=keepdims) -@_wraps(np.median, skip_params=['out', 'overwrite_input']) +@implements(np.median, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, @@ -896,7 +896,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') -@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) +@implements(np.nanmedian, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 4caf608396b2..15dc52cda55c 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -34,7 +34,7 @@ sort, where, zeros) from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan -from jax._src.numpy.util import check_arraylike, _wraps +from jax._src.numpy.util import check_arraylike, implements from jax._src.util import canonicalize_axis from jax._src.typing import Array, ArrayLike @@ -61,7 +61,7 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array: else: return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1) -@_wraps(np.setdiff1d, +@implements(np.setdiff1d, lax_description=_dedent(""" Because the size of the output of ``setdiff1d`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which @@ -98,7 +98,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value) -@_wraps(np.union1d, +@implements(np.union1d, lax_description=_dedent(""" Because the size of the output of ``union1d`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which @@ -125,7 +125,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, return cast(Array, out) -@_wraps(np.setxor1d, lax_description=""" +@implements(np.setxor1d, lax_description=""" In the JAX version, the input arrays are explicitly flattened regardless of assume_unique value. """) @@ -169,7 +169,7 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo return aux, mask -@_wraps(np.intersect1d) +@implements(np.intersect1d) def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return_indices: bool = False) -> Array | tuple[Array, Array, Array]: check_arraylike("intersect1d", ar1, ar2) @@ -206,7 +206,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return int1d -@_wraps(np.isin, lax_description=""" +@implements(np.isin, lax_description=""" In the JAX version, the `assume_unique` argument is not referenced. """) def isin(element: ArrayLike, test_elements: ArrayLike, @@ -312,7 +312,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo ret += (mask.sum(),) return ret[0] if len(ret) == 1 else ret -@_wraps(np.unique, skip_params=['axis'], +@implements(np.unique, skip_params=['axis'], lax_description=_dedent(""" Because the size of the output of ``unique`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which @@ -368,7 +368,7 @@ class _UniqueInverseResult(NamedTuple): inverse_indices: Array -@_wraps(getattr(np, "unique_all", None)) +@implements(getattr(np, "unique_all", None)) def unique_all(x: ArrayLike, /) -> _UniqueAllResult: check_arraylike("unique_all", x) values, indices, inverse_indices, counts = unique( @@ -376,21 +376,21 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult: return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) -@_wraps(getattr(np, "unique_counts", None)) +@implements(getattr(np, "unique_counts", None)) def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: check_arraylike("unique_counts", x) values, counts = unique(x, return_counts=True, equal_nan=False) return _UniqueCountsResult(values=values, counts=counts) -@_wraps(getattr(np, "unique_inverse", None)) +@implements(getattr(np, "unique_inverse", None)) def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: check_arraylike("unique_inverse", x) values, inverse_indices = unique(x, return_inverse=True, equal_nan=False) return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) -@_wraps(getattr(np, "unique_values", None)) +@implements(getattr(np, "unique_values", None)) def unique_values(x: ArrayLike, /) -> Array: check_arraylike("unique_values", x) return cast(Array, unique(x, equal_nan=False)) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 7bc2a59b713a..4cf0017b9e90 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -27,7 +27,7 @@ from jax._src.numpy import reductions from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take from jax._src.numpy.reductions import _moveaxis -from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where +from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize from jax._src.util import canonicalize_axis, set_module import numpy as np @@ -131,7 +131,7 @@ def __call__(self, *args: ArrayLike, raise NotImplementedError(f"where argument of {self}") return self._call(*args, **kwargs) - @_wraps(np.ufunc.reduce, module="numpy.ufunc") + @implements(np.ufunc.reduce, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -219,7 +219,7 @@ def body_fun(i, val): result = result.reshape(final_shape) return result - @_wraps(np.ufunc.accumulate, module="numpy.ufunc") + @implements(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -257,7 +257,7 @@ def scan_fun(carry, _): _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @_wraps(np.ufunc.accumulate, module="numpy.ufunc") + @implements(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: @@ -296,7 +296,7 @@ def scan_fun(carry, x): carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) return carry[1] - @_wraps(np.ufunc.reduceat, module="numpy.ufunc") + @implements(np.ufunc.reduceat, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -335,7 +335,7 @@ def loop_body(i, out): out) return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) - @_wraps(np.ufunc.outer, module="numpy.ufunc") + @implements(np.ufunc.outer, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0]) def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array: if self.nin != 2: diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 40032e285f54..dee3d9a12d89 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -34,7 +34,7 @@ from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, - promote_shapes, _where, _wraps, check_no_float0s) + promote_shapes, _where, implements, check_no_float0s) _lax_const = lax._const @@ -68,9 +68,9 @@ def _one_to_one_unop( fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn) + return implements(numpy_fn, lax_description=doc, module='numpy')(fn) else: - return _wraps(numpy_fn, module='numpy')(fn) + return implements(numpy_fn, module='numpy')(fn) def _one_to_one_binop( @@ -87,9 +87,9 @@ def _one_to_one_binop( fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn) + return implements(numpy_fn, lax_description=doc, module='numpy')(fn) else: - return _wraps(numpy_fn, module='numpy')(fn) + return implements(numpy_fn, module='numpy')(fn) def _maybe_bool_binop( @@ -102,9 +102,9 @@ def fn(x1, x2, /): fn = jit(fn, inline=True) if lax_doc: doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() # type: ignore[union-attr] - return _wraps(numpy_fn, lax_description=doc, module='numpy')(fn) + return implements(numpy_fn, lax_description=doc, module='numpy')(fn) else: - return _wraps(numpy_fn, module='numpy')(fn) + return implements(numpy_fn, module='numpy')(fn) def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp: @@ -120,7 +120,7 @@ def fn(x1, x2, /): return lax_fn(x1, x2) fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}" fn = jit(fn, inline=True) - return _wraps(numpy_fn, module='numpy')(fn) + return implements(numpy_fn, module='numpy')(fn) @overload def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp) -> UnOp: ... @@ -130,7 +130,7 @@ def _logical_op(np_op: Callable[..., Any], bitwise_op: BinOp) -> BinOp: ... def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: ... def _logical_op(np_op: Callable[..., Any], bitwise_op: UnOp | BinOp) -> UnOp | BinOp: - @_wraps(np_op, update_doc=False, module='numpy') + @implements(np_op, update_doc=False, module='numpy') @partial(jit, inline=True) def op(*args): zero = lambda x: lax.full_like(x, shape=(), fill_value=0) @@ -214,14 +214,14 @@ def _arccosh(x: ArrayLike, /) -> Array: atan2 = _one_to_one_binop(getattr(np, "atan2", np.arctan2), lax.atan2, True) -@_wraps(getattr(np, 'bitwise_count', None), module='numpy') +@implements(getattr(np, 'bitwise_count', None), module='numpy') @jit def bitwise_count(x: ArrayLike, /) -> Array: x, = promote_args_numeric("bitwise_count", x) # Following numpy we take the absolute value and return uint8. return lax.population_count(abs(x)).astype('uint8') -@_wraps(np.right_shift, module='numpy') +@implements(np.right_shift, module='numpy') @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2) @@ -229,7 +229,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) -@_wraps(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy') +@implements(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy') @partial(jit, inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2) @@ -237,16 +237,16 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) -@_wraps(np.absolute, module='numpy') +@implements(np.absolute, module='numpy') @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: check_arraylike('absolute', x) dt = dtypes.dtype(x) return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) -abs = _wraps(np.abs, module='numpy')(absolute) +abs = implements(np.abs, module='numpy')(absolute) -@_wraps(np.rint, module='numpy') +@implements(np.rint, module='numpy') @jit def rint(x: ArrayLike, /) -> Array: check_arraylike('rint', x) @@ -258,7 +258,7 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) -@_wraps(np.copysign, module='numpy') +@implements(np.copysign, module='numpy') @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("copysign", x1, x2) @@ -267,7 +267,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) -@_wraps(np.true_divide, module='numpy') +@implements(np.true_divide, module='numpy') @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("true_divide", x1, x2) @@ -276,7 +276,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: divide = true_divide -@_wraps(np.floor_divide, module='numpy') +@implements(np.floor_divide, module='numpy') @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_numeric("floor_divide", x1, x2) @@ -301,7 +301,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _float_divmod(x1, x2)[0] -@_wraps(np.divmod, module='numpy') +@implements(np.divmod, module='numpy') @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: x1, x2 = promote_args_numeric("divmod", x1, x2) @@ -323,7 +323,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod -@_wraps(np.power, module='numpy') +@implements(np.power, module='numpy') def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("power", x1, x2) check_no_float0s("power", x1, x2) @@ -393,7 +393,7 @@ def _pow_int_int(x1, x2): @custom_jvp -@_wraps(np.logaddexp, module='numpy') +@implements(np.logaddexp, module='numpy') @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("logaddexp", x1, x2) @@ -431,7 +431,7 @@ def _logaddexp_jvp(primals, tangents): @custom_jvp -@_wraps(np.logaddexp2, module='numpy') +@implements(np.logaddexp2, module='numpy') @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_inexact("logaddexp2", x1, x2) @@ -459,28 +459,28 @@ def _logaddexp2_jvp(primals, tangents): return primal_out, tangent_out -@_wraps(np.log2, module='numpy') +@implements(np.log2, module='numpy') @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: x, = promote_args_inexact("log2", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) -@_wraps(np.log10, module='numpy') +@implements(np.log10, module='numpy') @partial(jit, inline=True) def log10(x: ArrayLike, /) -> Array: x, = promote_args_inexact("log10", x) return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) -@_wraps(np.exp2, module='numpy') +@implements(np.exp2, module='numpy') @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: x, = promote_args_inexact("exp2", x) return lax.exp2(x) -@_wraps(np.signbit, module='numpy') +@implements(np.signbit, module='numpy') @jit def signbit(x: ArrayLike, /) -> Array: x, = promote_args("signbit", x) @@ -511,7 +511,7 @@ def _normalize_float(x): return lax.bitcast_convert_type(x1, int_type), x2 -@_wraps(np.ldexp, module='numpy') +@implements(np.ldexp, module='numpy') @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("ldexp", x1, x2) @@ -560,7 +560,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) -@_wraps(np.frexp, module='numpy') +@implements(np.frexp, module='numpy') @jit def frexp(x: ArrayLike, /) -> tuple[Array, Array]: check_arraylike("frexp", x) @@ -584,7 +584,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) -@_wraps(np.remainder, module='numpy') +@implements(np.remainder, module='numpy') @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: x1, x2 = promote_args_numeric("remainder", x1, x2) @@ -596,10 +596,10 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) -mod = _wraps(np.mod, module='numpy')(remainder) +mod = implements(np.mod, module='numpy')(remainder) -@_wraps(np.fmod, module='numpy') +@implements(np.fmod, module='numpy') @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("fmod", x1, x2) @@ -608,7 +608,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.rem(*promote_args_numeric("fmod", x1, x2)) -@_wraps(np.square, module='numpy') +@implements(np.square, module='numpy') @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: check_arraylike("square", x) @@ -616,14 +616,14 @@ def square(x: ArrayLike, /) -> Array: return lax.integer_pow(x, 2) -@_wraps(np.deg2rad, module='numpy') +@implements(np.deg2rad, module='numpy') @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: x, = promote_args_inexact("deg2rad", x) return lax.mul(x, _lax_const(x, np.pi / 180)) -@_wraps(np.rad2deg, module='numpy') +@implements(np.rad2deg, module='numpy') @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: x, = promote_args_inexact("rad2deg", x) @@ -634,7 +634,7 @@ def rad2deg(x: ArrayLike, /) -> Array: radians = deg2rad -@_wraps(np.conjugate, module='numpy') +@implements(np.conjugate, module='numpy') @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: check_arraylike("conjugate", x) @@ -642,20 +642,20 @@ def conjugate(x: ArrayLike, /) -> Array: conj = conjugate -@_wraps(np.imag) +@implements(np.imag) @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) -@_wraps(np.real) +@implements(np.real) @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) -@_wraps(np.modf, module='numpy', skip_params=['out']) +@implements(np.modf, module='numpy', skip_params=['out']) @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: check_arraylike("modf", x) @@ -666,7 +666,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: return x - whole, whole -@_wraps(np.isfinite, module='numpy') +@implements(np.isfinite, module='numpy') @partial(jit, inline=True) def isfinite(x: ArrayLike, /) -> Array: check_arraylike("isfinite", x) @@ -679,7 +679,7 @@ def isfinite(x: ArrayLike, /) -> Array: return lax.full_like(x, True, dtype=np.bool_) -@_wraps(np.isinf, module='numpy') +@implements(np.isinf, module='numpy') @jit def isinf(x: ArrayLike, /) -> Array: check_arraylike("isinf", x) @@ -707,24 +707,24 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) -isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])( +isposinf: UnOp = implements(np.isposinf, skip_params=['out'])( lambda x, /, out=None: _isposneginf(np.inf, x, out) ) -isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])( +isneginf: UnOp = implements(np.isneginf, skip_params=['out'])( lambda x, /, out=None: _isposneginf(-np.inf, x, out) ) -@_wraps(np.isnan, module='numpy') +@implements(np.isnan, module='numpy') @partial(jit, inline=True) def isnan(x: ArrayLike, /) -> Array: check_arraylike("isnan", x) return lax.ne(x, x) -@_wraps(np.heaviside, module='numpy') +@implements(np.heaviside, module='numpy') @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("heaviside", x1, x2) @@ -734,7 +734,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) -@_wraps(np.hypot, module='numpy') +@implements(np.hypot, module='numpy') @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: check_arraylike("hypot", x1, x2) @@ -745,7 +745,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1))))) -@_wraps(np.reciprocal, module='numpy') +@implements(np.reciprocal, module='numpy') @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: check_arraylike("reciprocal", x) @@ -753,7 +753,7 @@ def reciprocal(x: ArrayLike, /) -> Array: return lax.integer_pow(x, -1) -@_wraps(np.sinc, update_doc=False) +@implements(np.sinc, update_doc=False) @jit def sinc(x: ArrayLike, /) -> Array: check_arraylike("sinc", x) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 3061de85ad1e..1dfd94d8a6df 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -112,13 +112,13 @@ def _parse_parameters(body: str) -> dict[str, str]: def _parse_extra_params(extra_params: str) -> dict[str, str]: - """Parse the extra parameters passed to _wraps()""" + """Parse the extra parameters passed to implements()""" parameters = _parameter_break.split(extra_params.strip('\n')) return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} -def _wraps( - fun: Callable[..., Any] | None, +def implements( + original_fun: Callable[..., Any] | None, update_doc: bool = True, lax_description: str = "", sections: Sequence[str] = ('Parameters', 'Returns', 'References'), @@ -126,46 +126,46 @@ def _wraps( extra_params: str | None = None, module: str | None = None, ) -> Callable[[_T], _T]: - """Specialized version of functools.wraps for wrapping numpy functions. + """Decorator for JAX functions which implement a specified NumPy function. - This produces a wrapped function with a modified docstring. In particular, if - `update_doc` is True, parameters listed in the wrapped function that are not - supported by the decorated function will be removed from the docstring. For - this reason, it is important that parameter names match those in the original - numpy function. + This mainly contains logic to copy and modify the docstring of the original + function. In particular, if `update_doc` is True, parameters listed in the + original function that are not supported by the decorated function will + be removed from the docstring. For this reason, it is important that parameter + names match those in the original numpy function. Args: - fun: The function being wrapped + original_fun: The original function being implemented update_doc: whether to transform the numpy docstring to remove references of parameters that are supported by the numpy version but not the JAX version. If False, include the numpy docstring verbatim. lax_description: a string description that will be added to the beginning of the docstring. sections: a list of sections to include in the docstring. The default is - ["Parameters", "returns", "References"] + ["Parameters", "Returns", "References"] skip_params: a list of strings containing names of parameters accepted by the function that should be skipped in the parameter list. extra_params: an optional string containing additional parameter descriptions. When ``update_doc=True``, these will be added to the list of parameter descriptions in the updated doc. - module: an optional string specifying the module from which the wrapped function + module: an optional string specifying the module from which the original function is imported. This is useful for objects such as ufuncs, where the module cannot - be determined from the wrapped function itself. + be determined from the original function itself. """ - def wrap(op): - op.__np_wrapped__ = fun - # Allows this pattern: @wraps(getattr(np, 'new_function', None)) - if fun is None: + def decorator(wrapped_fun): + wrapped_fun.__np_wrapped__ = original_fun + # Allows this pattern: @implements(getattr(np, 'new_function', None)) + if original_fun is None: if lax_description: - op.__doc__ = lax_description - return op - docstr = getattr(fun, "__doc__", None) - name = getattr(fun, "__name__", getattr(op, "__name__", str(op))) + wrapped_fun.__doc__ = lax_description + return wrapped_fun + docstr = getattr(original_fun, "__doc__", None) + name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))) try: - mod = module or fun.__module__ + mod = module or original_fun.__module__ except AttributeError: if config.enable_checks.value: - raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.") + raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().") else: name = f"{mod}.{name}" if docstr: @@ -173,7 +173,7 @@ def wrap(op): parsed = _parse_numpydoc(docstr) if update_doc and 'Parameters' in parsed.sections: - code = getattr(getattr(op, "__wrapped__", op), "__code__", None) + code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None) # Remove unrecognized parameter descriptions. parameters = _parse_parameters(parsed.sections['Parameters']) if extra_params: @@ -211,18 +211,18 @@ def wrap(op): except: if config.enable_checks.value: raise - docstr = fun.__doc__ + docstr = original_fun.__doc__ - op.__doc__ = docstr + wrapped_fun.__doc__ = docstr for attr in ['__name__', '__qualname__']: try: - value = getattr(fun, attr) + value = getattr(original_fun, attr) except AttributeError: pass else: - setattr(op, attr, value) - return op - return wrap + setattr(wrapped_fun, attr, value) + return wrapped_fun + return decorator _dtype = partial(dtypes.dtype, canonicalize=True) diff --git a/jax/_src/scipy/cluster/vq.py b/jax/_src/scipy/cluster/vq.py index 53ba6adde06f..8a071ee89f57 100644 --- a/jax/_src/scipy/cluster/vq.py +++ b/jax/_src/scipy/cluster/vq.py @@ -19,7 +19,7 @@ from jax import vmap import jax.numpy as jnp -from jax._src.numpy.util import _wraps, check_arraylike, promote_dtypes_inexact +from jax._src.numpy.util import implements, check_arraylike, promote_dtypes_inexact _no_chkfinite_doc = textwrap.dedent(""" @@ -28,7 +28,7 @@ """) -@_wraps(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',)) +@implements(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',)) def vq(obs, code_book, check_finite=True): check_arraylike("scipy.cluster.vq.vq", obs, code_book) if obs.ndim != code_book.ndim: diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index 6c2d3cd5a53a..3f4fa082604f 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -22,7 +22,7 @@ from jax import lax import jax.numpy as jnp from jax._src.util import canonicalize_axis -from jax._src.numpy.util import _wraps, promote_dtypes_complex +from jax._src.numpy.util import implements, promote_dtypes_complex from jax._src.typing import Array def _W4(N: int, k: Array) -> Array: @@ -42,7 +42,7 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array: # Implementation based on # John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980) -@_wraps(osp_fft.dct) +@implements(osp_fft.dct) def dct(x: Array, type: int = 2, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: if type != 2: @@ -81,7 +81,7 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array: return out -@_wraps(osp_fft.dctn) +@implements(osp_fft.dctn) def dctn(x: Array, type: int = 2, s: Sequence[int] | None=None, axes: Sequence[int] | None = None, @@ -109,7 +109,7 @@ def dctn(x: Array, type: int = 2, return x -@_wraps(osp_fft.dct) +@implements(osp_fft.dct) def idct(x: Array, type: int = 2, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: if type != 2: @@ -139,7 +139,7 @@ def idct(x: Array, type: int = 2, n: int | None = None, out = _dct_deinterleave(x.real, axis) return out -@_wraps(osp_fft.idctn) +@implements(osp_fft.idctn) def idctn(x: Array, type: int = 2, s: Sequence[int] | None=None, axes: Sequence[int] | None = None, diff --git a/jax/_src/scipy/integrate.py b/jax/_src/scipy/integrate.py index d9aebbdae9b1..97cfe0ff1d0e 100644 --- a/jax/_src/scipy/integrate.py +++ b/jax/_src/scipy/integrate.py @@ -23,7 +23,7 @@ from jax._src.typing import Array, ArrayLike import jax.numpy as jnp -@util._wraps(scipy.integrate.trapezoid) +@util.implements(scipy.integrate.trapezoid) @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 9aab983e0cdf..f87255aab91a 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -29,7 +29,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src.lax import qdwh from jax._src.numpy.util import ( - check_arraylike, _wraps, promote_dtypes, promote_dtypes_inexact, + check_arraylike, implements, promote_dtypes, promote_dtypes_inexact, promote_dtypes_complex) from jax._src.typing import Array, ArrayLike @@ -46,14 +46,14 @@ def _cholesky(a: ArrayLike, lower: bool) -> Array: l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False) return l if lower else jnp.conj(l.mT) -@_wraps(scipy.linalg.cholesky, +@implements(scipy.linalg.cholesky, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array: del overwrite_a, check_finite # Unused return _cholesky(a, lower) -@_wraps(scipy.linalg.cho_factor, +@implements(scipy.linalg.cho_factor, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, bool]: @@ -70,7 +70,7 @@ def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array: transpose_a=lower, conjugate_a=lower) return b -@_wraps(scipy.linalg.cho_solve, update_doc=False, +@implements(scipy.linalg.cho_solve, update_doc=False, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite')) def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike, overwrite_b: bool = False, check_finite: bool = True) -> Array: @@ -112,7 +112,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ... -@_wraps(scipy.linalg.svd, +@implements(scipy.linalg.svd, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver')) def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, @@ -120,7 +120,7 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, del overwrite_a, check_finite, lapack_driver # unused return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv) -@_wraps(scipy.linalg.det, +@implements(scipy.linalg.det, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: del overwrite_a, check_finite # unused @@ -182,7 +182,7 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ... -@_wraps(scipy.linalg.eigh, +@implements(scipy.linalg.eigh, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'turbo', 'check_finite')) def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, @@ -198,21 +198,21 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]: a = a.astype(dtypes.to_complex_dtype(a.dtype)) return lax_linalg.schur(a) -@_wraps(scipy.linalg.schur) +@implements(scipy.linalg.schur) def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: if output not in ('real', 'complex'): raise ValueError( f"Expected 'output' to be either 'real' or 'complex', got {output=}.") return _schur(a, output) -@_wraps(scipy.linalg.inv, +@implements(scipy.linalg.inv, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: del overwrite_a, check_finite # unused return jnp.linalg.inv(a) -@_wraps(scipy.linalg.lu_factor, +@implements(scipy.linalg.lu_factor, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) @partial(jit, static_argnames=('overwrite_a', 'check_finite')) def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: @@ -222,7 +222,7 @@ def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True return lu, pivots -@_wraps(scipy.linalg.lu_solve, +@implements(scipy.linalg.lu_solve, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite')) @partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite')) def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0, @@ -269,7 +269,7 @@ def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... -@_wraps(scipy.linalg.lu, update_doc=False, +@implements(scipy.linalg.lu, update_doc=False, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) @partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite')) def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, @@ -320,7 +320,7 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Lit def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ... -@_wraps(scipy.linalg.qr, +@implements(scipy.linalg.qr, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork')) def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: @@ -352,7 +352,7 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) -@_wraps(scipy.linalg.solve, +@implements(scipy.linalg.solve, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite')) def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, @@ -391,7 +391,7 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str, else: return out -@_wraps(scipy.linalg.solve_triangular, +@implements(scipy.linalg.solve_triangular, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'debug', 'check_finite')) def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False, unit_diagonal: bool = False, overwrite_b: bool = False, @@ -414,7 +414,7 @@ def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bo - c=1.97 for float32 or complex64 """) -@_wraps(scipy.linalg.expm, lax_description=_expm_description) +@implements(scipy.linalg.expm, lax_description=_expm_description) @partial(jit, static_argnames=('upper_triangular', 'max_squarings')) def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array: A, = promote_dtypes_inexact(A) @@ -572,7 +572,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: ... -@_wraps(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description) +@implements(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description) @partial(jit, static_argnames=('method', 'compute_expm')) def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: @@ -597,7 +597,7 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, return expm_frechet_AE -@_wraps(scipy.linalg.block_diag) +@implements(scipy.linalg.block_diag) @jit def block_diag(*arrs: ArrayLike) -> Array: if len(arrs) == 0: @@ -619,7 +619,7 @@ def block_diag(*arrs: ArrayLike) -> Array: return acc -@_wraps(scipy.linalg.eigh_tridiagonal) +@implements(scipy.linalg.eigh_tridiagonal) @partial(jit, static_argnames=("eigvals_only", "select", "select_range")) def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False, select: str = 'a', select_range: tuple[float, float] | None = None, @@ -901,7 +901,7 @@ def _sqrtm(A: ArrayLike) -> Array: return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST), jnp.conj(Z.T), precision=lax.Precision.HIGHEST) -@_wraps(scipy.linalg.sqrtm, +@implements(scipy.linalg.sqrtm, lax_description=""" This differs from ``scipy.linalg.sqrtm`` in that the return type of ``jax.scipy.linalg.sqrtm`` is always ``complex64`` for 32-bit input, @@ -918,7 +918,7 @@ def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: raise NotImplementedError("Blocked version is not implemented yet.") return _sqrtm(A) -@_wraps(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc) +@implements(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc) @partial(jit, static_argnames=('check_finite',)) def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: del check_finite # unused @@ -987,7 +987,7 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = Fals def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ... -@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc) +@implements(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc) @partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a')) def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array | tuple[Array, Array]: @@ -1010,7 +1010,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, else: return h -@_wraps(scipy.linalg.toeplitz) +@implements(scipy.linalg.toeplitz) def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: if r is None: check_arraylike("toeplitz", c) diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index 475c0ee62270..1b01af5f4670 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -25,7 +25,7 @@ from jax._src import util from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax._src.typing import ArrayLike, Array from jax._src.util import safe_zip as zip @@ -127,7 +127,7 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike], return result.astype(input_arr.dtype) -@_wraps(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\ +@implements(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\ Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported. Note that interpolation near boundaries differs from the scipy function, diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 48b7e6e300be..284509be29eb 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -34,13 +34,13 @@ from jax._src.lax.lax import PrecisionLike from jax._src.numpy import linalg from jax._src.numpy.util import ( - check_arraylike, _wraps, promote_dtypes_inexact, promote_dtypes_complex) + check_arraylike, implements, promote_dtypes_inexact, promote_dtypes_complex) from jax._src.third_party.scipy import signal_helper from jax._src.typing import Array, ArrayLike from jax._src.util import canonicalize_axis, tuple_delete, tuple_insert -@_wraps(osp_signal.fftconvolve) +@implements(osp_signal.fftconvolve) def fftconvolve(in1: ArrayLike, in2: ArrayLike, mode: str = "full", axes: Sequence[int] | None = None) -> Array: check_arraylike('fftconvolve', in1, in2) @@ -133,7 +133,7 @@ def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike) return result[0, 0] -@_wraps(osp_signal.convolve) +@implements(osp_signal.convolve) def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: if method == 'fft': @@ -144,7 +144,7 @@ def convolve(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', raise ValueError(f"Got {method=}; expected 'auto', 'fft', or 'direct'.") -@_wraps(osp_signal.convolve2d) +@implements(osp_signal.convolve2d) def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill', fillvalue: float = 0, precision: PrecisionLike = None) -> Array: if boundary != 'fill' or fillvalue != 0: @@ -154,13 +154,13 @@ def convolve2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill return _convolve_nd(in1, in2, mode, precision=precision) -@_wraps(osp_signal.correlate) +@implements(osp_signal.correlate) def correlate(in1: Array, in2: Array, mode: str = 'full', method: str = 'auto', precision: PrecisionLike = None) -> Array: return convolve(in1, jnp.flip(in2.conj()), mode, precision=precision, method=method) -@_wraps(osp_signal.correlate2d) +@implements(osp_signal.correlate2d) def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fill', fillvalue: float = 0, precision: PrecisionLike = None) -> Array: if boundary != 'fill' or fillvalue != 0: @@ -191,7 +191,7 @@ def correlate2d(in1: Array, in2: Array, mode: str = 'full', boundary: str = 'fil return result -@_wraps(osp_signal.detrend) +@implements(osp_signal.detrend) def detrend(data: ArrayLike, axis: int = -1, type: str = 'linear', bp: int = 0, overwrite_data: None = None) -> Array: if overwrite_data is not None: @@ -499,7 +499,7 @@ def detrend_func(d): return freqs, time, result -@_wraps(osp_signal.stft) +@implements(osp_signal.stft) def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256, noverlap: int | None = None, nfft: int | None = None, detrend: bool = False, return_onesided: bool = True, boundary: str | None = 'zeros', @@ -518,7 +518,7 @@ def stft(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int = 256 function as `csd(x, None)`.""" -@_wraps(osp_signal.csd, lax_description=_csd_description) +@implements(osp_signal.csd, lax_description=_csd_description) def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, detrend: str = 'constant', @@ -551,7 +551,7 @@ def csd(x: Array, y: ArrayLike | None, fs: ArrayLike = 1.0, window: str = 'hann' return freqs, Pxy -@_wraps(osp_signal.welch) +@implements(osp_signal.welch) def welch(x: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, detrend: str = 'constant', @@ -613,7 +613,7 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: return x.reshape(tuple(batch_shape) + (-1,)) -@_wraps(osp_signal.istft) +@implements(osp_signal.istft) def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', nperseg: int | None = None, noverlap: int | None = None, nfft: int | None = None, input_onesided: bool = True, diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py index c80e5c6c60ac..3f96511a116b 100644 --- a/jax/_src/scipy/spatial/transform.py +++ b/jax/_src/scipy/spatial/transform.py @@ -22,10 +22,10 @@ import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements -@_wraps(scipy.spatial.transform.Rotation) +@implements(scipy.spatial.transform.Rotation) class Rotation(typing.NamedTuple): """Rotation in 3 dimensions.""" @@ -169,7 +169,7 @@ def single(self) -> bool: return self.quat.ndim == 1 -@_wraps(scipy.spatial.transform.Slerp) +@implements(scipy.spatial.transform.Slerp) class Slerp(typing.NamedTuple): """Spherical Linear Interpolation of Rotations.""" diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index fd4566b98fa2..d4aced143016 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -32,19 +32,19 @@ from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax._src.ops import special as ops_special from jax._src.third_party.scipy.betaln import betaln as _betaln_impl from jax._src.typing import Array, ArrayLike -@_wraps(osp_special.gammaln, module='scipy.special') +@implements(osp_special.gammaln, module='scipy.special') def gammaln(x: ArrayLike) -> Array: x, = promote_args_inexact("gammaln", x) return lax.lgamma(x) -@_wraps(osp_special.gamma, module='scipy.special', lax_description="""\ +@implements(osp_special.gamma, module='scipy.special', lax_description="""\ The JAX version only accepts real-valued inputs.""") def gamma(x: ArrayLike) -> Array: x, = promote_args_inexact("gamma", x) @@ -53,14 +53,14 @@ def gamma(x: ArrayLike) -> Array: sign = jnp.where((x > 0) | (x == floor_x), 1.0, (-1.0) ** floor_x) return sign * lax.exp(lax.lgamma(x)) -betaln = _wraps( +betaln = implements( osp_special.betaln, module='scipy.special', update_doc=False )(_betaln_impl) -@_wraps(osp_special.factorial, module='scipy.special') +@implements(osp_special.factorial, module='scipy.special') def factorial(n: ArrayLike, exact: bool = False) -> Array: if exact: raise NotImplementedError("factorial with exact=True") @@ -68,58 +68,58 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array: return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) -@_wraps(osp_special.beta, module='scipy.special') +@implements(osp_special.beta, module='scipy.special') def beta(x: ArrayLike, y: ArrayLike) -> Array: x, y = promote_args_inexact("beta", x, y) return lax.exp(betaln(x, y)) -@_wraps(osp_special.betainc, module='scipy.special') +@implements(osp_special.betainc, module='scipy.special') def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: a, b, x = promote_args_inexact("betainc", a, b, x) return lax.betainc(a, b, x) -@_wraps(osp_special.digamma, module='scipy.special', lax_description="""\ +@implements(osp_special.digamma, module='scipy.special', lax_description="""\ The JAX version only accepts real-valued inputs.""") def digamma(x: ArrayLike) -> Array: x, = promote_args_inexact("digamma", x) return lax.digamma(x) -@_wraps(osp_special.gammainc, module='scipy.special', update_doc=False) +@implements(osp_special.gammainc, module='scipy.special', update_doc=False) def gammainc(a: ArrayLike, x: ArrayLike) -> Array: a, x = promote_args_inexact("gammainc", a, x) return lax.igamma(a, x) -@_wraps(osp_special.gammaincc, module='scipy.special', update_doc=False) +@implements(osp_special.gammaincc, module='scipy.special', update_doc=False) def gammaincc(a: ArrayLike, x: ArrayLike) -> Array: a, x = promote_args_inexact("gammaincc", a, x) return lax.igammac(a, x) -@_wraps(osp_special.erf, module='scipy.special', skip_params=["out"], +@implements(osp_special.erf, module='scipy.special', skip_params=["out"], lax_description="Note that the JAX version does not support complex inputs.") def erf(x: ArrayLike) -> Array: x, = promote_args_inexact("erf", x) return lax.erf(x) -@_wraps(osp_special.erfc, module='scipy.special', update_doc=False) +@implements(osp_special.erfc, module='scipy.special', update_doc=False) def erfc(x: ArrayLike) -> Array: x, = promote_args_inexact("erfc", x) return lax.erfc(x) -@_wraps(osp_special.erfinv, module='scipy.special') +@implements(osp_special.erfinv, module='scipy.special') def erfinv(x: ArrayLike) -> Array: x, = promote_args_inexact("erfinv", x) return lax.erf_inv(x) @custom_derivatives.custom_jvp -@_wraps(osp_special.logit, module='scipy.special', update_doc=False) +@implements(osp_special.logit, module='scipy.special', update_doc=False) def logit(x: ArrayLike) -> Array: x, = promote_args_inexact("logit", x) return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) @@ -127,17 +127,17 @@ def logit(x: ArrayLike) -> Array: lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x)))) -@_wraps(osp_special.expit, module='scipy.special', update_doc=False) +@implements(osp_special.expit, module='scipy.special', update_doc=False) def expit(x: ArrayLike) -> Array: x, = promote_args_inexact("expit", x) return lax.logistic(x) -logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp) +logsumexp = implements(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp) @custom_derivatives.custom_jvp -@_wraps(osp_special.xlogy, module='scipy.special') +@implements(osp_special.xlogy, module='scipy.special') def xlogy(x: ArrayLike, y: ArrayLike) -> Array: # Note: xlogy(0, 0) should return 0 according to the function documentation. x, y = promote_args_inexact("xlogy", x, y) @@ -153,7 +153,7 @@ def _xlogy_jvp(primals, tangents): @custom_derivatives.custom_jvp -@_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False) +@implements(osp_special.xlog1py, module='scipy.special', update_doc=False) def xlog1py(x: ArrayLike, y: ArrayLike) -> Array: # Note: xlog1py(0, -1) should return 0 according to the function documentation. x, y = promote_args_inexact("xlog1py", x, y) @@ -179,14 +179,14 @@ def _xlogx_jvp(primals, tangents): _xlogx.defjvp(_xlogx_jvp) -@_wraps(osp_special.entr, module='scipy.special') +@implements(osp_special.entr, module='scipy.special') def entr(x: ArrayLike) -> Array: x, = promote_args_inexact("entr", x) return lax.select(lax.lt(x, _lax_const(x, 0)), lax.full_like(x, -np.inf), lax.neg(_xlogx(x))) -@_wraps(osp_special.multigammaln, update_doc=False) +@implements(osp_special.multigammaln, update_doc=False) def multigammaln(a: ArrayLike, d: ArrayLike) -> Array: d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = promote_args_inexact("multigammaln", a, d) @@ -201,7 +201,7 @@ def multigammaln(a: ArrayLike, d: ArrayLike) -> Array: return res + constant -@_wraps(osp_special.kl_div, module="scipy.special") +@implements(osp_special.kl_div, module="scipy.special") def kl_div( p: ArrayLike, q: ArrayLike, @@ -227,7 +227,7 @@ def kl_div( return result -@_wraps(osp_special.rel_entr, module="scipy.special") +@implements(osp_special.rel_entr, module="scipy.special") def rel_entr( p: ArrayLike, q: ArrayLike, @@ -268,7 +268,7 @@ def rel_entr( @custom_derivatives.custom_jvp -@_wraps(osp_special.zeta, module='scipy.special') +@implements(osp_special.zeta, module='scipy.special') def zeta(x: ArrayLike, q: ArrayLike | None = None) -> Array: if q is None: raise NotImplementedError( @@ -311,7 +311,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array: zeta.defjvp(partial(jvp, _zeta_series_expansion)) # type: ignore[arg-type] -@_wraps(osp_special.polygamma, module='scipy.special', update_doc=False) +@implements(osp_special.polygamma, module='scipy.special', update_doc=False) def polygamma(n: ArrayLike, x: ArrayLike) -> Array: assert jnp.issubdtype(lax.dtype(n), jnp.integer) n_arr, x_arr = promote_args_inexact("polygamma", n, x) @@ -725,22 +725,22 @@ def _norm_logpdf(x): log_normalizer = _lax_const(x, _norm_logpdf_constant) return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer) -@_wraps(osp_special.i0e, module='scipy.special') +@implements(osp_special.i0e, module='scipy.special') def i0e(x: ArrayLike) -> Array: x, = promote_args_inexact("i0e", x) return lax.bessel_i0e(x) -@_wraps(osp_special.i0, module='scipy.special') +@implements(osp_special.i0, module='scipy.special') def i0(x: ArrayLike) -> Array: x, = promote_args_inexact("i0", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x)) -@_wraps(osp_special.i1e, module='scipy.special') +@implements(osp_special.i1e, module='scipy.special') def i1e(x: ArrayLike) -> Array: x, = promote_args_inexact("i1e", x) return lax.bessel_i1e(x) -@_wraps(osp_special.i1, module='scipy.special') +@implements(osp_special.i1, module='scipy.special') def i1(x: ArrayLike) -> Array: x, = promote_args_inexact("i1", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x)) @@ -1459,7 +1459,7 @@ def _expi_neg(x: Array) -> Array: @custom_derivatives.custom_jvp @jit -@_wraps(osp_special.expi, module='scipy.special') +@implements(osp_special.expi, module='scipy.special') def expi(x: ArrayLike) -> Array: x_arr, = promote_args_inexact("expi", x) return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos]) @@ -1577,7 +1577,7 @@ def _expn3(n: int, x: Array) -> Array: @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,)) @jnp.vectorize -@_wraps(osp_special.expn, module='scipy.special') +@implements(osp_special.expn, module='scipy.special') @jit def expn(n: ArrayLike, x: ArrayLike) -> Array: n, x = promote_args_inexact("expn", n, x) @@ -1615,7 +1615,7 @@ def expn_jvp(n, primals, tangents): ) -@_wraps(osp_special.exp1, module="scipy.special") +@implements(osp_special.exp1, module="scipy.special") def exp1(x: ArrayLike, module='scipy.special') -> Array: x, = promote_args_inexact("exp1", x) # Casting because custom_jvp generic does not work correctly with mypy. @@ -1716,7 +1716,7 @@ def spence(x: Array) -> Array: return _spence(x) -@_wraps(osp_special.bernoulli, module='scipy.special') +@implements(osp_special.bernoulli, module='scipy.special') def bernoulli(n: int) -> Array: # Generate Bernoulli numbers using the Chowla and Hartung algorithm. n = core.concrete_or_error(operator.index, n, "Argument n of bernoulli") @@ -1734,7 +1734,7 @@ def bernoulli(n: int) -> Array: @custom_derivatives.custom_jvp -@_wraps(osp_special.poch, module='scipy.special', lax_description="""\ +@implements(osp_special.poch, module='scipy.special', lax_description="""\ The JAX version only accepts positive and real inputs.""") def poch(z: ArrayLike, m: ArrayLike) -> Array: # Factorial definition when m is close to an integer, otherwise gamma definition. @@ -1883,7 +1883,7 @@ def _hyp1f1_x_derivative(a, b, x): @custom_derivatives.custom_jvp @jit @jnp.vectorize -@_wraps(osp_special.hyp1f1, module='scipy.special', lax_description="""\ +@implements(osp_special.hyp1f1, module='scipy.special', lax_description="""\ The JAX version only accepts positive and real inputs. Values of a, b and x leading to high values of 1F1 might be erroneous, considering enabling double precision. Convention for a = b = 0 is 1, unlike in scipy's implementation.""") diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index a3b285094c43..7325b8cfbe83 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -23,7 +23,7 @@ from jax import jit from jax._src import dtypes from jax._src.api import vmap -from jax._src.numpy.util import check_arraylike, _wraps, promote_args_inexact +from jax._src.numpy.util import check_arraylike, implements, promote_args_inexact from jax._src.typing import ArrayLike, Array from jax._src.util import canonicalize_axis @@ -31,7 +31,7 @@ ModeResult = namedtuple('ModeResult', ('mode', 'count')) -@_wraps(scipy.stats.mode, lax_description="""\ +@implements(scipy.stats.mode, lax_description="""\ Currently the only supported nan_policy is 'propagate' """) @partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) @@ -90,7 +90,7 @@ def invert_permutation(i: Array) -> Array: """Helper function that inverts a permutation array.""" return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype)) -@_wraps(scipy.stats.rankdata, lax_description="""\ +@implements(scipy.stats.rankdata, lax_description="""\ Currently the only supported nan_policy is 'propagate' """) @partial(jit, static_argnames=["method", "axis", "nan_policy"]) @@ -148,7 +148,7 @@ def rankdata( return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_)) raise ValueError(f"unknown method '{method}'") -@_wraps(scipy.stats.sem, lax_description="""\ +@implements(scipy.stats.sem, lax_description="""\ Currently the only supported nan_policies are 'propagate' and 'omit' """) @partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index 70e8250b2004..94d0a6735210 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -18,12 +18,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import xlogy, xlog1py -@_wraps(osp_stats.bernoulli.logpmf, update_doc=False) +@implements(osp_stats.bernoulli.logpmf, update_doc=False) def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc) zero = _lax_const(k, 0) @@ -33,11 +33,11 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), -jnp.inf, log_probs) -@_wraps(osp_stats.bernoulli.pmf, update_doc=False) +@implements(osp_stats.bernoulli.pmf, update_doc=False) def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: return jnp.exp(logpmf(k, p, loc)) -@_wraps(osp_stats.bernoulli.cdf, update_doc=False) +@implements(osp_stats.bernoulli.cdf, update_doc=False) def cdf(k: ArrayLike, p: ArrayLike) -> Array: k, p = promote_args_inexact('bernoulli.cdf', k, p) zero, one = _lax_const(k, 0), _lax_const(k, 1) @@ -50,7 +50,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array: vals = [jnp.nan, zero, one - p, one] return jnp.select(conds, vals) -@_wraps(osp_stats.bernoulli.ppf, update_doc=False) +@implements(osp_stats.bernoulli.ppf, update_doc=False) def ppf(q: ArrayLike, p: ArrayLike) -> Array: q, p = promote_args_inexact('bernoulli.ppf', q, p) zero, one = _lax_const(q, 0), _lax_const(q, 1) diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 4e796ca33bfc..2b30ed7b824a 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -17,12 +17,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import betaln, betainc, xlogy, xlog1py -@_wraps(osp_stats.beta.logpdf, update_doc=False) +@implements(osp_stats.beta.logpdf, update_doc=False) def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale) @@ -36,13 +36,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, lax.lt(x, loc)), -jnp.inf, log_probs) -@_wraps(osp_stats.beta.pdf, update_doc=False) +@implements(osp_stats.beta.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, b, loc, scale)) -@_wraps(osp_stats.beta.cdf, update_doc=False) +@implements(osp_stats.beta.cdf, update_doc=False) def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale) @@ -57,13 +57,13 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, ) -@_wraps(osp_stats.beta.logcdf, update_doc=False) +@implements(osp_stats.beta.logcdf, update_doc=False) def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(cdf(x, a, b, loc, scale)) -@_wraps(osp_stats.beta.sf, update_doc=False) +@implements(osp_stats.beta.sf, update_doc=False) def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale) @@ -78,7 +78,7 @@ def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, ) -@_wraps(osp_stats.beta.logsf, update_doc=False) +@implements(osp_stats.beta.logsf, update_doc=False) def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(sf(x, a, b, loc, scale)) diff --git a/jax/_src/scipy/stats/betabinom.py b/jax/_src/scipy/stats/betabinom.py index ad73b3375d1b..1c7b1f9bd71c 100644 --- a/jax/_src/scipy/stats/betabinom.py +++ b/jax/_src/scipy/stats/betabinom.py @@ -18,12 +18,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.scipy.special import betaln from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.betabinom.logpmf, update_doc=False) +@implements(osp_stats.betabinom.logpmf, update_doc=False) def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0) -> Array: """JAX implementation of scipy.stats.betabinom.logpmf.""" @@ -40,7 +40,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, return jnp.where(n_a_b_cond, jnp.nan, log_probs) -@_wraps(osp_stats.betabinom.pmf, update_doc=False) +@implements(osp_stats.betabinom.pmf, update_doc=False) def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike, loc: ArrayLike = 0) -> Array: """JAX implementation of scipy.stats.betabinom.pmf.""" diff --git a/jax/_src/scipy/stats/binom.py b/jax/_src/scipy/stats/binom.py index 869eab91c5b5..878fdc744510 100644 --- a/jax/_src/scipy/stats/binom.py +++ b/jax/_src/scipy/stats/binom.py @@ -17,12 +17,12 @@ from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.scipy.special import gammaln, xlogy, xlog1py from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.nbinom.logpmf, update_doc=False) +@implements(osp_stats.nbinom.logpmf, update_doc=False) def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: """JAX implementation of scipy.stats.binom.logpmf.""" k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc) @@ -36,7 +36,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -jnp.inf) -@_wraps(osp_stats.nbinom.pmf, update_doc=False) +@implements(osp_stats.nbinom.pmf, update_doc=False) def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: """JAX implementation of scipy.stats.binom.pmf.""" return lax.exp(logpmf(k, n, p, loc)) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 38565ff65c7a..177cc0dcd197 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -18,12 +18,12 @@ from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax.numpy import arctan from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.cauchy.logpdf, update_doc=False) +@implements(osp_stats.cauchy.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale) pi = _lax_const(x, np.pi) @@ -32,13 +32,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) -@_wraps(osp_stats.cauchy.pdf, update_doc=False) +@implements(osp_stats.cauchy.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) -@_wraps(osp_stats.cauchy.cdf, update_doc=False) +@implements(osp_stats.cauchy.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale) pi = _lax_const(x, np.pi) @@ -46,24 +46,24 @@ def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x))) -@_wraps(osp_stats.cauchy.logcdf, update_doc=False) +@implements(osp_stats.cauchy.logcdf, update_doc=False) def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(cdf(x, loc, scale)) -@_wraps(osp_stats.cauchy.sf, update_doc=False) +@implements(osp_stats.cauchy.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale) return cdf(-x, -loc, scale) -@_wraps(osp_stats.cauchy.logsf, update_doc=False) +@implements(osp_stats.cauchy.logsf, update_doc=False) def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale) return logcdf(-x, -loc, scale) -@_wraps(osp_stats.cauchy.isf, update_doc=False) +@implements(osp_stats.cauchy.isf, update_doc=False) def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale) pi = _lax_const(q, np.pi) @@ -72,7 +72,7 @@ def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.add(lax.mul(unscaled, scale), loc) -@_wraps(osp_stats.cauchy.ppf, update_doc=False) +@implements(osp_stats.cauchy.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale) pi = _lax_const(q, np.pi) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index 76decb29e722..8058d49d7c9e 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -18,12 +18,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import gammainc, gammaincc -@_wraps(osp_stats.chi2.logpdf, update_doc=False) +@implements(osp_stats.chi2.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale) one = _lax_const(x, 1) @@ -38,12 +38,12 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) -@_wraps(osp_stats.chi2.pdf, update_doc=False) +@implements(osp_stats.chi2.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, df, loc, scale)) -@_wraps(osp_stats.chi2.cdf, update_doc=False) +@implements(osp_stats.chi2.cdf, update_doc=False) def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale) two = _lax_const(scale, 2) @@ -60,12 +60,12 @@ def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) - ) -@_wraps(osp_stats.chi2.logcdf, update_doc=False) +@implements(osp_stats.chi2.logcdf, update_doc=False) def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(cdf(x, df, loc, scale)) -@_wraps(osp_stats.chi2.sf, update_doc=False) +@implements(osp_stats.chi2.sf, update_doc=False) def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale) two = _lax_const(scale, 2) @@ -82,6 +82,6 @@ def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ) -@_wraps(osp_stats.chi2.logsf, update_doc=False) +@implements(osp_stats.chi2.logsf, update_doc=False) def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(sf(x, df, loc, scale)) diff --git a/jax/_src/scipy/stats/dirichlet.py b/jax/_src/scipy/stats/dirichlet.py index 89c4a8715e2d..f8b5705f8118 100644 --- a/jax/_src/scipy/stats/dirichlet.py +++ b/jax/_src/scipy/stats/dirichlet.py @@ -18,7 +18,7 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import promote_dtypes_inexact, _wraps +from jax._src.numpy.util import promote_dtypes_inexact, implements from jax.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike @@ -28,7 +28,7 @@ def _is_simplex(x: Array) -> Array: return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6) -@_wraps(osp_stats.dirichlet.logpdf, update_doc=False) +@implements(osp_stats.dirichlet.logpdf, update_doc=False) def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array: return _logpdf(*promote_dtypes_inexact(x, alpha)) @@ -52,6 +52,6 @@ def _logpdf(x: Array, alpha: Array) -> Array: return jnp.where(_is_simplex(x), log_probs, -jnp.inf) -@_wraps(osp_stats.dirichlet.pdf, update_doc=False) +@implements(osp_stats.dirichlet.pdf, update_doc=False) def pdf(x: ArrayLike, alpha: ArrayLike) -> Array: return lax.exp(logpdf(x, alpha)) diff --git a/jax/_src/scipy/stats/expon.py b/jax/_src/scipy/stats/expon.py index ed6feb996547..0b2ff0ea4058 100644 --- a/jax/_src/scipy/stats/expon.py +++ b/jax/_src/scipy/stats/expon.py @@ -16,11 +16,11 @@ from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.expon.logpdf, update_doc=False) +@implements(osp_stats.expon.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale) log_scale = lax.log(scale) @@ -28,6 +28,6 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: log_probs = lax.neg(lax.add(linear_term, log_scale)) return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) -@_wraps(osp_stats.expon.pdf, update_doc=False) +@implements(osp_stats.expon.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index 8a5e70215bc6..d63429021566 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -17,12 +17,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc -@_wraps(osp_stats.gamma.logpdf, update_doc=False) +@implements(osp_stats.gamma.logpdf, update_doc=False) def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale) one = _lax_const(x, 1) @@ -32,12 +32,12 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_probs = lax.sub(log_linear_term, shape_terms) return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) -@_wraps(osp_stats.gamma.pdf, update_doc=False) +@implements(osp_stats.gamma.pdf, update_doc=False) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, a, loc, scale)) -@_wraps(osp_stats.gamma.cdf, update_doc=False) +@implements(osp_stats.gamma.cdf, update_doc=False) def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale) return gammainc( @@ -50,17 +50,17 @@ def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> ) -@_wraps(osp_stats.gamma.logcdf, update_doc=False) +@implements(osp_stats.gamma.logcdf, update_doc=False) def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(cdf(x, a, loc, scale)) -@_wraps(osp_stats.gamma.sf, update_doc=False) +@implements(osp_stats.gamma.sf, update_doc=False) def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) return gammaincc(a, lax.div(lax.sub(x, loc), scale)) -@_wraps(osp_stats.gamma.logsf, update_doc=False) +@implements(osp_stats.gamma.logsf, update_doc=False) def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(sf(x, a, loc, scale)) diff --git a/jax/_src/scipy/stats/gennorm.py b/jax/_src/scipy/stats/gennorm.py index 348b4e7d98e1..4b89a25289bf 100644 --- a/jax/_src/scipy/stats/gennorm.py +++ b/jax/_src/scipy/stats/gennorm.py @@ -14,20 +14,20 @@ import scipy.stats as osp_stats from jax import lax -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.gennorm.logpdf, update_doc=False) +@implements(osp_stats.gennorm.logpdf, update_doc=False) def logpdf(x: ArrayLike, p: ArrayLike) -> Array: x, p = promote_args_inexact("gennorm.logpdf", x, p) return lax.log(.5 * p) - lax.lgamma(1/p) - lax.abs(x)**p -@_wraps(osp_stats.gennorm.cdf, update_doc=False) +@implements(osp_stats.gennorm.cdf, update_doc=False) def cdf(x: ArrayLike, p: ArrayLike) -> Array: x, p = promote_args_inexact("gennorm.cdf", x, p) return .5 * (1 + lax.sign(x) * lax.igamma(1/p, lax.abs(x)**p)) -@_wraps(osp_stats.gennorm.pdf, update_doc=False) +@implements(osp_stats.gennorm.pdf, update_doc=False) def pdf(x: ArrayLike, p: ArrayLike) -> Array: return lax.exp(logpdf(x, p)) diff --git a/jax/_src/scipy/stats/geom.py b/jax/_src/scipy/stats/geom.py index 25b3bbac939c..6b59cb31db0a 100644 --- a/jax/_src/scipy/stats/geom.py +++ b/jax/_src/scipy/stats/geom.py @@ -17,12 +17,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax.scipy.special import xlog1py from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.geom.logpmf, update_doc=False) +@implements(osp_stats.geom.logpmf, update_doc=False) def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc) zero = _lax_const(k, 0) @@ -32,6 +32,6 @@ def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: return jnp.where(lax.le(x, zero), -jnp.inf, log_probs) -@_wraps(osp_stats.geom.pmf, update_doc=False) +@implements(osp_stats.geom.pmf, update_doc=False) def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: return jnp.exp(logpmf(k, p, loc)) diff --git a/jax/_src/scipy/stats/kde.py b/jax/_src/scipy/stats/kde.py index c778d33fef05..516935525ce1 100644 --- a/jax/_src/scipy/stats/kde.py +++ b/jax/_src/scipy/stats/kde.py @@ -21,12 +21,12 @@ import jax.numpy as jnp from jax import jit, lax, random, vmap -from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps +from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements from jax._src.tree_util import register_pytree_node_class from jax.scipy import linalg, special -@_wraps(osp_stats.gaussian_kde, update_doc=False) +@implements(osp_stats.gaussian_kde, update_doc=False) @register_pytree_node_class @dataclass(frozen=True, init=False) class gaussian_kde: @@ -113,7 +113,7 @@ def d(self): def n(self): return self.dataset.shape[1] - @_wraps(osp_stats.gaussian_kde.evaluate, update_doc=False) + @implements(osp_stats.gaussian_kde.evaluate, update_doc=False) def evaluate(self, points): check_arraylike("evaluate", points) points = self._reshape_points(points) @@ -121,11 +121,11 @@ def evaluate(self, points): points.T, self.inv_cov) return result[:, 0] - @_wraps(osp_stats.gaussian_kde.__call__, update_doc=False) + @implements(osp_stats.gaussian_kde.__call__, update_doc=False) def __call__(self, points): return self.evaluate(points) - @_wraps(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False) + @implements(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False) def integrate_gaussian(self, mean, cov): mean = jnp.atleast_1d(jnp.squeeze(mean)) cov = jnp.atleast_2d(cov) @@ -141,7 +141,7 @@ def integrate_gaussian(self, mean, cov): return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights, mean) - @_wraps(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False) + @implements(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False) def integrate_box_1d(self, low, high): if self.d != 1: raise ValueError("integrate_box_1d() only handles 1D pdfs") @@ -153,7 +153,7 @@ def integrate_box_1d(self, low, high): high = jnp.squeeze((high - self.dataset) / sigma) return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low))) - @_wraps(osp_stats.gaussian_kde.integrate_kde, update_doc=False) + @implements(osp_stats.gaussian_kde.integrate_kde, update_doc=False) def integrate_kde(self, other): if other.d != self.d: raise ValueError("KDEs are not the same dimensionality") @@ -189,11 +189,11 @@ def resample(self, key, shape=()): dtype=self.dataset.dtype).T return self.dataset[:, ind] + eps - @_wraps(osp_stats.gaussian_kde.pdf, update_doc=False) + @implements(osp_stats.gaussian_kde.pdf, update_doc=False) def pdf(self, x): return self.evaluate(x) - @_wraps(osp_stats.gaussian_kde.logpdf, update_doc=False) + @implements(osp_stats.gaussian_kde.logpdf, update_doc=False) def logpdf(self, x): check_arraylike("logpdf", x) x = self._reshape_points(x) diff --git a/jax/_src/scipy/stats/laplace.py b/jax/_src/scipy/stats/laplace.py index a5c7f3b2ccd5..acd39046dcff 100644 --- a/jax/_src/scipy/stats/laplace.py +++ b/jax/_src/scipy/stats/laplace.py @@ -16,11 +16,11 @@ from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.laplace.logpdf, update_doc=False) +@implements(osp_stats.laplace.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale) two = _lax_const(x, 2) @@ -28,12 +28,12 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale)))) -@_wraps(osp_stats.laplace.pdf, update_doc=False) +@implements(osp_stats.laplace.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) -@_wraps(osp_stats.laplace.cdf, update_doc=False) +@implements(osp_stats.laplace.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale) half = _lax_const(x, 0.5) diff --git a/jax/_src/scipy/stats/logistic.py b/jax/_src/scipy/stats/logistic.py index 67901e83fb7d..b9f7b37b3a00 100644 --- a/jax/_src/scipy/stats/logistic.py +++ b/jax/_src/scipy/stats/logistic.py @@ -18,11 +18,11 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.logistic.logpdf, update_doc=False) +@implements(osp_stats.logistic.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale) x = lax.div(lax.sub(x, loc), scale) @@ -31,30 +31,30 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale)) -@_wraps(osp_stats.logistic.pdf, update_doc=False) +@implements(osp_stats.logistic.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) -@_wraps(osp_stats.logistic.ppf, update_doc=False) +@implements(osp_stats.logistic.ppf, update_doc=False) def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale) return lax.add(lax.mul(logit(x), scale), loc) -@_wraps(osp_stats.logistic.sf, update_doc=False) +@implements(osp_stats.logistic.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale) return expit(lax.neg(lax.div(lax.sub(x, loc), scale))) -@_wraps(osp_stats.logistic.isf, update_doc=False) +@implements(osp_stats.logistic.isf, update_doc=False) def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale) return lax.add(lax.mul(lax.neg(logit(x)), scale), loc) -@_wraps(osp_stats.logistic.cdf, update_doc=False) +@implements(osp_stats.logistic.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale) return expit(lax.div(lax.sub(x, loc), scale)) diff --git a/jax/_src/scipy/stats/multinomial.py b/jax/_src/scipy/stats/multinomial.py index 8da08a6138dd..150573ad7db0 100644 --- a/jax/_src/scipy/stats/multinomial.py +++ b/jax/_src/scipy/stats/multinomial.py @@ -16,12 +16,12 @@ import scipy.stats as osp_stats from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import _wraps, promote_args_inexact, promote_args_numeric +from jax._src.numpy.util import implements, promote_args_inexact, promote_args_numeric from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.multinomial.logpmf, update_doc=False) +@implements(osp_stats.multinomial.logpmf, update_doc=False) def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: """JAX implementation of scipy.stats.multinomial.logpmf.""" p, = promote_args_inexact("multinomial.logpmf", p) @@ -34,7 +34,7 @@ def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -jnp.inf) -@_wraps(osp_stats.multinomial.pmf, update_doc=False) +@implements(osp_stats.multinomial.pmf, update_doc=False) def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array: """JAX implementation of scipy.stats.multinomial.pmf.""" return lax.exp(logpmf(x, n, p)) diff --git a/jax/_src/scipy/stats/multivariate_normal.py b/jax/_src/scipy/stats/multivariate_normal.py index c0d539bceda3..e833da0a49c2 100644 --- a/jax/_src/scipy/stats/multivariate_normal.py +++ b/jax/_src/scipy/stats/multivariate_normal.py @@ -19,11 +19,11 @@ from jax import lax from jax import numpy as jnp -from jax._src.numpy.util import _wraps, promote_dtypes_inexact +from jax._src.numpy.util import implements, promote_dtypes_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description=""" +@implements(osp_stats.multivariate_normal.logpdf, update_doc=False, lax_description=""" In the JAX version, the `allow_singular` argument is not implemented. """) def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike: @@ -50,6 +50,6 @@ def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1)) -@_wraps(osp_stats.multivariate_normal.pdf, update_doc=False) +@implements(osp_stats.multivariate_normal.pdf, update_doc=False) def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array: return lax.exp(logpdf(x, mean, cov)) diff --git a/jax/_src/scipy/stats/nbinom.py b/jax/_src/scipy/stats/nbinom.py index cc8d1a039daa..6af74442da10 100644 --- a/jax/_src/scipy/stats/nbinom.py +++ b/jax/_src/scipy/stats/nbinom.py @@ -18,12 +18,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.scipy.special import gammaln, xlogy from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.nbinom.logpmf, update_doc=False) +@implements(osp_stats.nbinom.logpmf, update_doc=False) def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: """JAX implementation of scipy.stats.nbinom.logpmf.""" k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc) @@ -37,7 +37,7 @@ def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Arra return jnp.where(lax.lt(k, loc), -jnp.inf, log_probs) -@_wraps(osp_stats.nbinom.pmf, update_doc=False) +@implements(osp_stats.nbinom.pmf, update_doc=False) def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array: """JAX implementation of scipy.stats.nbinom.pmf.""" return lax.exp(logpmf(k, n, p, loc)) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index 4f913cc14f40..1258e8905d89 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -20,12 +20,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy import special -@_wraps(osp_stats.norm.logpdf, update_doc=False) +@implements(osp_stats.norm.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale) scale_sqrd = lax.square(scale) @@ -34,41 +34,41 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2)) -@_wraps(osp_stats.norm.pdf, update_doc=False) +@implements(osp_stats.norm.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) -@_wraps(osp_stats.norm.cdf, update_doc=False) +@implements(osp_stats.norm.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale) return special.ndtr(lax.div(lax.sub(x, loc), scale)) -@_wraps(osp_stats.norm.logcdf, update_doc=False) +@implements(osp_stats.norm.logcdf, update_doc=False) def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale) # Cast required because custom_jvp return type is broken. return cast(Array, special.log_ndtr(lax.div(lax.sub(x, loc), scale))) -@_wraps(osp_stats.norm.ppf, update_doc=False) +@implements(osp_stats.norm.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.asarray(special.ndtri(q) * scale + loc, float) -@_wraps(osp_stats.norm.logsf, update_doc=False) +@implements(osp_stats.norm.logsf, update_doc=False) def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale) return logcdf(-x, -loc, scale) -@_wraps(osp_stats.norm.sf, update_doc=False) +@implements(osp_stats.norm.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale) return cdf(-x, -loc, scale) -@_wraps(osp_stats.norm.isf, update_doc=False) +@implements(osp_stats.norm.isf, update_doc=False) def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return ppf(lax.sub(_lax_const(q, 1), q), loc, scale) diff --git a/jax/_src/scipy/stats/pareto.py b/jax/_src/scipy/stats/pareto.py index 20ed059ab440..0600fba29857 100644 --- a/jax/_src/scipy/stats/pareto.py +++ b/jax/_src/scipy/stats/pareto.py @@ -18,11 +18,11 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.pareto.logpdf, update_doc=False) +@implements(osp_stats.pareto.logpdf, update_doc=False) def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale) one = _lax_const(x, 1) @@ -31,6 +31,6 @@ def logpdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) return jnp.where(lax.lt(x, lax.add(loc, scale)), -jnp.inf, log_probs) -@_wraps(osp_stats.pareto.pdf, update_doc=False) +@implements(osp_stats.pareto.pdf, update_doc=False) def pdf(x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, b, loc, scale)) diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py index 473383feccfe..3d1862031be0 100644 --- a/jax/_src/scipy/stats/poisson.py +++ b/jax/_src/scipy/stats/poisson.py @@ -18,12 +18,12 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike from jax.scipy.special import xlogy, gammaln, gammaincc -@_wraps(osp_stats.poisson.logpmf, update_doc=False) +@implements(osp_stats.poisson.logpmf, update_doc=False) def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) @@ -31,11 +31,11 @@ def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: log_probs = xlogy(x, mu) - gammaln(x + 1) - mu return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs) -@_wraps(osp_stats.poisson.pmf, update_doc=False) +@implements(osp_stats.poisson.pmf, update_doc=False) def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: return jnp.exp(logpmf(k, mu, loc)) -@_wraps(osp_stats.poisson.cdf, update_doc=False) +@implements(osp_stats.poisson.cdf, update_doc=False) def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc) zero = _lax_const(k, 0) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index 5a54f2bf5578..742a2e16297c 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -18,11 +18,11 @@ from jax import lax from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.t.logpdf, update_doc=False) +@implements(osp_stats.t.logpdf, update_doc=False) def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale) two = _lax_const(x, 2) @@ -37,6 +37,6 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic)))) -@_wraps(osp_stats.t.pdf, update_doc=False) +@implements(osp_stats.t.pdf, update_doc=False) def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, df, loc, scale)) diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py index e4c48271de60..beadd682da21 100644 --- a/jax/_src/scipy/stats/truncnorm.py +++ b/jax/_src/scipy/stats/truncnorm.py @@ -17,7 +17,7 @@ from jax import lax import jax.numpy as jnp -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.scipy.stats import norm from jax._src.scipy.special import logsumexp, log_ndtr, ndtr @@ -69,7 +69,7 @@ def mass_case_central(a, b): return out -@_wraps(osp_stats.truncnorm.logpdf, update_doc=False) +@implements(osp_stats.truncnorm.logpdf, update_doc=False) def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale) val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b)) @@ -80,23 +80,23 @@ def logpdf(x, a, b, loc=0, scale=1): return val -@_wraps(osp_stats.truncnorm.pdf, update_doc=False) +@implements(osp_stats.truncnorm.pdf, update_doc=False) def pdf(x, a, b, loc=0, scale=1): return lax.exp(logpdf(x, a, b, loc, scale)) -@_wraps(osp_stats.truncnorm.logsf, update_doc=False) +@implements(osp_stats.truncnorm.logsf, update_doc=False) def logsf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale) return logcdf(-x, -b, -a, -loc, scale) -@_wraps(osp_stats.truncnorm.sf, update_doc=False) +@implements(osp_stats.truncnorm.sf, update_doc=False) def sf(x, a, b, loc=0, scale=1): return lax.exp(logsf(x, a, b, loc, scale)) -@_wraps(osp_stats.truncnorm.logcdf, update_doc=False) +@implements(osp_stats.truncnorm.logcdf, update_doc=False) def logcdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale) x, a, b = jnp.broadcast_arrays(x, a, b) @@ -113,6 +113,6 @@ def logcdf(x, a, b, loc=0, scale=1): return logcdf -@_wraps(osp_stats.truncnorm.cdf, update_doc=False) +@implements(osp_stats.truncnorm.cdf, update_doc=False) def cdf(x, a, b, loc=0, scale=1): return lax.exp(logcdf(x, a, b, loc, scale)) diff --git a/jax/_src/scipy/stats/uniform.py b/jax/_src/scipy/stats/uniform.py index 7fae1408c309..ba186cc6ca78 100644 --- a/jax/_src/scipy/stats/uniform.py +++ b/jax/_src/scipy/stats/uniform.py @@ -19,10 +19,10 @@ from jax import numpy as jnp from jax.numpy import where, inf, logical_or from jax._src.typing import Array, ArrayLike -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact -@_wraps(osp_stats.uniform.logpdf, update_doc=False) +@implements(osp_stats.uniform.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale) log_probs = lax.neg(lax.log(scale)) @@ -30,11 +30,11 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: lax.lt(x, loc)), -inf, log_probs) -@_wraps(osp_stats.uniform.pdf, update_doc=False) +@implements(osp_stats.uniform.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) -@_wraps(osp_stats.uniform.cdf, update_doc=False) +@implements(osp_stats.uniform.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale) zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype) @@ -43,7 +43,7 @@ def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return jnp.select(conds, vals) -@_wraps(osp_stats.uniform.ppf, update_doc=False) +@implements(osp_stats.uniform.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale) return where( diff --git a/jax/_src/scipy/stats/vonmises.py b/jax/_src/scipy/stats/vonmises.py index 05c61147911d..b32799c37d3d 100644 --- a/jax/_src/scipy/stats/vonmises.py +++ b/jax/_src/scipy/stats/vonmises.py @@ -17,15 +17,15 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.vonmises.logpdf, update_doc=False) +@implements(osp_stats.vonmises.logpdf, update_doc=False) def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array: x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa) zero = _lax_const(kappa, 0) return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * jnp.pi * lax.bessel_i0e(kappa)), jnp.nan) -@_wraps(osp_stats.vonmises.pdf, update_doc=False) +@implements(osp_stats.vonmises.pdf, update_doc=False) def pdf(x: ArrayLike, kappa: ArrayLike) -> Array: return lax.exp(logpdf(x, kappa)) diff --git a/jax/_src/scipy/stats/wrapcauchy.py b/jax/_src/scipy/stats/wrapcauchy.py index 6f45b5decaf0..f05b4e8606ee 100644 --- a/jax/_src/scipy/stats/wrapcauchy.py +++ b/jax/_src/scipy/stats/wrapcauchy.py @@ -17,11 +17,11 @@ from jax import lax import jax.numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import _wraps, promote_args_inexact +from jax._src.numpy.util import implements, promote_args_inexact from jax._src.typing import Array, ArrayLike -@_wraps(osp_stats.wrapcauchy.logpdf, update_doc=False) +@implements(osp_stats.wrapcauchy.logpdf, update_doc=False) def logpdf(x: ArrayLike, c: ArrayLike) -> Array: x, c = promote_args_inexact('wrapcauchy.logpdf', x, c) return jnp.where( @@ -34,6 +34,6 @@ def logpdf(x: ArrayLike, c: ArrayLike) -> Array: jnp.nan, ) -@_wraps(osp_stats.wrapcauchy.pdf, update_doc=False) +@implements(osp_stats.wrapcauchy.pdf, update_doc=False) def pdf(x: ArrayLike, c: ArrayLike) -> Array: return lax.exp(logpdf(x, c)) diff --git a/jax/_src/third_party/numpy/linalg.py b/jax/_src/third_party/numpy/linalg.py index 5d700089b8f0..7c8ffe9d5276 100644 --- a/jax/_src/third_party/numpy/linalg.py +++ b/jax/_src/third_party/numpy/linalg.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import jax.numpy.linalg as la -from jax._src.numpy.util import check_arraylike, _wraps +from jax._src.numpy.util import check_arraylike, implements def _isEmpty2d(arr): @@ -39,7 +39,7 @@ def _assert2d(*arrays): 'Array must be two-dimensional') -@_wraps(np.linalg.cond) +@implements(np.linalg.cond) def cond(x, p=None): check_arraylike('jnp.linalg.cond', x) _assertNoEmpty2d(x) @@ -62,7 +62,7 @@ def cond(x, p=None): return r -@_wraps(np.linalg.tensorinv) +@implements(np.linalg.tensorinv) def tensorinv(a, ind=2): check_arraylike('jnp.linalg.tensorinv', a) a = jnp.asarray(a) @@ -79,7 +79,7 @@ def tensorinv(a, ind=2): return ia.reshape(*invshape) -@_wraps(np.linalg.tensorsolve) +@implements(np.linalg.tensorsolve) def tensorsolve(a, b, axes=None): check_arraylike('jnp.linalg.tensorsolve', a, b) a = jnp.asarray(a) @@ -108,7 +108,7 @@ def tensorsolve(a, b, axes=None): return res -@_wraps(np.linalg.multi_dot) +@implements(np.linalg.multi_dot) def multi_dot(arrays, *, precision=None): check_arraylike('jnp.linalg.multi_dot', *arrays) n = len(arrays) diff --git a/jax/_src/third_party/scipy/interpolate.py b/jax/_src/third_party/scipy/interpolate.py index 00ef8b6a9324..0634eb4fd6a1 100644 --- a/jax/_src/third_party/scipy/interpolate.py +++ b/jax/_src/third_party/scipy/interpolate.py @@ -4,7 +4,7 @@ from jax.numpy import (asarray, broadcast_arrays, can_cast, empty, nan, searchsorted, where, zeros) from jax._src.tree_util import register_pytree_node -from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, _wraps +from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements def _ndim_coords_from_arrays(points, ndim=None): @@ -31,7 +31,7 @@ def _ndim_coords_from_arrays(points, ndim=None): return points -@_wraps( +@implements( osp_interpolate.RegularGridInterpolator, lax_description=""" In the JAX version, `bounds_error` defaults to and must always be `False` since no @@ -76,7 +76,7 @@ def __init__(self, self.grid = tuple(asarray(p) for p in points) self.values = values - @_wraps(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False) + @implements(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False) def __call__(self, xi, method=None): method = self.method if method is None else method if method not in ("linear", "nearest"): diff --git a/jax/_src/third_party/scipy/linalg.py b/jax/_src/third_party/scipy/linalg.py index b30f87a570fa..80f0656572c4 100644 --- a/jax/_src/third_party/scipy/linalg.py +++ b/jax/_src/third_party/scipy/linalg.py @@ -7,7 +7,7 @@ from jax import jit, lax import jax.numpy as jnp from jax._src.numpy.linalg import norm -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax._src.scipy.linalg import rsf2csf, schur from jax._src.typing import ArrayLike, Array @@ -51,7 +51,7 @@ def _inner_loop(i, p_F_minden): will be printed if the error in the array output is estimated to be large. """ -@_wraps(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION) +@implements(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION) def funm(A: ArrayLike, func: Callable[[Array], Array], disp: bool = True) -> Array | tuple[Array, Array]: A_arr = jnp.asarray(A) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f2ada4fbc228..048f4fbf56bb 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -52,7 +52,7 @@ from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal from jax._src.lib import xla_extension_version -from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps +from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements from jax._src.util import safe_zip, NumpyComplexWarning config.parse_flags_with_absl() @@ -5861,7 +5861,7 @@ def wrapped(x, out=None): if jit: wrapped = jax.jit(wrapped) - wrapped = _wraps(orig, skip_params=['out'])(wrapped) + wrapped = implements(orig, skip_params=['out'])(wrapped) doc = wrapped.__doc__ self.assertStartsWith(doc, "Example Docstring")