Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise RuntimeWarning if jax > 0.4.28 is installed #6864

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

<h3>Improvements 🛠</h3>

* A `RuntimeWarning` is now raised by `qml.QNode` and `qml.execute` if executing JAX workflows and the installed version of JAX
is greater than `0.4.28`.
[(#6864)](https://github.com/PennyLaneAI/pennylane/pull/6864)

* Added the `qml.workflow.construct_execution_config(qnode)(*args,**kwargs)` helper function.
Users can now construct the execution configuration from a particular `QNode` instance.
[(#6901)](https://github.com/PennyLaneAI/pennylane/pull/6901)
Expand Down
6 changes: 3 additions & 3 deletions pennylane/math/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,13 @@ def get_canonical_interface_name(user_input: InterfaceLike) -> Interface:
"""Helper function to get the canonical interface name.

Args:
interface (str, Interface): reference interface
interface (str, Interface): Reference interface

Raises:
ValueError: key does not exist in the interface map
ValueError: Key does not exist in the interface map

Returns:
Interface: canonical interface
Interface: Canonical interface
"""

if isinstance(user_input, Interface) and user_input in SUPPORTED_INTERFACE_NAMES:
Expand Down
5 changes: 4 additions & 1 deletion pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pennylane.tape import QuantumScript
from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram

from .resolution import SupportedDiffMethods
from .resolution import SupportedDiffMethods, _validate_jax_version

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -566,6 +566,9 @@ def __init__(
self.func = func
self.device = device
self._interface = get_canonical_interface_name(interface)
if self._interface in (Interface.JAX, Interface.JAX_JIT):
_validate_jax_version()

self.diff_method = diff_method
cache = (max_diff > 1) if cache == "auto" else cache

Expand Down
24 changes: 24 additions & 0 deletions pennylane/workflow/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
"""This module contains the necessary helper functions for setting up the workflow for execution."""
from collections.abc import Callable
from dataclasses import replace
from importlib.metadata import version
from importlib.util import find_spec
from typing import Literal, Optional, Union, get_args
from warnings import warn

from packaging.version import Version

import pennylane as qml
from pennylane.logging import debug_logger
Expand Down Expand Up @@ -58,6 +63,21 @@ def _use_tensorflow_autograph():
return not tf.executing_eagerly()


def _validate_jax_version():
"""Checks if the installed version of JAX is supported. If an unsupported version of
JAX is installed, a ``RuntimeWarning`` is raised."""
if not find_spec("jax"):
return

jax_version = version("jax")
if Version(jax_version) > Version("0.4.28"): # pragma: no cover
warn(
"PennyLane is currently not compatible with versions of JAX > 0.4.28. "
f"You have version {jax_version} installed.",
RuntimeWarning,
)


def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBatch) -> Interface:
"""Helper function to resolve an interface based on a set of tapes.

Expand All @@ -69,6 +89,8 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat
Interface: resolved interface
"""
interface = get_canonical_interface_name(interface)
if interface in (Interface.JAX, Interface.JAX_JIT):
_validate_jax_version()

if interface == Interface.AUTO:
params = []
Expand All @@ -77,6 +99,8 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat
interface = get_interface(*params)
try:
interface = get_canonical_interface_name(interface)
if interface in (Interface.JAX, Interface.JAX_JIT):
_validate_jax_version()
except ValueError:
# If the interface is not recognized, default to numpy, like networkx
interface = Interface.NUMPY
Expand Down