Skip to content

Commit

Permalink
Handle broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
Tarun-Kumar07 committed Apr 10, 2024
1 parent f47844f commit bbae7d4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
10 changes: 7 additions & 3 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,13 @@ def _process_single_shot(samples):
raise e
samples = qml.math.full((shots.total_shots, len(wires)), 0)

processed_samples = [
_process_single_shot(samples[lower:upper]) for lower, upper in shots.bins()
]
processed_samples = []
for lower, upper in shots.bins():
if len(samples.shape) == 3:
# Handle broadcasting
processed_samples.append(_process_single_shot(samples[:, lower:upper, :]))
else:
processed_samples.append(_process_single_shot(samples[lower:upper]))

return tuple(zip(*processed_samples))

Expand Down
6 changes: 3 additions & 3 deletions tests/devices/qubit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,8 @@ def test_sample_measure_shot_vector(self, shots):
"measurement, expected",
[
(
qml.probs(wires=[0, 1]),
np.array([[0, 0, 0, 1], [1 / 2, 0, 1 / 2, 0], [1 / 4, 1 / 4, 1 / 4, 1 / 4]]),
qml.probs(wires=[0, 1]),
np.array([[0, 0, 0, 1], [1 / 2, 0, 1 / 2, 0], [1 / 4, 1 / 4, 1 / 4, 1 / 4]]),
),
(qml.expval(qml.PauliZ(1)), np.array([-1, 1, 0])),
(qml.var(qml.PauliZ(1)), np.array([0, 0, 1])),
Expand Down Expand Up @@ -856,7 +856,7 @@ def test_nonsample_measure_shot_vector(self, shots, measurement, expected):
r = r[0]

assert r.shape == expected.shape
assert np.allclose(r, expected, atol=0.01)
assert np.allclose(r, expected, atol=0.02)


@pytest.mark.jax
Expand Down

0 comments on commit bbae7d4

Please sign in to comment.