Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up how interface is handled in QNode and qml.execute #6225

Merged
merged 38 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c62ce82
Add _to_autograd to autograd execute
astralcai Sep 5, 2024
26111be
stop treating numpy as autograd internally
astralcai Sep 5, 2024
5e8740b
Merge branch 'master' into autograd-bug
astralcai Sep 5, 2024
26d16bc
bug fixes
astralcai Sep 6, 2024
c45b750
uncomment line
astralcai Sep 6, 2024
dfacd7f
Merge branch 'master' into autograd-bug
astralcai Sep 6, 2024
0633dc0
fix tiny bug
astralcai Sep 6, 2024
f1213cb
Merge branch 'master' into autograd-bug
astralcai Sep 6, 2024
ffb9d3c
more fix
astralcai Sep 9, 2024
5d6e8eb
fix bug
astralcai Sep 9, 2024
6776c01
fix tests
astralcai Sep 9, 2024
eca0f14
fix black
astralcai Sep 9, 2024
33f4f63
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Sep 9, 2024
2e4f55c
revert change
astralcai Sep 9, 2024
8c4a72a
add sparse matrix to Hermitian
astralcai Sep 9, 2024
3c8a691
bug fix
astralcai Sep 9, 2024
71561f9
bug fix
astralcai Sep 9, 2024
676d85b
bug fix
astralcai Sep 9, 2024
63525a2
Merge branch 'master' into autograd-bug
astralcai Sep 9, 2024
b34b48c
clean up handling of interface
astralcai Sep 10, 2024
9810b68
Merge branch 'master' into autograd-bug
astralcai Sep 10, 2024
c21eee3
fix isort
astralcai Sep 10, 2024
ffd41a9
update
astralcai Sep 10, 2024
cee3976
fix some tests
astralcai Sep 10, 2024
087dc22
fix tests
astralcai Sep 10, 2024
535e66b
make pylint happy
astralcai Sep 10, 2024
18c5fa6
update name
astralcai Sep 10, 2024
7fde693
fix isort
astralcai Sep 10, 2024
9e86111
fix tests
astralcai Sep 10, 2024
c4415a8
Update pennylane/workflow/qnode.py
astralcai Sep 11, 2024
6ac5315
Merge branch 'master' into autograd-bug
astralcai Sep 11, 2024
b6ad427
changelog
astralcai Sep 12, 2024
c41572e
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Sep 12, 2024
fee0af9
Merge branch 'master' into autograd-bug
astralcai Sep 12, 2024
61f0dc1
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Sep 13, 2024
a8b11cc
fix bug
astralcai Sep 13, 2024
a44ef16
add test
astralcai Sep 13, 2024
99d393c
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Optional, Union

from pennylane.workflow import SUPPORTED_INTERFACES
from pennylane.workflow import SUPPORTED_INTERFACE_NAMES


@dataclass
Expand Down Expand Up @@ -110,9 +110,9 @@ def __post_init__(self):

