Skip to content

Commit

Permalink
attempted fix for JAX perturbation bug
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Oct 12, 2023
1 parent ae2b009 commit 4cfa22f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
11 changes: 11 additions & 0 deletions qiskit_dynamics/perturbation/dyson_magnus.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ def _setup_dyson_rhs_jax(

custom_matmul = _CustomMatmul(lmult_rule, index_offset=1, backend="jax")

"""
##################################################################################################OLD version
perturbations_evaluation_order = jnp.array(perturbations_evaluation_order, dtype=int)
new_list = [generator] + perturbations
Expand All @@ -483,6 +485,15 @@ def single_eval(idx, t):
def dyson_rhs(t, y):
return custom_matmul(multiple_eval(perturbations_evaluation_order, t), y)
"""
perturbations_evaluation_order = np.array(perturbations_evaluation_order, dtype=int)

def multiple_eval(t):
return jnp.array([new_list[idx](t) for idx in perturbations_evaluation_order])

def dyson_rhs(t, y):
return custom_matmul(multiple_eval(t), y)


return dyson_rhs

Expand Down
5 changes: 0 additions & 5 deletions test/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,3 @@
"""
Qiskit Dynamics tests
"""

# temporarily disable a change in JAX 0.4.4 that introduced a bug. Must be run before importing JAX
import os

os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"

0 comments on commit 4cfa22f

Please sign in to comment.