diff --git a/CHANGELOG.md b/CHANGELOG.md index 30156bea9..784d13e36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ measurements by providing a list of PennyLane `measurements `_ themselves. [(#405)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/405) + [(#466)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/466) * Added the support for converting conditional operations based on mid-circuit measurements and two of the ``ControlFlowOp`` operations - ``IfElseOp`` and ``SwitchCaseOp`` when converting diff --git a/pennylane_qiskit/converter.py b/pennylane_qiskit/converter.py index d68803804..b71514f35 100644 --- a/pennylane_qiskit/converter.py +++ b/pennylane_qiskit/converter.py @@ -15,7 +15,7 @@ This module contains functions for converting Qiskit QuantumCircuit objects into PennyLane circuit templates. """ -from typing import Dict, Any, Sequence, Union +from typing import Dict, Any, Iterable, Sequence, Union import warnings from functools import partial, reduce @@ -349,12 +349,12 @@ def load(quantum_circuit: QuantumCircuit, measurements=None): Args: quantum_circuit (qiskit.QuantumCircuit): the QuantumCircuit to be converted - measurements (list[pennylane.measurements.MeasurementProcess]): the list of PennyLane - `measurements `_ - that overrides the terminal measurements that may be present in the input circuit. + measurements (None | pennylane.measurements.MeasurementProcess | list[pennylane.measurements.MeasurementProcess]): + the PennyLane `measurements `_ + that override the terminal measurements that may be present in the input circuit Returns: - function: the resulting PennyLane template + function: The resulting PennyLane template. """ # pylint:disable=too-many-branches, fixme, protected-access @@ -553,9 +553,13 @@ def _function(*args, params: dict = None, wires: list = None, **kwargs): # Use the user-provided measurements if measurements: - if qml.queuing.QueuingManager.active_context(): + if not qml.queuing.QueuingManager.active_context(): + return measurements + + if isinstance(measurements, Iterable): return [qml.apply(meas) for meas in measurements] - return measurements + + return qml.apply(measurements) return tuple(mid_circ_meas + list(map(qml.measure, terminal_meas))) or None diff --git a/tests/test_converter.py b/tests/test_converter.py index c9ea24cd8..e5f7d46b1 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -1468,8 +1468,28 @@ def cost(x, y): assert np.allclose(jac, jac_expected) - def test_meas_circuit_in_qnode(self, qubit_device_2_wires): - """Tests loading a converted template in a QNode with measurements.""" + def test_quantum_circuit_with_single_measurement(self, qubit_device_single_wire): + """Tests loading a converted template in a QNode with a single measurement.""" + qc = QuantumCircuit(1) + qc.h(0) + qc.measure_all() + + measurement = qml.expval(qml.PauliZ(0)) + quantum_circuit = load(qc, measurements=measurement) + + @qml.qnode(qubit_device_single_wire) + def circuit_loaded_qiskit_circuit(): + return quantum_circuit() + + @qml.qnode(qubit_device_single_wire) + def circuit_native_pennylane(): + qml.Hadamard(0) + return qml.expval(qml.PauliZ(0)) + + assert circuit_loaded_qiskit_circuit() == circuit_native_pennylane() + + def test_quantum_circuit_with_multiple_measurements(self, qubit_device_2_wires): + """Tests loading a converted template in a QNode with multiple measurements.""" angle = 0.543