From 0aac47149a5064cab204411ba70ae61cf9a50c7d Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 4 Dec 2024 09:56:00 -0500 Subject: [PATCH 1/2] Fixing test and synthax --- doc/releases/changelog-dev.md | 1 + pennylane/workflow/_capture_qnode.py | 2 +- tests/capture/test_capture_qnode.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index aa4b78bfc5d..0b808e83969 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -150,6 +150,7 @@ added `binary_mapping()` function to map `BoseWord` and `BoseSentence` to qubit * `jax.vmap` can be captured with `qml.capture.make_plxpr` and is compatible with quantum circuits. [(#6349)](https://github.com/PennyLaneAI/pennylane/pull/6349) [(#6422)](https://github.com/PennyLaneAI/pennylane/pull/6422) + [(#...)](https://github.com/PennyLaneAI/pennylane/pull/...) * `qml.capture.PlxprInterpreter` base class has been added for easy transformation and execution of pennylane variant jaxpr. diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index ffe62d13a5b..a3c8e790430 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -112,7 +112,7 @@ def qfunc(*inner_args): # pylint: disable=protected-access return jax.vmap(partial(qnode._impl_call, shots=shots), batch_dims[n_consts:])( - *jax.tree_util.tree_leaves(non_const_args) + *non_const_args ) # pylint: disable=protected-access diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 41b1854956f..84e1554e6b1 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -909,7 +909,7 @@ def workflow4(y, x, z): assert len(eqn.outvars) == 1 assert eqn.outvars[0].aval.shape == (3,) - result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x, y, 1) + result = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x["arr"], y, 1) expected = jax.numpy.array([0.93005586, 0.00498127, -0.88789978]) * y assert jax.numpy.allclose(result[0], expected) assert jax.numpy.allclose(result[1], expected) From 80a19ceccd71fde173f6a322b47e43bb224a4bff Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Wed, 4 Dec 2024 09:58:31 -0500 Subject: [PATCH 2/2] Linking pull request number --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0b808e83969..066acb325aa 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -150,7 +150,7 @@ added `binary_mapping()` function to map `BoseWord` and `BoseSentence` to qubit * `jax.vmap` can be captured with `qml.capture.make_plxpr` and is compatible with quantum circuits. [(#6349)](https://github.com/PennyLaneAI/pennylane/pull/6349) [(#6422)](https://github.com/PennyLaneAI/pennylane/pull/6422) - [(#...)](https://github.com/PennyLaneAI/pennylane/pull/...) + [(#6668)](https://github.com/PennyLaneAI/pennylane/pull/6668) * `qml.capture.PlxprInterpreter` base class has been added for easy transformation and execution of pennylane variant jaxpr.