diff --git a/qiskit_dynamics/models/rotating_wave_approximation.py b/qiskit_dynamics/models/rotating_wave_approximation.py index 74df88231..98f4fd1e0 100644 --- a/qiskit_dynamics/models/rotating_wave_approximation.py +++ b/qiskit_dynamics/models/rotating_wave_approximation.py @@ -232,7 +232,7 @@ def jax_transformable_func(t): ) if return_signal_map: - signal_translator = lambda a, b: (get_rwa_signals(a), get_rwa_signals(b)) + signal_translator = lambda a: (get_rwa_signals(a[0]), get_rwa_signals(a[1])) return rwa_model, signal_translator return rwa_model diff --git a/qiskit_dynamics/solvers/solver_classes.py b/qiskit_dynamics/solvers/solver_classes.py index f45014d0a..a41c8e719 100644 --- a/qiskit_dynamics/solvers/solver_classes.py +++ b/qiskit_dynamics/solvers/solver_classes.py @@ -190,7 +190,9 @@ def signals(self) -> SignalList: return self._signals @signals.setter - def signals(self, new_signals: Union[List[Signal], SignalList]): + def signals( + self, new_signals: Union[List[Signal], SignalList, Tuple[List[Signal]], Tuple[SignalList]] + ): """Set signals for the solver, and pass to the model.""" self._signals = new_signals if self._rwa_signal_map is not None: diff --git a/releasenotes/notes/lindblad-rwa-bug-491aaeed27ab4780.yaml b/releasenotes/notes/lindblad-rwa-bug-491aaeed27ab4780.yaml new file mode 100644 index 000000000..ce7e52d7f --- /dev/null +++ b/releasenotes/notes/lindblad-rwa-bug-491aaeed27ab4780.yaml @@ -0,0 +1,9 @@ +--- +fixes: + - | + The ``rotating_wave_approximation`` function has been fixed in the case of + the ``model`` argument being a ``LindbladModel`` with ``return_signal_map=True``. + The returned signal mapping function was erroneously defined to take two inputs, + one for Hamiltonian signals and one for dissipator signals. This behaviour has been updated + to be consistent with the documentation, which states that in general this function accepts + only a single argument (in this case a tuple storing both sets of signals). diff --git a/test/dynamics/models/test_rotating_wave_approximation.py b/test/dynamics/models/test_rotating_wave_approximation.py index 7f0776b2b..9ba3e01ed 100644 --- a/test/dynamics/models/test_rotating_wave_approximation.py +++ b/test/dynamics/models/test_rotating_wave_approximation.py @@ -282,13 +282,13 @@ def test_signal_translator_lindblad_model(self): dissipator_signals=sigs, ) f = rotating_wave_approximation(LM, 100, return_signal_map=True)[1] - rwa_ham_sig, rwa_dis_sig = f(sigs, sigs) + rwa_ham_sig, rwa_dis_sig = f((sigs, sigs)) self.assertAllClose(rwa_ham_sig.complex_value(2)[:4], SignalList(sigs).complex_value(2)) self.assertAllClose(rwa_dis_sig.complex_value(2)[:4], SignalList(sigs).complex_value(2)) self.assertAllClose(rwa_ham_sig.complex_value(2)[4:], SignalList(s_prime).complex_value(2)) self.assertAllClose(rwa_dis_sig.complex_value(2)[4:], SignalList(s_prime).complex_value(2)) - self.assertTrue(f(None, None) == (None, None)) + self.assertTrue(f((None, None)) == (None, None)) def test_rwa_operators(self): """Tests get_rwa_operators using pseudorandom numbers."""