Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allenw/interp scalars #12

Merged
merged 4 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ def interp1d(
fx = kwargs.pop("fx", None)
outshape = xq.shape + f.shape[1:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq = jnp.atleast_1d(xq)

errorif(
(len(x) != f.shape[axis]) or (jnp.ndim(x) != 1),
ValueError,
Expand Down Expand Up @@ -621,6 +626,11 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
xq, yq = jnp.broadcast_arrays(xq, yq)
outshape = xq.shape + f.shape[2:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq, yq = map(jnp.atleast_1d, (xq, yq))

errorif(
(len(x) != f.shape[0]) or (x.ndim != 1),
ValueError,
Expand Down Expand Up @@ -839,6 +849,11 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
xq, yq, zq = jnp.broadcast_arrays(xq, yq, zq)
outshape = xq.shape + f.shape[3:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq, yq, zq = map(jnp.atleast_1d, (xq, yq, zq))

fx = kwargs.pop("fx", None)
fy = kwargs.pop("fy", None)
fz = kwargs.pop("fz", None)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
equinox
jax[cpu] >= 0.3.2, <= 0.4.14
jax[cpu] >= 0.3.2, <= 0.4.20
numpy >= 1.20.0, < 1.25.0
scipy >= 1.5.0, < 1.11.0
39 changes: 29 additions & 10 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax.config import config as jax_config
from jax import config as jax_config

from interpax import (
Interpolator1D,
Expand All @@ -24,10 +24,16 @@ class TestInterp1D:
"""Tests for interp1d function."""

@pytest.mark.unit
def test_interp1d(self):
@pytest.mark.parametrize(
"x",
[
np.linspace(0, 2 * np.pi, 10000),
0.0,
],
)
def test_interp1d(self, x):
"""Test accuracy of different 1d interpolation methods."""
xp = np.linspace(0, 2 * np.pi, 100)
x = np.linspace(0, 2 * np.pi, 10000)
f = lambda x: np.sin(x)
fp = f(xp)

Expand Down Expand Up @@ -99,12 +105,17 @@ class TestInterp2D:
"""Tests for interp2d function."""

@pytest.mark.unit
def test_interp2d(self):
@pytest.mark.parametrize(
"x, y",
[
(np.linspace(0, 3 * np.pi, 1000), np.linspace(0, 2 * np.pi, 1000)),
(0.0, 0.0),
],
)
def test_interp2d(self, x, y):
"""Test accuracy of different 2d interpolation methods."""
xp = np.linspace(0, 3 * np.pi, 99)
yp = np.linspace(0, 2 * np.pi, 40)
x = np.linspace(0, 3 * np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
xxp, yyp = np.meshgrid(xp, yp, indexing="ij")

f = lambda x, y: np.sin(x) * np.cos(y)
Expand Down Expand Up @@ -150,14 +161,22 @@ class TestInterp3D:
"""Tests for interp3d function."""

@pytest.mark.unit
def test_interp3d(self):
@pytest.mark.parametrize(
"x, y, z",
[
(
np.linspace(0, np.pi, 1000),
np.linspace(0, 2 * np.pi, 1000),
np.linspace(0, 3, 1000),
),
(0.0, 0.0, 0.0),
],
)
def test_interp3d(self, x, y, z):
"""Test accuracy of different 3d interpolation methods."""
xp = np.linspace(0, np.pi, 20)
yp = np.linspace(0, 2 * np.pi, 30)
zp = np.linspace(0, 3, 25)
x = np.linspace(0, np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
z = np.linspace(0, 3, 1000)
xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij")

f = lambda x, y, z: np.sin(x) * np.cos(y) * z**2
Expand Down