Skip to content

Commit

Permalink
lightning qubit uses parameter shift if metric tensor applied (#5624)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [ ] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [ ] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [ ] Ensure that the test suite passes, by running `make test`.

- [ ] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [ ] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**

**Description of the Change:**

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: David Wierichs <[email protected]>
  • Loading branch information
albi3ro and dwierichs authored May 3, 2024
1 parent 76edb60 commit f964d44
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 76 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.36.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@

<h3>Bug fixes 🐛</h3>

* Patches the QNode so that parameter-shift will be considered best with lightning if
`qml.metric_tensor` is in the transform program.
[(#5624)](https://github.com/PennyLaneAI/pennylane/pull/5624)

* Stopped printing the ID of `qcut.MeasureNode` and `qcut.PrepareNode` in tape drawing.
[(#5613)](https://github.com/PennyLaneAI/pennylane/pull/5613)

Expand Down
50 changes: 20 additions & 30 deletions pennylane/gradients/metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,19 +469,14 @@ def _metric_tensor_cov_matrix(tape, argnum, diag_approx): # pylint: disable=too
# Create a quantum tape with all operations
# prior to the parametrized layer, and the rotations
# to measure in the basis of the parametrized layer generators.
with qml.queuing.AnnotatedQueue() as layer_q:
for op in queue:
# TODO: Maybe there are gates that do not affect the
# generators of interest and thus need not be applied.
qml.apply(op)
# TODO: Maybe there are gates that do not affect the
# generators of interest and thus need not be applied.

for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
if param_in_argnum:
o.diagonalizing_gates()

qml.probs(wires=tape.wires)
for o, param_in_argnum in zip(layer_obs, in_argnum_list[-1]):
if param_in_argnum:
queue.extend(o.diagonalizing_gates())

layer_tape = qml.tape.QuantumScript.from_queue(layer_q)
layer_tape = qml.tape.QuantumScript(queue, [qml.probs(wires=tape.wires)], shots=tape.shots)
metric_tensor_tapes.append(layer_tape)

def processing_fn(probs):
Expand Down Expand Up @@ -573,7 +568,7 @@ def _get_gen_op(op, allow_nonunitary, aux_wire):
) from e


def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire, shots):
r"""Obtain the tapes for the first term of all tensor entries
belonging to an off-diagonal block.
Expand Down Expand Up @@ -610,23 +605,16 @@ def _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire):
for diffed_op_j, par_idx_j in zip(layer_j.ops, layer_j.param_inds):
gen_op_j = _get_gen_op(WrappedObj(diffed_op_j), allow_nonunitary, aux_wire)

with qml.queuing.AnnotatedQueue() as q:
# Initialize auxiliary wire
qml.Hadamard(wires=aux_wire)
# Apply backward cone of first layer
for op in layer_i.pre_ops:
qml.apply(op)
# Controlled-generator operation of first diff'ed op
qml.apply(gen_op_i)
# Apply first layer and operations between layers
for op in ops_between_cgens:
qml.apply(op)
# Controlled-generator operation of second diff'ed op
qml.apply(gen_op_j)
# Measure X on auxiliary wire
qml.expval(qml.X(aux_wire))

tapes.append(qml.tape.QuantumScript.from_queue(q))
ops = [
qml.Hadamard(wires=aux_wire),
*layer_i.pre_ops,
gen_op_i,
*ops_between_cgens,
gen_op_j,
]
new_tape = qml.tape.QuantumScript(ops, [qml.expval(qml.X(aux_wire))], shots=shots)

tapes.append(new_tape)
# Memorize to which metric entry this tape belongs
ids.append((par_idx_i, par_idx_j))

Expand Down Expand Up @@ -707,7 +695,9 @@ def _metric_tensor_hadamard(
block_sizes.append(len(layer_i.param_inds))

for layer_j in layers[idx_i + 1 :]:
_tapes, _ids = _get_first_term_tapes(layer_i, layer_j, allow_nonunitary, aux_wire)
_tapes, _ids = _get_first_term_tapes(
layer_i, layer_j, allow_nonunitary, aux_wire, shots=tape.shots
)
first_term_tapes.extend(_tapes)
ids.extend(_ids)

Expand Down
14 changes: 12 additions & 2 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ def __init__(
self.gradient_kwargs = {}
self._tape_cached = False

self._transform_program = qml.transforms.core.TransformProgram()
self._update_gradient_fn()
functools.update_wrapper(self, func)
self._transform_program = qml.transforms.core.TransformProgram()

def __copy__(self):
copied_qnode = QNode.__new__(QNode)
Expand Down Expand Up @@ -592,8 +592,17 @@ def _update_gradient_fn(self, shots=None, tape=None):
return
if tape is None and shots:
tape = qml.tape.QuantumScript([], [], shots=shots)

diff_method = self.diff_method
if (
self.device.name == "lightning.qubit"
and qml.metric_tensor in self.transform_program
and self.diff_method == "best"
):
diff_method = "parameter-shift"

self.gradient_fn, self.gradient_kwargs, self.device = self.get_gradient_fn(
self._original_device, self.interface, self.diff_method, tape=tape
self._original_device, self.interface, diff_method, tape=tape
)
self.gradient_kwargs.update(self._user_gradient_kwargs or {})

Expand Down Expand Up @@ -714,6 +723,7 @@ def get_best_method(device, interface, tape=None):
"""
config = _make_execution_config(None, "best")
if isinstance(device, qml.devices.Device):

if device.supports_derivatives(config, circuit=tape):
new_config = device.preprocess(config)[1]
return new_config.gradient_method, {}, device
Expand Down
1 change: 1 addition & 0 deletions tests/gradients/core/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def test_dtype_jax(self, dtype1, dtype2):
determined by the dtype of the dy."""
import jax

jax.config.update("jax_enable_x64", True)
dtype = dtype1
dtype1 = getattr(jax.numpy, dtype1)
dtype2 = getattr(jax.numpy, dtype2)
Expand Down
110 changes: 66 additions & 44 deletions tests/gradients/core/test_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def test_no_trainable_params_tape(self):
mt_tapes, post_processing = qml.metric_tensor(tape)
res = post_processing(qml.execute(mt_tapes, dev, None))

assert mt_tapes == []
assert mt_tapes == [] # pylint: disable=use-implicit-booleaness-not-comparison
assert res == ()


Expand Down Expand Up @@ -1091,8 +1091,13 @@ def qnode(*params):

def mt(*params):
state = qnode(*params)
rqnode = lambda *params: np.real(qnode(*params))
iqnode = lambda *params: np.imag(qnode(*params))

def rqnode(*params):
return np.real(qnode(*params))

def iqnode(*params):
return np.imag(qnode(*params))

rjac = qml.jacobian(rqnode)(*params)
ijac = qml.jacobian(iqnode)(*params)

Expand Down Expand Up @@ -1125,9 +1130,11 @@ class TestFullMetricTensor:
@pytest.mark.autograd
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "autograd"])
def test_correct_output_autograd(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_autograd(self, dev_name, ansatz, params, interface):

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.autograd", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

@qml.qnode(dev, interface=interface)
def circuit(*params):
Expand All @@ -1145,14 +1152,20 @@ def circuit(*params):
@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
def test_correct_output_jax(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_jax(self, dev_name, ansatz, params, interface):
import jax
from jax import numpy as jnp

if ansatz == fubini_ansatz2:
pytest.xfail("Issue involving trainable indices to be resolved.")
if ansatz == fubini_ansatz3 and dev_name == "lightning.qubit":
pytest.xfail("Issue invovling trainable_params to be resolved.")

jax.config.update("jax_enable_x64", True)

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(jnp.array(p) for p in params)

Expand All @@ -1176,10 +1189,11 @@ def circuit(*params):
@pytest.mark.jax
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
def test_jax_argnum_error(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_jax_argnum_error(self, dev_name, ansatz, params, interface):
from jax import numpy as jnp

dev = qml.device("default.qubit.jax", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(jnp.array(p) for p in params)

Expand All @@ -1198,11 +1212,12 @@ def circuit(*params):
@pytest.mark.torch
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "torch"])
def test_correct_output_torch(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_torch(self, dev_name, ansatz, params, interface):
import torch

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.torch", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(torch.tensor(p, dtype=torch.float64, requires_grad=True) for p in params)

Expand All @@ -1222,11 +1237,12 @@ def circuit(*params):
@pytest.mark.tf
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "tf"])
def test_correct_output_tf(self, ansatz, params, interface):
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_tf(self, dev_name, ansatz, params, interface):
import tensorflow as tf

expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
dev = qml.device("default.qubit.tf", wires=self.num_wires + 1)
dev = qml.device(dev_name, wires=self.num_wires + 1)

params = tuple(tf.Variable(p, dtype=tf.float64) for p in params)

Expand Down Expand Up @@ -1254,17 +1270,18 @@ def diffability_ansatz_0(weights, wires=None):
qml.RZ(weights[2], wires=1)


expected_diag_jac_0 = lambda weights: np.array(
[
[0, 0, 0],
[0, 0, 0],
def expected_diag_jac_0(weights):
return np.array(
[
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
0,
],
]
)
[0, 0, 0],
[0, 0, 0],
[
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
np.cos(weights[0] + weights[1]) * np.sin(weights[0] + weights[1]) / 2,
0,
],
]
)


def diffability_ansatz_1(weights, wires=None):
Expand All @@ -1275,17 +1292,18 @@ def diffability_ansatz_1(weights, wires=None):
qml.RZ(weights[2], wires=1)


expected_diag_jac_1 = lambda weights: np.array(
[
[0, 0, 0],
[-np.sin(2 * weights[0]) / 4, 0, 0],
def expected_diag_jac_1(weights):
return np.array(
[
np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)
[0, 0, 0],
[-np.sin(2 * weights[0]) / 4, 0, 0],
[
np.cos(weights[0]) * np.cos(weights[1]) ** 2 * np.sin(weights[0]) / 2,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)


def diffability_ansatz_2(weights, wires=None):
Expand All @@ -1296,17 +1314,19 @@ def diffability_ansatz_2(weights, wires=None):
qml.RZ(weights[2], wires=1)


expected_diag_jac_2 = lambda weights: np.array(
[
[0, 0, 0],
[0, 0, 0],
def expected_diag_jac_2(weights):
return np.array(
[
np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)
[0, 0, 0],
[0, 0, 0],
[
np.cos(weights[1]) ** 2 * np.sin(2 * weights[0]) / 4,
np.cos(weights[0]) ** 2 * np.sin(2 * weights[1]) / 4,
0,
],
]
)


weights_diff = np.array([0.432, 0.12, -0.292], requires_grad=True)

Expand Down Expand Up @@ -1466,7 +1486,9 @@ def test_autograd(self, diff_method, tol, ansatz, weights, interface):
def cost_full(*weights):
return np.array(qml.metric_tensor(qnode, approx=None)(*weights))

_cost_full = lambda *weights: np.array(autodiff_metric_tensor(ansatz, 3)(*weights))
def _cost_full(*weights):
return np.array(autodiff_metric_tensor(ansatz, 3)(*weights))

_c = _cost_full(*weights)
c = cost_full(*weights)
assert all(
Expand Down

0 comments on commit f964d44

Please sign in to comment.