Note that this hook is automatically called after init via the dataclass integration.
"""
if self.interface not in SUPPORTED_INTERFACES:
if self.interface not in SUPPORTED_INTERFACE_NAMES:
raise ValueError(
f"Unknown interface. interface must be in {SUPPORTED_INTERFACES}, got {self.interface} instead."
f"Unknown interface. interface must be in {SUPPORTED_INTERFACE_NAMES}, got {self.interface} instead."
)

if self.grad_on_execution not in {True, False, None}:
Expand Down
18 changes: 9 additions & 9 deletions pennylane/devices/legacy_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pennylane as qml
from pennylane.measurements import MidMeasureMP, Shots
from pennylane.transforms.core.transform_program import TransformProgram
from pennylane.workflow.execution import INTERFACE_MAP
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

from .device_api import Device
from .execution_config import DefaultExecutionConfig
Expand Down Expand Up @@ -328,18 +329,18 @@ def _create_temp_device(self, batch):
if interface == "numpy":
return self._device

mapped_interface = qml.workflow.execution.INTERFACE_MAP.get(interface, interface)
interface = INTERFACE_MAP.get(interface, interface)

backprop_interface = self._device.capabilities().get("passthru_interface", None)
if mapped_interface == backprop_interface:
if interface == backprop_interface:
return self._device

backprop_devices = self._device.capabilities().get("passthru_devices", None)

if backprop_devices is None:
raise qml.DeviceError(f"Device {self} does not support backpropagation.")

if backprop_devices[mapped_interface] == self._device.short_name:
if backprop_devices[interface] == self._device.short_name:
return self._device

if self.target_device.short_name != "default.qubit.legacy":
Expand Down Expand Up @@ -367,7 +368,7 @@ def _create_temp_device(self, batch):
)
# we already warned about backprop device switching
new_device = qml.device(
backprop_devices[mapped_interface],
backprop_devices[interface],
astralcai marked this conversation as resolved.
Show resolved Hide resolved
wires=self._device.wires,
shots=self._device.shots,
).target_device
Expand Down Expand Up @@ -396,25 +397,24 @@ def _validate_backprop_method(self, tape):
return False
params = tape.get_parameters(trainable_only=False)
interface = qml.math.get_interface(*params)
if interface != "numpy":
interface = INTERFACE_MAP.get(interface, interface)

if tape and any(isinstance(m.obs, qml.SparseHamiltonian) for m in tape.measurements):
return False
if interface == "numpy":
interface = None
mapped_interface = qml.workflow.execution.INTERFACE_MAP.get(interface, interface)

# determine if the device supports backpropagation
backprop_interface = self._device.capabilities().get("passthru_interface", None)

if backprop_interface is not None:
# device supports backpropagation natively
return mapped_interface in [backprop_interface, "Numpy"]
return interface in [backprop_interface, "numpy"]
# determine if the device has any child devices that support backpropagation
backprop_devices = self._device.capabilities().get("passthru_devices", None)

if backprop_devices is None:
return False
return mapped_interface in backprop_devices or mapped_interface == "Numpy"
return interface in backprop_devices or interface == "numpy"

def _validate_adjoint_method(self, tape):
# The conditions below provide a minimal set of requirements that we can likely improve upon in
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _(original_measurement: ExpectationMP, measures): # pylint: disable=unused-
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += v[0] * v[1]
cum_value += qml.math.multiply(v[0], v[1])
total_counts += v[0]
return cum_value / total_counts

Expand All @@ -935,7 +935,7 @@ def _(original_measurement: ProbabilityMP, measures): # pylint: disable=unused-
for v in measures.values():
if not v[0] or v[1] is tuple():
continue
cum_value += v[0] * v[1]
cum_value += qml.math.multiply(v[0], v[1])
total_counts += v[0]
return cum_value / total_counts

Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/qubit/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def compute_matrix(A: TensorLike) -> TensorLike: # pylint: disable=arguments-di
Hermitian._validate_input(A)
return A

@staticmethod
def compute_sparse_matrix(A) -> csr_matrix: # pylint: disable=arguments-differ
astralcai marked this conversation as resolved.
Show resolved Hide resolved
return csr_matrix(Hermitian.compute_matrix(A))

@property
def eigendecomposition(self) -> dict[str, TensorLike]:
"""Return the eigendecomposition of the matrix specified by the Hermitian observable.
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@

