Skip to content

Commit

Permalink
replace to_numeric_matrix_type with asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
to24toro committed Nov 17, 2023
1 parent 1e391cb commit ef5e1c7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
6 changes: 5 additions & 1 deletion qiskit_dynamics/arraylias/register_functions/asarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
import numpy as np
from scipy.sparse import csr_matrix, issparse

from qiskit_dynamics.type_utils import isinstance_qutip_qobj


def register_asarray(alias):
"""register asarray functions to each array libraries"""

@alias.register_default(path="asarray")
def _(arr):
if isinstance_qutip_qobj(arr):
return csr_matrix(arr)
return np.asarray(arr)

@alias.register_fallback(path="asarray")
Expand All @@ -33,7 +37,7 @@ def _(arr):

@alias.register_function(lib="scipy_sparse", path="asarray")
def _(arr):
if issparse(arr):
if issparse(arr) or issparse(arr[0]):
return arr
return csr_matrix(arr)

Expand Down
19 changes: 10 additions & 9 deletions qiskit_dynamics/models/rotating_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def state_into_frame_basis(self, y: ArrayLike) -> ArrayLike:
Returns:
ArrayLike: The state in the frame basis.
"""
y = unp.to_numeric_matrix_type(y)
y = unp.asarray(y)
if self.frame_basis_adjoint is None:
return y

Expand All @@ -158,7 +158,7 @@ def state_out_of_frame_basis(self, y: ArrayLike) -> ArrayLike:
Returns:
Array: The state in the frame basis.
"""
y = unp.to_numeric_matrix_type(y)
y = unp.asarray(y)
if self.frame_basis is None:
return y

Expand All @@ -182,7 +182,7 @@ def operator_into_frame_basis(
Array: The operator in the frame basis.
"""
if convert_type:
op = unp.to_numeric_matrix_type(op)
op = unp.asarray(op)

if self.frame_basis is None or op is None:
return op
Expand Down Expand Up @@ -210,7 +210,7 @@ def operator_out_of_frame_basis(
ArrayLike: The operator in the frame basis.
"""
if convert_type:
op = unp.to_numeric_matrix_type(op)
op = unp.asarray(op)

if self.frame_basis is None or op is None:
return op
Expand Down Expand Up @@ -239,7 +239,7 @@ def state_into_frame(
Returns:
ArrayLike: The state in the rotating frame.
"""
y = unp.to_numeric_matrix_type(y)
y = unp.asarray(y)
if self._frame_operator is None:
return y

Expand Down Expand Up @@ -314,8 +314,9 @@ def _conjugate_and_add(
Returns:
Array of the newly conjugated operator.
"""
operator = unp.to_numeric_matrix_type(operator)
op_to_add_in_fb = unp.to_numeric_matrix_type(op_to_add_in_fb)
operator = unp.asarray(operator)
if op_to_add_in_fb is not None:
op_to_add_in_fb = unp.asarray(op_to_add_in_fb)
if vectorized_operators:
# If passing vectorized operator, undo vectorization temporarily
if self._frame_operator is None:
Expand Down Expand Up @@ -460,7 +461,7 @@ def generator_into_frame(
ArrayLike: The generator in the rotating frame.
"""
if self.frame_operator is None:
return unp.to_numeric_matrix_type(operator)
return unp.asarray(operator)
else:
# conjugate and subtract the frame diagonal
return self._conjugate_and_add(
Expand Down Expand Up @@ -495,7 +496,7 @@ def generator_out_of_frame(
ArrayLike: The generator out of the rotating frame.
"""
if self.frame_operator is None:
return unp.to_numeric_matrix_type(operator)
return unp.asarray(operator)
else:
# conjugate and add the frame diagonal
return self._conjugate_and_add(
Expand Down
10 changes: 4 additions & 6 deletions test/dynamics/models/test_rotating_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from qiskit.quantum_info.operators import Operator
from scipy.sparse import csr_matrix
from qiskit_dynamics.models.rotating_frame import RotatingFrame
from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY_ALIAS
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY as unp
from ..common import JAXTestBase, NumpyTestBase, test_array_backends
Expand Down Expand Up @@ -100,7 +99,6 @@ def test_operator_into_frame_basis_sparse_list(self):

ops = [csr_matrix([[0.0, 1.0], [1.0, 0.0]]), csr_matrix([[1.0, 0.0], [0.0, -1.0]])]
rotating_frame = RotatingFrame(self.asarray([[0.0, 1.0], [1.0, 0.0]]))

val = rotating_frame.operator_into_frame_basis(ops)
U = rotating_frame.frame_basis
Uadj = rotating_frame.frame_basis_adjoint
Expand Down Expand Up @@ -566,19 +564,19 @@ def test_state_transformations_no_frame_array_type(self):
y = self.asarray([1.0, 1j])
out = rotating_frame.state_into_frame(t, y)
self.assertAllClose(out, y)
self.assertTrue(isinstance(out, ArrayLike))
self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library())
out = rotating_frame.state_out_of_frame(t, y)
self.assertAllClose(out, y)
self.assertTrue(isinstance(out, ArrayLike))
self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library())

t = 100.12498
y = self.asarray(np.eye(2))
out = rotating_frame.state_into_frame(t, y)
self.assertAllClose(out, y)
self.assertTrue(isinstance(out, ArrayLike))
self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library())
out = rotating_frame.state_out_of_frame(t, y)
self.assertAllClose(out, y)
self.assertTrue(isinstance(out, ArrayLike))
self.assertEqual(DYNAMICS_NUMPY_ALIAS.infer_libs(out)[0], self.array_library())


@partial(test_array_backends, array_libraries=["jax"])
Expand Down

0 comments on commit ef5e1c7

Please sign in to comment.