Skip to content

Commit

Permalink
Add serialization test for plotter and nested drawer
Browse files Browse the repository at this point in the history
  • Loading branch information
conradhaupt committed Sep 16, 2022
1 parent 9f17386 commit c40594c
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 25 deletions.
11 changes: 10 additions & 1 deletion qiskit_experiments/visualization/drawers/base_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,15 @@ def figure(self):
def config(self) -> Dict:
"""Return the config dictionary for this drawer."""
options = dict((key, getattr(self._options, key)) for key in self._set_options)
plot_options = dict(
(key, getattr(self._plot_options, key)) for key in self._set_plot_options
)

return {"cls": type(self), "options": options}
return {
"cls": type(self),
"options": options,
"plot_options": plot_options,
}

def __json_encode__(self):
return self.config()
Expand All @@ -318,4 +325,6 @@ def __json_decode__(cls, value):
instance = cls()
if "options" in value:
instance.set_options(**value["options"])
if "plot_options" in value:
instance.set_plot_options(**value["plot_options"])
return instance
12 changes: 9 additions & 3 deletions qiskit_experiments/visualization/plotters/base_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,25 +420,31 @@ def _initialize_drawer(self):

def config(self) -> Dict:
"""Return the config dictionary for this drawing."""
# FIXME: Figure out how self._drawer should be serialized?
options = dict((key, getattr(self._options, key)) for key in self._set_options)
plot_options = dict(
(key, getattr(self._plot_options, key)) for key in self._set_plot_options
)
drawer = self.drawer.__json_encode__()

return {
"cls": type(self),
"options": options,
"plot_options": plot_options,
"drawer": drawer,
}

def __json_encode__(self):
return self.config()

@classmethod
def __json_decode__(cls, value):
# FIXME: Figure out how self._drawer:BaseDrawer be serialized?
drawer = value["drawer"]
## Process drawer as it's needed to create a plotter
drawer_values = value["drawer"]
# We expect a subclass of BaseDrawer
drawer_cls: BaseDrawer = drawer_values["cls"]
drawer = drawer_cls.__json_decode__(drawer_values)

# Create plotter instance
instance = cls(drawer)
if "options" in value:
instance.set_options(**value["options"])
Expand Down
26 changes: 25 additions & 1 deletion qiskit_experiments/visualization/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Configurable stylesheet for :class:`BasePlotter` and :class:`BaseDrawer`.
"""
from copy import copy
from typing import Tuple
from typing import Dict, Tuple

from qiskit_experiments.framework import Options

Expand Down Expand Up @@ -92,3 +92,27 @@ def merge(cls, style1: "PlotStyle", style2: "PlotStyle") -> "PlotStyle":
new_style = copy(style1)
new_style.update(style2)
return new_style

def config(self) -> Dict:
"""Return the config dictionary for this PlotStyle instance.
.. Note::
Validators are not currently supported
Returns:
dict: A dictionary containing the config of the plot style.
"""
return {
"cls": type(self),
**self._fields,
}

def __json_encode__(self):
return self.config()

@classmethod
def __json_decode__(cls, value):
kwargs = value
kwargs.pop("cls")
inst = cls(**kwargs)
return inst
43 changes: 43 additions & 0 deletions test/visualization/test_mpldrawer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Test Matplotlib Drawer.
"""

from copy import copy
from test.base import QiskitExperimentsTestCase

import matplotlib
from qiskit_experiments.visualization import MplDrawer


class TestMplDrawer(QiskitExperimentsTestCase):
"""Test MplDrawer."""

def test_end_to_end(self):
"""Test that MplDrawer generates something."""
drawer = MplDrawer()

# Draw dummy data
drawer.initialize_canvas()
drawer.draw_raw_data([0, 1, 2], [0, 1, 2], "seriesA")
drawer.draw_formatted_data([0, 1, 2], [0, 1, 2], [0.1, 0.1, 0.1], "seriesA")
drawer.draw_line([3, 2, 1], [1, 2, 3], "seriesB")
drawer.draw_confidence_interval([0, 1, 2, 3], [1, 2, 1, 2], [-1, -2, -1, -2], "seriesB")
drawer.draw_report(r"Dummy report text with LaTex $\beta$")

# Get result
fig = drawer.figure

# Check that
self.assertTrue(fig is not None)
self.assertTrue(isinstance(fig, matplotlib.pyplot.Figure))
120 changes: 100 additions & 20 deletions test/visualization/test_plotter_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,84 @@
from copy import copy
from test.base import QiskitExperimentsTestCase

from qiskit_experiments.visualization import PlotStyle
from qiskit_experiments.framework import Options
from qiskit_experiments.visualization import BasePlotter, PlotStyle

from .mock_drawer import MockDrawer
from .mock_plotter import MockPlotter


