Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding decomposition MPSPrep #6896

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3527063
decomposition MPS
KetpuntoG Jan 28, 2025
67f04b6
Merge branch 'master' into MPSPrep-decomposition
KetpuntoG Jan 28, 2025
f1ec861
extra tests
KetpuntoG Jan 29, 2025
57b0c31
Merge branch 'MPSPrep-decomposition' of https://github.com/PennyLaneA…
KetpuntoG Jan 29, 2025
0d1ab8d
change log
KetpuntoG Jan 29, 2025
9a55803
Merge branch 'master' into MPSPrep-decomposition
KetpuntoG Jan 29, 2025
ac71564
black
KetpuntoG Jan 29, 2025
3259d71
Merge branch 'MPSPrep-decomposition' of https://github.com/PennyLaneA…
KetpuntoG Jan 29, 2025
c440144
local random
KetpuntoG Jan 29, 2025
ac0a31e
Merge branch 'master' into MPSPrep-decomposition
KetpuntoG Jan 29, 2025
625758f
Update state_prep_mps.py
KetpuntoG Jan 29, 2025
298cec9
comments Utkarsh
KetpuntoG Jan 31, 2025
4a1b932
QR decomposition
KetpuntoG Jan 31, 2025
39af831
Update pennylane/templates/state_preparations/state_prep_mps.py
KetpuntoG Feb 3, 2025
8ca8615
Update pennylane/templates/state_preparations/state_prep_mps.py
KetpuntoG Feb 3, 2025
10bbbaf
adding test that now fails
KetpuntoG Feb 5, 2025
2b2a60a
Update test_state_prep_mps.py
KetpuntoG Feb 7, 2025
98c33d8
review comments
KetpuntoG Feb 7, 2025
03e9ba1
Merge branch 'master' into MPSPrep-decomposition
KetpuntoG Feb 7, 2025
df0c52e
testing max_bond_dimension error
KetpuntoG Feb 7, 2025
39ceac9
Update test_state_prep_mps.py
KetpuntoG Feb 7, 2025
4d7781c
adding conversion to right canonical
KetpuntoG Feb 10, 2025
760d4d6
Merge branch 'master' into MPSPrep-decomposition
KetpuntoG Feb 10, 2025
b3cd75b
Austin suggestions
KetpuntoG Feb 10, 2025
e0e475e
adding right_canonicalize_mps function
KetpuntoG Feb 11, 2025
b09a651
Update state_prep_mps.py
KetpuntoG Feb 11, 2025
d929274
Merge branch 'master' into MPSPrep-decomposition
KetpuntoG Feb 11, 2025
f751f5f
Update doc/releases/changelog-dev.md
KetpuntoG Feb 11, 2025
d6c73a1
removing numpy
KetpuntoG Feb 11, 2025
6a34b4a
Merge branch 'MPSPrep-decomposition' of https://github.com/PennyLaneA…
KetpuntoG Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@
* The requested `diff_method` is now validated when program capture is enabled.
[(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852)


* The template `MPSPrep` now has a gate decomposition. This enables its use with any device.
[(#6896)](https://github.com/PennyLaneAI/pennylane/pull/6896)

* The `qml.clifford_t_decomposition` has been improved to use less gates when decomposing `qml.PhaseShift`.
[(#6842)](https://github.com/PennyLaneAI/pennylane/pull/6842)

Expand Down Expand Up @@ -318,6 +322,8 @@

This release contains contributions from (in alphabetical order):


Guillermo Alonso,
Utkarsh Azad,
Yushao Chen,
Isaac De Vlugt,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/templates/state_preparations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
from .cosine_window import CosineWindow
from .mottonen import MottonenStatePreparation
from .superposition import Superposition
from .state_prep_mps import MPSPrep
from .state_prep_mps import MPSPrep, right_canonicalize_mps
199 changes: 191 additions & 8 deletions pennylane/templates/state_preparations/state_prep_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,89 @@
Contains the MPSPrep template.
"""

import numpy as np

import pennylane as qml
from pennylane.operation import Operation
from pennylane.wires import Wires


def right_canonicalize_mps(mps):
"""
Transform an MPS into a right-canonical MPS.

Args:
mps (list[Array]): List of tensors representing the MPS.

Returns:
A list of tensors representing the MPS in right-canonical form.
"""

L = len(mps)
output_mps = [None] * L

is_right_canonical = True
for i in range(1, L - 1):
tensor = mps[i]
# Right-canonical definition
M = qml.math.tensordot(tensor, tensor.conj(), axes=([1, 2], [1, 2]))
if not qml.math.allclose(M, qml.math.eye(tensor.shape[0])):
is_right_canonical = False
break

if is_right_canonical:
return mps

max_bond_dim = 0
for tensor in mps[1:-1]:
D_left = tensor.shape[0]
D_right = tensor.shape[2]
max_bond_dim = max(max_bond_dim, D_left, D_right)

# Procedure analogous to the left-canonical conversion but starting from the right and storing the Vd
for i in range(L - 1, 0, -1):
chi_left, d, chi_right = mps[i].shape
M = mps[i].reshape(chi_left, d * chi_right)
U, S, Vd = qml.math.linalg.svd(M, full_matrices=False)

# Truncate SVD components if needed
chi_new = min(int(max_bond_dim), len(S))
U = U[:, :chi_new]
S = S[:chi_new]
Vd = Vd[:chi_new, :]

output_mps[i] = Vd.reshape(chi_new, d, chi_right)

US = U @ qml.math.diag(S)
mps[i - 1] = qml.math.tensordot(mps[i - 1], US, axes=([2], [0]))
Comment on lines +58 to +72
Copy link
Contributor

@austingmhuang austingmhuang Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some kind of reference that explains what M U S etc.. refer to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

U and S are part of the svd decomposition.
And M is an auxiliary matrix that we use in the process.
I can't find a really useful reference 😕


output_mps[0] = mps[0]
return output_mps


class MPSPrep(Operation):
r"""Prepares an initial state from a matrix product state (MPS) representation.

.. note::

Currently, this operator can only be used with ``qml.device(“lightning.tensor”)``.
This operator is natively supported on the ``lightning.tensor`` device, designed to run MPS structures
efficiently. For other devices, implementing this operation uses a gate-based decomposition which requires
auxiliary qubits (via ``work_wires``) to prepare the state vector represented by the MPS in a quantum circuit.



Args:
mps (List[Array]): list of arrays of rank-3 and rank-2 tensors representing an MPS state as a
product of site matrices. See the usage details section for more information.
mps (list[Array]): list of arrays of rank-3 and rank-2 tensors representing a right-canonized MPS state
as a product of site matrices. See the usage details section for more information.

wires (Sequence[int]): wires that the template acts on
work_wires (Sequence[int]): list of extra qubits needed in the decomposition. The maximum permissible bond
dimension of the provided MPS is defined as ``2^len(work_wires)``. Default is ``None``.


The decomposition follows Eq. (23) in `[arXiv:2310.18410] <https://arxiv.org/pdf/2310.18410>`_.

.. seealso:: :func:`~.right_canonicalize_mps`.

**Example**

Expand Down Expand Up @@ -87,6 +152,13 @@
Additionally, the physical dimension of the site should always be fixed at :math:`2`
(since the dimension of a qubit is :math:`2`), while the other dimensions must be powers of two.

A right-canonized MPS is a matrix product state where each tensor :math:`A^{(j)}` satisfies
the following orthonormality condition:

.. math::

\sum_{\alpha_j} A^{(j)}_{\alpha_{j-1}, s, \alpha_j} \left( A^{(j)}_{\alpha'_{j-1}, s, \alpha_j} \right)^* = \delta_{\alpha_{j-1}, \alpha'_{j-1}}

The following example shows a valid MPS input containing four tensors with
dimensions :math:`[(2,2), (2,2,4), (4,2,2), (2,2)]` which satisfy the criteria described above.

Expand All @@ -112,7 +184,7 @@
]
"""

def __init__(self, mps, wires, id=None):
def __init__(self, mps, wires, work_wires=None, right_canonicalize=False, id=None):

Check notice on line 187 in pennylane/templates/state_preparations/state_prep_mps.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/state_preparations/state_prep_mps.py#L187

Too many arguments (6/5) (too-many-arguments)

Check notice on line 187 in pennylane/templates/state_preparations/state_prep_mps.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/state_preparations/state_prep_mps.py#L187

Too many positional arguments (6/5) (too-many-positional-arguments)

# Validate the shape and dimensions of the first tensor
assert qml.math.isclose(
Expand Down Expand Up @@ -160,15 +232,30 @@
assert qml.math.isclose(
new_dj0, dj2
), "Dimension mismatch: the last tensor's first dimension does not match the previous third dimension."
super().__init__(*mps, wires=wires, id=id)

self.hyperparameters["input_wires"] = qml.wires.Wires(wires)
self.hyperparameters["right_canonicalize"] = right_canonicalize

if work_wires:
self.hyperparameters["work_wires"] = qml.wires.Wires(work_wires)
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
all_wires = self.hyperparameters["input_wires"] + self.hyperparameters["work_wires"]
else:
self.hyperparameters["work_wires"] = None
all_wires = self.hyperparameters["input_wires"]

super().__init__(*mps, wires=all_wires, id=id)

@property
def mps(self):
"""list representing the MPS input"""
return self.data

def _flatten(self):
hyperparameters = (("wires", self.wires),)
hyperparameters = (
("wires", self.hyperparameters["input_wires"]),
("work_wires", self.hyperparameters["work_wires"]),
("right_canonicalize", self.hyperparameters["right_canonicalize"]),
)
return self.mps, hyperparameters

@classmethod
Expand All @@ -177,8 +264,16 @@
return cls(data, **hyperparams_dict)

def map_wires(self, wire_map):
new_wires = Wires([wire_map.get(wire, wire) for wire in self.wires])
return MPSPrep(self.mps, new_wires)
new_wires = Wires(
[wire_map.get(wire, wire) for wire in self.hyperparameters["input_wires"]]
)
new_work_wires = Wires(
[wire_map.get(wire, wire) for wire in self.hyperparameters["work_wires"]]
)

return MPSPrep(
self.mps, new_wires, new_work_wires, self.hyperparameters["right_canonicalize"]
)

@classmethod
def _primitive_bind_call(cls, mps, wires, id=None):
Expand All @@ -188,6 +283,94 @@
return type.__call__(cls, mps=mps, wires=wires, id=id) # pragma: no cover
return cls._primitive.bind(*mps, wires=wires, id=id)

def decomposition(self): # pylint: disable=arguments-differ
filtered_hyperparameters = {
key: value for key, value in self.hyperparameters.items() if key != "input_wires"
}
return self.compute_decomposition(
self.parameters, wires=self.hyperparameters["input_wires"], **filtered_hyperparameters
)

KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def compute_decomposition(
mps, wires, work_wires, right_canonicalize=False
): # pylint: disable=arguments-differ
r"""Representation of the operator as a product of other operators.
KetpuntoG marked this conversation as resolved.
Show resolved Hide resolved
The decomposition follows Eq. (23) in `[arXiv:2310.18410] <https://arxiv.org/pdf/2310.18410>`_.

Args:
mps (list[Array]): list of arrays of rank-3 and rank-2 tensors representing an MPS state as a
product of site matrices.

wires (Sequence[int]): wires that the template acts on
work_wires (Sequence[int]): list of extra qubits needed in the decomposition. The maximum permissible bond
dimension of the provided MPS is defined as ``2^len(work_wires)``. Default is ``None``.

Returns:
list[.Operator]: Decomposition of the operator
"""

if work_wires is None:
raise ValueError("The qml.MPSPrep decomposition requires `work_wires` to be specified.")

bond_dimensions = []

for i in range(len(mps) - 1):
bond_dim = mps[i].shape[-1]
bond_dimensions.append(bond_dim)

max_bond_dimension = max(bond_dimensions)

if 2 ** len(work_wires) < max_bond_dimension:
raise ValueError("The bond dimension cannot exceed `2**len(work_wires)`.")

ops = []
n_wires = len(work_wires) + 1

mps[0] = mps[0].reshape((1, *mps[0].shape))
mps[-1] = mps[-1].reshape((*mps[-1].shape, 1))

# We transform the mps to ensure that the generated matrix is unitary
if right_canonicalize:
mps = right_canonicalize_mps(mps)

for i, Ai in enumerate(mps):

# encodes the tensor Ai in a unitary matrix following Eq.23 in https://arxiv.org/pdf/2310.18410

vectors = []
for column in Ai:

interface, dtype = qml.math.get_interface(mps[0]), mps[0].dtype
vector = qml.math.zeros(2**n_wires, like=interface, dtype=dtype)

if interface == "jax":
vector = vector.at[: len(column[0])].set(column[0])
vector = vector.at[
2 ** (n_wires - 1) : 2 ** (n_wires - 1) + len(column[1])
].set(column[1])

else:
vector[: len(column[0])] = column[0]
vector[2 ** (n_wires - 1) : 2 ** (n_wires - 1) + len(column[1])] = column[1]

vectors.append(vector)

vectors = qml.math.stack(vectors).T
d = vectors.shape[0]
k = vectors.shape[1]

# The unitary is completed using QR decomposition
rng = np.random.default_rng(42)
new_columns = qml.math.array(rng.random((d, d - k)))

matrix, R = qml.math.linalg.qr(qml.math.hstack([vectors, new_columns]))
matrix *= qml.math.sign(qml.math.diag(R)) # enforces uniqueness for QR decomposition

ops.append(qml.QubitUnitary(matrix, wires=[wires[i]] + work_wires))

return ops


if MPSPrep._primitive is not None: # pylint: disable=protected-access

Expand Down
Loading
Loading