Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make qml.math.jax_argnums_to_tape_trainable private #6609

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@

<h3>Breaking changes 💔</h3>

* `qml.math.jax_argnums_to_tape_trainable` is moved and made private to avoid a qnode dependency
in the math module.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
[(#6609)](https://github.com/PennyLaneAI/pennylane/pull/6609)

* Gradient transforms are now applied after the user's transform program.
[(#6590)](https://github.com/PennyLaneAI/pennylane/pull/6590)

Expand Down
1 change: 0 additions & 1 deletion pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
gammainc,
get_trainable_indices,
iscomplex,
jax_argnums_to_tape_trainable,
kron,
matmul,
multi_dispatch,
Expand Down
36 changes: 0 additions & 36 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from autoray import numpy as np
from numpy import ndarray

import pennylane as qml

from . import single_dispatch # pylint:disable=unused-import
from .utils import cast, cast_like, get_interface, requires_grad

Expand Down Expand Up @@ -1006,40 +1004,6 @@ def detach(tensor, like=None):
return tensor


def jax_argnums_to_tape_trainable(qnode, argnums, program, args, kwargs):
"""This functions gets the tape parameters from the QNode construction given some argnums (only for Jax).
The tape parameters are transformed to JVPTracer if they are from argnums. This function imitates the behaviour
of Jax in order to mark trainable parameters.

Args:
qnode(qml.QNode): the quantum node.
argnums(int, list[int]): the parameters that we want to set as trainable (on the QNode level).
program(qml.transforms.core.TransformProgram): the transform program to be applied on the tape.


Return:
list[float, jax.JVPTracer]: List of parameters where the trainable one are `JVPTracer`.
"""
import jax

with jax.core.new_main(jax.interpreters.ad.JVPTrace) as main:
trace = jax.interpreters.ad.JVPTrace(main, 0)

args_jvp = [
(
jax.interpreters.ad.JVPTracer(trace, arg, jax.numpy.zeros(arg.shape))
if i in argnums
else arg
)
for i, arg in enumerate(args)
]

tape = qml.workflow.construct_tape(qnode, level=0)(*args_jvp, **kwargs)
tapes, _ = program((tape,))
del trace
return tuple(tape.get_parameters(trainable_only=False) for tape in tapes)


@multi_dispatch(tensor_list=[1])
def set_index(array, idx, val, like=None):
"""Set the value at a specified index in an array.
Expand Down
36 changes: 35 additions & 1 deletion pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,40 @@
from .transform_dispatcher import TransformContainer, TransformDispatcher, TransformError


def _jax_argnums_to_tape_trainable(qnode, argnums, program, args, kwargs):
"""This functions gets the tape parameters from the QNode construction given some argnums (only for Jax).
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
The tape parameters are transformed to JVPTracer if they are from argnums. This function imitates the behaviour
of Jax in order to mark trainable parameters.

Args:
qnode(qml.QNode): the quantum node.
argnums(int, list[int]): the parameters that we want to set as trainable (on the QNode level).
program(qml.transforms.core.TransformProgram): the transform program to be applied on the tape.


albi3ro marked this conversation as resolved.
Show resolved Hide resolved
Return:
list[float, jax.JVPTracer]: List of parameters where the trainable one are `JVPTracer`.
"""
import jax # pylint: disable=import-outside-toplevel

with jax.core.new_main(jax.interpreters.ad.JVPTrace) as main:
trace = jax.interpreters.ad.JVPTrace(main, 0)

args_jvp = [
(
jax.interpreters.ad.JVPTracer(trace, arg, jax.numpy.zeros(arg.shape))
if i in argnums
else arg
)
for i, arg in enumerate(args)
]

tape = qml.workflow.construct_tape(qnode, level=0)(*args_jvp, **kwargs)
tapes, _ = program((tape,))
del trace
return tuple(tape.get_parameters(trainable_only=False) for tape in tapes)


def _batch_postprocessing(
results: ResultBatch, individual_fns: list[PostprocessingFn], slices: list[slice]
) -> ResultBatch:
Expand Down Expand Up @@ -477,7 +511,7 @@ def _set_all_argnums(self, qnode, args, kwargs, argnums):
argnums = [0] if qnode.interface in ["jax", "jax-jit"] and argnums is None else argnums
# pylint: disable=protected-access
if (transform._use_argnum or transform.classical_cotransform) and argnums:
params = qml.math.jax_argnums_to_tape_trainable(
params = _jax_argnums_to_tape_trainable(
qnode, argnums, TransformProgram(self[0:index]), args, kwargs
)
argnums_list.append([qml.math.get_trainable_indices(param) for param in params])
Expand Down