Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] add eval_jaxpr method to DefaultQubit #6594

Merged
merged 46 commits into from
Nov 20, 2024
Merged
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e827977
exploring an idea
albi3ro Jun 21, 2024
33b946c
lightning interpreter
albi3ro Jun 21, 2024
bdc2b1e
add call jaxpr function
albi3ro Jun 21, 2024
5185112
add demo notebook
albi3ro Jun 21, 2024
786caf3
call is function transform
albi3ro Aug 9, 2024
d8f7e59
Merge branch 'master' into plxpr-interpreter
albi3ro Aug 13, 2024
13b091d
assorted improvements
albi3ro Aug 13, 2024
3f7b5ff
more moving things around
albi3ro Aug 13, 2024
6e2b334
something?
albi3ro Aug 15, 2024
4f3efb6
Merge branch 'master' into plxpr-interpreter
albi3ro Aug 21, 2024
a61acd2
Add PLxprInterpreter base class
albi3ro Aug 26, 2024
0d68f07
Merge branch 'master' into plxpr-interpreter-base
albi3ro Aug 26, 2024
1609ed3
qnode fix
albi3ro Aug 26, 2024
f4a6ba8
Merge branch 'plxpr-interpreter-base' of https://github.com/PennyLane…
albi3ro Aug 26, 2024
f593d45
starting to write tests
albi3ro Aug 27, 2024
0defdde
trying to improve op math handling
albi3ro Aug 28, 2024
0ecf1a8
testing
albi3ro Aug 30, 2024
65202fa
improvementS
albi3ro Aug 30, 2024
22e32f7
more fixes
albi3ro Sep 3, 2024
93724c4
Merge branch 'master' into plxpr-interpreter-base
albi3ro Sep 9, 2024
de6f4b3
Merge branch 'master' into plxpr-interpreter-base
albi3ro Sep 30, 2024
ba298eb
adding tests
albi3ro Oct 1, 2024
d9d9fd7
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 1, 2024
1cc4e79
more tests
albi3ro Oct 1, 2024
1a8f410
more test changes
albi3ro Oct 2, 2024
769570f
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 2, 2024
e8a5c5a
some more tests and polishing
albi3ro Oct 2, 2024
c42090c
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 2, 2024
f806708
add default qubit interpreter
albi3ro Oct 2, 2024
751d66c
use copy not child
albi3ro Oct 9, 2024
c232286
merge in master
albi3ro Nov 11, 2024
bfbe68d
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 11, 2024
d3f729e
starting to write tests
albi3ro Nov 12, 2024
879bdad
more testing
albi3ro Nov 12, 2024
7741113
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 12, 2024
b80d290
more tests
albi3ro Nov 13, 2024
d0298ad
Apply suggestions from code review
albi3ro Nov 15, 2024
1d4bf51
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 15, 2024
07fae7b
update initial key each execution
albi3ro Nov 18, 2024
e076a48
changelog
albi3ro Nov 18, 2024
2c736bd
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 18, 2024
53dfac0
add eval_jaxpr method to DefaultQubit
albi3ro Nov 18, 2024
cf22d05
Update doc/releases/changelog-dev.md
albi3ro Nov 18, 2024
eb3e277
fix no prng key
albi3ro Nov 18, 2024
03644f7
Update tests/devices/default_qubit/test_default_qubit_plxpr.py
albi3ro Nov 20, 2024
e9beec7
Merge branch 'master' into dq-eval-jaxpr
albi3ro Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use copy not child
albi3ro committed Oct 9, 2024
commit 751d66cbb2cf76f9128021a121e3d7555e5a8293
54 changes: 32 additions & 22 deletions pennylane/devices/qubit/dq_interpreter.py
Original file line number Diff line number Diff line change
@@ -14,10 +14,12 @@
"""
This module contains a class for executing plxpr using default qubit tools.
"""
from copy import copy

import jax
import numpy as np

from pennylane.capture import PlxprInterpreter
from pennylane.capture.base_interpreter import PlxprInterpreter
from pennylane.capture.primitives import (
adjoint_transform_prim,
cond_prim,
@@ -26,7 +28,7 @@
measure_prim,
while_loop_prim,
)
from pennylane.measurements import MidMeasureMP
from pennylane.measurements import MidMeasureMP, Shots

from .apply_operation import apply_operation
from .initialize_state import create_initial_state
@@ -52,34 +54,45 @@ class DefaultQubitInterpreter(PlxprInterpreter):

"""

def __init__(self, num_wires, shots, key=None, stateref=None):
def __init__(self, num_wires, shots, key: None | jax.numpy.ndarray = None):
self.num_wires = num_wires
self.shots = shots
self.stateref = stateref or {"state": None}
self.key = key
self.shots = Shots(shots)
if key is None:
key = jax.random.PRNGKey(np.random.random())
self.stateref = {"state": None, "key": key, "mcms": None}

@property
def state(self):
def state(self) -> jax.numpy.ndarray:
return self.stateref["state"]

@state.setter
def state(self, value):
def state(self, value: jax.numpy.ndarray):
self.stateref["state"] = value

def child(self) -> "DefaultQubitInterpreter":
return type(self)(
num_wires=self.num_wires, shots=self.shots, key=self.key, stateref=self.stateref
)
@property
def key(self) -> jax.numpy.ndarray:
return self.stateref["key"]

@property
def mcms(self):
return self.stateref["mcms"]

@key.setter
def key(self, value):
self.stateref["key"] = value

def setup(self):
if self.state is None:
self.state = create_initial_state(range(self.num_wires))
self.state = create_initial_state(range(self.num_wires), like="jax")
if self.mcms is None:
self.stateref["mcms"] = {}

def interpret_operation(self, op):
self.state = apply_operation(op, self.state)

def interpret_measurement_eqn(self, primitive, *invals, **params):
mp = primitive.impl(*invals, **params)

if self.shots:
self.key, new_key = jax.random.split(self.key, 2)
# note that this does *not* group commuting measurements
@@ -107,7 +120,7 @@ def _(self, *invals, jaxpr_body_fn, n_consts):

res = None
for i in range(start, stop, step):
res = self.child().eval(jaxpr_body_fn, consts, i, *init_state)
res = copy(self).eval(jaxpr_body_fn, consts, i, *init_state)

return res

@@ -119,21 +132,18 @@ def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond)
init_state = invals[n_consts_body + n_consts_cond :]

fn_res = init_state
while self.child().eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]:
fn_res = self.child().eval(jaxpr_body_fn, consts_body, *fn_res)
while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res)

return fn_res


@DefaultQubitInterpreter.register_primitive(measure_prim)
def _(self, *invals, reset, postselect):
mp = MidMeasureMP(invals, reset=reset, postselect=postselect)
mid_measurements = {}
self.key, new_key = jax.random.split(self.key, 2)
self.state = apply_operation(
mp, self.state, mid_measurements=mid_measurements, prng_key=new_key
)
return mid_measurements[mp]
self.state = apply_operation(mp, self.state, mid_measurements=self.mcms, prng_key=new_key)
return self.mcms[mp]


@DefaultQubitInterpreter.register_primitive(cond_prim)
@@ -148,5 +158,5 @@ def _(self, *invals, jaxpr_branches, n_consts_per_branch, n_args):
consts = consts_flat[start : start + n_consts]
start += n_consts
if pred and jaxpr is not None:
return self.child().eval_jaxpr(jaxpr, consts, *args)
return copy(self).eval_jaxpr(jaxpr, consts, *args)
return ()