Skip to content

Commit

Permalink
Merge pull request #19507 from jakevdp:wraps-implements
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 601505827
  • Loading branch information
jax authors committed Jan 25, 2024
2 parents a6f2630 + 43a9faa commit 45daced
Show file tree
Hide file tree
Showing 47 changed files with 569 additions and 571 deletions.
38 changes: 19 additions & 19 deletions jax/_src/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.numpy.util import check_arraylike, implements
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike
Expand Down Expand Up @@ -105,28 +105,28 @@ def _fft_core(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
return transformed


@_wraps(np.fft.fftn)
@implements(np.fft.fftn)
def fftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('fftn', xla_client.FftType.FFT, a, s, axes, norm)


@_wraps(np.fft.ifftn)
@implements(np.fft.ifftn)
def ifftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm)


@_wraps(np.fft.rfftn)
@implements(np.fft.rfftn)
def rfftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm)


@_wraps(np.fft.irfftn)
@implements(np.fft.irfftn)
def irfftn(a: ArrayLike, s: Shape | None = None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
Expand All @@ -150,31 +150,31 @@ def _fft_core_1d(func_name: str, fft_type: xla_client.FftType,
return _fft_core(func_name, fft_type, a, s, axes, norm)


@_wraps(np.fft.fft)
@implements(np.fft.fft)
def fft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('fft', xla_client.FftType.FFT, a, n=n, axis=axis,
norm=norm)

@_wraps(np.fft.ifft)
@implements(np.fft.ifft)
def ifft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('ifft', xla_client.FftType.IFFT, a, n=n, axis=axis,
norm=norm)

@_wraps(np.fft.rfft)
@implements(np.fft.rfft)
def rfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('rfft', xla_client.FftType.RFFT, a, n=n, axis=axis,
norm=norm)

@_wraps(np.fft.irfft)
@implements(np.fft.irfft)
def irfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
return _fft_core_1d('irfft', xla_client.FftType.IRFFT, a, n=n, axis=axis,
norm=norm)

@_wraps(np.fft.hfft)
@implements(np.fft.hfft)
def hfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
conj_a = ufuncs.conj(a)
Expand All @@ -183,7 +183,7 @@ def hfft(a: ArrayLike, n: int | None = None,
return _fft_core_1d('hfft', xla_client.FftType.IRFFT, conj_a, n=n, axis=axis,
norm=norm) * nn

@_wraps(np.fft.ihfft)
@implements(np.fft.ihfft)
def ihfft(a: ArrayLike, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
_axis_check_1d('ihfft', axis)
Expand All @@ -206,32 +206,32 @@ def _fft_core_2d(func_name: str, fft_type: xla_client.FftType, a: ArrayLike,
return _fft_core(func_name, fft_type, a, s, axes, norm)


@_wraps(np.fft.fft2)
@implements(np.fft.fft2)
def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes,
norm=norm)

@_wraps(np.fft.ifft2)
@implements(np.fft.ifft2)
def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes,
norm=norm)

@_wraps(np.fft.rfft2)
@implements(np.fft.rfft2)
def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes,
norm=norm)

@_wraps(np.fft.irfft2)
@implements(np.fft.irfft2)
def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm: str | None = None) -> Array:
return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes,
norm=norm)


@_wraps(np.fft.fftfreq, extra_params="""
@implements(np.fft.fftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
Expand Down Expand Up @@ -266,7 +266,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
return k / jnp.array(d * n, dtype=dtype)


@_wraps(np.fft.rfftfreq, extra_params="""
@implements(np.fft.rfftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
Expand All @@ -292,7 +292,7 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
return k / jnp.array(d * n, dtype=dtype)


@_wraps(np.fft.fftshift)
@implements(np.fft.fftshift)
def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("fftshift", x)
x = jnp.asarray(x)
Expand All @@ -308,7 +308,7 @@ def fftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
return jnp.roll(x, shift, axes)


@_wraps(np.fft.ifftshift)
@implements(np.fft.ifftshift)
def ifftshift(x: ArrayLike, axes: None | int | Sequence[int] = None) -> Array:
check_arraylike("ifftshift", x)
x = jnp.asarray(x)
Expand Down
Loading

0 comments on commit 45daced

Please sign in to comment.