Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] JIT entire workflows with default qubit #6655

Merged
merged 65 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
e827977
exploring an idea
albi3ro Jun 21, 2024
33b946c
lightning interpreter
albi3ro Jun 21, 2024
bdc2b1e
add call jaxpr function
albi3ro Jun 21, 2024
5185112
add demo notebook
albi3ro Jun 21, 2024
786caf3
call is function transform
albi3ro Aug 9, 2024
d8f7e59
Merge branch 'master' into plxpr-interpreter
albi3ro Aug 13, 2024
13b091d
assorted improvements
albi3ro Aug 13, 2024
3f7b5ff
more moving things around
albi3ro Aug 13, 2024
6e2b334
something?
albi3ro Aug 15, 2024
4f3efb6
Merge branch 'master' into plxpr-interpreter
albi3ro Aug 21, 2024
a61acd2
Add PLxprInterpreter base class
albi3ro Aug 26, 2024
0d68f07
Merge branch 'master' into plxpr-interpreter-base
albi3ro Aug 26, 2024
1609ed3
qnode fix
albi3ro Aug 26, 2024
f4a6ba8
Merge branch 'plxpr-interpreter-base' of https://github.com/PennyLane…
albi3ro Aug 26, 2024
f593d45
starting to write tests
albi3ro Aug 27, 2024
0defdde
trying to improve op math handling
albi3ro Aug 28, 2024
0ecf1a8
testing
albi3ro Aug 30, 2024
65202fa
improvementS
albi3ro Aug 30, 2024
22e32f7
more fixes
albi3ro Sep 3, 2024
93724c4
Merge branch 'master' into plxpr-interpreter-base
albi3ro Sep 9, 2024
de6f4b3
Merge branch 'master' into plxpr-interpreter-base
albi3ro Sep 30, 2024
ba298eb
adding tests
albi3ro Oct 1, 2024
d9d9fd7
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 1, 2024
1cc4e79
more tests
albi3ro Oct 1, 2024
1a8f410
more test changes
albi3ro Oct 2, 2024
769570f
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 2, 2024
e8a5c5a
some more tests and polishing
albi3ro Oct 2, 2024
c42090c
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 2, 2024
f806708
add default qubit interpreter
albi3ro Oct 2, 2024
751d66c
use copy not child
albi3ro Oct 9, 2024
c232286
merge in master
albi3ro Nov 11, 2024
bfbe68d
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 11, 2024
d3f729e
starting to write tests
albi3ro Nov 12, 2024
879bdad
more testing
albi3ro Nov 12, 2024
7741113
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 12, 2024
b80d290
more tests
albi3ro Nov 13, 2024
d0298ad
Apply suggestions from code review
albi3ro Nov 15, 2024
1d4bf51
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 15, 2024
07fae7b
update initial key each execution
albi3ro Nov 18, 2024
e076a48
changelog
albi3ro Nov 18, 2024
2c736bd
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 18, 2024
53dfac0
add eval_jaxpr method to DefaultQubit
albi3ro Nov 18, 2024
cf22d05
Update doc/releases/changelog-dev.md
albi3ro Nov 18, 2024
14a8cd5
qnode natively executes jaxpr on device
albi3ro Nov 18, 2024
a5e75e4
no seed support
albi3ro Nov 18, 2024
9db9aa0
Merge branch 'master' into native-dq-execution
albi3ro Nov 20, 2024
05b1dae
trying to merge
albi3ro Nov 21, 2024
7784464
fixing up tests
albi3ro Nov 21, 2024
4f5329c
Merge branch 'master' into native-dq-execution
albi3ro Nov 26, 2024
f5a7adb
xfailing more tests
albi3ro Nov 27, 2024
9f77bf2
Merge branch 'master' into native-dq-execution
albi3ro Nov 27, 2024
d175149
fix failing test
albi3ro Nov 28, 2024
b202302
Merge branch 'master' into native-dq-execution
albi3ro Nov 28, 2024
663b69a
add test
albi3ro Nov 28, 2024
a35a697
black, test, pylint
albi3ro Nov 28, 2024
2a29950
Update tests/capture/test_capture_mid_measure.py
albi3ro Nov 28, 2024
68a6ba5
add workflow developement status
albi3ro Nov 29, 2024
e2910ba
add workflow developement status
albi3ro Nov 29, 2024
9b9d939
add workflow developement status
albi3ro Nov 29, 2024
ce5d8c3
jit circuits on dq
albi3ro Nov 29, 2024
9367658
Merge branch 'master' into capture-execution-jit
albi3ro Dec 5, 2024
a06951b
Update tests/capture/test_capture_mid_measure.py
albi3ro Dec 5, 2024
2885933
oops
albi3ro Dec 5, 2024
5cfd98f
Update doc/releases/changelog-dev.md
albi3ro Dec 5, 2024
d9480e4
Merge branch 'master' into capture-execution-jit
albi3ro Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ featuring a `simulate` function for simulating mixed states in analytic mode.
* Execution with capture enabled now follows a new execution pipeline and natively passes the
captured jaxpr to the device. Since it no longer falls back to the old pipeline, execution
only works with a reduced feature set.
[(#6496)](https://github.com/PennyLaneAI/pennylane/pull/6596)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
[(#6596)](https://github.com/PennyLaneAI/pennylane/pull/6596)

* PennyLane transforms can now be captured as primitives with experimental program capture enabled.
Expand Down
4 changes: 3 additions & 1 deletion pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
has_jax = True
try:
import jax
from jax.interpreters import ad, batching
from jax.interpreters import ad, batching, mlir
except ImportError:
has_jax = False

Expand Down Expand Up @@ -298,6 +298,8 @@

batching.primitive_batchers[qnode_prim] = _qnode_batching_rule

mlir.register_lowering(qnode_prim, mlir.lower_fun(qnode_impl, multiple_results=True))

Check notice on line 301 in pennylane/workflow/_capture_qnode.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/workflow/_capture_qnode.py#L301

Undefined variable 'qnode_impl' (undefined-variable)
lillian542 marked this conversation as resolved.
Show resolved Hide resolved

return qnode_prim


Expand Down
2 changes: 1 addition & 1 deletion tests/capture/test_capture_mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class TestMidMeasureExecute:

@pytest.mark.xfail(strict=False) # single branch statistics sometimes gives good results
@pytest.mark.parametrize("reset", [True, False])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("postselect", [pytest.param(None, marks=pytest.mark.xfail), 0, 1])
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5))
def test_simple_circuit_execution(self, phi, reset, postselect, get_device, shots, mp_fn, seed):
"""Test that circuits with mid-circuit measurements can be executed in a QNode."""
Expand Down
13 changes: 13 additions & 0 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,19 @@ def circuit(x):
assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt))


def test_qnode_jit():
"""Test that executions on default qubit can be jitted."""

@qml.qnode(qml.device("default.qubit", wires=1))
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

x = jax.numpy.array(-0.5)
res = jax.jit(circuit)(0.5)
assert qml.math.allclose(res, jax.numpy.cos(x))


# pylint: disable=too-many-public-methods
class TestQNodeVmapIntegration:
"""Tests for integrating JAX vmap with the QNode primitive."""
Expand Down
Loading