Skip to content

Commit

Permalink
Fix bug when extrap is float
Browse files Browse the repository at this point in the history
-Passing in a float for `extrap` caused an error due to some poorly
ordered conditionals and weird jax dtype issues. Should work now.

Resolves #16
  • Loading branch information
f0uriest committed Feb 1, 2024
1 parent 9633d30 commit 59daa8e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
8 changes: 6 additions & 2 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,8 @@ def _parse_ndarg(arg, n):
def _parse_extrap(extrap, n):
if isbool(extrap): # same for lower,upper in all dimensions
return tuple(extrap for _ in range(2 * n))
elif jnp.isscalar(extrap):
return tuple(extrap for _ in range(2 * n))
elif len(extrap) == 2 and jnp.isscalar(extrap[0]): # same l,h for all dimensions
return tuple(e for _ in range(n) for e in extrap)
elif len(extrap) == n and all(len(extrap[i]) == 2 for i in range(n)):
Expand Down Expand Up @@ -1140,15 +1142,17 @@ def hiclip(fq, hi):
def noclip(fq, *_):
return fq

# if extrap = True, don't clip. If it's false or numeric, clip to that value
# isbool(x) & bool(x) is testing if extrap is True but works for np/jnp bools
fq = jax.lax.cond(
isbool(lo) & lo,
isbool(lo) & jnp.asarray(lo).astype(bool),
noclip,
loclip,
fq,
lo,
)
fq = jax.lax.cond(
isbool(hi) & hi,
isbool(hi) & jnp.asarray(hi).astype(bool),
noclip,
hiclip,
fq,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,15 @@ def test_ad_interp3d(self):
# for some reason finite difference gives nan at endpoints so ignore that
np.testing.assert_allclose(jacf1[1:-1], jacd1[1:-1], rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(jacf2[1:-1], jacd2[1:-1], rtol=1e-6, atol=1e-6)


@pytest.mark.unit
def test_extrap_float():
"""Test for extrap being a float, from gh issue #16."""
x = jnp.linspace(0, 10, 10)
y = jnp.linspace(0, 8, 8)
z = jnp.zeros((10, 8)) + 1.0
interpol = Interpolator2D(x, y, z, extrap=0.0)
np.testing.assert_allclose(interpol(4.5, 5.3), 1.0)
np.testing.assert_allclose(interpol(-4.5, 5.3), 0.0)
np.testing.assert_allclose(interpol(4.5, -5.3), 0.0)

0 comments on commit 59daa8e

Please sign in to comment.