-
Notifications
You must be signed in to change notification settings - Fork 625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] Compute the jvp of a jaxpr using default.qubit tools #6875
base: master
Are you sure you want to change the base?
Changes from all commits
98bd43f
5cdd1cb
353a3af
b2ef102
8bd6338
26a8da0
c013510
06dda62
bb3bf09
92d549a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,163 @@ | ||||||||||||||||||||||||||||||||||||||||||
# Copyright 2025 Xanadu Quantum Technologies Inc. | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||
Compute the jvp of a jaxpr using the adjoint Jacobian method. | ||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||
import jax | ||||||||||||||||||||||||||||||||||||||||||
from jax import numpy as jnp | ||||||||||||||||||||||||||||||||||||||||||
from jax.interpreters import ad | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
from pennylane import adjoint, generator | ||||||||||||||||||||||||||||||||||||||||||
from pennylane.capture import disable, enable | ||||||||||||||||||||||||||||||||||||||||||
from pennylane.capture.primitives import AbstractOperator | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
from .apply_operation import apply_operation | ||||||||||||||||||||||||||||||||||||||||||
from .initialize_state import create_initial_state | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def _read(env, var): | ||||||||||||||||||||||||||||||||||||||||||
"""Return the value and tangent for a variable.""" | ||||||||||||||||||||||||||||||||||||||||||
return (var.val, ad.Zero(var.aval)) if isinstance(var, jax.core.Literal) else env[var] | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def _operator_forward_pass(eqn, env, ket): | ||||||||||||||||||||||||||||||||||||||||||
"""Apply an operator during the forward pass of the adjoint jvp.""" | ||||||||||||||||||||||||||||||||||||||||||
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars))) | ||||||||||||||||||||||||||||||||||||||||||
op = eqn.primitive.impl(*invals, **eqn.params) | ||||||||||||||||||||||||||||||||||||||||||
env[eqn.outvars[0]] = (op, ad.Zero(AbstractOperator())) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if any(not isinstance(t, ad.Zero) for t in tangents[1:]): | ||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError("adjoint jvp only differentiable parameters in the 0 position.") | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this stem from the restriction that the adjoint method only works for single-parameter/single-generator gates? Is the positional restriction an additional restriction? Maybe a dev comment could clarify :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. limitation for the moment to handle both the fact that generators are define with respect a single parameter, and it doesn' t make sense for wires to be differentiable. |
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if isinstance(eqn.outvars[0], jax.core.DropVar): | ||||||||||||||||||||||||||||||||||||||||||
return apply_operation(op, ket) | ||||||||||||||||||||||||||||||||||||||||||
if any(not isinstance(t, ad.Zero) for t in tangents): | ||||||||||||||||||||||||||||||||||||||||||
# derivatives of op arithmetic. Should be possible later | ||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
return ket | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def _measurement_forward_pass(eqn, env, ket): | ||||||||||||||||||||||||||||||||||||||||||
"""Perform a measurement during the forward pass of the adjoint jvp.""" | ||||||||||||||||||||||||||||||||||||||||||
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars))) | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I suggested this, although I suspect you wanted to explicitly convert this to a |
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if any(not isinstance(t, ad.Zero) for t in tangents): # pragma: no cover | ||||||||||||||||||||||||||||||||||||||||||
# currently prevented by "no differentiable operator arithmetic." | ||||||||||||||||||||||||||||||||||||||||||
# but better safe than sorry to keep this error | ||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError # pragma: no cover | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if eqn.primitive.name != "expval_obs": | ||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError("adjoint jvp only supports expectations of observables.") | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
mp = eqn.primitive.impl(*invals, **eqn.params) | ||||||||||||||||||||||||||||||||||||||||||
bra = apply_operation(mp.obs, ket) | ||||||||||||||||||||||||||||||||||||||||||
result = jnp.real(jnp.sum(jnp.conj(bra) * ket)) | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Should be faster |
||||||||||||||||||||||||||||||||||||||||||
env[eqn.outvars[0]] = (result, None) | ||||||||||||||||||||||||||||||||||||||||||
return bra | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def _other_prim_forward_pass(eqn: jax.core.JaxprEqn, env: dict) -> None: | ||||||||||||||||||||||||||||||||||||||||||
"""Handle any equation that is not an operator or measurement eqn. | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Maps outputs back to the environment | ||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars))) | ||||||||||||||||||||||||||||||||||||||||||
if eqn.primitive not in ad.primitive_jvps: | ||||||||||||||||||||||||||||||||||||||||||
raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||
f"Primitive {eqn.primitive} does not have a jvp rule and is not supported.." | ||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||
outvals, doutvals = ad.primitive_jvps[eqn.primitive](invals, tangents, **eqn.params) | ||||||||||||||||||||||||||||||||||||||||||
if not eqn.primitive.multiple_results: | ||||||||||||||||||||||||||||||||||||||||||
outvals = [outvals] | ||||||||||||||||||||||||||||||||||||||||||
doutvals = [doutvals] | ||||||||||||||||||||||||||||||||||||||||||
for var, v, dv in zip(eqn.outvars, outvals, doutvals, strict=True): | ||||||||||||||||||||||||||||||||||||||||||
env[var] = (v, dv) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def _forward_pass(jaxpr: jax.core.Jaxpr, env: dict, num_wires: int): | ||||||||||||||||||||||||||||||||||||||||||
"""Calculate the forward pass of an adjoint jvp calculation.""" | ||||||||||||||||||||||||||||||||||||||||||
bras = [] | ||||||||||||||||||||||||||||||||||||||||||
ket = create_initial_state(range(num_wires)) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
for eqn in jaxpr.eqns: | ||||||||||||||||||||||||||||||||||||||||||
if getattr(eqn.primitive, "prim_type", "") == "operator": | ||||||||||||||||||||||||||||||||||||||||||
ket = _operator_forward_pass(eqn, env, ket) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
elif getattr(eqn.primitive, "prim_type", "") == "measurement": | ||||||||||||||||||||||||||||||||||||||||||
bra = _measurement_forward_pass(eqn, env, ket) | ||||||||||||||||||||||||||||||||||||||||||
bras.append(bra) | ||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||
_other_prim_forward_pass(eqn, env) | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+94
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
What do you think about this version with only one call to |
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
results = [_read(env, var)[0] for var in jaxpr.outvars] | ||||||||||||||||||||||||||||||||||||||||||
return bras, ket, results | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def _backward_pass(jaxpr, bras, ket, results, env): | ||||||||||||||||||||||||||||||||||||||||||
"""Calculate the jvps during the backward pass stage of an adjoint jvp.""" | ||||||||||||||||||||||||||||||||||||||||||
out_jvps = [jnp.zeros_like(r) for r in results] | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
modified = False | ||||||||||||||||||||||||||||||||||||||||||
for eqn in reversed(jaxpr.eqns): | ||||||||||||||||||||||||||||||||||||||||||
if getattr(eqn.primitive, "prim_type", "") == "operator" and isinstance( | ||||||||||||||||||||||||||||||||||||||||||
eqn.outvars[0], jax.core.DropVar | ||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||
op = env[eqn.outvars[0]][0] | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if eqn.invars: | ||||||||||||||||||||||||||||||||||||||||||
t = _read(env, eqn.invars[0])[1] | ||||||||||||||||||||||||||||||||||||||||||
if not isinstance(t, ad.Zero): | ||||||||||||||||||||||||||||||||||||||||||
disable() | ||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||
ket_temp = apply_operation(generator(op, format="observable"), ket) | ||||||||||||||||||||||||||||||||||||||||||
finally: | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+123
to
+125
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it obvious why this happens with capture disabled? Might be worth a dev comment for future :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seeing the tests, it is clear to me :) Maybe still a good opportunity for a comment. |
||||||||||||||||||||||||||||||||||||||||||
enable() | ||||||||||||||||||||||||||||||||||||||||||
modified = True | ||||||||||||||||||||||||||||||||||||||||||
for i, bra in enumerate(bras): | ||||||||||||||||||||||||||||||||||||||||||
out_jvps[i] += -2 * t * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp)) | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+120
to
+129
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Can we add a more self-explanatory name than t? I prosed |
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
disable() | ||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||
adj_op = adjoint(op, lazy=False) | ||||||||||||||||||||||||||||||||||||||||||
finally: | ||||||||||||||||||||||||||||||||||||||||||
enable() | ||||||||||||||||||||||||||||||||||||||||||
ket = apply_operation(adj_op, ket) | ||||||||||||||||||||||||||||||||||||||||||
bras = [apply_operation(adj_op, bra) for bra in bras] | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
if modified: | ||||||||||||||||||||||||||||||||||||||||||
return out_jvps | ||||||||||||||||||||||||||||||||||||||||||
return [ad.Zero(r.aval) for r in results] | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
def execute_and_jvp(jaxpr: jax.core.Jaxpr, args: tuple, tangents: tuple, num_wires: int): | ||||||||||||||||||||||||||||||||||||||||||
"""Execute and calculate the jvp for a jaxpr using the adjoint method. | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||
jaxpr (jax.core.Jaxpr): the jaxpr to evaluate | ||||||||||||||||||||||||||||||||||||||||||
args : an iterable of tensorlikes. Should include the consts followed by the inputs | ||||||||||||||||||||||||||||||||||||||||||
tangents: an iterable of tensorlikes and ``jax.interpreter.ad.Zero`` objects. Should | ||||||||||||||||||||||||||||||||||||||||||
include the consts followed by the inputs. | ||||||||||||||||||||||||||||||||||||||||||
num_wires (int): the number of wires to use. | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Note that the consts for the jaxpr should be included in the beginning of both the ``args`` | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
and ``tangents``. | ||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||
env = { | ||||||||||||||||||||||||||||||||||||||||||
var: (arg, tangent) | ||||||||||||||||||||||||||||||||||||||||||
for var, arg, tangent in zip(jaxpr.constvars + jaxpr.invars, args, tangents, strict=True) | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
bras, ket, results = _forward_pass(jaxpr, env, num_wires) | ||||||||||||||||||||||||||||||||||||||||||
return results, _backward_pass(jaxpr, bras, ket, results, env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entry already exists (I am sure this is just the result of a merge conflict in the changelog)