Skip to content

Commit

Permalink
compute batch sizes properly; move -1 into sum
Browse files Browse the repository at this point in the history
i misunderstood how `np.array_split` worked. it uses the
number of splits, as opposed to the size of each batch.

(-1)^x + (-1)^y =/= (-1)^(x + y)
  • Loading branch information
natestemen committed Jan 11, 2024
1 parent 88658d5 commit 23b4e1e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
17 changes: 9 additions & 8 deletions mitiq/shadows/classical_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get_single_shot_pauli_fidelity(

def get_pauli_fidelities(
calibration_outcomes: Tuple[List[str], List[str]],
batch_size: int,
num_batches: int,
locality: Optional[int] = None,
) -> Dict[str, complex]:
r"""
Expand All @@ -100,7 +100,7 @@ def get_pauli_fidelities(
Args:
calibration_measurement_outcomes: The `random_Pauli_measurement`
outcomes for the state :math:`|0\rangle^{\otimes n}`}` .
k_calibration: number of splits in the median of means estimator.
num_batches: The number of batches in the median of means estimator.
locality: The locality of the operator, whose expectation value is
going to be estimated by the classical shadow. E.g., if the
operator is the Ising model Hamiltonian with nearest neighbor
Expand All @@ -112,7 +112,7 @@ def get_pauli_fidelities(
"""
means = defaultdict(list)
for bitstrings, paulistrings in batch_calibration_data(
calibration_outcomes, batch_size
calibration_outcomes, num_batches
):
all_fidelities = defaultdict(list)
for bitstring, paulistring in zip(bitstrings, paulistrings):
Expand All @@ -123,7 +123,7 @@ def get_pauli_fidelities(
all_fidelities[b].append(f)

for bitstring, fids in all_fidelities.items():
means[bitstring].append(sum(fids) / batch_size)
means[bitstring].append(sum(fids) / num_batches)

return {
bitstring: median(averages) for bitstring, averages in means.items()
Expand Down Expand Up @@ -216,7 +216,7 @@ def shadow_state_reconstruction(
def expectation_estimation_shadow(
measurement_outcomes: Tuple[List[str], List[str]],
pauli: mitiq.PauliString,
batch_size: int,
num_batches: int,
fidelities: Optional[Dict[str, float]] = None,
) -> float:
"""Calculate the expectation value of an observable from classical shadows.
Expand All @@ -227,7 +227,7 @@ def expectation_estimation_shadow(
`z_basis_measurement`.
pauli_str: Single mitiq observable consisting of
Pauli operators.
batch_size: Size of batches to process measurement outcomes in.
num_batches: Number of batches to process measurement outcomes in.
f_est: The estimated Pauli fidelities to use for calibration if
available.
Expand All @@ -248,10 +248,11 @@ def expectation_estimation_shadow(
filtered_data = (filtered_bitstrings, filtered_paulis)

means = []
for bits, paulis in batch_calibration_data(filtered_data, batch_size):
for bits, paulis in batch_calibration_data(filtered_data, num_batches):
matching_indices = [i for i, p in enumerate(paulis) if p == pauli.spec]
if matching_indices:
product = (-1) ** sum(bit.count("1") for bit in bits)
matching_bits = (bits[i] for i in matching_indices)
product = sum((-1) ** bit.count("1") for bit in matching_bits)

if fidelities:
b = create_string(num_qubits, qubits)
Expand Down
2 changes: 1 addition & 1 deletion mitiq/shadows/shadows.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def classical_post_processing(
expectation_values = expectation_estimation_shadow(
shadow_outcomes,
obs,
batch_size=k_shadows,
num_batches=k_shadows,
fidelities=calibration_results,
)
output[str(obs)] = expectation_values
Expand Down
3 changes: 2 additions & 1 deletion mitiq/shadows/shadows_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def fidelity(


def batch_calibration_data(
data: Tuple[List[str], List[str]], batch_size: int
data: Tuple[List[str], List[str]], num_batches: int
) -> Generator[Tuple[List[str], List[str]], None, None]:
"""Batch calibration into chunks of size batch_size.
Expand All @@ -110,6 +110,7 @@ def batch_calibration_data(
Tuples of bit strings and pauli strings.
"""
bits, paulis = data
batch_size = len(bits) // num_batches
for i in range(0, len(bits), batch_size):
yield bits[i : i + batch_size], paulis[i : i + batch_size]

Expand Down

0 comments on commit 23b4e1e

Please sign in to comment.