Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling jax.vmap for program capture #6349

Merged
merged 64 commits into from
Nov 8, 2024
Merged

Enabling jax.vmap for program capture #6349

merged 64 commits into from
Nov 8, 2024

Conversation

PietropaoloFrisoni
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni commented Oct 7, 2024

Context: This PR is the first step to enable the usage of jax.vmap with quantum circuits when qml.capture.enabled() is True.

Description of the Change: We implemented a batching rule for the captured qnode and modified the abstract evaluation to keep track of the batch dimension.

Benefits: Allows vectorization (although with several limitations at this stage) with qml.capture.enabled() via jax.vmap

Possible Drawbacks: Right now, there are 2 main limitations that I see.

  • Limitation 1: Multidimensional arrays

Right now, the implementation does not work with multidimensional arrays. For example, the following does not work:

qml.capture.enable()

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x):
    qml.RX(x[1], 0)
    return qml.expval(qml.Z(0))

# jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) Generates an error

This limitation will be removed in the following PR.

  • Limitation 2 How to prevent the user from bypassing jax.vmap ?
qml.capture.enable()


@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x, y):
    qml.RX(x, 0)
    qml.RY(y, 0)
    return qml.expval(qml.Z(0))

x = jax.numpy.array([0.1, 0.2])
y = jax.numpy.array([0.1, 0.2])

jax.vmap(circuit, in_axes=(0, None))(x, y) 

This works by accident, but the second argument is vectorized along 2 dimensions, although it shouldn't!

With in_axes=(0, None) we should raise an error if qml.RY receives something that is not a scalar (as it happens in Catalyst).

At this stage, we simply raise a warning because we don't have a way to check this inside the QNode batching rule. To fix this behavior, I think we need to add more properties to the AbstractOperator class and improve the integration with the captured QNode. This is most probably also necessary for capturing parameter broadcasting in PL.

Related GitHub Issues: None.

Related Shortcut Stories: [sc-73779] [sc-73782] [sc-73783]

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, no blocker on my side!

Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in the review.

I think most of my hesitance comes forbidding perfectly valid behavior because we don't want to trust the user. For example,

mat = jax.numpy.eye(2)

@qml.qnode(qml.device('default.qubit', wires=2))
def circuit(x):
    qml.QubitUnitary(mat, 0)
    qml.RX(x,0)
    return qml.expval(qml.Z(0))

jax.vmap(circuit)(jax.numpy.array([1.2, 2.3]) )
ValueError: One argument has more than one dimension. Currently, only single-dimension batching is supported.

or

@qml.qnode(qml.device('default.qubit', wires=2))
def circuit(x):
    qml.StatePrep(jax.numpy.array([1.0, 0.0]), wires=0)
    qml.RX(x,0)
    return qml.expval(qml.Z(0))

jax.vmap(circuit)(jax.numpy.array([1.2, 2.3]) )
ValueError: ('Constant argument at index 0 is not scalar. ', 'Only scalar constants are currently supported with jax.vmap.')

Those both should be a perfectly valid thing to do. We are not technically limited in being able to support them. The only reason we can't support them is that we have chosen to explicitly forbid such a situation. Yes, we can't promise that there won't be errors later on, but that is always the case. We need to find a balance between supporting as many features we can, trusting the users are telling us the truth when we ask them a specific question, and having sensible, informative errors when things have gone wrong. I guess I fall more on the side of an increased feature set at the cost of the informative errors happening later in the pipeline.

But ultimately, this is a fuzzy judgment call we can change later on. I'm happy to get this in as is and update as needed.

@PietropaoloFrisoni
Copy link
Contributor Author

Thanks @albi3ro

I think most of my hesitance comes from forbidding perfectly valid behavior because we don't want to trust the user.

I see. For the first example you provided, this limitation has already been removed in this PR. As for the second one, I agree with you that a warning might be more appropriate than a value error. I can change that in the aforementioned PR 👍

@PietropaoloFrisoni PietropaoloFrisoni added the merge-ready ✔️ All tests pass and the PR is ready to be merged. label Nov 7, 2024
@albi3ro albi3ro enabled auto-merge (squash) November 8, 2024 16:48
@albi3ro albi3ro disabled auto-merge November 8, 2024 17:08
@albi3ro albi3ro merged commit 34179a0 into master Nov 8, 2024
44 of 45 checks passed
@albi3ro albi3ro deleted the vmap_program_capture branch November 8, 2024 17:09
mudit2812 pushed a commit that referenced this pull request Nov 11, 2024
**Context:** This PR is the first step to enable the usage of `jax.vmap`
with quantum circuits when `qml.capture.enabled()` is `True`.

**Description of the Change:** We implemented a batching rule for the
captured `qnode` and modified the abstract evaluation to keep track of
the batch dimension.

**Benefits:** Allows vectorization (although with several limitations at
this stage) with `qml.capture.enabled()` via `jax.vmap`

**Possible Drawbacks:** Right now, there are 2 main limitations that I
see.

- Limitation 1: *Multidimensional arrays*

Right now, the implementation does not work with multidimensional
arrays. For example, the following does not work:

```
qml.capture.enable()

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x):
    qml.RX(x[1], 0)
    return qml.expval(qml.Z(0))

# jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) Generates an error
```

This limitation will be removed in the [following
PR](#6422).

- Limitation 2 *How to prevent the user from bypassing `jax.vmap` ?*

```
qml.capture.enable()


@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(x, y):
    qml.RX(x, 0)
    qml.RY(y, 0)
    return qml.expval(qml.Z(0))

x = jax.numpy.array([0.1, 0.2])
y = jax.numpy.array([0.1, 0.2])

jax.vmap(circuit, in_axes=(0, None))(x, y) 

```

This works by accident, but the second argument is vectorized along 2
dimensions, although it shouldn't!

With `in_axes=(0, None)` we should raise an error if `qml.RY` receives
something that is not a scalar (as [it happens in
Catalyst](https://github.com/PennyLaneAI/catalyst/blob/7c5b828d5173cdaa52073d30a5f3a7df660b37d6/frontend/catalyst/jax_primitives.py#L1206)).

At this stage, we simply raise a warning because we don't have a way to
check this inside the QNode batching rule. To fix this behavior, I think
we need to add more properties to the `AbstractOperator` class and
improve the integration with the captured `QNode`. This is most probably
also necessary for capturing parameter broadcasting in PL.


**Related GitHub Issues:** None.

**Related Shortcut Stories:** [sc-73779] [sc-73782] [sc-73783]

---------

Co-authored-by: Christina Lee <[email protected]>
PietropaoloFrisoni added a commit that referenced this pull request Nov 18, 2024
**Context:** This PR extends the captured `jax.vmap` version to the
multidimensional input case. For further details, we refer to the
description of the [first
PR](#6349).

**Description of the Change:** As above. For 'multidimensional input
case' we mean something like the following:

```
qml.capture.enable()

@qml.qnode(qml.device("default.qubit", wires=...))
...

jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) 

```

**Benefits:** Now `jax.vmap` can be used with captured enabled if the
input parameter is an array with a shape greater than 1.

**Possible Drawbacks:** None that I can think of.

**Related GitHub Issues:** None.

**Related Shortcut Stories:** [sc-76381]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge-ready ✔️ All tests pass and the PR is ready to be merged.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants