From 3ed9a116b65c08b86e4c8ad259cf1af26c3867be Mon Sep 17 00:00:00 2001 From: Daniel Puzzuoli Date: Fri, 1 Mar 2024 09:50:01 -0800 Subject: [PATCH] Fix complex cast warning (#329) --- .../register_functions/linear_combo.py | 5 ++- .../models/operator_collections.py | 41 +++++++++++++++---- test/dynamics/solvers/test_solver_classes.py | 11 ++--- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/qiskit_dynamics/arraylias/register_functions/linear_combo.py b/qiskit_dynamics/arraylias/register_functions/linear_combo.py index 9a263381c..3882f6907 100644 --- a/qiskit_dynamics/arraylias/register_functions/linear_combo.py +++ b/qiskit_dynamics/arraylias/register_functions/linear_combo.py @@ -36,7 +36,10 @@ def _(coeffs, mats): @alias.register_function(lib="jax", path="linear_combo") def _(coeffs, mats): - return jnp.tensordot(coeffs, mats, axes=1) + # real and imag broken up to avoid real/complex tensordot warning + return jnp.tensordot(coeffs, mats.real, axes=1) + 1j * jnp.tensordot( + coeffs, mats.imag, axes=1 + ) from jax.experimental.sparse import sparsify diff --git a/qiskit_dynamics/models/operator_collections.py b/qiskit_dynamics/models/operator_collections.py index 9bddb9ac6..b847211ad 100644 --- a/qiskit_dynamics/models/operator_collections.py +++ b/qiskit_dynamics/models/operator_collections.py @@ -510,9 +510,18 @@ def evaluate_rhs( ) if self._static_dissipators is None: + # real and imag broken up to avoid real/complex tensordot warning + mats = _matmul( + self._dissipator_operators, _matmul(y, self._dissipator_operators_adj) + ) both_mult_contribution = _numpy_multi_dispatch( dis_coefficients, - _matmul(self._dissipator_operators, _matmul(y, self._dissipator_operators_adj)), + mats.real, + path="tensordot", + axes=(-1, -3), + ) + 1j * _numpy_multi_dispatch( + dis_coefficients, + mats.imag, path="tensordot", axes=(-1, -3), ) @@ -522,14 +531,28 @@ def evaluate_rhs( axis=-3, ) else: - both_mult_contribution = unp.sum( - _matmul(self._static_dissipators, _matmul(y, self._static_dissipators_adj)), - axis=-3, - ) + _numpy_multi_dispatch( - dis_coefficients, - _matmul(self._dissipator_operators, _matmul(y, self._dissipator_operators_adj)), - path="tensordot", - axes=(-1, -3), + # real and imag broken up to avoid real/complex tensordot warning + mats = _matmul( + self._dissipator_operators, _matmul(y, self._dissipator_operators_adj) + ) + both_mult_contribution = ( + unp.sum( + _matmul(self._static_dissipators, _matmul(y, self._static_dissipators_adj)), + axis=-3, + ) + + _numpy_multi_dispatch( + dis_coefficients, + mats.real, + path="tensordot", + axes=(-1, -3), + ) + + 1j + * _numpy_multi_dispatch( + dis_coefficients, + mats.imag, + path="tensordot", + axes=(-1, -3), + ) ) return left_mult_contribution + right_mult_contribution + both_mult_contribution diff --git a/test/dynamics/solvers/test_solver_classes.py b/test/dynamics/solvers/test_solver_classes.py index 27be9906b..587173478 100644 --- a/test/dynamics/solvers/test_solver_classes.py +++ b/test/dynamics/solvers/test_solver_classes.py @@ -443,17 +443,18 @@ def test_rwa_td_lindblad_model(self): def test_signals_are_None(self): """Test the model signals return to being None after simulation.""" - ham_solver = Solver(hamiltonian_operators=[self.X]) - ham_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.array([0.0, 1.0])) + ham_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.array([0.0, 1.0], dtype=complex)) self.assertTrue(ham_solver.model.signals is None) lindblad_solver = Solver(hamiltonian_operators=[self.X], static_dissipators=[self.X]) - lindblad_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.eye(2)) + lindblad_solver.solve(signals=[1.0], t_span=[0.0, 0.01], y0=np.eye(2, dtype=complex)) self.assertTrue(lindblad_solver.model.signals == (None, None)) td_lindblad_solver = Solver(hamiltonian_operators=[self.X], dissipator_operators=[self.X]) - td_lindblad_solver.solve(signals=([1.0], [1.0]), t_span=[0.0, 0.01], y0=np.eye(2)) + td_lindblad_solver.solve( + signals=([1.0], [1.0]), t_span=[0.0, 0.01], y0=np.eye(2, dtype=complex) + ) self.assertTrue(td_lindblad_solver.model.signals == (None, None)) @@ -719,7 +720,7 @@ def test_jit_solve(self): def func(a): yf = self.ham_solver.solve( t_span=np.array([0.0, 1.0]), - y0=np.array([0.0, 1.0]), + y0=np.array([0.0, 1.0], dtype=complex), signals=[Signal(lambda t: a, 5.0)], method="jax_odeint", ).y[-1]