-
Notifications
You must be signed in to change notification settings - Fork 624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] Compute the jvp of a jaxpr using default.qubit tools #6875
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6875 +/- ##
=======================================
Coverage 99.59% 99.59%
=======================================
Files 478 479 +1
Lines 45294 45375 +81
=======================================
+ Hits 45112 45193 +81
Misses 182 182 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @albi3ro, this addition is really nice to see! 🎉
I like how the separation of forward and backward pass is so convenient to separate the two parts of the adjoint method algorithm 😍
Nice and clean tests as well!
The only real request I'd have is to add some dev comments to make it easier for future work on this code.
|
||
jaxpr = jax.make_jaxpr(f)(0.5).jaxpr | ||
|
||
with pytest.raises(ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this a generator-undefined related error? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I thought. I'm planning on fixing that at some point soon, but for now, ValueError
is the behavior.
assert qml.math.allclose(dres, dexpected) | ||
|
||
def test_jaxpr_consts(self): | ||
"""Test that we can execute jaxpr with consts.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are closed-over variables supported already as well? :) Could test that in the same test, possibly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consts and closed over variables are essentially the same thing from the jaxpr perspective. They both go into jaxpr.consts
. Since I'm testing executing the jaxpr, and not capturing it, I think one way of producing consts
should be fine.
env[eqn.outvars[0]] = (op, ad.Zero(AbstractOperator())) | ||
|
||
if any(not isinstance(t, ad.Zero) for t in tangents[1:]): | ||
raise NotImplementedError("adjoint jvp only differentiable parameters in the 0 position.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this stem from the restriction that the adjoint method only works for single-parameter/single-generator gates? Is the positional restriction an additional restriction? Maybe a dev comment could clarify :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
limitation for the moment to handle both the fact that generators are define with respect a single parameter, and it doesn' t make sense for wires to be differentiable.
Co-authored-by: David Wierichs <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
* The adjoint jvp of a jaxpr can be computed using default.qubit tooling. | ||
[(#6875)](https://github.com/PennyLaneAI/pennylane/pull/6875) | ||
|
||
* The source code has been updated use black 25.1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* The source code has been updated use black 25.1.0 |
This entry already exists (I am sure this is just the result of a merge conflict in the changelog)
include the consts followed by the inputs. | ||
num_wires (int): the number of wires to use. | ||
|
||
Note that the consts for the jaxpr should be included in the beginning of both the ``args`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the consts for the jaxpr should be included in the beginning of both the ``args`` | |
Note that the consts for the jaxpr should be included at the beginning of both the ``args`` |
for eqn in jaxpr.eqns: | ||
if getattr(eqn.primitive, "prim_type", "") == "operator": | ||
ket = _operator_forward_pass(eqn, env, ket) | ||
|
||
elif getattr(eqn.primitive, "prim_type", "") == "measurement": | ||
bra = _measurement_forward_pass(eqn, env, ket) | ||
bras.append(bra) | ||
else: | ||
_other_prim_forward_pass(eqn, env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for eqn in jaxpr.eqns: | |
if getattr(eqn.primitive, "prim_type", "") == "operator": | |
ket = _operator_forward_pass(eqn, env, ket) | |
elif getattr(eqn.primitive, "prim_type", "") == "measurement": | |
bra = _measurement_forward_pass(eqn, env, ket) | |
bras.append(bra) | |
else: | |
_other_prim_forward_pass(eqn, env) | |
for eqn in jaxpr.eqns: | |
prim_type = getattr(eqn.primitive, "prim_type", "") | |
if prim_type == "operator": | |
ket = _operator_forward_pass(eqn, env, ket) | |
elif prim_type == "measurement": | |
bras.append(_measurement_forward_pass(eqn, env, ket)) | |
else: | |
_other_prim_forward_pass(eqn, env) |
What do you think about this version with only one call to getattr
?
t = _read(env, eqn.invars[0])[1] | ||
if not isinstance(t, ad.Zero): | ||
disable() | ||
try: | ||
ket_temp = apply_operation(generator(op, format="observable"), ket) | ||
finally: | ||
enable() | ||
modified = True | ||
for i, bra in enumerate(bras): | ||
out_jvps[i] += -2 * t * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t = _read(env, eqn.invars[0])[1] | |
if not isinstance(t, ad.Zero): | |
disable() | |
try: | |
ket_temp = apply_operation(generator(op, format="observable"), ket) | |
finally: | |
enable() | |
modified = True | |
for i, bra in enumerate(bras): | |
out_jvps[i] += -2 * t * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp)) | |
tang = _read(env, eqn.invars[0])[1] | |
if not isinstance(tang , ad.Zero): | |
disable() | |
try: | |
ket_temp = apply_operation(generator(op, format="observable"), ket) | |
finally: | |
enable() | |
modified = True | |
for i, bra in enumerate(bras): | |
out_jvps[i] += -2 * tang * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp)) |
Can we add a more self-explanatory name than t? I prosed tang
, but feel free to change it : )
|
||
def _measurement_forward_pass(eqn, env, ket): | ||
"""Perform a measurement during the forward pass of the adjoint jvp.""" | ||
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars))) | |
invals, tangents = zip(*(_read(env, var) for var in eqn.invars)) |
I suggested this, although I suspect you wanted to explicitly convert this to a tuple
for efficiency. Is that so?
|
||
mp = eqn.primitive.impl(*invals, **eqn.params) | ||
bra = apply_operation(mp.obs, ket) | ||
result = jnp.real(jnp.sum(jnp.conj(bra) * ket)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
result = jnp.real(jnp.sum(jnp.conj(bra) * ket)) | |
result = jnp.real(jnp.vdot(bra, ket)) |
Should be faster
Context:
See PR #6905 for how this will fit in to the entire workflow.
Description of the Change:
Adds a
execute_and_jvp
function todevices/qubit/jaxpr_adjoint
for computing the results and jvp of a jaxpr using adjoint jacobian.Benefits:
More jaxpr derivatives.
Possible Drawbacks:
Not very featureful and robust right now.
Related GitHub Issues:
[sc-82169]