From d7969cb0e1f82f64e45c919f261488bc6266298a Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Wed, 15 Nov 2023 11:08:44 -0500 Subject: [PATCH 1/3] Jax requirement from <= 0.4.14 to <= 0.4.20 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2236ad8..042c41e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ equinox -jax[cpu] >= 0.3.2, <= 0.4.14 +jax[cpu] >= 0.3.2, <= 0.4.20 numpy >= 1.20.0, < 1.25.0 scipy >= 1.5.0, < 1.11.0 From fdeb5dd45f36ae3e0c0d54b96fd01c9a5dd48e25 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Wed, 15 Nov 2023 14:19:03 -0500 Subject: [PATCH 2/3] Add support for scalar inputs to interpxd --- interpax/_spline.py | 15 +++++++++++++++ tests/test_interpolate.py | 24 +++++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/interpax/_spline.py b/interpax/_spline.py index 8efb775..76eb5f1 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -474,6 +474,11 @@ def interp1d( fx = kwargs.pop("fx", None) outshape = xq.shape + f.shape[1:] + # Promote scalar query points to 1D array. + # Note this is done after the computation of outshape + # to make jax.grad work in the scalar case. + xq = jnp.atleast_1d(xq) + errorif( (len(x) != f.shape[axis]) or (jnp.ndim(x) != 1), ValueError, @@ -621,6 +626,11 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces xq, yq = jnp.broadcast_arrays(xq, yq) outshape = xq.shape + f.shape[2:] + # Promote scalar query points to 1D array. + # Note this is done after the computation of outshape + # to make jax.grad work in the scalar case. + xq, yq = map(jnp.atleast_1d, (xq, yq)) + errorif( (len(x) != f.shape[0]) or (x.ndim != 1), ValueError, @@ -839,6 +849,11 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces xq, yq, zq = jnp.broadcast_arrays(xq, yq, zq) outshape = xq.shape + f.shape[3:] + # Promote scalar query points to 1D array. + # Note this is done after the computation of outshape + # to make jax.grad work in the scalar case. + xq, yq, zq = map(jnp.atleast_1d, (xq, yq, zq)) + fx = kwargs.pop("fx", None) fy = kwargs.pop("fy", None) fz = kwargs.pop("fz", None) diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index b20c1d0..0212cd3 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -24,10 +24,13 @@ class TestInterp1D: """Tests for interp1d function.""" @pytest.mark.unit - def test_interp1d(self): + @pytest.mark.parametrize("x", [ + np.linspace(0, 2 * np.pi, 10000), + 0.0, + ]) + def test_interp1d(self, x): """Test accuracy of different 1d interpolation methods.""" xp = np.linspace(0, 2 * np.pi, 100) - x = np.linspace(0, 2 * np.pi, 10000) f = lambda x: np.sin(x) fp = f(xp) @@ -99,12 +102,14 @@ class TestInterp2D: """Tests for interp2d function.""" @pytest.mark.unit - def test_interp2d(self): + @pytest.mark.parametrize("x, y", [ + (np.linspace(0, 3 * np.pi, 1000), np.linspace(0, 2 * np.pi, 1000)), + (0.0, 0.0), + ]) + def test_interp2d(self, x, y): """Test accuracy of different 2d interpolation methods.""" xp = np.linspace(0, 3 * np.pi, 99) yp = np.linspace(0, 2 * np.pi, 40) - x = np.linspace(0, 3 * np.pi, 1000) - y = np.linspace(0, 2 * np.pi, 1000) xxp, yyp = np.meshgrid(xp, yp, indexing="ij") f = lambda x, y: np.sin(x) * np.cos(y) @@ -150,14 +155,15 @@ class TestInterp3D: """Tests for interp3d function.""" @pytest.mark.unit - def test_interp3d(self): + @pytest.mark.parametrize("x, y, z", [ + (np.linspace(0, np.pi, 1000), np.linspace(0, 2 * np.pi, 1000), np.linspace(0, 3, 1000)), + (0.0, 0.0, 0.0), + ]) + def test_interp3d(self, x, y, z): """Test accuracy of different 3d interpolation methods.""" xp = np.linspace(0, np.pi, 20) yp = np.linspace(0, 2 * np.pi, 30) zp = np.linspace(0, 3, 25) - x = np.linspace(0, np.pi, 1000) - y = np.linspace(0, 2 * np.pi, 1000) - z = np.linspace(0, 3, 1000) xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij") f = lambda x, y, z: np.sin(x) * np.cos(y) * z**2 From 9daa42595a2462e6a2db80cafbe85c1f82a1875b Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 17 Nov 2023 09:57:43 -0500 Subject: [PATCH 3/3] Fix jax.config deprecation issue --- tests/test_interpolate.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index 0212cd3..20aa9f0 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import numpy as np import pytest -from jax.config import config as jax_config +from jax import config as jax_config from interpax import ( Interpolator1D, @@ -24,10 +24,13 @@ class TestInterp1D: """Tests for interp1d function.""" @pytest.mark.unit - @pytest.mark.parametrize("x", [ - np.linspace(0, 2 * np.pi, 10000), - 0.0, - ]) + @pytest.mark.parametrize( + "x", + [ + np.linspace(0, 2 * np.pi, 10000), + 0.0, + ], + ) def test_interp1d(self, x): """Test accuracy of different 1d interpolation methods.""" xp = np.linspace(0, 2 * np.pi, 100) @@ -102,10 +105,13 @@ class TestInterp2D: """Tests for interp2d function.""" @pytest.mark.unit - @pytest.mark.parametrize("x, y", [ - (np.linspace(0, 3 * np.pi, 1000), np.linspace(0, 2 * np.pi, 1000)), - (0.0, 0.0), - ]) + @pytest.mark.parametrize( + "x, y", + [ + (np.linspace(0, 3 * np.pi, 1000), np.linspace(0, 2 * np.pi, 1000)), + (0.0, 0.0), + ], + ) def test_interp2d(self, x, y): """Test accuracy of different 2d interpolation methods.""" xp = np.linspace(0, 3 * np.pi, 99) @@ -155,10 +161,17 @@ class TestInterp3D: """Tests for interp3d function.""" @pytest.mark.unit - @pytest.mark.parametrize("x, y, z", [ - (np.linspace(0, np.pi, 1000), np.linspace(0, 2 * np.pi, 1000), np.linspace(0, 3, 1000)), - (0.0, 0.0, 0.0), - ]) + @pytest.mark.parametrize( + "x, y, z", + [ + ( + np.linspace(0, np.pi, 1000), + np.linspace(0, 2 * np.pi, 1000), + np.linspace(0, 3, 1000), + ), + (0.0, 0.0, 0.0), + ], + ) def test_interp3d(self, x, y, z): """Test accuracy of different 3d interpolation methods.""" xp = np.linspace(0, np.pi, 20)