Skip to content

Commit

Permalink
Move _cache_transform to its own file in qml.workflow (#6624)
Browse files Browse the repository at this point in the history
**Context:**

Introduces a new file `_cache_transform.py` in `qml.workflow` to house
the private `_cache_transform` helper. Previously it lived in
`execution.py` despite having its own unit test file
`workflow/test_cache_transform.py`.

Also, update file name: `resolve_diff_method.py` →
`_resolve_diff_method.py` as I should have done earlier. 😄

**Benefits:** Code consistency and clean-up

---------

Co-authored-by: Yushao Chen (Jerry) <[email protected]>
  • Loading branch information
andrijapau and JerryChen97 authored Nov 21, 2024
1 parent 21a12a1 commit 96d2587
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 58 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@

<h4>Other Improvements</h4>

* `_cache_transform` transform has been moved to its own file located
at `qml.workflow._cache_transform.py`.
[(#6624)](https://github.com/PennyLaneAI/pennylane/pull/6624)

* `qml.BasisRotation` template is now JIT compatible.
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)

Expand Down
3 changes: 2 additions & 1 deletion pennylane/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@
from .construct_tape import construct_tape
from .execution import INTERFACE_MAP, SUPPORTED_INTERFACE_NAMES, execute
from .get_best_diff_method import get_best_diff_method
from .resolve_diff_method import _resolve_diff_method
from .qnode import QNode, qnode
from ._cache_transform import _cache_transform
from ._resolve_diff_method import _resolve_diff_method
70 changes: 70 additions & 0 deletions pennylane/workflow/_cache_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the transform for caching the result of a ``tape``.
"""

import warnings
from collections.abc import MutableMapping

from pennylane.tape import QuantumScript
from pennylane.transforms import transform
from pennylane.typing import Result, ResultBatch

_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS = (
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
"will be identical across executions. This situation arises where tapes "
"are executed with identical operations, measurements, and parameters.\n"
"To avoid this behaviour, provide 'cache=False' to the QNode or execution "
"function."
)
"""str: warning message to display when cached execution is used with finite shots"""


@transform
def _cache_transform(tape: QuantumScript, cache: MutableMapping):
"""Caches the result of ``tape`` using the provided ``cache``.
.. note::
This function makes use of :attr:`.QuantumTape.hash` to identify unique tapes.
"""

def cache_hit_postprocessing(_results: ResultBatch) -> Result:
result = cache[tape.hash]
if result is not None:
if tape.shots and getattr(cache, "_persistent_cache", True):
warnings.warn(_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS, UserWarning)
return result

raise RuntimeError(
"Result for tape is missing from the execution cache. "
"This is likely the result of a race condition."
)

if tape.hash in cache:
return [], cache_hit_postprocessing

def cache_miss_postprocessing(results: ResultBatch) -> Result:
result = results[0]
cache[tape.hash] = result
return result

# Adding a ``None`` entry to the cache indicates that a result will eventually be available for
# the tape. This assumes that post-processing functions are called in the same order in which
# the transforms are invoked. Otherwise, ``cache_hit_postprocessing()`` may be called before the
# result of the corresponding tape is placed in the cache by ``cache_miss_postprocessing()``.
cache[tape.hash] = None
return [tape], cache_miss_postprocessing
File renamed without changes.
56 changes: 4 additions & 52 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@

import inspect
import logging
import warnings
from collections.abc import Callable, MutableMapping
from collections.abc import Callable
from functools import partial
from typing import Literal, Optional, Union, get_args
from warnings import warn

from cachetools import Cache, LRUCache

import pennylane as qml
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.typing import Result, ResultBatch
from pennylane.tape import QuantumScriptBatch
from pennylane.typing import ResultBatch

from ._cache_transform import _cache_transform
from .jacobian_products import DeviceDerivatives, DeviceJacobianProducts, TransformJacobianProducts

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,16 +93,6 @@
SUPPORTED_INTERFACE_NAMES = list(INTERFACE_MAP)
"""list[str]: allowed interface strings"""

_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS = (
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
"will be identical across executions. This situation arises where tapes "
"are executed with identical operations, measurements, and parameters.\n"
"To avoid this behaviour, provide 'cache=False' to the QNode or execution "
"function."
)
"""str: warning message to display when cached execution is used with finite shots"""


def _use_tensorflow_autograph():
"""Checks if TensorFlow is in graph mode, allowing Autograph for optimized execution"""
Expand Down Expand Up @@ -217,43 +206,6 @@ def inner_execute(tapes: QuantumScriptBatch, **_) -> ResultBatch:
return inner_execute


@transform
def _cache_transform(tape: QuantumScript, cache: MutableMapping):
"""Caches the result of ``tape`` using the provided ``cache``.
.. note::
This function makes use of :attr:`.QuantumTape.hash` to identify unique tapes.
"""

def cache_hit_postprocessing(_results: ResultBatch) -> Result:
result = cache[tape.hash]
if result is not None:
if tape.shots and getattr(cache, "_persistent_cache", True):
warnings.warn(_CACHED_EXECUTION_WITH_FINITE_SHOTS_WARNINGS, UserWarning)
return result

raise RuntimeError(
"Result for tape is missing from the execution cache. "
"This is likely the result of a race condition."
)

if tape.hash in cache:
return [], cache_hit_postprocessing

def cache_miss_postprocessing(results: ResultBatch) -> Result:
result = results[0]
cache[tape.hash] = result
return result

# Adding a ``None`` entry to the cache indicates that a result will eventually be available for
# the tape. This assumes that post-processing functions are called in the same order in which
# the transforms are invoked. Otherwise, ``cache_hit_postprocessing()`` may be called before the
# result of the corresponding tape is placed in the cache by ``cache_miss_postprocessing()``.
cache[tape.hash] = None
return [tape], cache_miss_postprocessing


def _get_interface_name(tapes, interface):
"""Helper function to get the interface name of a list of tapes
Expand Down
6 changes: 3 additions & 3 deletions tests/interfaces/test_jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class TestCaching:
def test_cache_maxsize(self, mocker):
"""Test the cachesize property of the cache"""
dev = qml.device("default.qubit", wires=1)
spy = mocker.spy(qml.workflow.execution._cache_transform, "_transform")
spy = mocker.spy(qml.workflow._cache_transform, "_transform")

def cost(a, cachesize):
with qml.queuing.AnnotatedQueue() as q:
Expand Down Expand Up @@ -209,7 +209,7 @@ def cost(a, cachesize):
def test_custom_cache(self, mocker):
"""Test the use of a custom cache object"""
dev = qml.device("default.qubit", wires=1)
spy = mocker.spy(qml.workflow.execution._cache_transform, "_transform")
spy = mocker.spy(qml.workflow._cache_transform, "_transform")

def cost(a, cache):
with qml.queuing.AnnotatedQueue() as q:
Expand All @@ -236,7 +236,7 @@ def cost(a, cache):
def test_custom_cache_multiple(self, mocker):
"""Test the use of a custom cache object with multiple tapes"""
dev = qml.device("default.qubit", wires=1)
spy = mocker.spy(qml.workflow.execution._cache_transform, "_transform")
spy = mocker.spy(qml.workflow._cache_transform, "_transform")

a = jax.numpy.array(0.1)
b = jax.numpy.array(0.2)
Expand Down
2 changes: 1 addition & 1 deletion tests/logging/test_logging_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def circuit(params):
["Calling <construct(self=<QNode: device='<default.qubit device"],
),
(
"pennylane.workflow.resolve_diff_method",
"pennylane.workflow._resolve_diff_method",
["Calling <_resolve_diff_method("],
),
(
Expand Down
2 changes: 1 addition & 1 deletion tests/workflow/test_cache_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pennylane as qml
from pennylane.tape import QuantumScript
from pennylane.workflow.execution import _cache_transform
from pennylane.workflow import _cache_transform


@pytest.fixture
Expand Down

0 comments on commit 96d2587

Please sign in to comment.