From e827977b105dc26a6322155616c1c8d547222f90 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 21 Jun 2024 13:16:19 -0400 Subject: [PATCH 01/45] exploring an idea --- pennylane/capture/interpreters.py | 201 ++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 pennylane/capture/interpreters.py diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py new file mode 100644 index 00000000000..3361d0dcd39 --- /dev/null +++ b/pennylane/capture/interpreters.py @@ -0,0 +1,201 @@ +from typing import Optional + +import jax + +from pennylane.devices.qubit import apply_operation, create_initial_state, measure +from pennylane.tape import QuantumScript +from pennylane.transforms.optimization.cancel_inverses import _are_inverses + +from .primitives import _get_abstract_measurement, _get_abstract_operator + +AbstractOperator = _get_abstract_operator() +AbstractMeasurement = _get_abstract_measurement() + + +class PlxprInterpreter: + + _env: Optional[dict] = None + _op_math_cache: Optional[dict] = None + + def _read(self, var): + """Extract the value corresponding to a variable.""" + if self._env is None: + raise ValueError("_env not yet initialized.") + return var.val if type(var) is jax.core.Literal else self._env[var] + + def setup(self): + pass + + def cleanup(self): + pass + + def interpret_operation(self, op: "pennylane.operation.Operator"): + raise NotImplementedError + + def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): + + invals = [self._read(invar) for invar in eqn.invars] + op = eqn.primitive.impl(*invals, **eqn.params) + if not isinstance(eqn.outvars[0], jax.core.DropVar): + self._op_math_cache[eqn.outvars[0]] = eqn + self._env[eqn.outvars[0]] = op + return + return self.interpret_operation(op) + + def interpret_measurement(self, measurement: "pennylane.measurements.MeasurementProcess"): + raise NotImplementedError + + def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): + invals = [self._read(invar) for invar in eqn.invars] + mp = eqn.primitive.impl(*invals, **eqn.params) + return self.interpret_measurement(mp) + + def __call__(self, jaxpr, consts, *args): + self._env = {} + self._op_math_cache = {} + self.setup() + + for arg, invar in zip(args, jaxpr.invars): + self._env[invar] = arg + for const, constvar in zip(consts, jaxpr.constvars): + self._env[constvar] = const + + measurements = [] + for eqn in jaxpr.eqns: + if isinstance(eqn.outvars[0].aval, AbstractOperator): + self.interpret_operation_eqn(eqn) + + elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): + measurement = self.interpret_measurement_eqn(eqn) + measurements.append(measurement) + else: + invals = [self._read(invar) for invar in eqn.invars] + outvals = eqn.primitive.bind(*invals, **eqn.params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + for outvar, outval in zip(eqn.outvars, outvals): + self._env[outvar] = outval + + self.cleanup() + # Read the final result of the Jaxpr from the environment + return measurements + + +class DefaultQubitInterpreter(PlxprInterpreter): + + _state = None + + def __init__(self, num_wires): + self.num_wires = num_wires + + def setup(self): + self._state = create_initial_state(range(self.num_wires)) + + def cleanup(self): + self._state = None + + def interpret_operation(self, op): + self._state = apply_operation(op, self._state) + + def interpret_measurement(self, m): + return measure(m, self._state) + + +class DecompositionInterpreter(PlxprInterpreter): + """ + >>> def f(x): + ... qml.IsingXX(x, wires=(0,1)) + ... qml.Rot(0.5, x, 1.5, wires=1) + >>> jaxpr = jax.make_jaxpr(f)(0.5) + >>> converter = partial(DecompositionInterpreter(), jaxpr.jaxpr, jaxpr.consts) + >>> jax.make_jaxpr(converter)(0.5) + { lambda ; a:f32[]. let + _:AbstractOperator() = CNOT[n_wires=2] 0 1 + _:AbstractOperator() = RX[n_wires=1] a 0 + _:AbstractOperator() = CNOT[n_wires=2] 0 1 + _:AbstractOperator() = RZ[n_wires=1] 0.5 1 + _:AbstractOperator() = RY[n_wires=1] a 1 + _:AbstractOperator() = RZ[n_wires=1] 1.5 1 + in () } + + """ + + def interpret_operation(self, op): + op.decomposition() + + +class ConvertToTape(PlxprInterpreter): + """ + + >>> def f(x): + ... qml.RX(x, wires=0) + ... return qml.expval(qml.Z(0)) + >>> jaxpr = jax.make_jaxpr(f)(0.5) + >>> ConvertToTape()(jaxpr.jaxpr, jaxpr.consts, 1.2).circuit + [RX(1.2, wires=[0]), expval(Z(0))] + + """ + + def setup(self): + self._ops = [] + self._measurements = [] + + def interpret_operation(self, op): + self._ops.append(op) + + def interpret_measurement(self, m): + self._measurements.append(m) + return m + + def __call__(self, jaxpr, consts, *args): + out = super().__call__(jaxpr, consts, *args) + return QuantumScript(self._ops, self._measurements) + + +class CancelInverses(PlxprInterpreter): + """ + + >>> def f(x): + ... qml.X(0) + ... qml.X(0) + ... qml.Hadamard(0) + ... qml.Y(1) + ... qml.RX(x, 0) + ... qml.adjoint(qml.RX(x, 0)) + >>> jaxpr = jax.make_jaxpr(f)(0.5) + >>> converter = partial(CancelInverses(), jaxpr.jaxpr, jaxpr.consts) + >>> jax.make_jaxpr(converter)(0.5) + { lambda ; a:f64[]. let + _:AbstractOperator() = Hadamard[n_wires=1] 0 + _:AbstractOperator() = PauliY[n_wires=1] 1 + in () } + + """ + + _last_op_on_wires = None + + def setup(self): + self._last_op_on_wires = {} + + def interpret_operation(self, op): + if len(op.wires) != 1: + for w in op.wires: + self._last_op_on_wires[w] = None + op._unflatten(*op._flatten()) + return + + w = op.wires[0] + if w in self._last_op_on_wires: + if _are_inverses(self._last_op_on_wires[w], op): + self._last_op_on_wires[w] = None + return + previous_op = self._last_op_on_wires[w] + if previous_op is not None: + previous_op._unflatten(*previous_op._flatten()) + self._last_op_on_wires[w] = op + return + + def cleanup(self): + for _, op in self._last_op_on_wires.items(): + if op is not None: + op._unflatten(*op._flatten()) From 33b946cee271499e0de19e94f9936419b94ad9fd Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 21 Jun 2024 13:43:59 -0400 Subject: [PATCH 02/45] lightning interpreter --- pennylane/capture/interpreters.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py index 3361d0dcd39..62a79e08370 100644 --- a/pennylane/capture/interpreters.py +++ b/pennylane/capture/interpreters.py @@ -100,6 +100,23 @@ def interpret_operation(self, op): def interpret_measurement(self, m): return measure(m, self._state) +from pennylane_lightning.lightning_qubit._state_vector import LightningStateVector, LightningMeasurements + +class LightningInterpreter(PlxprInterpreter): + + def __init__(self, num_wires): + self._num_wires = num_wires + + def setup(self): + self._state = LightningStateVector(self._num_wires) + + def interpret_operation(self, op): + self._state._apply_lightning([op]) + + def interpret_measurement(self, m): + return LightningMeasurements(self._state).measurement(m) + + class DecompositionInterpreter(PlxprInterpreter): """ From bdc2b1e78d3d18c0ef6b2d1a38ba7fd04f1d04cc Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 21 Jun 2024 13:51:54 -0400 Subject: [PATCH 03/45] add call jaxpr function --- pennylane/capture/interpreters.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py index 62a79e08370..85c06d3f766 100644 --- a/pennylane/capture/interpreters.py +++ b/pennylane/capture/interpreters.py @@ -1,3 +1,5 @@ + +from functools import partial from typing import Optional import jax @@ -50,6 +52,10 @@ def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): mp = eqn.primitive.impl(*invals, **eqn.params) return self.interpret_measurement(mp) + def call_jaxpr(self, jaxpr, consts): + partial_call = partial(self, jaxpr, consts) + return jax.make_jaxpr(partial_call) + def __call__(self, jaxpr, consts, *args): self._env = {} self._op_math_cache = {} @@ -124,8 +130,7 @@ class DecompositionInterpreter(PlxprInterpreter): ... qml.IsingXX(x, wires=(0,1)) ... qml.Rot(0.5, x, 1.5, wires=1) >>> jaxpr = jax.make_jaxpr(f)(0.5) - >>> converter = partial(DecompositionInterpreter(), jaxpr.jaxpr, jaxpr.consts) - >>> jax.make_jaxpr(converter)(0.5) + >>> DecompositionInterpreter().call_jaxpr(jaxpr.jaxpr, jaxpr.consts)(0.5) { lambda ; a:f32[]. let _:AbstractOperator() = CNOT[n_wires=2] 0 1 _:AbstractOperator() = RX[n_wires=1] a 0 @@ -180,8 +185,7 @@ class CancelInverses(PlxprInterpreter): ... qml.RX(x, 0) ... qml.adjoint(qml.RX(x, 0)) >>> jaxpr = jax.make_jaxpr(f)(0.5) - >>> converter = partial(CancelInverses(), jaxpr.jaxpr, jaxpr.consts) - >>> jax.make_jaxpr(converter)(0.5) + >>> CancelInverses().call_jaxpr(jaxpr.jaxpr, jaxpr.consts)(0.5) { lambda ; a:f64[]. let _:AbstractOperator() = Hadamard[n_wires=1] 0 _:AbstractOperator() = PauliY[n_wires=1] 1 From 51851125985f9e420f67ebc94ea66fb078774782 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 21 Jun 2024 15:01:25 -0400 Subject: [PATCH 04/45] add demo notebook --- pennylane/capture/Interpreters_Demo.md | 207 +++++++++++++++++++++++++ pennylane/capture/interpreters.py | 93 ++++++++++- 2 files changed, 293 insertions(+), 7 deletions(-) create mode 100644 pennylane/capture/Interpreters_Demo.md diff --git a/pennylane/capture/Interpreters_Demo.md b/pennylane/capture/Interpreters_Demo.md new file mode 100644 index 00000000000..81d741933e8 --- /dev/null +++ b/pennylane/capture/Interpreters_Demo.md @@ -0,0 +1,207 @@ +```python +import pennylane as qml +import jax + +from pennylane.capture.interpreters import PlxprInterpreter, DefaultQubitInterpreter, LightningInterpreter, DecompositionInterpreter, ConvertToTape, CancelInverses, MergeRotations +qml.capture.enable() +``` + +### Demonstrating Existing Implementations + + +```python +def f(x): + qml.X(0) + qml.adjoint(qml.X(0)) + qml.Hadamard(0) + qml.IsingXX(x, wires=(0,1)) + return qml.expval(qml.Z(0)), qml.probs(wires=(0,1)) + +plxpr = jax.make_jaxpr(f)(0.5) +``` + + +```python +DefaultQubitInterpreter(num_wires=2)(plxpr.jaxpr, plxpr.consts, 1.2) +``` + + + + + [0.0, array([0.34058944, 0.15941056, 0.34058944, 0.15941056])] + + + + +```python +LightningInterpreter(num_wires=2)(plxpr.jaxpr, plxpr.consts, 1.2) +``` + + + + + [0.0, array([0.34058944, 0.15941056, 0.34058944, 0.15941056])] + + + + +```python +tape = ConvertToTape()(plxpr.jaxpr, plxpr.consts, 1.2) +print(tape.draw()) +``` + + 0: ──X──X†──H─╭IsingXX─┤ ╭Probs + 1: ───────────╰IsingXX─┤ ╰Probs + + + +```python +DecompositionInterpreter().call_jaxpr(plxpr.jaxpr, plxpr.consts)(2.5) +``` + + + + + { lambda ; a:f64[]. let + _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 + _:AbstractOperator() = RX[n_wires=1] 3.141592653589793 0 + _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 + _:AbstractOperator() = PauliX[n_wires=1] 0 + _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 + _:AbstractOperator() = RX[n_wires=1] 1.5707963267948966 0 + _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 + _:AbstractOperator() = CNOT[n_wires=2] 0 1 + _:AbstractOperator() = RX[n_wires=1] a 0 + _:AbstractOperator() = CNOT[n_wires=2] 0 1 + b:AbstractOperator() = PauliZ[n_wires=1] 0 + c:AbstractMeasurement(n_wires=None) = expval_obs b + d:AbstractMeasurement(n_wires=2) = probs_wires 0 1 + in (c, d) } + + + + +```python +CancelInverses().call_jaxpr(plxpr.jaxpr, plxpr.consts)(2.5) +``` + + + + + { lambda ; a:f64[]. let + _:AbstractOperator() = IsingXX[n_wires=2] a 0 1 + b:AbstractOperator() = PauliZ[n_wires=1] 0 + c:AbstractMeasurement(n_wires=None) = expval_obs b + d:AbstractMeasurement(n_wires=2) = probs_wires 0 1 + in (c, d) } + + + + +```python +def g(x): + qml.RX(x, 0) + qml.RX(2*x, 0) + qml.RX(-4*x, 0) + qml.X(0) + qml.RX(0.5, 0) + +plxpr = jax.make_jaxpr(g)(1.0) +MergeRotations().call_jaxpr(plxpr.jaxpr, plxpr.consts)(1.0) +``` + + + + + { lambda ; a:f64[]. let + b:f64[] = mul 2.0 a + c:f64[] = add b a + d:f64[] = mul -4.0 a + e:f64[] = add d c + _:AbstractOperator() = RX[n_wires=1] e 0 + _:AbstractOperator() = PauliX[n_wires=1] 0 + _:AbstractOperator() = RX[n_wires=1] 0.5 0 + in () } + + + +### Writing a new interpreter + + +```python +class AddSWAPNoise(PlxprInterpreter): + + def __init__(self, scale, prng_key=jax.random.key(12345)): + self.scale = scale + self.prng_key = prng_key + + def interpret_operation(self, op): + if isinstance(op, qml.SWAP): + self.prng_key, subkey = jax.random.split(self.prng_key) + phi = self.scale*jax.random.uniform(subkey) + qml.PhaseShift(phi, op.wires[0]) + val, structure = jax.tree_util.tree_flatten(op) + jax.tree_util.tree_unflatten(structure, val) + + def interpret_measurement(self, m): + vals, structure = jax.tree_util.tree_flatten(m) + return jax.tree_util.tree_unflatten(structure, vals) +``` + + +```python +def f(): + qml.SWAP((0,1)) + qml.SWAP((1,2)) + return qml.expval(qml.Z(0)) + +plxpr = jax.make_jaxpr(f)() +AddSWAPNoise(0.1).call_jaxpr(plxpr.jaxpr, plxpr.consts)() +``` + + + + + let _uniform = { lambda ; a:key[] b:f64[] c:f64[]. let + d:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b + e:f64[] = convert_element_type[new_dtype=float64 weak_type=False] c + f:u64[] = random_bits[bit_width=64 shape=()] a + g:u64[] = shift_right_logical f 12 + h:u64[] = or g 4607182418800017408 + i:f64[] = bitcast_convert_type[new_dtype=float64] h + j:f64[] = sub i 1.0 + k:f64[] = sub e d + l:f64[] = mul j k + m:f64[] = add l d + n:f64[] = reshape[dimensions=None new_sizes=()] m + o:f64[] = max d n + in (o,) } in + { lambda p:key[]; . let + q:key[2] = random_split[shape=(2,)] p + r:key[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] q + s:key[] = squeeze[dimensions=(0,)] r + t:key[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] q + u:key[] = squeeze[dimensions=(0,)] t + v:f64[] = pjit[name=_uniform jaxpr=_uniform] u 0.0 1.0 + w:f64[] = mul 0.1 v + _:AbstractOperator() = PhaseShift[n_wires=1] w 0 + _:AbstractOperator() = SWAP[n_wires=2] 0 1 + x:key[2] = random_split[shape=(2,)] s + y:key[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] x + _:key[] = squeeze[dimensions=(0,)] y + z:key[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] x + ba:key[] = squeeze[dimensions=(0,)] z + bb:f64[] = pjit[name=_uniform jaxpr=_uniform] ba 0.0 1.0 + bc:f64[] = mul 0.1 bb + _:AbstractOperator() = PhaseShift[n_wires=1] bc 1 + _:AbstractOperator() = SWAP[n_wires=2] 1 2 + bd:AbstractOperator() = PauliZ[n_wires=1] 0 + be:AbstractMeasurement(n_wires=None) = expval_obs bd + in (be,) } + + + + +```python + +``` diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py index 85c06d3f766..463b3e6ebe2 100644 --- a/pennylane/capture/interpreters.py +++ b/pennylane/capture/interpreters.py @@ -1,8 +1,8 @@ - from functools import partial from typing import Optional import jax +from jax.tree_util import tree_flatten, tree_unflatten from pennylane.devices.qubit import apply_operation, create_initial_state, measure from pennylane.tape import QuantumScript @@ -106,7 +106,12 @@ def interpret_operation(self, op): def interpret_measurement(self, m): return measure(m, self._state) -from pennylane_lightning.lightning_qubit._state_vector import LightningStateVector, LightningMeasurements + +from pennylane_lightning.lightning_qubit._state_vector import ( + LightningMeasurements, + LightningStateVector, +) + class LightningInterpreter(PlxprInterpreter): @@ -123,7 +128,6 @@ def interpret_measurement(self, m): return LightningMeasurements(self._state).measurement(m) - class DecompositionInterpreter(PlxprInterpreter): """ >>> def f(x): @@ -143,7 +147,15 @@ class DecompositionInterpreter(PlxprInterpreter): """ def interpret_operation(self, op): - op.decomposition() + if op.has_decomposition: + op.decomposition() + else: + vals, structure = tree_flatten(op) + tree_unflatten(structure, vals) + + def interpret_measurement(self, m): + vals, structure = tree_flatten(m) + return tree_unflatten(structure, vals) class ConvertToTape(PlxprInterpreter): @@ -202,7 +214,8 @@ def interpret_operation(self, op): if len(op.wires) != 1: for w in op.wires: self._last_op_on_wires[w] = None - op._unflatten(*op._flatten()) + vals, structure = tree_flatten(op) + tree_unflatten(structure, vals) return w = op.wires[0] @@ -212,11 +225,77 @@ def interpret_operation(self, op): return previous_op = self._last_op_on_wires[w] if previous_op is not None: - previous_op._unflatten(*previous_op._flatten()) + vals, structure = tree_flatten(previous_op) + tree_unflatten(structure, vals) self._last_op_on_wires[w] = op return + def interpret_measurement(self, m): + vals, structure = tree_flatten(m) + return tree_unflatten(structure, vals) + + def cleanup(self): + for _, op in self._last_op_on_wires.items(): + if op is not None: + vals, structure = tree_flatten(op) + tree_unflatten(structure, vals) + + +class MergeRotations(PlxprInterpreter): + """ + + >>> def g(x): + ... qml.RX(x, 0) + ... qml.RX(2*x, 0) + ... qml.RX(-4*x, 0) + ... qml.X(0) + ... qml.RX(0.5, 0) + >>> plxpr = jax.make_jaxpr(g)(1.0) + >>> MergeRotations().call_jaxpr(plxpr.jaxpr, plxpr.consts)(1.0) + { lambda ; a:f64[]. let + b:f64[] = mul 2.0 a + c:f64[] = add b a + d:f64[] = mul -4.0 a + e:f64[] = add d c + _:AbstractOperator() = RX[n_wires=1] e 0 + _:AbstractOperator() = PauliX[n_wires=1] 0 + _:AbstractOperator() = RX[n_wires=1] 0.5 0 + in () } + + """ + + _last_op_on_wires = None + + def setup(self): + self._last_op_on_wires = {} + + def interpret_operation(self, op): + if len(op.wires) != 1: + for w in op.wires: + self._last_op_on_wires[w] = None + vals, structure = tree_flatten(op) + tree_unflatten(structure, vals) + return + + w = op.wires[0] + if w in self._last_op_on_wires: + previous_op = self._last_op_on_wires[w] + if op.name == previous_op.name and op.wires == previous_op.wires: + new_data = [d1 + d2 for d1, d2 in zip(op.data, previous_op.data)] + self._last_op_on_wires[w] = op._primitive.impl(*new_data, wires=op.wires) + return + if previous_op is not None: + vals, structure = tree_flatten(previous_op) + tree_unflatten(structure, vals) + self._last_op_on_wires[w] = op + return + + def interpret_measurement(self, m): + vals, structure = tree_flatten(m) + return tree_unflatten(structure, vals) + def cleanup(self): for _, op in self._last_op_on_wires.items(): if op is not None: - op._unflatten(*op._flatten()) + vals, structure = tree_flatten(op) + tree_unflatten(structure, vals) From 786caf3e1a0d874a17c2fb2485d6c9309f6f96f3 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 9 Aug 2024 17:01:53 -0400 Subject: [PATCH 05/45] call is function transform --- pennylane/capture/interpreters.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py index 463b3e6ebe2..b352f98c66a 100644 --- a/pennylane/capture/interpreters.py +++ b/pennylane/capture/interpreters.py @@ -52,11 +52,20 @@ def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): mp = eqn.primitive.impl(*invals, **eqn.params) return self.interpret_measurement(mp) - def call_jaxpr(self, jaxpr, consts): - partial_call = partial(self, jaxpr, consts) - return jax.make_jaxpr(partial_call) + def call(self, jaxpr, n_consts): + def wrapper(*args): + return self(jaxpr.jaxpr, args[:n_consts], *args[n_consts:]) - def __call__(self, jaxpr, consts, *args): + return wrapper + + def __call__(self, f): + def wrapper(*args): + jaxpr = jax.make_jaxpr(f)(*args) + return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + + return wrapper + + def eval(self, jaxpr, consts, *args): self._env = {} self._op_math_cache = {} self.setup() From 13b091de5cc3312f989e61ec6291f178e32ab3f9 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 13 Aug 2024 15:10:41 -0400 Subject: [PATCH 06/45] assorted improvements --- pennylane/capture/base_interpreter.py | 188 ++++++++++++++++++++++++++ pennylane/capture/interpreters.py | 106 +-------------- 2 files changed, 194 insertions(+), 100 deletions(-) create mode 100644 pennylane/capture/base_interpreter.py diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py new file mode 100644 index 00000000000..28b014269c4 --- /dev/null +++ b/pennylane/capture/base_interpreter.py @@ -0,0 +1,188 @@ +import copy +from functools import partial, wraps +from typing import Optional + +import jax +from jax.tree_util import tree_flatten, tree_unflatten + +from pennylane import cond +from pennylane.tape import QuantumScript +from pennylane.transforms.optimization.cancel_inverses import _are_inverses +from pennylane.workflow import qnode + +import pennylane as qml +from .primitives import _get_abstract_measurement, _get_abstract_operator +from .capture_qnode import _get_qnode_prim +from pennylane.compiler.qjit_api import _get_for_loop_qfunc_prim, for_loop, while_loop, _get_while_loop_qfunc_prim +from pennylane.ops.op_math.condition import _get_cond_qfunc_prim + +for_prim = _get_for_loop_qfunc_prim() +while_prim = _get_while_loop_qfunc_prim() +cond_prim = _get_cond_qfunc_prim() +qnode_prim = _get_qnode_prim() + +AbstractOperator = _get_abstract_operator() +AbstractMeasurement = _get_abstract_measurement() + + +class PlxprInterpreter: + + _env: dict + _op_math_cache: dict + _primitive_registrations = {} + + def __init_subclass__(cls) -> None: + cls._primitive_registrations = copy.copy(PlxprInterpreter._primitive_registrations) + + def __init__(self, state=None): + self._env = {} + self._op_math_cache = {} + self.state = state + + @classmethod + def register_primitive(cls, primitive): + def decorator(f): + cls._primitive_registrations[primitive] = f + return f + return decorator + + def read(self, var): + """Extract the value corresponding to a variable.""" + if self._env is None: + raise ValueError("_env not yet initialized.") + return var.val if type(var) is jax.core.Literal else self._env[var] + + def setup(self): + pass + + def cleanup(self): + pass + + def interpret_operation(self, op: "pennylane.operation.Operator"): + raise NotImplementedError + + def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): + + invals = [self.read(invar) for invar in eqn.invars] + op = eqn.primitive.impl(*invals, **eqn.params) + if not isinstance(eqn.outvars[0], jax.core.DropVar): + self._op_math_cache[eqn.outvars[0]] = eqn + self._env[eqn.outvars[0]] = op + return + return self.interpret_operation(op) + + def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): + invals = [self.read(invar) for invar in eqn.invars] + return eqn.primitive.bind(*invals, **eqn.params) + + def eval(self, jaxpr, consts, *args): + self._env = {} + self._op_math_cache = {} + self.setup() + + for arg, invar in zip(args, jaxpr.invars): + self._env[invar] = arg + for const, constvar in zip(consts, jaxpr.constvars): + self._env[constvar] = const + + measurements = [] + for eqn in jaxpr.eqns: + custom_handler= self._primitive_registrations.get(eqn.primitive, None) + if custom_handler: + custom_handler(self, eqn.outvars, *eqn.invars, **eqn.params) + elif isinstance(eqn.outvars[0].aval, AbstractOperator): + self.interpret_operation_eqn(eqn) + + elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): + measurement = self.interpret_measurement_eqn(eqn) + self._env[eqn.outvars[0]] = measurement + measurements.append(measurement) + else: + invals = [self.read(invar) for invar in eqn.invars] + outvals = eqn.primitive.bind(*invals, **eqn.params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + for outvar, outval in zip(eqn.outvars, outvals): + self._env[outvar] = outval + + self.cleanup() + # Read the final result of the Jaxpr from the environment + return [self._env[outvar] for outvar in jaxpr.outvars] + + def __call__(self, f): + @wraps(f) + def wrapper(*args, **kwargs): + jaxpr = jax.make_jaxpr(partial(f, **kwargs))(*args) + return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + + return wrapper + +@PlxprInterpreter.register_primitive(for_prim) +def handle_for_loop(self, outvars, *invars, jaxpr_body_fn, n_consts): + invals = [self.read(invar) for invar in invars] + start, stop, step = invals[0], invals[1], invals[2] + consts = invals[3:3+n_consts] + + @for_loop(start, stop, step) + def g(i, *init_state): + return type(self)().eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) + + res = g(*invals[3+n_consts:]) + + for outvar, outval in zip(outvars, res): + self._env[outvar] = outval + +@PlxprInterpreter.register_primitive(cond_prim) +def handle_cond(self, outvars, *invars, jaxpr_branches, n_consts_per_branch, n_args): + n_branches = len(jaxpr_branches) + invals = [self.read(var) for var in invars] + conditions = invals[:n_branches] + consts_flat = invals[n_branches + n_args :] + + @cond(conditions[0]) + def true_fn(*args): + return type(self)().eval(jaxpr_branches[0].jaxpr, consts_flat[:n_consts_per_branch[0]], *args) + + @true_fn.otherwise + def _(*args): + return type(self)().eval(jaxpr_branches[-1].jaxpr, consts_flat[n_consts_per_branch[0]:], *args) + + res = true_fn(*invals[n_branches : n_branches + n_args]) + + for outvar, outval in zip(outvars, res): + self._env[outvar] = outval + + +@PlxprInterpreter.register_primitive(while_prim) +def handle_while_loop(self, outvars, *invars, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): + invals = [self.read(invar) for invar in invars] + consts_body = invals[:n_consts_body] + consts_cond = invals[n_consts_body: n_consts_body+n_consts_cond] + init_state = invals[n_consts_body+n_consts_cond:] + + def cond_fn(*args): + return jax.core.eval_jaxpr(jaxpr_cond_fn.jaxpr, consts_cond, *args) + + @while_loop(cond_fn) + def loop(*args): + return type(self)().eval(jaxpr_body_fn.jaxpr, consts_body, *args) + + res = loop(*init_state) + + for outvar, outval in zip(outvars, res): + self._env[outvar] = outval + + +@PlxprInterpreter.register_primitive(qnode_prim) +def handle_qnode(self, outvars, *invars, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts): + invals = [self.read(invar) for invar in invars] + consts = invals[:n_consts] + + @qml.qnode(device, **qnode_kwargs) + def new_qnode(*args): + return type(self)().eval(qfunc_jaxpr, consts, *args) + + res = new_qnode(invals[n_consts:]) + + for outvar, outval in zip(outvars, res): + self._env[outvar] = outval diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py index b352f98c66a..084bf22e058 100644 --- a/pennylane/capture/interpreters.py +++ b/pennylane/capture/interpreters.py @@ -1,110 +1,20 @@ -from functools import partial -from typing import Optional -import jax from jax.tree_util import tree_flatten, tree_unflatten +from .base_interpreter import PlxprInterpreter + from pennylane.devices.qubit import apply_operation, create_initial_state, measure from pennylane.tape import QuantumScript -from pennylane.transforms.optimization.cancel_inverses import _are_inverses - -from .primitives import _get_abstract_measurement, _get_abstract_operator - -AbstractOperator = _get_abstract_operator() -AbstractMeasurement = _get_abstract_measurement() - - -class PlxprInterpreter: - - _env: Optional[dict] = None - _op_math_cache: Optional[dict] = None - - def _read(self, var): - """Extract the value corresponding to a variable.""" - if self._env is None: - raise ValueError("_env not yet initialized.") - return var.val if type(var) is jax.core.Literal else self._env[var] - - def setup(self): - pass - - def cleanup(self): - pass - - def interpret_operation(self, op: "pennylane.operation.Operator"): - raise NotImplementedError - - def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): - - invals = [self._read(invar) for invar in eqn.invars] - op = eqn.primitive.impl(*invals, **eqn.params) - if not isinstance(eqn.outvars[0], jax.core.DropVar): - self._op_math_cache[eqn.outvars[0]] = eqn - self._env[eqn.outvars[0]] = op - return - return self.interpret_operation(op) - - def interpret_measurement(self, measurement: "pennylane.measurements.MeasurementProcess"): - raise NotImplementedError - - def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): - invals = [self._read(invar) for invar in eqn.invars] - mp = eqn.primitive.impl(*invals, **eqn.params) - return self.interpret_measurement(mp) - - def call(self, jaxpr, n_consts): - def wrapper(*args): - return self(jaxpr.jaxpr, args[:n_consts], *args[n_consts:]) - - return wrapper - - def __call__(self, f): - def wrapper(*args): - jaxpr = jax.make_jaxpr(f)(*args) - return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) - - return wrapper - - def eval(self, jaxpr, consts, *args): - self._env = {} - self._op_math_cache = {} - self.setup() - - for arg, invar in zip(args, jaxpr.invars): - self._env[invar] = arg - for const, constvar in zip(consts, jaxpr.constvars): - self._env[constvar] = const - - measurements = [] - for eqn in jaxpr.eqns: - if isinstance(eqn.outvars[0].aval, AbstractOperator): - self.interpret_operation_eqn(eqn) - - elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): - measurement = self.interpret_measurement_eqn(eqn) - measurements.append(measurement) - else: - invals = [self._read(invar) for invar in eqn.invars] - outvals = eqn.primitive.bind(*invals, **eqn.params) - if not eqn.primitive.multiple_results: - outvals = [outvals] - for outvar, outval in zip(eqn.outvars, outvals): - self._env[outvar] = outval - - self.cleanup() - # Read the final result of the Jaxpr from the environment - return measurements - class DefaultQubitInterpreter(PlxprInterpreter): - _state = None - - def __init__(self, num_wires): + def __init__(self, num_wires, state = None): self.num_wires = num_wires + self._state = {"statevector": state} def setup(self): - self._state = create_initial_state(range(self.num_wires)) + if self._state is not None: + self._state = create_initial_state(range(self.num_wires)) def cleanup(self): self._state = None @@ -162,10 +72,6 @@ def interpret_operation(self, op): vals, structure = tree_flatten(op) tree_unflatten(structure, vals) - def interpret_measurement(self, m): - vals, structure = tree_flatten(m) - return tree_unflatten(structure, vals) - class ConvertToTape(PlxprInterpreter): """ From 3f7b5ffedbaaaaa942397e7f3364ac90be77613f Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 13 Aug 2024 17:00:59 -0400 Subject: [PATCH 07/45] more moving things around --- pennylane/capture/base_interpreter.py | 113 ++++++++++------------- pennylane/capture/interpreters.py | 127 ++++++++++++++++++++------ pennylane/capture/primitives.py | 2 +- pennylane/compiler/qjit_api.py | 1 + 4 files changed, 151 insertions(+), 92 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 28b014269c4..68a10c0b8a5 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -5,16 +5,21 @@ import jax from jax.tree_util import tree_flatten, tree_unflatten +import pennylane as qml from pennylane import cond +from pennylane.compiler.qjit_api import ( + _get_for_loop_qfunc_prim, + _get_while_loop_qfunc_prim, + for_loop, + while_loop, +) +from pennylane.ops.op_math.condition import _get_cond_qfunc_prim from pennylane.tape import QuantumScript from pennylane.transforms.optimization.cancel_inverses import _are_inverses from pennylane.workflow import qnode -import pennylane as qml -from .primitives import _get_abstract_measurement, _get_abstract_operator from .capture_qnode import _get_qnode_prim -from pennylane.compiler.qjit_api import _get_for_loop_qfunc_prim, for_loop, while_loop, _get_while_loop_qfunc_prim -from pennylane.ops.op_math.condition import _get_cond_qfunc_prim +from .primitives import _get_abstract_measurement, _get_abstract_operator for_prim = _get_for_loop_qfunc_prim() while_prim = _get_while_loop_qfunc_prim() @@ -28,15 +33,13 @@ class PlxprInterpreter: _env: dict - _op_math_cache: dict _primitive_registrations = {} - + def __init_subclass__(cls) -> None: - cls._primitive_registrations = copy.copy(PlxprInterpreter._primitive_registrations) + cls._primitive_registrations = copy.copy(cls._primitive_registrations) def __init__(self, state=None): self._env = {} - self._op_math_cache = {} self.state = state @classmethod @@ -44,8 +47,9 @@ def register_primitive(cls, primitive): def decorator(f): cls._primitive_registrations[primitive] = f return f + return decorator - + def read(self, var): """Extract the value corresponding to a variable.""" if self._env is None: @@ -62,14 +66,11 @@ def interpret_operation(self, op: "pennylane.operation.Operator"): raise NotImplementedError def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): - invals = [self.read(invar) for invar in eqn.invars] op = eqn.primitive.impl(*invals, **eqn.params) - if not isinstance(eqn.outvars[0], jax.core.DropVar): - self._op_math_cache[eqn.outvars[0]] = eqn - self._env[eqn.outvars[0]] = op - return - return self.interpret_operation(op) + if isinstance(eqn.outvars[0], jax.core.DropVar): + return self.interpret_operation(op) + return op def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): invals = [self.read(invar) for invar in eqn.invars] @@ -77,7 +78,6 @@ def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): def eval(self, jaxpr, consts, *args): self._env = {} - self._op_math_cache = {} self.setup() for arg, invar in zip(args, jaxpr.invars): @@ -85,25 +85,23 @@ def eval(self, jaxpr, consts, *args): for const, constvar in zip(consts, jaxpr.constvars): self._env[constvar] = const - measurements = [] for eqn in jaxpr.eqns: - custom_handler= self._primitive_registrations.get(eqn.primitive, None) + invals = [self.read(invar) for invar in eqn.invars] + + custom_handler = self._primitive_registrations.get(eqn.primitive, None) if custom_handler: - custom_handler(self, eqn.outvars, *eqn.invars, **eqn.params) + outvals = custom_handler(self, *invals, **eqn.params) elif isinstance(eqn.outvars[0].aval, AbstractOperator): - self.interpret_operation_eqn(eqn) - + outvals = self.interpret_operation_eqn(eqn) elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): - measurement = self.interpret_measurement_eqn(eqn) - self._env[eqn.outvars[0]] = measurement - measurements.append(measurement) + outvals = self.interpret_measurement_eqn(eqn) else: - invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) - if not eqn.primitive.multiple_results: - outvals = [outvals] - for outvar, outval in zip(eqn.outvars, outvals): - self._env[outvar] = outval + + if not eqn.primitive.multiple_results: + outvals = [outvals] + for outvar, outval in zip(eqn.outvars, outvals): + self._env[outvar] = outval self.cleanup() # Read the final result of the Jaxpr from the environment @@ -117,72 +115,61 @@ def wrapper(*args, **kwargs): return wrapper + @PlxprInterpreter.register_primitive(for_prim) -def handle_for_loop(self, outvars, *invars, jaxpr_body_fn, n_consts): - invals = [self.read(invar) for invar in invars] +def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): start, stop, step = invals[0], invals[1], invals[2] - consts = invals[3:3+n_consts] + consts = invals[3 : 3 + n_consts] @for_loop(start, stop, step) def g(i, *init_state): - return type(self)().eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) + return type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) - res = g(*invals[3+n_consts:]) + return g(*invals[3 + n_consts :]) - for outvar, outval in zip(outvars, res): - self._env[outvar] = outval @PlxprInterpreter.register_primitive(cond_prim) -def handle_cond(self, outvars, *invars, jaxpr_branches, n_consts_per_branch, n_args): +def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): n_branches = len(jaxpr_branches) - invals = [self.read(var) for var in invars] conditions = invals[:n_branches] consts_flat = invals[n_branches + n_args :] - @cond(conditions[0]) def true_fn(*args): - return type(self)().eval(jaxpr_branches[0].jaxpr, consts_flat[:n_consts_per_branch[0]], *args) - - @true_fn.otherwise - def _(*args): - return type(self)().eval(jaxpr_branches[-1].jaxpr, consts_flat[n_consts_per_branch[0]:], *args) + return type(self)(state=self.state).eval( + jaxpr_branches[0].jaxpr, consts_flat[: n_consts_per_branch[0]], *args + ) - res = true_fn(*invals[n_branches : n_branches + n_args]) + def false_fn(*args): + return type(self)(state=self.state).eval( + jaxpr_branches[-1].jaxpr, consts_flat[n_consts_per_branch[0] :], *args + ) - for outvar, outval in zip(outvars, res): - self._env[outvar] = outval + cond_fn = cond(conditions[0], true_fn, false_fn=false_fn) + return cond_fn(*invals[n_branches : n_branches + n_args]) @PlxprInterpreter.register_primitive(while_prim) -def handle_while_loop(self, outvars, *invars, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): - invals = [self.read(invar) for invar in invars] +def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): consts_body = invals[:n_consts_body] - consts_cond = invals[n_consts_body: n_consts_body+n_consts_cond] - init_state = invals[n_consts_body+n_consts_cond:] + consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] + init_state = invals[n_consts_body + n_consts_cond :] def cond_fn(*args): return jax.core.eval_jaxpr(jaxpr_cond_fn.jaxpr, consts_cond, *args) @while_loop(cond_fn) def loop(*args): - return type(self)().eval(jaxpr_body_fn.jaxpr, consts_body, *args) + return type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts_body, *args) - res = loop(*init_state) - - for outvar, outval in zip(outvars, res): - self._env[outvar] = outval + return loop(*init_state) @PlxprInterpreter.register_primitive(qnode_prim) -def handle_qnode(self, outvars, *invars, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts): - invals = [self.read(invar) for invar in invars] +def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts): consts = invals[:n_consts] @qml.qnode(device, **qnode_kwargs) def new_qnode(*args): - return type(self)().eval(qfunc_jaxpr, consts, *args) - - res = new_qnode(invals[n_consts:]) + return type(self)(state=self.state).eval(qfunc_jaxpr, consts, *args) - for outvar, outval in zip(outvars, res): - self._env[outvar] = outval + return new_qnode(invals[n_consts:]) diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py index 084bf22e058..6f18d1a6f24 100644 --- a/pennylane/capture/interpreters.py +++ b/pennylane/capture/interpreters.py @@ -1,29 +1,95 @@ +from functools import partial, wraps +import jax from jax.tree_util import tree_flatten, tree_unflatten -from .base_interpreter import PlxprInterpreter - +from pennylane.compiler.qjit_api import ( + _get_for_loop_qfunc_prim, + _get_while_loop_qfunc_prim, + for_loop, + while_loop, +) from pennylane.devices.qubit import apply_operation, create_initial_state, measure +from pennylane.measurements.mid_measure import MidMeasureMP, _create_mid_measure_primitive +from pennylane.ops.op_math.condition import _get_cond_qfunc_prim from pennylane.tape import QuantumScript +from .base_interpreter import PlxprInterpreter + +for_prim = _get_for_loop_qfunc_prim() +midmeasure_prim = _create_mid_measure_primitive() +cond_prim = _get_cond_qfunc_prim() + + class DefaultQubitInterpreter(PlxprInterpreter): - def __init__(self, num_wires, state = None): + def __init__(self, num_wires=None, state=None): self.num_wires = num_wires - self._state = {"statevector": state} + self.state = state + + @property + def statevector(self): + return None if self.state is None else self.state["statevector"] + + @statevector.setter + def statevector(self, val): + if self.state is None: + self.state = {"statevector": val} + else: + self.state["statevector"] = val def setup(self): - if self._state is not None: - self._state = create_initial_state(range(self.num_wires)) + if self.statevector is None: + self.statevector = create_initial_state(range(self.num_wires)) def cleanup(self): - self._state = None + self.state = None def interpret_operation(self, op): - self._state = apply_operation(op, self._state) + self.statevector = apply_operation(op, self.statevector) + return op + + def interpret_measurement_eqn(self, eqn): + invals = [self.read(invar) for invar in eqn.invars] + mp = eqn.primitive.impl(*invals, **eqn.params) + return measure(mp, self.statevector) - def interpret_measurement(self, m): - return measure(m, self._state) + +@DefaultQubitInterpreter.register_primitive(for_prim) +def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): + start, stop, step = invals[0], invals[1], invals[2] + consts = invals[3 : 3 + n_consts] + init_state = invals[3 + n_consts :] + + res = None + for i in range(start, stop, step): + res = type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) + + return res + + +@DefaultQubitInterpreter.register_primitive(midmeasure_prim) +def handle_mm(self, *invals, reset, postselect): + mp = MidMeasureMP(invals, reset=reset, postselect=postselect) + mid_measurements = {} + self.statevector = apply_operation(mp, self.statevector, mid_measurements=mid_measurements) + return mid_measurements[mp] + + +@DefaultQubitInterpreter.register_primitive(cond_prim) +def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): + n_branches = len(jaxpr_branches) + conditions = invals[:n_branches] + consts_flat = invals[n_branches + n_args :] + args = invals[n_branches : n_branches + n_args] + + if conditions[0]: + return type(self)(state=self.state).eval( + jaxpr_branches[0].jaxpr, consts_flat[: n_consts_per_branch[0]], *args + ) + return type(self)(state=self.state).eval( + jaxpr_branches[-1].jaxpr, consts_flat[: n_consts_per_branch[0]], *args + ) from pennylane_lightning.lightning_qubit._state_vector import ( @@ -32,19 +98,17 @@ def interpret_measurement(self, m): ) -class LightningInterpreter(PlxprInterpreter): - - def __init__(self, num_wires): - self._num_wires = num_wires +class LightningInterpreter(DefaultQubitInterpreter): def setup(self): - self._state = LightningStateVector(self._num_wires) + if self.statevector is None: + self.statevector = LightningStateVector(self.num_wires) def interpret_operation(self, op): - self._state._apply_lightning([op]) + self.statevector._apply_lightning([op]) def interpret_measurement(self, m): - return LightningMeasurements(self._state).measurement(m) + return LightningMeasurements(self.statevector).measurement(m) class DecompositionInterpreter(PlxprInterpreter): @@ -86,19 +150,26 @@ class ConvertToTape(PlxprInterpreter): """ def setup(self): - self._ops = [] - self._measurements = [] + if self.state is None: + self.state = {"ops": [], "measurements": []} def interpret_operation(self, op): - self._ops.append(op) - - def interpret_measurement(self, m): - self._measurements.append(m) - return m - - def __call__(self, jaxpr, consts, *args): - out = super().__call__(jaxpr, consts, *args) - return QuantumScript(self._ops, self._measurements) + self.state["ops"].append(op) + + def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): + invals = [self.read(invar) for invar in eqn.invars] + mp = eqn.primitive.bind(*invals, **eqn.params) + self.state["measurements"].append(mp) + return mp + + def __call__(self, f): + @wraps(f) + def wrapper(*args, **kwargs): + jaxpr = jax.make_jaxpr(partial(f, **kwargs))(*args) + _ = self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + return QuantumScript(self.state["ops"], self.state["measurements"]) + + return wrapper class CancelInverses(PlxprInterpreter): diff --git a/pennylane/capture/primitives.py b/pennylane/capture/primitives.py index 09cf9c22a20..41cd7203582 100644 --- a/pennylane/capture/primitives.py +++ b/pennylane/capture/primitives.py @@ -193,7 +193,7 @@ def _(*args, **kwargs): split = None if n_wires == 0 else -n_wires # need to convert array values into integers # for plxpr, all wires must be integers - wires = tuple(int(w) for w in args[split:]) + wires = tuple(w for w in args[split:]) args = args[:split] return type.__call__(operator_type, *args, wires=wires, **kwargs) diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index ed60ad0d14f..daf62032a9c 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -688,6 +688,7 @@ def _call_capture_enabled(self, *init_state): jaxpr_body_fn = jax.make_jaxpr(flat_fn)(0, *init_state) flat_args, _ = jax.tree_util.tree_flatten(init_state) + print(jaxpr_body_fn) results = for_loop_prim.bind( self.lower_bound, self.upper_bound, From 6e2b33406a0749263aee6834a9d76cc76794b800 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 15 Aug 2024 15:45:22 -0400 Subject: [PATCH 08/45] something? --- pennylane/capture/interpreters.py | 13 +++++++++++++ pennylane/compiler/qjit_api.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py index 6f18d1a6f24..373eeeca92c 100644 --- a/pennylane/capture/interpreters.py +++ b/pennylane/capture/interpreters.py @@ -172,6 +172,19 @@ def wrapper(*args, **kwargs): return wrapper +@ConvertToTape.register_primitive(for_prim) +def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): + start, stop, step = invals[0], invals[1], invals[2] + consts = invals[3 : 3 + n_consts] + init_state = invals[3 + n_consts :] + + res = None + for i in range(start, stop, step): + res = type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) + + return res + + class CancelInverses(PlxprInterpreter): """ diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index daf62032a9c..dd8531ac789 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -688,7 +688,7 @@ def _call_capture_enabled(self, *init_state): jaxpr_body_fn = jax.make_jaxpr(flat_fn)(0, *init_state) flat_args, _ = jax.tree_util.tree_flatten(init_state) - print(jaxpr_body_fn) + results = for_loop_prim.bind( self.lower_bound, self.upper_bound, From a61acd283871902f80c0c82a838ada7d5245a52a Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 26 Aug 2024 09:48:03 -0400 Subject: [PATCH 09/45] Add PLxprInterpreter base class --- pennylane/capture/Interpreters_Demo.md | 207 ----------------- pennylane/capture/__init__.py | 16 +- pennylane/capture/base_interpreter.py | 196 +++++++++++++--- pennylane/capture/interpreters.py | 300 ------------------------- 4 files changed, 184 insertions(+), 535 deletions(-) delete mode 100644 pennylane/capture/Interpreters_Demo.md delete mode 100644 pennylane/capture/interpreters.py diff --git a/pennylane/capture/Interpreters_Demo.md b/pennylane/capture/Interpreters_Demo.md deleted file mode 100644 index 81d741933e8..00000000000 --- a/pennylane/capture/Interpreters_Demo.md +++ /dev/null @@ -1,207 +0,0 @@ -```python -import pennylane as qml -import jax - -from pennylane.capture.interpreters import PlxprInterpreter, DefaultQubitInterpreter, LightningInterpreter, DecompositionInterpreter, ConvertToTape, CancelInverses, MergeRotations -qml.capture.enable() -``` - -### Demonstrating Existing Implementations - - -```python -def f(x): - qml.X(0) - qml.adjoint(qml.X(0)) - qml.Hadamard(0) - qml.IsingXX(x, wires=(0,1)) - return qml.expval(qml.Z(0)), qml.probs(wires=(0,1)) - -plxpr = jax.make_jaxpr(f)(0.5) -``` - - -```python -DefaultQubitInterpreter(num_wires=2)(plxpr.jaxpr, plxpr.consts, 1.2) -``` - - - - - [0.0, array([0.34058944, 0.15941056, 0.34058944, 0.15941056])] - - - - -```python -LightningInterpreter(num_wires=2)(plxpr.jaxpr, plxpr.consts, 1.2) -``` - - - - - [0.0, array([0.34058944, 0.15941056, 0.34058944, 0.15941056])] - - - - -```python -tape = ConvertToTape()(plxpr.jaxpr, plxpr.consts, 1.2) -print(tape.draw()) -``` - - 0: ──X──X†──H─╭IsingXX─┤ ╭Probs - 1: ───────────╰IsingXX─┤ ╰Probs - - - -```python -DecompositionInterpreter().call_jaxpr(plxpr.jaxpr, plxpr.consts)(2.5) -``` - - - - - { lambda ; a:f64[]. let - _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 - _:AbstractOperator() = RX[n_wires=1] 3.141592653589793 0 - _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 - _:AbstractOperator() = PauliX[n_wires=1] 0 - _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 - _:AbstractOperator() = RX[n_wires=1] 1.5707963267948966 0 - _:AbstractOperator() = PhaseShift[n_wires=1] 1.5707963267948966 0 - _:AbstractOperator() = CNOT[n_wires=2] 0 1 - _:AbstractOperator() = RX[n_wires=1] a 0 - _:AbstractOperator() = CNOT[n_wires=2] 0 1 - b:AbstractOperator() = PauliZ[n_wires=1] 0 - c:AbstractMeasurement(n_wires=None) = expval_obs b - d:AbstractMeasurement(n_wires=2) = probs_wires 0 1 - in (c, d) } - - - - -```python -CancelInverses().call_jaxpr(plxpr.jaxpr, plxpr.consts)(2.5) -``` - - - - - { lambda ; a:f64[]. let - _:AbstractOperator() = IsingXX[n_wires=2] a 0 1 - b:AbstractOperator() = PauliZ[n_wires=1] 0 - c:AbstractMeasurement(n_wires=None) = expval_obs b - d:AbstractMeasurement(n_wires=2) = probs_wires 0 1 - in (c, d) } - - - - -```python -def g(x): - qml.RX(x, 0) - qml.RX(2*x, 0) - qml.RX(-4*x, 0) - qml.X(0) - qml.RX(0.5, 0) - -plxpr = jax.make_jaxpr(g)(1.0) -MergeRotations().call_jaxpr(plxpr.jaxpr, plxpr.consts)(1.0) -``` - - - - - { lambda ; a:f64[]. let - b:f64[] = mul 2.0 a - c:f64[] = add b a - d:f64[] = mul -4.0 a - e:f64[] = add d c - _:AbstractOperator() = RX[n_wires=1] e 0 - _:AbstractOperator() = PauliX[n_wires=1] 0 - _:AbstractOperator() = RX[n_wires=1] 0.5 0 - in () } - - - -### Writing a new interpreter - - -```python -class AddSWAPNoise(PlxprInterpreter): - - def __init__(self, scale, prng_key=jax.random.key(12345)): - self.scale = scale - self.prng_key = prng_key - - def interpret_operation(self, op): - if isinstance(op, qml.SWAP): - self.prng_key, subkey = jax.random.split(self.prng_key) - phi = self.scale*jax.random.uniform(subkey) - qml.PhaseShift(phi, op.wires[0]) - val, structure = jax.tree_util.tree_flatten(op) - jax.tree_util.tree_unflatten(structure, val) - - def interpret_measurement(self, m): - vals, structure = jax.tree_util.tree_flatten(m) - return jax.tree_util.tree_unflatten(structure, vals) -``` - - -```python -def f(): - qml.SWAP((0,1)) - qml.SWAP((1,2)) - return qml.expval(qml.Z(0)) - -plxpr = jax.make_jaxpr(f)() -AddSWAPNoise(0.1).call_jaxpr(plxpr.jaxpr, plxpr.consts)() -``` - - - - - let _uniform = { lambda ; a:key[] b:f64[] c:f64[]. let - d:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b - e:f64[] = convert_element_type[new_dtype=float64 weak_type=False] c - f:u64[] = random_bits[bit_width=64 shape=()] a - g:u64[] = shift_right_logical f 12 - h:u64[] = or g 4607182418800017408 - i:f64[] = bitcast_convert_type[new_dtype=float64] h - j:f64[] = sub i 1.0 - k:f64[] = sub e d - l:f64[] = mul j k - m:f64[] = add l d - n:f64[] = reshape[dimensions=None new_sizes=()] m - o:f64[] = max d n - in (o,) } in - { lambda p:key[]; . let - q:key[2] = random_split[shape=(2,)] p - r:key[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] q - s:key[] = squeeze[dimensions=(0,)] r - t:key[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] q - u:key[] = squeeze[dimensions=(0,)] t - v:f64[] = pjit[name=_uniform jaxpr=_uniform] u 0.0 1.0 - w:f64[] = mul 0.1 v - _:AbstractOperator() = PhaseShift[n_wires=1] w 0 - _:AbstractOperator() = SWAP[n_wires=2] 0 1 - x:key[2] = random_split[shape=(2,)] s - y:key[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] x - _:key[] = squeeze[dimensions=(0,)] y - z:key[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] x - ba:key[] = squeeze[dimensions=(0,)] z - bb:f64[] = pjit[name=_uniform jaxpr=_uniform] ba 0.0 1.0 - bc:f64[] = mul 0.1 bb - _:AbstractOperator() = PhaseShift[n_wires=1] bc 1 - _:AbstractOperator() = SWAP[n_wires=2] 1 2 - bd:AbstractOperator() = PauliZ[n_wires=1] 0 - be:AbstractMeasurement(n_wires=None) = expval_obs bd - in (be,) } - - - - -```python - -``` diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 98637c2d425..28b3f45a8fc 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -141,24 +141,33 @@ def _(*args, **kwargs): AbstractOperator: type AbstractMeasurement: type qnode_prim: "jax.core.Primitive" +PlxprInterpreter: type # pylint: disable=redefined-outer-name +# pylint: disable=import-outside-toplevel, redefined-outer-name def __getattr__(key): if key == "AbstractOperator": - from .primitives import _get_abstract_operator # pylint: disable=import-outside-toplevel + from .primitives import _get_abstract_operator return _get_abstract_operator() if key == "AbstractMeasurement": - from .primitives import _get_abstract_measurement # pylint: disable=import-outside-toplevel + from .primitives import _get_abstract_measurement return _get_abstract_measurement() if key == "qnode_prim": - from .capture_qnode import _get_qnode_prim # pylint: disable=import-outside-toplevel + from .capture_qnode import _get_qnode_prim return _get_qnode_prim() + if key == "PlxprInterpreter": + from .base_interpreter import ( + PlxprInterpreter, + ) + + return PlxprInterpreter + raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'") @@ -176,4 +185,5 @@ def __getattr__(key): "AbstractOperator", "AbstractMeasurement", "qnode_prim", + "PlxprInterpreter", ) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 68a10c0b8a5..787a3c7b388 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -1,9 +1,26 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule defines a strategy structure for defining custom plxpr interpreters +""" + import copy from functools import partial, wraps -from typing import Optional +from typing import Callable +# note: the module has a jax dependency and cannot exist in the standard import path for now. import jax -from jax.tree_util import tree_flatten, tree_unflatten import pennylane as qml from pennylane import cond @@ -13,10 +30,9 @@ for_loop, while_loop, ) +from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim, adjoint from pennylane.ops.op_math.condition import _get_cond_qfunc_prim -from pennylane.tape import QuantumScript -from pennylane.transforms.optimization.cancel_inverses import _are_inverses -from pennylane.workflow import qnode +from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim, ctrl from .capture_qnode import _get_qnode_prim from .primitives import _get_abstract_measurement, _get_abstract_operator @@ -25,15 +41,30 @@ while_prim = _get_while_loop_qfunc_prim() cond_prim = _get_cond_qfunc_prim() qnode_prim = _get_qnode_prim() +adjoint_transform_prim = _get_adjoint_qfunc_prim() +ctrl_transform_prim = _get_ctrl_qfunc_prim() AbstractOperator = _get_abstract_operator() AbstractMeasurement = _get_abstract_measurement() class PlxprInterpreter: + """A template base class for defining plxpr interpreters + + Args: + state (Any): any kind of information that may need to get carried around between different interpreters. + + **State property:** + + Higher order primitives can often be handled by a separate interpreter, but need to reference or modify the same values. + For example, a device interpreter may need to modify a statevector, or conversion to a tape may need to modify operations + and measurement lists. By maintaining this information in the optional ``state`` property, this information can automatically + by passed to new sub-interpreters. + + """ _env: dict - _primitive_registrations = {} + _primitive_registrations: dict["jax.core.Primitive", Callable] = {} def __init_subclass__(cls) -> None: cls._primitive_registrations = copy.copy(cls._primitive_registrations) @@ -43,13 +74,36 @@ def __init__(self, state=None): self.state = state @classmethod - def register_primitive(cls, primitive): - def decorator(f): + def register_primitive(cls, primitive: "jax.core.Primitive") -> Callable[[Callable], Callable]: + """Registers a custom method for handling a primitive + + Args: + primitive (jax.core.Primitive): the primitive we want custom behavior for + + Returns: + Callable: a decorator for adding a function to the custom registrations map + + Side Effect: + Calling the returned decorator with a function will place the function into the + primitive registrations map. + + ``` + my_primitive = jax.core.Primitive("my_primitve") + + @Interpreter_Type.register(my_primitive) + def handle_my_primitive(self: Interpreter_Type, *invals, **params) + return invals[0] + invals[1] # some sort of custom handling + ``` + + """ + + def decorator(f: Callable) -> Callable: cls._primitive_registrations[primitive] = f return f return decorator + # pylint: disable=unidiomatic-typecheck def read(self, var): """Extract the value corresponding to a variable.""" if self._env is None: @@ -57,26 +111,82 @@ def read(self, var): return var.val if type(var) is jax.core.Literal else self._env[var] def setup(self): - pass + """Initialize the instance before interpretting equations. + + Blank by default, this method can initialize any additional instance variables + needed by an interpreter + + """ def cleanup(self): - pass + """Perform any final steps after iterating through all equations. + + Blank by default, this method can clean up instance variables, or perform + equations that have been deffered till later. + + """ def interpret_operation(self, op: "pennylane.operation.Operator"): - raise NotImplementedError + """Interpret a PennyLane operation instance. - def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): - invals = [self.read(invar) for invar in eqn.invars] - op = eqn.primitive.impl(*invals, **eqn.params) - if isinstance(eqn.outvars[0], jax.core.DropVar): - return self.interpret_operation(op) + Args: + op (Operator): a pennylane operator instance + + Returns: + Any + + This method is only called when the operator's output is a dropped variable, + so the output will not effect later equations in the circuit. + + See also: :meth:`~.interpret_operation_eqn`. + + """ return op - def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): - invals = [self.read(invar) for invar in eqn.invars] - return eqn.primitive.bind(*invals, **eqn.params) + def interpret_operation_eqn(self, primitive, *invals, is_drop_var, **params): + """Interpret an equation corresponding to an operator. + + Args: + primitive (jax.core.Primitive): a jax primitive corresponding to an operation + *invals (Any): the positional input variables for the equation + + Keyword Args: + is_drop_var (bool): whether or not the equation's output is a dropped variable + **params: The equations parameters dictionary + + See also: :meth:`~.interpret_operation`. + + """ + if is_drop_var: + op = primitive.impl(*invals, **params) + return self.interpret_operation(op) + return primitive.bind(*invals, **params) + + def interpret_measurement_eqn(self, primitive, *invals, **params): + """Interpret an equation corresponding to a measurement process. + + Args: + primitive (jax.core.Primitive): a jax primitive corresponding to a measurement. + *invals (Any): the positional input variables for the equation + + Keyword Args: + **params: The equations parameters dictionary + + """ + return primitive.bind(*invals, **params) + + def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: + """Evaluate a jaxpr. + + Args: + jaxpr (jax.core.Jaxpr): the jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args (tuple[TensorLike]): The arguments for the jaxpr. - def eval(self, jaxpr, consts, *args): + Returns: + list[TensorLike]: the results of the execution. + + """ self._env = {} self.setup() @@ -92,9 +202,12 @@ def eval(self, jaxpr, consts, *args): if custom_handler: outvals = custom_handler(self, *invals, **eqn.params) elif isinstance(eqn.outvars[0].aval, AbstractOperator): - outvals = self.interpret_operation_eqn(eqn) + is_drop_var = isinstance(eqn.outvars[0], jax.core.DropVar) + outvals = self.interpret_operation_eqn( + eqn.primitive, *invals, is_drop_var=is_drop_var, **eqn.params + ) elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): - outvals = self.interpret_measurement_eqn(eqn) + outvals = self.interpret_measurement_eqn(eqn.primitive, *invals, **eqn.params) else: outvals = eqn.primitive.bind(*invals, **eqn.params) @@ -105,9 +218,9 @@ def eval(self, jaxpr, consts, *args): self.cleanup() # Read the final result of the Jaxpr from the environment - return [self._env[outvar] for outvar in jaxpr.outvars] + return [self.read(outvar) for outvar in jaxpr.outvars] - def __call__(self, f): + def __call__(self, f: Callable) -> Callable: @wraps(f) def wrapper(*args, **kwargs): jaxpr = jax.make_jaxpr(partial(f, **kwargs))(*args) @@ -116,8 +229,37 @@ def wrapper(*args, **kwargs): return wrapper +@PlxprInterpreter.register_primitive(adjoint_transform_prim) +def handle_adjoint_transform(self, *invals, jaxpr, lazy, n_consts): + """Interpret an adjoint transform primitive.""" + consts = invals[:n_consts] + args = invals[n_consts:] + + def new_qfunc(*inner_args): + return type(self)(state=self.state).eval(jaxpr, consts, *inner_args) + + return adjoint(new_qfunc, lazy=lazy)(*args) + + +# pylint: disable=too-many-arguments +@PlxprInterpreter.register_primitive(ctrl_transform_prim) +def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): + """Interpret a ctrl transform primitive.""" + consts = invals[:n_consts] + control_wires = invals[-n_control:] + args = invals[n_consts:-n_control] + + def new_qfunc(*inner_args): + return type(self)(state=self.state).eval(jaxpr, consts, *inner_args) + + return ctrl( + new_qfunc, control_values=control_values, control=control_wires, work_wires=work_wires + )(*args) + + @PlxprInterpreter.register_primitive(for_prim) def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): + """Handle a for loop primitive.""" start, stop, step = invals[0], invals[1], invals[2] consts = invals[3 : 3 + n_consts] @@ -130,6 +272,7 @@ def g(i, *init_state): @PlxprInterpreter.register_primitive(cond_prim) def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): + """Handle a cond primitive.""" n_branches = len(jaxpr_branches) conditions = invals[:n_branches] consts_flat = invals[n_branches + n_args :] @@ -150,6 +293,7 @@ def false_fn(*args): @PlxprInterpreter.register_primitive(while_prim) def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): + """Handle a while loop primitive.""" consts_body = invals[:n_consts_body] consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] init_state = invals[n_consts_body + n_consts_cond :] @@ -164,12 +308,14 @@ def loop(*args): return loop(*init_state) +# pylint: disable=unused-argument, too-many-arguments @PlxprInterpreter.register_primitive(qnode_prim) def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts): + """Handle a qnode primitive.""" consts = invals[:n_consts] @qml.qnode(device, **qnode_kwargs) def new_qnode(*args): return type(self)(state=self.state).eval(qfunc_jaxpr, consts, *args) - return new_qnode(invals[n_consts:]) + return new_qnode(invals[n_consts:], shots=shots) diff --git a/pennylane/capture/interpreters.py b/pennylane/capture/interpreters.py deleted file mode 100644 index 373eeeca92c..00000000000 --- a/pennylane/capture/interpreters.py +++ /dev/null @@ -1,300 +0,0 @@ -from functools import partial, wraps - -import jax -from jax.tree_util import tree_flatten, tree_unflatten - -from pennylane.compiler.qjit_api import ( - _get_for_loop_qfunc_prim, - _get_while_loop_qfunc_prim, - for_loop, - while_loop, -) -from pennylane.devices.qubit import apply_operation, create_initial_state, measure -from pennylane.measurements.mid_measure import MidMeasureMP, _create_mid_measure_primitive -from pennylane.ops.op_math.condition import _get_cond_qfunc_prim -from pennylane.tape import QuantumScript - -from .base_interpreter import PlxprInterpreter - -for_prim = _get_for_loop_qfunc_prim() -midmeasure_prim = _create_mid_measure_primitive() -cond_prim = _get_cond_qfunc_prim() - - -class DefaultQubitInterpreter(PlxprInterpreter): - - def __init__(self, num_wires=None, state=None): - self.num_wires = num_wires - self.state = state - - @property - def statevector(self): - return None if self.state is None else self.state["statevector"] - - @statevector.setter - def statevector(self, val): - if self.state is None: - self.state = {"statevector": val} - else: - self.state["statevector"] = val - - def setup(self): - if self.statevector is None: - self.statevector = create_initial_state(range(self.num_wires)) - - def cleanup(self): - self.state = None - - def interpret_operation(self, op): - self.statevector = apply_operation(op, self.statevector) - return op - - def interpret_measurement_eqn(self, eqn): - invals = [self.read(invar) for invar in eqn.invars] - mp = eqn.primitive.impl(*invals, **eqn.params) - return measure(mp, self.statevector) - - -@DefaultQubitInterpreter.register_primitive(for_prim) -def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): - start, stop, step = invals[0], invals[1], invals[2] - consts = invals[3 : 3 + n_consts] - init_state = invals[3 + n_consts :] - - res = None - for i in range(start, stop, step): - res = type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) - - return res - - -@DefaultQubitInterpreter.register_primitive(midmeasure_prim) -def handle_mm(self, *invals, reset, postselect): - mp = MidMeasureMP(invals, reset=reset, postselect=postselect) - mid_measurements = {} - self.statevector = apply_operation(mp, self.statevector, mid_measurements=mid_measurements) - return mid_measurements[mp] - - -@DefaultQubitInterpreter.register_primitive(cond_prim) -def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): - n_branches = len(jaxpr_branches) - conditions = invals[:n_branches] - consts_flat = invals[n_branches + n_args :] - args = invals[n_branches : n_branches + n_args] - - if conditions[0]: - return type(self)(state=self.state).eval( - jaxpr_branches[0].jaxpr, consts_flat[: n_consts_per_branch[0]], *args - ) - return type(self)(state=self.state).eval( - jaxpr_branches[-1].jaxpr, consts_flat[: n_consts_per_branch[0]], *args - ) - - -from pennylane_lightning.lightning_qubit._state_vector import ( - LightningMeasurements, - LightningStateVector, -) - - -class LightningInterpreter(DefaultQubitInterpreter): - - def setup(self): - if self.statevector is None: - self.statevector = LightningStateVector(self.num_wires) - - def interpret_operation(self, op): - self.statevector._apply_lightning([op]) - - def interpret_measurement(self, m): - return LightningMeasurements(self.statevector).measurement(m) - - -class DecompositionInterpreter(PlxprInterpreter): - """ - >>> def f(x): - ... qml.IsingXX(x, wires=(0,1)) - ... qml.Rot(0.5, x, 1.5, wires=1) - >>> jaxpr = jax.make_jaxpr(f)(0.5) - >>> DecompositionInterpreter().call_jaxpr(jaxpr.jaxpr, jaxpr.consts)(0.5) - { lambda ; a:f32[]. let - _:AbstractOperator() = CNOT[n_wires=2] 0 1 - _:AbstractOperator() = RX[n_wires=1] a 0 - _:AbstractOperator() = CNOT[n_wires=2] 0 1 - _:AbstractOperator() = RZ[n_wires=1] 0.5 1 - _:AbstractOperator() = RY[n_wires=1] a 1 - _:AbstractOperator() = RZ[n_wires=1] 1.5 1 - in () } - - """ - - def interpret_operation(self, op): - if op.has_decomposition: - op.decomposition() - else: - vals, structure = tree_flatten(op) - tree_unflatten(structure, vals) - - -class ConvertToTape(PlxprInterpreter): - """ - - >>> def f(x): - ... qml.RX(x, wires=0) - ... return qml.expval(qml.Z(0)) - >>> jaxpr = jax.make_jaxpr(f)(0.5) - >>> ConvertToTape()(jaxpr.jaxpr, jaxpr.consts, 1.2).circuit - [RX(1.2, wires=[0]), expval(Z(0))] - - """ - - def setup(self): - if self.state is None: - self.state = {"ops": [], "measurements": []} - - def interpret_operation(self, op): - self.state["ops"].append(op) - - def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): - invals = [self.read(invar) for invar in eqn.invars] - mp = eqn.primitive.bind(*invals, **eqn.params) - self.state["measurements"].append(mp) - return mp - - def __call__(self, f): - @wraps(f) - def wrapper(*args, **kwargs): - jaxpr = jax.make_jaxpr(partial(f, **kwargs))(*args) - _ = self.eval(jaxpr.jaxpr, jaxpr.consts, *args) - return QuantumScript(self.state["ops"], self.state["measurements"]) - - return wrapper - - -@ConvertToTape.register_primitive(for_prim) -def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): - start, stop, step = invals[0], invals[1], invals[2] - consts = invals[3 : 3 + n_consts] - init_state = invals[3 + n_consts :] - - res = None - for i in range(start, stop, step): - res = type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) - - return res - - -class CancelInverses(PlxprInterpreter): - """ - - >>> def f(x): - ... qml.X(0) - ... qml.X(0) - ... qml.Hadamard(0) - ... qml.Y(1) - ... qml.RX(x, 0) - ... qml.adjoint(qml.RX(x, 0)) - >>> jaxpr = jax.make_jaxpr(f)(0.5) - >>> CancelInverses().call_jaxpr(jaxpr.jaxpr, jaxpr.consts)(0.5) - { lambda ; a:f64[]. let - _:AbstractOperator() = Hadamard[n_wires=1] 0 - _:AbstractOperator() = PauliY[n_wires=1] 1 - in () } - - """ - - _last_op_on_wires = None - - def setup(self): - self._last_op_on_wires = {} - - def interpret_operation(self, op): - if len(op.wires) != 1: - for w in op.wires: - self._last_op_on_wires[w] = None - vals, structure = tree_flatten(op) - tree_unflatten(structure, vals) - return - - w = op.wires[0] - if w in self._last_op_on_wires: - if _are_inverses(self._last_op_on_wires[w], op): - self._last_op_on_wires[w] = None - return - previous_op = self._last_op_on_wires[w] - if previous_op is not None: - vals, structure = tree_flatten(previous_op) - tree_unflatten(structure, vals) - self._last_op_on_wires[w] = op - return - - def interpret_measurement(self, m): - vals, structure = tree_flatten(m) - return tree_unflatten(structure, vals) - - def cleanup(self): - for _, op in self._last_op_on_wires.items(): - if op is not None: - vals, structure = tree_flatten(op) - tree_unflatten(structure, vals) - - -class MergeRotations(PlxprInterpreter): - """ - - >>> def g(x): - ... qml.RX(x, 0) - ... qml.RX(2*x, 0) - ... qml.RX(-4*x, 0) - ... qml.X(0) - ... qml.RX(0.5, 0) - >>> plxpr = jax.make_jaxpr(g)(1.0) - >>> MergeRotations().call_jaxpr(plxpr.jaxpr, plxpr.consts)(1.0) - { lambda ; a:f64[]. let - b:f64[] = mul 2.0 a - c:f64[] = add b a - d:f64[] = mul -4.0 a - e:f64[] = add d c - _:AbstractOperator() = RX[n_wires=1] e 0 - _:AbstractOperator() = PauliX[n_wires=1] 0 - _:AbstractOperator() = RX[n_wires=1] 0.5 0 - in () } - - """ - - _last_op_on_wires = None - - def setup(self): - self._last_op_on_wires = {} - - def interpret_operation(self, op): - if len(op.wires) != 1: - for w in op.wires: - self._last_op_on_wires[w] = None - vals, structure = tree_flatten(op) - tree_unflatten(structure, vals) - return - - w = op.wires[0] - if w in self._last_op_on_wires: - previous_op = self._last_op_on_wires[w] - if op.name == previous_op.name and op.wires == previous_op.wires: - new_data = [d1 + d2 for d1, d2 in zip(op.data, previous_op.data)] - self._last_op_on_wires[w] = op._primitive.impl(*new_data, wires=op.wires) - return - if previous_op is not None: - vals, structure = tree_flatten(previous_op) - tree_unflatten(structure, vals) - self._last_op_on_wires[w] = op - return - - def interpret_measurement(self, m): - vals, structure = tree_flatten(m) - return tree_unflatten(structure, vals) - - def cleanup(self): - for _, op in self._last_op_on_wires.items(): - if op is not None: - vals, structure = tree_flatten(op) - tree_unflatten(structure, vals) From 1609ed38d40a371762ac06c8072cf1d5c7a6cf83 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 26 Aug 2024 10:52:03 -0400 Subject: [PATCH 10/45] qnode fix --- pennylane/capture/base_interpreter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 787a3c7b388..97773d53749 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -318,4 +318,4 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, def new_qnode(*args): return type(self)(state=self.state).eval(qfunc_jaxpr, consts, *args) - return new_qnode(invals[n_consts:], shots=shots) + return new_qnode(*invals[n_consts:], shots=shots) From f593d45db1391c4a349f55933db6ee604b2412d7 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 27 Aug 2024 15:55:35 -0400 Subject: [PATCH 11/45] starting to write tests --- pennylane/capture/__init__.py | 1 + tests/capture/test_base_interpreter.py | 80 ++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 tests/capture/test_base_interpreter.py diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 28b3f45a8fc..4f215e875e3 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -34,6 +34,7 @@ ~create_measurement_wires_primitive ~create_measurement_mcm_primitive ~qnode_call + ~PlxprInterpreter To activate and deactivate the new PennyLane program capturing mechanism, use the switches ``qml.capture.enable`` and ``qml.capture.disable``. diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py new file mode 100644 index 00000000000..c4681eba2b2 --- /dev/null +++ b/tests/capture/test_base_interpreter.py @@ -0,0 +1,80 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This submodule tests strategy structure for defining custom plxpr interpreters +""" + +import pytest + +import pennylane as qml + +jax = pytest.importorskip("jax") + +from pennylane.capture.base_interpreter import PlxprInterpreter + +pytestmark = pytest.mark.jax + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + """Enable and disable the PennyLane JAX capture context manager.""" + qml.capture.enable() + yield + qml.capture.disable() + + +class MapWiresInterpreter(PlxprInterpreter): + + def interpret_operation(self, op): + # fairly limited use case, but good enough for testing. + wire_map = {0: 5, 1: 6, 2: 7} + return type(op)(*op.data, wires=tuple(wire_map[w] for w in op.wires)) + + +def test_env_and_state_initialized(): + """Test that env and state are initialized at the start.""" + + interpreter = MapWiresInterpreter() + assert interpreter._env == {} + assert interpreter.state is None + + +def test_primitive_registrations(): + """Test that child primitive registrations dict's are not copied and do + not effect PlxprInterpreeter.""" + + assert ( + MapWiresInterpreter._primitive_registrations + is not PlxprInterpreter._primitive_registrations + ) + + @MapWiresInterpreter.register_primitive(qml.X._primitive) + def _(self, *invals, **params): + return qml.X(*invals) + + assert qml.X._primitive in MapWiresInterpreter._primitive_registrations + assert qml.X._primitive not in PlxprInterpreter._primitive_registrations + + @MapWiresInterpreter() + def f(): + qml.X(0) + qml.Y(0) + + jaxpr = jax.make_jaxpr(f)() + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, []) + + qml.assert_equal(q.queue[0], qml.X(0)) # not mapped due to primitive registration + qml.assert_equal(q.queue[1], qml.Y(5)) # mapped wire From 0defddea2575ba08676b96117f3417db9720c1a9 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 28 Aug 2024 16:30:18 -0400 Subject: [PATCH 12/45] trying to improve op math handling --- pennylane/capture/base_interpreter.py | 45 ++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 97773d53749..80cdb82dc26 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -65,12 +65,14 @@ class PlxprInterpreter: _env: dict _primitive_registrations: dict["jax.core.Primitive", Callable] = {} + _op_math_cache: dict def __init_subclass__(cls) -> None: cls._primitive_registrations = copy.copy(cls._primitive_registrations) def __init__(self, state=None): self._env = {} + self._op_math_cache = {} self.state = state @classmethod @@ -108,6 +110,8 @@ def read(self, var): """Extract the value corresponding to a variable.""" if self._env is None: raise ValueError("_env not yet initialized.") + if type(var) is jax.core.Literal: + return var.val return var.val if type(var) is jax.core.Literal else self._env[var] def setup(self): @@ -143,24 +147,35 @@ def interpret_operation(self, op: "pennylane.operation.Operator"): """ return op - def interpret_operation_eqn(self, primitive, *invals, is_drop_var, **params): + def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): """Interpret an equation corresponding to an operator. Args: primitive (jax.core.Primitive): a jax primitive corresponding to an operation + outvar *invals (Any): the positional input variables for the equation Keyword Args: - is_drop_var (bool): whether or not the equation's output is a dropped variable **params: The equations parameters dictionary See also: :meth:`~.interpret_operation`. """ - if is_drop_var: - op = primitive.impl(*invals, **params) + + invals = [ + ( + invar.val + if type(invar) is jax.core.Literal + else self._op_math_cache.get(invar, self.read(invar)) + ) + for invar in eqn.invars + ] + op = eqn.primitive.impl(*invals, **eqn.params) + if isinstance(eqn.outvars[0], jax.core.DropVar): return self.interpret_operation(op) - return primitive.bind(*invals, **params) + + self._op_math_cache[eqn.outvars[0]] = op + return op def interpret_measurement_eqn(self, primitive, *invals, **params): """Interpret an equation corresponding to a measurement process. @@ -188,6 +203,7 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: """ self._env = {} + self._op_math_cache = {} self.setup() for arg, invar in zip(args, jaxpr.invars): @@ -196,19 +212,18 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: self._env[constvar] = const for eqn in jaxpr.eqns: - invals = [self.read(invar) for invar in eqn.invars] custom_handler = self._primitive_registrations.get(eqn.primitive, None) if custom_handler: + invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) elif isinstance(eqn.outvars[0].aval, AbstractOperator): - is_drop_var = isinstance(eqn.outvars[0], jax.core.DropVar) - outvals = self.interpret_operation_eqn( - eqn.primitive, *invals, is_drop_var=is_drop_var, **eqn.params - ) + outvals = self.interpret_operation_eqn(eqn) elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): + invals = [self.read(invar) for invar in eqn.invars] outvals = self.interpret_measurement_eqn(eqn.primitive, *invals, **eqn.params) else: + invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) if not eqn.primitive.multiple_results: @@ -218,7 +233,15 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: self.cleanup() # Read the final result of the Jaxpr from the environment - return [self.read(outvar) for outvar in jaxpr.outvars] + outvals = [] + for var in jaxpr.outvars: + if var in self._op_math_cache: + outvals.append(self.interpret_operation(self._op_math_cache[var])) + else: + outvals.append(self.read(var)) + self._op_math_cache = {} + self._env = {} + return outvals def __call__(self, f: Callable) -> Callable: @wraps(f) From 0ecf1a856229556ed78fb116dc60905a9898ddd5 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 30 Aug 2024 11:24:28 -0400 Subject: [PATCH 13/45] testing --- pennylane/capture/base_interpreter.py | 32 +++++++++++-- pennylane/ops/op_math/pow.py | 6 ++- tests/capture/test_base_interpreter.py | 64 +++++++++++++++++++++----- 3 files changed, 84 insertions(+), 18 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 80cdb82dc26..0c9e499d2c2 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -61,6 +61,21 @@ class PlxprInterpreter: and measurement lists. By maintaining this information in the optional ``state`` property, this information can automatically by passed to new sub-interpreters. + + **Examples:** + + .. code-block:: python + + class SimplifyInterpreter(PlxprInterpreter): + + def interpret_operation(self, op): + new_op = op.simplify() + if new_op is op: + # if new op isn't queued, need to requeue op. + data, struct = jax.tree_util.tree_flatten(new_op) + new_op = jax.tree_util.tree_unflatten(struct, data) + return new_op + """ _env: dict @@ -261,7 +276,9 @@ def handle_adjoint_transform(self, *invals, jaxpr, lazy, n_consts): def new_qfunc(*inner_args): return type(self)(state=self.state).eval(jaxpr, consts, *inner_args) - return adjoint(new_qfunc, lazy=lazy)(*args) + jaxpr = jax.make_jaxpr(new_qfunc)(*args) + + return adjoint_transform_prim.bind(*invals, jaxpr=jaxpr.jaxpr, lazy=lazy, n_consts=n_consts) # pylint: disable=too-many-arguments @@ -275,9 +292,16 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_ def new_qfunc(*inner_args): return type(self)(state=self.state).eval(jaxpr, consts, *inner_args) - return ctrl( - new_qfunc, control_values=control_values, control=control_wires, work_wires=work_wires - )(*args) + jaxpr = jax.make_jaxpr(new_qfunc)(*args) + + return ctrl_transform_prim.bind( + *invals, + n_control=n_control, + jaxpr=jaxpr.jaxpr, + control_values=control_values, + work_wires=work_wires, + n_consts=n_consts, + ) @PlxprInterpreter.register_primitive(for_prim) diff --git a/pennylane/ops/op_math/pow.py b/pennylane/ops/op_math/pow.py index 6251fb6fdbb..68f1d1c5fc4 100644 --- a/pennylane/ops/op_math/pow.py +++ b/pennylane/ops/op_math/pow.py @@ -380,14 +380,16 @@ def simplify(self) -> Union["Pow", Identity]: pr.simplify() return pr.operation(wire_order=self.wires) - base = self.base.simplify() + base = self.base if qml.capture.enabled() else self.base.simplify() try: ops = base.pow(z=self.z) if not ops: return qml.Identity(self.wires) op = qml.prod(*ops) if len(ops) > 1 else ops[0] - return op.simplify() + return op if qml.capture.enabled() else op.simplify() except PowUndefinedError: + if qml.capture.enabled(): + return Pow(base.simplify(), z=self.z) return Pow(base=base, z=self.z) diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index c4681eba2b2..5d1e62c0737 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -34,18 +34,21 @@ def enable_disable_plxpr(): qml.capture.disable() -class MapWiresInterpreter(PlxprInterpreter): +class SimplifyInterpreter(PlxprInterpreter): def interpret_operation(self, op): - # fairly limited use case, but good enough for testing. - wire_map = {0: 5, 1: 6, 2: 7} - return type(op)(*op.data, wires=tuple(wire_map[w] for w in op.wires)) + new_op = op.simplify() + if new_op is op: + # if new op isn't queued, need to requeue op. + data, struct = jax.tree_util.tree_flatten(new_op) + new_op = jax.tree_util.tree_unflatten(struct, data) + return new_op def test_env_and_state_initialized(): """Test that env and state are initialized at the start.""" - interpreter = MapWiresInterpreter() + interpreter = SimplifyInterpreter() assert interpreter._env == {} assert interpreter.state is None @@ -55,26 +58,63 @@ def test_primitive_registrations(): not effect PlxprInterpreeter.""" assert ( - MapWiresInterpreter._primitive_registrations + SimplifyInterpreter._primitive_registrations is not PlxprInterpreter._primitive_registrations ) - @MapWiresInterpreter.register_primitive(qml.X._primitive) + @SimplifyInterpreter.register_primitive(qml.X._primitive) def _(self, *invals, **params): - return qml.X(*invals) + print("in custom interpreter") + return qml.Z(*invals) - assert qml.X._primitive in MapWiresInterpreter._primitive_registrations + assert qml.X._primitive in SimplifyInterpreter._primitive_registrations assert qml.X._primitive not in PlxprInterpreter._primitive_registrations - @MapWiresInterpreter() + @SimplifyInterpreter() def f(): qml.X(0) - qml.Y(0) + qml.Y(5) jaxpr = jax.make_jaxpr(f)() with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(jaxpr.jaxpr, []) - qml.assert_equal(q.queue[0], qml.X(0)) # not mapped due to primitive registration + print(jaxpr) + print(q.queue) + qml.assert_equal(q.queue[0], qml.Z(0)) # turned into a Y qml.assert_equal(q.queue[1], qml.Y(5)) # mapped wire + + # restore simplify interpreter to its previous state + SimplifyInterpreter._primitive_registrations.pop(qml.X._primitive) + + +class TestHigherOrderPrimitiveRegistrations: + + @pytest.mark.parametrize("lazy", (True, False)) + def test_adjoint_transform(self, lazy): + """Test the higher order adjoint transform.""" + + @SimplifyInterpreter() + def f(x): + def g(y): + qml.RX(y, 0) ** 3 + + qml.adjoint(g, lazy=lazy)(x) + + jaxpr = jax.make_jaxpr(f)(0.5) + + assert jaxpr.eqns[0].params["lazy"] == lazy + # assert jaxpr.eqns[0].primitive == adjoint_transform_primitive + inner_jaxpr = jaxpr.eqns[0].params["jaxpr"] + # first eqn mul, second RX + assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive + assert len(inner_jaxpr.eqns) == 2 + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5) + + if lazy: + qml.assert_equal(q.queue[0], qml.adjoint(qml.RX(jax.numpy.array(1.5), 0))) + else: + qml.assert_equal(q.queue[0], qml.RX(jax.numpy.array(-1.5), 0)) From 65202fa90e0c6724ac112f1e70507b1a4a72c24d Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 30 Aug 2024 16:32:27 -0400 Subject: [PATCH 14/45] improvementS --- pennylane/capture/base_interpreter.py | 82 +++++++++++++++------------ 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 0c9e499d2c2..94cdeba3ed5 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -30,11 +30,12 @@ for_loop, while_loop, ) -from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim, adjoint +from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim from pennylane.ops.op_math.condition import _get_cond_qfunc_prim -from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim, ctrl +from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim from .capture_qnode import _get_qnode_prim +from .flatfn import FlatFn from .primitives import _get_abstract_measurement, _get_abstract_operator for_prim = _get_for_loop_qfunc_prim() @@ -48,6 +49,15 @@ AbstractMeasurement = _get_abstract_measurement() +def jaxpr_to_jaxpr( + interpreter: "PlxprInterpreter", jaxpr: "jax.core.Jaxpr", consts, *args +) -> "jax.core.Jaxpr": + def f(*inner_args): + return interpreter.eval(jaxpr, consts, *inner_args) + + return jax.make_jaxpr(f)(*args).jaxpr + + class PlxprInterpreter: """A template base class for defining plxpr interpreters @@ -127,7 +137,7 @@ def read(self, var): raise ValueError("_env not yet initialized.") if type(var) is jax.core.Literal: return var.val - return var.val if type(var) is jax.core.Literal else self._env[var] + return self._op_math_cache.get(var, self._env[var]) def setup(self): """Initialize the instance before interpretting equations. @@ -177,14 +187,7 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): """ - invals = [ - ( - invar.val - if type(invar) is jax.core.Literal - else self._op_math_cache.get(invar, self.read(invar)) - ) - for invar in eqn.invars - ] + invals = [self.read(invar) for invar in eqn.invars] op = eqn.primitive.impl(*invals, **eqn.params) if isinstance(eqn.outvars[0], jax.core.DropVar): return self.interpret_operation(op) @@ -203,6 +206,9 @@ def interpret_measurement_eqn(self, primitive, *invals, **params): **params: The equations parameters dictionary """ + invals = [ + self.interpret_operation(op) for op in invals if isinstance(op, qml.operation.Operator) + ] return primitive.bind(*invals, **params) def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: @@ -259,10 +265,15 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: return outvals def __call__(self, f: Callable) -> Callable: + + flat_f = FlatFn(f) + @wraps(f) def wrapper(*args, **kwargs): - jaxpr = jax.make_jaxpr(partial(f, **kwargs))(*args) - return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + jaxpr = jax.make_jaxpr(partial(flat_f, **kwargs))(*args) + results = self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + assert flat_f.out_tree + return jax.tree_util.tree_unflatten(flat_f.out_tree, results) return wrapper @@ -273,12 +284,8 @@ def handle_adjoint_transform(self, *invals, jaxpr, lazy, n_consts): consts = invals[:n_consts] args = invals[n_consts:] - def new_qfunc(*inner_args): - return type(self)(state=self.state).eval(jaxpr, consts, *inner_args) - - jaxpr = jax.make_jaxpr(new_qfunc)(*args) - - return adjoint_transform_prim.bind(*invals, jaxpr=jaxpr.jaxpr, lazy=lazy, n_consts=n_consts) + jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) + return adjoint_transform_prim.bind(*invals, jaxpr=jaxpr, lazy=lazy, n_consts=n_consts) # pylint: disable=too-many-arguments @@ -286,18 +293,13 @@ def new_qfunc(*inner_args): def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): """Interpret a ctrl transform primitive.""" consts = invals[:n_consts] - control_wires = invals[-n_control:] args = invals[n_consts:-n_control] - - def new_qfunc(*inner_args): - return type(self)(state=self.state).eval(jaxpr, consts, *inner_args) - - jaxpr = jax.make_jaxpr(new_qfunc)(*args) + jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) return ctrl_transform_prim.bind( *invals, n_control=n_control, - jaxpr=jaxpr.jaxpr, + jaxpr=jaxpr, control_values=control_values, work_wires=work_wires, n_consts=n_consts, @@ -307,14 +309,16 @@ def new_qfunc(*inner_args): @PlxprInterpreter.register_primitive(for_prim) def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): """Handle a for loop primitive.""" - start, stop, step = invals[0], invals[1], invals[2] + start = invals[0] consts = invals[3 : 3 + n_consts] + init_state = invals[3 + n_consts :] - @for_loop(start, stop, step) - def g(i, *init_state): - return type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts, i, *init_state) + new_jaxpr_body_fn = jaxpr_to_jaxpr( + type(self)(state=self.state), jaxpr_body_fn.jaxpr, consts, start, *init_state + ) - return g(*invals[3 + n_consts :]) + new_jaxpr_body_fn = jax.core.ClosedJaxpr(new_jaxpr_body_fn, consts) + return for_prim.bind(*invals, jaxpr_body_fn=new_jaxpr_body_fn, n_consts=n_consts) @PlxprInterpreter.register_primitive(cond_prim) @@ -361,8 +365,16 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, """Handle a qnode primitive.""" consts = invals[:n_consts] - @qml.qnode(device, **qnode_kwargs) - def new_qnode(*args): - return type(self)(state=self.state).eval(qfunc_jaxpr, consts, *args) + new_qfunc_jaxpr = jaxpr_to_jaxpr( + type(self)(state=self.state), qfunc_jaxpr, consts, *invals[n_consts:] + ) - return new_qnode(*invals[n_consts:], shots=shots) + return qnode_prim.bind( + *invals, + shots=shots, + qnode=qnode, + device=device, + qnode_kwargs=qnode_kwargs, + qfunc_jaxpr=new_qfunc_jaxpr, + n_consts=n_consts, + ) From 22e32f78b3f6c43d633bb5e115c4e3ec9d9b79b6 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 3 Sep 2024 12:33:21 -0400 Subject: [PATCH 15/45] more fixes --- pennylane/capture/base_interpreter.py | 44 +++++++++++++++------------ 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 94cdeba3ed5..da2425c8abf 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -325,21 +325,19 @@ def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): """Handle a cond primitive.""" n_branches = len(jaxpr_branches) - conditions = invals[:n_branches] consts_flat = invals[n_branches + n_args :] + args = invals[n_branches : n_branches + n_args] - def true_fn(*args): - return type(self)(state=self.state).eval( - jaxpr_branches[0].jaxpr, consts_flat[: n_consts_per_branch[0]], *args - ) + new_jaxprs = [] + start = 0 + for n_consts, jaxpr in zip(n_consts_per_branch, jaxpr_branches): + consts = consts_flat[start : start + n_consts] + start += n_consts + new_jaxprs.append(jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr.jaxpr, consts, *args)) - def false_fn(*args): - return type(self)(state=self.state).eval( - jaxpr_branches[-1].jaxpr, consts_flat[n_consts_per_branch[0] :], *args - ) - - cond_fn = cond(conditions[0], true_fn, false_fn=false_fn) - return cond_fn(*invals[n_branches : n_branches + n_args]) + return cond_prim.bind( + *invals, jaxpr_brances=new_jaxprs, n_consts_per_branch=n_consts_per_branch, n_args=n_args + ) @PlxprInterpreter.register_primitive(while_prim) @@ -349,14 +347,22 @@ def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] init_state = invals[n_consts_body + n_consts_cond :] - def cond_fn(*args): - return jax.core.eval_jaxpr(jaxpr_cond_fn.jaxpr, consts_cond, *args) - - @while_loop(cond_fn) - def loop(*args): - return type(self)(state=self.state).eval(jaxpr_body_fn.jaxpr, consts_body, *args) + new_jaxpr_body_fn = jaxpr_to_jaxpr( + type(self)(state=self.state), jaxpr_body_fn.jaxpr, consts_body, *init_state + ) + new_jaxpr_body_fn = jax.core.ClosedJaxpr(new_jaxpr_body_fn, consts_body) + new_jaxpr_cond_fn = jaxpr_to_jaxpr( + type(self)(state=self.state), jaxpr_cond_fn.jaxpr, consts_cond, *init_state + ) + new_jaxpr_cond_fn = jax.core.ClosedJaxpr(new_jaxpr_cond_fn, consts_cond) - return loop(*init_state) + return while_prim.bind( + *invals, + jaxpr_body_fn=new_jaxpr_body_fn, + jaxpr_bond_fn=new_jaxpr_cond_fn, + n_consts_body=n_consts_body, + n_consts_cond=n_consts_cond, + ) # pylint: disable=unused-argument, too-many-arguments From ba298eba8f906bc79bd483531fe23f197ff7e801 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 1 Oct 2024 14:06:32 -0400 Subject: [PATCH 16/45] adding tests --- pennylane/capture/base_interpreter.py | 151 +++++++++++++++------- pennylane/capture/capture_operators.py | 3 +- pennylane/capture/primitives.py | 4 +- pennylane/ops/op_math/adjoint.py | 8 +- pennylane/pauli/pauli_arithmetic.py | 14 +- tests/capture/test_base_interpreter.py | 170 +++++++++++++++++++++++-- 6 files changed, 288 insertions(+), 62 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index da2425c8abf..09260b8620e 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -19,39 +19,30 @@ from functools import partial, wraps from typing import Callable -# note: the module has a jax dependency and cannot exist in the standard import path for now. import jax import pennylane as qml -from pennylane import cond -from pennylane.compiler.qjit_api import ( - _get_for_loop_qfunc_prim, - _get_while_loop_qfunc_prim, - for_loop, - while_loop, -) -from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim -from pennylane.ops.op_math.condition import _get_cond_qfunc_prim -from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim -from .capture_qnode import _get_qnode_prim from .flatfn import FlatFn -from .primitives import _get_abstract_measurement, _get_abstract_operator - -for_prim = _get_for_loop_qfunc_prim() -while_prim = _get_while_loop_qfunc_prim() -cond_prim = _get_cond_qfunc_prim() -qnode_prim = _get_qnode_prim() -adjoint_transform_prim = _get_adjoint_qfunc_prim() -ctrl_transform_prim = _get_ctrl_qfunc_prim() - -AbstractOperator = _get_abstract_operator() -AbstractMeasurement = _get_abstract_measurement() +from .primitives import ( + AbstractMeasurement, + AbstractOperator, + adjoint_transform_prim, + cond_prim, + ctrl_transform_prim, + for_loop_prim, + grad_prim, + jacobian_prim, + qnode_prim, + while_loop_prim, +) def jaxpr_to_jaxpr( interpreter: "PlxprInterpreter", jaxpr: "jax.core.Jaxpr", consts, *args ) -> "jax.core.Jaxpr": + """A convenience uility for converting jaxpr to a new jaxpr via an interpreter.""" + def f(*inner_args): return interpreter.eval(jaxpr, consts, *inner_args) @@ -76,15 +67,57 @@ class PlxprInterpreter: .. code-block:: python + import jax + from pennylane.capture import PlxprInterpreter + class SimplifyInterpreter(PlxprInterpreter): - def interpret_operation(self, op): - new_op = op.simplify() - if new_op is op: - # if new op isn't queued, need to requeue op. - data, struct = jax.tree_util.tree_flatten(new_op) - new_op = jax.tree_util.tree_unflatten(struct, data) - return new_op + def interpret_operation(self, op): + new_op = qml.simplify(op) + if new_op is op: + # if new op isn't queued, need to requeue op. + data, struct = jax.tree_util.tree_flatten(new_op) + new_op = jax.tree_util.tree_unflatten(struct, data) + return new_op + + Now the interpreter can be used to transform functions and jaxpr: + + >>> interpreter = SimplifyInterpreter() + >>> def f(x): + ... qml.RX(x, 0)**2 + ... qml.adjoint(qml.Z(0)) + ... return qml.expval(qml.X(0) + qml.X(0)) + >>> simplified_f = interpreter(f) + >>> print(qml.draw(simplified_f)(0.5) + 0: ──RX(1.00)──Z─┤ <2.00*X> + >>> jaxpr = jax.make_jaxpr(f)(0.5) + >>> interpreter.eval(jaxpr.jaxpr, [], 0.5) + [expval(2.0 * X(0))] + + It will also preserve higher order primitives by default: + + >>> def g(x): + ... @qml.for_loop(3) + ... def loop(i, x): + ... qml.RX(x, 0) ** i + ... return x + ... loop(1.0) + ... return qml.expval(qml.Z(0) + 3*qml.Z(0)) + >>> jax.make_jaxpr(interpreter(g))(0.5) + { lambda ; a:f32[]. let + _:f32[] = for_loop[ + jaxpr_body_fn={ lambda ; b:i32[] c:f32[]. let + d:f32[] = convert_element_type[new_dtype=float32 weak_type=True] b + e:f32[] = mul c d + _:AbstractOperator() = RX[n_wires=1] e 0 + in (c,) } + n_consts=0 + ] 0 3 1 1.0 + f:AbstractOperator() = PauliZ[n_wires=1] 0 + g:AbstractOperator() = SProd[_pauli_rep=4.0 * Z(0)] 4.0 f + h:AbstractMeasurement(n_wires=None) = expval_obs g + in (h,) } + """ @@ -143,7 +176,8 @@ def setup(self): """Initialize the instance before interpretting equations. Blank by default, this method can initialize any additional instance variables - needed by an interpreter + needed by an interpreter. For example, a device interpreter could initialize a statevector, + or a compilation interpreter could initialize a staging area for the latest operation on each wire. """ @@ -151,7 +185,9 @@ def cleanup(self): """Perform any final steps after iterating through all equations. Blank by default, this method can clean up instance variables, or perform - equations that have been deffered till later. + equations that have been deffered till later. For example, if a compilation + interpreter has a staging area for the latest operation on each wire, the cleanup method + could clear out the staging area. """ @@ -170,7 +206,8 @@ def interpret_operation(self, op: "pennylane.operation.Operator"): See also: :meth:`~.interpret_operation_eqn`. """ - return op + data, struct = jax.tree_util.tree_flatten(op) + return jax.tree_util.tree_unflatten(struct, data) def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): """Interpret an equation corresponding to an operator. @@ -187,8 +224,9 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): """ - invals = [self.read(invar) for invar in eqn.invars] - op = eqn.primitive.impl(*invals, **eqn.params) + invals = (self.read(invar) for invar in eqn.invars) + with qml.QueuingManager.stop_recording(): + op = eqn.primitive.impl(*invals, **eqn.params) if isinstance(eqn.outvars[0], jax.core.DropVar): return self.interpret_operation(op) @@ -206,9 +244,9 @@ def interpret_measurement_eqn(self, primitive, *invals, **params): **params: The equations parameters dictionary """ - invals = [ + invals = ( self.interpret_operation(op) for op in invals if isinstance(op, qml.operation.Operator) - ] + ) return primitive.bind(*invals, **params) def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: @@ -270,7 +308,8 @@ def __call__(self, f: Callable) -> Callable: @wraps(f) def wrapper(*args, **kwargs): - jaxpr = jax.make_jaxpr(partial(flat_f, **kwargs))(*args) + with qml.QueuingManager.stop_recording(): + jaxpr = jax.make_jaxpr(partial(flat_f, **kwargs))(*args) results = self.eval(jaxpr.jaxpr, jaxpr.consts, *args) assert flat_f.out_tree return jax.tree_util.tree_unflatten(flat_f.out_tree, results) @@ -306,7 +345,7 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_ ) -@PlxprInterpreter.register_primitive(for_prim) +@PlxprInterpreter.register_primitive(for_loop_prim) def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): """Handle a for loop primitive.""" start = invals[0] @@ -318,7 +357,7 @@ def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): ) new_jaxpr_body_fn = jax.core.ClosedJaxpr(new_jaxpr_body_fn, consts) - return for_prim.bind(*invals, jaxpr_body_fn=new_jaxpr_body_fn, n_consts=n_consts) + return for_loop_prim.bind(*invals, jaxpr_body_fn=new_jaxpr_body_fn, n_consts=n_consts) @PlxprInterpreter.register_primitive(cond_prim) @@ -333,14 +372,18 @@ def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): for n_consts, jaxpr in zip(n_consts_per_branch, jaxpr_branches): consts = consts_flat[start : start + n_consts] start += n_consts - new_jaxprs.append(jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr.jaxpr, consts, *args)) + if jaxpr is None: + new_jaxprs.append(None) + else: + open_jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr.jaxpr, consts, *args) + new_jaxprs.append(jax.core.ClosedJaxpr(open_jaxpr, consts)) return cond_prim.bind( - *invals, jaxpr_brances=new_jaxprs, n_consts_per_branch=n_consts_per_branch, n_args=n_args + *invals, jaxpr_branches=new_jaxprs, n_consts_per_branch=n_consts_per_branch, n_args=n_args ) -@PlxprInterpreter.register_primitive(while_prim) +@PlxprInterpreter.register_primitive(while_loop_prim) def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): """Handle a while loop primitive.""" consts_body = invals[:n_consts_body] @@ -356,10 +399,10 @@ def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body ) new_jaxpr_cond_fn = jax.core.ClosedJaxpr(new_jaxpr_cond_fn, consts_cond) - return while_prim.bind( + return while_loop_prim.bind( *invals, jaxpr_body_fn=new_jaxpr_body_fn, - jaxpr_bond_fn=new_jaxpr_cond_fn, + jaxpr_cond_fn=new_jaxpr_cond_fn, n_consts_body=n_consts_body, n_consts_cond=n_consts_cond, ) @@ -384,3 +427,21 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, qfunc_jaxpr=new_qfunc_jaxpr, n_consts=n_consts, ) + + +@PlxprInterpreter.register_primitive(grad_prim) +def handle_grad(self, *invals, jaxpr, n_consts, **params): + """Handle the grad primitive.""" + consts = invals[:n_consts] + args = invals[n_consts:] + new_jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) + return grad_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) + + +@PlxprInterpreter.register_primitive(jacobian_prim) +def handle_jacobian(self, *invals, jaxpr, n_consts, **params): + """Handle the jacobian primitive.""" + consts = invals[:n_consts] + args = invals[n_consts:] + new_jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) + return jacobian_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py index 8bea0be3f31..a04a34c2ac4 100644 --- a/pennylane/capture/capture_operators.py +++ b/pennylane/capture/capture_operators.py @@ -114,7 +114,8 @@ def _(*args, **kwargs): split = None if n_wires == 0 else -n_wires # need to convert array values into integers # for plxpr, all wires must be integers - wires = tuple(int(w) for w in args[split:]) + # could be abstract when using tracing evaluation in interpreter + wires = tuple(w if qml.math.is_abstract(w) else int(w) for w in args[split:]) args = args[:split] return type.__call__(operator_type, *args, wires=wires, **kwargs) diff --git a/pennylane/capture/primitives.py b/pennylane/capture/primitives.py index 3d578b82f7f..5755c432b86 100644 --- a/pennylane/capture/primitives.py +++ b/pennylane/capture/primitives.py @@ -18,6 +18,7 @@ It has a jax dependency and should be located in a standard import path. """ from pennylane.compiler.qjit_api import _get_for_loop_qfunc_prim, _get_while_loop_qfunc_prim +from pennylane.measurements.mid_measure import _create_mid_measure_primitive from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim from pennylane.ops.op_math.condition import _get_cond_qfunc_prim from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim @@ -37,7 +38,7 @@ cond_prim = _get_cond_qfunc_prim() for_loop_prim = _get_for_loop_qfunc_prim() while_loop_prim = _get_while_loop_qfunc_prim() - +measure_prim = _create_mid_measure_primitive() __all__ = [ "AbstractOperator", @@ -50,4 +51,5 @@ "cond_prim", "for_loop_prim", "while_loop_prim", + "measure_prim", ] diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py index e09ea4b007d..3546d66c4b1 100644 --- a/pennylane/ops/op_math/adjoint.py +++ b/pennylane/ops/op_math/adjoint.py @@ -409,10 +409,10 @@ def adjoint(self): return self.base.queue() def simplify(self): - base = self.base.simplify() - if self.base.has_adjoint: - return base.adjoint().simplify() - return Adjoint(base=base.simplify()) + base = self.base if qml.capture.enabled() else self.base.simplify() + if base.has_adjoint: + return base.adjoint() if qml.capture.enabled() else base.adjoint().simplify() + return Adjoint(base=base) # pylint: disable=no-member diff --git a/pennylane/pauli/pauli_arithmetic.py b/pennylane/pauli/pauli_arithmetic.py index ae24889990f..c6b71c1ea94 100644 --- a/pennylane/pauli/pauli_arithmetic.py +++ b/pennylane/pauli/pauli_arithmetic.py @@ -82,7 +82,7 @@ def _cached_sparse_data(op): elif op == "Y": data = np.array([-1.0j, 1.0j], dtype=np.complex128) indices = np.array([1, 0], dtype=np.int64) - elif op == "Z": + else: # if op == "Z": data = np.array([1.0, -1.0], dtype=np.complex128) indices = np.array([0, 1], dtype=np.int64) return data, indices @@ -510,7 +510,11 @@ def operation(self, wire_order=None, get_as_tensor=False): if len(self) == 0: return Identity(wires=wire_order) - factors = [_make_operation(op, wire) for wire, op in self.items()] + if qml.capture.enabled(): + # cant use lru_cache with program capture + factors = [op_map[op](wire) for wire, op in self.items()] + else: + factors = [_make_operation(op, wire) for wire, op in self.items()] if get_as_tensor: return factors[0] if len(factors) == 1 else Tensor(*factors) @@ -524,7 +528,11 @@ def hamiltonian(self, wire_order=None): raise ValueError("Can't get the Hamiltonian for an empty PauliWord.") return qml.Hamiltonian([1], [Identity(wires=wire_order)]) - obs = [_make_operation(op, wire) for wire, op in self.items()] + if qml.capture.enabled(): + # cant use lru_cache with program capture + obs = [op_map[op](wire) for wire, op in self.items()] + else: + obs = [_make_operation(op, wire) for wire, op in self.items()] return qml.Hamiltonian([1], [obs[0] if len(obs) == 1 else Tensor(*obs)]) def map_wires(self, wire_map: dict) -> "PauliWord": diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index 5d1e62c0737..91c209ca8e2 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -14,14 +14,24 @@ """ This submodule tests strategy structure for defining custom plxpr interpreters """ - +# pylint: disable=protected-access import pytest import pennylane as qml jax = pytest.importorskip("jax") -from pennylane.capture.base_interpreter import PlxprInterpreter +from pennylane.capture.base_interpreter import ( # pylint: disable=wrong-import-position + PlxprInterpreter, +) +from pennylane.capture.primitives import ( # pylint: disable=wrong-import-position + adjoint_transform_prim, + cond_prim, + ctrl_transform_prim, + for_loop_prim, + qnode_prim, + while_loop_prim, +) pytestmark = pytest.mark.jax @@ -49,7 +59,7 @@ def test_env_and_state_initialized(): """Test that env and state are initialized at the start.""" interpreter = SimplifyInterpreter() - assert interpreter._env == {} + assert interpreter._env == {} # pylint: disable=use-implicit-booleaness-not-comparison assert interpreter.state is None @@ -63,7 +73,7 @@ def test_primitive_registrations(): ) @SimplifyInterpreter.register_primitive(qml.X._primitive) - def _(self, *invals, **params): + def _(self, *invals, **params): # pylint: disable=unused-argument print("in custom interpreter") return qml.Z(*invals) @@ -80,8 +90,6 @@ def f(): with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(jaxpr.jaxpr, []) - print(jaxpr) - print(q.queue) qml.assert_equal(q.queue[0], qml.Z(0)) # turned into a Y qml.assert_equal(q.queue[1], qml.Y(5)) # mapped wire @@ -89,6 +97,16 @@ def f(): SimplifyInterpreter._primitive_registrations.pop(qml.X._primitive) +def test_overriding_measurements(): + """Test usage of an interpreter with a custom way of handling measurements.""" + + class MeasurementsToSample(PlxprInterpreter): + + def interpret_measurement_eqn(self, primitive, *invals, **params): + temp_mp = primitive.impl(*invals, **params) + return qml.sample(wires=temp_mp.wires) + + class TestHigherOrderPrimitiveRegistrations: @pytest.mark.parametrize("lazy", (True, False)) @@ -98,14 +116,14 @@ def test_adjoint_transform(self, lazy): @SimplifyInterpreter() def f(x): def g(y): - qml.RX(y, 0) ** 3 + _ = qml.RX(y, 0) ** 3 qml.adjoint(g, lazy=lazy)(x) jaxpr = jax.make_jaxpr(f)(0.5) assert jaxpr.eqns[0].params["lazy"] == lazy - # assert jaxpr.eqns[0].primitive == adjoint_transform_primitive + assert jaxpr.eqns[0].primitive == adjoint_transform_prim inner_jaxpr = jaxpr.eqns[0].params["jaxpr"] # first eqn mul, second RX assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive @@ -118,3 +136,139 @@ def g(y): qml.assert_equal(q.queue[0], qml.adjoint(qml.RX(jax.numpy.array(1.5), 0))) else: qml.assert_equal(q.queue[0], qml.RX(jax.numpy.array(-1.5), 0)) + + def test_ctrl_transform(self): + """Test the higher order adjoint transform.""" + + @SimplifyInterpreter() + def f(x, control): + def g(y): + _ = qml.RY(y, 0) ** 3 + + qml.ctrl(g, control)(x) + + jaxpr = jax.make_jaxpr(f)(0.5, 1) + + assert jaxpr.eqns[0].primitive == ctrl_transform_prim + inner_jaxpr = jaxpr.eqns[0].params["jaxpr"] + # first eqn mul, second RY + assert inner_jaxpr.eqns[1].primitive == qml.RY._primitive + assert len(inner_jaxpr.eqns) == 2 + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2.0, 1) + + qml.assert_equal(q.queue[0], qml.ctrl(qml.RY(jax.numpy.array(6.0), 0), 1)) + + def test_cond(self): + """Test the cond higher order primitive.""" + + @SimplifyInterpreter() + def f(x, control): + + def true_fn(y): + _ = qml.RY(y, 0) ** 2 + + def false_fn(y): + _ = qml.adjoint(qml.RY(y, 0)) + + qml.cond(control, true_fn, false_fn)(x) + + jaxpr = jax.make_jaxpr(f)(0.5, False) + assert jaxpr.eqns[0].primitive == cond_prim + + def test_for_loop(self): + """Test the higher order for loop registration.""" + + @SimplifyInterpreter() + def f(n): + + @qml.for_loop(n) + def g(i): + qml.adjoint(qml.X(i)) + + g() + + jaxpr = jax.make_jaxpr(f)(3) + assert jaxpr.eqns[0].primitive == for_loop_prim + + inner_jaxpr = jaxpr.eqns[0].params["jaxpr_body_fn"] + assert len(inner_jaxpr.eqns) == 1 + assert inner_jaxpr.eqns[0].primitive == qml.X._primitive # no adjoint of x + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3) + + qml.assert_equal(q.queue[0], qml.X(0)) + qml.assert_equal(q.queue[1], qml.X(1)) + qml.assert_equal(q.queue[2], qml.X(2)) + assert len(q) == 3 + + def test_while_loop(self): + """Test the higher order for loop registration.""" + + @SimplifyInterpreter() + def f(n): + + @qml.while_loop(lambda i: i < n) + def g(i): + qml.adjoint(qml.Z(i)) + return i + 1 + + g(0) + + jaxpr = jax.make_jaxpr(f)(3) + assert jaxpr.eqns[0].primitive == while_loop_prim + + inner_jaxpr = jaxpr.eqns[0].params["jaxpr_body_fn"] + assert len(inner_jaxpr.eqns) == 2 + assert inner_jaxpr.eqns[0].primitive == qml.Z._primitive # no adjoint of x + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3) + + qml.assert_equal(q.queue[0], qml.Z(0)) + qml.assert_equal(q.queue[1], qml.Z(1)) + qml.assert_equal(q.queue[2], qml.Z(2)) + assert len(q) == 3 + + def test_qnode(self): + """Test transforming qnodes.""" + + class AddNoise(PlxprInterpreter): + + def interpret_operation(self, op): + data, struct = jax.tree_util.tree_flatten(op) + new_op = jax.tree_util.tree_unflatten(struct, data) + _ = [qml.RX(0.1, w) for w in op.wires] + return new_op + + dev = qml.device("default.qubit", wires=1) + + @AddNoise() + @qml.qnode(dev, diff_method="adjoint", grad_on_execution=False) + def f(): + qml.I(0) + qml.I(0) + return qml.probs(wires=0) + + jaxpr = jax.make_jaxpr(f)() + assert jaxpr.eqns[0].primitive == qnode_prim + inner_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] + + assert len(inner_jaxpr.eqns) == 5 + assert inner_jaxpr.eqns[0].primitive == qml.I._primitive + assert inner_jaxpr.eqns[2].primitive == qml.I._primitive + assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive + assert inner_jaxpr.eqns[3].primitive == qml.RX._primitive + + assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "adjoint" + assert jaxpr.eqns[0].params["qnode_kwargs"]["grad_on_execution"] is False + assert jaxpr.eqns[0].params["device"] == dev + + res1 = f() + # end up performing two rx gates with phase of 0.1 each on wire 0 + expected = jax.numpy.array([jax.numpy.cos(0.2 / 2) ** 2, jax.numpy.sin(0.2 / 2) ** 2]) + assert qml.math.allclose(res1, expected) + res2 = jax.core.eval_jaxpr(jaxpr.jaxpr, []) + assert qml.math.allclose(res2, expected) From 1cc4e79b70c9d3e94c790febf9419502ac859582 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 1 Oct 2024 17:09:57 -0400 Subject: [PATCH 17/45] more tests --- pennylane/capture/base_interpreter.py | 12 +-- tests/capture/test_base_interpreter.py | 106 +++++++++++++++++++++++-- 2 files changed, 104 insertions(+), 14 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 09260b8620e..3ce5693d9b7 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -147,13 +147,13 @@ def register_primitive(cls, primitive: "jax.core.Primitive") -> Callable[[Callab Calling the returned decorator with a function will place the function into the primitive registrations map. - ``` - my_primitive = jax.core.Primitive("my_primitve") + ..code-block:: python - @Interpreter_Type.register(my_primitive) - def handle_my_primitive(self: Interpreter_Type, *invals, **params) - return invals[0] + invals[1] # some sort of custom handling - ``` + my_primitive = jax.core.Primitive("my_primitve") + + @Interpreter_Type.register(my_primitive) + def handle_my_primitive(self: Interpreter_Type, *invals, **params) + return invals[0] + invals[1] # some sort of custom handling """ diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index 91c209ca8e2..f1101d92b39 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -67,20 +67,30 @@ def test_primitive_registrations(): """Test that child primitive registrations dict's are not copied and do not effect PlxprInterpreeter.""" + class SimplifyInterpreterLocal(PlxprInterpreter): + + def interpret_operation(self, op): + new_op = op.simplify() + if new_op is op: + # if new op isn't queued, need to requeue op. + data, struct = jax.tree_util.tree_flatten(new_op) + new_op = jax.tree_util.tree_unflatten(struct, data) + return new_op + assert ( - SimplifyInterpreter._primitive_registrations + SimplifyInterpreterLocal._primitive_registrations is not PlxprInterpreter._primitive_registrations ) - @SimplifyInterpreter.register_primitive(qml.X._primitive) + @SimplifyInterpreterLocal.register_primitive(qml.X._primitive) def _(self, *invals, **params): # pylint: disable=unused-argument print("in custom interpreter") return qml.Z(*invals) - assert qml.X._primitive in SimplifyInterpreter._primitive_registrations + assert qml.X._primitive in SimplifyInterpreterLocal._primitive_registrations assert qml.X._primitive not in PlxprInterpreter._primitive_registrations - @SimplifyInterpreter() + @SimplifyInterpreterLocal() def f(): qml.X(0) qml.Y(5) @@ -93,9 +103,6 @@ def f(): qml.assert_equal(q.queue[0], qml.Z(0)) # turned into a Y qml.assert_equal(q.queue[1], qml.Y(5)) # mapped wire - # restore simplify interpreter to its previous state - SimplifyInterpreter._primitive_registrations.pop(qml.X._primitive) - def test_overriding_measurements(): """Test usage of an interpreter with a custom way of handling measurements.""" @@ -106,6 +113,13 @@ def interpret_measurement_eqn(self, primitive, *invals, **params): temp_mp = primitive.impl(*invals, **params) return qml.sample(wires=temp_mp.wires) + @MeasurementsToSample() + @qml.qnode(qml.device("default.qubit", wires=2, shots=5)) + def circuit(): + return qml.expval(qml.Z(0)), qml.probs(wires=(0, 1)) + + circuit() + class TestHigherOrderPrimitiveRegistrations: @@ -170,13 +184,66 @@ def true_fn(y): _ = qml.RY(y, 0) ** 2 def false_fn(y): - _ = qml.adjoint(qml.RY(y, 0)) + _ = qml.adjoint(qml.RX(y, 0)) qml.cond(control, true_fn, false_fn)(x) jaxpr = jax.make_jaxpr(f)(0.5, False) assert jaxpr.eqns[0].primitive == cond_prim + branch1 = jaxpr.eqns[0].params["jaxpr_branches"][0] + assert len(branch1.eqns) == 2 + assert branch1.eqns[1].primitive == qml.RY._primitive + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(branch1, [], 0.5) + qml.assert_equal(q.queue[0], qml.RY(2 * 0.5, 0)) + + branch2 = jaxpr.eqns[0].params["jaxpr_branches"][1] + assert len(branch2.eqns) == 2 + assert branch2.eqns[1].primitive == qml.RX._primitive + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(branch2, [], 0.5) + qml.assert_equal(q.queue[0], qml.RY(-0.5, 0)) + + assert jaxpr.eqns[0].params["n_args"] == 1 + assert jaxpr.eqns[0].params["n_consts_per_branch"] == [0, 0] + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2.4, True) + + qml.assert_equal(q.queue[0], qml.RY(4.8, 0)) + + with qml.queuing.AnnotatedQueue() as q: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1.23, False) + + qml.assert_equal(q.queue[0], qml.RX(-1.23, 0)) + + def test_cond_no_false_branch(self): + """Test transforming a cond HOP when no false branch exists.""" + + @SimplifyInterpreter() + def f(control): + + @qml.cond(control) + def f(): + qml.X(0) @ qml.X(0) + + f() + + jaxpr = jax.make_jaxpr(f)(True) + + assert jaxpr.eqns[0].params["jaxpr_branches"][-1] is None # no false branch + + with qml.queuing.AnnotatedQueue() as q_true: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, True) + + qml.assert_equal(q_true.queue[0], qml.I(0)) + + with qml.queuing.AnnotatedQueue() as q_false: + jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, False) + + assert len(q_false.queue) == 0 + def test_for_loop(self): """Test the higher order for loop registration.""" @@ -272,3 +339,26 @@ def f(): assert qml.math.allclose(res1, expected) res2 = jax.core.eval_jaxpr(jaxpr.jaxpr, []) assert qml.math.allclose(res2, expected) + + @pytest.mark.parametrize("grad_f", (qml.grad, qml.jacobian)) + def test_grad_and_jac(self, grad_f): + """Test interpreters can handle grad and jacobian HOP's.""" + + class DoubleAngle(PlxprInterpreter): + + def interpret_operation(self, op): + leaves, struct = jax.tree_util.tree_flatten(op) + return jax.tree_util.tree_unflatten(struct, [2 * l for l in leaves]) + + @DoubleAngle() + def f(x): + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(y): + qml.RX(y, 0) + return qml.expval(qml.Z(0)) + + return grad_f(circuit)(x) + + out = f(0.5) + expected = -2 * jax.numpy.sin(2 * 0.5) # includes the factors of 2 from doubling the angle. + assert qml.math.allclose(out, expected) From 1a8f410b4d8881e3d03ef1f421fd4e6c41f0349e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 2 Oct 2024 10:10:40 -0400 Subject: [PATCH 18/45] more test changes --- pennylane/capture/base_interpreter.py | 41 +++++++------------------- tests/capture/test_base_interpreter.py | 20 ++++++------- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 3ce5693d9b7..0da964c417d 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -52,17 +52,6 @@ def f(*inner_args): class PlxprInterpreter: """A template base class for defining plxpr interpreters - Args: - state (Any): any kind of information that may need to get carried around between different interpreters. - - **State property:** - - Higher order primitives can often be handled by a separate interpreter, but need to reference or modify the same values. - For example, a device interpreter may need to modify a statevector, or conversion to a tape may need to modify operations - and measurement lists. By maintaining this information in the optional ``state`` property, this information can automatically - by passed to new sub-interpreters. - - **Examples:** .. code-block:: python @@ -128,10 +117,9 @@ def interpret_operation(self, op): def __init_subclass__(cls) -> None: cls._primitive_registrations = copy.copy(cls._primitive_registrations) - def __init__(self, state=None): + def __init__(self): self._env = {} self._op_math_cache = {} - self.state = state @classmethod def register_primitive(cls, primitive: "jax.core.Primitive") -> Callable[[Callable], Callable]: @@ -206,8 +194,7 @@ def interpret_operation(self, op: "pennylane.operation.Operator"): See also: :meth:`~.interpret_operation_eqn`. """ - data, struct = jax.tree_util.tree_flatten(op) - return jax.tree_util.tree_unflatten(struct, data) + return op._unflatten(*op._flatten()) # pylint: disable=protected-access def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): """Interpret an equation corresponding to an operator. @@ -323,7 +310,7 @@ def handle_adjoint_transform(self, *invals, jaxpr, lazy, n_consts): consts = invals[:n_consts] args = invals[n_consts:] - jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) + jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) return adjoint_transform_prim.bind(*invals, jaxpr=jaxpr, lazy=lazy, n_consts=n_consts) @@ -333,7 +320,7 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_ """Interpret a ctrl transform primitive.""" consts = invals[:n_consts] args = invals[n_consts:-n_control] - jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) + jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) return ctrl_transform_prim.bind( *invals, @@ -353,7 +340,7 @@ def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): init_state = invals[3 + n_consts :] new_jaxpr_body_fn = jaxpr_to_jaxpr( - type(self)(state=self.state), jaxpr_body_fn.jaxpr, consts, start, *init_state + type(self)(), jaxpr_body_fn.jaxpr, consts, start, *init_state ) new_jaxpr_body_fn = jax.core.ClosedJaxpr(new_jaxpr_body_fn, consts) @@ -375,7 +362,7 @@ def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): if jaxpr is None: new_jaxprs.append(None) else: - open_jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr.jaxpr, consts, *args) + open_jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr.jaxpr, consts, *args) new_jaxprs.append(jax.core.ClosedJaxpr(open_jaxpr, consts)) return cond_prim.bind( @@ -390,13 +377,9 @@ def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] init_state = invals[n_consts_body + n_consts_cond :] - new_jaxpr_body_fn = jaxpr_to_jaxpr( - type(self)(state=self.state), jaxpr_body_fn.jaxpr, consts_body, *init_state - ) + new_jaxpr_body_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_body_fn.jaxpr, consts_body, *init_state) new_jaxpr_body_fn = jax.core.ClosedJaxpr(new_jaxpr_body_fn, consts_body) - new_jaxpr_cond_fn = jaxpr_to_jaxpr( - type(self)(state=self.state), jaxpr_cond_fn.jaxpr, consts_cond, *init_state - ) + new_jaxpr_cond_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_cond_fn.jaxpr, consts_cond, *init_state) new_jaxpr_cond_fn = jax.core.ClosedJaxpr(new_jaxpr_cond_fn, consts_cond) return while_loop_prim.bind( @@ -414,9 +397,7 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, """Handle a qnode primitive.""" consts = invals[:n_consts] - new_qfunc_jaxpr = jaxpr_to_jaxpr( - type(self)(state=self.state), qfunc_jaxpr, consts, *invals[n_consts:] - ) + new_qfunc_jaxpr = jaxpr_to_jaxpr(type(self)(), qfunc_jaxpr, consts, *invals[n_consts:]) return qnode_prim.bind( *invals, @@ -434,7 +415,7 @@ def handle_grad(self, *invals, jaxpr, n_consts, **params): """Handle the grad primitive.""" consts = invals[:n_consts] args = invals[n_consts:] - new_jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) + new_jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) return grad_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) @@ -443,5 +424,5 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params): """Handle the jacobian primitive.""" consts = invals[:n_consts] args = invals[n_consts:] - new_jaxpr = jaxpr_to_jaxpr(type(self)(state=self.state), jaxpr, consts, *args) + new_jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) return jacobian_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index f1101d92b39..15c6ce66de9 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -49,18 +49,18 @@ class SimplifyInterpreter(PlxprInterpreter): def interpret_operation(self, op): new_op = op.simplify() if new_op is op: + new_op = new_op._unflatten(*op._flatten()) # if new op isn't queued, need to requeue op. - data, struct = jax.tree_util.tree_flatten(new_op) - new_op = jax.tree_util.tree_unflatten(struct, data) return new_op -def test_env_and_state_initialized(): - """Test that env and state are initialized at the start.""" +# pylint: disable=use-implicit-booleaness-not-comparison +def test_env_and_initialized(): + """Test that env is initialized at the start.""" interpreter = SimplifyInterpreter() - assert interpreter._env == {} # pylint: disable=use-implicit-booleaness-not-comparison - assert interpreter.state is None + assert interpreter._env == {} + assert interpreter._op_math_cache == {} def test_primitive_registrations(): @@ -73,8 +73,7 @@ def interpret_operation(self, op): new_op = op.simplify() if new_op is op: # if new op isn't queued, need to requeue op. - data, struct = jax.tree_util.tree_flatten(new_op) - new_op = jax.tree_util.tree_unflatten(struct, data) + new_op = new_op._unflatten(*op._flatten()) return new_op assert ( @@ -226,7 +225,7 @@ def f(control): @qml.cond(control) def f(): - qml.X(0) @ qml.X(0) + _ = qml.X(0) @ qml.X(0) f() @@ -305,8 +304,7 @@ def test_qnode(self): class AddNoise(PlxprInterpreter): def interpret_operation(self, op): - data, struct = jax.tree_util.tree_flatten(op) - new_op = jax.tree_util.tree_unflatten(struct, data) + new_op = op._unflatten(*op._flatten()) _ = [qml.RX(0.1, w) for w in op.wires] return new_op From e8a5c5a4af7dcb108e16a95bdc9a641629a17671 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 2 Oct 2024 11:51:35 -0400 Subject: [PATCH 19/45] some more tests and polishing --- pennylane/capture/base_interpreter.py | 37 ++++-------- tests/capture/test_base_interpreter.py | 79 ++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 30 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 0da964c417d..f46b76197df 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -135,7 +135,7 @@ def register_primitive(cls, primitive: "jax.core.Primitive") -> Callable[[Callab Calling the returned decorator with a function will place the function into the primitive registrations map. - ..code-block:: python + .. code-block:: python my_primitive = jax.core.Primitive("my_primitve") @@ -160,7 +160,7 @@ def read(self, var): return var.val return self._op_math_cache.get(var, self._env[var]) - def setup(self): + def setup(self) -> None: """Initialize the instance before interpretting equations. Blank by default, this method can initialize any additional instance variables @@ -169,14 +169,12 @@ def setup(self): """ - def cleanup(self): + def cleanup(self) -> None: """Perform any final steps after iterating through all equations. - Blank by default, this method can clean up instance variables, or perform - equations that have been deffered till later. For example, if a compilation - interpreter has a staging area for the latest operation on each wire, the cleanup method - could clear out the staging area. - + Blank by default, this method can clean up instance variables. Particularily, + this method can be used to deallocate qubits and registers when converting to + catalyst variant jaxpr. """ def interpret_operation(self, op: "pennylane.operation.Operator"): @@ -200,12 +198,7 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): """Interpret an equation corresponding to an operator. Args: - primitive (jax.core.Primitive): a jax primitive corresponding to an operation - outvar - *invals (Any): the positional input variables for the equation - - Keyword Args: - **params: The equations parameters dictionary + eqn (jax.core.JaxprEqn): a jax equation for an operator. See also: :meth:`~.interpret_operation`. @@ -277,7 +270,6 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: for outvar, outval in zip(eqn.outvars, outvals): self._env[outvar] = outval - self.cleanup() # Read the final result of the Jaxpr from the environment outvals = [] for var in jaxpr.outvars: @@ -285,6 +277,7 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: outvals.append(self.interpret_operation(self._op_math_cache[var])) else: outvals.append(self.read(var)) + self.cleanup() self._op_math_cache = {} self._env = {} return outvals @@ -339,11 +332,8 @@ def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): consts = invals[3 : 3 + n_consts] init_state = invals[3 + n_consts :] - new_jaxpr_body_fn = jaxpr_to_jaxpr( - type(self)(), jaxpr_body_fn.jaxpr, consts, start, *init_state - ) + new_jaxpr_body_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_body_fn, consts, start, *init_state) - new_jaxpr_body_fn = jax.core.ClosedJaxpr(new_jaxpr_body_fn, consts) return for_loop_prim.bind(*invals, jaxpr_body_fn=new_jaxpr_body_fn, n_consts=n_consts) @@ -362,8 +352,7 @@ def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): if jaxpr is None: new_jaxprs.append(None) else: - open_jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr.jaxpr, consts, *args) - new_jaxprs.append(jax.core.ClosedJaxpr(open_jaxpr, consts)) + new_jaxprs.append(jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args)) return cond_prim.bind( *invals, jaxpr_branches=new_jaxprs, n_consts_per_branch=n_consts_per_branch, n_args=n_args @@ -377,10 +366,8 @@ def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] init_state = invals[n_consts_body + n_consts_cond :] - new_jaxpr_body_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_body_fn.jaxpr, consts_body, *init_state) - new_jaxpr_body_fn = jax.core.ClosedJaxpr(new_jaxpr_body_fn, consts_body) - new_jaxpr_cond_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_cond_fn.jaxpr, consts_cond, *init_state) - new_jaxpr_cond_fn = jax.core.ClosedJaxpr(new_jaxpr_cond_fn, consts_cond) + new_jaxpr_body_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_body_fn, consts_body, *init_state) + new_jaxpr_cond_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_cond_fn, consts_cond, *init_state) return while_loop_prim.bind( *invals, diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index 15c6ce66de9..89357dae459 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -20,6 +20,7 @@ import pennylane as qml jax = pytest.importorskip("jax") +jnp = pytest.importorskip("jax.numpy") from pennylane.capture.base_interpreter import ( # pylint: disable=wrong-import-position PlxprInterpreter, @@ -117,7 +118,75 @@ def interpret_measurement_eqn(self, primitive, *invals, **params): def circuit(): return qml.expval(qml.Z(0)), qml.probs(wires=(0, 1)) - circuit() + res = circuit() + assert qml.math.allclose(res[0], jax.numpy.zeros(5)) + assert qml.math.allclose(res[1], jax.numpy.zeros((5, 2))) + + jaxpr = jax.make_jaxpr(circuit)() + assert ( + jaxpr.eqns[0].params["qfunc_jaxpr"].eqns[0].primitive + == qml.measurements.SampleMP._wires_primitive + ) + assert ( + jaxpr.eqns[0].params["qfunc_jaxpr"].eqns[1].primitive + == qml.measurements.SampleMP._wires_primitive + ) + + +def test_setup_method(): + """Test that the setup method can be used to initialized variables each call.""" + + class CollectOps(PlxprInterpreter): + + ops = None + + def setup(self): + self.ops = [] + + def interpret_operation(self, op): + self.ops.append(op) + return op._unflatten(*op._flatten()) + + def f(x): + qml.RX(x, 0) + qml.RY(2 * x, 0) + + jaxpr = jax.make_jaxpr(f)(0.5) + inst = CollectOps() + inst.eval(jaxpr.jaxpr, jaxpr.consts, 1.2) + assert inst.ops + assert len(inst.ops) == 2 + qml.assert_equal(inst.ops[0], qml.RX(1.2, 0)) + qml.assert_equal(inst.ops[1], qml.RY(jnp.array(2.4), 0)) + + # refreshed if instance is re-used + inst.eval(jaxpr.jaxpr, jaxpr.consts, -0.5) + assert len(inst.ops) == 2 + qml.assert_equal(inst.ops[0], qml.RX(-0.5, 0)) + qml.assert_equal(inst.ops[1], qml.RY(jnp.array(-1.0), 0)) + + +def test_cleanup_method(): + """Test that the cleanup method.""" + + class CleanupTester(PlxprInterpreter): + + state = "DEFAULT" + + def setup(self): + self.state = "SOME LARGE MEMORY" + + def cleanup(self): + self.state = None + + inst = CleanupTester() + + @inst + def f(x): + qml.RX(x, 0) + + f(0.5) + assert inst.state is None class TestHigherOrderPrimitiveRegistrations: @@ -195,14 +264,14 @@ def false_fn(y): assert branch1.eqns[1].primitive == qml.RY._primitive with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(branch1, [], 0.5) - qml.assert_equal(q.queue[0], qml.RY(2 * 0.5, 0)) + qml.assert_equal(q.queue[0], qml.RY(2 * jax.numpy.array(0.5), 0)) branch2 = jaxpr.eqns[0].params["jaxpr_branches"][1] assert len(branch2.eqns) == 2 assert branch2.eqns[1].primitive == qml.RX._primitive with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(branch2, [], 0.5) - qml.assert_equal(q.queue[0], qml.RY(-0.5, 0)) + qml.assert_equal(q.queue[0], qml.RX(jax.numpy.array(-0.5), 0)) assert jaxpr.eqns[0].params["n_args"] == 1 assert jaxpr.eqns[0].params["n_consts_per_branch"] == [0, 0] @@ -210,12 +279,12 @@ def false_fn(y): with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2.4, True) - qml.assert_equal(q.queue[0], qml.RY(4.8, 0)) + qml.assert_equal(q.queue[0], qml.RY(jax.numpy.array(4.8), 0)) with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1.23, False) - qml.assert_equal(q.queue[0], qml.RX(-1.23, 0)) + qml.assert_equal(q.queue[0], qml.RX(jax.numpy.array(-1.23), 0)) def test_cond_no_false_branch(self): """Test transforming a cond HOP when no false branch exists.""" From f806708456b95f4dd5f09a617c5e7d588d8fa070 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 2 Oct 2024 17:15:25 -0400 Subject: [PATCH 20/45] add default qubit interpreter --- pennylane/devices/qubit/dq_interpreter.py | 152 ++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 pennylane/devices/qubit/dq_interpreter.py diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py new file mode 100644 index 00000000000..ae5fbd1f599 --- /dev/null +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -0,0 +1,152 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains a class for executing plxpr using default qubit tools. +""" + +import jax + +from pennylane.capture import PlxprInterpreter +from pennylane.capture.primitives import ( + adjoint_transform_prim, + cond_prim, + ctrl_transform_prim, + for_loop_prim, + measure_prim, + while_loop_prim, +) +from pennylane.measurements import MidMeasureMP + +from .apply_operation import apply_operation +from .initialize_state import create_initial_state +from .measure import measure +from .sampling import measure_with_samples + + +class DefaultQubitInterpreter(PlxprInterpreter): + """Implements a class for interpreting plxpr using default qubit. + + >>> key = jax.random.PRNGKey(1234) + >>> dq = DefaultQubitInterpreter(num_wires=2, shots=qml.measurements.Shots(50), key=key) + >>> @qml.for_loop(2) + ... def g(i,y): + ... qml.RX(y,0) + ... return y + >>> def f(x): + ... g(x) + ... return qml.expval(qml.Z(0)) + >>> dq(f)(0.5) + Array(-0.79999995, dtype=float32) + + + """ + + def __init__(self, num_wires, shots, key=None, stateref=None): + self.num_wires = num_wires + self.shots = shots + self.stateref = stateref or {"state": None} + self.key = key + + @property + def state(self): + return self.stateref["state"] + + @state.setter + def state(self, value): + self.stateref["state"] = value + + def child(self) -> "DefaultQubitInterpreter": + return type(self)( + num_wires=self.num_wires, shots=self.shots, key=self.key, stateref=self.stateref + ) + + def setup(self): + if self.state is None: + self.state = create_initial_state(range(self.num_wires)) + + def interpret_operation(self, op): + self.state = apply_operation(op, self.state) + + def interpret_measurement_eqn(self, primitive, *invals, **params): + mp = primitive.impl(*invals, **params) + if self.shots: + self.key, new_key = jax.random.split(self.key, 2) + # note that this does *not* group commuting measurements + # further work could figure out how to perform multiple measurements at the same time + return measure_with_samples([mp], self.state, shots=self.shots, prng_key=new_key)[0] + return measure(mp, self.state) + + +# pylint: disable=unused-argument +@DefaultQubitInterpreter.register_primitive(adjoint_transform_prim) +def _(self, *invals, jaxpr, n_consts, lazy=True): + raise NotImplementedError("TODO?") + + +@DefaultQubitInterpreter.register_primitive(ctrl_transform_prim) +def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): + raise NotImplementedError("TODO?") + + +@DefaultQubitInterpreter.register_primitive(for_loop_prim) +def _(self, *invals, jaxpr_body_fn, n_consts): + start, stop, step = invals[0], invals[1], invals[2] + consts = invals[3 : 3 + n_consts] + init_state = invals[3 + n_consts :] + + res = None + for i in range(start, stop, step): + res = self.child().eval(jaxpr_body_fn, consts, i, *init_state) + + return res + + +@DefaultQubitInterpreter.register_primitive(while_loop_prim) +def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): + consts_body = invals[:n_consts_body] + consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] + init_state = invals[n_consts_body + n_consts_cond :] + + fn_res = init_state + while self.child().eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]: + fn_res = self.child().eval(jaxpr_body_fn, consts_body, *fn_res) + + return fn_res + + +@DefaultQubitInterpreter.register_primitive(measure_prim) +def _(self, *invals, reset, postselect): + mp = MidMeasureMP(invals, reset=reset, postselect=postselect) + mid_measurements = {} + self.key, new_key = jax.random.split(self.key, 2) + self.state = apply_operation( + mp, self.state, mid_measurements=mid_measurements, prng_key=new_key + ) + return mid_measurements[mp] + + +@DefaultQubitInterpreter.register_primitive(cond_prim) +def _(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): + n_branches = len(jaxpr_branches) + conditions = invals[:n_branches] + consts_flat = invals[n_branches + n_args :] + args = invals[n_branches : n_branches + n_args] + + start = 0 + for pred, jaxpr, n_consts in zip(conditions, jaxpr_branches, n_consts_per_branch): + consts = consts_flat[start : start + n_consts] + start += n_consts + if pred and jaxpr is not None: + return self.child().eval_jaxpr(jaxpr, consts, *args) + return () From 751d66cbb2cf76f9128021a121e3d7555e5a8293 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 9 Oct 2024 15:08:17 -0400 Subject: [PATCH 21/45] use copy not child --- pennylane/devices/qubit/dq_interpreter.py | 54 ++++++++++++++--------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index ae5fbd1f599..a6d20eb303f 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -14,10 +14,12 @@ """ This module contains a class for executing plxpr using default qubit tools. """ +from copy import copy import jax +import numpy as np -from pennylane.capture import PlxprInterpreter +from pennylane.capture.base_interpreter import PlxprInterpreter from pennylane.capture.primitives import ( adjoint_transform_prim, cond_prim, @@ -26,7 +28,7 @@ measure_prim, while_loop_prim, ) -from pennylane.measurements import MidMeasureMP +from pennylane.measurements import MidMeasureMP, Shots from .apply_operation import apply_operation from .initialize_state import create_initial_state @@ -52,34 +54,45 @@ class DefaultQubitInterpreter(PlxprInterpreter): """ - def __init__(self, num_wires, shots, key=None, stateref=None): + def __init__(self, num_wires, shots, key: None | jax.numpy.ndarray = None): self.num_wires = num_wires - self.shots = shots - self.stateref = stateref or {"state": None} - self.key = key + self.shots = Shots(shots) + if key is None: + key = jax.random.PRNGKey(np.random.random()) + self.stateref = {"state": None, "key": key, "mcms": None} @property - def state(self): + def state(self) -> jax.numpy.ndarray: return self.stateref["state"] @state.setter - def state(self, value): + def state(self, value: jax.numpy.ndarray): self.stateref["state"] = value - def child(self) -> "DefaultQubitInterpreter": - return type(self)( - num_wires=self.num_wires, shots=self.shots, key=self.key, stateref=self.stateref - ) + @property + def key(self) -> jax.numpy.ndarray: + return self.stateref["key"] + + @property + def mcms(self): + return self.stateref["mcms"] + + @key.setter + def key(self, value): + self.stateref["key"] = value def setup(self): if self.state is None: - self.state = create_initial_state(range(self.num_wires)) + self.state = create_initial_state(range(self.num_wires), like="jax") + if self.mcms is None: + self.stateref["mcms"] = {} def interpret_operation(self, op): self.state = apply_operation(op, self.state) def interpret_measurement_eqn(self, primitive, *invals, **params): mp = primitive.impl(*invals, **params) + if self.shots: self.key, new_key = jax.random.split(self.key, 2) # note that this does *not* group commuting measurements @@ -107,7 +120,7 @@ def _(self, *invals, jaxpr_body_fn, n_consts): res = None for i in range(start, stop, step): - res = self.child().eval(jaxpr_body_fn, consts, i, *init_state) + res = copy(self).eval(jaxpr_body_fn, consts, i, *init_state) return res @@ -119,8 +132,8 @@ def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond) init_state = invals[n_consts_body + n_consts_cond :] fn_res = init_state - while self.child().eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]: - fn_res = self.child().eval(jaxpr_body_fn, consts_body, *fn_res) + while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]: + fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res) return fn_res @@ -128,12 +141,9 @@ def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond) @DefaultQubitInterpreter.register_primitive(measure_prim) def _(self, *invals, reset, postselect): mp = MidMeasureMP(invals, reset=reset, postselect=postselect) - mid_measurements = {} self.key, new_key = jax.random.split(self.key, 2) - self.state = apply_operation( - mp, self.state, mid_measurements=mid_measurements, prng_key=new_key - ) - return mid_measurements[mp] + self.state = apply_operation(mp, self.state, mid_measurements=self.mcms, prng_key=new_key) + return self.mcms[mp] @DefaultQubitInterpreter.register_primitive(cond_prim) @@ -148,5 +158,5 @@ def _(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): consts = consts_flat[start : start + n_consts] start += n_consts if pred and jaxpr is not None: - return self.child().eval_jaxpr(jaxpr, consts, *args) + return copy(self).eval_jaxpr(jaxpr, consts, *args) return () From c2322865013eadc6a9c070008c1daf7cdef10ba2 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 11 Nov 2024 14:45:36 -0500 Subject: [PATCH 22/45] merge in master --- pennylane/capture/__init__.py | 7 +- pennylane/capture/base_interpreter.py | 204 ++++++++++++++++--------- pennylane/pauli/pauli_arithmetic.py | 37 ++++- tests/capture/test_base_interpreter.py | 167 +++++++++++++++----- 4 files changed, 292 insertions(+), 123 deletions(-) diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 761bd663b50..080701f0d7d 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -33,6 +33,7 @@ ~create_measurement_obs_primitive ~create_measurement_wires_primitive ~create_measurement_mcm_primitive + ~make_plxpr ~qnode_call ~PlxprInterpreter ~FlatFn @@ -157,6 +158,7 @@ def _(*args, **kwargs): ) from .capture_qnode import qnode_call from .flatfn import FlatFn +from .make_plxpr import make_plxpr # by defining this here, we avoid # E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module) @@ -185,9 +187,7 @@ def __getattr__(key): return _get_qnode_prim() if key == "PlxprInterpreter": - from .base_interpreter import ( - PlxprInterpreter, - ) + from .base_interpreter import PlxprInterpreter return PlxprInterpreter @@ -210,4 +210,5 @@ def __getattr__(key): "qnode_prim", "PlxprInterpreter", "FlatFn", + "make_plxpr", ) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index f46b76197df..9147d08afaf 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -14,8 +14,8 @@ """ This submodule defines a strategy structure for defining custom plxpr interpreters """ - -import copy +# pylint: disable=no-self-use +from copy import copy from functools import partial, wraps from typing import Callable @@ -41,16 +41,15 @@ def jaxpr_to_jaxpr( interpreter: "PlxprInterpreter", jaxpr: "jax.core.Jaxpr", consts, *args ) -> "jax.core.Jaxpr": - """A convenience uility for converting jaxpr to a new jaxpr via an interpreter.""" + """A convenience utility for converting jaxpr to a new jaxpr via an interpreter.""" - def f(*inner_args): - return interpreter.eval(jaxpr, consts, *inner_args) + f = partial(interpreter.eval, jaxpr, consts) return jax.make_jaxpr(f)(*args).jaxpr class PlxprInterpreter: - """A template base class for defining plxpr interpreters + """A base class for defining plxpr interpreters. **Examples:** @@ -64,11 +63,18 @@ class SimplifyInterpreter(PlxprInterpreter): def interpret_operation(self, op): new_op = qml.simplify(op) if new_op is op: - # if new op isn't queued, need to requeue op. + # simplify didnt create a new operator, so it didnt get captured data, struct = jax.tree_util.tree_flatten(new_op) new_op = jax.tree_util.tree_unflatten(struct, data) return new_op + def interpret_measurement(self, measurement): + new_mp = measurement.simplify() + if new_mp is measurement: + new_mp = new_mp._unflatten(*measurement._flatten()) + # if new op isn't queued, need to requeue op. + return new_mp + Now the interpreter can be used to transform functions and jaxpr: >>> interpreter = SimplifyInterpreter() @@ -83,7 +89,16 @@ def interpret_operation(self, op): >>> interpreter.eval(jaxpr.jaxpr, [], 0.5) [expval(2.0 * X(0))] - It will also preserve higher order primitives by default: + **Handling higher order primitives:** + + Two main strategies exist for handling higher order primitives (primitives with jaxpr as metatdata). + + 1) Structure preserving. Tracing the execution preserves the higher order primitive. + 2) Structure flattening. Tracing the execution eliminates the higher order primitive. + + Compilation transforms, like the above ``SimplifyInterpreter``, may prefer to handle higher order primitives + via a structure preserving method. After transforming the jaxpr, the `for_loop` still exists. This maintains + the compact structure of the jaxpr and reduces the size of the program. This behavior is the default. >>> def g(x): ... @qml.for_loop(3) @@ -107,26 +122,60 @@ def interpret_operation(self, op): h:AbstractMeasurement(n_wires=None) = expval_obs g in (h,) } + Accumulation transforms, like device execution or conversion to tapes, may need to flatten out + the higher order primitive to execute it. + .. code-block:: python + + class AccumulateOps(PlxprInterpreter): + + def __init__(self, ops=None): + self.ops = ops + + def setup(self): + if self.ops is None: + self.ops = [] + + def interpret_operation(self, op): + self.ops.append(op) + + @AccumulateOps.register_primitive(qml.capture.primitives.for_loop_prim) + def _(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice): + consts = invals[consts_slice] + state = invals[args_slice] + + for i in range(start, stop, step): + state = copy(self).eval(jaxpr_body_fn, consts, i, *state) + return state + + >>> @qml.for_loop(3) + ... def loop(i, x): + ... qml.RX(x, i) + ... return x + >>> accumulator = AccumlateOps() + >>> accumulator(loop)(0.5) + >>> accumulator.ops + [RX(0.5, wires=[0]), RX(0.5, wires=[1]), RX(0.5, wires=[2])] + + In this case, we need to actually evaluate the jaxpr 3 times using our interpreter. If jax's + evaluation interpreter ran it three times, we wouldn't actually manage to accumulate the operations. """ _env: dict _primitive_registrations: dict["jax.core.Primitive", Callable] = {} - _op_math_cache: dict def __init_subclass__(cls) -> None: - cls._primitive_registrations = copy.copy(cls._primitive_registrations) + cls._primitive_registrations = copy(cls._primitive_registrations) def __init__(self): self._env = {} - self._op_math_cache = {} @classmethod def register_primitive(cls, primitive: "jax.core.Primitive") -> Callable[[Callable], Callable]: """Registers a custom method for handling a primitive Args: - primitive (jax.core.Primitive): the primitive we want custom behavior for + primitive (jax.core.Primitive): the primitive we want custom behavior for Returns: Callable: a decorator for adding a function to the custom registrations map @@ -151,22 +200,16 @@ def decorator(f: Callable) -> Callable: return decorator - # pylint: disable=unidiomatic-typecheck def read(self, var): """Extract the value corresponding to a variable.""" - if self._env is None: - raise ValueError("_env not yet initialized.") - if type(var) is jax.core.Literal: - return var.val - return self._op_math_cache.get(var, self._env[var]) + return var.val if isinstance(var, jax.core.Literal) else self._env[var] def setup(self) -> None: - """Initialize the instance before interpretting equations. + """Initialize the instance before interpreting equations. Blank by default, this method can initialize any additional instance variables needed by an interpreter. For example, a device interpreter could initialize a statevector, or a compilation interpreter could initialize a staging area for the latest operation on each wire. - """ def cleanup(self) -> None: @@ -174,7 +217,7 @@ def cleanup(self) -> None: Blank by default, this method can clean up instance variables. Particularily, this method can be used to deallocate qubits and registers when converting to - catalyst variant jaxpr. + a Catalyst variant jaxpr. """ def interpret_operation(self, op: "pennylane.operation.Operator"): @@ -187,12 +230,13 @@ def interpret_operation(self, op: "pennylane.operation.Operator"): Any This method is only called when the operator's output is a dropped variable, - so the output will not effect later equations in the circuit. + so the output will not affect later equations in the circuit. See also: :meth:`~.interpret_operation_eqn`. """ - return op._unflatten(*op._flatten()) # pylint: disable=protected-access + data, struct = jax.tree_util.tree_flatten(op) + return jax.tree_util.tree_unflatten(struct, data) def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): """Interpret an equation corresponding to an operator. @@ -203,31 +247,38 @@ def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"): See also: :meth:`~.interpret_operation`. """ - invals = (self.read(invar) for invar in eqn.invars) with qml.QueuingManager.stop_recording(): op = eqn.primitive.impl(*invals, **eqn.params) if isinstance(eqn.outvars[0], jax.core.DropVar): return self.interpret_operation(op) - - self._op_math_cache[eqn.outvars[0]] = op return op - def interpret_measurement_eqn(self, primitive, *invals, **params): + def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): """Interpret an equation corresponding to a measurement process. Args: - primitive (jax.core.Primitive): a jax primitive corresponding to a measurement. - *invals (Any): the positional input variables for the equation + eqn (jax.core.JaxprEqn) + + See also :meth:`~.interpret_measurement`. + + """ + invals = (self.read(invar) for invar in eqn.invars) + with qml.QueuingManager.stop_recording(): + mp = eqn.primitive.impl(*invals, **eqn.params) + return self.interpret_measurement(mp) + + def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess"): + """Interpret a measurement process instance. + + Args: + measurement (MeasurementProcess): a measurement instance. - Keyword Args: - **params: The equations parameters dictionary + See also :meth:`~.interpret_measurement_eqn`. """ - invals = ( - self.interpret_operation(op) for op in invals if isinstance(op, qml.operation.Operator) - ) - return primitive.bind(*invals, **params) + data, struct = jax.tree_util.tree_flatten(measurement) + return jax.tree_util.tree_unflatten(struct, data) def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: """Evaluate a jaxpr. @@ -242,12 +293,11 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: """ self._env = {} - self._op_math_cache = {} self.setup() - for arg, invar in zip(args, jaxpr.invars): + for arg, invar in zip(args, jaxpr.invars, strict=True): self._env[invar] = arg - for const, constvar in zip(consts, jaxpr.constvars): + for const, constvar in zip(consts, jaxpr.constvars, strict=True): self._env[constvar] = const for eqn in jaxpr.eqns: @@ -259,26 +309,25 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: elif isinstance(eqn.outvars[0].aval, AbstractOperator): outvals = self.interpret_operation_eqn(eqn) elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): - invals = [self.read(invar) for invar in eqn.invars] - outvals = self.interpret_measurement_eqn(eqn.primitive, *invals, **eqn.params) + outvals = self.interpret_measurement_eqn(eqn) else: invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] - for outvar, outval in zip(eqn.outvars, outvals): + for outvar, outval in zip(eqn.outvars, outvals, strict=True): self._env[outvar] = outval # Read the final result of the Jaxpr from the environment outvals = [] for var in jaxpr.outvars: - if var in self._op_math_cache: - outvals.append(self.interpret_operation(self._op_math_cache[var])) + outval = self.read(var) + if isinstance(outval, qml.operation.Operator): + outvals.append(self.interpret_operation(outval)) else: - outvals.append(self.read(var)) + outvals.append(outval) self.cleanup() - self._op_math_cache = {} self._env = {} return outvals @@ -303,7 +352,7 @@ def handle_adjoint_transform(self, *invals, jaxpr, lazy, n_consts): consts = invals[:n_consts] args = invals[n_consts:] - jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) + jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) return adjoint_transform_prim.bind(*invals, jaxpr=jaxpr, lazy=lazy, n_consts=n_consts) @@ -313,7 +362,7 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_ """Interpret a ctrl transform primitive.""" consts = invals[:n_consts] args = invals[n_consts:-n_control] - jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) + jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) return ctrl_transform_prim.bind( *invals, @@ -326,55 +375,62 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_ @PlxprInterpreter.register_primitive(for_loop_prim) -def handle_for_loop(self, *invals, jaxpr_body_fn, n_consts): +def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice): """Handle a for loop primitive.""" - start = invals[0] - consts = invals[3 : 3 + n_consts] - init_state = invals[3 + n_consts :] + init_state = args[args_slice] - new_jaxpr_body_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_body_fn, consts, start, *init_state) + new_jaxpr_body_fn = jaxpr_to_jaxpr( + copy(self), jaxpr_body_fn, args[consts_slice], start, *init_state + ) - return for_loop_prim.bind(*invals, jaxpr_body_fn=new_jaxpr_body_fn, n_consts=n_consts) + return for_loop_prim.bind( + start, + stop, + step, + *args, + jaxpr_body_fn=new_jaxpr_body_fn, + consts_slice=consts_slice, + args_slice=args_slice, + ) @PlxprInterpreter.register_primitive(cond_prim) -def handle_cond(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): +def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice): """Handle a cond primitive.""" - n_branches = len(jaxpr_branches) - consts_flat = invals[n_branches + n_args :] - args = invals[n_branches : n_branches + n_args] + args = invals[args_slice] new_jaxprs = [] - start = 0 - for n_consts, jaxpr in zip(n_consts_per_branch, jaxpr_branches): - consts = consts_flat[start : start + n_consts] - start += n_consts + for const_slice, jaxpr in zip(consts_slices, jaxpr_branches): + consts = invals[const_slice] if jaxpr is None: new_jaxprs.append(None) else: - new_jaxprs.append(jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args)) + new_jaxprs.append(jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)) return cond_prim.bind( - *invals, jaxpr_branches=new_jaxprs, n_consts_per_branch=n_consts_per_branch, n_args=n_args + *invals, jaxpr_branches=new_jaxprs, consts_slices=consts_slices, args_slice=args_slice ) @PlxprInterpreter.register_primitive(while_loop_prim) -def handle_while_loop(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): +def handle_while_loop( + self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice +): """Handle a while loop primitive.""" - consts_body = invals[:n_consts_body] - consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] - init_state = invals[n_consts_body + n_consts_cond :] + consts_body = invals[body_slice] + consts_cond = invals[cond_slice] + init_state = invals[args_slice] - new_jaxpr_body_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_body_fn, consts_body, *init_state) - new_jaxpr_cond_fn = jaxpr_to_jaxpr(type(self)(), jaxpr_cond_fn, consts_cond, *init_state) + new_jaxpr_body_fn = jaxpr_to_jaxpr(copy(self), jaxpr_body_fn, consts_body, *init_state) + new_jaxpr_cond_fn = jaxpr_to_jaxpr(copy(self), jaxpr_cond_fn, consts_cond, *init_state) return while_loop_prim.bind( *invals, jaxpr_body_fn=new_jaxpr_body_fn, jaxpr_cond_fn=new_jaxpr_cond_fn, - n_consts_body=n_consts_body, - n_consts_cond=n_consts_cond, + body_slice=body_slice, + cond_slice=cond_slice, + args_slice=args_slice, ) @@ -384,7 +440,7 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, """Handle a qnode primitive.""" consts = invals[:n_consts] - new_qfunc_jaxpr = jaxpr_to_jaxpr(type(self)(), qfunc_jaxpr, consts, *invals[n_consts:]) + new_qfunc_jaxpr = jaxpr_to_jaxpr(copy(self), qfunc_jaxpr, consts, *invals[n_consts:]) return qnode_prim.bind( *invals, @@ -402,7 +458,7 @@ def handle_grad(self, *invals, jaxpr, n_consts, **params): """Handle the grad primitive.""" consts = invals[:n_consts] args = invals[n_consts:] - new_jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) + new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) return grad_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) @@ -411,5 +467,5 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params): """Handle the jacobian primitive.""" consts = invals[:n_consts] args = invals[n_consts:] - new_jaxpr = jaxpr_to_jaxpr(type(self)(), jaxpr, consts, *args) + new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args) return jacobian_prim.bind(*invals, jaxpr=new_jaxpr, n_consts=n_consts, **params) diff --git a/pennylane/pauli/pauli_arithmetic.py b/pennylane/pauli/pauli_arithmetic.py index c6b71c1ea94..9822fd6cadf 100644 --- a/pennylane/pauli/pauli_arithmetic.py +++ b/pennylane/pauli/pauli_arithmetic.py @@ -15,6 +15,7 @@ # pylint:disable=protected-access from copy import copy from functools import lru_cache, reduce +from warnings import warn import numpy as np from scipy import sparse @@ -82,7 +83,7 @@ def _cached_sparse_data(op): elif op == "Y": data = np.array([-1.0j, 1.0j], dtype=np.complex128) indices = np.array([1, 0], dtype=np.int64) - else: # if op == "Z": + else: # op == "Z" data = np.array([1.0, -1.0], dtype=np.complex128) indices = np.array([0, 1], dtype=np.int64) return data, indices @@ -522,17 +523,25 @@ def operation(self, wire_order=None, get_as_tensor=False): return factors[0] if len(factors) == 1 else Prod(*factors, _pauli_rep=pauli_rep) def hamiltonian(self, wire_order=None): - """Return :class:`~pennylane.Hamiltonian` representing the PauliWord.""" + """Return :class:`~pennylane.Hamiltonian` representing the PauliWord. + + .. warning:: + + :meth:`~pennylane.pauli.PauliWord.hamiltonian` is deprecated. Instead, please use + :meth:`~pennylane.pauli.PauliWord.operation` + + """ + warn( + "PauliWord.hamiltonian() is deprecated. Please use PauliWord.operation() instead.", + qml.PennyLaneDeprecationWarning, + ) + if len(self) == 0: if wire_order in (None, [], Wires([])): raise ValueError("Can't get the Hamiltonian for an empty PauliWord.") return qml.Hamiltonian([1], [Identity(wires=wire_order)]) - if qml.capture.enabled(): - # cant use lru_cache with program capture - obs = [op_map[op](wire) for wire, op in self.items()] - else: - obs = [_make_operation(op, wire) for wire, op in self.items()] + obs = [_make_operation(op, wire) for wire, op in self.items()] return qml.Hamiltonian([1], [obs[0] if len(obs) == 1 else Tensor(*obs)]) def map_wires(self, wire_map: dict) -> "PauliWord": @@ -1030,7 +1039,19 @@ def operation(self, wire_order=None): return summands[0] if len(summands) == 1 else Sum(*summands, _pauli_rep=self) def hamiltonian(self, wire_order=None): - """Returns a native PennyLane :class:`~pennylane.Hamiltonian` representing the PauliSentence.""" + """Returns a native PennyLane :class:`~pennylane.Hamiltonian` representing the PauliSentence. + + .. warning:: + + :meth:`~pennylane.pauli.PauliSentence.hamiltonian` is deprecated. Instead, please use + :meth:`~pennylane.pauli.PauliSentence.operation` + + """ + warn( + "PauliSentence.hamiltonian() is deprecated. Please use PauliSentence.operation() instead.", + qml.PennyLaneDeprecationWarning, + ) + if len(self) == 0: if wire_order in (None, [], Wires([])): raise ValueError("Can't get the Hamiltonian for an empty PauliSentence.") diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index 89357dae459..f977ff154ef 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -14,7 +14,7 @@ """ This submodule tests strategy structure for defining custom plxpr interpreters """ -# pylint: disable=protected-access +# pylint: disable=protected-access, too-few-public-methods import pytest import pennylane as qml @@ -22,9 +22,7 @@ jax = pytest.importorskip("jax") jnp = pytest.importorskip("jax.numpy") -from pennylane.capture.base_interpreter import ( # pylint: disable=wrong-import-position - PlxprInterpreter, -) +from pennylane.capture import PlxprInterpreter # pylint: disable=wrong-import-position from pennylane.capture.primitives import ( # pylint: disable=wrong-import-position adjoint_transform_prim, cond_prim, @@ -54,6 +52,13 @@ def interpret_operation(self, op): # if new op isn't queued, need to requeue op. return new_op + def interpret_measurement(self, measurement): + new_mp = measurement.simplify() + if new_mp is measurement: + new_mp = new_mp._unflatten(*measurement._flatten()) + # if new op isn't queued, need to requeue op. + return new_mp + # pylint: disable=use-implicit-booleaness-not-comparison def test_env_and_initialized(): @@ -61,21 +66,34 @@ def test_env_and_initialized(): interpreter = SimplifyInterpreter() assert interpreter._env == {} - assert interpreter._op_math_cache == {} + + +def test_zip_length_validation(): + """Test that errors are raised if the input values isnt long enough for the needed variables.""" + + def f(x): + return x + 1 + + jaxpr = jax.make_jaxpr(f)(0.5) + with pytest.raises(ValueError): + PlxprInterpreter().eval(jaxpr.jaxpr, []) + + y = jax.numpy.array([1.0]) + + def g(): + return y + 2 + + jaxpr = jax.make_jaxpr(g)() + with pytest.raises(ValueError): + PlxprInterpreter().eval(jaxpr.jaxpr, []) def test_primitive_registrations(): """Test that child primitive registrations dict's are not copied and do - not effect PlxprInterpreeter.""" - - class SimplifyInterpreterLocal(PlxprInterpreter): + not affect PlxprInterpreter.""" - def interpret_operation(self, op): - new_op = op.simplify() - if new_op is op: - # if new op isn't queued, need to requeue op. - new_op = new_op._unflatten(*op._flatten()) - return new_op + class SimplifyInterpreterLocal(SimplifyInterpreter): + pass assert ( SimplifyInterpreterLocal._primitive_registrations @@ -84,7 +102,6 @@ def interpret_operation(self, op): @SimplifyInterpreterLocal.register_primitive(qml.X._primitive) def _(self, *invals, **params): # pylint: disable=unused-argument - print("in custom interpreter") return qml.Z(*invals) assert qml.X._primitive in SimplifyInterpreterLocal._primitive_registrations @@ -100,8 +117,74 @@ def f(): with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(jaxpr.jaxpr, []) - qml.assert_equal(q.queue[0], qml.Z(0)) # turned into a Y - qml.assert_equal(q.queue[1], qml.Y(5)) # mapped wire + qml.assert_equal(q.queue[0], qml.Z(0)) # turned into a Z + qml.assert_equal(q.queue[1], qml.Y(5)) + + +def test_default_operator_handling(): + """Test that the PlxprInterpreter itself can handle operators and leaves them unchanged.""" + + @PlxprInterpreter() + def f(x): + qml.adjoint(qml.RX(x, 0)) + qml.T(1) + return qml.X(0) + qml.X(1) + + with qml.queuing.AnnotatedQueue() as q: + out = f(0.5) + + qml.assert_equal(out, qml.X(0) + qml.X(1)) + qml.assert_equal(q.queue[0], qml.adjoint(qml.RX(0.5, 0))) + qml.assert_equal(q.queue[1], qml.T(1)) + qml.assert_equal(q.queue[2], qml.X(0) + qml.X(1)) + + jaxpr = jax.make_jaxpr(f)(1.2) + + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == qml.ops.Adjoint._primitive + assert jaxpr.eqns[2].primitive == qml.T._primitive + assert jaxpr.eqns[3].primitive == qml.X._primitive + assert jaxpr.eqns[4].primitive == qml.X._primitive + assert jaxpr.eqns[5].primitive == qml.ops.Sum._primitive + + +def test_default_measurement_handling(): + """Test that measurements are simply re-queued by default.""" + + def f(): + return qml.expval(qml.Z(0) + qml.Z(0)), qml.probs(wires=0) + + jaxpr = jax.make_jaxpr(f)() + with qml.queuing.AnnotatedQueue() as q: + res1, res2 = PlxprInterpreter().eval(jaxpr.jaxpr, jaxpr.consts) + assert len(q.queue) == 2 + assert q.queue[0] is res1 + assert q.queue[1] is res2 + qml.assert_equal(res1, qml.expval(qml.Z(0) + qml.Z(0))) + qml.assert_equal(res2, qml.probs(wires=0)) + + +def test_measurement_handling(): + """Test that the default measurment handling works.""" + + @SimplifyInterpreter() + def f(w): + return qml.expval(qml.X(w) + qml.X(w)), qml.probs(wires=w) + + m1, m2 = f(0) + qml.assert_equal(m1, qml.expval(2 * qml.X(0))) + qml.assert_equal(m2, qml.probs(wires=0)) + + jaxpr = jax.make_jaxpr(f)(0) + + assert jaxpr.eqns[0].primitive == qml.X._primitive + assert jaxpr.eqns[1].primitive == qml.ops.SProd._primitive + assert jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive + assert jaxpr.eqns[3].primitive == qml.measurements.ProbabilityMP._wires_primitive + + m1, m2 = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0) + qml.assert_equal(m1, qml.expval(2 * qml.X(0))) + qml.assert_equal(m2, qml.probs(wires=0)) def test_overriding_measurements(): @@ -109,9 +192,8 @@ def test_overriding_measurements(): class MeasurementsToSample(PlxprInterpreter): - def interpret_measurement_eqn(self, primitive, *invals, **params): - temp_mp = primitive.impl(*invals, **params) - return qml.sample(wires=temp_mp.wires) + def interpret_measurement(self, measurement): + return qml.sample(wires=measurement.wires) @MeasurementsToSample() @qml.qnode(qml.device("default.qubit", wires=2, shots=5)) @@ -134,7 +216,7 @@ def circuit(): def test_setup_method(): - """Test that the setup method can be used to initialized variables each call.""" + """Test that the setup method can be used to initialize variables at each call.""" class CollectOps(PlxprInterpreter): @@ -167,7 +249,7 @@ def f(x): def test_cleanup_method(): - """Test that the cleanup method.""" + """Test that the cleanup method can be used to reset variables after evaluation.""" class CleanupTester(PlxprInterpreter): @@ -189,6 +271,16 @@ def f(x): assert inst.state is None +def test_returning_operators(): + """Test that operators that are returned are still processed by the interpreter.""" + + @SimplifyInterpreter() + def f(): + return qml.X(0) ** 2 + + qml.assert_equal(f(), qml.I(0)) + + class TestHigherOrderPrimitiveRegistrations: @pytest.mark.parametrize("lazy", (True, False)) @@ -220,7 +312,7 @@ def g(y): qml.assert_equal(q.queue[0], qml.RX(jax.numpy.array(-1.5), 0)) def test_ctrl_transform(self): - """Test the higher order adjoint transform.""" + """Test the higher order ctrl transform.""" @SimplifyInterpreter() def f(x, control): @@ -273,9 +365,6 @@ def false_fn(y): jax.core.eval_jaxpr(branch2, [], 0.5) qml.assert_equal(q.queue[0], qml.RX(jax.numpy.array(-0.5), 0)) - assert jaxpr.eqns[0].params["n_args"] == 1 - assert jaxpr.eqns[0].params["n_consts_per_branch"] == [0, 0] - with qml.queuing.AnnotatedQueue() as q: jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2.4, True) @@ -411,21 +500,23 @@ def f(): def test_grad_and_jac(self, grad_f): """Test interpreters can handle grad and jacobian HOP's.""" - class DoubleAngle(PlxprInterpreter): - - def interpret_operation(self, op): - leaves, struct = jax.tree_util.tree_flatten(op) - return jax.tree_util.tree_unflatten(struct, [2 * l for l in leaves]) - - @DoubleAngle() + @SimplifyInterpreter() def f(x): @qml.qnode(qml.device("default.qubit", wires=2)) def circuit(y): - qml.RX(y, 0) - return qml.expval(qml.Z(0)) + _ = qml.RX(y, 0) ** 2 + return qml.expval(qml.Z(0) + qml.Z(0)) return grad_f(circuit)(x) - out = f(0.5) - expected = -2 * jax.numpy.sin(2 * 0.5) # includes the factors of 2 from doubling the angle. - assert qml.math.allclose(out, expected) + jaxpr = jax.make_jaxpr(f)(0.5) + + if grad_f == qml.grad: + assert jaxpr.eqns[0].primitive == qml.capture.primitives.grad_prim + else: + assert jaxpr.eqns[0].primitive == qml.capture.primitives.jacobian_prim + grad_jaxpr = jaxpr.eqns[0].params["jaxpr"] + qfunc_jaxpr = grad_jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[1].primitive == qml.RX._primitive # eqn 0 is mul + assert qfunc_jaxpr.eqns[2].primitive == qml.Z._primitive + assert qfunc_jaxpr.eqns[3].primitive == qml.ops.SProd._primitive From d3f729e2f2e82811dddbf43918ab516c717593e6 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 12 Nov 2024 09:29:40 -0500 Subject: [PATCH 23/45] starting to write tests --- pennylane/devices/qubit/dq_interpreter.py | 128 +++++++++++++-------- tests/devices/qubit/test_dq_interpreter.py | 93 +++++++++++++++ 2 files changed, 176 insertions(+), 45 deletions(-) create mode 100644 tests/devices/qubit/test_dq_interpreter.py diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index a6d20eb303f..ce120becf86 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -19,6 +19,7 @@ import jax import numpy as np +from pennylane.capture import disable, enable from pennylane.capture.base_interpreter import PlxprInterpreter from pennylane.capture.primitives import ( adjoint_transform_prim, @@ -28,7 +29,7 @@ measure_prim, while_loop_prim, ) -from pennylane.measurements import MidMeasureMP, Shots +from pennylane.measurements import MeasurementValue, MidMeasureMP, Shots from .apply_operation import apply_operation from .initialize_state import create_initial_state @@ -39,8 +40,13 @@ class DefaultQubitInterpreter(PlxprInterpreter): """Implements a class for interpreting plxpr using default qubit. + Args: + num_wires (int): the numberof wires to initialize the state with + shots (int | None): the number of shots to use for the execution + key (None, jax.numpy.ndarray): the ``PRNGKey`` to use for random number generation. + >>> key = jax.random.PRNGKey(1234) - >>> dq = DefaultQubitInterpreter(num_wires=2, shots=qml.measurements.Shots(50), key=key) + >>> dq = DefaultQubitInterpreter(num_wires=2, shots=None, key=key) >>> @qml.for_loop(2) ... def g(i,y): ... qml.RX(y,0) @@ -49,87 +55,122 @@ class DefaultQubitInterpreter(PlxprInterpreter): ... g(x) ... return qml.expval(qml.Z(0)) >>> dq(f)(0.5) - Array(-0.79999995, dtype=float32) + Array(0.54030231, dtype=float64) + + This execution can be differentiated via backprop and jitted as normal. Note that finite shot executions + still cannot be differented with backprop. + + >>> jax.grad(dq(f))(jax.numpy.array(0.5)) + Array(-1.68294197, dtype=float64, weak_type=True) + >>> jax.jit(dq(f))(jax.numpy.array(0.5)) + Array(0.54030231, dtype=float64) """ - def __init__(self, num_wires, shots, key: None | jax.numpy.ndarray = None): + def __init__(self, num_wires: int, shots: int | None, key: None | jax.numpy.ndarray = None): self.num_wires = num_wires self.shots = Shots(shots) + if self.shots.has_partitioned_shots: + raise NotImplementedError( + "DefaultQubitInterpreter does not yet support partitioned shots." + ) if key is None: - key = jax.random.PRNGKey(np.random.random()) - self.stateref = {"state": None, "key": key, "mcms": None} + key = jax.random.PRNGKey(np.random.randint(100000)) + + self.initial_key = key + self.stateref = None + super().__init__() @property - def state(self) -> jax.numpy.ndarray: - return self.stateref["state"] + def state(self) -> None | jax.numpy.ndarray: + """The current state of the system. None if not initialized.""" + return self.stateref["state"] if self.stateref else None @state.setter - def state(self, value: jax.numpy.ndarray): + def state(self, value: jax.numpy.ndarray | None): + if self.stateref is None: + raise ValueError("execution not yet initialized.") self.stateref["state"] = value @property def key(self) -> jax.numpy.ndarray: - return self.stateref["key"] - - @property - def mcms(self): - return self.stateref["mcms"] + """A jax PRNGKey. ``initial_key`` if not yet initialized.""" + return self.stateref["key"] if self.stateref else self.initial_key @key.setter def key(self, value): + if self.stateref is None: + raise ValueError("execution not yet initialized.") self.stateref["key"] = value + @property + def mcms(self) -> None | dict[MeasurementValue, bool]: + """The mid circuit measurements. ``NOne`` if not yet initialized.""" + return self.stateref["mcms"] if self.stateref else None + def setup(self): - if self.state is None: - self.state = create_initial_state(range(self.num_wires), like="jax") - if self.mcms is None: - self.stateref["mcms"] = {} + if self.stateref is None: + self.stateref = { + "state": create_initial_state(range(self.num_wires), like="jax"), + "key": self.initial_key, + "mcms": {}, + } + + def cleanup(self) -> None: + self.stateref = None def interpret_operation(self, op): self.state = apply_operation(op, self.state) - def interpret_measurement_eqn(self, primitive, *invals, **params): - mp = primitive.impl(*invals, **params) - - if self.shots: - self.key, new_key = jax.random.split(self.key, 2) - # note that this does *not* group commuting measurements - # further work could figure out how to perform multiple measurements at the same time - return measure_with_samples([mp], self.state, shots=self.shots, prng_key=new_key)[0] - return measure(mp, self.state) + def interpret_measurement(self, measurement): + disable() + try: # measurements can sometimes create intermediary mps + if self.shots: + self.key, new_key = jax.random.split(self.key, 2) + # note that this does *not* group commuting measurements + # further work could figure out how to perform multiple measurements at the same time + output = measure_with_samples( + [measurement], self.state, shots=self.shots, prng_key=new_key + )[0] + else: + output = measure(measurement, self.state) + finally: + enable() + return output # pylint: disable=unused-argument @DefaultQubitInterpreter.register_primitive(adjoint_transform_prim) def _(self, *invals, jaxpr, n_consts, lazy=True): - raise NotImplementedError("TODO?") + raise NotImplementedError +# pylint: disable=too-many-arguments @DefaultQubitInterpreter.register_primitive(ctrl_transform_prim) def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): - raise NotImplementedError("TODO?") + raise NotImplementedError +# pylint: disable=too-many-arguments @DefaultQubitInterpreter.register_primitive(for_loop_prim) -def _(self, *invals, jaxpr_body_fn, n_consts): - start, stop, step = invals[0], invals[1], invals[2] - consts = invals[3 : 3 + n_consts] - init_state = invals[3 + n_consts :] +def _(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice): + consts = invals[consts_slice] + init_state = invals[args_slice] - res = None + res = init_state for i in range(start, stop, step): res = copy(self).eval(jaxpr_body_fn, consts, i, *init_state) return res +# pylint: disable=too-many-arguments @DefaultQubitInterpreter.register_primitive(while_loop_prim) -def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, n_consts_body, n_consts_cond): - consts_body = invals[:n_consts_body] - consts_cond = invals[n_consts_body : n_consts_body + n_consts_cond] - init_state = invals[n_consts_body + n_consts_cond :] +def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice): + consts_body = invals[body_slice] + consts_cond = invals[cond_slice] + init_state = invals[args_slice] fn_res = init_state while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]: @@ -147,16 +188,13 @@ def _(self, *invals, reset, postselect): @DefaultQubitInterpreter.register_primitive(cond_prim) -def _(self, *invals, jaxpr_branches, n_consts_per_branch, n_args): +def _(self, *invals, jaxpr_branches, consts_slices, args_slice): n_branches = len(jaxpr_branches) conditions = invals[:n_branches] - consts_flat = invals[n_branches + n_args :] - args = invals[n_branches : n_branches + n_args] + args = invals[args_slice] - start = 0 - for pred, jaxpr, n_consts in zip(conditions, jaxpr_branches, n_consts_per_branch): - consts = consts_flat[start : start + n_consts] - start += n_consts + for pred, jaxpr, const_slice in zip(conditions, jaxpr_branches, consts_slices): + consts = invals[const_slice] if pred and jaxpr is not None: return copy(self).eval_jaxpr(jaxpr, consts, *args) return () diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py new file mode 100644 index 00000000000..977acb07d77 --- /dev/null +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -0,0 +1,93 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module tests the default qubit interpreter. +""" +import pytest + +import pennylane as qml + +jax = pytest.importorskip("jax") +pytestmark = pytest.mark.jax + +# must be below the importorskip +# pylint: disable=wrong-import-position +from pennylane.devices.qubit.dq_interpreter import DefaultQubitInterpreter + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + qml.capture.enable() + yield + qml.capture.disable() + + +def test_initialization(): + dq = DefaultQubitInterpreter(num_wires=3, shots=None) + assert dq.num_wires == 3 + assert dq.shots == qml.measurements.Shots(None) + assert isinstance(dq.initial_key, jax.numpy.ndarray) + assert dq.stateref is None + + +def test_setup(): + key = jax.random.PRNGKey(1234) + dq = DefaultQubitInterpreter(num_wires=2, shots=2, key=key) + assert dq.stateref is None + + dq.setup() + assert isinstance(dq.stateref, dict) + assert list(dq.stateref.keys()) == ["state", "key", "mcms"] + + assert dq.stateref["key"] is key + assert dq.key is key + + assert dq.stateref["mcms"] == {} + assert dq.mcms is dq.stateref["mcms"] + + assert dq.state is dq.stateref["state"] + expected = jax.numpy.array([[1.0, 0.0], [0.0, 0.0]], dtype=complex) + assert qml.math.allclose(dq.state, expected) + + dq.cleanup() + + +def test_simple_execution(): + + @DefaultQubitInterpreter(num_wires=1, shots=None) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + res = f(0.5) + assert qml.math.allclose(res, jax.numpy.cos(0.5)) + + g = jax.grad(f)(jax.numpy.array(0.5)) + assert qml.math.allclose(g, -jax.numpy.sin(0.5)) + + +def test_sampling(): + + @DefaultQubitInterpreter(num_wires=2, shots=10) + def sampler(): + qml.X(0) + return qml.sample(wires=(0, 1)) + + results = sampler() + + expected0 = jax.numpy.ones((10,)) # zero wire + expected1 = jax.numpy.zeros((10,)) # one wire + expected = jax.numpy.hstack([expected0, expected1]).T + + assert qml.math.allclose(results, expected) From 879bdadc8be7cbe8a318ad99050b63fe07c59881 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 12 Nov 2024 15:41:39 -0500 Subject: [PATCH 24/45] more testing --- pennylane/devices/qubit/dq_interpreter.py | 52 ++-- tests/devices/qubit/test_dq_interpreter.py | 292 +++++++++++++++++++-- 2 files changed, 307 insertions(+), 37 deletions(-) diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index ce120becf86..9e3fd9ed350 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -29,7 +29,7 @@ measure_prim, while_loop_prim, ) -from pennylane.measurements import MeasurementValue, MidMeasureMP, Shots +from pennylane.measurements import MidMeasureMP, Shots from .apply_operation import apply_operation from .initialize_state import create_initial_state @@ -56,6 +56,9 @@ class DefaultQubitInterpreter(PlxprInterpreter): ... return qml.expval(qml.Z(0)) >>> dq(f)(0.5) Array(0.54030231, dtype=float64) + >>> jaxpr = jax.make_jaxpr(f)(0.5) + >>> dq.eval(jaxpr.jaxpr, jaxpr.consts, 0.5) + Array(0.54030231, dtype=float64) This execution can be differentiated via backprop and jitted as normal. Note that finite shot executions still cannot be differented with backprop. @@ -68,7 +71,9 @@ class DefaultQubitInterpreter(PlxprInterpreter): """ - def __init__(self, num_wires: int, shots: int | None, key: None | jax.numpy.ndarray = None): + def __init__( + self, num_wires: int, shots: int | None = None, key: None | jax.numpy.ndarray = None + ): self.num_wires = num_wires self.shots = Shots(shots) if self.shots.has_partitioned_shots: @@ -104,28 +109,32 @@ def key(self, value): raise ValueError("execution not yet initialized.") self.stateref["key"] = value - @property - def mcms(self) -> None | dict[MeasurementValue, bool]: - """The mid circuit measurements. ``NOne`` if not yet initialized.""" - return self.stateref["mcms"] if self.stateref else None - - def setup(self): + def setup(self) -> None: if self.stateref is None: self.stateref = { "state": create_initial_state(range(self.num_wires), like="jax"), "key": self.initial_key, - "mcms": {}, } + # else set by copying a parent interpreter and we need to modify same stateref def cleanup(self) -> None: self.stateref = None + # Open question: should we update initial key? def interpret_operation(self, op): self.state = apply_operation(op, self.state) + def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): + if "mcm" in eqn.primitive.name: + raise NotImplementedError( + "DefaultQubitInterpreter does not yet support postprocessing mcms" + ) + return super().interpret_measurement_eqn(eqn) + def interpret_measurement(self, measurement): + # measurements can sometimes create intermediary mps, but those intermediaries will not work with capture enabled disable() - try: # measurements can sometimes create intermediary mps + try: if self.shots: self.key, new_key = jax.random.split(self.key, 2) # note that this does *not* group commuting measurements @@ -140,15 +149,26 @@ def interpret_measurement(self, measurement): return output +@DefaultQubitInterpreter.register_primitive(measure_prim) +def _(self, *invals, reset, postselect): + mp = MidMeasureMP(invals, reset=reset, postselect=postselect) + self.key, new_key = jax.random.split(self.key, 2) + mcms = {} + self.state = apply_operation(mp, self.state, mid_measurements=mcms, prng_key=new_key) + return mcms[mp] + + # pylint: disable=unused-argument @DefaultQubitInterpreter.register_primitive(adjoint_transform_prim) def _(self, *invals, jaxpr, n_consts, lazy=True): + # TODO: requires jaxpr -> list of ops first raise NotImplementedError # pylint: disable=too-many-arguments @DefaultQubitInterpreter.register_primitive(ctrl_transform_prim) def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): + # TODO: requires jaxpr -> list of ops first raise NotImplementedError @@ -160,7 +180,7 @@ def _(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice) res = init_state for i in range(start, stop, step): - res = copy(self).eval(jaxpr_body_fn, consts, i, *init_state) + res = copy(self).eval(jaxpr_body_fn, consts, i, *res) return res @@ -179,14 +199,6 @@ def _(self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_ return fn_res -@DefaultQubitInterpreter.register_primitive(measure_prim) -def _(self, *invals, reset, postselect): - mp = MidMeasureMP(invals, reset=reset, postselect=postselect) - self.key, new_key = jax.random.split(self.key, 2) - self.state = apply_operation(mp, self.state, mid_measurements=self.mcms, prng_key=new_key) - return self.mcms[mp] - - @DefaultQubitInterpreter.register_primitive(cond_prim) def _(self, *invals, jaxpr_branches, consts_slices, args_slice): n_branches = len(jaxpr_branches) @@ -196,5 +208,5 @@ def _(self, *invals, jaxpr_branches, consts_slices, args_slice): for pred, jaxpr, const_slice in zip(conditions, jaxpr_branches, consts_slices): consts = invals[const_slice] if pred and jaxpr is not None: - return copy(self).eval_jaxpr(jaxpr, consts, *args) + return copy(self).eval(jaxpr, consts, *args) return () diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py index 977acb07d77..2262caf8b2b 100644 --- a/tests/devices/qubit/test_dq_interpreter.py +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -16,11 +16,13 @@ """ import pytest -import pennylane as qml - jax = pytest.importorskip("jax") pytestmark = pytest.mark.jax +from jax import numpy as jnp # pylint: disable=wrong-import-position + +import pennylane as qml # pylint: disable=wrong-import-position + # must be below the importorskip # pylint: disable=wrong-import-position from pennylane.devices.qubit.dq_interpreter import DefaultQubitInterpreter @@ -34,6 +36,7 @@ def enable_disable_plxpr(): def test_initialization(): + """Test that relevant properties are set on initialization.""" dq = DefaultQubitInterpreter(num_wires=3, shots=None) assert dq.num_wires == 3 assert dq.shots == qml.measurements.Shots(None) @@ -41,29 +44,36 @@ def test_initialization(): assert dq.stateref is None -def test_setup(): +def test_no_partitioned_shots(): + """Test that an error is raised if partitioned shots is requested.""" + + with pytest.raises(NotImplementedError, match="does not yet support partitioned shots"): + DefaultQubitInterpreter(num_wires=1, shots=(100, 100, 100)) + + +def test_setup_and_cleanup(): + """Test setup initializes the stateref dictionary and cleanup removes it.""" key = jax.random.PRNGKey(1234) dq = DefaultQubitInterpreter(num_wires=2, shots=2, key=key) assert dq.stateref is None dq.setup() assert isinstance(dq.stateref, dict) - assert list(dq.stateref.keys()) == ["state", "key", "mcms"] + assert list(dq.stateref.keys()) == ["state", "key"] assert dq.stateref["key"] is key assert dq.key is key - assert dq.stateref["mcms"] == {} - assert dq.mcms is dq.stateref["mcms"] - assert dq.state is dq.stateref["state"] expected = jax.numpy.array([[1.0, 0.0], [0.0, 0.0]], dtype=complex) assert qml.math.allclose(dq.state, expected) dq.cleanup() + assert dq.stateref is None def test_simple_execution(): + """Test the execution, jitting, and gradient of a simple quantum circuit.""" @DefaultQubitInterpreter(num_wires=1, shots=None) def f(x): @@ -73,21 +83,269 @@ def f(x): res = f(0.5) assert qml.math.allclose(res, jax.numpy.cos(0.5)) + jit_res = jax.jit(f)(0.5) + assert qml.math.allclose(jit_res, res) + g = jax.grad(f)(jax.numpy.array(0.5)) assert qml.math.allclose(g, -jax.numpy.sin(0.5)) -def test_sampling(): +def test_capture_remains_enabled_if_measurement_error(): + """Test that capture remains enabled if there is a measurement error.""" + + @DefaultQubitInterpreter(num_wires=1, shots=None) + def g(): + return qml.sample(wires=0) # sampling with analytic execution. + + with pytest.raises(NotImplementedError): + g() + + assert qml.capture.enabled() + + +def test_pytree_function_output(): + """Test that the results respect the pytree output of the function.""" + + @DefaultQubitInterpreter(num_wires=1, shots=None) + def g(): + return { + "probs": qml.probs(wires=0), + "state": qml.state(), + "var_Z": qml.var(qml.Z(0)), + "var_X": qml.var(qml.X(0)), + } + + res = g() + assert qml.math.allclose(res["probs"], [1.0, 0.0]) + assert qml.math.allclose(res["state"], [1.0, 0.0 + 0j]) + assert qml.math.allclose(res["var_Z"], 0.0) + assert qml.math.allclose(res["var_X"], 1.0) + + +class TestSampling: + """Test cases for generating samples.""" + + def test_known_sampling(self): + """Test sampling output with deterministic sampling output""" + + @DefaultQubitInterpreter(num_wires=2, shots=10) + def sampler(): + qml.X(0) + return qml.sample(wires=(0, 1)) + + results = sampler() + + expected0 = jax.numpy.ones((10,)) # zero wire + expected1 = jax.numpy.zeros((10,)) # one wire + expected = jax.numpy.vstack([expected0, expected1]).T + + assert qml.math.allclose(results, expected) + + def test_same_key_same_results(self): + """Test that two circuits with the same key give identical results.""" + key = jax.random.PRNGKey(1234) + + @DefaultQubitInterpreter(num_wires=1, shots=100, key=key) + def circuit1(): + qml.Hadamard(0) + return qml.sample(wires=0) + + @DefaultQubitInterpreter(num_wires=1, shots=100, key=key) + def circuit2(): + qml.Hadamard(0) + return qml.sample(wires=0) + + res1 = circuit1() + res2 = circuit2() + + assert qml.math.allclose(res1, res2) + + @pytest.mark.parametrize("mcm_value", (0, 1)) + def test_return_mcm(self, mcm_value): + """Test that the interpreter can return the result of mid circuit measurements""" + + @DefaultQubitInterpreter(num_wires=1) + def f(): + if mcm_value: + qml.X(0) + return qml.measure(0) + + output = f() + assert qml.math.allclose(output, mcm_value) + + def test_mcm_depends_on_key(self): + """Test that the value of an mcm depends on the key.""" + + def get_mcm_from_key(key): + @DefaultQubitInterpreter(num_wires=1, key=key) + def f(): + qml.H(0) + return qml.measure(0) + + return f() + + for key in range(0, 100, 10): + m1 = get_mcm_from_key(jax.random.PRNGKey(key)) + m2 = get_mcm_from_key(jax.random.PRNGKey(key)) + assert qml.math.allclose(m1, m2) + + samples = [int(get_mcm_from_key(jax.random.PRNGKey(key))) for key in range(0, 100, 1)] + assert set(samples) == {0, 1} + + def test_classical_transformation_mcm_value(self): + """Test that mid circuit measurements can be used in classical manipulations.""" + + @DefaultQubitInterpreter(num_wires=1) + def f(): + qml.X(0) + m0 = qml.measure(0) # 1 + qml.X(0) # reset to 0 + qml.RX(2 * m0, wires=0) + return qml.expval(qml.Z(0)) + + expected = jax.numpy.cos(2.0) + assert qml.math.allclose(f(), expected) + + @pytest.mark.parametrize("mp_type", (qml.sample, qml.expval, qml.probs)) + def test_mcm_measurements_not_yet_implemented(self, mp_type): + """Test that measurements of mcms are not yet implemented""" + + @DefaultQubitInterpreter(num_wires=1) + def f(): + m0 = qml.measure(0) + if mp_type == qml.probs: + return mp_type(op=m0) + return mp_type(m0) + + with pytest.raises(NotImplementedError): + f() + + +class TestQuantumHOP: + """Tests for the quantum higher order primitives: adjoint and ctrl.""" + + def test_adjoint_transform(self): + """Test that the adjoint_transform is not yet implemented.""" + + @DefaultQubitInterpreter(num_wires=1, shots=None) + def circuit(x): + qml.adjoint(qml.RX)(x, 0) + return 1 + + with pytest.raises(NotImplementedError): + circuit(0.5) + + def test_ctrl_transform(self): + """Test that the ctrl_transform is not yet implemented.""" + + @DefaultQubitInterpreter(num_wires=2, shots=None) + def circuit(): + qml.ctrl(qml.X, control=1)(0) + + with pytest.raises(NotImplementedError): + circuit() + + +class TestClassicalComponents: + """Test execution of classical components.""" + + def test_classical_operations_in_circuit(self): + """Test that we can have classical operations in the circuit.""" + + @DefaultQubitInterpreter(num_wires=1) + def f(x, y, w): + qml.RX(2 * x + y, wires=w - 1) + return qml.expval(qml.Z(0)) + + x = jax.numpy.array(0.5) + y = jax.numpy.array(1.2) + w = jax.numpy.array(1) + + output = f(x, y, w) + expected = jax.numpy.cos(2 * x + y) + assert qml.math.allclose(output, expected) + + def test_for_loop(self): + """Test that the for loop can be executed.""" + + @DefaultQubitInterpreter(num_wires=4) + def f(y): + @qml.for_loop(4) + def f(i, x): + qml.RX(x, i) + return x + 0.1 + + f(y) + return [qml.expval(qml.Z(i)) for i in range(4)] + + output = f(1.0) + assert len(output) == 4 + assert qml.math.allclose(output[0], jax.numpy.cos(1.0)) + assert qml.math.allclose(output[1], jax.numpy.cos(1.1)) + assert qml.math.allclose(output[2], jax.numpy.cos(1.2)) + assert qml.math.allclose(output[3], jax.numpy.cos(1.3)) + + def test_while_loop(self): + """Test that the while loop can be executed.""" + + @DefaultQubitInterpreter(num_wires=4) + def f(): + def cond_fn(i): + return i < 4 + + @qml.while_loop(cond_fn) + def f(i): + qml.X(i) + return i + 1 + + f(0) + return [qml.expval(qml.Z(i)) for i in range(4)] + + output = f() + assert qml.math.allclose(output, [-1, -1, -1, -1]) + + def test_cond_boolean(self): + """Test that cond can be used with normal classical values.""" + + def true_fn(x): + qml.RX(x, 0) + return x + 1 + + def false_fn(x): + return 2 * x + + @DefaultQubitInterpreter(num_wires=1) + def f(x, val): + out = qml.cond(val, true_fn, false_fn)(x) + return qml.probs(wires=0), out + + output_true = f(0.5, True) + expected0 = [jax.numpy.cos(0.5 / 2) ** 2, jax.numpy.sin(0.5 / 2) ** 2] + assert qml.math.allclose(output_true[0], expected0) + assert qml.math.allclose(output_true[1], 1.5) # 0.5 + 1 + + output_false = f(0.5, False) + assert qml.math.allclose(output_false[0], [1.0, 0.0]) + assert qml.math.allclose(output_false[1], 1.0) # 2 * 0.5 + + def test_cond_mcm(self): + """Test that cond can be used with the output of mcms.""" - @DefaultQubitInterpreter(num_wires=2, shots=10) - def sampler(): - qml.X(0) - return qml.sample(wires=(0, 1)) + def true_fn(y): + qml.RX(y, 0) - results = sampler() + # pylint: disable=unused-argument + def false_fn(y): + qml.X(0) - expected0 = jax.numpy.ones((10,)) # zero wire - expected1 = jax.numpy.zeros((10,)) # one wire - expected = jax.numpy.hstack([expected0, expected1]).T + @DefaultQubitInterpreter(num_wires=1, shots=None) + def g(x): + qml.X(0) + m0 = qml.measure(0) + qml.X(0) + qml.cond(m0, true_fn, false_fn)(x) + return qml.probs(wires=0) - assert qml.math.allclose(results, expected) + output = g(0.5) + expected = [jnp.cos(0.5 / 2) ** 2, jnp.sin(0.5 / 2) ** 2] + assert qml.math.allclose(output, expected) From b80d290be23d0ff4388abe1eebb91948104a1c58 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 13 Nov 2024 12:56:50 -0500 Subject: [PATCH 25/45] more tests --- pennylane/devices/qubit/dq_interpreter.py | 4 +- pennylane/measurements/mid_measure.py | 4 +- pennylane/ops/op_math/pow.py | 2 - tests/devices/qubit/test_dq_interpreter.py | 160 +++++++++++++++++++++ 4 files changed, 163 insertions(+), 7 deletions(-) diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index 9e3fd9ed350..c6e90ddeea4 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -95,7 +95,7 @@ def state(self) -> None | jax.numpy.ndarray: @state.setter def state(self, value: jax.numpy.ndarray | None): if self.stateref is None: - raise ValueError("execution not yet initialized.") + raise AttributeError("execution not yet initialized.") self.stateref["state"] = value @property @@ -106,7 +106,7 @@ def key(self) -> jax.numpy.ndarray: @key.setter def key(self, value): if self.stateref is None: - raise ValueError("execution not yet initialized.") + raise AttributeError("execution not yet initialized.") self.stateref["key"] = value def setup(self) -> None: diff --git a/pennylane/measurements/mid_measure.py b/pennylane/measurements/mid_measure.py index afe89e680da..5cdcd8cd708 100644 --- a/pennylane/measurements/mid_measure.py +++ b/pennylane/measurements/mid_measure.py @@ -25,9 +25,7 @@ from .measurements import MeasurementProcess, MidMeasure -def measure( - wires: Union[Hashable, Wires], reset: Optional[bool] = False, postselect: Optional[int] = None -): +def measure(wires: Union[Hashable, Wires], reset: bool = False, postselect: Optional[int] = None): r"""Perform a mid-circuit measurement in the computational basis on the supplied qubit. diff --git a/pennylane/ops/op_math/pow.py b/pennylane/ops/op_math/pow.py index 0f7149211e8..e3ff963971f 100644 --- a/pennylane/ops/op_math/pow.py +++ b/pennylane/ops/op_math/pow.py @@ -393,8 +393,6 @@ def simplify(self) -> Union["Pow", Identity]: op = qml.prod(*ops) if len(ops) > 1 else ops[0] return op if qml.capture.enabled() else op.simplify() except PowUndefinedError: - if qml.capture.enabled(): - return Pow(base.simplify(), z=self.z) return Pow(base=base, z=self.z) diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py index 2262caf8b2b..98b50d441fb 100644 --- a/tests/devices/qubit/test_dq_interpreter.py +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -72,6 +72,23 @@ def test_setup_and_cleanup(): assert dq.stateref is None +def test_working_state_key_before_setup(): + """Test that state and key can't be accessed before setup.""" + + key = jax.random.PRNGKey(9876) + + dq = DefaultQubitInterpreter(num_wires=1, key=key) + + assert dq.state is None + assert dq.key is key + + with pytest.raises(AttributeError, match="execution not yet initialized"): + dq.state = [1.0, 0.0] + + with pytest.raises(AttributeError, match="execution not yet initialized"): + dq.key = jax.random.PRNGKey(8765) + + def test_simple_execution(): """Test the execution, jitting, and gradient of a simple quantum circuit.""" @@ -122,6 +139,34 @@ def g(): assert qml.math.allclose(res["var_X"], 1.0) +def test_mcm_reset(): + """Test that mid circuit measurements can reset the state.""" + + @DefaultQubitInterpreter(num_wires=1) + def f(): + qml.X(0) + qml.measure(0, reset=True) + return qml.state() + + out = f() + assert qml.math.allclose(out, jnp.array([1.0, 0.0])) # reset into zero state. + + +def test_operator_arithmetic(): + """Test that dq can execute operator arithmetic.""" + + @DefaultQubitInterpreter(num_wires=2) + def f(x): + qml.RY(1.0, 0) + qml.adjoint(qml.RY(x, 0)) + _ = qml.SX(1) ** 2 + return qml.expval(qml.Z(0) + 2 * qml.Z(1)) + + output = f(0.5) + expected = jnp.cos(1 - 0.5) - 2 * 1 + assert qml.math.allclose(output, expected) + + class TestSampling: """Test cases for generating samples.""" @@ -220,6 +265,38 @@ def f(): with pytest.raises(NotImplementedError): f() + def test_mcms_not_all_same_key(self): + """Test that each mid circuit measurement has a different key and can have different options.""" + + @DefaultQubitInterpreter(num_wires=1, shots=None, key=jax.random.PRNGKey(87665)) + def g(): + qml.Hadamard(0) + m0 = qml.measure(0, reset=0) + qml.Hadamard(0) + m1 = qml.measure(0, reset=0) + qml.Hadamard(0) + m2 = qml.measure(0, reset=0) + qml.Hadamard(0) + m3 = qml.measure(0, reset=0) + qml.Hadamard(0) + m4 = qml.measure(0, reset=0) + return m0, m1, m2, m3, m4 + + output = g() + assert not all(qml.math.allclose(output[0], output[i]) for i in range(1, 5)) + # only way we could different values for some mcms is if they had different seeds + + def test_each_measurement_has_different_key(self): + """Test that each sampling measurement is performed with a different key.""" + + @DefaultQubitInterpreter(num_wires=1, shots=100, key=jax.random.PRNGKey(87665)) + def g(): + qml.Hadamard(0) + return qml.sample(wires=0), qml.sample(wires=0) + + res1, res2 = g() + assert not qml.math.allclose(res1, res2) + class TestQuantumHOP: """Tests for the quantum higher order primitives: adjoint and ctrl.""" @@ -285,6 +362,23 @@ def f(i, x): assert qml.math.allclose(output[2], jax.numpy.cos(1.2)) assert qml.math.allclose(output[3], jax.numpy.cos(1.3)) + def test_for_loop_consts(self): + """Test that the for_loop can be executed properly when it has closure variables.""" + + @DefaultQubitInterpreter(num_wires=2) + def g(x): + @qml.for_loop(2) + def f(i): + qml.RX(x, i) # x is closure variable + + f() + return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)) + + res1, res2 = g(jax.numpy.array(-0.654)) + expected = jnp.cos(-0.654) + assert qml.math.allclose(res1, expected) + assert qml.math.allclose(res2, expected) + def test_while_loop(self): """Test that the while loop can be executed.""" @@ -304,6 +398,26 @@ def f(i): output = f() assert qml.math.allclose(output, [-1, -1, -1, -1]) + def test_while_loop_with_consts(self): + """Test that both the cond_fn and body_fn can contain constants with the while loop.""" + + @DefaultQubitInterpreter(num_wires=2, shots=None, key=jax.random.PRNGKey(87665)) + def g(x, target): + def cond_fn(i): + return i < target + + @qml.while_loop(cond_fn) + def f(i): + qml.RX(x, 0) + return i + 1 + + f(0) + return qml.expval(qml.Z(0)) + + output = g(jnp.array(1.2), jnp.array(2)) + + assert qml.math.allclose(output, jnp.cos(2 * 1.2)) + def test_cond_boolean(self): """Test that cond can be used with normal classical values.""" @@ -349,3 +463,49 @@ def g(x): output = g(0.5) expected = [jnp.cos(0.5 / 2) ** 2, jnp.sin(0.5 / 2) ** 2] assert qml.math.allclose(output, expected) + + def test_cond_false_no_false_fn(self): + """Test nothing is returned when the false_fn is not provided but the condition is false.""" + + def true_fn(w): + qml.X(w) + + @DefaultQubitInterpreter(num_wires=1) + def g(condition): + qml.cond(condition, true_fn)(0) + return qml.expval(qml.Z(0)) + + out = g(False) + assert qml.math.allclose(out, 1.0) + + def test_condition_with_consts(self): + """Test that each branch in a condition can contain consts.""" + + @DefaultQubitInterpreter(num_wires=1) + def circuit(x, y, z, condition0, condition1): + + def true_fn(): + qml.RX(x, 0) + + def false_fn(): + qml.RX(y, 0) + + def elif_fn(): + qml.RX(z, 0) + + qml.cond(condition0, true_fn, false_fn=false_fn, elifs=((condition1, elif_fn),))() + + return qml.expval(qml.Z(0)) + + x = jax.numpy.array(0.3) + y = jax.numpy.array(0.6) + z = jax.numpy.array(1.2) + + res0 = circuit(x, y, z, True, False) + assert qml.math.allclose(res0, jnp.cos(x)) + + res1 = circuit(x, y, z, False, True) + assert qml.math.allclose(res1, jnp.cos(z)) # elif branch = z + + res2 = circuit(x, y, z, False, False) + assert qml.math.allclose(res2, jnp.cos(y)) # false fn = y From d0298ad22364bbab22488c27fac125d0254f854f Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 15 Nov 2024 16:01:57 -0500 Subject: [PATCH 26/45] Apply suggestions from code review Co-authored-by: David Wierichs Co-authored-by: Pietropaolo Frisoni --- pennylane/devices/qubit/dq_interpreter.py | 15 +++++++-------- tests/devices/qubit/test_dq_interpreter.py | 4 ++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index c6e90ddeea4..de0773f45db 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -38,13 +38,14 @@ class DefaultQubitInterpreter(PlxprInterpreter): - """Implements a class for interpreting plxpr using default qubit. + """Implements a class for interpreting plxpr using python simulation tools. Args: - num_wires (int): the numberof wires to initialize the state with - shots (int | None): the number of shots to use for the execution + num_wires (int): the number of wires to initialize the state with + shots (int | None): the number of shots to use for the execution. Shot vectors are not supported yet. key (None, jax.numpy.ndarray): the ``PRNGKey`` to use for random number generation. - +>>> from pennylane.devices.qubit.dq_interpreter import DefaultQubitInterpreter +>>> qml.capture.enable() >>> key = jax.random.PRNGKey(1234) >>> dq = DefaultQubitInterpreter(num_wires=2, shots=None, key=key) >>> @qml.for_loop(2) @@ -61,14 +62,12 @@ class DefaultQubitInterpreter(PlxprInterpreter): Array(0.54030231, dtype=float64) This execution can be differentiated via backprop and jitted as normal. Note that finite shot executions - still cannot be differented with backprop. + still cannot be differentiated with backprop. >>> jax.grad(dq(f))(jax.numpy.array(0.5)) Array(-1.68294197, dtype=float64, weak_type=True) >>> jax.jit(dq(f))(jax.numpy.array(0.5)) Array(0.54030231, dtype=float64) - - """ def __init__( @@ -89,7 +88,7 @@ def __init__( @property def state(self) -> None | jax.numpy.ndarray: - """The current state of the system. None if not initialized.""" + """The current state of the system. None if not initialized.""" return self.stateref["state"] if self.stateref else None @state.setter diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py index 98b50d441fb..147df59dffd 100644 --- a/tests/devices/qubit/test_dq_interpreter.py +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -266,7 +266,7 @@ def f(): f() def test_mcms_not_all_same_key(self): - """Test that each mid circuit measurement has a different key and can have different options.""" + """Test that each mid circuit measurement has a different key.""" @DefaultQubitInterpreter(num_wires=1, shots=None, key=jax.random.PRNGKey(87665)) def g(): @@ -284,7 +284,7 @@ def g(): output = g() assert not all(qml.math.allclose(output[0], output[i]) for i in range(1, 5)) - # only way we could different values for some mcms is if they had different seeds + # only way we could get different values between the mcms is if they had different seeds def test_each_measurement_has_different_key(self): """Test that each sampling measurement is performed with a different key.""" From 07fae7bba71a08d4cde37d0b5bf0ac00be00bcd3 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 18 Nov 2024 09:58:23 -0500 Subject: [PATCH 27/45] update initial key each execution --- pennylane/devices/qubit/dq_interpreter.py | 9 ++++++--- tests/devices/qubit/test_dq_interpreter.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index de0773f45db..7299ad4faa8 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -44,8 +44,11 @@ class DefaultQubitInterpreter(PlxprInterpreter): num_wires (int): the number of wires to initialize the state with shots (int | None): the number of shots to use for the execution. Shot vectors are not supported yet. key (None, jax.numpy.ndarray): the ``PRNGKey`` to use for random number generation. ->>> from pennylane.devices.qubit.dq_interpreter import DefaultQubitInterpreter ->>> qml.capture.enable() + + + >>> from pennylane.devices.qubit.dq_interpreter import DefaultQubitInterpreter + >>> qml.capture.enable() + >>> import jax >>> key = jax.random.PRNGKey(1234) >>> dq = DefaultQubitInterpreter(num_wires=2, shots=None, key=key) >>> @qml.for_loop(2) @@ -117,8 +120,8 @@ def setup(self) -> None: # else set by copying a parent interpreter and we need to modify same stateref def cleanup(self) -> None: + self.initial_key = self.key # be cautious of leaked tracers, but we should be fine. self.stateref = None - # Open question: should we update initial key? def interpret_operation(self, op): self.state = apply_operation(op, self.state) diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py index 147df59dffd..bc857b58975 100644 --- a/tests/devices/qubit/test_dq_interpreter.py +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -297,6 +297,18 @@ def g(): res1, res2 = g() assert not qml.math.allclose(res1, res2) + def test_more_executions_same_interpreter_different_results(self): + """Test that if multiple executions occur with the same interpreter, they will have different results.""" + + @DefaultQubitInterpreter(num_wires=1, shots=100, key=jax.random.PRNGKey(76543)) + def f(): + qml.Hadamard(0) + return qml.sample(0) + + s1 = f() + s2 = f() # should be done with different key, leading to different results. + assert not qml.math.allclose(s1, s2) + class TestQuantumHOP: """Tests for the quantum higher order primitives: adjoint and ctrl.""" From e076a4810836a4b2f9eb3f0a631ae582784d7bd0 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 18 Nov 2024 10:00:31 -0500 Subject: [PATCH 28/45] changelog --- doc/releases/changelog-dev.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index faa9cf7e27d..f2b053c19b3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -49,8 +49,11 @@ pennylane variant jaxpr. [(#6141)](https://github.com/PennyLaneAI/pennylane/pull/6141) +* A `DefaultQubitInterpreter` class has been added to provide plxpr execution using python based tools. + [(#6328)](https://github.com/PennyLaneAI/pennylane/pull/6328) + * An optional method `eval_jaxpr` is added to the device API for native execution of plxpr programs. -[(#6580)](https://github.com/PennyLaneAI/pennylane/pull/6580) + [(#6580)](https://github.com/PennyLaneAI/pennylane/pull/6580)

Other Improvements

From 53dfac0e3033f72b96c7293947fa4861e41c04d5 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 18 Nov 2024 10:57:10 -0500 Subject: [PATCH 29/45] add eval_jaxpr method to DefaultQubit --- doc/releases/changelog-dev.md | 3 +- pennylane/devices/default_qubit.py | 18 +++- .../default_qubit/test_default_qubit_plxpr.py | 84 +++++++++++++++++++ 3 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 tests/devices/default_qubit/test_default_qubit_plxpr.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 74cf25d31ae..9197c537bca 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -49,7 +49,8 @@ pennylane variant jaxpr. [(#6141)](https://github.com/PennyLaneAI/pennylane/pull/6141) -* A `DefaultQubitInterpreter` class has been added to provide plxpr execution using python based tools. +* A `DefaultQubitInterpreter` class has been added to provide plxpr execution using python based tools, + and the `DefaultQubit.eval_jaxpr` method is now implemented. [(#6328)](https://github.com/PennyLaneAI/pennylane/pull/6328) * An optional method `eval_jaxpr` is added to the device API for native execution of plxpr programs. diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 4d2f4d5a76e..54298ecf628 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -32,7 +32,7 @@ from pennylane.tape import QuantumScript, QuantumScriptBatch, QuantumScriptOrBatch from pennylane.transforms import convert_to_numpy_parameters from pennylane.transforms.core import TransformProgram -from pennylane.typing import PostprocessingFn, Result, ResultBatch +from pennylane.typing import PostprocessingFn, Result, ResultBatch, TensorLike from . import Device from .execution_config import DefaultExecutionConfig, ExecutionConfig @@ -891,6 +891,22 @@ def execute_and_compute_vjp( return tuple(zip(*results)) + # pylint: disable=import-outside-toplevel + def eval_jaxpr( + self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args + ) -> list[TensorLike]: + from .qubit.dq_interpreter import DefaultQubitInterpreter + + if self.wires is None: + raise qml.DeviceError("Device wires are required for jaxpr execution.") + if self.shots.has_partitioned_shots: + raise qml.DeviceError("Shot vectors are unsupported with jaxpr execution.") + key = self.get_prng_keys() + interpreter = DefaultQubitInterpreter( + num_wires=len(self.wires), shots=self.shots.total_shots, key=key + ) + return interpreter.eval(jaxpr, consts, *args) + def _simulate_wrapper(circuit, kwargs): return simulate(circuit, **kwargs) diff --git a/tests/devices/default_qubit/test_default_qubit_plxpr.py b/tests/devices/default_qubit/test_default_qubit_plxpr.py new file mode 100644 index 00000000000..263f106b8cf --- /dev/null +++ b/tests/devices/default_qubit/test_default_qubit_plxpr.py @@ -0,0 +1,84 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for default qubit executing jaxpr..""" + +import pytest + +import pennylane as qml + +jax = pytest.importorskip("jax") +pytestmark = pytest.mark.jax + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + qml.capture.enable() + yield + qml.capture.disable() + + +def test_requires_wires(): + """Test that a device error is raised if device wires are not specified.""" + + jaxpr = jax.make_jaxpr(lambda x: x + 1)(0.1) + dev = qml.device("default.qubit") + + with pytest.raises(qml.DeviceError, match="Device wires are required."): + dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.2) + + +def test_no_partitioned_shots(): + """Test that an error is raised if the device has partitioned shots.""" + + jaxpr = jax.make_jaxpr(lambda x: x + 1)(0.1) + dev = qml.device("default.qubit", wires=1, shots=(100, 100)) + + with pytest.raises(qml.DeviceError, match="Shot vectors are unsupported with jaxpr execution."): + dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.cosnts, 0.2) + + +def test_use_device_prng(): + """Test that sampling depends on the device prng.""" + + key1 = jax.random.PRNGKey(1234) + key2 = jax.random.PRNGKey(1234) + + dev1 = qml.device("default.qubit", wires=1, shots=100, seed=key1) + dev2 = qml.device("default.qubit", wires=1, shots=100, seed=key2) + + def f(): + qml.H(0) + return qml.sample(wires=0) + + jaxpr = jax.make_jaxpr(f)() + + samples1 = dev1.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + samples2 = dev2.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + + assert qml.math.allclose(samples1, samples2) + + +def test_simple_execution(): + """Test the execution, jitting, and gradient of a simple quantum circuit.""" + + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)(0.123) + + dev = qml.device("default.qubit", wires=1) + + res = dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5) + assert qml.math.allclose(res, jax.numpy.cos(0.5)) From cf22d059ecda4baa9cfccb61fb2c67360ecbbef4 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Mon, 18 Nov 2024 11:00:38 -0500 Subject: [PATCH 30/45] Update doc/releases/changelog-dev.md --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 9197c537bca..0b1b3586db3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -51,6 +51,7 @@ * A `DefaultQubitInterpreter` class has been added to provide plxpr execution using python based tools, and the `DefaultQubit.eval_jaxpr` method is now implemented. + [(#6594)](https://github.com/PennyLaneAI/pennylane/pull/6594) [(#6328)](https://github.com/PennyLaneAI/pennylane/pull/6328) * An optional method `eval_jaxpr` is added to the device API for native execution of plxpr programs. From 14a8cd562e8e2851d442b6ad0c98662533847782 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 18 Nov 2024 12:25:02 -0500 Subject: [PATCH 31/45] qnode natively executes jaxpr on device --- pennylane/capture/capture_qnode.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pennylane/capture/capture_qnode.py b/pennylane/capture/capture_qnode.py index de19f464701..fb4897cb8a6 100644 --- a/pennylane/capture/capture_qnode.py +++ b/pennylane/capture/capture_qnode.py @@ -105,23 +105,20 @@ def _get_qnode_prim(): qnode_prim = jax.core.Primitive("qnode") qnode_prim.multiple_results = True - # pylint: disable=too-many-arguments + # pylint: disable=too-many-arguments, unused-argument @qnode_prim.def_impl def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): consts = args[:n_consts] non_const_args = args[n_consts:] - def qfunc(*inner_args): - return jax.core.eval_jaxpr(qfunc_jaxpr, consts, *inner_args) - - qnode = qml.QNode(qfunc, device, **qnode_kwargs) - if batch_dims is not None: # pylint: disable=protected-access - return jax.vmap(partial(qnode._impl_call, shots=shots), batch_dims)(*non_const_args) + return jax.vmap(partial(device.eval_jaxpr, qfunc_jaxpr, consts), batch_dims)( + *non_const_args + ) # pylint: disable=protected-access - return qnode._impl_call(*non_const_args, shots=shots) + return device.eval_jaxpr(qfunc_jaxpr, consts, *non_const_args) # pylint: disable=unused-argument @qnode_prim.def_abstract_eval From a5e75e4ce825645762e0169820ec4626e8f1d294 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 18 Nov 2024 16:31:14 -0500 Subject: [PATCH 32/45] no seed support --- pennylane/devices/default_qubit.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 54298ecf628..c8b6ff997dc 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -901,7 +901,13 @@ def eval_jaxpr( raise qml.DeviceError("Device wires are required for jaxpr execution.") if self.shots.has_partitioned_shots: raise qml.DeviceError("Shot vectors are unsupported with jaxpr execution.") - key = self.get_prng_keys() + if self._prng_key is not None: + key = self.get_prng_keys()[0] + else: + import jax + + key = jax.random.PRNGKey(self._rng.integers(100000)) + interpreter = DefaultQubitInterpreter( num_wires=len(self.wires), shots=self.shots.total_shots, key=key ) From 77844647bbadab0a33a91207a89b51c2aeae2bf3 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 21 Nov 2024 16:52:46 -0500 Subject: [PATCH 33/45] fixing up tests --- pennylane/devices/qubit/dq_interpreter.py | 48 +++++++++++------------ pennylane/workflow/_capture_qnode.py | 20 +++++----- tests/capture/test_capture_cond.py | 2 + tests/capture/test_capture_diff.py | 18 +++------ tests/capture/test_capture_mid_measure.py | 13 +++--- tests/capture/test_capture_qnode.py | 18 +++++---- tests/pytest.ini | 2 +- 7 files changed, 62 insertions(+), 59 deletions(-) diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index 7299ad4faa8..598153ce87f 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -89,33 +89,27 @@ def __init__( self.stateref = None super().__init__() - @property - def state(self) -> None | jax.numpy.ndarray: - """The current state of the system. None if not initialized.""" - return self.stateref["state"] if self.stateref else None - - @state.setter - def state(self, value: jax.numpy.ndarray | None): - if self.stateref is None: - raise AttributeError("execution not yet initialized.") - self.stateref["state"] = value - - @property - def key(self) -> jax.numpy.ndarray: - """A jax PRNGKey. ``initial_key`` if not yet initialized.""" - return self.stateref["key"] if self.stateref else self.initial_key - - @key.setter - def key(self, value): - if self.stateref is None: - raise AttributeError("execution not yet initialized.") - self.stateref["key"] = value + def __getattr__(self, key): + if key in {"state", "key", "is_state_batched"}: + if self.stateref is None: + raise AttributeError("execution not yet initialized.") + return self.stateref[key] + return super().__getattr__(key) + + def __setattr__(self, __name: str, __value) -> None: + if __name in {"state", "key", "is_state_batched"}: + if self.stateref is None: + raise AttributeError("execution not yet initialized") + self.stateref[__name] = __value + else: + super().__setattr__(__name, __value) def setup(self) -> None: if self.stateref is None: self.stateref = { "state": create_initial_state(range(self.num_wires), like="jax"), "key": self.initial_key, + "is_state_batched": False, } # else set by copying a parent interpreter and we need to modify same stateref @@ -124,7 +118,9 @@ def cleanup(self) -> None: self.stateref = None def interpret_operation(self, op): - self.state = apply_operation(op, self.state) + self.state = apply_operation(op, self.state, is_state_batched=self.is_state_batched) + if op.batch_size: + self.is_state_batched = True def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"): if "mcm" in eqn.primitive.name: @@ -142,10 +138,14 @@ def interpret_measurement(self, measurement): # note that this does *not* group commuting measurements # further work could figure out how to perform multiple measurements at the same time output = measure_with_samples( - [measurement], self.state, shots=self.shots, prng_key=new_key + [measurement], + self.state, + shots=self.shots, + prng_key=new_key, + is_state_batched=self.is_state_batched, )[0] else: - output = measure(measurement, self.state) + output = measure(measurement, self.state, is_state_batched=self.is_state_batched) finally: enable() return output diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 1594e400f2a..3730d33e53e 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -100,18 +100,19 @@ def _get_qnode_prim(): # pylint: disable=too-many-arguments, unused-argument @qnode_prim.def_impl def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): + if shots != device.shots: + raise NotImplementedError( + "override shots are not yet supported with the program capture execution." + ) + consts = args[:n_consts] non_const_args = args[n_consts:] - if batch_dims is not None: - - # pylint: disable=protected-access - return jax.vmap(partial(device.eval_jaxpr, qfunc_jaxpr, consts), batch_dims[n_consts:])( - *jax.tree_util.tree_leaves(non_const_args) - ) - - # pylint: disable=protected-access - return device.eval_jaxpr(qfunc_jaxpr, consts, *non_const_args) + if batch_dims is None: + return device.eval_jaxpr(qfunc_jaxpr, consts, *non_const_args) + return jax.vmap(partial(device.eval_jaxpr, qfunc_jaxpr, consts), batch_dims[n_consts:])( + *jax.tree_util.tree_leaves(non_const_args) + ) # pylint: disable=unused-argument @qnode_prim.def_abstract_eval @@ -274,6 +275,7 @@ def f(x): shots = qml.measurements.Shots(kwargs.pop("shots")) else: shots = qnode.device.shots + if shots.has_partitioned_shots: # Questions over the pytrees and the nested result object shape raise NotImplementedError("shot vectors are not yet supported with plxpr capture.") diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 03d89570131..6ff88417315 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -645,6 +645,7 @@ def test_circuit_consts(self, pred, arg, expected): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.xfail # currently using single branch statistics @pytest.mark.local_salt(1) @pytest.mark.parametrize("reset", [True, False]) @pytest.mark.parametrize("postselect", [None, 0, 1]) @@ -675,6 +676,7 @@ def f(x, y): assert np.allclose(res, expected), f"Expected {expected}, but got {res}" + @pytest.mark.xfail # currently using single branch statistics @pytest.mark.parametrize("shots", [None, 300]) @pytest.mark.parametrize( "params, expected", diff --git a/tests/capture/test_capture_diff.py b/tests/capture/test_capture_diff.py index c6a210c7d8f..557fc82a322 100644 --- a/tests/capture/test_capture_diff.py +++ b/tests/capture/test_capture_diff.py @@ -189,7 +189,9 @@ def func(x): jax.config.update("jax_enable_x64", initial_mode) - @pytest.mark.parametrize("diff_method", ("backprop", "parameter-shift")) + @pytest.mark.parametrize( + "diff_method", ("backprop", pytest.param("parameter-shift", marks=pytest.mark.xfail)) + ) def test_grad_of_simple_qnode(self, x64_mode, diff_method, mocker): """Test capturing the gradient of a simple qnode.""" # pylint: disable=protected-access @@ -243,12 +245,7 @@ def circuit(x): assert len(grad_eqn.outvars) == 1 assert grad_eqn.outvars[0].aval == jax.core.ShapedArray((2,), fdtype) - spy = mocker.spy(qml.gradients.parameter_shift, "expval_param_shift") manual_res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) - if diff_method == "parameter-shift": - spy.assert_called_once() - else: - spy.assert_not_called() assert qml.math.allclose(manual_res, expected_res) jax.config.update("jax_enable_x64", initial_mode) @@ -506,7 +503,9 @@ def func(x): jax.config.update("jax_enable_x64", initial_mode) - @pytest.mark.parametrize("diff_method", ("backprop", "parameter-shift")) + @pytest.mark.parametrize( + "diff_method", ("backprop", pytest.param("parameter-shift", marks=pytest.mark.xfail)) + ) def test_jacobian_of_simple_qnode(self, x64_mode, diff_method, mocker): """Test capturing the gradient of a simple qnode.""" # pylint: disable=protected-access @@ -560,12 +559,7 @@ def circuit(x): assert [outvar.aval for outvar in jac_eqn.outvars] == jaxpr.out_avals - spy = mocker.spy(qml.gradients.parameter_shift, "expval_param_shift") manual_res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) - if diff_method == "parameter-shift": - spy.assert_called_once() - else: - spy.assert_not_called() assert _jac_allclose(manual_res, expected_res, 1) jax.config.update("jax_enable_x64", initial_mode) diff --git a/tests/capture/test_capture_mid_measure.py b/tests/capture/test_capture_mid_measure.py index fdd69d6527c..914006e844b 100644 --- a/tests/capture/test_capture_mid_measure.py +++ b/tests/capture/test_capture_mid_measure.py @@ -293,6 +293,7 @@ def compare_with_capture_disabled(qnode, *args, **kwargs): res = qnode(*args, **kwargs) qml.capture.disable() expected = qnode(*args, **kwargs) + print(res, expected) return jnp.allclose(res, expected) @@ -312,8 +313,9 @@ class TestMidMeasureExecute: """System-level tests for executing circuits with mid-circuit measurements with program capture enabled.""" + @pytest.mark.skip("flaky failures due to single branch statistics") @pytest.mark.parametrize("reset", [True, False]) - @pytest.mark.parametrize("postselect", [None, 0, 1]) + @pytest.mark.parametrize("postselect", [pytest.param(None, marks=pytest.mark.xfail), 0, 1]) @pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5)) def test_simple_circuit_execution(self, phi, reset, postselect, get_device, shots, mp_fn, seed): """Test that circuits with mid-circuit measurements can be executed in a QNode.""" @@ -322,7 +324,7 @@ def test_simple_circuit_execution(self, phi, reset, postselect, get_device, shot dev = get_device(wires=2, shots=shots, seed=jax.random.PRNGKey(seed)) - @qml.qnode(dev) + @qml.qnode(dev, postselect_mode="fill-shots") def f(x): qml.RX(x, 0) qml.measure(0, reset=reset, postselect=postselect) @@ -330,6 +332,7 @@ def f(x): assert compare_with_capture_disabled(f, phi) + @pytest.mark.xfail # not yet supported @pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5)) @pytest.mark.parametrize("multi_mcm", [True, False]) def test_circuit_with_terminal_measurement_execution( @@ -377,7 +380,6 @@ def f(x, y): assert compare_with_capture_disabled(f, phi, phi + 1.5) - @pytest.mark.xfail @pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5)) def test_circuit_with_classical_processing_execution(self, phi, get_device, shots, mp_fn, seed): """Test that circuits that apply non-boolean operations to mid-circuit measurement @@ -397,7 +399,7 @@ def f(x, y): _ = a ** (m2 / 5) return mp_fn(op=qml.Z(0)) - assert f(phi, phi + 1.5) + _ = f(phi, phi + 1.5) @pytest.mark.xfail @pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5)) @@ -421,7 +423,6 @@ def f(x): assert f(phi) - @pytest.mark.xfail @pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5)) def test_mid_measure_as_gate_parameter_execution(self, phi, get_device, shots, mp_fn, seed): """Test that mid-circuit measurements (simple or classical processed) used as gate @@ -438,4 +439,4 @@ def f(x): qml.RX(m, 0) return mp_fn(op=qml.Z(0)) - assert f(phi) + _ = f(phi) diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 41b1854956f..5178b4b8524 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -65,8 +65,9 @@ def circuit(): circuit() jax.make_jaxpr(partial(circuit, shots=50))() # should run fine - res = circuit(shots=50) - assert qml.math.allclose(res, jax.numpy.zeros((50,))) + with pytest.raises(NotImplementedError, match="override shots are not yet supported"): + res = circuit(shots=50) + assert qml.math.allclose(res, jax.numpy.zeros((50,))) def test_error_if_overridden_shot_vector(): @@ -178,8 +179,9 @@ def circuit(): (50,), jax.numpy.int64 if x64_mode else jax.numpy.int32 ) - res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) - assert qml.math.allclose(res, jax.numpy.zeros((50,))) + with pytest.raises(NotImplementedError): + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) + assert qml.math.allclose(res, jax.numpy.zeros((50,))) jax.config.update("jax_enable_x64", initial_mode) @@ -536,7 +538,8 @@ def circuit(x): x = jax.numpy.array([1.0, 2.0, 3.0]) jaxpr = jax.make_jaxpr(jax.vmap(partial(circuit, shots=50), in_axes=0))(x) - res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + with pytest.raises(NotImplementedError, match="override shots are not yet supported"): + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) assert len(jaxpr.eqns) == 1 eqn0 = jaxpr.eqns[0] @@ -551,8 +554,9 @@ def circuit(x): assert eqn0.outvars[0].aval.shape == (3, 50) - res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) - assert qml.math.allclose(res, jax.numpy.zeros((3, 50))) + with pytest.raises(NotImplementedError, match="override shots are not yet supported"): + res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) + assert qml.math.allclose(res, jax.numpy.zeros((3, 50))) def test_vmap_error_indexing(self): """Test that an IndexError is raised when indexing a batched parameter.""" diff --git a/tests/pytest.ini b/tests/pytest.ini index 27b4c29f17e..036b5196116 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -33,5 +33,5 @@ filterwarnings = ignore:PauliSentence.hamiltonian:pennylane.PennyLaneDeprecationWarning ignore:PauliWord.hamiltonian:pennylane.PennyLaneDeprecationWarning addopts = --benchmark-disable - +xfail_strict=true rng_salt = v0.39.0 From f5a7adbfc2a6eb8ded69437a9994fc40eb682477 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 27 Nov 2024 09:31:55 -0500 Subject: [PATCH 34/45] xfailing more tests --- pennylane/devices/qubit/dq_interpreter.py | 5 ++++- pennylane/workflow/_capture_qnode.py | 4 ++++ tests/capture/test_base_interpreter.py | 4 ++-- tests/capture/test_capture_cond.py | 7 +++++-- tests/capture/test_capture_diff.py | 6 +++--- tests/capture/test_capture_while_loop.py | 1 + tests/capture/test_nested_plxpr.py | 2 ++ 7 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pennylane/devices/qubit/dq_interpreter.py b/pennylane/devices/qubit/dq_interpreter.py index 598153ce87f..b8b0fb6a6b5 100644 --- a/pennylane/devices/qubit/dq_interpreter.py +++ b/pennylane/devices/qubit/dq_interpreter.py @@ -37,6 +37,7 @@ from .sampling import measure_with_samples +# pylint: disable=attribute-defined-outside-init, access-member-before-definition class DefaultQubitInterpreter(PlxprInterpreter): """Implements a class for interpreting plxpr using python simulation tools. @@ -94,7 +95,7 @@ def __getattr__(self, key): if self.stateref is None: raise AttributeError("execution not yet initialized.") return self.stateref[key] - return super().__getattr__(key) + raise AttributeError(f"No attribute {key}") def __setattr__(self, __name: str, __value) -> None: if __name in {"state", "key", "is_state_batched"}: @@ -157,6 +158,8 @@ def _(self, *invals, reset, postselect): self.key, new_key = jax.random.split(self.key, 2) mcms = {} self.state = apply_operation(mp, self.state, mid_measurements=mcms, prng_key=new_key) + if mp.postselect is not None: + self.state = self.state / (1 - jax.numpy.abs(mp.postselect - mcms[mp])) return mcms[mp] diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 3730d33e53e..a9f33474408 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -104,6 +104,10 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_di raise NotImplementedError( "override shots are not yet supported with the program capture execution." ) + if qnode_kwargs["diff_method"] not in {"backprop", "best"}: + raise NotImplementedError( + "only backpropagation derivatives are supported at this time." + ) consts = args[:n_consts] non_const_args = args[n_consts:] diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index 4572b78eb01..ed1e5bd2dad 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -461,7 +461,7 @@ def interpret_operation(self, op): dev = qml.device("default.qubit", wires=1) @AddNoise() - @qml.qnode(dev, diff_method="adjoint", grad_on_execution=False) + @qml.qnode(dev, diff_method="backprop", grad_on_execution=False) def f(): qml.I(0) qml.I(0) @@ -477,7 +477,7 @@ def f(): assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive assert inner_jaxpr.eqns[3].primitive == qml.RX._primitive - assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "adjoint" + assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "backprop" assert jaxpr.eqns[0].params["qnode_kwargs"]["grad_on_execution"] is False assert jaxpr.eqns[0].params["device"] == dev diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 6ff88417315..57f50b12bbe 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -645,7 +645,7 @@ def test_circuit_consts(self, pred, arg, expected): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" - @pytest.mark.xfail # currently using single branch statistics + @pytest.mark.xfail(strict=False) # might pass if postselection equal to measurement @pytest.mark.local_salt(1) @pytest.mark.parametrize("reset", [True, False]) @pytest.mark.parametrize("postselect", [None, 0, 1]) @@ -676,7 +676,9 @@ def f(x, y): assert np.allclose(res, expected), f"Expected {expected}, but got {res}" - @pytest.mark.xfail # currently using single branch statistics + @pytest.mark.xfail( + strict=False + ) # currently using single branch statistics, sometimes gives good results @pytest.mark.parametrize("shots", [None, 300]) @pytest.mark.parametrize( "params, expected", @@ -735,6 +737,7 @@ def f(*x): assert np.allclose(res, expected, atol=atol, rtol=0), f"Expected {expected}, but got {res}" + @pytest.mark.xfail(strict=False) @pytest.mark.parametrize("upper_bound, arg", [(3, [0.1, 0.3, 0.5]), (2, [2, 7, 12])]) def test_nested_cond_for_while_loop(self, upper_bound, arg): """Test that a nested control flows are correctly captured into a jaxpr.""" diff --git a/tests/capture/test_capture_diff.py b/tests/capture/test_capture_diff.py index 557fc82a322..8dfaa8055ee 100644 --- a/tests/capture/test_capture_diff.py +++ b/tests/capture/test_capture_diff.py @@ -192,7 +192,7 @@ def func(x): @pytest.mark.parametrize( "diff_method", ("backprop", pytest.param("parameter-shift", marks=pytest.mark.xfail)) ) - def test_grad_of_simple_qnode(self, x64_mode, diff_method, mocker): + def test_grad_of_simple_qnode(self, x64_mode, diff_method): """Test capturing the gradient of a simple qnode.""" # pylint: disable=protected-access initial_mode = jax.config.jax_enable_x64 @@ -308,7 +308,7 @@ def test_grad_qnode_with_pytrees(self, argnum, x64_mode): dev = qml.device("default.qubit", wires=2) - @qml.qnode(dev) + @qml.qnode(dev, diff_method="backprop") def circuit(x, y, z): qml.RX(x["a"], wires=0) qml.RY(y, wires=0) @@ -506,7 +506,7 @@ def func(x): @pytest.mark.parametrize( "diff_method", ("backprop", pytest.param("parameter-shift", marks=pytest.mark.xfail)) ) - def test_jacobian_of_simple_qnode(self, x64_mode, diff_method, mocker): + def test_jacobian_of_simple_qnode(self, x64_mode, diff_method): """Test capturing the gradient of a simple qnode.""" # pylint: disable=protected-access initial_mode = jax.config.jax_enable_x64 diff --git a/tests/capture/test_capture_while_loop.py b/tests/capture/test_capture_while_loop.py index 1d45a104d44..ac0e4d00d1e 100644 --- a/tests/capture/test_capture_while_loop.py +++ b/tests/capture/test_capture_while_loop.py @@ -211,6 +211,7 @@ def inner(j): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.xfail(strict=False) # mcms only sometimes give the right answer @pytest.mark.parametrize("upper_bound, arg", [(3, 0.5), (2, 12)]) def test_while_and_for_loop_nested(self, upper_bound, arg): """Test that a nested while and for loop is correctly captured into a jaxpr.""" diff --git a/tests/capture/test_nested_plxpr.py b/tests/capture/test_nested_plxpr.py index 8f4715a10a5..031ae5c2328 100644 --- a/tests/capture/test_nested_plxpr.py +++ b/tests/capture/test_nested_plxpr.py @@ -154,6 +154,7 @@ def qfunc(wire): # x is closure variable and a tracer assert len(q) == 1 qml.assert_equal(q.queue[0], qml.adjoint(qml.RX(2.5, 2))) + @pytest.mark.xfail(raises=NotImplementedError) def test_adjoint_grad(self): """Test that adjoint differentiated with grad can be captured.""" from pennylane.capture.primitives import grad_prim, qnode_prim @@ -328,6 +329,7 @@ def workflow(wire): assert len(eqn.params["jaxpr"].eqns) == 5 + include_s + @pytest.mark.xfail(raises=NotImplementedError) def test_ctrl_grad(self): """Test that ctrl differentiated with grad can be captured.""" from pennylane.capture.primitives import grad_prim, qnode_prim From d1751497e72a57e843544530769b1c49ded0c470 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 28 Nov 2024 11:16:26 -0500 Subject: [PATCH 35/45] fix failing test --- doc/releases/changelog-dev.md | 5 +++++ tests/devices/qubit/test_dq_interpreter.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6f369935041..44bec808024 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -91,6 +91,11 @@

Capturing and representing hybrid programs

+* Execution with capture enabled now follows a new execution pipeline and natively passes the + captured jaxpr to the device. Since it no longer falls back to the old pipeline, execution + only works with a reduced feature set. + [(#6496)](https://github.com/PennyLaneAI/pennylane/pull/6596) + * `jax.vmap` can be captured with `qml.capture.make_plxpr` and is compatible with quantum circuits. [(#6349)](https://github.com/PennyLaneAI/pennylane/pull/6349) [(#6422)](https://github.com/PennyLaneAI/pennylane/pull/6422) diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py index 220e3a63690..0fe1729dddf 100644 --- a/tests/devices/qubit/test_dq_interpreter.py +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -59,12 +59,14 @@ def test_setup_and_cleanup(): dq.setup() assert isinstance(dq.stateref, dict) - assert list(dq.stateref.keys()) == ["state", "key"] + assert list(dq.stateref.keys()) == ["state", "key", "is_state_batched"] assert dq.stateref["key"] is key assert dq.key is key assert dq.state is dq.stateref["state"] + assert dq.is_state_batched is False + assert dq.stateref["is_state_batched"] is False expected = jax.numpy.array([[1.0, 0.0], [0.0, 0.0]], dtype=complex) assert qml.math.allclose(dq.state, expected) From 663b69a944b6a178a55f95d385ab21f5ab921021 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 28 Nov 2024 12:50:30 -0500 Subject: [PATCH 36/45] add test --- tests/devices/qubit/test_dq_interpreter.py | 37 ++++++++++++++-------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py index 0fe1729dddf..3dcadc4375f 100644 --- a/tests/devices/qubit/test_dq_interpreter.py +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -81,9 +81,6 @@ def test_working_state_key_before_setup(): dq = DefaultQubitInterpreter(num_wires=1, key=key) - assert dq.state is None - assert dq.key is key - with pytest.raises(AttributeError, match="execution not yet initialized"): dq.state = [1.0, 0.0] @@ -169,13 +166,27 @@ def f(x): assert qml.math.allclose(output, expected) +def test_parameter_broadcasting(): + """Test that dq can execute a circuit with parameter broadcasting.""" + + @DefaultQubitInterpreter(num_wires=3) + def f(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + x = jax.numpy.array([1.2, 2.3, 3.4]) + output = f(x) + expected = jax.numpy.cos(x) + assert qml.math.allclose(output, expected) + + class TestSampling: """Test cases for generating samples.""" - def test_known_sampling(self): + def test_known_sampling(self, seed): """Test sampling output with deterministic sampling output""" - @DefaultQubitInterpreter(num_wires=2, shots=10) + @DefaultQubitInterpreter(num_wires=2, shots=10, key=jax.random.PRNGKey(seed)) def sampler(): qml.X(0) return qml.sample(wires=(0, 1)) @@ -188,9 +199,9 @@ def sampler(): assert qml.math.allclose(results, expected) - def test_same_key_same_results(self): + def test_same_key_same_results(self, seed): """Test that two circuits with the same key give identical results.""" - key = jax.random.PRNGKey(1234) + key = jax.random.PRNGKey(seed) @DefaultQubitInterpreter(num_wires=1, shots=100, key=key) def circuit1(): @@ -270,10 +281,10 @@ def f(): with pytest.raises(NotImplementedError): f() - def test_mcms_not_all_same_key(self): + def test_mcms_not_all_same_key(self, seed): """Test that each mid circuit measurement has a different key.""" - @DefaultQubitInterpreter(num_wires=1, shots=None, key=jax.random.PRNGKey(87665)) + @DefaultQubitInterpreter(num_wires=1, shots=None, key=jax.random.PRNGKey(seed)) def g(): qml.Hadamard(0) m0 = qml.measure(0, reset=0) @@ -291,10 +302,10 @@ def g(): assert not all(qml.math.allclose(output[0], output[i]) for i in range(1, 5)) # only way we could get different values between the mcms is if they had different seeds - def test_each_measurement_has_different_key(self): + def test_each_measurement_has_different_key(self, seed): """Test that each sampling measurement is performed with a different key.""" - @DefaultQubitInterpreter(num_wires=1, shots=100, key=jax.random.PRNGKey(87665)) + @DefaultQubitInterpreter(num_wires=1, shots=100, key=jax.random.PRNGKey(seed)) def g(): qml.Hadamard(0) return qml.sample(wires=0), qml.sample(wires=0) @@ -302,10 +313,10 @@ def g(): res1, res2 = g() assert not qml.math.allclose(res1, res2) - def test_more_executions_same_interpreter_different_results(self): + def test_more_executions_same_interpreter_different_results(self, seed): """Test that if multiple executions occur with the same interpreter, they will have different results.""" - @DefaultQubitInterpreter(num_wires=1, shots=100, key=jax.random.PRNGKey(76543)) + @DefaultQubitInterpreter(num_wires=1, shots=100, key=jax.random.PRNGKey(seed)) def f(): qml.Hadamard(0) return qml.sample(wires=0) From a35a6974eb650607fcb5a5d7c236a7fdcdb95d4c Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 28 Nov 2024 16:15:14 -0500 Subject: [PATCH 37/45] black, test, pylint --- tests/devices/default_qubit/test_default_qubit_plxpr.py | 3 ++- tests/devices/qubit/test_dq_interpreter.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/devices/default_qubit/test_default_qubit_plxpr.py b/tests/devices/default_qubit/test_default_qubit_plxpr.py index c44d83b3fc0..23628177d9c 100644 --- a/tests/devices/default_qubit/test_default_qubit_plxpr.py +++ b/tests/devices/default_qubit/test_default_qubit_plxpr.py @@ -23,6 +23,7 @@ @pytest.fixture(autouse=True) def enable_disable_plxpr(): + """Enable and disable plxpr.""" qml.capture.enable() yield qml.capture.disable() @@ -81,7 +82,7 @@ def f(): res = dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts) assert qml.math.allclose(res, jax.numpy.zeros(100)) - + def test_simple_execution(): """Test the execution, jitting, and gradient of a simple quantum circuit.""" diff --git a/tests/devices/qubit/test_dq_interpreter.py b/tests/devices/qubit/test_dq_interpreter.py index 3dcadc4375f..c32d778d1f7 100644 --- a/tests/devices/qubit/test_dq_interpreter.py +++ b/tests/devices/qubit/test_dq_interpreter.py @@ -57,6 +57,9 @@ def test_setup_and_cleanup(): dq = DefaultQubitInterpreter(num_wires=2, shots=2, key=key) assert dq.stateref is None + with pytest.raises(AttributeError, match="execution not yet initialized"): + _ = dq.state + dq.setup() assert isinstance(dq.stateref, dict) assert list(dq.stateref.keys()) == ["state", "key", "is_state_batched"] From 2a2995078043357b150cbdcc6329dded64ed5e80 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 28 Nov 2024 16:17:13 -0500 Subject: [PATCH 38/45] Update tests/capture/test_capture_mid_measure.py --- tests/capture/test_capture_mid_measure.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/capture/test_capture_mid_measure.py b/tests/capture/test_capture_mid_measure.py index 914006e844b..078c30c399e 100644 --- a/tests/capture/test_capture_mid_measure.py +++ b/tests/capture/test_capture_mid_measure.py @@ -293,7 +293,6 @@ def compare_with_capture_disabled(qnode, *args, **kwargs): res = qnode(*args, **kwargs) qml.capture.disable() expected = qnode(*args, **kwargs) - print(res, expected) return jnp.allclose(res, expected) From 68a6ba5d800c0a7400cb8bafb304ba42070d23a7 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 29 Nov 2024 10:44:54 -0500 Subject: [PATCH 39/45] add workflow developement status --- pennylane/workflow/_capture_qnode.py | 92 +++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index a9f33474408..19a73f8edfa 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -13,6 +13,93 @@ # limitations under the License. """ This submodule defines a capture compatible call to QNodes. + +Workflow Developement Status +---------------------------- + +The non-exhaustive list of unsupported features are: + +**Overridden shots:** Device execution currently pulls the shot information from the device. In order +to support dynamic shots, we need to develop an additional protocol for communicating the shot information +associated with a circuit. Dynamically mutating objects is not compatible with jaxpr and jitting. + +**Shot vectors**. Shot vectors are not yet supported. We need to figure out how to stack +and reshape the outputs from measurements on the device when multiple measurements are present. + +**Gradients other than default qubit backprop**. We managed to get backprop of default qubit for +free, but no other gradients methods have support yet. + +*MCM methods other than single branch statistics*. Mid circuit measurements +are only handled via a "single branch statistics" algorithm, which will lead to unexpected +results. Even on analytic devices, once branch will be randomly chosen on each execution. + +>>> @qml.qnode(qml.device('default.qubit', wires=1)) +>>> def circuit(x): +... qml.H(0) +... m0 = qml.measure(0) +... qml.cond(m0, qml.RX, qml.RZ)(x,0) +... return qml.expval(qml.Z(0)) +>>> circuit(0.5), circuit(0.5), circuit(0.5) +(Array(-0.87758256, dtype=float64), +Array(1., dtype=float64), +Array(-0.87758256, dtype=float64)) +>>> qml.capture.disable() +>>> circuit(0.5) +np.float64(0.06120871905481362) +>>> qml.capture.enable() + +*Device preprocessing and validation*. No device preprocessing and validation will occur. The captured +jaxpr is directly sent to the device, whether or not the device can handle it. + +>>> @qml.qnode(qml.device('default.qubit', wires=3)) +... def circuit(): +... qml.Permute(jax.numpy.array((0,1,2)), wires=(2,1,0)) +... return qml.state() +>>> circuit() +MatrixUndefinedError: + +*Transforms are still under developement*. No transforms will currently be applied as part of the workflow. + +*Breaking vmap/ parameter broadcasting into a non-broadcasted state*. The current workflow assumes +that the device execution can natively handled broadcasted parameters. vmap and parameter broadcasting +will not work with devices other than default qubit. + +>>> @qml.qnode(qml.device('lightning.qubit', wires=1)) +... def circuit(x): +... qml.RX(x, 0) +... return qml.expval(qml.Z(0)) +>>> jax.vmap(circuit)(jax.numpy.array([1.0, 2.0, 3.0])) +TypeError: RX(): incompatible function arguments. The following argument types are supported: + 1. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: list[int], arg1: bool, arg2: list[float]) -> None + 2. (self: pennylane_lightning.lightning_qubit_ops.StateVectorC128, arg0: list[int], arg1: list[bool], arg2: list[int], arg3: bool, arg4: list[float]) -> None + +*Grouping commuting measurements and/ or splitting up non-commuting measurements.* Currently, each +measurment is fully independent and generated from different raw samples than every other measurement. +To generate multiple measurments from the same samples, we need a way of denoting which measurements +should be taken together. A "Combination measurement process" higher order primitive, or something like it. +We will also need to figure out how to implement splitting up a circuit with non-commuting measuremets into +multiple circuits. + +>>> @qml.qnode(qml.device('default.qubit', wires=1, shots=5)) +... def circuit(): +... qml.H(0) +... return qml.sample(wires=0), qml.sample(wires=0) +>>> circuit() +(Array([1, 0, 1, 0, 0], dtype=int64), Array([0, 0, 1, 0, 0], dtype=int64)) + +*Figuring out what types of data can be sent to the device.* Is the device always +responsible for converting jax arrays to numpy arrays? Is the device responsible for having a +pure-callback boundary if the execution is not jittable? We do have an opportunity here +to have gpu-end-to-end simulation on lightning gpu and lightning kokkos. + +*Jitting workflows involving qnodes*. While the execution of jaxpr on default qubit is +currently jittable, we will need to register a lowering for the qnode primitive. We will also +need to figure out where to apply a ``jax.pure_callback`` for devices like lightning qubit that are +not jittable. + +*Unknown other features*. The workflow currently has limited testing, so this list of unsupported +features is non-exhaustive. + """ from copy import copy from dataclasses import asdict @@ -28,7 +115,6 @@ try: import jax from jax.interpreters import ad, batching - except ImportError: has_jax = False @@ -99,7 +185,9 @@ def _get_qnode_prim(): # pylint: disable=too-many-arguments, unused-argument @qnode_prim.def_impl - def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): + def qnode_impl( + *args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None + ): if shots != device.shots: raise NotImplementedError( "override shots are not yet supported with the program capture execution." From e2910ba3808da69a4bf7d50f0515a5230b1348b6 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 29 Nov 2024 10:45:45 -0500 Subject: [PATCH 40/45] add workflow developement status --- pennylane/workflow/_capture_qnode.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 19a73f8edfa..dbf21021e62 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -97,6 +97,9 @@ need to figure out where to apply a ``jax.pure_callback`` for devices like lightning qubit that are not jittable. +*Result caching*. The new workflow is not capable of caching the results of executions, and we have +not even started thinking about how it might be possible to do so. + *Unknown other features*. The workflow currently has limited testing, so this list of unsupported features is non-exhaustive. From 9b9d9399a619d3b840deece82fe5f8bc8d145e59 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 29 Nov 2024 10:49:30 -0500 Subject: [PATCH 41/45] add workflow developement status --- pennylane/workflow/_capture_qnode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index dbf21021e62..e35d5d53d77 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -32,6 +32,8 @@ *MCM methods other than single branch statistics*. Mid circuit measurements are only handled via a "single branch statistics" algorithm, which will lead to unexpected results. Even on analytic devices, once branch will be randomly chosen on each execution. +Returning measurments based on mid circuit measurements, ``qml.sample(m0)``, +is also not yet supported on default qubit or lightning. >>> @qml.qnode(qml.device('default.qubit', wires=1)) >>> def circuit(x): From ce5d8c349ce64a0687de7d0121079b6cb549aea7 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 29 Nov 2024 12:57:14 -0500 Subject: [PATCH 42/45] jit circuits on dq --- pennylane/workflow/_capture_qnode.py | 4 +++- tests/capture/test_capture_qnode.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index e35d5d53d77..4809921405f 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -119,7 +119,7 @@ has_jax = True try: import jax - from jax.interpreters import ad, batching + from jax.interpreters import ad, batching, mlir except ImportError: has_jax = False @@ -300,6 +300,8 @@ def _qnode_jvp(args, tangents, **impl_kwargs): batching.primitive_batchers[qnode_prim] = _qnode_batching_rule + mlir.register_lowering(qnode_prim, mlir.lower_fun(qnode_impl, multiple_results=True)) + return qnode_prim diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 5178b4b8524..da7ed488b66 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -375,6 +375,19 @@ def circuit(x): assert qml.math.allclose(jvp, (qml.math.cos(x), -qml.math.sin(x) * xt)) +def test_qnode_jit(): + """Test that executions on default qubit can be jitted.""" + + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(x): + qml.RX(x, 0) + return qml.expval(qml.Z(0)) + + x = jax.numpy.array(-0.5) + res = jax.jit(circuit)(0.5) + assert qml.math.allclose(res, jax.numpy.cos(x)) + + # pylint: disable=too-many-public-methods class TestQNodeVmapIntegration: """Tests for integrating JAX vmap with the QNode primitive.""" From a06951b1ac36e40e0724acca242d69f99b7e7a73 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 5 Dec 2024 10:36:23 -0500 Subject: [PATCH 43/45] Update tests/capture/test_capture_mid_measure.py --- tests/capture/test_capture_mid_measure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/capture/test_capture_mid_measure.py b/tests/capture/test_capture_mid_measure.py index a4e10368363..afc23d09347 100644 --- a/tests/capture/test_capture_mid_measure.py +++ b/tests/capture/test_capture_mid_measure.py @@ -309,7 +309,7 @@ class TestMidMeasureExecute: @pytest.mark.xfail(strict=False) # single branch statistics sometimes gives good results @pytest.mark.parametrize("reset", [True, False]) - @pytest.mark.parametrize("postselect", [pytest.param(None, marks=pytest.mark.xfail), 0, 1]) + @pytest.mark.parametrize("postselect", [None, 0, 1]) @pytest.mark.parametrize("phi", jnp.arange(1.0, 2 * jnp.pi, 1.5)) def test_simple_circuit_execution(self, phi, reset, postselect, get_device, shots, mp_fn, seed): """Test that circuits with mid-circuit measurements can be executed in a QNode.""" From 2885933014748a4a83035ac5c370d08bfbd60415 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 5 Dec 2024 10:38:10 -0500 Subject: [PATCH 44/45] oops --- pennylane/workflow/_capture_qnode.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 9b52bdcf8e5..41b8b37e175 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -190,7 +190,9 @@ def _get_qnode_prim(): # pylint: disable=too-many-arguments, unused-argument @qnode_prim.def_impl - def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): + def qnode_impl( + *args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None + ): if shots != device.shots: raise NotImplementedError( "Overriding shots is not yet supported with the program capture execution." From 5cfd98fb0b3e3afffbd5bad21cec2f8b0706e6dd Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Thu, 5 Dec 2024 14:25:06 -0500 Subject: [PATCH 45/45] Update doc/releases/changelog-dev.md Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com> --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 2ea8372a8e3..2db3ecb4582 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -165,7 +165,7 @@ featuring a `simulate` function for simulating mixed states in analytic mode. * Execution with capture enabled now follows a new execution pipeline and natively passes the captured jaxpr to the device. Since it no longer falls back to the old pipeline, execution only works with a reduced feature set. - [(#6496)](https://github.com/PennyLaneAI/pennylane/pull/6596) + [(#6655)](https://github.com/PennyLaneAI/pennylane/pull/6655) [(#6596)](https://github.com/PennyLaneAI/pennylane/pull/6596) * PennyLane transforms can now be captured as primitives with experimental program capture enabled.