Skip to content

Commit

Permalink
add tests for added util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
natestemen committed Jan 11, 2024
1 parent 4cc124e commit d32a301
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions mitiq/shadows/test/test_shadows_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
# This source code is licensed under the GPL license (v3) found in the
# LICENSE file in the root directory of this source tree.

"""Defines utility functions for classical shadows protocol."""
import math

import mitiq
from mitiq.shadows.shadows_utils import (
create_string,
n_measurements_opts_expectation_bound,
n_measurements_tomography_bound,
batch_calibration_data,
valid_bitstrings,
fidelity,
)
import numpy as np


def test_create_string():
Expand All @@ -19,16 +23,33 @@ def test_create_string():
assert create_string(str_len, loc_list) == "01010"


def test_valid_bitstrings():
num_qubits = 5
bitstrings_on_5_qubits = valid_bitstrings(num_qubits)
assert len(bitstrings_on_5_qubits) == 2**num_qubits
assert all(b == "0" or b == "1" for b in bitstrings_on_5_qubits.pop())

num_qubits = 4
max_hamming_weight = 2
bitstrings_on_3_qubits_hamming_2 = valid_bitstrings(
num_qubits, max_hamming_weight
)
assert len(bitstrings_on_3_qubits_hamming_2) == sum(
math.comb(num_qubits, i) for i in range(max_hamming_weight + 1)
) # sum_{i == 0}^{max_hamming_weight} (num_qubits choose i)


def test_batch_calibration_data():
data = (["010", "110", "000", "001"], ["XXY", "ZYY", "ZZZ", "XYZ"])
num_batches = 2
for bits, paulis in batch_calibration_data(data, num_batches):
assert len(bits) == len(paulis) == num_batches


def test_n_measurements_tomography_bound():
assert (
n_measurements_tomography_bound(0.5, 2) == 2176
), f"Expected 2176, got {n_measurements_tomography_bound(0.5, 2)}"
assert (
n_measurements_tomography_bound(1.0, 1) == 136
), f"Expected 136, got {n_measurements_tomography_bound(1.0, 1)}"
assert (
n_measurements_tomography_bound(0.1, 3) == 217599
), f"Expected 217599, got {n_measurements_tomography_bound(0.1, 3)}"
assert n_measurements_tomography_bound(0.5, 2) == 2176
assert n_measurements_tomography_bound(1.0, 1) == 136
assert n_measurements_tomography_bound(0.1, 3) == 217599


def test_n_measurements_opts_expectation_bound():
Expand All @@ -38,5 +59,13 @@ def test_n_measurements_opts_expectation_bound():
mitiq.PauliString("Z"),
]
N, K = n_measurements_opts_expectation_bound(0.5, observables, 0.1)
assert isinstance(N, int), f"Expected int, got {type(N)}"
assert isinstance(K, int), f"Expected int, got {type(K)}"
assert isinstance(N, int)
assert isinstance(K, int)


def test_fidelity():
state_vector = np.array([0.5, 0.5, 0.5, 0.5])
rho = np.eye(4) / 4
assert np.isclose(
fidelity(state_vector, rho), 0.25
), f"Expected 0.25, got {fidelity(state_vector, rho)}"

0 comments on commit d32a301

Please sign in to comment.