From 59daa8ecc6167466122006314717774a286613c2 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 31 Jan 2024 21:22:32 -0500 Subject: [PATCH] Fix bug when extrap is float -Passing in a float for `extrap` caused an error due to some poorly ordered conditionals and weird jax dtype issues. Should work now. Resolves #16 --- interpax/_spline.py | 8 ++++++-- tests/test_interpolate.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 380c307..1b4a8a7 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -1104,6 +1104,8 @@ def _parse_ndarg(arg, n): def _parse_extrap(extrap, n): if isbool(extrap): # same for lower,upper in all dimensions return tuple(extrap for _ in range(2 * n)) + elif jnp.isscalar(extrap): + return tuple(extrap for _ in range(2 * n)) elif len(extrap) == 2 and jnp.isscalar(extrap[0]): # same l,h for all dimensions return tuple(e for _ in range(n) for e in extrap) elif len(extrap) == n and all(len(extrap[i]) == 2 for i in range(n)): @@ -1140,15 +1142,17 @@ def hiclip(fq, hi): def noclip(fq, *_): return fq + # if extrap = True, don't clip. If it's false or numeric, clip to that value + # isbool(x) & bool(x) is testing if extrap is True but works for np/jnp bools fq = jax.lax.cond( - isbool(lo) & lo, + isbool(lo) & jnp.asarray(lo).astype(bool), noclip, loclip, fq, lo, ) fq = jax.lax.cond( - isbool(hi) & hi, + isbool(hi) & jnp.asarray(hi).astype(bool), noclip, hiclip, fq, diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index 128dac5..5f101ff 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -613,3 +613,15 @@ def test_ad_interp3d(self): # for some reason finite difference gives nan at endpoints so ignore that np.testing.assert_allclose(jacf1[1:-1], jacd1[1:-1], rtol=1e-6, atol=1e-6) np.testing.assert_allclose(jacf2[1:-1], jacd2[1:-1], rtol=1e-6, atol=1e-6) + + +@pytest.mark.unit +def test_extrap_float(): + """Test for extrap being a float, from gh issue #16.""" + x = jnp.linspace(0, 10, 10) + y = jnp.linspace(0, 8, 8) + z = jnp.zeros((10, 8)) + 1.0 + interpol = Interpolator2D(x, y, z, extrap=0.0) + np.testing.assert_allclose(interpol(4.5, 5.3), 1.0) + np.testing.assert_allclose(interpol(-4.5, 5.3), 0.0) + np.testing.assert_allclose(interpol(4.5, -5.3), 0.0)