From 95664a182445963c4602cf47e8a972b7b6da3877 Mon Sep 17 00:00:00 2001 From: Conrad Haupt Date: Wed, 2 Nov 2022 07:53:33 +0100 Subject: [PATCH] Add marking misclassified points to IQPlotter. (#962) * This PR adds the functionality to flag misclassified IQ points. Co-authored-by: Daniel J. Egger <38065505+eggerdj@users.noreply.github.com> --- .../visualization/plotters/iq_plotter.py | 77 +++++++++++++++--- test/visualization/test_iq_plotter.py | 79 ++++++++++++++++--- 2 files changed, 131 insertions(+), 25 deletions(-) diff --git a/qiskit_experiments/visualization/plotters/iq_plotter.py b/qiskit_experiments/visualization/plotters/iq_plotter.py index 419d746d3f..e8b0c7ebe6 100644 --- a/qiskit_experiments/visualization/plotters/iq_plotter.py +++ b/qiskit_experiments/visualization/plotters/iq_plotter.py @@ -28,9 +28,11 @@ class IQPlotter(BasePlotter): :class:`IQPlotter` plots results from experiments which used measurement-level 1, i.e. IQ data. This class also supports plotting predictions from a discriminator (subclass of - :class:`BaseDiscriminator`), which is used to classify IQ results into measurement labels. The - canonical application of :class:`IQPlotter` is for classification of single-qubit readout for - different prepared states. + :class:`BaseDiscriminator`), which is used to classify IQ results into labels. The discriminator + labels are matched with the series-names to generate an image of the predictions. Points that are + misclassified by the discriminator are flagged in the figure (see ``flag_misclassified`` + :attr:`option`). A canonical application of :class:`IQPlotter` is for classification of + single-qubit readout for different prepared states. Example: .. code-block:: python @@ -61,7 +63,7 @@ class also supports plotting predictions from a discriminator (subclass of ... # Optional: Add trained discriminator. discrim = MyIQDiscriminator() - discrim.fit(train_data,train_labels) + discrim.fit(train_data,train_labels) # Labels are the same as series-names. plotter.set_supplementary_data(discriminator=discrim) ... # Plot figure. @@ -116,9 +118,9 @@ def _compute_extent(self) -> Optional[ExtentTuple]: (points,) = self.data_for(series, "points") ext_calc.register_data(points) has_registered_data = True - if self.data_exists_for(series, "centroids"): + if self.data_exists_for(series, "centroid"): (centroid,) = self.data_for(series, "centroid") - ext_calc.register_data(centroid) + ext_calc.register_data(np.asarray(centroid).reshape(1, 2)) has_registered_data = True if self.figure_options.xlim: ext_calc.register_data(self.figure_options.xlim, dim=0) @@ -154,6 +156,11 @@ def _compute_discriminator_image( if extent is None: return None, None + # Get the discriminator and check if it is trained. If not, return. + discrim: BaseDiscriminator = self.supplementary_data["discriminator"] + if not discrim.is_trained(): + return None, None + # Compute discriminator resolution. extent_range = np.diff(np.asarray(extent).reshape(2, 2), axis=1).flatten() resolution = ( @@ -169,10 +176,7 @@ def _compute_discriminator_image( ) ] - # Get predictions for coordinates from the discriminator, if the discriminator is trained. - discrim: BaseDiscriminator = self.supplementary_data["discriminator"] - if not discrim.is_trained(): - return None, None + # Get predictions for coordinates from the discriminator. predictions = discrim.predict(coords) # Unwrap predictions into a 2D array @@ -202,7 +206,15 @@ def _default_options(cls) -> Options: discriminator_extent (Optional[ExtentTuple]): An optional tuple defining the extent of the image created by sampling from the discriminator. If ``None``, the extent tuple is computed using ``discriminator_multiplier``, ``discriminator_aspect_ratio``, and the - series-data ``points`` and ``centroids``. Defaults to ``None``. + series-data ``points`` and ``centroid``. Defaults to ``None``. + flag_misclassified (bool): Whether to mark misclassified IQ values from all ``points`` series + data, based on whether their series-name is not the same as the prediction from the + discriminator provided as supplementary data. If ``discriminator`` is not provided, + ``flag_misclassified`` has no effect. Defaults to True. + misclassified_symbol (str): Symbol for misclassified points, as a drawer-compatible string. + Defaults to "x". + misclassified_color (str | tuple): Color for misclassified points, as a drawer-compatible + string or RGB tuple. Defaults to "r". """ options = super()._default_options() @@ -213,6 +225,10 @@ def _default_options(cls) -> Options: options.discriminator_max_resolution = 1024 options.discriminator_alpha = 0.2 options.discriminator_extent = None + # Points options + options.flag_misclassified = True + options.misclassified_symbol = "x" + options.misclassified_color = "r" return options @classmethod @@ -224,6 +240,29 @@ def _default_figure_options(cls) -> Options: fig_opts.yval_unit = "arb." return fig_opts + def _misclassified_points(self, series_name: str, points: np.ndarray) -> Optional[np.ndarray]: + """Returns a list of IQ coordinates for points that are misclassified by the discriminator. + + Args: + series_name: The series-name to use as the expected discriminator label. If the discriminator + returns a prediction that doesn't equal ``series_name``, it is marked as misclassified. + points: The list of points to check for misclassification. + + Returns: + Optional[np.ndarray]: A NumPy array of IQ points, being those that were misclassified by the + discriminator. If the discriminator isn't set and trained, then `None` is returned. The array + may be empty. + """ + # Check if we have a discriminator, and if it is trained. If not, return None. + if "discriminator" not in self.supplementary_data: + return None + discrim: BaseDiscriminator = self.supplementary_data["discriminator"] + if not discrim.is_trained(): + return None + classifications = discrim.predict(points) + misclassified = np.argwhere(classifications != series_name) + return points[misclassified, :].reshape(-1, 2) + def _plot_figure(self): """Plots an IQ figure.""" # Plot discriminator first so that subsequent graphics change the automatic limits. This is a @@ -260,7 +299,7 @@ def _plot_figure(self): centroid[1], name=ser, legend=True, - zorder=3, + zorder=4, s=20, edgecolor="k", marker="o", @@ -278,3 +317,17 @@ def _plot_figure(self): alpha=0.2, marker=".", ) + if self.options.flag_misclassified: + misclassified_points = self._misclassified_points(ser, points) + if misclassified_points is not None: + self.drawer.scatter( + misclassified_points[:, 0], + misclassified_points[:, 1], + name="misclassified", + legend=False, + zorder=3, + s=10, + alpha=0.4, + marker=self.options.misclassified_symbol, + color=self.options.misclassified_color, + ) diff --git a/test/visualization/test_iq_plotter.py b/test/visualization/test_iq_plotter.py index c9d7c04dc5..883e8210b6 100644 --- a/test/visualization/test_iq_plotter.py +++ b/test/visualization/test_iq_plotter.py @@ -13,6 +13,7 @@ Test IQ plotter. """ +from itertools import product from test.base import QiskitExperimentsTestCase from typing import Any, Dict, List, Tuple @@ -23,31 +24,48 @@ from qiskit_experiments.visualization import IQPlotter, MplDrawer +class MockDiscriminatorNotTrainedException(Exception): + """Mock exception to be raised when :meth:`MockDiscriminator.predict` is called on an untrained + :class:`MockDiscriminator`.""" + + pass + + class MockDiscriminator(BaseDiscriminator): """A mock discriminator for testing.""" - def __init__(self, is_trained: bool = False): + def __init__( + self, is_trained: bool = False, n_states: int = 3, raise_predict_not_trained: bool = False + ): """Create a MockDiscriminator instance. Args: is_trained: Whether the discriminator is trained or not. Defaults to False. + n_states: The number of states/labels. Defaults to 3. + raise_predict_not_trained: Whether to raise an exception if :meth:`predict` is called and + :attr:`is_trained` is ``False``. Raises """ super().__init__() self._is_trained = is_trained + self._n_states = n_states + self._raise_predict_not_trained = raise_predict_not_trained self.predict_was_called = False """Whether :meth:`predict` was called at least once.""" def predict(self, data: List): - """Returns dummy predictions where everything has the label ``0``.""" + """Returns dummy predictions with random labels.""" self.predict_was_called = True + if self._raise_predict_not_trained and not self.is_trained(): + raise MockDiscriminatorNotTrainedException() if isinstance(data, list): - return [0] * len(data) - return [0] * data.shape[0] + return np.random.choice([f"{i}" for i in range(self._n_states)], len(data)).tolist() + return np.random.choice([f"{i}" for i in range(self._n_states)], data.shape[0]) def config(self) -> Dict[str, Any]: return { "predict_was_called": self.predict_was_called, "is_trained": self._is_trained, + "n_states": self._n_states, } def is_trained(self) -> bool: @@ -63,12 +81,15 @@ def _dummy_data( cls, is_trained: bool = True, n_series: int = 3, + raise_predict_not_trained: bool = False, ) -> Tuple[List, List, BaseDiscriminator]: """Create dummy data for the tests. Args: is_trained: Whether the discriminator should be trained or not. Defaults to True. n_series: The number of series to generate dummy data for. Defaults to 3. + raise_predict_not_trained: Passed to the discriminator :class:`MockDiscriminator` class. + Returns: tuple: the tuple ``(points, names, discrim)`` where ``points`` is a list of NumPy arrays of @@ -80,7 +101,9 @@ def _dummy_data( for i in range(n_series): points.append(np.random.rand(128, 2)) labels.append(f"{i}") - mock_discrim = MockDiscriminator(is_trained) + mock_discrim = MockDiscriminator( + is_trained, n_states=n_series, raise_predict_not_trained=raise_predict_not_trained + ) return points, labels, mock_discrim @ddt.data(True, False) @@ -105,15 +128,45 @@ def test_discriminator_trained(self, is_trained: bool): # Assert that MockDiscriminator.predict() was/wasn't called, depending on whether it was trained # or not. - self.assertEqual(is_trained, discrim.predict_was_called) - - def test_end_to_end(self): + self.assertEqual( + is_trained, + discrim.predict_was_called, + msg=f"Discriminator `predict()` {'was' if is_trained else 'was not'} meant to be called, " + f"but {'was' if discrim.predict_was_called else 'was not'} called. is_trained={is_trained}.", + ) + + @ddt.data(*list(product([True, False], repeat=3))) + def test_end_to_end(self, args): """Test end-to-end functionality of IQPlotter.""" + # Expand args + with_centroids, with_misclassified, with_discriminator = args + + # Create plotter and add data plotter = IQPlotter(MplDrawer()) - points, labels, discrim = self._dummy_data(is_trained=True) - for points, series_name in zip(points, labels): - centroid = np.mean(points, axis=0) - plotter.set_series_data(series_name, points=points, centroid=centroid) - plotter.set_supplementary_data(discriminator=discrim) + plotter.set_options(flag_misclassified=with_misclassified) + points, labels, discrim = self._dummy_data( + is_trained=True, + ) + for series_points, series_name in zip(points, labels): + plotter.set_series_data(series_name, points=series_points) + if with_centroids: + centroid = np.mean(series_points, axis=0) + plotter.set_series_data(series_name, centroid=centroid) + if with_discriminator: + plotter.set_supplementary_data(discriminator=discrim) + # Generate figure plotter.figure() + + # Verify that the correct number of series colours were created. If we are flagging misclassified + # points, we have one extra series. This assumes we are using `MplDrawer`. The discriminator + # should label each input as one of the series-names, which means we should have the same number + # of entries in `plotter.drawer._series` as colours queried from `MplDrawer._get_default_color` + # (stored in `MplDrawer._series`). + self.assertEqual( + len(plotter.drawer._series), + len(points) + (1 if with_misclassified and with_discriminator else 0), + msg="Number of series plotted by IQPlotter does not match the number of series from " + f"the dummy data. Expected {len(points)} but got {len(plotter.drawer._series)}. Series=" + f"{plotter.drawer._series}.", + )