diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 6247ea7..ae9eb3f 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -21,13 +21,16 @@ jobs: unit_tests: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v3 - - name: Set up Python 3.10 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/interpax/_fourier.py b/interpax/_fourier.py index 7a6d2ac..a2a7b46 100644 --- a/interpax/_fourier.py +++ b/interpax/_fourier.py @@ -1,11 +1,12 @@ from functools import partial +import jax import jax.numpy as jnp from jax import jit @partial(jit, static_argnames="n") -def fft_interp1d(f, n, sx=None, dx=1): +def fft_interp1d(f: jax.Array, n: int, sx: jax.Array = None, dx: float = 1.0): """Interpolation of a 1d periodic function via FFT. Parameters @@ -38,7 +39,15 @@ def fft_interp1d(f, n, sx=None, dx=1): @partial(jit, static_argnames=("n1", "n2")) -def fft_interp2d(f, n1, n2, sx=None, sy=None, dx=1, dy=1): +def fft_interp2d( + f: jax.Array, + n1: int, + n2: int, + sx: jax.Array = None, + sy: jax.Array = None, + dx: float = 1.0, + dy: float = 1.0, +): """Interpolation of a 2d periodic function via FFT. Parameters @@ -82,7 +91,7 @@ def fft_interp2d(f, n1, n2, sx=None, sy=None, dx=1, dy=1): return jnp.fft.fft2(c, axes=(0, 1)).real -def _pad_along_axis(array, pad=(0, 0), axis=0): +def _pad_along_axis(array: jax.Array, pad: tuple = (0, 0), axis: int = 0): """Pad with zeros or truncate a given dimension.""" array = jnp.moveaxis(array, axis, 0) diff --git a/interpax/_spline.py b/interpax/_spline.py index dfb1e22..8efb775 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -2,6 +2,7 @@ from collections import OrderedDict from functools import partial +from typing import Union import equinox as eqx import jax @@ -62,11 +63,19 @@ class Interpolator1D(eqx.Module): f: jax.Array derivs: dict method: str - extrap: bool | float | tuple - period: float | tuple + extrap: Union[bool, float, tuple] + period: Union[None, float] axis: int - def __init__(self, x, f, method="cubic", extrap=False, period=None, **kwargs): + def __init__( + self, + x: jax.Array, + f: jax.Array, + method: str = "cubic", + extrap: Union[bool, float, tuple] = False, + period: Union[None, float] = None, + **kwargs, + ): x, f = map(jnp.asarray, (x, f)) axis = kwargs.get("axis", 0) fx = kwargs.pop("fx", None) @@ -90,7 +99,7 @@ def __init__(self, x, f, method="cubic", extrap=False, period=None, **kwargs): self.derivs = {"fx": fx} - def __call__(self, xq, dx=0): + def __call__(self, xq: jax.Array, dx: int = 0): """Evaluate the interpolated function or its derivatives. Parameters @@ -161,11 +170,20 @@ class Interpolator2D(eqx.Module): f: jax.Array derivs: dict method: str - extrap: bool | float | tuple - period: float | tuple + extrap: Union[bool, float, tuple] + period: Union[None, float, tuple] axis: int - def __init__(self, x, y, f, method="cubic", extrap=False, period=None, **kwargs): + def __init__( + self, + x: jax.Array, + y: jax.Array, + f: jax.Array, + method: str = "cubic", + extrap: Union[bool, float, tuple] = False, + period: Union[None, float, tuple] = None, + **kwargs, + ): x, y, f = map(jnp.asarray, (x, y, f)) axis = kwargs.get("axis", 0) fx = kwargs.pop("fx", None) @@ -201,7 +219,7 @@ def __init__(self, x, y, f, method="cubic", extrap=False, period=None, **kwargs) self.derivs = {"fx": fx, "fy": fy, "fxy": fxy} - def __call__(self, xq, yq, dx=0, dy=0): + def __call__(self, xq: jax.Array, yq: jax.Array, dx: int = 0, dy: int = 0): """Evaluate the interpolated function or its derivatives. Parameters @@ -260,7 +278,7 @@ class Interpolator3D(eqx.Module): also be passed as an array or tuple to specify different conditions [[xlow, xhigh],[ylow,yhigh]] period : float > 0, None, array-like, shape(2,) - periodicity of the function in x, y directions. None denotes no periodicity, + periodicity of the function in x, y, z directions. None denotes no periodicity, otherwise function is assumed to be periodic on the interval [0,period]. Use a single value for the same in both directions. @@ -277,11 +295,21 @@ class Interpolator3D(eqx.Module): f: jax.Array derivs: dict method: str - extrap: bool | float | tuple - period: float | tuple + extrap: Union[bool, float, tuple] + period: Union[None, float, tuple] axis: int - def __init__(self, x, y, z, f, method="cubic", extrap=False, period=None, **kwargs): + def __init__( + self, + x: jax.Array, + y: jax.Array, + z: jax.Array, + f: jax.Array, + method: str = "cubic", + extrap: Union[bool, float, tuple] = False, + period: Union[None, float, tuple] = None, + **kwargs, + ): x, y, z, f = map(jnp.asarray, (x, y, z, f)) axis = kwargs.get("axis", 0) @@ -344,7 +372,15 @@ def __init__(self, x, y, z, f, method="cubic", extrap=False, period=None, **kwar "fxyz": fxyz, } - def __call__(self, xq, yq, zq, dx=0, dy=0, dz=0): + def __call__( + self, + xq: jax.Array, + yq: jax.Array, + zq: jax.Array, + dx: int = 0, + dy: int = 0, + dz: int = 0, + ): """Evaluate the interpolated function or its derivatives. Parameters @@ -377,7 +413,14 @@ def __call__(self, xq, yq, zq, dx=0, dy=0, dz=0): @partial(jit, static_argnames="method") def interp1d( - xq, x, f, method="cubic", derivative=0, extrap=False, period=None, **kwargs + xq: jax.Array, + x: jax.Array, + f: jax.Array, + method: str = "cubic", + derivative: int = 0, + extrap: Union[bool, float, tuple] = False, + period: Union[None, float] = None, + **kwargs, ): """Interpolate a 1d function. @@ -510,15 +553,15 @@ def derivative2(): @partial(jit, static_argnames="method") def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces - xq, - yq, - x, - y, - f, - method="cubic", - derivative=0, - extrap=False, - period=None, + xq: jax.Array, + yq: jax.Array, + x: jax.Array, + y: jax.Array, + f: jax.Array, + method: str = "cubic", + derivative: int = 0, + extrap: Union[bool, float, tuple] = False, + period: Union[None, float, tuple] = None, **kwargs, ): """Interpolate a 2d function. @@ -708,17 +751,17 @@ def derivative1(): @partial(jit, static_argnames="method") def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces - xq, - yq, - zq, - x, - y, - z, - f, - method="cubic", - derivative=0, - extrap=False, - period=None, + xq: jax.Array, + yq: jax.Array, + zq: jax.Array, + x: jax.Array, + y: jax.Array, + z: jax.Array, + f: jax.Array, + method: str = "cubic", + derivative: int = 0, + extrap: Union[bool, float, tuple] = False, + period: Union[None, float, tuple] = None, **kwargs, ): """Interpolate a 3d function. @@ -994,7 +1037,7 @@ def derivative1(): @partial(jit, static_argnames=("axis")) -def _make_periodic(xq, x, period, axis, *arrs): +def _make_periodic(xq: jax.Array, x: jax.Array, period: float, axis: int, *arrs): """Make arrays periodic along a specified axis.""" period = abs(period) xq = xq % period @@ -1018,7 +1061,7 @@ def _make_periodic(xq, x, period, axis, *arrs): @jit -def _get_t_der(t, derivative, dxi): +def _get_t_der(t: jax.Array, derivative: int, dxi: jax.Array): """Get arrays of [1,t,t^2,t^3] for cubic interpolation.""" t0 = jnp.zeros_like(t) t1 = jnp.ones_like(t) @@ -1058,7 +1101,13 @@ def _parse_extrap(extrap, n): @jit -def _extrap(xq, fq, x, lo, hi): +def _extrap( + xq: jax.Array, + fq: jax.Array, + x: jax.Array, + lo: Union[bool, float], + hi: Union[bool, float], +): """Clamp or extrapolate values outside bounds.""" def loclip(fq, lo): @@ -1095,7 +1144,9 @@ def noclip(fq, *_): @partial(jit, static_argnames=("method", "axis")) -def approx_df(x, f, method="cubic", axis=-1, **kwargs): +def approx_df( + x: jax.Array, f: jax.Array, method: str = "cubic", axis: int = -1, **kwargs +): """Approximates first derivatives using cubic spline interpolation. Parameters