Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Compute the jvp of a jaxpr using default.qubit tools #6875

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@

<h3>Internal changes ⚙️</h3>

* The adjoint jvp of a jaxpr can be computed using default.qubit tooling.
[(#6875)](https://github.com/PennyLaneAI/pennylane/pull/6875)

* The source code has been updated use black 25.1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* The source code has been updated use black 25.1.0

This entry already exists (I am sure this is just the result of a merge conflict in the changelog)

* Remove `QNode.get_gradient_fn` from source code.
[(#6898)](https://github.com/PennyLaneAI/pennylane/pull/6898)

Expand Down
163 changes: 163 additions & 0 deletions pennylane/devices/qubit/jaxpr_adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Compute the jvp of a jaxpr using the adjoint Jacobian method.
"""
import jax
from jax import numpy as jnp
from jax.interpreters import ad

from pennylane import adjoint, generator
from pennylane.capture import disable, enable
from pennylane.capture.primitives import AbstractOperator

from .apply_operation import apply_operation
from .initialize_state import create_initial_state


def _read(env, var):
"""Return the value and tangent for a variable."""
return (var.val, ad.Zero(var.aval)) if isinstance(var, jax.core.Literal) else env[var]


def _operator_forward_pass(eqn, env, ket):
"""Apply an operator during the forward pass of the adjoint jvp."""
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars)))
op = eqn.primitive.impl(*invals, **eqn.params)
env[eqn.outvars[0]] = (op, ad.Zero(AbstractOperator()))

if any(not isinstance(t, ad.Zero) for t in tangents[1:]):
raise NotImplementedError("adjoint jvp only differentiable parameters in the 0 position.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this stem from the restriction that the adjoint method only works for single-parameter/single-generator gates? Is the positional restriction an additional restriction? Maybe a dev comment could clarify :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

limitation for the moment to handle both the fact that generators are define with respect a single parameter, and it doesn' t make sense for wires to be differentiable.


if isinstance(eqn.outvars[0], jax.core.DropVar):
return apply_operation(op, ket)
if any(not isinstance(t, ad.Zero) for t in tangents):
# derivatives of op arithmetic. Should be possible later
raise NotImplementedError

return ket


def _measurement_forward_pass(eqn, env, ket):
"""Perform a measurement during the forward pass of the adjoint jvp."""
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars)))
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars)))
invals, tangents = zip(*(_read(env, var) for var in eqn.invars))

I suggested this, although I suspect you wanted to explicitly convert this to a tuple for efficiency. Is that so?


if any(not isinstance(t, ad.Zero) for t in tangents): # pragma: no cover
# currently prevented by "no differentiable operator arithmetic."
# but better safe than sorry to keep this error
raise NotImplementedError # pragma: no cover

if eqn.primitive.name != "expval_obs":
raise NotImplementedError("adjoint jvp only supports expectations of observables.")

mp = eqn.primitive.impl(*invals, **eqn.params)
bra = apply_operation(mp.obs, ket)
result = jnp.real(jnp.sum(jnp.conj(bra) * ket))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
result = jnp.real(jnp.sum(jnp.conj(bra) * ket))
result = jnp.real(jnp.vdot(bra, ket))

Should be faster

env[eqn.outvars[0]] = (result, None)
return bra


def _other_prim_forward_pass(eqn: jax.core.JaxprEqn, env: dict) -> None:
"""Handle any equation that is not an operator or measurement eqn.

Maps outputs back to the environment
"""
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars)))
if eqn.primitive not in ad.primitive_jvps:
raise NotImplementedError(
f"Primitive {eqn.primitive} does not have a jvp rule and is not supported.."
)
outvals, doutvals = ad.primitive_jvps[eqn.primitive](invals, tangents, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
doutvals = [doutvals]
for var, v, dv in zip(eqn.outvars, outvals, doutvals, strict=True):
env[var] = (v, dv)


def _forward_pass(jaxpr: jax.core.Jaxpr, env: dict, num_wires: int):
"""Calculate the forward pass of an adjoint jvp calculation."""
bras = []
ket = create_initial_state(range(num_wires))

for eqn in jaxpr.eqns:
if getattr(eqn.primitive, "prim_type", "") == "operator":
ket = _operator_forward_pass(eqn, env, ket)

elif getattr(eqn.primitive, "prim_type", "") == "measurement":
bra = _measurement_forward_pass(eqn, env, ket)
bras.append(bra)
else:
_other_prim_forward_pass(eqn, env)
Comment on lines +94 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for eqn in jaxpr.eqns:
if getattr(eqn.primitive, "prim_type", "") == "operator":
ket = _operator_forward_pass(eqn, env, ket)
elif getattr(eqn.primitive, "prim_type", "") == "measurement":
bra = _measurement_forward_pass(eqn, env, ket)
bras.append(bra)
else:
_other_prim_forward_pass(eqn, env)
for eqn in jaxpr.eqns:
prim_type = getattr(eqn.primitive, "prim_type", "")
if prim_type == "operator":
ket = _operator_forward_pass(eqn, env, ket)
elif prim_type == "measurement":
bras.append(_measurement_forward_pass(eqn, env, ket))
else:
_other_prim_forward_pass(eqn, env)

What do you think about this version with only one call to getattr?


results = [_read(env, var)[0] for var in jaxpr.outvars]
return bras, ket, results


def _backward_pass(jaxpr, bras, ket, results, env):
"""Calculate the jvps during the backward pass stage of an adjoint jvp."""
out_jvps = [jnp.zeros_like(r) for r in results]

modified = False
for eqn in reversed(jaxpr.eqns):
if getattr(eqn.primitive, "prim_type", "") == "operator" and isinstance(
eqn.outvars[0], jax.core.DropVar
):
op = env[eqn.outvars[0]][0]

if eqn.invars:
t = _read(env, eqn.invars[0])[1]
if not isinstance(t, ad.Zero):
disable()
try:
ket_temp = apply_operation(generator(op, format="observable"), ket)
finally:
Comment on lines +123 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it obvious why this happens with capture disabled? Might be worth a dev comment for future :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing the tests, it is clear to me :) Maybe still a good opportunity for a comment.

enable()
modified = True
for i, bra in enumerate(bras):
out_jvps[i] += -2 * t * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp))
Comment on lines +120 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t = _read(env, eqn.invars[0])[1]
if not isinstance(t, ad.Zero):
disable()
try:
ket_temp = apply_operation(generator(op, format="observable"), ket)
finally:
enable()
modified = True
for i, bra in enumerate(bras):
out_jvps[i] += -2 * t * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp))
tang = _read(env, eqn.invars[0])[1]
if not isinstance(tang , ad.Zero):
disable()
try:
ket_temp = apply_operation(generator(op, format="observable"), ket)
finally:
enable()
modified = True
for i, bra in enumerate(bras):
out_jvps[i] += -2 * tang * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp))

Can we add a more self-explanatory name than t? I prosed tang, but feel free to change it : )


disable()
try:
adj_op = adjoint(op, lazy=False)
finally:
enable()
ket = apply_operation(adj_op, ket)
bras = [apply_operation(adj_op, bra) for bra in bras]

if modified:
return out_jvps
return [ad.Zero(r.aval) for r in results]


def execute_and_jvp(jaxpr: jax.core.Jaxpr, args: tuple, tangents: tuple, num_wires: int):
"""Execute and calculate the jvp for a jaxpr using the adjoint method.

Args:
jaxpr (jax.core.Jaxpr): the jaxpr to evaluate
args : an iterable of tensorlikes. Should include the consts followed by the inputs
tangents: an iterable of tensorlikes and ``jax.interpreter.ad.Zero`` objects. Should
include the consts followed by the inputs.
num_wires (int): the number of wires to use.

Note that the consts for the jaxpr should be included in the beginning of both the ``args``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note that the consts for the jaxpr should be included in the beginning of both the ``args``
Note that the consts for the jaxpr should be included at the beginning of both the ``args``

and ``tangents``.
"""
env = {
var: (arg, tangent)
for var, arg, tangent in zip(jaxpr.constvars + jaxpr.invars, args, tangents, strict=True)
}

bras, ket, results = _forward_pass(jaxpr, env, num_wires)
return results, _backward_pass(jaxpr, bras, ket, results, env)
9 changes: 1 addition & 8 deletions tests/devices/qubit/test_dq_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

jax = pytest.importorskip("jax")
pytestmark = pytest.mark.jax
pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")]

from jax import numpy as jnp # pylint: disable=wrong-import-position

Expand All @@ -28,13 +28,6 @@
from pennylane.devices.qubit.dq_interpreter import DefaultQubitInterpreter


@pytest.fixture(autouse=True)
def enable_disable_plxpr():
qml.capture.enable()
yield
qml.capture.disable()


def test_initialization():
"""Test that relevant properties are set on initialization."""
dq = DefaultQubitInterpreter(num_wires=3, shots=None)
Expand Down
Loading
Loading