Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] add eval_jaxpr method to DefaultQubit #6594

Merged
merged 46 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e827977
exploring an idea
albi3ro Jun 21, 2024
33b946c
lightning interpreter
albi3ro Jun 21, 2024
bdc2b1e
add call jaxpr function
albi3ro Jun 21, 2024
5185112
add demo notebook
albi3ro Jun 21, 2024
786caf3
call is function transform
albi3ro Aug 9, 2024
d8f7e59
Merge branch 'master' into plxpr-interpreter
albi3ro Aug 13, 2024
13b091d
assorted improvements
albi3ro Aug 13, 2024
3f7b5ff
more moving things around
albi3ro Aug 13, 2024
6e2b334
something?
albi3ro Aug 15, 2024
4f3efb6
Merge branch 'master' into plxpr-interpreter
albi3ro Aug 21, 2024
a61acd2
Add PLxprInterpreter base class
albi3ro Aug 26, 2024
0d68f07
Merge branch 'master' into plxpr-interpreter-base
albi3ro Aug 26, 2024
1609ed3
qnode fix
albi3ro Aug 26, 2024
f4a6ba8
Merge branch 'plxpr-interpreter-base' of https://github.com/PennyLane…
albi3ro Aug 26, 2024
f593d45
starting to write tests
albi3ro Aug 27, 2024
0defdde
trying to improve op math handling
albi3ro Aug 28, 2024
0ecf1a8
testing
albi3ro Aug 30, 2024
65202fa
improvementS
albi3ro Aug 30, 2024
22e32f7
more fixes
albi3ro Sep 3, 2024
93724c4
Merge branch 'master' into plxpr-interpreter-base
albi3ro Sep 9, 2024
de6f4b3
Merge branch 'master' into plxpr-interpreter-base
albi3ro Sep 30, 2024
ba298eb
adding tests
albi3ro Oct 1, 2024
d9d9fd7
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 1, 2024
1cc4e79
more tests
albi3ro Oct 1, 2024
1a8f410
more test changes
albi3ro Oct 2, 2024
769570f
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 2, 2024
e8a5c5a
some more tests and polishing
albi3ro Oct 2, 2024
c42090c
Merge branch 'master' into plxpr-interpreter-base
albi3ro Oct 2, 2024
f806708
add default qubit interpreter
albi3ro Oct 2, 2024
751d66c
use copy not child
albi3ro Oct 9, 2024
c232286
merge in master
albi3ro Nov 11, 2024
bfbe68d
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 11, 2024
d3f729e
starting to write tests
albi3ro Nov 12, 2024
879bdad
more testing
albi3ro Nov 12, 2024
7741113
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 12, 2024
b80d290
more tests
albi3ro Nov 13, 2024
d0298ad
Apply suggestions from code review
albi3ro Nov 15, 2024
1d4bf51
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 15, 2024
07fae7b
update initial key each execution
albi3ro Nov 18, 2024
e076a48
changelog
albi3ro Nov 18, 2024
2c736bd
Merge branch 'master' into default-qubit-interpreter
albi3ro Nov 18, 2024
53dfac0
add eval_jaxpr method to DefaultQubit
albi3ro Nov 18, 2024
cf22d05
Update doc/releases/changelog-dev.md
albi3ro Nov 18, 2024
eb3e277
fix no prng key
albi3ro Nov 18, 2024
03644f7
Update tests/devices/default_qubit/test_default_qubit_plxpr.py
albi3ro Nov 20, 2024
e9beec7
Merge branch 'master' into dq-eval-jaxpr
albi3ro Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
pennylane variant jaxpr.
[(#6141)](https://github.com/PennyLaneAI/pennylane/pull/6141)

* A `DefaultQubitInterpreter` class has been added to provide plxpr execution using python based tools.
* A `DefaultQubitInterpreter` class has been added to provide plxpr execution using python based tools,
and the `DefaultQubit.eval_jaxpr` method is now implemented.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
[(#6594)](https://github.com/PennyLaneAI/pennylane/pull/6594)
[(#6328)](https://github.com/PennyLaneAI/pennylane/pull/6328)

* An optional method `eval_jaxpr` is added to the device API for native execution of plxpr programs.
Expand Down
24 changes: 23 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pennylane.tape import QuantumScript, QuantumScriptBatch, QuantumScriptOrBatch
from pennylane.transforms import convert_to_numpy_parameters
from pennylane.transforms.core import TransformProgram
from pennylane.typing import PostprocessingFn, Result, ResultBatch
from pennylane.typing import PostprocessingFn, Result, ResultBatch, TensorLike

from . import Device
from .execution_config import DefaultExecutionConfig, ExecutionConfig
Expand Down Expand Up @@ -891,6 +891,28 @@ def execute_and_compute_vjp(

return tuple(zip(*results))

# pylint: disable=import-outside-toplevel
def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args
) -> list[TensorLike]:
from .qubit.dq_interpreter import DefaultQubitInterpreter

if self.wires is None:
raise qml.DeviceError("Device wires are required for jaxpr execution.")
if self.shots.has_partitioned_shots:
raise qml.DeviceError("Shot vectors are unsupported with jaxpr execution.")
if self._prng_key is not None:
key = self.get_prng_keys()[0]
else:
import jax

key = jax.random.PRNGKey(self._rng.integers(100000))

interpreter = DefaultQubitInterpreter(
num_wires=len(self.wires), shots=self.shots.total_shots, key=key
)
return interpreter.eval(jaxpr, consts, *args)


def _simulate_wrapper(circuit, kwargs):
return simulate(circuit, **kwargs)
Expand Down
97 changes: 97 additions & 0 deletions tests/devices/default_qubit/test_default_qubit_plxpr.py
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for default qubit executing jaxpr."""

import pytest

import pennylane as qml

jax = pytest.importorskip("jax")
pytestmark = pytest.mark.jax


@pytest.fixture(autouse=True)
def enable_disable_plxpr():
qml.capture.enable()
yield
qml.capture.disable()


def test_requires_wires():
"""Test that a device error is raised if device wires are not specified."""

jaxpr = jax.make_jaxpr(lambda x: x + 1)(0.1)
dev = qml.device("default.qubit")

with pytest.raises(qml.DeviceError, match="Device wires are required."):
dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.2)


def test_no_partitioned_shots():
"""Test that an error is raised if the device has partitioned shots."""

jaxpr = jax.make_jaxpr(lambda x: x + 1)(0.1)
dev = qml.device("default.qubit", wires=1, shots=(100, 100))

with pytest.raises(qml.DeviceError, match="Shot vectors are unsupported with jaxpr execution."):
dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.2)


def test_use_device_prng():
"""Test that sampling depends on the device prng."""

key1 = jax.random.PRNGKey(1234)
key2 = jax.random.PRNGKey(1234)

dev1 = qml.device("default.qubit", wires=1, shots=100, seed=key1)
dev2 = qml.device("default.qubit", wires=1, shots=100, seed=key2)

def f():
qml.H(0)
return qml.sample(wires=0)

jaxpr = jax.make_jaxpr(f)()

samples1 = dev1.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
samples2 = dev2.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)

assert qml.math.allclose(samples1, samples2)


def test_no_prng_key():
"""Test that that sampling works without a provided prng key."""

dev = qml.device("default.qubit", wires=1, shots=100)

def f():
return qml.sample(wires=0)

jaxpr = jax.make_jaxpr(f)()
res = dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts)
assert qml.math.allclose(res, jax.numpy.zeros(100))


def test_simple_execution():
"""Test the execution, jitting, and gradient of a simple quantum circuit."""

def f(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

jaxpr = jax.make_jaxpr(f)(0.123)

dev = qml.device("default.qubit", wires=1)

res = dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.5)
assert qml.math.allclose(res, jax.numpy.cos(0.5))