Skip to content

Commit

Permalink
Add marking misclassified points to IQPlotter. (#962)
Browse files Browse the repository at this point in the history
* This PR adds the functionality to flag misclassified IQ points.

Co-authored-by: Daniel J. Egger <[email protected]>
  • Loading branch information
conradhaupt and eggerdj authored Nov 2, 2022
1 parent 1eccb16 commit 95664a1
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 25 deletions.
77 changes: 65 additions & 12 deletions qiskit_experiments/visualization/plotters/iq_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ class IQPlotter(BasePlotter):
:class:`IQPlotter` plots results from experiments which used measurement-level 1, i.e. IQ data. This
class also supports plotting predictions from a discriminator (subclass of
:class:`BaseDiscriminator`), which is used to classify IQ results into measurement labels. The
canonical application of :class:`IQPlotter` is for classification of single-qubit readout for
different prepared states.
:class:`BaseDiscriminator`), which is used to classify IQ results into labels. The discriminator
labels are matched with the series-names to generate an image of the predictions. Points that are
misclassified by the discriminator are flagged in the figure (see ``flag_misclassified``
:attr:`option`). A canonical application of :class:`IQPlotter` is for classification of
single-qubit readout for different prepared states.
Example:
.. code-block:: python
Expand Down Expand Up @@ -61,7 +63,7 @@ class also supports plotting predictions from a discriminator (subclass of
...
# Optional: Add trained discriminator.
discrim = MyIQDiscriminator()
discrim.fit(train_data,train_labels)
discrim.fit(train_data,train_labels) # Labels are the same as series-names.
plotter.set_supplementary_data(discriminator=discrim)
...
# Plot figure.
Expand Down Expand Up @@ -116,9 +118,9 @@ def _compute_extent(self) -> Optional[ExtentTuple]:
(points,) = self.data_for(series, "points")
ext_calc.register_data(points)
has_registered_data = True
if self.data_exists_for(series, "centroids"):
if self.data_exists_for(series, "centroid"):
(centroid,) = self.data_for(series, "centroid")
ext_calc.register_data(centroid)
ext_calc.register_data(np.asarray(centroid).reshape(1, 2))
has_registered_data = True
if self.figure_options.xlim:
ext_calc.register_data(self.figure_options.xlim, dim=0)
Expand Down Expand Up @@ -154,6 +156,11 @@ def _compute_discriminator_image(
if extent is None:
return None, None

# Get the discriminator and check if it is trained. If not, return.
discrim: BaseDiscriminator = self.supplementary_data["discriminator"]
if not discrim.is_trained():
return None, None

# Compute discriminator resolution.
extent_range = np.diff(np.asarray(extent).reshape(2, 2), axis=1).flatten()
resolution = (
Expand All @@ -169,10 +176,7 @@ def _compute_discriminator_image(
)
]

# Get predictions for coordinates from the discriminator, if the discriminator is trained.
discrim: BaseDiscriminator = self.supplementary_data["discriminator"]
if not discrim.is_trained():
return None, None
# Get predictions for coordinates from the discriminator.
predictions = discrim.predict(coords)

# Unwrap predictions into a 2D array
Expand Down Expand Up @@ -202,7 +206,15 @@ def _default_options(cls) -> Options:
discriminator_extent (Optional[ExtentTuple]): An optional tuple defining the extent of the
image created by sampling from the discriminator. If ``None``, the extent tuple is
computed using ``discriminator_multiplier``, ``discriminator_aspect_ratio``, and the
series-data ``points`` and ``centroids``. Defaults to ``None``.
series-data ``points`` and ``centroid``. Defaults to ``None``.
flag_misclassified (bool): Whether to mark misclassified IQ values from all ``points`` series
data, based on whether their series-name is not the same as the prediction from the
discriminator provided as supplementary data. If ``discriminator`` is not provided,
``flag_misclassified`` has no effect. Defaults to True.
misclassified_symbol (str): Symbol for misclassified points, as a drawer-compatible string.
Defaults to "x".
misclassified_color (str | tuple): Color for misclassified points, as a drawer-compatible
string or RGB tuple. Defaults to "r".
"""
options = super()._default_options()
Expand All @@ -213,6 +225,10 @@ def _default_options(cls) -> Options:
options.discriminator_max_resolution = 1024
options.discriminator_alpha = 0.2
options.discriminator_extent = None
# Points options
options.flag_misclassified = True
options.misclassified_symbol = "x"
options.misclassified_color = "r"
return options

@classmethod
Expand All @@ -224,6 +240,29 @@ def _default_figure_options(cls) -> Options:
fig_opts.yval_unit = "arb."
return fig_opts

def _misclassified_points(self, series_name: str, points: np.ndarray) -> Optional[np.ndarray]:
"""Returns a list of IQ coordinates for points that are misclassified by the discriminator.
Args:
series_name: The series-name to use as the expected discriminator label. If the discriminator
returns a prediction that doesn't equal ``series_name``, it is marked as misclassified.
points: The list of points to check for misclassification.
Returns:
Optional[np.ndarray]: A NumPy array of IQ points, being those that were misclassified by the
discriminator. If the discriminator isn't set and trained, then `None` is returned. The array
may be empty.
"""
# Check if we have a discriminator, and if it is trained. If not, return None.
if "discriminator" not in self.supplementary_data:
return None
discrim: BaseDiscriminator = self.supplementary_data["discriminator"]
if not discrim.is_trained():
return None
classifications = discrim.predict(points)
misclassified = np.argwhere(classifications != series_name)
return points[misclassified, :].reshape(-1, 2)

def _plot_figure(self):
"""Plots an IQ figure."""
# Plot discriminator first so that subsequent graphics change the automatic limits. This is a
Expand Down Expand Up @@ -260,7 +299,7 @@ def _plot_figure(self):
centroid[1],
name=ser,
legend=True,
zorder=3,
zorder=4,
s=20,
edgecolor="k",
marker="o",
Expand All @@ -278,3 +317,17 @@ def _plot_figure(self):
alpha=0.2,
marker=".",
)
if self.options.flag_misclassified:
misclassified_points = self._misclassified_points(ser, points)
if misclassified_points is not None:
self.drawer.scatter(
misclassified_points[:, 0],
misclassified_points[:, 1],
name="misclassified",
legend=False,
zorder=3,
s=10,
alpha=0.4,
marker=self.options.misclassified_symbol,
color=self.options.misclassified_color,
)
79 changes: 66 additions & 13 deletions test/visualization/test_iq_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Test IQ plotter.
"""

from itertools import product
from test.base import QiskitExperimentsTestCase
from typing import Any, Dict, List, Tuple

Expand All @@ -23,31 +24,48 @@
from qiskit_experiments.visualization import IQPlotter, MplDrawer


class MockDiscriminatorNotTrainedException(Exception):
"""Mock exception to be raised when :meth:`MockDiscriminator.predict` is called on an untrained
:class:`MockDiscriminator`."""

pass


class MockDiscriminator(BaseDiscriminator):
"""A mock discriminator for testing."""

def __init__(self, is_trained: bool = False):
def __init__(
self, is_trained: bool = False, n_states: int = 3, raise_predict_not_trained: bool = False
):
"""Create a MockDiscriminator instance.
Args:
is_trained: Whether the discriminator is trained or not. Defaults to False.
n_states: The number of states/labels. Defaults to 3.
raise_predict_not_trained: Whether to raise an exception if :meth:`predict` is called and
:attr:`is_trained` is ``False``. Raises
"""
super().__init__()
self._is_trained = is_trained
self._n_states = n_states
self._raise_predict_not_trained = raise_predict_not_trained
self.predict_was_called = False
"""Whether :meth:`predict` was called at least once."""

def predict(self, data: List):
"""Returns dummy predictions where everything has the label ``0``."""
"""Returns dummy predictions with random labels."""
self.predict_was_called = True
if self._raise_predict_not_trained and not self.is_trained():
raise MockDiscriminatorNotTrainedException()
if isinstance(data, list):
return [0] * len(data)
return [0] * data.shape[0]
return np.random.choice([f"{i}" for i in range(self._n_states)], len(data)).tolist()
return np.random.choice([f"{i}" for i in range(self._n_states)], data.shape[0])

def config(self) -> Dict[str, Any]:
return {
"predict_was_called": self.predict_was_called,
"is_trained": self._is_trained,
"n_states": self._n_states,
}

def is_trained(self) -> bool:
Expand All @@ -63,12 +81,15 @@ def _dummy_data(
cls,
is_trained: bool = True,
n_series: int = 3,
raise_predict_not_trained: bool = False,
) -> Tuple[List, List, BaseDiscriminator]:
"""Create dummy data for the tests.
Args:
is_trained: Whether the discriminator should be trained or not. Defaults to True.
n_series: The number of series to generate dummy data for. Defaults to 3.
raise_predict_not_trained: Passed to the discriminator :class:`MockDiscriminator` class.
Returns:
tuple: the tuple ``(points, names, discrim)`` where ``points`` is a list of NumPy arrays of
Expand All @@ -80,7 +101,9 @@ def _dummy_data(
for i in range(n_series):
points.append(np.random.rand(128, 2))
labels.append(f"{i}")
mock_discrim = MockDiscriminator(is_trained)
mock_discrim = MockDiscriminator(
is_trained, n_states=n_series, raise_predict_not_trained=raise_predict_not_trained
)
return points, labels, mock_discrim

@ddt.data(True, False)
Expand All @@ -105,15 +128,45 @@ def test_discriminator_trained(self, is_trained: bool):

# Assert that MockDiscriminator.predict() was/wasn't called, depending on whether it was trained
# or not.
self.assertEqual(is_trained, discrim.predict_was_called)

def test_end_to_end(self):
self.assertEqual(
is_trained,
discrim.predict_was_called,
msg=f"Discriminator `predict()` {'was' if is_trained else 'was not'} meant to be called, "
f"but {'was' if discrim.predict_was_called else 'was not'} called. is_trained={is_trained}.",
)

@ddt.data(*list(product([True, False], repeat=3)))
def test_end_to_end(self, args):
"""Test end-to-end functionality of IQPlotter."""
# Expand args
with_centroids, with_misclassified, with_discriminator = args

# Create plotter and add data
plotter = IQPlotter(MplDrawer())
points, labels, discrim = self._dummy_data(is_trained=True)
for points, series_name in zip(points, labels):
centroid = np.mean(points, axis=0)
plotter.set_series_data(series_name, points=points, centroid=centroid)
plotter.set_supplementary_data(discriminator=discrim)
plotter.set_options(flag_misclassified=with_misclassified)
points, labels, discrim = self._dummy_data(
is_trained=True,
)
for series_points, series_name in zip(points, labels):
plotter.set_series_data(series_name, points=series_points)
if with_centroids:
centroid = np.mean(series_points, axis=0)
plotter.set_series_data(series_name, centroid=centroid)
if with_discriminator:
plotter.set_supplementary_data(discriminator=discrim)

# Generate figure
plotter.figure()

# Verify that the correct number of series colours were created. If we are flagging misclassified
# points, we have one extra series. This assumes we are using `MplDrawer`. The discriminator
# should label each input as one of the series-names, which means we should have the same number
# of entries in `plotter.drawer._series` as colours queried from `MplDrawer._get_default_color`
# (stored in `MplDrawer._series`).
self.assertEqual(
len(plotter.drawer._series),
len(points) + (1 if with_misclassified and with_discriminator else 0),
msg="Number of series plotted by IQPlotter does not match the number of series from "
f"the dummy data. Expected {len(points)} but got {len(plotter.drawer._series)}. Series="
f"{plotter.drawer._series}.",
)

0 comments on commit 95664a1

Please sign in to comment.