Skip to content

Commit

Permalink
Fix complex cast warning (#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli authored Mar 1, 2024
1 parent f15fa97 commit 3ed9a11
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
5 changes: 4 additions & 1 deletion qiskit_dynamics/arraylias/register_functions/linear_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def _(coeffs, mats):

@alias.register_function(lib="jax", path="linear_combo")
def _(coeffs, mats):
return jnp.tensordot(coeffs, mats, axes=1)
# real and imag broken up to avoid real/complex tensordot warning
return jnp.tensordot(coeffs, mats.real, axes=1) + 1j * jnp.tensordot(
coeffs, mats.imag, axes=1
)

from jax.experimental.sparse import sparsify

Expand Down
41 changes: 32 additions & 9 deletions qiskit_dynamics/models/operator_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,18 @@ def evaluate_rhs(
)

if self._static_dissipators is None:
# real and imag broken up to avoid real/complex tensordot warning
mats = _matmul(
self._dissipator_operators, _matmul(y, self._dissipator_operators_adj)
)
both_mult_contribution = _numpy_multi_dispatch(
dis_coefficients,
_matmul(self._dissipator_operators, _matmul(y, self._dissipator_operators_adj)),
mats.real,
path="tensordot",
axes=(-1, -3),
) + 1j * _numpy_multi_dispatch(
dis_coefficients,
mats.imag,
path="tensordot",
axes=(-1, -3),
)
Expand All @@ -522,14 +531,28 @@ def evaluate_rhs(
axis=-3,
)
else:
both_mult_contribution = unp.sum(
_matmul(self._static_dissipators, _matmul(y, self._static_dissipators_adj)),
axis=-3,
) + _numpy_multi_dispatch(
dis_coefficients,
_matmul(self._dissipator_operators, _matmul(y, self._dissipator_operators_adj)),
path="tensordot",
axes=(-1, -3),
# real and imag broken up to avoid real/complex tensordot warning
mats = _matmul(
self._dissipator_operators, _matmul(y, self._dissipator_operators_adj)
)
both_mult_contribution = (
unp.sum(
_matmul(self._static_dissipators, _matmul(y, self._static_dissipators_adj)),
axis=-3,
)
+ _numpy_multi_dispatch(
dis_coefficients,
mats.real,
path="tensordot",
axes=(-1, -3),
)
+ 1j
* _numpy_multi_dispatch(
dis_coefficients,
mats.imag,
path="tensordot",
axes=(-1, -3),
)
)

return left_mult_contribution + right_mult_contribution + both_mult_contribution
Expand Down
11 changes: 6 additions & 5 deletions test/dynamics/solvers/test_solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,17 +443,18 @@ def test_rwa_td_lindblad_model(self):

def test_signals_are_None(self):
"""Test the model signals return to being None after simulation."""

ham_solver = Solver(hamiltonian_operators=[self.X])
ham_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.array([0.0, 1.0]))
ham_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.array([0.0, 1.0], dtype=complex))
self.assertTrue(ham_solver.model.signals is None)

lindblad_solver = Solver(hamiltonian_operators=[self.X], static_dissipators=[self.X])
lindblad_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.eye(2))
lindblad_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.eye(2, dtype=complex))
self.assertTrue(lindblad_solver.model.signals == (None, None))

td_lindblad_solver = Solver(hamiltonian_operators=[self.X], dissipator_operators=[self.X])
td_lindblad_solver.solve(signals=([1.0], [1.0]), t_span=[0.0, 0.01], y0=np.eye(2))
td_lindblad_solver.solve(
signals=([1.0], [1.0]), t_span=[0.0, 0.01], y0=np.eye(2, dtype=complex)
)
self.assertTrue(td_lindblad_solver.model.signals == (None, None))


Expand Down Expand Up @@ -719,7 +720,7 @@ def test_jit_solve(self):
def func(a):
yf = self.ham_solver.solve(
t_span=np.array([0.0, 1.0]),
y0=np.array([0.0, 1.0]),
y0=np.array([0.0, 1.0], dtype=complex),
signals=[Signal(lambda t: a, 5.0)],
method="jax_odeint",
).y[-1]
Expand Down

0 comments on commit 3ed9a11

Please sign in to comment.