Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shots.bins() generator method #5476

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c7338ff
Implement Shots.bins()
Tarun-Kumar07 Apr 5, 2024
00f6770
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 5, 2024
9c37690
Improve docstring
Tarun-Kumar07 Apr 5, 2024
5eca670
Add entry to changelog-dev.md
Tarun-Kumar07 Apr 5, 2024
bb615f4
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 5, 2024
130c07a
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 9, 2024
6c24bf1
Sample once in `_measure_with_samples_diagonalizing_gates`
Tarun-Kumar07 Apr 9, 2024
f47844f
Merge branch 'master' into Add-Shots.bins()-generator-method
albi3ro Apr 9, 2024
0ca1341
Handle broadcasting
Tarun-Kumar07 Apr 10, 2024
70c01c3
Format test_sampling.py
Tarun-Kumar07 Apr 10, 2024
33cf4c1
Refactor sampling logic
Tarun-Kumar07 Apr 10, 2024
65070c2
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 10, 2024
5140130
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 16, 2024
5314abc
Fix jax tests
Tarun-Kumar07 Apr 16, 2024
c1180e7
Implement code review suggestions
Tarun-Kumar07 Apr 17, 2024
dbc4fe7
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 17, 2024
5044b1d
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 17, 2024
f12dd3e
Implement code review changes
Tarun-Kumar07 Apr 18, 2024
1d1a59e
Merge branch 'master' into Add-Shots.bins()-generator-method
Tarun-Kumar07 Apr 18, 2024
c9e8348
Merge branch 'master' into Add-Shots.bins()-generator-method
co9olguy Apr 18, 2024
a9209ae
Merge branch 'master' into Add-Shots.bins()-generator-method
albi3ro Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@
* Removed the warning that an observable might not be hermitian in `qnode` executions. This enables jit-compilation.
[(#5506)](https://github.com/PennyLaneAI/pennylane/pull/5506)

* Implement `Shots.bins()` method.
[(#5476)](https://github.com/PennyLaneAI/pennylane/pull/5476)

<h3>Breaking changes 💔</h3>

* Operator dunder methods now combine like-operator arithmetic classes via `lazy=False`. This reduces the chance of `RecursionError` and makes nested
Expand Down
38 changes: 12 additions & 26 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,6 @@ def _process_single_shot(samples):

return tuple(processed)

# if there is a shot vector, build a list containing results for each shot entry
if shots.has_partitioned_shots:
processed_samples = []
for s in shots:
# currently we call sample_state for each shot entry, but it may be
# better to call sample_state just once with total_shots, then use
# the shot_range keyword argument
try:
samples = sample_state(
state,
shots=s,
is_state_batched=is_state_batched,
wires=wires,
rng=rng,
prng_key=prng_key,
)
except ValueError as e:
if str(e) != "probabilities contain NaN":
raise e
samples = qml.math.full((s, len(wires)), 0)

processed_samples.append(_process_single_shot(samples))

return tuple(zip(*processed_samples))

try:
samples = sample_state(
state,
Expand All @@ -307,7 +282,18 @@ def _process_single_shot(samples):
raise e
samples = qml.math.full((shots.total_shots, len(wires)), 0)

return _process_single_shot(samples)
processed_samples = []
for lower, upper in shots.bins():
if len(samples.shape) == 3:
Tarun-Kumar07 marked this conversation as resolved.
Show resolved Hide resolved
# Handle broadcasting
processed_samples.append(_process_single_shot(samples[:, lower:upper, :]))
else:
processed_samples.append(_process_single_shot(samples[lower:upper]))

if shots.has_partitioned_shots:
return tuple(zip(*processed_samples))

return processed_samples[0]


def _measure_classical_shadow(
Expand Down
17 changes: 17 additions & 0 deletions pennylane/measurements/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,20 @@ def has_partitioned_shots(self):
def num_copies(self):
"""The total number of copies of any shot quantity."""
return sum(s.copies for s in self.shot_vector)

def bins(self):
"""
Yields:
tuple: A tuple containing the lower and upper bounds for each shot quantity in shot_vector.

Example:
>>> shots = Shots((1, 1, 2, 3))
>>> list(shots.bins())
[(0,1), (1,2), (2,4), (4,7)]
"""
lower_bound = 0
for sc in self.shot_vector:
for _ in range(sc.copies):
upper_bound = lower_bound + sc.shots
yield lower_bound, upper_bound
lower_bound = upper_bound
Tarun-Kumar07 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions tests/devices/qubit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def test_nonsample_measure_shot_vector(self, shots, measurement, expected):
r = r[0]

assert r.shape == expected.shape
assert np.allclose(r, expected, atol=0.01)
assert np.allclose(r, expected, atol=0.02)


@pytest.mark.jax
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def test_nonsample_measure_shot_vector(self, mocker, shots, measurement, expecte
r = r[0]

assert r.shape == expected.shape
assert np.allclose(r, expected, atol=0.01)
assert np.allclose(r, expected, atol=0.03)


class TestHamiltonianSamples:
Expand Down
20 changes: 20 additions & 0 deletions tests/measurements/test_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,23 @@ def test_shots_rmul(self):
scaled_sh1 = 2 * sh1
rev_scaled_sh1 = sh1 * 2
assert scaled_sh1.total_shots == rev_scaled_sh1.total_shots


class TestShotsBins:
"""Tests Shots.bins() method."""

def test_when_shots_is_none(self):
"""Tests that the method returns an empty list when shots is None."""
Tarun-Kumar07 marked this conversation as resolved.
Show resolved Hide resolved
shots = Shots(None)
assert not list(shots.bins())

def test_when_shots_is_int(self):
"""Tests that the method returns the correct bins when shots is an int."""
shots = Shots(10)
assert list(shots.bins()) == [(0, 10)]

@pytest.mark.parametrize("sequence", [[1, 1, 3, 4], [(1, 2), 3, 4]])
def test_when_shots_is_sequence_with_copies(self, sequence):
"""Tests that the method returns the correct bins when shots is a sequence with copies."""
shots = Shots(sequence)
assert list(shots.bins()) == [(0, 1), (1, 2), (2, 5), (5, 9)]
Loading