Skip to content

Commit

Permalink
Add a marginal_distribution function (Qiskit#8026)
Browse files Browse the repository at this point in the history
* Add a marginal_distribution function

This commit adds a new function marginal_distribution which performs
marginalization similar to the existing marginal_counts() function. This
new function however differs in a few key ways. Most importantly the order
of the bit indices is significant for the purpose of permuting the bits
while marginalizing, while marginal_counts() doesn't do this.
Additionally, this function only works with dictionaries and not Result
objects and is written in Rust for performance. While the purposed of
this function is mostly identical to marginal_counts(), because making
the bit indices order significant is a breaking change this was done as
a separate function to avoid that. Once this new function is released
we can look at deprecated and eventually removing the existing
marginal_counts() function.

Fixes Qiskit#6230

* Fix typos in python function

* Handle missing indices

In case a bit string has stripped leading zeros this commit adds
handling to treat a missing index from the bit string as a '0'.

* Apply suggestions from code review

Co-authored-by: Jake Lishman <[email protected]>

* Remove unecessary rust index

* Update test

Co-authored-by: Jake Lishman <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 16, 2022
1 parent 4d459c8 commit 84cfd5c
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 1 deletion.
2 changes: 2 additions & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
sys.modules["qiskit._accelerate.pauli_expval"] = qiskit._accelerate.pauli_expval
sys.modules["qiskit._accelerate.dense_layout"] = qiskit._accelerate.dense_layout
sys.modules["qiskit._accelerate.sparse_pauli_op"] = qiskit._accelerate.sparse_pauli_op
sys.modules["qiskit._accelerate.results"] = qiskit._accelerate.results


# qiskit errors operator
from qiskit.exceptions import QiskitError, MissingOptionalLibraryError
Expand Down
2 changes: 2 additions & 0 deletions qiskit/result/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ResultError
Counts
marginal_counts
marginal_distribution
Distributions
=============
Expand All @@ -48,6 +49,7 @@
from .result import Result
from .exceptions import ResultError
from .utils import marginal_counts
from .utils import marginal_distribution
from .counts import Counts

from .distributions.probability import ProbDistribution
Expand Down
55 changes: 54 additions & 1 deletion qiskit/result/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=c-extension-no-member

"""Utility functions for working with Results."""

from typing import List, Union, Optional, Dict
Expand All @@ -18,8 +20,15 @@

from qiskit.exceptions import QiskitError
from qiskit.result.result import Result
from qiskit.result.counts import Counts
from qiskit.result.distributions.probability import ProbDistribution
from qiskit.result.distributions.quasi import QuasiDistribution

from qiskit.result.postprocess import _bin_to_hex, _hex_to_bin

# pylint: disable=import-error, no-name-in-module
from qiskit._accelerate import results as results_rs


def marginal_counts(
result: Union[dict, Result],
Expand Down Expand Up @@ -118,10 +127,54 @@ def _adjust_creg_sizes(creg_sizes, indices):
return new_creg_sizes


def marginal_distribution(
counts: dict, indices: Optional[List[int]] = None, format_marginal: bool = False
) -> Dict[str, int]:
"""Marginalize counts from an experiment over some indices of interest.
Unlike :func:`~.marginal_counts` this function respects the order of
the input ``indices``. If the input ``indices`` list is specified, the order
the bit indices will be the output order of the bitstrings
in the marginalized output.
Args:
counts: result to be marginalized
indices: The bit positions of interest
to marginalize over. If ``None`` (default), do not marginalize at all.
format_marginal: Default: False. If True, takes the output of
marginalize and formats it with placeholders between cregs and
for non-indices.
Returns:
dict(str, int): A marginalized dictionary
Raises:
QiskitError: If any value in ``indices`` is invalid or the ``counts`` dict
is invalid.
"""
num_clbits = len(max(counts.keys()).replace(" ", ""))
if indices is not None and (not indices or not set(indices).issubset(range(num_clbits))):
raise QiskitError(f"indices must be in range [0, {num_clbits - 1}].")

if isinstance(counts, Counts):
res = results_rs.marginal_counts(counts, indices)
elif isinstance(counts, (ProbDistribution, QuasiDistribution)):
res = results_rs.marginal_distribution(counts, indices)
else:
first_value = next(iter(counts.values()))
if isinstance(first_value, int):
res = results_rs.marginal_counts(counts, indices)
elif isinstance(first_value, float):
res = results_rs.marginal_distribution(counts, indices)
else:
raise QiskitError("Values of counts must be an int or float")

if format_marginal and indices is not None:
return _format_marginal(counts, res, indices)
return res


def _marginalize(counts, indices=None):
"""Get the marginal counts for the given set of indices"""
num_clbits = len(next(iter(counts)).replace(" ", ""))

# Check if we do not need to marginalize and if so, trim
# whitespace and '_' and return
if (indices is None) or set(range(num_clbits)) == set(indices):
Expand Down
16 changes: 16 additions & 0 deletions releasenotes/notes/add-marginal-distribution-21060de506ed9cfc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
features:
- |
Added a new function, :func:`~.marginal_distribution`, which is used to
marginalize an input dictionary of bitstrings to an integer (such as
:class:`~.Counts`). This is similar in functionality to the existing
:func:`~.marginal_counts` function with three key differences. The first
is that :func:`~.marginal_counts` works with either a counts dictionary
or a :class:`~.Results` object while :func:`~.marginal_distribution` only
works with a dictionary. The second is that :func:`~.marginal_counts` does
not respect the order of indices in its ``indices`` argument while
:func:`~.marginal_distribution` does and will permute the output bits
based on the ``indices`` order. The third difference is that
:func:`~.marginal_distribution` should be faster as its implementation
is written in Rust and streamlined for just marginalizing a dictionary
input.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod dense_layout;
mod edge_collections;
mod nlayout;
mod pauli_exp_val;
mod results;
mod sparse_pauli_op;
mod stochastic_swap;

Expand All @@ -42,5 +43,6 @@ fn _accelerate(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(pauli_exp_val::pauli_expval))?;
m.add_wrapped(wrap_pymodule!(dense_layout::dense_layout))?;
m.add_wrapped(wrap_pymodule!(sparse_pauli_op::sparse_pauli_op))?;
m.add_wrapped(wrap_pymodule!(results::results))?;
Ok(())
}
70 changes: 70 additions & 0 deletions src/results/marginalization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2022
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use hashbrown::HashMap;
use pyo3::prelude::*;

fn marginalize<T: std::ops::AddAssign + Copy>(
counts: HashMap<String, T>,
indices: Option<Vec<usize>>,
) -> HashMap<String, T> {
let mut out_counts: HashMap<String, T> = HashMap::with_capacity(counts.len());
let clbit_size = counts.keys().next().unwrap().replace(&['_', ' '], "").len();
let all_indices: Vec<usize> = (0..clbit_size).collect();
counts
.iter()
.map(|(k, v)| (k.replace(&['_', ' '], ""), *v))
.for_each(|(k, v)| match &indices {
Some(indices) => {
if all_indices == *indices {
out_counts.insert(k, v);
} else {
let key_arr = k.as_bytes();
let new_key: String = indices
.iter()
.map(|bit| {
let index = clbit_size - *bit - 1;
match key_arr.get(index) {
Some(bit) => *bit as char,
None => '0',
}
})
.rev()
.collect();
out_counts
.entry(new_key)
.and_modify(|e| *e += v)
.or_insert(v);
}
}
None => {
out_counts.insert(k, v);
}
});
out_counts
}

#[pyfunction]
pub fn marginal_counts(
counts: HashMap<String, u64>,
indices: Option<Vec<usize>>,
) -> HashMap<String, u64> {
marginalize(counts, indices)
}

#[pyfunction]
pub fn marginal_distribution(
counts: HashMap<String, f64>,
indices: Option<Vec<usize>>,
) -> HashMap<String, f64> {
marginalize(counts, indices)
}
23 changes: 23 additions & 0 deletions src/results/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2022
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

pub mod marginalization;

use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

#[pymodule]
pub fn results(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(marginalization::marginal_counts))?;
m.add_wrapped(wrap_pyfunction!(marginalization::marginal_distribution))?;
Ok(())
}
36 changes: 36 additions & 0 deletions test/python/result/test_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def test_marginal_counts(self):
result = utils.marginal_counts(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_marginal_distribution(self):
raw_counts = {"0x0": 4, "0x1": 7, "0x2": 10, "0x6": 5, "0x9": 11, "0xD": 9, "0xE": 8}
expected = {"00": 4, "01": 27, "10": 23}
counts_obj = counts.Counts(raw_counts, creg_sizes=[["c0", 4]], memory_slots=4)
result = utils.marginal_distribution(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_int_outcomes(self):
raw_counts = {"0x0": 21, "0x2": 12, "0x3": 5, "0x2E": 265}
expected = {0: 21, 2: 12, 3: 5, 46: 265}
Expand Down Expand Up @@ -90,6 +97,13 @@ def test_marginal_int_counts(self):
result = utils.marginal_counts(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_marginal_distribution_int_counts(self):
raw_counts = {0: 4, 1: 7, 2: 10, 6: 5, 9: 11, 13: 9, 14: 8}
expected = {"00": 4, "01": 27, "10": 23}
counts_obj = counts.Counts(raw_counts, creg_sizes=[["c0", 4]], memory_slots=4)
result = utils.marginal_distribution(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_int_outcomes_with_int_counts(self):
raw_counts = {0: 21, 2: 12, 3: 5, 46: 265}
counts_obj = counts.Counts(raw_counts)
Expand Down Expand Up @@ -139,6 +153,13 @@ def test_marginal_bitstring_counts(self):
result = utils.marginal_counts(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_marginal_distribution_bitstring_counts(self):
raw_counts = {"0": 4, "1": 7, "10": 10, "110": 5, "1001": 11, "1101": 9, "1110": 8}
expected = {"00": 4, "01": 27, "10": 23}
counts_obj = counts.Counts(raw_counts, creg_sizes=[["c0", 4]], memory_slots=4)
result = utils.marginal_distribution(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_int_outcomes_with_bitstring_counts(self):
raw_counts = {"0": 21, "10": 12, "11": 5, "101110": 265}
expected = {0: 21, 2: 12, 3: 5, 46: 265}
Expand Down Expand Up @@ -268,6 +289,21 @@ def test_marginal_0b_string_counts(self):
result = utils.marginal_counts(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_marginal_distribution_0b_string_counts(self):
raw_counts = {
"0b0": 4,
"0b1": 7,
"0b10": 10,
"0b110": 5,
"0b1001": 11,
"0b1101": 9,
"0b1110": 8,
}
expected = {"00": 4, "01": 27, "10": 23}
counts_obj = counts.Counts(raw_counts, creg_sizes=[["c0", 4]], memory_slots=4)
result = utils.marginal_distribution(counts_obj, [0, 1])
self.assertEqual(expected, result)

def test_int_outcomes_with_0b_bitstring_counts(self):
raw_counts = {"0b0": 21, "0b10": 12, "0b11": 5, "0b101110": 265}
expected = {0: 21, 2: 12, 3: 5, 46: 265}
Expand Down
29 changes: 29 additions & 0 deletions test/python/result/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from qiskit.result import models
from qiskit.result import marginal_counts
from qiskit.result import marginal_distribution
from qiskit.result import Result
from qiskit.qobj import QobjExperimentHeader
from qiskit.test import QiskitTestCase
Expand Down Expand Up @@ -181,6 +182,34 @@ def test_marginal_counts(self):
self.assertEqual(marginal_counts(result.get_counts(), [0, 1]), expected_marginal_counts)
self.assertEqual(marginal_counts(result.get_counts(), [1, 0]), expected_marginal_counts)

def test_marginal_distribution(self):
"""Test that counts are marginalized correctly."""
raw_counts = {"0x0": 4, "0x1": 7, "0x2": 10, "0x6": 5, "0x9": 11, "0xD": 9, "0xE": 8}
data = models.ExperimentResultData(counts=raw_counts)
exp_result_header = QobjExperimentHeader(creg_sizes=[["c0", 4]], memory_slots=4)
exp_result = models.ExperimentResult(
shots=54, success=True, data=data, header=exp_result_header
)
result = Result(results=[exp_result], **self.base_result_args)
expected_marginal_counts = {"00": 4, "01": 27, "10": 23}
expected_reverse = {"00": 4, "10": 27, "01": 23}

self.assertEqual(
marginal_distribution(result.get_counts(), [0, 1]), expected_marginal_counts
)
self.assertEqual(marginal_distribution(result.get_counts(), [1, 0]), expected_reverse)
# test with register spacing, bitstrings are in form of "00 00" for register split
data = models.ExperimentResultData(counts=raw_counts)
exp_result_header = QobjExperimentHeader(creg_sizes=[["c0", 2], ["c1", 2]], memory_slots=4)
exp_result = models.ExperimentResult(
shots=54, success=True, data=data, header=exp_result_header
)
result = Result(results=[exp_result], **self.base_result_args)
self.assertEqual(
marginal_distribution(result.get_counts(), [0, 1]), expected_marginal_counts
)
self.assertEqual(marginal_distribution(result.get_counts(), [1, 0]), expected_reverse)

def test_marginal_counts_result(self):
"""Test that a Result object containing counts marginalizes correctly."""
raw_counts_1 = {"0x0": 4, "0x1": 7, "0x2": 10, "0x6": 5, "0x9": 11, "0xD": 9, "0xE": 8}
Expand Down

0 comments on commit 84cfd5c

Please sign in to comment.