diff --git a/doc/development/deprecations.rst b/doc/development/deprecations.rst
index db5d4482959..ca4152d8c91 100644
--- a/doc/development/deprecations.rst
+++ b/doc/development/deprecations.rst
@@ -20,11 +20,6 @@ Pending deprecations
- Deprecated in v0.39
- Will be removed in v0.40
-* ``qml.broadcast`` has been deprecated. Users should use ``for`` loops instead.
-
- - Deprecated in v0.39
- - Will be removed in v0.40
-
* The ``qml.qinfo`` module has been deprecated. Please see the respective functions in the ``qml.math`` and ``qml.measurements``
modules instead.
@@ -111,6 +106,11 @@ Other deprecations
Completed deprecation cycles
----------------------------
+* ``qml.broadcast`` has been removed. Users should use ``for`` loops instead.
+
+ - Deprecated in v0.39
+ - Removed in v0.40
+
* The ``max_expansion`` argument for :func:`~pennylane.transforms.decompositions.clifford_t_decomposition`
has been removed.
diff --git a/doc/introduction/templates.rst b/doc/introduction/templates.rst
index 4e7b5bac43d..a7b300e8124 100644
--- a/doc/introduction/templates.rst
+++ b/doc/introduction/templates.rst
@@ -360,54 +360,6 @@ Other useful templates which do not belong to the previous categories can be fou
.. _intro_ref_temp_constr:
-Broadcasting function
----------------------
-
-PennyLane offers a broadcasting function to easily construct templates: :func:`~.broadcast`
-takes either quantum gates or templates and applies them to wires in a specific pattern.
-
-.. warning::
-
- While the broadcasting function can make template construction very convenient, it
- adds an overhead and is therefore not recommended when speed is a major concern.
-
-.. gallery-item::
- :description: :doc:`Broadcast (Single) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_single.png
-
-.. gallery-item::
- :description: :doc:`Broadcast (Double) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_double.png
-
-.. gallery-item::
- :description: :doc:`Broadcast (Double Odd) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_double_odd.png
-
-.. gallery-item::
- :description: :doc:`Broadcast (Chain) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_chain.png
-
-.. gallery-item::
- :description: :doc:`Broadcast (Ring) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_ring.png
-
-.. gallery-item::
- :description: :doc:`Broadcast (Pyramid) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_pyramid.png
-
-.. gallery-item::
- :description: :doc:`Broadcast (All-to-All) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_alltoall.png
-
-.. gallery-item::
- :description: :doc:`Broadcast (Custom) <../code/api/pennylane.broadcast>`
- :figure: _static/templates/broadcast_custom.png
-
-.. raw:: html
-
-
-
-.. _intro_ref_temp_init:
Parameter initializations
-------------------------
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 8781537ec9d..7dd0163e66e 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -71,6 +71,9 @@
Breaking changes 💔
+* `qml.broadcast` has been removed. Users should use `for` loops instead.
+ [(#6527)](https://github.com/PennyLaneAI/pennylane/pull/6527)
+
* The `max_expansion` argument for `qml.transforms.clifford_t_decomposition` has been removed.
[(#6531)](https://github.com/PennyLaneAI/pennylane/pull/6531)
diff --git a/pennylane/__init__.py b/pennylane/__init__.py
index 1553eea188d..b659b23e6ee 100644
--- a/pennylane/__init__.py
+++ b/pennylane/__init__.py
@@ -74,7 +74,7 @@
)
from pennylane.ops import *
from pennylane.ops import adjoint, ctrl, cond, exp, sum, pow, prod, s_prod
-from pennylane.templates import broadcast, layer
+from pennylane.templates import layer
from pennylane.templates.embeddings import *
from pennylane.templates.layers import *
from pennylane.templates.tensornetworks import *
diff --git a/pennylane/templates/__init__.py b/pennylane/templates/__init__.py
index c0fa18761df..7613a1724a2 100644
--- a/pennylane/templates/__init__.py
+++ b/pennylane/templates/__init__.py
@@ -15,7 +15,6 @@
This module contains templates, which are pre-coded routines that can be used in a quantum node.
"""
-from .broadcast import *
from .embeddings import *
from .layer import *
from .layers import *
diff --git a/pennylane/templates/broadcast.py b/pennylane/templates/broadcast.py
deleted file mode 100644
index 11e2b4faa47..00000000000
--- a/pennylane/templates/broadcast.py
+++ /dev/null
@@ -1,579 +0,0 @@
-# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-
-# http://www.apache.org/licenses/LICENSE-2.0
-
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-r"""
-Contains the ``broadcast`` template constructor.
-To add a new pattern:
-* extend the variables ``OPTIONS``, ``n_parameters`` and ``wire_sequence``,
-* update the list in the docstring and add a usage example at the end of the docstring's
- ``details`` section,
-* add tests to parametrizations in :func:`test_templates_broadcast`.
-"""
-from warnings import warn
-
-# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
-import pennylane as qml
-from pennylane.wires import Wires
-
-OPTIONS = {"single", "double", "double_odd", "chain", "ring", "pyramid", "all_to_all", "custom"}
-
-###################
-# helpers to define pattern wire sequences
-
-
-def wires_ring(wires):
- """Wire sequence for the ring pattern"""
-
- if len(wires) in [0, 1]:
- return []
-
- if len(wires) == 2:
- # deviation from the rule: for 2 wires ring is set equal to chain,
- # to avoid duplication of single gate
- return [wires.subset([0, 1])]
-
- sequence = [wires.subset([i, i + 1], periodic_boundary=True) for i in range(len(wires))]
- return sequence
-
-
-def wires_pyramid(wires):
- """Wire sequence for the pyramid pattern."""
- sequence = []
- for layer in range(len(wires) // 2):
- block = wires[layer : len(wires) - layer]
- sequence += [block.subset([i, i + 1]) for i in range(0, len(block) - 1, 2)]
- return sequence
-
-
-def wires_all_to_all(wires):
- """Wire sequence for the all-to-all pattern"""
- sequence = []
- for i in range(len(wires)):
- for j in range(i + 1, len(wires)):
- sequence += [wires.subset([i, j])]
- return sequence
-
-
-# define wire sequences for patterns
-PATTERN_TO_WIRES = {
- "single": lambda wires: [wires.subset([i]) for i in range(len(wires))],
- "double": lambda wires: [wires.subset([i, i + 1]) for i in range(0, len(wires) - 1, 2)],
- "double_odd": lambda wires: [wires.subset([i, i + 1]) for i in range(1, len(wires) - 1, 2)],
- "chain": lambda wires: [wires.subset([i, i + 1]) for i in range(len(wires) - 1)],
- "ring": wires_ring,
- "pyramid": wires_pyramid,
- "all_to_all": wires_all_to_all,
- "custom": lambda wires: wires,
-}
-
-# define required number of parameters
-PATTERN_TO_NUM_PARAMS = {
- "single": len, # Use the length of the given wires.
- "double": lambda wires: 0 if len(wires) in [0, 1] else len(wires) // 2,
- "double_odd": lambda wires: 0 if len(wires) in [0, 1] else (len(wires) - 1) // 2,
- "chain": lambda wires: 0 if len(wires) in [0, 1] else len(wires) - 1,
- "ring": lambda wires: 0 if len(wires) in [0, 1] else (1 if len(wires) == 2 else len(wires)),
- "pyramid": lambda w: 0 if len(w) in [0, 1] else sum(i + 1 for i in range(len(w) // 2)),
- "all_to_all": lambda wires: 0 if len(wires) in [0, 1] else len(wires) * (len(wires) - 1) // 2,
- "custom": lambda wires: len(wires) if wires is not None else None,
-}
-###################
-
-
-def _preprocess(parameters, pattern, wires):
- """Validate and pre-process inputs as follows:
-
- * Check that pattern is recognised, or use default pattern if None.
- * Check the dimension of the parameters
- * Create wire sequence of the pattern.
-
- Args:
- parameters (tensor_like): trainable parameters of the template
- pattern (str): specifies the wire pattern
- wires (Wires): wires that template acts on
-
- Returns:
- wire_sequence, parameters: preprocessed pattern and parameters
- """
-
- if isinstance(pattern, str):
- _wires = wires
- if pattern not in OPTIONS:
- raise ValueError(f"did not recognize pattern {pattern}")
- else:
- # turn custom pattern into list of Wires objects
- _wires = [Wires(w) for w in pattern]
- # set "pattern" to "custom", indicating that custom settings have to be used
- pattern = "custom"
-
- # check that there are enough parameters for pattern
- if parameters is not None:
- shape = qml.math.shape(parameters)
-
- # expand dimension so that parameter sets for each unitary can be unpacked
- if len(shape) == 1:
- parameters = qml.math.expand_dims(parameters, 1)
-
- # specific error message for ring edge case of 2 wires
- if (pattern == "ring") and (len(wires) == 2) and (shape[0] != 1):
- raise ValueError(
- "the ring pattern with 2 wires is an exception and only applies one unitary"
- )
- num_params = PATTERN_TO_NUM_PARAMS[pattern](_wires)
- if shape[0] != num_params:
- raise ValueError(
- f"Parameters must contain entries for {num_params} unitaries; got {shape[0]} entries"
- )
-
- wire_sequence = PATTERN_TO_WIRES[pattern](_wires)
- return wire_sequence, parameters
-
-
-def broadcast(unitary, wires, pattern, parameters=None, kwargs=None):
- r"""Applies a unitary multiple times to a specific pattern of wires.
-
- The unitary, defined by the argument ``unitary``, is either a quantum operation
- (such as :meth:`~.pennylane.ops.RX`), or a
- user-supplied template. Depending on the chosen pattern, ``unitary`` is applied to a wire or a subset of wires:
-
- * ``pattern="single"`` applies a single-wire unitary to each one of the :math:`M` wires:
-
- .. figure:: ../../_static/templates/broadcast_single.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- * ``pattern="double"`` applies a two-wire unitary to :math:`\lfloor \frac{M}{2} \rfloor`
- subsequent pairs of wires:
-
- .. figure:: ../../_static/templates/broadcast_double.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- * ``pattern="double_odd"`` applies a two-wire unitary to :math:`\lfloor \frac{M-1}{2} \rfloor`
- subsequent pairs of wires, starting with the second wire:
-
- .. figure:: ../../_static/templates/broadcast_double_odd.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- * ``pattern="chain"`` applies a two-wire unitary to all :math:`M-1` neighbouring pairs of wires:
-
- .. figure:: ../../_static/templates/broadcast_chain.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- * ``pattern="ring"`` applies a two-wire unitary to all :math:`M` neighbouring pairs of wires,
- where the last wire is considered to be a neighbour to the first one:
-
- .. figure:: ../../_static/templates/broadcast_ring.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- .. note:: For 2 wires, the ring pattern is automatically replaced by ``pattern = 'chain'`` to avoid
- a mere repetition of the unitary.
-
- * ``pattern="pyramid"`` applies a two-wire unitary to wire pairs shaped in a pyramid declining to the right:
-
- .. figure:: ../../_static/templates/broadcast_pyramid.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- * ``pattern="all_to_all"`` applies a two-wire unitary to wire pairs that connect all wires to each other:
-
- .. figure:: ../../_static/templates/broadcast_alltoall.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- * A custom pattern can be passed by providing a list of wire lists to ``pattern``. The ``unitary`` is applied
- to each set of wires specified in the list.
-
- .. figure:: ../../_static/templates/broadcast_custom.png
- :align: center
- :width: 20%
- :target: javascript:void(0);
-
- Each ``unitary`` may depend on a different set of parameters. These are passed as a list by the ``parameters``
- argument.
-
- For more details, see *Usage Details* below.
-
- .. warning::
-
- ``qml.broadcast`` has been deprecated and will be removed in v0.40. Please use ``for`` loops instead.
-
- Args:
- unitary (func): quantum gate or template
- pattern (str): specifies the wire pattern of the broadcast
- parameters (list): sequence of parameters for each gate applied
- wires (Iterable or Wires): Wires that the template acts on. Accepts an iterable of numbers or strings, or
- a Wires object.
- kwargs (dict): dictionary of auxilliary parameters for ``unitary``
-
- Raises:
- ValueError: if inputs do not have the correct format
-
- .. details::
- :title: Usage Details
-
- **Broadcasting single gates**
-
- In the simplest case the unitary is typically an :meth:`~.pennylane.operation.Operation` object
- implementing a quantum gate.
-
- .. code-block:: python
-
- import pennylane as qml
- from pennylane import broadcast
-
- dev = qml.device('default.qubit', wires=3)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.RX, pattern="single", wires=[0,1,2], parameters=pars)
- return qml.expval(qml.Z(0))
-
- circuit([1, 1, 2])
-
- This is equivalent to the following circuit:
-
- .. code-block:: python
-
- @qml.qnode(dev)
- def circuit(pars):
- qml.RX(pars[0], wires=[0])
- qml.RX(pars[1], wires=[1])
- qml.RX(pars[2], wires=[2])
- return qml.expval(qml.Z(0))
-
- circuit([1, 1, 2])
-
- **Broadcasting templates**
-
- Alternatively, one can broadcast a built-in or user-defined template:
-
- .. code-block:: python
-
- def mytemplate(pars, wires):
- qml.Hadamard(wires=wires)
- qml.RY(pars, wires=wires)
-
- dev = qml.device('default.qubit', wires=3)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=mytemplate, pattern="single", wires=[0,1,2], parameters=pars)
- return qml.expval(qml.Z(0))
-
- print(circuit([1, 1, 0.1]))
-
- **Constant unitaries**
-
- If the ``unitary`` argument does not take parameters, no ``parameters`` argument is passed to
- :func:`~.pennylane.broadcast`:
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=3)
-
- @qml.qnode(dev)
- def circuit():
- broadcast(unitary=qml.Hadamard, pattern="single", wires=[0,1,2])
- return qml.expval(qml.Z(0))
-
- circuit()
-
- **Multiple parameters in unitary**
-
- The unitary, whether it is a single gate or a user-defined template,
- can take multiple parameters. For example:
-
- .. code-block:: python
-
- def mytemplate(pars1, pars2, wires):
- qml.Hadamard(wires=wires)
- qml.RY(pars1, wires=wires)
- qml.RX(pars2, wires=wires)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=mytemplate, pattern="single", wires=[0,1,2], parameters=pars)
- return qml.expval(qml.Z(0))
-
- circuit([[1, 1], [2, 1], [0.1, 1]])
-
- In general, the unitary takes D parameters and **must** have the following signature:
-
- .. code-block:: python
-
- unitary(parameter1, parameter2, ... parameterD, wires, **kwargs)
-
- If ``unitary`` does not depend on parameters (:math:`D=0`), the signature is
-
- .. code-block:: python
-
- unitary(wires, **kwargs)
-
- As a result, ``parameters`` must be a list or array of length-:math:`D` lists or arrays.
-
- If :math:`D` becomes large, the signature can be simplified by wrapping each entry in ``parameters``:
-
- .. code-block:: python
-
- def mytemplate(pars, wires):
- qml.Hadamard(wires=wires)
- qml.RY(pars[0], wires=wires)
- qml.RX(pars[1], wires=wires)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=mytemplate, pattern="single", wires=[0,1,2], parameters=pars)
- return qml.expval(qml.Z(0))
-
- print(circuit([[[1, 1]], [[2, 1]], [[0.1, 1]]]))
-
- If the number of parameters for each wire does not match the unitary, an error gets thrown:
-
- .. code-block:: python
-
- def mytemplate(pars1, pars2, wires):
- qml.Hadamard(wires=wires)
- qml.RY(pars1, wires=wires)
- qml.RX(pars2, wires=wires)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=mytemplate, pattern="single", wires=[0, 1, 2], parameters=pars)
- return qml.expval(qml.Z(0))
-
- >>> circuit([1, 2, 3]))
- TypeError: mytemplate() missing 1 required positional argument: 'pars2'
-
- **Keyword arguments**
-
- The unitary can be a template that takes additional keyword arguments.
-
- .. code-block:: python
-
- def mytemplate(wires, h=True):
- if h:
- qml.Hadamard(wires=wires)
- qml.T(wires=wires)
-
- @qml.qnode(dev)
- def circuit(hadamard=None):
- broadcast(unitary=mytemplate, pattern="single", wires=[0, 1, 2], kwargs={'h': hadamard})
- return qml.expval(qml.Z(0))
-
- circuit(hadamard=False)
-
- **Different patterns**
-
- The basic usage of the different patterns works as follows:
-
- * Double pattern
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=4)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern='double',
- wires=[0,1,2,3], parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [-1, 2.5, 3]
- pars2 = [-1, 4, 2]
-
- circuit([pars1, pars2])
-
- * Double-odd pattern
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=4)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern='double_odd',
- wires=[0,1,2,3], parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [-5.3, 2.3, 3]
-
- circuit([pars1])
-
- * Chain pattern
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=4)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern='chain',
- wires=[0,1,2,3], parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [1.8, 2, 3]
- pars2 = [-1, 3, 1]
- pars3 = [2, -1.2, 4]
-
- circuit([pars1, pars2, pars3])
-
- * Ring pattern
-
- In general, the number of parameter sequences has to match
- the number of wires:
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=3)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern='ring',
- wires=[0,1,2], parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [1, -2.2, 3]
- pars2 = [-1, 3, 1]
- pars3 = [2.6, 1, 4]
-
- circuit([pars1, pars2, pars3])
-
- However, there is an exception for 2 wires, where only one set of parameters is needed.
- This avoids repeating a gate over the
- same wires twice:
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=2)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern='ring',
- wires=[0,1], parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [-3.2, 2, 1.2]
-
- circuit([pars1])
-
- * Pyramid pattern
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=4)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern='pyramid',
- wires=[0,1,2,3], parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [1.1, 2, 3]
- pars2 = [-1, 3, 1]
- pars3 = [2, 1, 4.2]
-
- circuit([pars1, pars2, pars3])
-
- * All-to-all pattern
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=4)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern="all_to_all",
- wires=[0,1,2,3], parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [1, 2, 3]
- pars2 = [-1, 3, 1]
- pars3 = [2, 1, 4]
- pars4 = [-1, -2, -3]
- pars5 = [2, 1, 4]
- pars6 = [3, -2, -3]
-
- circuit([pars1, pars2, pars3, pars4, pars5, pars6])
-
- * Custom pattern
-
- For a custom pattern, the wire lists for each application of the unitary is
- passed to ``pattern``:
-
- .. code-block:: python
-
- dev = qml.device('default.qubit', wires=5)
-
- pattern = [[0, 1], [3, 4]]
-
- @qml.qnode(dev)
- def circuit():
- broadcast(unitary=qml.CNOT, pattern=pattern,
- wires=range(5))
- return qml.expval(qml.Z(0))
-
- circuit()
-
- When using a parametrized unitary, make sure that the number of wire lists in ``pattern`` corresponds to the
- number of parameters in ``parameters``.
-
- .. code-block:: python
-
- pattern = [[0, 1], [3, 4]]
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=qml.CRot, pattern=pattern,
- wires=range(5), parameters=pars)
- return qml.expval(qml.Z(0))
-
- pars1 = [1, 2, 3]
- pars2 = [-1, 3, 1]
- pars = [pars1, pars2]
-
- assert len(pars) == len(pattern)
-
- circuit(pars)
- """
- # We deliberately disable iterating using enumerate here, since
- # it causes a slowdown when iterating over TensorFlow variables.
- # pylint: disable=consider-using-enumerate
-
- warn(
- "qml.broadcast is deprecated and will be removed in v0.40. Please use a for loop instead",
- qml.PennyLaneDeprecationWarning,
- )
-
- wires = Wires(wires)
- if kwargs is None:
- kwargs = {}
-
- wire_sequence, parameters = _preprocess(parameters, pattern, wires)
-
- if parameters is None:
- for i in range(len(wire_sequence)):
- unitary(wires=wire_sequence[i], **kwargs)
- else:
- for i in range(len(wire_sequence)):
- unitary(*parameters[i], wires=wire_sequence[i], **kwargs)
diff --git a/tests/templates/test_broadcast.py b/tests/templates/test_broadcast.py
deleted file mode 100644
index 49d0b77fc60..00000000000
--- a/tests/templates/test_broadcast.py
+++ /dev/null
@@ -1,379 +0,0 @@
-# Copyright 2018-2020 Xanadu Quantum Technologies Inc.
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-
-# http://www.apache.org/licenses/LICENSE-2.0
-
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Unit tests for the :func:`pennylane.template.broadcast` function.
-Integration tests should be placed into ``test_templates.py``.
-"""
-# pylint: disable=protected-access,cell-var-from-loop,too-many-arguments
-from math import pi
-
-import numpy as np
-import pytest
-
-import pennylane as qml
-from pennylane.ops import CNOT, CRX, RX, RY, CRot, Rot, S, T
-from pennylane.templates import broadcast
-from pennylane.templates.broadcast import wires_all_to_all, wires_pyramid, wires_ring
-from pennylane.wires import Wires
-
-pytestmark = pytest.mark.filterwarnings(
- "ignore:qml.broadcast is deprecated:pennylane.PennyLaneDeprecationWarning"
-)
-
-
-def ConstantTemplate(wires):
- T(wires=wires)
- S(wires=wires)
-
-
-def ParametrizedTemplate(par1, par2, wires):
- RX(par1, wires=wires)
- RY(par2, wires=wires)
-
-
-def KwargTemplate(par, wires, a=True):
- if a:
- T(wires=wires)
- RY(par, wires=wires)
-
-
-def ConstantTemplateDouble(wires):
- T(wires=wires[0])
- CNOT(wires=wires)
-
-
-def ParametrizedTemplateDouble(par1, par2, wires):
- CRX(par1, wires=wires)
- RY(par2, wires=wires[0])
-
-
-def KwargTemplateDouble(par, wires, a=True):
- if a:
- T(wires=wires[0])
- CRX(par, wires=wires)
-
-
-TARGET_OUTPUTS = [
- ("single", 4, [pi, pi, pi / 2, 0], RX, [1, 1, 0, -1]),
- ("double", 4, [pi / 2, pi / 2], CRX, [-1, 0, -1, 0]),
- ("double", 4, None, CNOT, [-1, 1, -1, 1]),
- ("double_odd", 4, [pi / 2], CRX, [-1, -1, 0, -1]),
- ("chain", 4, [pi, pi, pi / 2], CRX, [-1, 1, -1, 0]),
- ("ring", 4, [pi, pi, pi / 2, pi], CRX, [0, 1, -1, 0]),
- ("pyramid", 4, [0, pi, pi / 2], CRX, [-1, -1, 0, 1]),
- ("all_to_all", 4, [pi / 2, pi / 2, pi / 2, pi / 2, pi / 2, pi / 2], CRX, [-1, 0, 1 / 2, 3 / 4]),
-]
-
-GATE_PARAMETERS = [
- ("single", 0, T, []),
- ("single", 1, T, [[]]),
- ("single", 2, T, [[], []]),
- ("single", 3, T, [[], [], []]),
- ("single", 3, RX, [[0.1], [0.2], [0.3]]),
- ("single", 3, Rot, [[0.1, 0.2, 0.3], [0.3, 0.2, 0.1], [0.3, 0.2, -0.1]]),
- ("double", 0, CNOT, []),
- ("double", 1, CNOT, []),
- ("double", 3, CNOT, [[]]),
- ("double", 2, CNOT, [[]]),
- ("double", 3, CRX, [[0.1]]),
- ("double", 3, CRot, [[0.1, 0.2, 0.3]]),
- ("double_odd", 0, CNOT, []),
- ("double_odd", 1, CNOT, []),
- ("double_odd", 2, CNOT, []),
- ("double_odd", 3, CNOT, [[]]),
- ("double_odd", 3, CRX, [[0.1]]),
- ("double_odd", 3, CRot, [[0.3, 0.2, 0.1]]),
- ("chain", 0, CNOT, []),
- ("chain", 1, CNOT, []),
- ("chain", 2, CNOT, [[]]),
- ("chain", 3, CNOT, [[], []]),
- ("chain", 3, CRX, [[0.1], [0.1]]),
- ("chain", 3, CRot, [[0.3, 0.2, 0.1], [0.3, 0.2, 0.1]]),
- ("ring", 0, CNOT, []),
- ("ring", 1, CNOT, []),
- ("ring", 2, CNOT, [[]]),
- ("ring", 3, CNOT, [[], [], []]),
- ("ring", 3, CRX, [[0.1], [0.1], [0.1]]),
- ("ring", 3, CRot, [[0.3, 0.2, 0.1], [0.3, 0.2, 0.1], [0.3, 0.2, 0.1]]),
- ("pyramid", 0, CNOT, []),
- ("pyramid", 1, CNOT, []),
- ("pyramid", 2, CNOT, [[]]),
- ("pyramid", 4, CNOT, [[], [], []]),
- ("pyramid", 3, CRX, [[0.1]]),
- ("pyramid", 4, CRX, [[0.1], [0.1], [0.1]]),
- ("pyramid", 4, CRot, [[0.3, 0.2, 0.1], [0.3, 0.2, 0.1], [0.3, 0.2, 0.1]]),
- ("all_to_all", 0, CNOT, []),
- ("all_to_all", 1, CNOT, []),
- ("all_to_all", 2, CNOT, [[]]),
- ("all_to_all", 4, CNOT, [[], [], [], [], [], []]),
- ("all_to_all", 3, CRX, [[0.1], [0.1], [0.1]]),
- ("all_to_all", 4, CRX, [[0.1], [0.1], [0.1], [0.1], [0.1], [0.1]]),
- (
- "all_to_all",
- 4,
- CRot,
- [
- [0.3, 0.2, 0.1],
- [0.3, 0.2, 0.1],
- [0.3, 0.2, 0.1],
- [0.3, 0.2, 0.1],
- [0.3, 0.2, 0.1],
- [0.3, 0.2, 0.1],
- ],
- ),
-]
-
-
-def test_broadcast_deprecation():
- """Test that a warning is raised when using qml.broadcast"""
- op = qml.Hadamard
- wires = [0, 1, 2]
-
- with pytest.warns(qml.PennyLaneDeprecationWarning, match="qml.broadcast is deprecated"):
- qml.broadcast(op, wires, "single")
-
-
-class TestBuiltinPatterns:
- """Tests the built-in patterns ("single", "ring", etc) of the broadcast template constructor."""
-
- @pytest.mark.parametrize(
- "unitary, parameters",
- [
- (RX, [[0.1], [0.2], [0.3]]),
- (Rot, [[0.1, 0.2, 0.3], [0.3, 0.2, 0.1], [0.3, 0.2, -0.1]]),
- (T, [[], [], []]),
- ],
- )
- def test_correct_queue_for_gate_unitary(self, unitary, parameters):
- """Tests that correct gate queue is created when 'unitary' is a single gate."""
-
- with qml.tape.OperationRecorder() as rec:
- broadcast(unitary=unitary, pattern="single", wires=range(3), parameters=parameters)
-
- for gate in rec.queue:
- assert isinstance(gate, unitary)
-
- @pytest.mark.parametrize(
- "unitary, gates, parameters",
- [
- (ParametrizedTemplate, [RX, RY], [[0.1, 1], [0.2, 1], [0.1, 1]]),
- (ConstantTemplate, [T, S], [[], [], []]),
- ],
- )
- def test_correct_queue_for_template_unitary(self, unitary, gates, parameters):
- """Tests that correct gate queue is created when 'unitary' is a template."""
-
- with qml.tape.OperationRecorder() as rec:
- broadcast(unitary=unitary, pattern="single", wires=range(3), parameters=parameters)
-
- first_gate = gates[0]
- second_gate = gates[1]
- for idx, gate in enumerate(rec.queue):
- if idx % 2 == 0:
- assert isinstance(gate, first_gate)
- else:
- assert isinstance(gate, second_gate)
-
- @pytest.mark.parametrize(
- "template, kwarg, target_queue, parameters",
- [
- (KwargTemplate, True, [T, RY, T, RY], [[1], [2]]),
- (KwargTemplate, False, [RY, RY], [[1], [2]]),
- ],
- )
- def test_correct_queue_for_template_unitary_with_keyword(
- self, template, kwarg, target_queue, parameters
- ):
- """Tests that correct gate queue is created when 'unitary' is a template that uses a keyword."""
-
- with qml.tape.OperationRecorder() as rec:
- broadcast(
- unitary=template,
- pattern="single",
- wires=range(2),
- parameters=parameters,
- kwargs={"a": kwarg},
- )
-
- for gate, target_gate in zip(rec.queue, target_queue):
- assert isinstance(gate, target_gate)
-
- @pytest.mark.parametrize(
- "pars1, pars2, gate",
- [
- ([[], [], []], None, T),
- ([1, 2, 3], [[1], [2], [3]], RX),
- ],
- )
- def test_correct_queue_same_gate_unitary_different_parameter_formats(self, pars1, pars2, gate):
- """Tests that specific parameter inputs have the same output."""
-
- with qml.tape.OperationRecorder() as rec1:
- broadcast(unitary=gate, pattern="single", wires=range(3), parameters=pars1)
-
- with qml.tape.OperationRecorder() as rec2:
- broadcast(unitary=gate, pattern="single", wires=range(3), parameters=pars2)
-
- for g1, g2 in zip(rec1.queue, rec2.queue):
- assert g1.parameters == g2.parameters
-
- @pytest.mark.parametrize("pattern, n_wires, gate, parameters", GATE_PARAMETERS)
- def test_correct_parameters_in_queue(self, pattern, n_wires, gate, parameters):
- """Tests that gate queue has correct parameters."""
-
- with qml.tape.OperationRecorder() as rec:
- broadcast(unitary=gate, pattern=pattern, wires=range(n_wires), parameters=parameters)
-
- for target_par, g in zip(parameters, rec.queue):
- assert g.parameters == target_par
-
- @pytest.mark.parametrize("pattern, n_wires, parameters, unitary, target", TARGET_OUTPUTS)
- def test_prepares_correct_state(self, pattern, n_wires, parameters, unitary, target):
- """Tests the state produced by different unitaries."""
-
- dev = qml.device("default.qubit", wires=n_wires)
-
- @qml.qnode(dev)
- def circuit():
- for w in range(4):
- qml.PauliX(wires=w)
- broadcast(unitary=unitary, pattern=pattern, wires=range(4), parameters=parameters)
- return [qml.expval(qml.PauliZ(wires=w)) for w in range(4)]
-
- res = circuit()
- assert np.allclose(res, target)
-
- @pytest.mark.parametrize("parameters, n_wires", [(np.array([0]), 2), ([0, 0, 0, 1, 0], 3)])
- def test_throws_error_when_mismatch_params_wires(self, parameters, n_wires):
- """Tests that error thrown when 'parameters' does not contain one set
- of parameters for each wire."""
-
- dev = qml.device("default.qubit", wires=n_wires)
-
- @qml.qnode(dev)
- def circuit():
- broadcast(unitary=RX, wires=range(n_wires), pattern="single", parameters=parameters)
- return qml.expval(qml.PauliZ(0))
-
- with pytest.raises(ValueError, match="Parameters must contain entries for"):
- circuit()
-
- def test_throws_special_error_for_ring_pattern_2_wires(self):
- """Tests that the special error is thrown when 'parameters' does not contain one sequence
- of parameters for a two-wire ring pattern."""
-
- dev = qml.device("default.qubit", wires=2)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=RX, wires=range(2), pattern="ring", parameters=pars)
- return qml.expval(qml.PauliZ(0))
-
- pars = [[1.6], [2.1]]
-
- with pytest.raises(ValueError, match="the ring pattern with 2 wires is an exception"):
- circuit(pars)
-
- @pytest.mark.parametrize(
- "function, wires, target",
- [
- (wires_pyramid, [8, 2, 0, 4, 6, 1], [[8, 2], [0, 4], [6, 1], [2, 0], [4, 6], [0, 4]]),
- (
- wires_pyramid,
- [5, 10, 1, 0, 3, 4, 6],
- [[5, 10], [1, 0], [3, 4], [10, 1], [0, 3], [1, 0]],
- ),
- (wires_pyramid, [0], []),
- (wires_ring, [8, 2, 0, 4, 6, 1], [[8, 2], [2, 0], [0, 4], [4, 6], [6, 1], [1, 8]]),
- (wires_ring, [0], []),
- (wires_ring, [4, 2], [[4, 2]]),
- (wires_all_to_all, [8, 2, 0, 4], [[8, 2], [8, 0], [8, 4], [2, 0], [2, 4], [0, 4]]),
- (wires_all_to_all, [0], []),
- ],
- )
- def test_wire_sequence_generating_functions(self, function, wires, target):
- """Tests that the wire list generating functions for different patterns create the correct sequence."""
-
- wires = Wires(wires)
- sequence = function(wires)
- for w, t in zip(sequence, target):
- assert w.tolist() == t
-
-
-class TestCustomPattern:
- """Additional tests for using broadcast with a custom pattern."""
-
- @pytest.mark.parametrize(
- "custom_pattern, pattern",
- [
- ([[0, 1], [1, 2], [2, 3], [3, 0]], "ring"),
- ([[0, 1], [1, 2], [2, 3]], "chain"),
- ([[0, 1], [2, 3]], "double"),
- ],
- )
- def test_reproduce_builtin_patterns(self, custom_pattern, pattern):
- """Tests that the custom pattern can reproduce the built in patterns."""
-
- dev = qml.device("default.qubit", wires=4)
-
- # qnode using custom pattern
- @qml.qnode(dev)
- def circuit1():
- broadcast(unitary=qml.CNOT, pattern=custom_pattern, wires=range(4))
- return [qml.expval(qml.PauliZ(wires=w)) for w in range(4)]
-
- # qnode using built-in pattern
- @qml.qnode(dev)
- def circuit2():
- broadcast(unitary=qml.CNOT, pattern=pattern, wires=range(4))
- return [qml.expval(qml.PauliZ(wires=w)) for w in range(4)]
-
- custom = circuit1()
- built_in = circuit2()
- assert np.allclose(custom, built_in)
-
- @pytest.mark.parametrize(
- "custom_pattern, expected",
- [
- ([[0], [2], [3], [2]], [-1.0, 1.0, 1.0, -1.0]),
- ([[3], [2], [0]], [-1.0, 1.0, -1.0, -1.0]),
- ],
- )
- def test_correct_output(self, custom_pattern, expected):
- """Tests the output for simple cases."""
-
- dev = qml.device("default.qubit", wires=4)
-
- @qml.qnode(dev)
- def circuit():
- broadcast(unitary=qml.PauliX, wires=range(4), pattern=custom_pattern)
- return [qml.expval(qml.PauliZ(w)) for w in range(4)]
-
- res = circuit()
- assert np.allclose(res, expected)
-
-
-def test_unknown_pattern():
- """Test that an unknown pattern raises an error"""
- dev = qml.device("default.qubit", wires=2)
-
- @qml.qnode(dev)
- def circuit(pars):
- broadcast(unitary=RX, wires=range(2), pattern="hello", parameters=pars)
- return qml.expval(qml.PauliZ(0))
-
- pars = [[1.6], [2.1]]
-
- with pytest.raises(ValueError, match="did not recognize pattern hello"):
- circuit(pars)