def dummy_plotter() -> BasePlotter:
"""Return a MockPlotter with dummy option values.
Returns:
BasePlotter: A dummy plotter.
"""
plotter = MockPlotter(MockDrawer())
# Set dummy plot options to update
plotter.set_plot_options(
xlabel="xlabel",
ylabel="ylabel",
figure_title="figure_title",
non_drawer_options="should not be set",
)
plotter.set_options(
style=PlotStyle(test_param="test_param", overwrite_param="new_overwrite_param_value")
)
return plotter


class TestPlotterAndDrawerIntegration(QiskitExperimentsTestCase):
"""Test plotter and drawer integration."""

def assertOptionsEqual(
self,
options1: Options,
options2: Options,
msg_prefix: str = "",
only_assert_for_intersection=False,
):
"""Asserts that two options are the same by checking each individual option.
This method is easier to read than a standard equality assertion as individual option names are
printed.
Args:
options1: The first Options instance to check.
options2: The second Options instance to check.
msg_prefix: A prefix to add before assert messages.
only_assert_for_intersection: If True, will only check options that are in both Options
instances. Defaults to False.
"""
# Get combined field names
if only_assert_for_intersection:
fields = set(options1._fields.keys()).intersection(set(options2._fields.keys()))
else:
fields = set(options1._fields.keys()).union(set(options2._fields.keys()))

# Check individual options.
for key in fields:
# Check if the option exists in both
self.assertTrue(
hasattr(options1, key),
msg=f"[{msg_prefix}] Expected field {key} in both, but only found in one: not in "
f"{options1}.",
)
self.assertTrue(
hasattr(options2, key),
msg=f"[{msg_prefix}] Expected field {key} in both, but only found in one: not in "
f"{options2}.",
)
self.assertEqual(
getattr(options1, key),
getattr(options2, key),
msg=f"[{msg_prefix}] Expected equal values for option '{key}': "
f"{getattr(options1, key),} vs {getattr(options2,key)}",
)

def test_plot_options(self):
"""Test copying and passing of plot-options between plotter and drawer."""
plotter = MockPlotter(MockDrawer())
plotter = dummy_plotter()

# Expected options
expected_plot_options = copy(plotter.drawer.plot_options)
Expand All @@ -39,36 +105,22 @@ def test_plot_options(self):
expected_custom_style = PlotStyle(
test_param="test_param", overwrite_param="new_overwrite_param_value"
)
plotter.set_options(style=expected_custom_style)
expected_full_style = PlotStyle.merge(
plotter.drawer.options.default_style, expected_custom_style
)
expected_plot_options.custom_style = expected_custom_style

# Set dummy plot options to update
plotter.set_plot_options(
xlabel="xlabel",
ylabel="ylabel",
figure_title="figure_title",
non_drawer_options="should not be set",
)
plotter.set_options(
style=PlotStyle(test_param="test_param", overwrite_param="new_overwrite_param_value")
)

# Call plotter.figure() to force passing of plot_options to drawer
plotter.figure()

## Test values
# Check style as this is a more detailed plot-option than others.
self.assertEqual(expected_full_style, plotter.drawer.style)

# Check individual plot-options.
for key, value in expected_plot_options._fields.items():
self.assertEqual(
getattr(plotter.drawer.plot_options, key),
value,
msg=f"Expected equal values for plot option '{key}'",
)
# Check individual plot-options, but only the intersection as those are the ones we expect to be
# updated.
self.assertOptionsEqual(expected_plot_options, plotter.drawer.plot_options, True)

# Coarse equality check of plot_options
self.assertEqual(
Expand All @@ -77,3 +129,31 @@ def test_plot_options(self):
msg=rf"expected_plot_options = {expected_plot_options}\nactual_plot_options ="
rf"{plotter.drawer.plot_options}",
)

def test_serializable(self):
"""Test that plotter is serializable."""
original_plotter = dummy_plotter()

def check_options(original, new):
"""Verifies that ``new`` plotter has the same options as ``original`` plotter."""
self.assertOptionsEqual(original.options, new.options, "options")
self.assertOptionsEqual(original.plot_options, new.plot_options, "plot_options")
self.assertOptionsEqual(original.drawer.options, new.drawer.options, "drawer.options")
self.assertOptionsEqual(
original.drawer.plot_options, new.drawer.plot_options, "drawer.plot_options"
)

## Check that plotter, BEFORE PLOTTING, survives serialization correctly.
# HACK: A dedicated JSON encoder and decoder class would be better.
# __json_<encode/decode>__ are not typically called, instead json.dumps etc. is called
encoded = original_plotter.__json_encode__()
decoded_plotter = original_plotter.__class__.__json_decode__(encoded)
check_options(original_plotter, decoded_plotter)

## Check that plotter, AFTER PLOTTING, survives serialization correctly.
original_plotter.figure()
# HACK: A dedicated JSON encoder and decoder class would be better.
# __json_<encode/decode>__ are not typically called, instead json.dumps etc. is called
encoded = original_plotter.__json_encode__()
decoded_plotter = original_plotter.__class__.__json_decode__(encoded)
check_options(original_plotter, decoded_plotter)

0 comments on commit c40594c

Please sign in to comment.