Skip to content

Commit

Permalink
Deprecate some unneeded methods of QuantumState (#4075)
Browse files Browse the repository at this point in the history
* Deprecate some unneeded methods of QuantumState

* Simplify QuantumState

* Remove rep and data from QuantumState so it only contains information about subsystem dimensions. This is to match recent changes with BaseOperator.
* Stops DensityMatrix and Statevector from always copying input array if it doesn't need to.

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
chriseclectic and mergify[bot] authored Apr 7, 2020
1 parent 5555835 commit b2a87e1
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 125 deletions.
52 changes: 25 additions & 27 deletions qiskit/quantum_info/states/densitymatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def __init__(self, data, dims=None):
# If no 'to_operator' attribute exists we next look for a
# 'to_matrix' attribute to a matrix that will be cast into
# a complex numpy matrix.
mat = np.array(data.to_matrix(), dtype=complex)
mat = np.asarray(data.to_matrix(), dtype=complex)
elif isinstance(data, (list, np.ndarray)):
# Finally we check if the input is a raw matrix in either a
# python list or numpy array format.
mat = np.array(data, dtype=complex)
mat = np.asarray(data, dtype=complex)
else:
raise QiskitError("Invalid input data format for DensityMatrix")
# Convert statevector into a density matrix
Expand All @@ -72,8 +72,25 @@ def __init__(self, data, dims=None):
if mat.ndim != 2 or mat.shape[0] != mat.shape[1]:
raise QiskitError(
"Invalid DensityMatrix input: not a square matrix.")
subsystem_dims = self._automatic_dims(dims, mat.shape[0])
super().__init__('DensityMatrix', mat, subsystem_dims)
self._data = mat
super().__init__(self._automatic_dims(dims, self._data.shape[0]))

def __eq__(self, other):
return super().__eq__(other) and np.allclose(
self._data, other._data, rtol=self.rtol, atol=self.atol)

def __repr__(self):
prefix = 'DensityMatrix('
pad = len(prefix) * ' '
return '{}{},\n{}dims={})'.format(
prefix, np.array2string(
self._data, separator=', ', prefix=prefix),
pad, self._dims)

@property
def data(self):
"""Return data."""
return self._data

def is_valid(self, atol=None, rtol=None):
"""Return True if trace 1 and positive semidefinite."""
Expand Down Expand Up @@ -147,7 +164,7 @@ def expand(self, other):
data = np.kron(other._data, self._data)
return DensityMatrix(data, dims)

def add(self, other):
def _add(self, other):
"""Return the linear combination self + other.
Args:
Expand All @@ -166,33 +183,14 @@ def add(self, other):
raise QiskitError("other DensityMatrix has different dimensions.")
return DensityMatrix(self.data + other.data, self.dims())

def subtract(self, other):
"""Return the linear operator self - other.
Args:
other (DensityMatrix): a quantum state object.
Returns:
DensityMatrix: the linear combination self - other.
Raises:
QiskitError: if other is not a quantum state, or has
incompatible dimensions.
"""
if not isinstance(other, DensityMatrix):
other = DensityMatrix(other)
if self.dim != other.dim:
raise QiskitError("other DensityMatrix has different dimensions.")
return DensityMatrix(self.data - other.data, self.dims())

def multiply(self, other):
"""Return the linear operator self * other.
def _multiply(self, other):
"""Return the scalar multiplied state other * self.
Args:
other (complex): a complex number.
Returns:
DensityMatrix: the linear combination other * self.
DensityMatrix: the scalar multiplied state other * self.
Raises:
QiskitError: if other is not a valid complex number.
Expand Down
105 changes: 59 additions & 46 deletions qiskit/quantum_info/states/quantum_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
Abstract QuantumState class.
"""

import copy
import warnings
from abc import ABC, abstractmethod

import numpy as np
Expand All @@ -33,14 +35,8 @@ class QuantumState(ABC):
_RTOL_DEFAULT = RTOL_DEFAULT
_MAX_TOL = 1e-4

def __init__(self, rep, data, dims):
def __init__(self, dims):
"""Initialize a state object."""
if not isinstance(rep, str):
raise QiskitError("rep must be a string not a {}".format(
rep.__class__))
self._rep = rep
self._data = data

# Dimension attributes
# Note that the tuples of input and output dims are ordered
# from least-significant to most-significant subsystems
Expand All @@ -52,24 +48,7 @@ def __init__(self, rep, data, dims):
self._rng = np.random.RandomState()

def __eq__(self, other):
if (isinstance(other, self.__class__)
and self.dims() == other.dims()):
return np.allclose(
self.data, other.data, rtol=self.rtol, atol=self.atol)
return False

def __repr__(self):
prefix = '{}('.format(self.rep)
pad = len(prefix) * ' '
return '{}{},\n{}dims={})'.format(
prefix, np.array2string(
self.data, separator=', ', prefix=prefix),
pad, self._dims)

@property
def rep(self):
"""Return state representation string."""
return self._rep
return isinstance(other, self.__class__) and self.dims() == other.dims()

@property
def dim(self):
Expand All @@ -81,11 +60,6 @@ def num_qubits(self):
"""Return the number of qubits if a N-qubit state or None otherwise."""
return self._num_qubits

@property
def data(self):
"""Return data."""
return self._data

@property
def atol(self):
"""The absolute tolerance parameter for float comparisons."""
Expand Down Expand Up @@ -148,9 +122,7 @@ def dims(self, qargs=None):

def copy(self):
"""Make a copy of current operator."""
# pylint: disable=no-value-for-parameter
# The constructor of subclasses from raw data should be a copy
return self.__class__(self.data, self.dims())
return copy.deepcopy(self)

def seed(self, value=None):
"""Set the seed for the quantum state RNG."""
Expand Down Expand Up @@ -211,10 +183,42 @@ def expand(self, other):
"""
pass

@abstractmethod
def _add(self, other):
"""Return the linear combination self + other.
Args:
other (QuantumState): a state object.
Returns:
QuantumState: the linear combination self + other.
Raises:
NotImplementedError: if subclass does not support addition.
"""
raise NotImplementedError(
"{} does not support addition".format(type(self)))

def _multiply(self, other):
"""Return the scalar multipled state other * self.
Args:
other (complex): a complex number.
Returns:
QuantumState: the scalar multipled state other * self.
Raises:
NotImplementedError: if subclass does not support scala
multiplication.
"""
raise NotImplementedError(
"{} does not support scalar multiplication".format(type(self)))

def add(self, other):
"""Return the linear combination self + other.
DEPRECATED: use ``state + other`` instead.
Args:
other (QuantumState): a quantum state object.
Expand All @@ -225,12 +229,16 @@ def add(self, other):
QiskitError: if other is not a quantum state, or has
incompatible dimensions.
"""
pass
warnings.warn("`{}.add` method is deprecated, use + binary operator"
"`state + other` instead.".format(self.__class__),
DeprecationWarning)
return self._add(other)

@abstractmethod
def subtract(self, other):
"""Return the linear operator self - other.
DEPRECATED: use ``state - other`` instead.
Args:
other (QuantumState): a quantum state object.
Expand All @@ -241,22 +249,27 @@ def subtract(self, other):
QiskitError: if other is not a quantum state, or has
incompatible dimensions.
"""
pass
warnings.warn("`{}.subtract` method is deprecated, use - binary operator"
"`state - other` instead.".format(self.__class__),
DeprecationWarning)
return self._add(-other)

@abstractmethod
def multiply(self, other):
"""Return the linear operator self * other.
"""Return the scalar multipled state other * self.
Args:
other (complex): a complex number.
Returns:
Operator: the linear combination other * self.
QuantumState: the scalar multipled state other * self.
Raises:
QiskitError: if other is not a valid complex number.
"""
pass
warnings.warn("`{}.multiply` method is deprecated, use * binary operator"
"`other * state` instead.".format(self.__class__),
DeprecationWarning)
return self._multiply(other)

@abstractmethod
def evolve(self, other, qargs=None):
Expand Down Expand Up @@ -628,19 +641,19 @@ def __xor__(self, other):
return self.tensor(other)

def __mul__(self, other):
return self.multiply(other)
return self._multiply(other)

def __truediv__(self, other):
return self.multiply(1 / other)
return self._multiply(1 / other)

def __rmul__(self, other):
return self.__mul__(other)

def __add__(self, other):
return self.add(other)
return self._add(other)

def __sub__(self, other):
return self.subtract(other)
return self._add(-other)

def __neg__(self):
return self.multiply(-1)
return self._multiply(-1)
67 changes: 32 additions & 35 deletions qiskit/quantum_info/states/statevector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,47 @@ class Statevector(QuantumState):
def __init__(self, data, dims=None):
"""Initialize a state object."""
if isinstance(data, Statevector):
# Shallow copy constructor
vec = data.data
self._data = data._data
if dims is None:
dims = data.dims()
dims = data._dims
elif isinstance(data, Operator):
# We allow conversion of column-vector operators to Statevectors
input_dim, output_dim = data.dim
if input_dim != 1:
raise QiskitError("Input Operator is not a column-vector.")
vec = np.reshape(data.data, output_dim)
self._data = np.ravel(data.data)
elif isinstance(data, (list, np.ndarray)):
# Finally we check if the input is a raw vector in either a
# python list or numpy array format.
vec = np.array(data, dtype=complex)
self._data = np.asarray(data, dtype=complex)
else:
raise QiskitError("Invalid input data format for Statevector")
# Check that the input is a numpy vector or column-vector numpy
# matrix. If it is a column-vector matrix reshape to a vector.
if vec.ndim not in [1, 2] or (vec.ndim == 2 and vec.shape[1] != 1):
ndim = self._data.ndim
shape = self._data.shape
if ndim not in [1, 2] or (ndim == 2 and shape[1] != 1):
raise QiskitError("Invalid input: not a vector or column-vector.")
if vec.ndim == 2 and vec.shape[1] == 1:
vec = np.reshape(vec, vec.shape[0])
dim = vec.shape[0]
subsystem_dims = self._automatic_dims(dims, dim)
super().__init__('Statevector', vec, subsystem_dims)
if ndim == 2 and shape[1] == 1:
self._data = np.reshape(self._data, shape[0])
super().__init__(self._automatic_dims(dims, shape[0]))

def __eq__(self, other):
return super().__eq__(other) and np.allclose(
self._data, other._data, rtol=self.rtol, atol=self.atol)

def __repr__(self):
prefix = 'Statevector('
pad = len(prefix) * ' '
return '{}{},\n{}dims={})'.format(
prefix, np.array2string(
self.data, separator=', ', prefix=prefix),
pad, self._dims)

@property
def data(self):
"""Return data."""
return self._data

def is_valid(self, atol=None, rtol=None):
"""Return True if a Statevector has norm 1."""
Expand Down Expand Up @@ -128,14 +144,14 @@ def expand(self, other):
data = np.kron(other._data, self._data)
return Statevector(data, dims)

def add(self, other):
def _add(self, other):
"""Return the linear combination self + other.
Args:
other (Statevector): a quantum state object.
Returns:
LinearOperator: the linear combination self + other.
Statevector: the linear combination self + other.
Raises:
QiskitError: if other is not a quantum state, or has
Expand All @@ -147,33 +163,14 @@ def add(self, other):
raise QiskitError("other Statevector has different dimensions.")
return Statevector(self.data + other.data, self.dims())

def subtract(self, other):
"""Return the linear operator self - other.
Args:
other (Statevector): a quantum state object.
Returns:
LinearOperator: the linear combination self - other.
Raises:
QiskitError: if other is not a quantum state, or has
incompatible dimensions.
"""
if not isinstance(other, Statevector):
other = Statevector(other)
if self.dim != other.dim:
raise QiskitError("other Statevector has different dimensions.")
return Statevector(self.data - other.data, self.dims())

def multiply(self, other):
"""Return the linear operator self * other.
def _multiply(self, other):
"""Return the scalar multiplied state self * other.
Args:
other (complex): a complex number.
Returns:
Operator: the linear combination other * self.
Statevector: the scalar multiplied state other * self.
Raises:
QiskitError: if other is not a valid complex number.
Expand Down
Loading

0 comments on commit b2a87e1

Please sign in to comment.