Skip to content

Commit

Permalink
removing warning for jax version
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Oct 12, 2023
1 parent 4cfa22f commit 8c4a0a3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 18 deletions.
18 changes: 0 additions & 18 deletions qiskit_dynamics/dispatch/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,6 @@
from jax import Array
from jax.core import Tracer

# warning based on JAX version
from packaging import version
import warnings

if version.parse(jax.__version__) >= version.parse("0.4.4"):
import os

if (
version.parse(jax.__version__) > version.parse("0.4.6")
or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0"
):
warnings.warn(
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, "
"0.4.5, and 0.4.6, using the perturbation module functionality requires setting "
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics."
)

JAX_TYPES = (Array, Tracer)

from ..dispatch import Dispatch
Expand Down
2 changes: 2 additions & 0 deletions qiskit_dynamics/perturbation/dyson_magnus.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ def dyson_rhs(t, y):
"""
perturbations_evaluation_order = np.array(perturbations_evaluation_order, dtype=int)

new_list = [generator] + perturbations

def multiple_eval(t):
return jnp.array([new_list[idx](t) for idx in perturbations_evaluation_order])

Expand Down

0 comments on commit 8c4a0a3

Please sign in to comment.