diff --git a/qiskit_dynamics/perturbation/dyson_magnus.py b/qiskit_dynamics/perturbation/dyson_magnus.py index 7642bcd25..484a54ad8 100644 --- a/qiskit_dynamics/perturbation/dyson_magnus.py +++ b/qiskit_dynamics/perturbation/dyson_magnus.py @@ -472,6 +472,8 @@ def _setup_dyson_rhs_jax( custom_matmul = _CustomMatmul(lmult_rule, index_offset=1, backend="jax") + """ + ##################################################################################################OLD version perturbations_evaluation_order = jnp.array(perturbations_evaluation_order, dtype=int) new_list = [generator] + perturbations @@ -483,6 +485,15 @@ def single_eval(idx, t): def dyson_rhs(t, y): return custom_matmul(multiple_eval(perturbations_evaluation_order, t), y) + """ + perturbations_evaluation_order = np.array(perturbations_evaluation_order, dtype=int) + + def multiple_eval(t): + return jnp.array([new_list[idx](t) for idx in perturbations_evaluation_order]) + + def dyson_rhs(t, y): + return custom_matmul(multiple_eval(t), y) + return dyson_rhs diff --git a/test/dynamics/__init__.py b/test/dynamics/__init__.py index af0a00fef..be625db2a 100644 --- a/test/dynamics/__init__.py +++ b/test/dynamics/__init__.py @@ -13,8 +13,3 @@ """ Qiskit Dynamics tests """ - -# temporarily disable a change in JAX 0.4.4 that introduced a bug. Must be run before importing JAX -import os - -os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"