Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] First PR for enabling dynamic decompositions with PLxPR enabled #6859

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
696acf3
E.C.
PietropaoloFrisoni Jan 20, 2025
9cdfae9
Creating an empty `DynamicDecomposeInterpreter` c;ass
PietropaoloFrisoni Jan 21, 2025
1e138fa
Sbattendo la testa contro il muro tante volte
PietropaoloFrisoni Jan 21, 2025
76c9250
Current prototype version
PietropaoloFrisoni Jan 22, 2025
c821ac9
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 22, 2025
cee6ec4
Fixing one more problem
PietropaoloFrisoni Jan 22, 2025
a208dc5
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 22, 2025
d16a9e0
Moving tests to separate file
PietropaoloFrisoni Jan 23, 2025
a8b9283
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 23, 2025
18bc43c
Pylint fixes (although premature)
PietropaoloFrisoni Jan 23, 2025
e2e8fd0
Removing reundandt tuple calls
PietropaoloFrisoni Jan 23, 2025
0abd620
Tests with dynamic wires
PietropaoloFrisoni Jan 23, 2025
1e3ffb6
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 23, 2025
1ae399b
Adding Autograph test
PietropaoloFrisoni Jan 23, 2025
9a54c3e
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 24, 2025
c5f2ae5
Removing unused parameters and adding a few tests
PietropaoloFrisoni Jan 24, 2025
497440c
Adding a few more tests
PietropaoloFrisoni Jan 24, 2025
c7da133
Removing import
PietropaoloFrisoni Jan 24, 2025
2f0417c
Pylint
PietropaoloFrisoni Jan 24, 2025
3c8bc37
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 24, 2025
e9ff110
Adding test with hyperparameters
PietropaoloFrisoni Jan 27, 2025
b9f5d03
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 27, 2025
f2437b4
Black
PietropaoloFrisoni Jan 27, 2025
8c615c5
A few more tests
PietropaoloFrisoni Jan 27, 2025
9fef95b
Changelog
PietropaoloFrisoni Jan 27, 2025
4a56150
Removing redundant operations
PietropaoloFrisoni Jan 28, 2025
e733762
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 29, 2025
9792df6
Pre-binding hyperparameters [ci skip]
PietropaoloFrisoni Jan 29, 2025
aa422b3
Removing redundant method
PietropaoloFrisoni Jan 30, 2025
b7a18cd
Pylint
PietropaoloFrisoni Jan 30, 2025
97cba03
Testing CI failures (JAX imports)
PietropaoloFrisoni Jan 30, 2025
e05963e
isort
PietropaoloFrisoni Jan 30, 2025
ec73757
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,6 @@

<h3>Improvements 🛠</h3>

