diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 585aab458bc..ee1bcfcd5a4 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -10,6 +10,31 @@
Improvements ðŸ›
+* `qml.QubitUnitary` now accepts sparse CSR matrices (from `scipy.sparse`). This allows efficient representation of large unitaries with mostly zero entries. Note that sparse unitaries are still in early development and may not support all features of their dense counterparts.
+ [(#6889)](https://github.com/PennyLaneAI/pennylane/pull/6889)
+
+ ```pycon
+ >>> import numpy as np
+ >>> import pennylane as qml
+ >>> import scipy as sp
+ >>> U_dense = np.eye(4) # 2-wire identity
+ >>> U_sparse = sp.sparse.csr_matrix(U_dense)
+ >>> op = qml.QubitUnitary(U_sparse, wires=[0, 1])
+ >>> print(op.matrix())
+
+ Coords Values
+ (0, 0) 1.0
+ (1, 1) 1.0
+ (2, 2) 1.0
+ (3, 3) 1.0
+ >>> op.matrix().toarray()
+ array([[1., 0., 0., 0.],
+ [0., 1., 0., 0.],
+ [0., 0., 1., 0.],
+ [0., 0., 0., 1.]])
+ ```
+
* Add a decomposition for multi-controlled global phases into a one-less-controlled phase shift.
[(#6936)](https://github.com/PennyLaneAI/pennylane/pull/6936)
@@ -86,7 +111,7 @@
>>> print(new_circuit.diff_method)
'parameter-shift'
```
-
+
* Devices can now configure whether or not ML framework data is sent to them
via an `ExecutionConfig.convert_to_numpy` parameter. End-to-end jitting on
`default.qubit` is used if the user specified a `jax.random.PRNGKey` as a seed.
diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py
index 6d9902d415f..0598f466fbb 100644
--- a/pennylane/math/single_dispatch.py
+++ b/pennylane/math/single_dispatch.py
@@ -19,8 +19,10 @@
# pylint: disable=wrong-import-order
import autoray as ar
import numpy as np
+import scipy as sp
from packaging.version import Version
from scipy.linalg import block_diag as _scipy_block_diag
+from scipy.sparse.linalg import splu
from .interface_utils import get_deep_interface
from .utils import is_abstract
@@ -68,6 +70,67 @@ def _builtins_shape(x):
ar.register_function("scipy", "ndim", np.ndim)
+# -------------------------------- SciPy Sparse --------------------------------- #
+# the following is required to ensure that general SciPy sparse matrices are
+# not automatically 'unwrapped' to dense NumPy arrays. Note that we assume
+# that whenever the backend is 'scipy', the input is a SciPy sparse matrix.
+
+
+def _det_sparse(x):
+ """Compute determinant of sparse matrices without densification"""
+
+ assert sp.sparse.issparse(x), TypeError(f"Expected SciPy sparse, got {type(x)}")
+
+ x = sp.sparse.csr_matrix(x)
+ if x.shape == (2, 2):
+ # Direct array access
+ indptr, indices, data = x.indptr, x.indices, x.data
+ values = {(i, j): 0.0 for i in range(2) for j in range(2)}
+ for i in range(2):
+ for j_idx in range(indptr[i], indptr[i + 1]):
+ j = indices[j_idx]
+ values[(i, j)] = data[j_idx]
+ return values[(0, 0)] * values[(1, 1)] - values[(0, 1)] * values[(1, 0)]
+ return _generic_sparse_det(x)
+
+
+def _generic_sparse_det(A):
+ """Compute the determinant of a sparse matrix using LU decomposition."""
+
+ assert hasattr(A, "tocsc"), TypeError(f"Expected SciPy sparse, got {type(A)}")
+
+ A_csc = A.tocsc()
+ lu = splu(A_csc)
+ U_diag = lu.U.diagonal()
+ det_A = np.prod(U_diag)
+ parity = _permutation_parity(lu.perm_r)
+ return det_A * parity
+
+
+def _permutation_parity(perm):
+ """Compute the parity of a permutation."""
+
+ parity = 1
+ visited = [False] * len(perm)
+ for i in range(len(perm)):
+ if not visited[i]:
+ cycle_length = 0
+ j = i
+ while not visited[j]:
+ visited[j] = True
+ j = perm[j]
+ cycle_length += 1
+
+ if cycle_length:
+
+ parity *= (-1) ** (cycle_length - 1)
+ return parity
+
+
+ar.register_function("scipy", "linalg.det", _det_sparse)
+ar.register_function("scipy", "linalg.eigs", sp.sparse.linalg.eigs)
+ar.register_function("scipy", "trace", lambda x: x.trace())
+
# -------------------------------- NumPy --------------------------------- #
ar.register_function("numpy", "flatten", lambda x: x.flatten())
diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py
index 823c561bc1a..0c333d58dfb 100644
--- a/pennylane/math/utils.py
+++ b/pennylane/math/utils.py
@@ -16,6 +16,7 @@
# pylint: disable=wrong-import-order
import autoray as ar
import numpy as _np
+import scipy as sp
# pylint: disable=import-outside-toplevel
from autograd.numpy.numpy_boxes import ArrayBox
@@ -174,6 +175,9 @@ def convert_like(tensor1, tensor2):
dev = tensor2.device
return np.asarray(tensor1, device=dev, like=interface)
+ if interface == "scipy":
+ return sp.sparse.csr_matrix(tensor1)
+
return np.asarray(tensor1, like=interface)
diff --git a/pennylane/ops/op_math/decompositions/single_qubit_unitary.py b/pennylane/ops/op_math/decompositions/single_qubit_unitary.py
index eab40d549c7..d8ef8be9071 100644
--- a/pennylane/ops/op_math/decompositions/single_qubit_unitary.py
+++ b/pennylane/ops/op_math/decompositions/single_qubit_unitary.py
@@ -14,8 +14,10 @@
"""Contains transforms and helpers functions for decomposing arbitrary unitary
operations into elementary gates.
"""
+from functools import singledispatch
import numpy as np
+import scipy as sp
import pennylane as qml
from pennylane import math
@@ -42,11 +44,17 @@ def _convert_to_su2(U, return_global_phase=False):
with np.errstate(divide="ignore", invalid="ignore"):
determinants = math.linalg.det(U)
phase = math.angle(determinants) / 2
- U = math.cast_like(U, determinants) * math.exp(-1j * math.cast_like(phase, 1j))[:, None, None]
+ U = (
+ U * math.exp(-1j * phase)
+ if sp.sparse.issparse(U)
+ else math.cast_like(U, determinants)
+ * math.exp(-1j * math.cast_like(phase, 1j))[:, None, None]
+ )
return (U, phase) if return_global_phase else U
+@singledispatch
def _zyz_get_rotation_angles(U):
r"""Computes the rotation angles :math:`\phi`, :math:`\theta`, :math:`\omega`
for a unitary :math:`U` that is :math:`SU(2)`
@@ -91,6 +99,49 @@ def _zyz_get_rotation_angles(U):
return phis, thetas, omegas
+@_zyz_get_rotation_angles.register(sp.sparse.csr_matrix)
+def _zyz_get_rotation_angles_sparse(U):
+ r"""Computes the rotation angles :math:`\phi`, :math:`\theta`, :math:`\omega`
+ for a unitary :math:`U` that is :math:`SU(2)`, sparse case
+
+ Args:
+ U (array[complex]): A matrix that is :math:`SU(2)`
+
+ Returns:
+ tuple[array[float]]: A tuple containing the rotation angles
+ :math:`\phi`, :math:`\theta`, :math:`\omega`
+
+ """
+
+ assert sp.sparse.issparse(U), "Do not use this method if U is not sparse"
+
+ u00 = U[0, 0]
+ u01 = U[0, 1]
+ u10 = U[1, 0]
+
+ # For batched U or single U with non-zero off-diagonal, compute the
+ # normal decomposition instead
+ off_diagonal_elements = math.clip(math.abs(u01), 0, 1)
+ thetas = 2 * math.arcsin(off_diagonal_elements)
+
+ # Compute phi and omega from the angles of the top row; use atan2 to keep
+ # the angle within -np.pi and np.pi
+ angles_U00 = math.arctan2(math.imag(u00), math.real(u00))
+ angles_U10 = math.arctan2(math.imag(u10), math.real(u10))
+
+ phis = -angles_U10 - angles_U00
+ omegas = angles_U10 - angles_U00
+
+ phis, thetas, omegas = map(math.squeeze, [phis, thetas, omegas])
+
+ # Normalize the angles
+ phis = phis % (4 * np.pi)
+ thetas = thetas % (4 * np.pi)
+ omegas = omegas % (4 * np.pi)
+
+ return phis, thetas, omegas
+
+
def _rot_decomposition(U, wire, return_global_phase=False):
r"""Compute the decomposition of a single-qubit matrix :math:`U` in terms of
elementary operations, as a single :class:`.RZ` gate or a :class:`.Rot` gate.
@@ -167,7 +218,7 @@ def _get_single_qubit_rot_angles_via_matrix(
of the matrix of the target operation using ZYZ rotations.
"""
# Cast to batched format for more consistent code
- U = math.expand_dims(U, axis=0) if len(U.shape) == 2 else U
+ U = math.expand_dims(U, axis=0) if len(U.shape) == 2 and not sp.sparse.issparse(U) else U
# Convert to SU(2) format and extract global phase
U_su2, global_phase = _convert_to_su2(U, return_global_phase=True)
diff --git a/pennylane/ops/op_math/decompositions/two_qubit_unitary.py b/pennylane/ops/op_math/decompositions/two_qubit_unitary.py
index de223d5f3ef..666a592bad0 100644
--- a/pennylane/ops/op_math/decompositions/two_qubit_unitary.py
+++ b/pennylane/ops/op_math/decompositions/two_qubit_unitary.py
@@ -17,6 +17,7 @@
import warnings
import numpy as np
+import scipy as sp
import pennylane as qml
from pennylane import math
@@ -102,6 +103,9 @@ def _check_differentiability_warning(U):
)
+global_arrays_name = ["E", "Edag", "CNOT01", "CNOT10", "SWAP", "S_SX", "v_one_cnot", "q_one_cnot"]
+
+
def _convert_to_su4(U):
r"""Convert a 4x4 matrix to :math:`SU(4)`.
@@ -621,6 +625,11 @@ def two_qubit_decomposition(U, wires):
_check_differentiability_warning(U)
# First, we note that this method works only for SU(4) gates, meaning that
# we need to rescale the matrix by its determinant.
+ if sp.sparse.issparse(U):
+ # Convert all the global elements to sparse matrices in-place
+ for name in global_arrays_name:
+ array = globals()[name]
+ globals()[name] = sp.sparse.csr_matrix(array)
U = _convert_to_su4(U)
# The next thing we will do is compute the number of CNOTs needed, as this affects
diff --git a/pennylane/ops/qubit/matrix_ops.py b/pennylane/ops/qubit/matrix_ops.py
index 2ca5bd9d50a..0a79dad6dbb 100644
--- a/pennylane/ops/qubit/matrix_ops.py
+++ b/pennylane/ops/qubit/matrix_ops.py
@@ -21,7 +21,9 @@
from typing import Optional, Union
import numpy as np
+import scipy as sp
from scipy.linalg import fractional_matrix_power
+from scipy.sparse import csr_matrix
import pennylane as qml
from pennylane import numpy as pnp
@@ -78,6 +80,10 @@ class QubitUnitary(Operation):
r"""QubitUnitary(U, wires)
Apply an arbitrary unitary matrix with a dimension that is a power of two.
+ .. warning::
+
+ The sparse matrix representation of QubitUnitary is still under development. Currently we only support a limited set of interfaces that preserve the sparsity of the matrix, including ..method::`adjoint`, ..method::`pow`, ..method::`compute_sparse_matrix` and ..method::`compute_decomposition`. Differentiability is not supported for sparse matrices.
+
**Details:**
* Number of wires: Any (the operation can act on any number of wires)
@@ -86,7 +92,7 @@ class QubitUnitary(Operation):
* Gradient recipe: None
Args:
- U (array[complex]): square unitary matrix
+ U (array[complex] or csr_matrix): square unitary matrix
wires (Sequence[int] or int): the wire(s) the operation acts on
id (str): custom label given to an operator instance,
can be useful for some applications where the instance has to be identified
@@ -121,7 +127,7 @@ class QubitUnitary(Operation):
def __init__(
self,
- U: TensorLike,
+ U: Union[TensorLike, csr_matrix],
wires: WiresLike,
id: Optional[str] = None,
unitary_check: bool = False,
@@ -135,19 +141,16 @@ def __init__(
if len(U_shape) not in {2, 3} or U_shape[-2:] != (dim, dim):
raise ValueError(
f"Input unitary must be of shape {(dim, dim)} or (batch_size, {dim}, {dim}) "
- f"to act on {len(wires)} wires."
+ f"to act on {len(wires)} wires. Got shape {U_shape} instead."
)
+ # If the matrix is sparse, we need to convert it to a csr_matrix
+ if sp.sparse.issparse(U):
+ U = U.tocsr()
+
# Check for unitarity; due to variable precision across the different ML frameworks,
# here we issue a warning to check the operation, instead of raising an error outright.
- if unitary_check and not (
- qml.math.is_abstract(U)
- or qml.math.allclose(
- qml.math.einsum("...ij,...kj->...ik", U, qml.math.conj(U)),
- qml.math.eye(dim),
- atol=1e-6,
- )
- ):
+ if unitary_check and not self._unitary_check(U, dim):
warnings.warn(
f"Operator {U}\n may not be unitary. "
"Verify unitarity of operation, or use a datatype with increased precision.",
@@ -156,6 +159,18 @@ def __init__(
super().__init__(U, wires=wires, id=id)
+ @staticmethod
+ def _unitary_check(U, dim):
+ if isinstance(U, csr_matrix):
+ U_dagger = U.conjugate().transpose()
+ identity = sp.sparse.eye(dim, format="csr")
+ return sp.sparse.linalg.norm(U @ U_dagger - identity) < 1e-10
+ return qml.math.allclose(
+ qml.math.einsum("...ij,...kj->...ik", U, qml.math.conj(U)),
+ qml.math.eye(dim),
+ atol=1e-6,
+ )
+
@staticmethod
def compute_matrix(U: TensorLike): # pylint: disable=arguments-differ
r"""Representation of the operator as a canonical matrix in the computational basis (static method).
@@ -180,6 +195,26 @@ def compute_matrix(U: TensorLike): # pylint: disable=arguments-differ
"""
return U
+ @staticmethod
+ def compute_sparse_matrix(U: TensorLike): # pylint: disable=arguments-differ
+ r"""Representation of the operator as a sparse matrix.
+
+ Args:
+ U (tensor_like): unitary matrix
+
+ Returns:
+ csr_matrix: sparse matrix representation
+
+ **Example**
+
+ >>> U = np.array([[0.98877108+0.j, 0.-0.14943813j], [0.-0.14943813j, 0.98877108+0.j]])
+ >>> qml.QubitUnitary.compute_sparse_matrix(U)
+ <2x2 sparse matrix of type ''
+ with 2 stored elements in Compressed Sparse Row format>
+ """
+ U = qml.math.asarray(U, like="numpy")
+ return sp.sparse.csr_matrix(U)
+
@staticmethod
def compute_decomposition(U: TensorLike, wires: WiresLike):
r"""Representation of the operator as a product of other operators (static method).
@@ -236,10 +271,17 @@ def has_decomposition(self) -> bool:
def adjoint(self) -> "QubitUnitary":
U = self.matrix()
+ if isinstance(U, csr_matrix):
+ adjoint_sp_mat = U.conjugate().transpose()
+ # Note: it is necessary to explicitly cast back to csr, or it will be come csc
+ return QubitUnitary(csr_matrix(adjoint_sp_mat), wires=self.wires)
return QubitUnitary(qml.math.moveaxis(qml.math.conj(U), -2, -1), wires=self.wires)
def pow(self, z: Union[int, float]):
mat = self.matrix()
+ if isinstance(mat, csr_matrix):
+ pow_mat = sp.sparse.linalg.matrix_power(mat, z)
+ return [QubitUnitary(pow_mat, wires=self.wires)]
if isinstance(z, int) and qml.math.get_deep_interface(mat) != "tensorflow":
pow_mat = qml.math.linalg.matrix_power(mat, z)
elif self.batch_size is not None or qml.math.shape(z) != ():
diff --git a/pennylane/ops/qubit/parametric_ops_single_qubit.py b/pennylane/ops/qubit/parametric_ops_single_qubit.py
index c0c8c168147..456a8c480d5 100644
--- a/pennylane/ops/qubit/parametric_ops_single_qubit.py
+++ b/pennylane/ops/qubit/parametric_ops_single_qubit.py
@@ -21,6 +21,7 @@
from typing import Optional, Union
import numpy as np
+import scipy as sp
import pennylane as qml
from pennylane.operation import Operation
@@ -113,6 +114,15 @@ def compute_matrix(theta: TensorLike) -> TensorLike: # pylint: disable=argument
js = -1j * s
return qml.math.stack([stack_last([c, js]), stack_last([js, c])], axis=-2)
+ @staticmethod
+ def compute_sparse_matrix(theta):
+ return sp.sparse.csr_matrix(
+ [
+ [np.cos(theta / 2), -1j * np.sin(theta / 2)],
+ [-1j * np.sin(theta / 2), np.cos(theta / 2)],
+ ]
+ )
+
def adjoint(self) -> "RX":
return RX(-self.data[0], wires=self.wires)
@@ -209,6 +219,12 @@ def compute_matrix(theta: TensorLike) -> TensorLike: # pylint: disable=argument
s = (1 + 0j) * s
return qml.math.stack([stack_last([c, -s]), stack_last([s, c])], axis=-2)
+ @staticmethod
+ def compute_sparse_matrix(theta):
+ return sp.sparse.csr_matrix(
+ [[np.cos(theta / 2), -np.sin(theta / 2)], [np.sin(theta / 2), np.cos(theta / 2)]]
+ )
+
def adjoint(self) -> "RY":
return RY(-self.data[0], wires=self.wires)
@@ -307,6 +323,10 @@ def compute_matrix(theta: TensorLike) -> TensorLike: # pylint: disable=argument
diags = qml.math.exp(qml.math.outer(arg, signs))
return diags[:, :, np.newaxis] * qml.math.cast_like(qml.math.eye(2, like=diags), diags)
+ @staticmethod
+ def compute_sparse_matrix(theta):
+ return sp.sparse.csr_matrix([[np.exp(-1j * theta / 2), 0], [0, np.exp(1j * theta / 2)]])
+
@staticmethod
def compute_eigvals(theta: TensorLike) -> TensorLike: # pylint: disable=arguments-differ
r"""Eigenvalues of the operator in the computational basis (static method).
diff --git a/pennylane/ops/qubit/state_preparation.py b/pennylane/ops/qubit/state_preparation.py
index 342267d7927..ee6b57fe20a 100644
--- a/pennylane/ops/qubit/state_preparation.py
+++ b/pennylane/ops/qubit/state_preparation.py
@@ -326,7 +326,8 @@ def __init__(
validate_norm: bool = True,
):
self.is_sparse = False
- if isinstance(state, csr_matrix):
+ if sp.sparse.issparse(state):
+ state = state.tocsr()
state = self._preprocess_csr(state, wires, None, normalize, validate_norm)
self.is_sparse = True
else:
diff --git a/tests/ops/qubit/test_matrix_ops.py b/tests/ops/qubit/test_matrix_ops.py
index 0348bf7340b..dcaa0a4864c 100644
--- a/tests/ops/qubit/test_matrix_ops.py
+++ b/tests/ops/qubit/test_matrix_ops.py
@@ -20,6 +20,7 @@
import numpy as np
import pytest
from gate_data import H, I, S, T, X, Z
+from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
import pennylane as qml
from pennylane import numpy as pnp
@@ -28,6 +29,137 @@
from pennylane.wires import Wires
+class TestQubitUnitaryCSR:
+ """Tests for using csr_matrix in QubitUnitary."""
+
+ def test_compute_sparse_matrix(self):
+ """Test that the compute_sparse_matrix method works correctly."""
+ U = np.array([[0, 1], [1, 0]])
+ op = qml.QubitUnitary.compute_sparse_matrix(U)
+ assert isinstance(op, csr_matrix)
+ assert np.allclose(op.toarray(), U)
+
+ def test_generic_sparse_convert_to_csr(self):
+ """Test that other generic sparse matrices can be converted to csr_matrix."""
+ # 4x4 Identity as a csr_matrix
+ dense = np.eye(4)
+ sparse = coo_matrix(dense)
+ op = qml.QubitUnitary(sparse, wires=[0, 1])
+ assert isinstance(op.matrix(), csr_matrix)
+ sparse = csc_matrix(dense)
+ op = qml.QubitUnitary(sparse, wires=[0, 1])
+ assert isinstance(op.matrix(), csr_matrix)
+
+ @pytest.mark.parametrize(
+ "dense",
+ [H, I, S, T, X, Z],
+ )
+ def test_csr_matrix_init_success(self, dense):
+ """Test that a valid 2-wire csr_matrix can be instantiated, covering necessary single-qubit gates."""
+ # 4x4 Identity as a csr_matrix
+ sparse = csr_matrix(dense)
+ op = qml.QubitUnitary(sparse, wires=[0])
+ assert isinstance(op.matrix(), csr_matrix) # Should still be sparse
+ assert qml.math.allclose(op.matrix().toarray(), dense)
+
+ def test_csr_matrix_shape_mismatch(self):
+ """Test that shape mismatch with csr_matrix raises an error."""
+ dense = np.eye(2) # Only 2x2
+ sparse = csr_matrix(dense)
+ with pytest.raises(ValueError, match="Input unitary must be of shape"):
+ qml.QubitUnitary(sparse, wires=[0, 1])
+
+ def test_csr_matrix_unitary_check_fail(self):
+ """Test that unitary_check warns if the matrix may not be unitary."""
+ dense = np.array([[1, 0], [0, 0.5]]) # Not a unitary
+ sparse = csr_matrix(dense)
+ with pytest.warns(UserWarning, match="may not be unitary"):
+ qml.QubitUnitary(sparse, wires=0, unitary_check=True)
+
+ def test_csr_matrix_pow_integer(self):
+ """Test that QubitUnitary.pow() works for integer exponents with csr_matrix."""
+ dense = np.eye(4)
+ sparse = csr_matrix(dense)
+ op = qml.QubitUnitary(sparse, wires=[0, 1])
+ powered_ops = op.pow(2)
+ assert len(powered_ops) == 1
+
+ powered_op = powered_ops[0]
+ assert isinstance(powered_op, qml.QubitUnitary)
+ assert isinstance(powered_op.matrix(), csr_matrix)
+ # The resulting matrix should still be the identity
+ final_mat = powered_op.matrix()
+ # If it's still sparse, compare .toarray()
+ if isinstance(final_mat, csr_matrix):
+ final_mat = final_mat.toarray()
+ assert qml.math.allclose(final_mat, dense)
+
+ @pytest.mark.parametrize(
+ "dense",
+ [
+ np.eye(4),
+ np.array(
+ [[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]]
+ ), # sample permutation matrix
+ ],
+ )
+ def test_csr_matrix_adjoint(self, dense):
+ """Test that QubitUnitary.adjoint() works with csr_matrix, matching dense result."""
+ sparse = csr_matrix(dense)
+ op = qml.QubitUnitary(sparse, wires=[0, 1])
+ adj_op = op.adjoint()
+
+ assert isinstance(adj_op, qml.QubitUnitary)
+ assert isinstance(adj_op.matrix(), csr_matrix)
+
+ final_mat = adj_op.matrix()
+ # Compare with dense representation if still sparse
+ if isinstance(final_mat, csr_matrix):
+ final_mat = final_mat.toarray()
+
+ # For real/complex conjugate transpose, if dense is unitary, final_mat == dense^\dagger
+ expected = dense.conjugate().T
+ assert qml.math.allclose(final_mat, expected)
+
+ def test_csr_matrix_adjoint_large(self):
+ """Construct a large sparse matrix (e.g. 2^20 dimension) but only store minimal elements."""
+ N = 20
+ dim = 2**N
+
+ # For demonstration, let's just store a single 1 on the diagonal
+ row_indices = [12345] # some arbitrary index < dim
+ col_indices = [12345]
+ data = [1.0]
+
+ sparse_large = csr_matrix((data, (row_indices, col_indices)), shape=(dim, dim))
+ with pytest.warns(UserWarning, match="may not be unitary"):
+ op = qml.QubitUnitary(sparse_large, wires=range(N), unitary_check=True)
+ adj_op = op.adjoint()
+
+ assert isinstance(adj_op, qml.QubitUnitary)
+ assert isinstance(adj_op.matrix(), csr_matrix)
+
+ # The single element should remain 1 at [12345,12345] after conjugate transpose
+ final_mat = adj_op.matrix()
+ assert final_mat[12345, 12345] == 1.0
+
+ def test_csr_matrix_decomposition(self):
+ """Test that QubitUnitary.decomposition() works with csr_matrix."""
+ # 4x4 Identity as a csr_matrix
+ dense = np.eye(4)
+ sparse = csr_matrix(dense)
+ op = qml.QubitUnitary(sparse, wires=[0, 1])
+ decomp = op.decomposition()
+ assert len(decomp) == 6
+
+ # 2x2 Identity as a csr_matrix
+ dense = np.eye(2)
+ sparse = csr_matrix(dense)
+ op = qml.QubitUnitary(sparse, wires=[0])
+ decomp = op.decomposition()
+ assert len(decomp) == 3
+
+
class TestQubitUnitary:
"""Tests for the QubitUnitary class."""
diff --git a/tests/ops/qubit/test_state_prep.py b/tests/ops/qubit/test_state_prep.py
index 1c6975e644d..49b299d06bc 100644
--- a/tests/ops/qubit/test_state_prep.py
+++ b/tests/ops/qubit/test_state_prep.py
@@ -413,6 +413,13 @@ def test_BasisState_wrong_param_size(self):
class TestSparseStateVector:
"""Test the sparse_state_vector() method of various state-prep operations."""
+ def test_sparse_state_convert_to_csr(self):
+ """Test that the sparse_state_vector() method returns a csr_matrix."""
+ sp_vec = sp.sparse.coo_matrix([0, 0, 1, 0])
+ qsv_op = qml.StatePrep(sp_vec, wires=[0, 1])
+ ket = qsv_op.state_vector()
+ assert isinstance(ket, sp.sparse.csr_matrix)
+
@pytest.mark.parametrize(
"num_wires,wire_order,one_position",
[