diff --git a/qiskit_dynamics/arraylias/register_functions/asarray.py b/qiskit_dynamics/arraylias/register_functions/asarray.py index 62c5d325b..397095802 100644 --- a/qiskit_dynamics/arraylias/register_functions/asarray.py +++ b/qiskit_dynamics/arraylias/register_functions/asarray.py @@ -19,12 +19,16 @@ import numpy as np from scipy.sparse import csr_matrix, issparse +from qiskit_dynamics.type_utils import isinstance_qutip_qobj + def register_asarray(alias): """register asarray functions to each array libraries""" @alias.register_default(path="asarray") def _(arr): + if isinstance_qutip_qobj(arr): + return csr_matrix(arr) return np.asarray(arr) @alias.register_fallback(path="asarray") @@ -33,7 +37,7 @@ def _(arr): @alias.register_function(lib="scipy_sparse", path="asarray") def _(arr): - if issparse(arr): + if issparse(arr) or issparse(arr[0]): return arr return csr_matrix(arr) diff --git a/qiskit_dynamics/models/rotating_frame.py b/qiskit_dynamics/models/rotating_frame.py index 8f2b03024..ec80216ba 100644 --- a/qiskit_dynamics/models/rotating_frame.py +++ b/qiskit_dynamics/models/rotating_frame.py @@ -143,7 +143,7 @@ def state_into_frame_basis(self, y: ArrayLike) -> ArrayLike: Returns: ArrayLike: The state in the frame basis. """ - y = unp.to_numeric_matrix_type(y) + y = unp.asarray(y) if self.frame_basis_adjoint is None: return y @@ -158,7 +158,7 @@ def state_out_of_frame_basis(self, y: ArrayLike) -> ArrayLike: Returns: Array: The state in the frame basis. """ - y = unp.to_numeric_matrix_type(y) + y = unp.asarray(y) if self.frame_basis is None: return y @@ -182,7 +182,7 @@ def operator_into_frame_basis( Array: The operator in the frame basis. """ if convert_type: - op = unp.to_numeric_matrix_type(op) + op = unp.asarray(op) if self.frame_basis is None or op is None: return op @@ -210,7 +210,7 @@ def operator_out_of_frame_basis( ArrayLike: The operator in the frame basis. """ if convert_type: - op = unp.to_numeric_matrix_type(op) + op = unp.asarray(op) if self.frame_basis is None or op is None: return op @@ -239,7 +239,7 @@ def state_into_frame( Returns: ArrayLike: The state in the rotating frame. """ - y = unp.to_numeric_matrix_type(y) + y = unp.asarray(y) if self._frame_operator is None: return y @@ -314,8 +314,9 @@ def _conjugate_and_add( Returns: Array of the newly conjugated operator. """ - operator = unp.to_numeric_matrix_type(operator) - op_to_add_in_fb = unp.to_numeric_matrix_type(op_to_add_in_fb) + operator = unp.asarray(operator) + if op_to_add_in_fb is not None: + op_to_add_in_fb = unp.asarray(op_to_add_in_fb) if vectorized_operators: # If passing vectorized operator, undo vectorization temporarily if self._frame_operator is None: @@ -460,7 +461,7 @@ def generator_into_frame( ArrayLike: The generator in the rotating frame. """ if self.frame_operator is None: - return unp.to_numeric_matrix_type(operator) + return unp.asarray(operator) else: # conjugate and subtract the frame diagonal return self._conjugate_and_add( @@ -495,7 +496,7 @@ def generator_out_of_frame( ArrayLike: The generator out of the rotating frame. """ if self.frame_operator is None: - return unp.to_numeric_matrix_type(operator) + return unp.asarray(operator) else: # conjugate and add the frame diagonal return self._conjugate_and_add( diff --git a/test/dynamics/models/test_rotating_frame.py b/test/dynamics/models/test_rotating_frame.py index 83b5d79d6..53bf1f978 100644 --- a/test/dynamics/models/test_rotating_frame.py +++ b/test/dynamics/models/test_rotating_frame.py @@ -23,7 +23,6 @@ from qiskit.quantum_info.operators import Operator from scipy.sparse import csr_matrix from qiskit_dynamics.models.rotating_frame import RotatingFrame -from qiskit_dynamics.arraylias import ArrayLike from qiskit_dynamics.arraylias import DYNAMICS_NUMPY_ALIAS from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp from ..common import JAXTestBase, NumpyTestBase, test_array_backends @@ -100,7 +99,6 @@ def test_operator_into_frame_basis_sparse_list(self): ops = [csr_matrix([[0.0, 1.0], [1.0, 0.0]]), csr_matrix([[1.0, 0.0], [0.0, -1.0]])] rotating_frame = RotatingFrame(self.asarray([[0.0, 1.0], [1.0, 0.0]])) - val = rotating_frame.operator_into_frame_basis(ops) U = rotating_frame.frame_basis Uadj = rotating_frame.frame_basis_adjoint @@ -566,19 +564,19 @@ def test_state_transformations_no_frame_array_type(self): y = self.asarray([1.0, 1j]) out = rotating_frame.state_into_frame(t, y) self.assertAllClose(out, y) - self.assertTrue(isinstance(out, ArrayLike)) + self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library()) out = rotating_frame.state_out_of_frame(t, y) self.assertAllClose(out, y) - self.assertTrue(isinstance(out, ArrayLike)) + self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library()) t = 100.12498 y = self.asarray(np.eye(2)) out = rotating_frame.state_into_frame(t, y) self.assertAllClose(out, y) - self.assertTrue(isinstance(out, ArrayLike)) + self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library()) out = rotating_frame.state_out_of_frame(t, y) self.assertAllClose(out, y) - self.assertTrue(isinstance(out, ArrayLike)) + self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library()) @partial(test_array_backends, array_libraries=["jax"])