"""
from .construct_batch import construct_batch, get_transform_program
from .execution import INTERFACE_MAP, SUPPORTED_INTERFACES, execute
from .execution import INTERFACE_MAP, SUPPORTED_INTERFACE_NAMES, execute
from .qnode import QNode, qnode
from .set_shots import set_shots
75 changes: 43 additions & 32 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@
"autograd",
"numpy",
"torch",
"pytorch",
"jax",
"jax-python",
"jax-jit",
"tf",
"tensorflow",
}

SupportedInterfaceUserInput = Literal[
Expand All @@ -78,30 +75,29 @@
]

_mapping_output = (
"Numpy",
"numpy",
astralcai marked this conversation as resolved.
Show resolved Hide resolved
"auto",
"autograd",
"autograd",
"numpy",
"jax",
"jax",
"jax-jit",
"jax",
"jax",
"torch",
"torch",
"tf",
"tf",
"tf",
"tf",
"tf-autograph",
"tf-autograph",
astralcai marked this conversation as resolved.
Show resolved Hide resolved
)

INTERFACE_MAP = dict(zip(get_args(SupportedInterfaceUserInput), _mapping_output))
"""dict[str, str]: maps an allowed interface specification to its canonical name."""

#: list[str]: allowed interface strings
SUPPORTED_INTERFACES = list(INTERFACE_MAP)
SUPPORTED_INTERFACE_NAMES = list(INTERFACE_MAP)
"""list[str]: allowed interface strings"""


_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS = (
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
Expand Down Expand Up @@ -135,41 +131,41 @@ def _get_ml_boundary_execute(
pennylane.QuantumFunctionError if the required package is not installed.

"""
mapped_interface = INTERFACE_MAP[interface]
astralcai marked this conversation as resolved.
Show resolved Hide resolved
try:
if mapped_interface == "autograd":
if interface == "autograd":
from .interfaces.autograd import autograd_execute as ml_boundary

elif mapped_interface == "tf":
if "autograph" in interface:
from .interfaces.tensorflow_autograph import execute as ml_boundary
elif interface == "tf-autograph":
from .interfaces.tensorflow_autograph import execute as ml_boundary

ml_boundary = partial(ml_boundary, grad_on_execution=grad_on_execution)
ml_boundary = partial(ml_boundary, grad_on_execution=grad_on_execution)

else:
from .interfaces.tensorflow import tf_execute as full_ml_boundary
elif interface == "tf":
from .interfaces.tensorflow import tf_execute as full_ml_boundary

ml_boundary = partial(full_ml_boundary, differentiable=differentiable)
ml_boundary = partial(full_ml_boundary, differentiable=differentiable)

elif mapped_interface == "torch":
elif interface == "torch":
from .interfaces.torch import execute as ml_boundary

elif interface == "jax-jit":
if device_vjp:
from .interfaces.jax_jit import jax_jit_vjp_execute as ml_boundary
else:
from .interfaces.jax_jit import jax_jit_jvp_execute as ml_boundary
else: # interface in {"jax", "jax-python", "JAX"}:

else: # interface is jax
if device_vjp:
from .interfaces.jax_jit import jax_jit_vjp_execute as ml_boundary
else:
from .interfaces.jax import jax_jvp_execute as ml_boundary

except ImportError as e: # pragma: no-cover
raise qml.QuantumFunctionError(
f"{mapped_interface} not found. Please install the latest "
f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
f"{interface} not found. Please install the latest "
f"version of {interface} to enable the '{interface}' interface."
) from e

return ml_boundary


Expand Down Expand Up @@ -263,12 +259,22 @@ def _get_interface_name(tapes, interface):

