diff --git a/src/kepler_jax/kepler_jax.py b/src/kepler_jax/kepler_jax.py index dfb9fa8..93bfd47 100644 --- a/src/kepler_jax/kepler_jax.py +++ b/src/kepler_jax/kepler_jax.py @@ -10,7 +10,7 @@ from jax.abstract_arrays import ShapedArray from jax.interpreters import ad, batching, mlir, xla from jax.lib import xla_client -from jaxlib.mhlo_helpers import custom_call +from jaxlib.hlo_helpers import custom_call # Register the CPU XLA custom calls from . import cpu_ops