diff --git a/qiskit_dynamics/dispatch/backends/jax.py b/qiskit_dynamics/dispatch/backends/jax.py index ed104b9ef..0d5d55936 100644 --- a/qiskit_dynamics/dispatch/backends/jax.py +++ b/qiskit_dynamics/dispatch/backends/jax.py @@ -20,24 +20,6 @@ from jax import Array from jax.core import Tracer - # warning based on JAX version - from packaging import version - import warnings - - if version.parse(jax.__version__) >= version.parse("0.4.4"): - import os - - if ( - version.parse(jax.__version__) > version.parse("0.4.6") - or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0" - ): - warnings.warn( - "The functionality in the perturbation module of Qiskit Dynamics requires a JAX " - "version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, " - "0.4.5, and 0.4.6, using the perturbation module functionality requires setting " - "os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics." - ) - JAX_TYPES = (Array, Tracer) from ..dispatch import Dispatch diff --git a/qiskit_dynamics/perturbation/dyson_magnus.py b/qiskit_dynamics/perturbation/dyson_magnus.py index 484a54ad8..948088fdc 100644 --- a/qiskit_dynamics/perturbation/dyson_magnus.py +++ b/qiskit_dynamics/perturbation/dyson_magnus.py @@ -488,6 +488,8 @@ def dyson_rhs(t, y): """ perturbations_evaluation_order = np.array(perturbations_evaluation_order, dtype=int) + new_list = [generator] + perturbations + def multiple_eval(t): return jnp.array([new_list[idx](t) for idx in perturbations_evaluation_order])