Returns:
str: Interface name"""

if interface not in SUPPORTED_INTERFACE_NAMES:
raise qml.QuantumFunctionError(
f"Unknown interface {interface}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}."
)

interface = INTERFACE_MAP[interface]

if interface == "auto":
params = []
for tape in tapes:
params.extend(tape.get_parameters(trainable_only=False))
interface = qml.math.get_interface(*params)
if INTERFACE_MAP.get(interface, "") == "tf" and _use_tensorflow_autograph():
if interface != "numpy":
interface = INTERFACE_MAP[interface]
if interface == "tf" and _use_tensorflow_autograph():
interface = "tf-autograph"
if interface == "jax":
try: # pragma: no cover
Expand Down Expand Up @@ -439,6 +445,7 @@ def cost_fn(params, x):

### Specifying and preprocessing variables ####

_interface_user_input = interface
interface = _get_interface_name(tapes, interface)
# Only need to calculate derivatives with jax when we know it will be executed later.
if interface in {"jax", "jax-jit"}:
Expand All @@ -460,7 +467,11 @@ def cost_fn(params, x):
)

# Mid-circuit measurement configuration validation
mcm_interface = interface or _get_interface_name(tapes, "auto")
# If the user specifies `interface=None`, regular execution considers it numpy, but the mcm
# workflow still needs to know if jax-jit is used
mcm_interface = (
_get_interface_name(tapes, "auto") if _interface_user_input is None else interface
)
finite_shots = any(tape.shots for tape in tapes)
_update_mcm_config(config.mcm_config, mcm_interface, finite_shots)

Expand All @@ -479,12 +490,12 @@ def cost_fn(params, x):
cache = None

# changing this set of conditions causes a bunch of tests to break.
no_interface_boundary_required = interface is None or config.gradient_method in {
no_interface_boundary_required = interface == "numpy" or config.gradient_method in {
astralcai marked this conversation as resolved.
Show resolved Hide resolved
None,
"backprop",
}
device_supports_interface_data = no_interface_boundary_required and (
interface is None
interface == "numpy"
or config.gradient_method == "backprop"
or getattr(device, "short_name", "") == "default.mixed"
)
Expand All @@ -497,9 +508,9 @@ def cost_fn(params, x):
numpy_only=not device_supports_interface_data,
)

# moved to its own explicit step so it will be easier to remove
# moved to its own explicit step so that it will be easier to remove
def inner_execute_with_empty_jac(tapes, **_):
return (inner_execute(tapes), [])
return inner_execute(tapes), []

if interface in jpc_interfaces:
execute_fn = inner_execute
Expand All @@ -522,7 +533,7 @@ def inner_execute_with_empty_jac(tapes, **_):
and getattr(device, "short_name", "") in ("lightning.gpu", "lightning.kokkos")
and interface in jpc_interfaces
): # pragma: no cover
if INTERFACE_MAP[interface] == "jax" and "use_device_state" in gradient_kwargs:
if "jax" in interface and "use_device_state" in gradient_kwargs:
gradient_kwargs["use_device_state"] = False

jpc = LightningVJPs(device, gradient_kwargs=gradient_kwargs)
Expand Down Expand Up @@ -563,7 +574,7 @@ def execute_fn(internal_tapes) -> tuple[ResultBatch, tuple]:
config: the ExecutionConfig that specifies how to perform the simulations.
"""
numpy_tapes, _ = qml.transforms.convert_to_numpy_parameters(internal_tapes)
return (device.execute(numpy_tapes, config), tuple())
return device.execute(numpy_tapes, config), tuple()

def gradient_fn(internal_tapes):
"""A partial function that wraps compute_derivatives method of the device.
Expand Down Expand Up @@ -612,7 +623,7 @@ def gradient_fn(internal_tapes):

# trainable parameters can only be set on the first pass for jax
# not higher order passes for higher order derivatives
if interface in {"jax", "jax-python", "jax-jit"}:
if "jax" in interface:
for tape in tapes:
params = tape.get_parameters(trainable_only=False)
tape.trainable_params = qml.math.get_trainable_indices(params)
Expand Down
17 changes: 16 additions & 1 deletion pennylane/workflow/interfaces/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ def autograd_execute(
return _execute(parameters, tuple(tapes), execute_fn, jpc)


def _to_autograd(result: qml.typing.ResultBatch) -> qml.typing.ResultBatch:
"""Converts an arbitrary result batch to one with autograd arrays.
Args:
result (ResultBatch): a nested structure of lists, tuples, dicts, and numpy arrays
Returns:
ResultBatch: a nested structure of tuples, dicts, and jax arrays
"""
if isinstance(result, dict):
return result
# pylint: disable=no-member
if isinstance(result, (list, tuple, autograd.builtins.tuple, autograd.builtins.list)):
return tuple(_to_autograd(r) for r in result)
return autograd.numpy.array(result)


@autograd.extend.primitive
def _execute(
parameters,
Expand All @@ -165,7 +180,7 @@ def _execute(
for the input tapes.

"""
return execute_fn(tapes)
return _to_autograd(execute_fn(tapes))


# pylint: disable=unused-argument
Expand Down
Loading
Loading