Skip to content

Commit

Permalink
Add support for scalar inputs to interpxd
Browse files Browse the repository at this point in the history
  • Loading branch information
allen-adastra committed Nov 15, 2023
1 parent d7969cb commit fdeb5dd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
15 changes: 15 additions & 0 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ def interp1d(
fx = kwargs.pop("fx", None)
outshape = xq.shape + f.shape[1:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq = jnp.atleast_1d(xq)

errorif(
(len(x) != f.shape[axis]) or (jnp.ndim(x) != 1),
ValueError,
Expand Down Expand Up @@ -621,6 +626,11 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces
xq, yq = jnp.broadcast_arrays(xq, yq)
outshape = xq.shape + f.shape[2:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq, yq = map(jnp.atleast_1d, (xq, yq))

errorif(
(len(x) != f.shape[0]) or (x.ndim != 1),
ValueError,
Expand Down Expand Up @@ -839,6 +849,11 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces
xq, yq, zq = jnp.broadcast_arrays(xq, yq, zq)
outshape = xq.shape + f.shape[3:]

# Promote scalar query points to 1D array.
# Note this is done after the computation of outshape
# to make jax.grad work in the scalar case.
xq, yq, zq = map(jnp.atleast_1d, (xq, yq, zq))

fx = kwargs.pop("fx", None)
fy = kwargs.pop("fy", None)
fz = kwargs.pop("fz", None)
Expand Down
24 changes: 15 additions & 9 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ class TestInterp1D:
"""Tests for interp1d function."""

@pytest.mark.unit
def test_interp1d(self):
@pytest.mark.parametrize("x", [
np.linspace(0, 2 * np.pi, 10000),
0.0,
])
def test_interp1d(self, x):
"""Test accuracy of different 1d interpolation methods."""
xp = np.linspace(0, 2 * np.pi, 100)
x = np.linspace(0, 2 * np.pi, 10000)
f = lambda x: np.sin(x)
fp = f(xp)

Expand Down Expand Up @@ -99,12 +102,14 @@ class TestInterp2D:
"""Tests for interp2d function."""

@pytest.mark.unit
def test_interp2d(self):
@pytest.mark.parametrize("x, y", [
(np.linspace(0, 3 * np.pi, 1000), np.linspace(0, 2 * np.pi, 1000)),
(0.0, 0.0),
])
def test_interp2d(self, x, y):
"""Test accuracy of different 2d interpolation methods."""
xp = np.linspace(0, 3 * np.pi, 99)
yp = np.linspace(0, 2 * np.pi, 40)
x = np.linspace(0, 3 * np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
xxp, yyp = np.meshgrid(xp, yp, indexing="ij")

f = lambda x, y: np.sin(x) * np.cos(y)
Expand Down Expand Up @@ -150,14 +155,15 @@ class TestInterp3D:
"""Tests for interp3d function."""

@pytest.mark.unit
def test_interp3d(self):
@pytest.mark.parametrize("x, y, z", [
(np.linspace(0, np.pi, 1000), np.linspace(0, 2 * np.pi, 1000), np.linspace(0, 3, 1000)),
(0.0, 0.0, 0.0),
])
def test_interp3d(self, x, y, z):
"""Test accuracy of different 3d interpolation methods."""
xp = np.linspace(0, np.pi, 20)
yp = np.linspace(0, 2 * np.pi, 30)
zp = np.linspace(0, 3, 25)
x = np.linspace(0, np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
z = np.linspace(0, 3, 1000)
xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij")

f = lambda x, y, z: np.sin(x) * np.cos(y) * z**2
Expand Down

0 comments on commit fdeb5dd

Please sign in to comment.