* The higher order primitives in program capture can now accept inputs with abstract shapes.
[(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786)

* The `PlxprInterpreter` classes can now handle creating dynamic arrays via `jnp.ones`, `jnp.zeros`,
`jnp.arange`, and `jnp.full`.
[#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865)

* `QNode` objects now have an `update` method that allows for re-configuring settings like `diff_method`, `mcm_method`, and more. This allows for easier on-the-fly adjustments to workflows. Any arguments not specified will retain their original value.
[(#6803)](https://github.com/PennyLaneAI/pennylane/pull/6803)

Expand Down Expand Up @@ -61,6 +54,19 @@
* The requested `diff_method` is now validated when program capture is enabled.
[(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852)


<h4>Capturing and representing hybrid programs</h4>

* Implemented a new `DynamicDecomposeInterpreter` to capture decompositions of operators with control-flow instructions.
[(#6859)](https://github.com/PennyLaneAI/pennylane/pull/6859)

* The higher order primitives in program capture can now accept inputs with abstract shapes.
[(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786)

* The `PlxprInterpreter` classes can now handle creating dynamic arrays via `jnp.ones`, `jnp.zeros`,
`jnp.arange`, and `jnp.full`.
[#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865)

<h3>Breaking changes 💔</h3>

* `MultiControlledX` no longer accepts strings as control values.
Expand Down
25 changes: 25 additions & 0 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@
import warnings
from collections.abc import Hashable, Iterable
from enum import IntEnum
from functools import partial
from typing import Any, Callable, Literal, Optional, Union

import numpy as np
Expand Down Expand Up @@ -1322,6 +1323,30 @@ def decomposition(self) -> list["Operator"]:
*self.parameters, wires=self.wires, **self.hyperparameters
)

@classproperty
def _has_plxpr_decomposition(cls) -> bool:
"""Whether or not the Operator returns a defined plxpr decomposition."""

return (
cls._compute_plxpr_decomposition != Operator._compute_plxpr_decomposition
or cls._plxpr_decomposition != Operator._plxpr_decomposition
)

def _plxpr_decomposition(self) -> "jax.core.Jaxpr":
"""Representation of the operator as a plxpr decomposition."""

args = (*self.parameters, *self.wires)
jaxpr_decomp = qml.capture.make_plxpr(
partial(self._compute_plxpr_decomposition, **self.hyperparameters)
)(*args)

return jaxpr_decomp

@staticmethod
def _compute_plxpr_decomposition(*args, **hyperparameters):
"""Experimental method to compute the plxpr decomposition of the operator."""
raise DecompositionUndefinedError

@staticmethod
def compute_decomposition(
*params: TensorLike,
Expand Down
82 changes: 82 additions & 0 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,88 @@ def wrapper(*inner_args):
DecomposeInterpreter, decompose_plxpr_to_plxpr = _get_plxpr_decompose()


@lru_cache
def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring
try:
# pylint: disable=import-outside-toplevel
# pylint: disable=unused-import
import jax

from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator
except ImportError: # pragma: no cover
return None, None

# pylint: disable=redefined-outer-name

class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter):
"""
Experimental Plxpr Interpreter for applying a dynamic decomposition to operations when program capture is enabled.
"""

def eval_dynamic_decomposition(self, jaxpr_decomp: "jax.core.Jaxpr", *args):
"""
Evaluate a dynamic decomposition of a Jaxpr.

Args:
jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate
*args: the arguments to use in the evaluation
"""

for arg, invar in zip(args, jaxpr_decomp.invars, strict=True):
self._env[invar] = arg

for inner_eqn in jaxpr_decomp.eqns:

custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None)

if custom_handler:
invals = [self.read(invar) for invar in inner_eqn.invars]
outvals = custom_handler(self, *invals, **inner_eqn.params)

elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator):
# This does not currently support nested decompositions
outvals = super().interpret_operation_eqn(inner_eqn)
elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement):
outvals = super().interpret_measurement_eqn(inner_eqn)
else:
invals = [self.read(invar) for invar in inner_eqn.invars]
outvals = inner_eqn.primitive.bind(*invals, **inner_eqn.params)

if not inner_eqn.primitive.multiple_results:
outvals = [outvals]

for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True):
self._env[inner_outvar] = inner_outval

def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"):
"""
Interpret an equation corresponding to an operator.

Args:
eqn (jax.core.JaxprEqn): a jax equation for an operator.
"""

invals = (self.read(invar) for invar in eqn.invars)
with qml.QueuingManager.stop_recording():
op = eqn.primitive.impl(*invals, **eqn.params)

if isinstance(eqn.outvars[0], jax.core.DropVar):

if op._has_plxpr_decomposition:
jaxpr_decomp = op._plxpr_decomposition()
args = (*op.parameters, *op.wires)
return self.eval_dynamic_decomposition(jaxpr_decomp.jaxpr, *args)

return super().interpret_operation(op)

return op

return DynamicDecomposeInterpreter


DynamicDecomposeInterpreter = _get_plxpr_dynamic_decompose()


@partial(transform, plxpr_transform=decompose_plxpr_to_plxpr)
def decompose(tape, gate_set=None, max_expansion=None):
"""Decomposes a quantum circuit into a user-specified gate set.
Expand Down
Loading