Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise RuntimeWarning if jax > 0.4.28 is installed #6864

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

<h3>Improvements 🛠</h3>

* A `RuntimeWarning` is now raised by `qml.QNode` and `qml.execute` if executing JAX workflows and the installed version of JAX
is greater than `0.4.28`.
[(#6864)](https://github.com/PennyLaneAI/pennylane/pull/6864)

* `QNode` objects now have an `update` method that allows for re-configuring settings like `diff_method`, `mcm_method`, and more. This allows for easier on-the-fly adjustments to workflows. Any arguments not specified will retain their original value.
[(#6803)](https://github.com/PennyLaneAI/pennylane/pull/6803)

Expand Down Expand Up @@ -105,4 +109,5 @@ Diksha Dhawan,
Pietropaolo Frisoni,
Marcus Gisslén,
Christina Lee,
Mudit Pandey,
Andrija Paurevic
41 changes: 36 additions & 5 deletions pennylane/math/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@

import warnings
from enum import Enum
from importlib.metadata import version
from importlib.util import find_spec
from typing import Literal, Union
from warnings import warn

import autoray as ar
from packaging.version import Version


class Interface(Enum):
Expand Down Expand Up @@ -212,23 +216,50 @@ def get_deep_interface(value):
return _get_interface_of_single_tensor(itr)


def get_canonical_interface_name(user_input: InterfaceLike) -> Interface:
def _check_supported_jax() -> None:
"""Checks if the installed version of JAX is supported. If an unsupported version of
JAX is installed, a ``RuntimeWarning`` is raised."""
if not find_spec("jax"):
return

jax_version = version("jax")
if Version(jax_version) > Version("0.4.28"): # pragma: no cover
warn(
"PennyLane is currently not compatible with versions of JAX > 0.4.28. "
f"You have version {jax_version} installed.",
RuntimeWarning,
)


def get_canonical_interface_name(
user_input: InterfaceLike, _validate_jax_version=False
) -> Interface:
"""Helper function to get the canonical interface name.

Args:
interface (str, Interface): reference interface
interface (str, Interface): Reference interface
_validate_jax_version (bool): Whether we should check if a supported version of
JAX is installed. If ``True``, a ``RuntimeWarning`` is raised if the canonical
interface is found to be JAX. ``False`` by default.

Raises:
ValueError: key does not exist in the interface map
ValueError: Key does not exist in the interface map

Returns:
Interface: canonical interface
Interface: Canonical interface
"""

if user_input in SUPPORTED_INTERFACE_NAMES:
if _validate_jax_version and user_input in (Interface.JAX, Interface.JAX_JIT):
_check_supported_jax()

return user_input
try:
return INTERFACE_MAP[user_input]
out = INTERFACE_MAP[user_input]

if out in (Interface.JAX, Interface.JAX_JIT):
_check_supported_jax()
return out
except KeyError as exc:
raise ValueError(
f"Unknown interface {user_input}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}."
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def __init__(
# input arguments
self.func = func
self.device = device
self._interface = get_canonical_interface_name(interface)
self._interface = get_canonical_interface_name(interface, _validate_jax_version=True)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
self.diff_method = diff_method
mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode)
cache = (max_diff > 1) if cache == "auto" else cache
Expand Down
4 changes: 2 additions & 2 deletions pennylane/workflow/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat
Interface: resolved interface
"""

interface = get_canonical_interface_name(interface)
interface = get_canonical_interface_name(interface, _validate_jax_version=True)

if interface == Interface.AUTO:
params = []
Expand All @@ -106,7 +106,7 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat
interface = get_interface(*params)
if interface != Interface.NUMPY:
try:
interface = get_canonical_interface_name(interface)
interface = get_canonical_interface_name(interface, _validate_jax_version=True)
except ValueError:
interface = Interface.NUMPY
if interface == Interface.TF and _use_tensorflow_autograph():
Expand Down
Loading