Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Compute the jvp of a jaxpr using default.qubit tools #6875

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jan 22, 2025

Context:

See PR #6905 for how this will fit in to the entire workflow.

Description of the Change:

Adds a execute_and_jvp function to devices/qubit/jaxpr_adjoint for computing the results and jvp of a jaxpr using adjoint jacobian.

Benefits:

More jaxpr derivatives.

Possible Drawbacks:

Not very featureful and robust right now.

Related GitHub Issues:

[sc-82169]

@albi3ro albi3ro marked this pull request as ready for review January 30, 2025 20:13
Copy link

codecov bot commented Jan 30, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.59%. Comparing base (ca9a547) to head (06dda62).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6875   +/-   ##
=======================================
  Coverage   99.59%   99.59%           
=======================================
  Files         478      479    +1     
  Lines       45294    45375   +81     
=======================================
+ Hits        45112    45193   +81     
  Misses        182      182           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @albi3ro, this addition is really nice to see! 🎉
I like how the separation of forward and backward pass is so convenient to separate the two parts of the adjoint method algorithm 😍
Nice and clean tests as well!

The only real request I'd have is to add some dev comments to make it easier for future work on this code.

pennylane/devices/qubit/jaxpr_adjoint.py Outdated Show resolved Hide resolved
pennylane/devices/qubit/jaxpr_adjoint.py Outdated Show resolved Hide resolved
pennylane/devices/qubit/jaxpr_adjoint.py Outdated Show resolved Hide resolved
pennylane/devices/qubit/jaxpr_adjoint.py Outdated Show resolved Hide resolved
pennylane/devices/qubit/jaxpr_adjoint.py Outdated Show resolved Hide resolved

jaxpr = jax.make_jaxpr(f)(0.5).jaxpr

with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this a generator-undefined related error? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought. I'm planning on fixing that at some point soon, but for now, ValueError is the behavior.

tests/devices/qubit/test_jaxpr_adjoint_jvp.py Outdated Show resolved Hide resolved
tests/devices/qubit/test_jaxpr_adjoint_jvp.py Outdated Show resolved Hide resolved
assert qml.math.allclose(dres, dexpected)

def test_jaxpr_consts(self):
"""Test that we can execute jaxpr with consts."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are closed-over variables supported already as well? :) Could test that in the same test, possibly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consts and closed over variables are essentially the same thing from the jaxpr perspective. They both go into jaxpr.consts. Since I'm testing executing the jaxpr, and not capturing it, I think one way of producing consts should be fine.

env[eqn.outvars[0]] = (op, ad.Zero(AbstractOperator()))

if any(not isinstance(t, ad.Zero) for t in tangents[1:]):
raise NotImplementedError("adjoint jvp only differentiable parameters in the 0 position.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this stem from the restriction that the adjoint method only works for single-parameter/single-generator gates? Is the positional restriction an additional restriction? Maybe a dev comment could clarify :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

limitation for the moment to handle both the fact that generators are define with respect a single parameter, and it doesn' t make sense for wires to be differentiable.

Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

* The adjoint jvp of a jaxpr can be computed using default.qubit tooling.
[(#6875)](https://github.com/PennyLaneAI/pennylane/pull/6875)

* The source code has been updated use black 25.1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* The source code has been updated use black 25.1.0

This entry already exists (I am sure this is just the result of a merge conflict in the changelog)

include the consts followed by the inputs.
num_wires (int): the number of wires to use.

Note that the consts for the jaxpr should be included in the beginning of both the ``args``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note that the consts for the jaxpr should be included in the beginning of both the ``args``
Note that the consts for the jaxpr should be included at the beginning of both the ``args``

Comment on lines +94 to +102
for eqn in jaxpr.eqns:
if getattr(eqn.primitive, "prim_type", "") == "operator":
ket = _operator_forward_pass(eqn, env, ket)

elif getattr(eqn.primitive, "prim_type", "") == "measurement":
bra = _measurement_forward_pass(eqn, env, ket)
bras.append(bra)
else:
_other_prim_forward_pass(eqn, env)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for eqn in jaxpr.eqns:
if getattr(eqn.primitive, "prim_type", "") == "operator":
ket = _operator_forward_pass(eqn, env, ket)
elif getattr(eqn.primitive, "prim_type", "") == "measurement":
bra = _measurement_forward_pass(eqn, env, ket)
bras.append(bra)
else:
_other_prim_forward_pass(eqn, env)
for eqn in jaxpr.eqns:
prim_type = getattr(eqn.primitive, "prim_type", "")
if prim_type == "operator":
ket = _operator_forward_pass(eqn, env, ket)
elif prim_type == "measurement":
bras.append(_measurement_forward_pass(eqn, env, ket))
else:
_other_prim_forward_pass(eqn, env)

What do you think about this version with only one call to getattr?

Comment on lines +120 to +129
t = _read(env, eqn.invars[0])[1]
if not isinstance(t, ad.Zero):
disable()
try:
ket_temp = apply_operation(generator(op, format="observable"), ket)
finally:
enable()
modified = True
for i, bra in enumerate(bras):
out_jvps[i] += -2 * t * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t = _read(env, eqn.invars[0])[1]
if not isinstance(t, ad.Zero):
disable()
try:
ket_temp = apply_operation(generator(op, format="observable"), ket)
finally:
enable()
modified = True
for i, bra in enumerate(bras):
out_jvps[i] += -2 * t * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp))
tang = _read(env, eqn.invars[0])[1]
if not isinstance(tang , ad.Zero):
disable()
try:
ket_temp = apply_operation(generator(op, format="observable"), ket)
finally:
enable()
modified = True
for i, bra in enumerate(bras):
out_jvps[i] += -2 * tang * jnp.imag(jnp.sum(jnp.conj(bra) * ket_temp))

Can we add a more self-explanatory name than t? I prosed tang, but feel free to change it : )


def _measurement_forward_pass(eqn, env, ket):
"""Perform a measurement during the forward pass of the adjoint jvp."""
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars)))
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
invals, tangents = tuple(zip(*(_read(env, var) for var in eqn.invars)))
invals, tangents = zip(*(_read(env, var) for var in eqn.invars))

I suggested this, although I suspect you wanted to explicitly convert this to a tuple for efficiency. Is that so?


mp = eqn.primitive.impl(*invals, **eqn.params)
bra = apply_operation(mp.obs, ket)
result = jnp.real(jnp.sum(jnp.conj(bra) * ket))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
result = jnp.real(jnp.sum(jnp.conj(bra) * ket))
result = jnp.real(jnp.vdot(bra, ket))

Should be faster

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants