diff --git a/tests/test_kepler_jax.py b/tests/test_kepler_jax.py index 6703bb8..41a0ada 100644 --- a/tests/test_kepler_jax.py +++ b/tests/test_kepler_jax.py @@ -4,13 +4,12 @@ import pytest import jax -from jax.config import config from jax.test_util import check_grads from kepler_jax import kepler -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) @pytest.fixture(params=[np.float32, np.float64])