Skip to content

Commit

Permalink
refactor expectation_estimation_shadow
Browse files Browse the repository at this point in the history
  • Loading branch information
natestemen committed Jan 9, 2024
1 parent 2bd4adb commit 4c38a82
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 71 deletions.
73 changes: 26 additions & 47 deletions mitiq/shadows/classical_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def shadow_state_reconstruction(

def expectation_estimation_shadow(
measurement_outcomes: Tuple[List[str], List[str]],
pauli_str: mitiq.PauliString,
k_shadows: int,
f_est: Optional[Dict[str, float]] = None,
pauli: mitiq.PauliString,
batch_size: int,
fidelities: Optional[Dict[str, float]] = None,
) -> float:
"""Calculate the expectation value of an observable from classical shadows.
Use median of means to ameliorate the effects of outliers.
Expand All @@ -227,62 +227,41 @@ def expectation_estimation_shadow(
`z_basis_measurement`.
pauli_str: Single mitiq observable consisting of
Pauli operators.
k_shadows: number of splits in the median of means estimator.
batch_size: Size of batches to process measurement outcomes in.
f_est: The estimated Pauli fidelities to use for calibration if
available.
Returns:
Float corresponding to the estimate of the observable
expectation value.
Float corresponding to the estimate of the observable expectation
value.
"""
num_qubits = len(measurement_outcomes[0][0])
obs = pauli_str._pauli
coeff = pauli_str.coeff

target_obs, target_locs = [], []
for qubit, pauli in obs.items():
target_obs.append(str(pauli))
target_locs.append(int(qubit))

# classical values stored in classical computer
b_lists_shadow = np.array([list(u) for u in measurement_outcomes[0]])[
:, target_locs
bitstrings, paulistrings = measurement_outcomes
num_qubits = len(bitstrings[0])

qubits = sorted(pauli.support())
filtered_bitstrings = [
"".join([bitstring[q] for q in qubits]) for bitstring in bitstrings
]
u_lists_shadow = np.array([list(u) for u in measurement_outcomes[1]])[
:, target_locs
filtered_paulis = [
"".join([pauli[q] for q in qubits]) for pauli in paulistrings
]
filtered_data = (filtered_bitstrings, filtered_paulis)

means = []

# loop over the splits of the shadow:
group_idxes = np.array_split(np.arange(len(b_lists_shadow)), k_shadows)

# loop over the splits of the shadow:
for idxes in group_idxes:
matching_indexes = np.nonzero(
np.all(u_lists_shadow[idxes] == target_obs, axis=1)
)

if len(matching_indexes[0]):
product = (-1) ** np.sum(
b_lists_shadow[idxes][matching_indexes].astype(int),
axis=1,
)

if f_est:
b = create_string(num_qubits, target_locs)
f_val = f_est.get(b, np.inf)
# product becomes an array of snapshot expectation values
# witch satisfy condition (1) and (2)
product = (1 / f_val) * product
for bits, paulis in batch_calibration_data(filtered_data, batch_size):
matching_indices = [i for i, p in enumerate(paulis) if p == pauli.spec]
if matching_indices:
product = (-1) ** sum(bit.count("1") for bit in bits)

if fidelities:
b = create_string(num_qubits, qubits)
product /= fidelities.get(b, np.inf)
else:
product = 3 ** len(target_locs) * product
product *= 3 ** len(qubits)

else:
product = 0.0

# append the mean of the product in each split
means.append(np.sum(product) / len(idxes))
means.append(product / len(bits))

# return the median of means
return float(np.real(np.median(means) * coeff))
return np.real(np.median(means) * pauli.coeff)
4 changes: 2 additions & 2 deletions mitiq/shadows/shadows.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def classical_post_processing(
expectation_values = expectation_estimation_shadow(
shadow_outcomes,
obs,
k_shadows=k_shadows,
f_est=calibration_results,
batch_size=k_shadows,
fidelities=calibration_results,
)
output[str(obs)] = expectation_values
return output
37 changes: 15 additions & 22 deletions mitiq/shadows/test/test_classical_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,25 +177,21 @@ def test_shadow_state_reconstruction_cal():


def test_expectation_estimation_shadow():
b_lists = ["0101", "0110"]
u_lists = ["ZZXX", "ZZXX"]

measurement_outcomes = (b_lists, u_lists)
observable = mitiq.PauliString("ZZ", support=(0, 1))
k = 1
measurement_outcomes = ["0101", "0110"], ["ZZXX", "ZZXX"]
pauli = mitiq.PauliString("ZZ")
batch_size = 1
expected_result = -9

result = expectation_estimation_shadow(
measurement_outcomes, observable, k, False
measurement_outcomes, pauli, batch_size
)
assert isinstance(result, float), f"Expected a float, got {type(result)}"
assert np.isclose(result, expected_result)


def test_expectation_estimation_shadow_cal():
b_lists = ["0101", "0110"]
u_lists = ["YXZZ", "XXXX"]
f_est = {
bitstrings = ["0101", "0110"]
paulistrings = ["YXZZ", "XXXX"]
fidelities = {
"0000": 1,
"0001": 1 / 3,
"0010": 1 / 3,
Expand All @@ -214,16 +210,14 @@ def test_expectation_estimation_shadow_cal():
"1111": 1 / 81,
}

measurement_outcomes = b_lists, u_lists
observable = mitiq.PauliString("YXZZ", support=(0, 1, 2, 3))
k = 1
measurement_outcomes = bitstrings, paulistrings
pauli = mitiq.PauliString("YXZZ")
batch_size = 1
expected_result = 81 / 2
print("expected_result", expected_result)

result = expectation_estimation_shadow(
measurement_outcomes, observable, k, f_est
measurement_outcomes, pauli, batch_size, fidelities
)
assert isinstance(result, float), f"Expected a float, got {type(result)}"
assert np.isclose(result, expected_result)


Expand All @@ -232,13 +226,12 @@ def test_expectation_estimation_shadow_no_indices():
Test expectation estimation for a shadow with no matching indices.
The result should be 0 as there are no matching
"""
q0, q1, q2 = cirq.LineQubit.range(3)
observable = mitiq.PauliString("XYZ", support=(0, 1, 2))
pauli = mitiq.PauliString("XYZ")
measurement_outcomes = ["101", "010", "101"], ["ZXY", "YZX", "ZZY"]
k_shadows = 1
batch_size = 1

result = expectation_estimation_shadow(
measurement_outcomes, observable, k_shadows, False
measurement_outcomes, pauli, batch_size
)

assert result == 0.0
assert result == 0

0 comments on commit 4c38a82

Please sign in to comment.