Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defer calculating classical cotransform and argnums till needed #6716

Merged
merged 287 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
287 commits
Select commit Hold shift + click to select a range
40dead4
Merge branch 'master' into add-resolve-exec-config
andrijapau Nov 18, 2024
6d7753b
update get_gradient_fn login
andrijapau Nov 18, 2024
a3a7cf5
refactor get_gradient_fn to resolve_diff_method
andrijapau Nov 18, 2024
be604d7
add new resolve_diff_method files
andrijapau Nov 18, 2024
2a6bfc8
Merge branch 'master' into add-resolve-exec-config
andrijapau Nov 18, 2024
7fe5535
fix language in error message
andrijapau Nov 18, 2024
34fc29c
Update pennylane/workflow/resolve_diff_method.py
andrijapau Nov 18, 2024
8637526
only filter for lightning.qubit
andrijapau Nov 18, 2024
c7a62bd
Merge branch 'add-resolve-exec-config' of github.com:PennyLaneAI/penn…
andrijapau Nov 18, 2024
d0feae8
return to filter all lightning
andrijapau Nov 18, 2024
79ab797
fix errors
andrijapau Nov 19, 2024
dc53839
Merge branch 'master' into add-resolve-exec-config
andrijapau Nov 19, 2024
1617c1d
Merge branch 'master' into add-resolve-exec-config
andrijapau Nov 20, 2024
9f9161a
initial work
andrijapau Nov 20, 2024
34afc7c
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 20, 2024
16cab30
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 20, 2024
c908a20
fix jax jit tests
andrijapau Nov 20, 2024
9dfe04e
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 20, 2024
5bbf619
update signature of function to fix tests
andrijapau Nov 20, 2024
b402cce
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 20, 2024
90c04ae
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 21, 2024
d1848cf
fix test_cache_transform.py
andrijapau Nov 21, 2024
cb64e49
Update workflow/_init__.py
andrijapau Nov 21, 2024
d42a114
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 22, 2024
e88da49
move setup transform to a different file
andrijapau Nov 22, 2024
6a8ce29
resolve circular imports
andrijapau Nov 22, 2024
15ce596
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 22, 2024
bce6077
add new test files
andrijapau Nov 22, 2024
2e0aa0f
add new test files
andrijapau Nov 22, 2024
ca23125
repair test
andrijapau Nov 22, 2024
4972655
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 25, 2024
7e59c6d
Update _setup_transform_program as per Christina's comment
andrijapau Nov 25, 2024
10e0a57
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 25, 2024
4364105
FIX IMPORTS
andrijapau Nov 25, 2024
79f647f
Add tests for _setup_transform_program
andrijapau Nov 25, 2024
31c3326
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 25, 2024
6614969
Add tests for _resolve_execution_config
andrijapau Nov 25, 2024
23739dc
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 25, 2024
6254931
add jax mark so tests dont fail
andrijapau Nov 25, 2024
ae7c7de
import _setup_transform_program from top
andrijapau Nov 25, 2024
b70167f
Address Christina's comments
andrijapau Nov 26, 2024
283bf94
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 26, 2024
aad3d07
oops forgot to update tests
andrijapau Nov 26, 2024
952e103
fix tests
andrijapau Nov 26, 2024
899b174
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 26, 2024
51b8143
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 26, 2024
5fd0f5a
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 27, 2024
02db9b1
remove protected access
andrijapau Nov 27, 2024
73271d0
move _setup_transform_program to top-level import
andrijapau Nov 27, 2024
18e535d
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 27, 2024
89bf09c
initial attempt
andrijapau Nov 27, 2024
7aff28e
Merge branch 'master' into add-interface-enum
andrijapau Nov 27, 2024
c3d1054
update import path
andrijapau Nov 27, 2024
bfa708a
update import paths
andrijapau Nov 27, 2024
e01934d
remove unneccesary error check
andrijapau Nov 27, 2024
eab6e13
fix qutrit simulate
andrijapau Nov 27, 2024
753c4d6
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 27, 2024
bb2ec01
Merge branch 'add-setup-transform-program' into add-interface-enum
andrijapau Nov 27, 2024
3a67d08
oops, I put the wrong argument
andrijapau Nov 27, 2024
3823a26
add Lillian's suggestion for null_postprocessing
andrijapau Nov 27, 2024
63b244c
add Lillian's suggestion for _setup_transform_program
andrijapau Nov 27, 2024
b581d64
Merge branch 'add-setup-transform-program' into add-interface-enum
andrijapau Nov 27, 2024
2c529d3
oops, forgot to remove a function that wasn't being used
andrijapau Nov 27, 2024
7fb784c
Merge branch 'add-setup-transform-program' into add-interface-enum
andrijapau Nov 27, 2024
c320875
fix mocking in test_resolve_execution_config.py
andrijapau Nov 27, 2024
1f4d069
fix mocking in test_resolve_execution_config.py
andrijapau Nov 27, 2024
621496f
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 27, 2024
47307ea
fix test that had the wrong regex
andrijapau Nov 27, 2024
b85bccf
attempt fix for tests
andrijapau Nov 27, 2024
1039e14
Merge branch 'master' into add-interface-enum
andrijapau Nov 27, 2024
51796a7
remove mocks from test_setup_transform_program.py
andrijapau Nov 28, 2024
d917691
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 28, 2024
7b6bcb6
improve testing
andrijapau Nov 28, 2024
7687f65
improve testing
andrijapau Nov 28, 2024
32a5651
add doc string for _setup_transform_programs
andrijapau Nov 28, 2024
216ba9c
fix circular import :(
andrijapau Nov 28, 2024
0b7d5ac
improve test_setup_transform_program
andrijapau Nov 28, 2024
d05222b
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 28, 2024
a0ce8c1
improve test_setup_transform_program
andrijapau Nov 28, 2024
0db66e8
improve test_setup_transform_program
andrijapau Nov 28, 2024
822371b
improve test_setup_transform_program
andrijapau Nov 28, 2024
fecb884
fix tests
andrijapau Nov 28, 2024
a456608
fix tests
andrijapau Nov 28, 2024
eb74414
fix documentation tests and ml tests
andrijapau Nov 28, 2024
db1daf9
remove one line fixtures
andrijapau Nov 28, 2024
9aa368f
initial movement of files :ok_hand:
andrijapau Nov 28, 2024
88a2c98
add new file
andrijapau Nov 28, 2024
f5556c4
fix tests and update changelog
andrijapau Nov 28, 2024
67b4da6
removed protected-access pylint disable
andrijapau Nov 28, 2024
cc98eba
shuffle things around between utils and interface
andrijapau Nov 28, 2024
f1049d1
fix imports
andrijapau Nov 28, 2024
1210d12
fix imports
andrijapau Nov 28, 2024
c8d8ffd
improve existing code
andrijapau Nov 29, 2024
e7fb3ef
Merge branch 'master' into move-interface-logic
andrijapau Nov 29, 2024
7d1c01d
change file name
andrijapau Nov 29, 2024
c493260
docs: Update interfaces.rst
andrijapau Nov 29, 2024
a708ed9
Merge branch 'add-setup-transform-program' into add-interface-enum
andrijapau Nov 29, 2024
a6fa61c
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 29, 2024
60a14b8
Merge branch 'move-interface-logic' into add-interface-enum
andrijapau Nov 29, 2024
c1d22eb
doc: Update changelog-dev.md
andrijapau Nov 29, 2024
34a863e
doc: Update changelog-dev.md
andrijapau Nov 29, 2024
fe5af9b
Merge branch 'master' into add-setup-transform-program
andrijapau Nov 29, 2024
4051bad
Merge branch 'master' into move-interface-logic
andrijapau Nov 29, 2024
c6c9d96
Merge branch 'add-setup-transform-program' into move-interface-logic
andrijapau Nov 29, 2024
0d66855
fix: Update resolution.py
andrijapau Nov 29, 2024
7e887ef
Merge branch 'master' into move-interface-logic
andrijapau Nov 29, 2024
f147ece
fix: Update resolution.py to make isort happy
andrijapau Nov 29, 2024
de7345c
Merge branch 'move-interface-logic' into add-interface-enum
andrijapau Nov 29, 2024
9ff483b
fix: Update interface_utils.py to make formatting happy
andrijapau Nov 29, 2024
2ea7f04
doc: Update interface_utils.py doc strings
andrijapau Nov 29, 2024
8be289e
fix: Update qnode.py and execution.py
andrijapau Nov 29, 2024
e16ff90
refactor: Update test_execution_config.py
andrijapau Nov 29, 2024
109fde3
fix: Update interface_utils.py to remove circular import warnings
andrijapau Nov 29, 2024
af6af5c
fix: :ambulance: test_qnode_legacy.py change error type
andrijapau Nov 29, 2024
633106a
feat: Add run.py
andrijapau Dec 2, 2024
6cf910d
feat: Add run.py
andrijapau Dec 2, 2024
f54bf33
feat: Add test_run.py file
andrijapau Dec 2, 2024
c9fa462
fix: Update execution.py config handling
andrijapau Dec 2, 2024
23f4f8c
fix: use config for max_diff info
andrijapau Dec 2, 2024
945ae5c
fix: get cache info from inner transform program
andrijapau Dec 2, 2024
4fd30c5
fix: remove jpc_interfaces from codebase
andrijapau Dec 2, 2024
4c30749
Merge branch 'master' into move-interface-logic
andrijapau Dec 2, 2024
89ccde9
fix: remove jpc_interfaces and also refactor diff_method
andrijapau Dec 2, 2024
d1f4c5f
refactor: move pennylane dep fxns back to workflow
andrijapau Dec 2, 2024
5c907f1
fix: update pylint disable
andrijapau Dec 2, 2024
71a3ac2
fix: Update __init__.py to import _resolve_interface
andrijapau Dec 2, 2024
26bdbca
fix: Update execution.py for tf tests
andrijapau Dec 2, 2024
dc0c6a3
refactor: Remove print statement in run.py
andrijapau Dec 2, 2024
b9e5413
refactor: Update spacing in run.py
andrijapau Dec 2, 2024
85ee629
feat: Add test_resolve_interface.py
andrijapau Dec 3, 2024
8a66e26
Merge branch 'master' into move-interface-logic
andrijapau Dec 3, 2024
3cefd06
fix: Remove duplicate _resolve_interface function in resolution.py
andrijapau Dec 3, 2024
b94179b
Merge branch 'move-interface-logic' into add-interface-enum
andrijapau Dec 3, 2024
0fb3bb6
fix: Update execution.py to use Interface.TF_AUTOGRAPH
andrijapau Dec 3, 2024
567b10d
fix: Update simulate.py for devices to use canonical interface names
andrijapau Dec 3, 2024
f9ffb4c
fix: Revert measurements.py and conversion.py
andrijapau Dec 3, 2024
453d980
doc: Update documentation in execution.py and resolution.py
andrijapau Dec 3, 2024
6309611
doc: Update doc string in run.py
andrijapau Dec 3, 2024
ac4bf97
fix: Make get_canonical_interface_name a public helper function
andrijapau Dec 3, 2024
30d8861
Merge branch 'master' into move-interface-logic
andrijapau Dec 3, 2024
7f63f88
fix: Get rid of INTERFACE_MAP in codebase and use get_canonical_inter…
andrijapau Dec 3, 2024
99b33ba
fix: Fix handling of INTERFACE_MAP defaulting to None before
andrijapau Dec 3, 2024
9e52490
fix: Update tests in test_qnode*.py to correct error type
andrijapau Dec 3, 2024
37877c0
fix: Update qnode.py to handle case where INTERFACE_MAP defaulted to …
andrijapau Dec 3, 2024
af561c5
Merge branch 'master' into move-interface-logic
andrijapau Dec 3, 2024
b17aebb
fix: Update qnode.py to raise ValueError instead of QFE
andrijapau Dec 4, 2024
a7cfcb6
Merge branch 'master' into move-interface-logic
andrijapau Dec 4, 2024
7cd52e5
fix: Update test_qnode.py to catch ValueError instead of QFE
andrijapau Dec 4, 2024
c412b8d
Merge branch 'move-interface-logic' into add-interface-enum
andrijapau Dec 4, 2024
cd2d853
fix: Update qnode.py _impl_call logic
andrijapau Dec 4, 2024
7ad32e5
fix: Update qnode.py _impl_call logic
andrijapau Dec 4, 2024
9ba89e9
fix: Update test_logging_autograd.py
andrijapau Dec 4, 2024
6a994b7
feat: Add autograd test to test_resolve_interface.py
andrijapau Dec 4, 2024
0e53dac
Merge branch 'master' into move-interface-logic
andrijapau Dec 4, 2024
398bcab
Merge branch 'move-interface-logic' into add-interface-enum
andrijapau Dec 4, 2024
3cec1b5
Merge branch 'add-interface-enum' into add-dev-run-fxn
andrijapau Dec 4, 2024
c08e1d4
fix: Update test_resolve_interface.py with Interface ENUM
andrijapau Dec 4, 2024
ab0e917
fix: Update path of _resolve_interface in execution.py
andrijapau Dec 4, 2024
525c4fd
Trigger CI
andrijapau Dec 4, 2024
7c4a99a
Trigger CI
andrijapau Dec 4, 2024
36e8430
Trigger CI
andrijapau Dec 4, 2024
0ac33c8
Merge branch 'master' into add-interface-enum
andrijapau Dec 4, 2024
ec454f6
fix: Update _get_ml_boundary_execute doc string
andrijapau Dec 4, 2024
2ed1fcd
feat: Update _get_ml_boundary_execute to use match case instead of if…
andrijapau Dec 4, 2024
48750b2
fix: Change interface=None from test_resolve_mcm_config.py
andrijapau Dec 4, 2024
f459d64
feat: Try using different sphinx-action branch rather than master
andrijapau Dec 4, 2024
addc46c
feat: Try using different sphinx-action branch rather than master
andrijapau Dec 4, 2024
b9982f7
Revert "feat: Try using different sphinx-action branch rather than ma…
andrijapau Dec 4, 2024
0658f4d
Revert "feat: Try using different sphinx-action branch rather than ma…
andrijapau Dec 4, 2024
64e5582
Revert "feat: Update _get_ml_boundary_execute to use match case inste…
andrijapau Dec 4, 2024
c09af25
fix: Update _get_ml_boundary_execute redundancy for JAX_JIT case
andrijapau Dec 4, 2024
fb04b18
feat: Try removing internal interface maps like INTERFACE_TO_LIKE
andrijapau Dec 4, 2024
c68d7cb
Merge branch 'master' into add-interface-enum
andrijapau Dec 4, 2024
35399de
fix: Update simulate.py for qubit and qutrit_mixed to fix interface h…
andrijapau Dec 5, 2024
1de05d4
fix: Update null_qubit.py to use Interface enum
andrijapau Dec 5, 2024
a7c2832
fix: Update simulate.py for qubit and qutrit_mixed to use get_canonic…
andrijapau Dec 5, 2024
cae716b
fix: Revert null_qubit.py logic
andrijapau Dec 5, 2024
68fa758
feat: Add get_like() method to Interface enum
andrijapau Dec 5, 2024
8a8a25a
Merge branch 'master' into add-interface-enum
andrijapau Dec 5, 2024
ea45c86
feat: Update simulate.py to use get_like()
andrijapau Dec 5, 2024
5711839
doc: Fix doc string in Interface enum
andrijapau Dec 5, 2024
8909474
fix: Update null_qubit.py to use get_like()
andrijapau Dec 5, 2024
bc51a14
Merge branch 'master' into add-interface-enum
andrijapau Dec 5, 2024
dd003e0
feat: Add tests to test_run.py
andrijapau Dec 5, 2024
76a57d3
feat: Update call signature to _get_ml_boundary_execute
andrijapau Dec 5, 2024
43eaabd
style: Clean-up cache logic in run.py
andrijapau Dec 5, 2024
24a58b7
feat: Add _construct_ml_execution_pipeline to run.py
andrijapau Dec 5, 2024
584db7f
fix: Add torch mark to test in test_run.py
andrijapau Dec 5, 2024
e0f90cf
doc: Update doc string for _consrtruct_ml_execution_pipeline helper
andrijapau Dec 5, 2024
83a3d71
fix: Update _construct_ml_execution_pipeline handling of execute_fn a…
andrijapau Dec 5, 2024
2ce22b4
fix: Update interface_utils.py to address Lillian's comment
andrijapau Dec 5, 2024
d157970
Merge branch 'master' into add-interface-enum
andrijapau Dec 5, 2024
6654db9
Merge branch 'add-interface-enum' into add-dev-run-fxn
andrijapau Dec 5, 2024
e41e7b3
doc: Add changelog-dev.md entry
andrijapau Dec 5, 2024
9bf8887
feat: Move pure callback JAX interface update after the _construct_ml…
andrijapau Dec 5, 2024
45f195f
fix: Move pure callback JAX interface update after the _construct_ml_…
andrijapau Dec 5, 2024
6b4c0ea
feat: Split logic for execution pipeline for tf-autograph and jpc int…
andrijapau Dec 5, 2024
609c569
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 5, 2024
76b0fe6
feat: Update test_run.py with more tests
andrijapau Dec 6, 2024
d8ae164
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 6, 2024
66b37e4
feat: Remove unnecessary pylint disable from execution.py
andrijapau Dec 6, 2024
b48617f
doc: Update docstrings in run.py
andrijapau Dec 6, 2024
b2874ce
feat: Improve test_run.py with base interface tests
andrijapau Dec 10, 2024
818098f
feat: Add more testing
andrijapau Dec 11, 2024
4a6ea34
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 11, 2024
d2628ad
feat: Add TestJaxJitRun to test_run.py
andrijapau Dec 11, 2024
5fc2135
feat: add TestTFAutograhRun
andrijapau Dec 11, 2024
add9f25
feat: Add test to check for grad_on_execution ValueError
andrijapau Dec 11, 2024
e7036b9
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 11, 2024
284fc80
feat: Move tests to their own files and add conftest.py to store test…
andrijapau Dec 11, 2024
39a58b5
fix: Rename the test files for CI is happy
andrijapau Dec 11, 2024
b97fdb1
fix: Not sure why test_torch_run.py wasn't updated
andrijapau Dec 11, 2024
6a56540
fix: Add pylint disable for conftest import
andrijapau Dec 11, 2024
4793374
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 11, 2024
db3f375
fix: Clean-up type hinting and if-else logic in execution pipeline ge…
andrijapau Dec 11, 2024
17b30a2
fix: Remove *__ from inner_execute construction
andrijapau Dec 11, 2024
37989f4
fix: Remove getattr from test_autograd_run.py
andrijapau Dec 11, 2024
8e22211
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 11, 2024
565c6e9
fix: Update resolved_execution_config -> config for brevity in run() fxn
andrijapau Dec 11, 2024
6a2be9f
fix: Join together logic for each type of pipeline
andrijapau Dec 11, 2024
39e5c0a
fix: Update run.py to fix call back errors
andrijapau Dec 11, 2024
f57a674
fix: Remove print lol
andrijapau Dec 11, 2024
8174489
refactor: Re-work tf-autograph logic in run.py so it's easier to prune
andrijapau Dec 11, 2024
ea98b68
Update pennylane/workflow/run.py
andrijapau Dec 12, 2024
2bba865
Update tests/workflow/interfaces/run/test_jax_jit_run.py
andrijapau Dec 12, 2024
fc875c8
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 12, 2024
1af33b9
fix: use jax.jit(...) in test_jax_jit.py
andrijapau Dec 12, 2024
41ccad5
fix: typo in TODO comment lol
andrijapau Dec 12, 2024
1370399
fix: xfail adjoint with tf-autograph
andrijapau Dec 12, 2024
2598cd3
Merge branch 'master' into add-dev-run-fxn
andrijapau Dec 12, 2024
f2b1ca2
first attempt
albi3ro Dec 13, 2024
6ed7530
tidy up handling of classical part of transform program
albi3ro Dec 13, 2024
f637caa
Merge branch 'add-dev-run-fxn' into transform-program-caching
albi3ro Dec 13, 2024
e2ae226
move some stuff in qnode around
albi3ro Dec 13, 2024
3d653e3
merging
albi3ro Dec 13, 2024
0b330f1
fix mistake
albi3ro Dec 13, 2024
0b9296d
make construct return the tape
albi3ro Dec 13, 2024
e15ed26
fixing failures
albi3ro Dec 13, 2024
77ecf45
fixing failiures
albi3ro Dec 13, 2024
98c1199
Merge branch 'master' into transform-program-caching
albi3ro Dec 13, 2024
a8896f2
adding testing
albi3ro Dec 31, 2024
1ab70b2
fix test
albi3ro Dec 31, 2024
478897a
final test for coverage
albi3ro Dec 31, 2024
c17680f
Merge branch 'master' into transform-program-caching
albi3ro Dec 31, 2024
7a48ab9
test with rx graph
albi3ro Dec 31, 2024
43dab37
merging
albi3ro Dec 31, 2024
23aa398
Merge branch 'master' into transform-program-caching
albi3ro Jan 2, 2025
513ec49
minor updates from reviews
albi3ro Jan 3, 2025
d37cbd3
Apply suggestions from code review
albi3ro Jan 3, 2025
a65dcea
Merge branch 'master' into transform-program-caching
albi3ro Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,11 @@ such as `shots`, `rng` and `prng_key`.
* Moved all interface handling logic to `interface_utils.py` in the `qml.math` module.
[(#6649)](https://github.com/PennyLaneAI/pennylane/pull/6649)

* `qml.execute` can now be used with `diff_method="best"`.
Classical cotransform information is now handled lazily by the workflow. Gradient method
validation and program setup is now handled inside of `qml.execute`, instead of in `QNode`.
[(#6716)](https://github.com/PennyLaneAI/pennylane/pull/6716)

* Added PyTree support for measurements in a circuit.
[(#6378)](https://github.com/PennyLaneAI/pennylane/pull/6378)

Expand Down
2 changes: 1 addition & 1 deletion pennylane/math/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def get_deep_interface(value):
return _get_interface_of_single_tensor(itr)


def get_canonical_interface_name(user_input: Union[str, Interface]) -> Interface:
def get_canonical_interface_name(user_input: Union[str, Interface, None]) -> Interface:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Helper function to get the canonical interface name.

Args:
Expand Down
112 changes: 71 additions & 41 deletions pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
This module contains the ``TransformProgram`` class.
"""
from collections import namedtuple
from collections.abc import Sequence
from functools import partial
from typing import Optional, overload
Expand All @@ -24,6 +25,20 @@

from .transform_dispatcher import TransformContainer, TransformDispatcher, TransformError

CotransfromCache = namedtuple("CotransformCache", ("qnode", "args", "kwargs"))
albi3ro marked this conversation as resolved.
Show resolved Hide resolved


def _get_interface(qnode, args, kwargs):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
if qnode.interface == "auto":
interface = qml.math.get_interface(*args, *list(kwargs.values()))
try:
interface = qml.math.get_canonical_interface_name(interface).value
except ValueError:
interface = "numpy"
else:
interface = qnode.interface
return interface


def _numpy_jac(*_, **__) -> qml.typing.TensorLike:
raise qml.QuantumFunctionError("No trainable parameters.")
Expand Down Expand Up @@ -132,7 +147,7 @@ def _batch_postprocessing(

Keyword Args:
individual_fns (List[Callable]): postprocessing functions converting a batch of results into a single result
corresponding to only a single :class:`~.QuantumTape`.
corresponding to only a single :class:`~.QuantumTape`.
slices (List[slice]): the indices for the results that correspond to each individual post processing function.

>>> results = (1.0, 2.0, 3.0, 4.0)
Expand Down Expand Up @@ -200,6 +215,12 @@ class TransformProgram:

The order of execution is the order in the list containing the containers.

Args:
initial_program (Optional[Sequence[TransformContainer]]): A sequence of transforms to
initialize the program with
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
cotransform_cache (Optional[CotransformCache]): A named tuple containing the ``qnode``,
``args``, and ``kwargs`` required to compute classical cotransforms.

The main case where one would have to interact directly with a transform program is when developing a
:class:`Device <pennylane.devices.Device>`. In this case, the pre-processing method of a device
returns a transform program. You should directly refer to the device API documentation for more details.
Expand Down Expand Up @@ -243,10 +264,13 @@ class TransformProgram:

"""

def __init__(self, initial_program: Optional[Sequence] = None):
def __init__(
self,
initial_program: Optional[Sequence[TransformContainer]] = None,
cotransform_cache: Optional[CotransfromCache] = None,
):
self._transform_program = list(initial_program) if initial_program else []
self._classical_jacobians = None
self._argnums = None
self.cotransform_cache = cotransform_cache

def __iter__(self):
"""list[TransformContainer]: Return an iterator to the underlying transform program."""
Expand All @@ -270,15 +294,22 @@ def __getitem__(self, idx):
def __bool__(self) -> bool:
return bool(self._transform_program)

def __add__(self, other):
def __add__(self, other: "TransformProgram") -> "TransformProgram":
if self.has_final_transform and other.has_final_transform:
raise TransformError("The transform program already has a terminal transform.")

transforms = self._transform_program + other._transform_program
if self.has_final_transform:
transforms.append(transforms.pop(len(self) - 1))

return TransformProgram(transforms)
cotransform_cache = None
if self.cotransform_cache:
if other.cotransform_cache:
raise ValueError("Cannot add two transform programs with cotransform caches.")
cotransform_cache = self.cotransform_cache
elif other.cotransform_cache:
cotransform_cache = other.cotransform_cache
return TransformProgram(transforms, cotransform_cache=cotransform_cache)

def __repr__(self):
"""The string representation of the transform program class."""
Expand Down Expand Up @@ -444,17 +475,9 @@ def has_classical_cotransform(self) -> bool:

def set_classical_component(self, qnode, args, kwargs):
"""Set the classical jacobians and argnums if the transform is hybrid with a classical cotransform."""
if not self.has_classical_cotransform():
return
hybrid = self[-1].kwargs.pop("hybrid", True) # pylint: disable=no-member

if hybrid:
argnums = self[-1].kwargs.pop("argnums", None) # pylint: disable=no-member
self._classical_jacobians = [
self._get_classical_jacobian(index, qnode, args, kwargs, argnums)
for index, _ in enumerate(self)
]
self._set_all_argnums(qnode, args, kwargs, argnums)
# pylint: disable=no-member
if self.has_classical_cotransform() and self[-1].kwargs.get("hybrid", True):
self.cotransform_cache = CotransfromCache(qnode, args, kwargs)

def prune_dynamic_transform(self, type_to_keep=1):
"""Ensures that only one or none ``dynamic_one_shot`` is applied.
Expand Down Expand Up @@ -484,16 +507,20 @@ def prune_dynamic_transform(self, type_to_keep=1):
return found

# pylint: disable=too-many-arguments, too-many-positional-arguments
def _get_classical_jacobian(self, index: int, qnode, args, kwargs, argnums):
if not self[index].classical_cotransform:
def _get_classical_jacobian(self, index: int):
if self.cotransform_cache is None or not self[index].classical_cotransform:
return None
if qnode.interface == "jax" and "argnum" in self[index].kwargs:
argnums = self[-1].kwargs.get("argnums", None) # pylint: disable=no-member
qnode, args, kwargs = self.cotransform_cache

interface = _get_interface(qnode, args, kwargs)
if interface == "jax" and "argnum" in self[index].kwargs:
raise qml.QuantumFunctionError(
"argnum does not work with the Jax interface. You should use argnums instead."
)

f = partial(_classical_preprocessing, qnode, self[:index])
classical_jacobian = _jac_map[qnode.interface](f, argnums, *args, **kwargs)
classical_jacobian = _jac_map[interface](f, argnums, *args, **kwargs)

# autograd and tf cant handle pytrees, so need to unsqueeze the squeezing
# done in _classical_preprocessing
Expand All @@ -504,22 +531,22 @@ def _get_classical_jacobian(self, index: int, qnode, args, kwargs, argnums):
classical_jacobian = [classical_jacobian]
return classical_jacobian

def _set_all_argnums(self, qnode, args, kwargs, argnums):
def _get_argnums(self, index):
"""It can be used inside the QNode to set all argnums (tape level) using argnums from the argnums at the QNode
level.
"""

argnums_list = []
for index, transform in enumerate(self):
argnums = [0] if qnode.interface in ["jax", "jax-jit"] and argnums is None else argnums
# pylint: disable=protected-access
if (transform._use_argnum or transform.classical_cotransform) and argnums:
params = _jax_argnums_to_tape_trainable(qnode, argnums, self[:index], args, kwargs)
argnums_list.append([qml.math.get_trainable_indices(param) for param in params])
else:
argnums_list.append(None)

self._argnums = argnums_list
if self.cotransform_cache is None:
return None
qnode, args, kwargs = self.cotransform_cache
interface = _get_interface(qnode, args, kwargs)
transform = self[index]
argnums = self[-1].kwargs.get("argnums", None) # pylint: disable=no-member
argnums = [0] if interface in ["jax", "jax-jit"] and argnums is None else argnums
# pylint: disable=protected-access
if (transform._use_argnum or transform.classical_cotransform) and argnums:
params = _jax_argnums_to_tape_trainable(qnode, argnums, self[:index], args, kwargs)
return [qml.math.get_trainable_indices(param) for param in params]
return None

def __call__(
self, tapes: QuantumScriptBatch
Expand All @@ -531,7 +558,9 @@ def __call__(

for i, transform_container in enumerate(self):
transform, targs, tkwargs, cotransform, _, _, _ = transform_container

tkwargs = {
key: value for key, value in tkwargs.items() if key not in {"argnums", "hybrid"}
}
execution_tapes = []
fns = []
slices = []
Expand All @@ -541,9 +570,11 @@ def __call__(

start = 0
start_classical = 0
classical_jacobians = self._get_classical_jacobian(i)
argnums = self._get_argnums(i)
for j, tape in enumerate(tapes):
if self._argnums is not None and self._argnums[i] is not None:
tape.trainable_params = self._argnums[i][j]
if argnums is not None:
tape.trainable_params = argnums[j]
new_tapes, fn = transform(tape, *targs, **tkwargs)
execution_tapes.extend(new_tapes)

Expand All @@ -552,14 +583,14 @@ def __call__(
slices.append(slice(start, end))
start = end

if cotransform and self._classical_jacobians:
if cotransform and classical_jacobians:
classical_fns.append(
partial(cotransform, cjac=self._classical_jacobians[i][j], tape=tape)
partial(cotransform, cjac=classical_jacobians[j], tape=tape)
)
slices_classical.append(slice(start_classical, start_classical + 1))
start_classical += 1

if cotransform and self._classical_jacobians:
if cotransform and classical_jacobians:
batch_postprocessing_classical = partial(
_batch_postprocessing, individual_fns=classical_fns, slices=slices_classical
)
Expand All @@ -581,5 +612,4 @@ def __call__(
postprocessing_fn.__doc__ = _apply_postprocessing_stack.__doc__

# Reset classical jacobians
self._classical_jacobians = []
return tuple(tapes), postprocessing_fn
4 changes: 3 additions & 1 deletion pennylane/workflow/_setup_transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def _setup_transform_program(

device_transform_program = device.preprocess_transforms(resolved_execution_config)

full_transform_program = qml.transforms.core.TransformProgram(user_transform_program)
full_transform_program = qml.transforms.core.TransformProgram(
user_transform_program, cotransform_cache=user_transform_program.cotransform_cache
)
inner_transform_program = qml.transforms.core.TransformProgram()

# Add the gradient expand to the program if necessary
Expand Down
21 changes: 8 additions & 13 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pennylane.typing import ResultBatch

from ._setup_transform_program import _setup_transform_program
from .resolution import _resolve_interface
from .resolution import _resolve_execution_config, _resolve_interface
from .run import run

logger = logging.getLogger(__name__)
Expand All @@ -39,7 +39,7 @@


# pylint: disable=too-many-arguments
def execute(

Check notice on line 42 in pennylane/workflow/execution.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/workflow/execution.py#L42

Too many positional arguments (15/5) (too-many-positional-arguments)
tapes: QuantumScriptBatch,
device: Union["qml.devices.LegacyDevice", "qml.devices.Device"],
diff_method: Optional[Union[Callable, str, qml.transforms.core.TransformDispatcher]] = None,
Expand Down Expand Up @@ -187,21 +187,13 @@
"::L".join(str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]),
)

if not tapes:
return ()

### Specifying and preprocessing variables ####

interface = _resolve_interface(interface, tapes)
# Only need to calculate derivatives with jax when we know it will be executed later.
if interface in {Interface.JAX, Interface.JAX_JIT}:
grad_on_execution = grad_on_execution if isinstance(diff_method, Callable) else False

if (
device_vjp
and isinstance(device, qml.devices.LegacyDeviceFacade)
and "lightning" not in getattr(device, "short_name", "").lower()
):
raise qml.QuantumFunctionError(
"device provided jacobian products are not compatible with the old device interface."
)

gradient_kwargs = gradient_kwargs or {}
mcm_config = mcm_config or {}
Expand All @@ -215,7 +207,9 @@
gradient_keyword_arguments=gradient_kwargs,
derivative_order=max_diff,
)
config = device.setup_execution_config(config)
config = _resolve_execution_config(
config, device, tapes, transform_program=transform_program
)

config = replace(
config,
Expand All @@ -224,6 +218,7 @@
)

if transform_program is None or inner_transform is None:
transform_program = transform_program or qml.transforms.core.TransformProgram()
transform_program, inner_transform = _setup_transform_program(
transform_program, device, config, cache, cachesize
)
Expand Down
Loading
Loading