-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling jax.vmap
for program capture
#6349
Conversation
…o vmap_program_capture
…com/PennyLaneAI/pennylane into vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…/pennylane into experim_param_broad_capture
…o experim_param_broad_capture
…o vmap_program_capture
…o experim_param_broad_capture
…yLaneAI/pennylane into vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
…o vmap_program_capture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, no blocker on my side!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in the review.
I think most of my hesitance comes forbidding perfectly valid behavior because we don't want to trust the user. For example,
mat = jax.numpy.eye(2)
@qml.qnode(qml.device('default.qubit', wires=2))
def circuit(x):
qml.QubitUnitary(mat, 0)
qml.RX(x,0)
return qml.expval(qml.Z(0))
jax.vmap(circuit)(jax.numpy.array([1.2, 2.3]) )
ValueError: One argument has more than one dimension. Currently, only single-dimension batching is supported.
or
@qml.qnode(qml.device('default.qubit', wires=2))
def circuit(x):
qml.StatePrep(jax.numpy.array([1.0, 0.0]), wires=0)
qml.RX(x,0)
return qml.expval(qml.Z(0))
jax.vmap(circuit)(jax.numpy.array([1.2, 2.3]) )
ValueError: ('Constant argument at index 0 is not scalar. ', 'Only scalar constants are currently supported with jax.vmap.')
Those both should be a perfectly valid thing to do. We are not technically limited in being able to support them. The only reason we can't support them is that we have chosen to explicitly forbid such a situation. Yes, we can't promise that there won't be errors later on, but that is always the case. We need to find a balance between supporting as many features we can, trusting the users are telling us the truth when we ask them a specific question, and having sensible, informative errors when things have gone wrong. I guess I fall more on the side of an increased feature set at the cost of the informative errors happening later in the pipeline.
But ultimately, this is a fuzzy judgment call we can change later on. I'm happy to get this in as is and update as needed.
Thanks @albi3ro
I see. For the first example you provided, this limitation has already been removed in this PR. As for the second one, I agree with you that a warning might be more appropriate than a value error. I can change that in the aforementioned PR 👍 |
**Context:** This PR is the first step to enable the usage of `jax.vmap` with quantum circuits when `qml.capture.enabled()` is `True`. **Description of the Change:** We implemented a batching rule for the captured `qnode` and modified the abstract evaluation to keep track of the batch dimension. **Benefits:** Allows vectorization (although with several limitations at this stage) with `qml.capture.enabled()` via `jax.vmap` **Possible Drawbacks:** Right now, there are 2 main limitations that I see. - Limitation 1: *Multidimensional arrays* Right now, the implementation does not work with multidimensional arrays. For example, the following does not work: ``` qml.capture.enable() @qml.qnode(qml.device("default.qubit", wires=2)) def circuit(x): qml.RX(x[1], 0) return qml.expval(qml.Z(0)) # jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) Generates an error ``` This limitation will be removed in the [following PR](#6422). - Limitation 2 *How to prevent the user from bypassing `jax.vmap` ?* ``` qml.capture.enable() @qml.qnode(qml.device("default.qubit", wires=2)) def circuit(x, y): qml.RX(x, 0) qml.RY(y, 0) return qml.expval(qml.Z(0)) x = jax.numpy.array([0.1, 0.2]) y = jax.numpy.array([0.1, 0.2]) jax.vmap(circuit, in_axes=(0, None))(x, y) ``` This works by accident, but the second argument is vectorized along 2 dimensions, although it shouldn't! With `in_axes=(0, None)` we should raise an error if `qml.RY` receives something that is not a scalar (as [it happens in Catalyst](https://github.com/PennyLaneAI/catalyst/blob/7c5b828d5173cdaa52073d30a5f3a7df660b37d6/frontend/catalyst/jax_primitives.py#L1206)). At this stage, we simply raise a warning because we don't have a way to check this inside the QNode batching rule. To fix this behavior, I think we need to add more properties to the `AbstractOperator` class and improve the integration with the captured `QNode`. This is most probably also necessary for capturing parameter broadcasting in PL. **Related GitHub Issues:** None. **Related Shortcut Stories:** [sc-73779] [sc-73782] [sc-73783] --------- Co-authored-by: Christina Lee <[email protected]>
**Context:** This PR extends the captured `jax.vmap` version to the multidimensional input case. For further details, we refer to the description of the [first PR](#6349). **Description of the Change:** As above. For 'multidimensional input case' we mean something like the following: ``` qml.capture.enable() @qml.qnode(qml.device("default.qubit", wires=...)) ... jax.vmap(circuit)(jax.numpy.array([[0.1, 0.2], [0.3, 0.4]])) ``` **Benefits:** Now `jax.vmap` can be used with captured enabled if the input parameter is an array with a shape greater than 1. **Possible Drawbacks:** None that I can think of. **Related GitHub Issues:** None. **Related Shortcut Stories:** [sc-76381]
Context: This PR is the first step to enable the usage of
jax.vmap
with quantum circuits whenqml.capture.enabled()
isTrue
.Description of the Change: We implemented a batching rule for the captured
qnode
and modified the abstract evaluation to keep track of the batch dimension.Benefits: Allows vectorization (although with several limitations at this stage) with
qml.capture.enabled()
viajax.vmap
Possible Drawbacks: Right now, there are 2 main limitations that I see.
Right now, the implementation does not work with multidimensional arrays. For example, the following does not work:
This limitation will be removed in the following PR.
jax.vmap
?This works by accident, but the second argument is vectorized along 2 dimensions, although it shouldn't!
With
in_axes=(0, None)
we should raise an error ifqml.RY
receives something that is not a scalar (as it happens in Catalyst).At this stage, we simply raise a warning because we don't have a way to check this inside the QNode batching rule. To fix this behavior, I think we need to add more properties to the
AbstractOperator
class and improve the integration with the capturedQNode
. This is most probably also necessary for capturing parameter broadcasting in PL.Related GitHub Issues: None.
Related Shortcut Stories: [sc-73779] [sc-73782] [sc-73783]