diff --git a/qiskit_experiments/curve_analysis/__init__.py b/qiskit_experiments/curve_analysis/__init__.py index 1eb7333dd6..d36bbe7f18 100644 --- a/qiskit_experiments/curve_analysis/__init__.py +++ b/qiskit_experiments/curve_analysis/__init__.py @@ -318,8 +318,8 @@ class AnalysisB(CurveAnalysis): compute custom quantities based on the raw fit parameters. See :ref:`curve_analysis_results` for details. Afterwards, the analysis draws several curves in the Matplotlib figure. -User can set custom drawer to the option ``curve_drawer``. -The drawer defaults to the :class:`MplCurveDrawer`. +Users can set a custom plotter in :class:`CurveAnalysis` classes, to customize +figures, by setting the :attr:`~CurveAnalysis.plotter` attribute. Finally, it returns the list of created analysis results and Matplotlib figure. diff --git a/qiskit_experiments/curve_analysis/base_curve_analysis.py b/qiskit_experiments/curve_analysis/base_curve_analysis.py index 2ef78d3970..a8befb06cc 100644 --- a/qiskit_experiments/curve_analysis/base_curve_analysis.py +++ b/qiskit_experiments/curve_analysis/base_curve_analysis.py @@ -16,15 +16,28 @@ import warnings from abc import ABC, abstractmethod -from typing import List, Dict, Union +from typing import Dict, List, Union import lmfit from qiskit_experiments.data_processing import DataProcessor from qiskit_experiments.data_processing.processor_library import get_processor -from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, Options, ExperimentData -from .curve_data import CurveData, ParameterRepr, CurveFitResult -from .visualization import MplCurveDrawer, BaseCurveDrawer +from qiskit_experiments.framework import ( + AnalysisResultData, + BaseAnalysis, + ExperimentData, + Options, +) +from qiskit_experiments.visualization import ( + BaseDrawer, + BasePlotter, + CurvePlotter, + LegacyCurveCompatDrawer, + MplDrawer, +) +from qiskit_experiments.warnings import deprecated_function + +from .curve_data import CurveData, CurveFitResult, ParameterRepr PARAMS_ENTRY_PREFIX = "@Parameters_" DATA_ENTRY_PREFIX = "@Data_" @@ -113,16 +126,28 @@ def models(self) -> List[lmfit.Model]: """Return fit models.""" @property - def drawer(self) -> BaseCurveDrawer: - """A short-cut for curve drawer instance.""" - return self._options.curve_drawer + def plotter(self) -> BasePlotter: + """A short-cut to the curve plotter instance.""" + return self._options.plotter + + @property + @deprecated_function( + last_version="0.6", + msg="Replaced by `plotter` from the new visualization submodule.", + ) + def drawer(self) -> BaseDrawer: + """A short-cut for curve drawer instance, if set. ``None`` otherwise.""" + if isinstance(self.plotter.drawer, LegacyCurveCompatDrawer): + return self.plotter.drawer._curve_drawer + else: + return None @classmethod def _default_options(cls) -> Options: """Return default analysis options. Analysis Options: - curve_drawer (BaseCurveDrawer): A curve drawer instance to visualize + plotter (BasePlotter): A curve plotter instance to visualize the analysis result. plot_raw_data (bool): Set ``True`` to draw processed data points, dataset without formatting, on canvas. This is ``False`` by default. @@ -168,7 +193,7 @@ def _default_options(cls) -> Options: """ options = super()._default_options() - options.curve_drawer = MplCurveDrawer() + options.plotter = CurvePlotter(MplDrawer()) options.plot_raw_data = False options.plot = True options.return_fit_parameters = True @@ -187,7 +212,7 @@ def _default_options(cls) -> Options: # Set automatic validator for particular option values options.set_validator(field="data_processor", validator_value=DataProcessor) - options.set_validator(field="curve_drawer", validator_value=BaseCurveDrawer) + options.set_validator(field="plotter", validator_value=BasePlotter) return options @@ -211,6 +236,27 @@ def set_options(self, **fields): ) fields["lmfit_options"] = fields.pop("curve_fitter_options") + # TODO remove this in Qiskit Experiments 0.6 + if "curve_drawer" in fields: + warnings.warn( + "The option 'curve_drawer' is replaced with 'plotter'. " + "This option will be removed in Qiskit Experiments 0.6.", + DeprecationWarning, + stacklevel=2, + ) + # Set the plotter drawer to `curve_drawer`. If `curve_drawer` is the right type, set it + # directly. If not, wrap it in a compatibility drawer. + if isinstance(fields["curve_drawer"], BaseDrawer): + plotter = self.options.plotter + plotter.drawer = fields.pop("curve_drawer") + fields["plotter"] = plotter + else: + drawer = fields["curve_drawer"] + compat_drawer = LegacyCurveCompatDrawer(drawer) + plotter = self.options.plotter + plotter.drawer = compat_drawer + fields["plotter"] = plotter + super().set_options(**fields) @abstractmethod diff --git a/qiskit_experiments/curve_analysis/composite_curve_analysis.py b/qiskit_experiments/curve_analysis/composite_curve_analysis.py index 38a76b06f3..3726f3e51c 100644 --- a/qiskit_experiments/curve_analysis/composite_curve_analysis.py +++ b/qiskit_experiments/curve_analysis/composite_curve_analysis.py @@ -15,17 +15,31 @@ """ # pylint: disable=invalid-name import warnings -from typing import Dict, List, Tuple, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import lmfit import numpy as np -from uncertainties import unumpy as unp, UFloat - -from qiskit_experiments.framework import BaseAnalysis, ExperimentData, AnalysisResultData, Options -from .base_curve_analysis import BaseCurveAnalysis, PARAMS_ENTRY_PREFIX +from uncertainties import UFloat +from uncertainties import unumpy as unp + +from qiskit_experiments.framework import ( + AnalysisResultData, + BaseAnalysis, + ExperimentData, + Options, +) +from qiskit_experiments.visualization import ( + BaseDrawer, + BasePlotter, + CurvePlotter, + LegacyCurveCompatDrawer, + MplDrawer, +) +from qiskit_experiments.warnings import deprecated_function + +from .base_curve_analysis import PARAMS_ENTRY_PREFIX, BaseCurveAnalysis from .curve_data import CurveFitResult from .utils import analysis_result_to_repr, eval_with_uncertainties -from .visualization import MplCurveDrawer, BaseCurveDrawer class CompositeCurveAnalysis(BaseAnalysis): @@ -124,9 +138,21 @@ def models(self) -> Dict[str, List[lmfit.Model]]: return models @property - def drawer(self) -> BaseCurveDrawer: - """A short-cut for curve drawer instance.""" - return self._options.curve_drawer + def plotter(self) -> BasePlotter: + """A short-cut to the plotter instance.""" + return self._options.plotter + + @property + @deprecated_function( + last_version="0.6", + msg="Replaced by `plotter` from the new visualization submodule.", + ) + def drawer(self) -> BaseDrawer: + """A short-cut for curve drawer instance, if set. ``None`` otherwise.""" + if hasattr(self._options, "curve_drawer"): + return self._options.curve_drawer + else: + return None def analyses( self, index: Optional[Union[str, int]] = None @@ -187,7 +213,7 @@ def _default_options(cls) -> Options: """Default analysis options. Analysis Options: - curve_drawer (BaseCurveDrawer): A curve drawer instance to visualize + plotter (BasePlotter): A plotter instance to visualize the analysis result. plot (bool): Set ``True`` to create figure for fit result. This is ``True`` by default. @@ -200,7 +226,7 @@ def _default_options(cls) -> Options: """ options = super()._default_options() options.update_options( - curve_drawer=MplCurveDrawer(), + plotter=CurvePlotter(MplDrawer()), plot=True, return_fit_parameters=True, return_data_points=False, @@ -208,11 +234,32 @@ def _default_options(cls) -> Options: ) # Set automatic validator for particular option values - options.set_validator(field="curve_drawer", validator_value=BaseCurveDrawer) + options.set_validator(field="plotter", validator_value=BasePlotter) return options def set_options(self, **fields): + # TODO remove this in Qiskit Experiments 0.6 + if "curve_drawer" in fields: + warnings.warn( + "The option 'curve_drawer' is replaced with 'plotter'. " + "This option will be removed in Qiskit Experiments 0.6.", + DeprecationWarning, + stacklevel=2, + ) + # Set the plotter drawer to `curve_drawer`. If `curve_drawer` is the right type, set it + # directly. If not, wrap it in a compatibility drawer. + if isinstance(fields["curve_drawer"], BaseDrawer): + plotter = self.options.plotter + plotter.drawer = fields.pop("curve_drawer") + fields["plotter"] = plotter + else: + drawer = fields["curve_drawer"] + compat_drawer = LegacyCurveCompatDrawer(drawer) + plotter = self.options.plotter + plotter.drawer = compat_drawer + fields["plotter"] = plotter + for field in fields: if not hasattr(self.options, field): warnings.warn( @@ -232,10 +279,6 @@ def _run_analysis( analysis_results = [] - # Initialize canvas - if self.options.plot: - self.drawer.initialize_canvas() - fit_dataset = {} for analysis in self._analyses: analysis._initialize(experiment_data) @@ -251,10 +294,10 @@ def _run_analysis( if self.options.plot and analysis.options.plot_raw_data: for model in analysis.models: sub_data = processed_data.get_subset_of(model._name) - self.drawer.draw_raw_data( - x_data=sub_data.x, - y_data=sub_data.y, - name=model._name + f"_{analysis.name}", + self.plotter.set_series_data( + model._name + f"_{analysis.name}", + x=sub_data.x, + y=sub_data.y, ) # Format data @@ -262,11 +305,11 @@ def _run_analysis( if self.options.plot: for model in analysis.models: sub_data = formatted_data.get_subset_of(model._name) - self.drawer.draw_formatted_data( - x_data=sub_data.x, - y_data=sub_data.y, - y_err_data=sub_data.y_err, - name=model._name + f"_{analysis.name}", + self.plotter.set_series_data( + model._name + f"_{analysis.name}", + x_formatted=sub_data.x, + y_formatted=sub_data.y, + y_formatted_err=sub_data.y_err, ) # Run fitting @@ -299,34 +342,30 @@ def _run_analysis( # Draw fit result if self.options.plot: - interp_x = np.linspace( + x_interp = np.linspace( np.min(formatted_data.x), np.max(formatted_data.x), num=100 ) for model in analysis.models: y_data_with_uncertainty = eval_with_uncertainties( - x=interp_x, + x=x_interp, model=model, params=fit_data.ufloat_params, ) - y_mean = unp.nominal_values(y_data_with_uncertainty) - # Draw fit line - self.drawer.draw_fit_line( - x_data=interp_x, - y_data=y_mean, - name=model._name + f"_{analysis.name}", + y_interp = unp.nominal_values(y_data_with_uncertainty) + # Add fit line data + self.plotter.set_series_data( + model._name + f"_{analysis.name}", + x_interp=x_interp, + y_interp=y_interp, ) if fit_data.covar is not None: - # Draw confidence intervals with different n_sigma - sigmas = unp.std_devs(y_data_with_uncertainty) - if np.isfinite(sigmas).all(): - for n_sigma, alpha in self.drawer.options.plot_sigma: - self.drawer.draw_confidence_interval( - x_data=interp_x, - y_ub=y_mean + n_sigma * sigmas, - y_lb=y_mean - n_sigma * sigmas, - name=model._name + f"_{analysis.name}", - alpha=alpha, - ) + # Add confidence interval data + y_interp_err = unp.std_devs(y_data_with_uncertainty) + if np.isfinite(y_interp_err).all(): + self.plotter.set_series_data( + model._name + f"_{analysis.name}", + y_interp_err=y_interp_err, + ) # Add raw data points if self.options.return_data_points: @@ -355,10 +394,8 @@ def _run_analysis( for group, fit_data in fit_dataset.items(): chisqs.append(r"reduced-$\chi^2$ = " + f"{fit_data.reduced_chisq: .4g} ({group})") report += "\n".join(chisqs) - self.drawer.draw_fit_report(description=report) + self.plotter.set_supplementary_data(report_text=report) - # Finalize canvas - self.drawer.format_canvas() - return analysis_results, [self.drawer.figure] + return analysis_results, [self.plotter.figure()] return analysis_results, [] diff --git a/qiskit_experiments/curve_analysis/curve_analysis.py b/qiskit_experiments/curve_analysis/curve_analysis.py index 78a3dae891..6c3022c12a 100644 --- a/qiskit_experiments/curve_analysis/curve_analysis.py +++ b/qiskit_experiments/curve_analysis/curve_analysis.py @@ -140,7 +140,7 @@ def __init__( ) # pylint: disable=no-member models = [] - plot_options = {} + series_params = {} for series_def in self.__series__: models.append( lmfit.Model( @@ -149,12 +149,13 @@ def __init__( data_sort_key=series_def.filter_kwargs, ) ) - plot_options[series_def.name] = { + series_params[series_def.name] = { "color": series_def.plot_color, "symbol": series_def.plot_symbol, "canvas": series_def.canvas, + "label": series_def.name, } - self.drawer.set_options(plot_options=plot_options) + self.plotter.set_figure_options(series_params=series_params) self._models = models or [] self._name = name or self.__class__.__name__ @@ -467,10 +468,6 @@ def _run_analysis( self._initialize(experiment_data) analysis_results = [] - # Initialize canvas - if self.options.plot: - self.drawer.initialize_canvas() - # Run data processing processed_data = self._run_data_processing( raw_data=experiment_data.data(), @@ -480,10 +477,10 @@ def _run_analysis( if self.options.plot and self.options.plot_raw_data: for model in self._models: sub_data = processed_data.get_subset_of(model._name) - self.drawer.draw_raw_data( - x_data=sub_data.x, - y_data=sub_data.y, - name=model._name, + self.plotter.set_series_data( + model._name, + x=sub_data.x, + y=sub_data.y, ) # for backward compatibility, will be removed in 0.4. self.__processed_data_set["raw_data"] = processed_data @@ -493,11 +490,11 @@ def _run_analysis( if self.options.plot: for model in self._models: sub_data = formatted_data.get_subset_of(model._name) - self.drawer.draw_formatted_data( - x_data=sub_data.x, - y_data=sub_data.y, - y_err_data=sub_data.y_err, - name=model._name, + self.plotter.set_series_data( + model._name, + x_formatted=sub_data.x, + y_formatted=sub_data.y, + y_formatted_err=sub_data.y_err, ) # for backward compatibility, will be removed in 0.4. self.__processed_data_set["fit_ready"] = formatted_data @@ -553,32 +550,28 @@ def _run_analysis( # This is the case when fit model exist but no data to fit is provided. # For example, experiment may omit experimenting with some setting. continue - interp_x = np.linspace(np.min(sub_data.x), np.max(sub_data.x), num=100) + x_interp = np.linspace(np.min(sub_data.x), np.max(sub_data.x), num=100) y_data_with_uncertainty = eval_with_uncertainties( - x=interp_x, + x=x_interp, model=model, params=fit_data.ufloat_params, ) - y_mean = unp.nominal_values(y_data_with_uncertainty) - # Draw fit line - self.drawer.draw_fit_line( - x_data=interp_x, - y_data=y_mean, - name=model._name, + y_interp = unp.nominal_values(y_data_with_uncertainty) + # Add fit line data + self.plotter.set_series_data( + model._name, + x_interp=x_interp, + y_interp=y_interp, ) if fit_data.covar is not None: - # Draw confidence intervals with different n_sigma - sigmas = unp.std_devs(y_data_with_uncertainty) - if np.isfinite(sigmas).all(): - for n_sigma, alpha in self.drawer.options.plot_sigma: - self.drawer.draw_confidence_interval( - x_data=interp_x, - y_ub=y_mean + n_sigma * sigmas, - y_lb=y_mean - n_sigma * sigmas, - name=model._name, - alpha=alpha, - ) + # Add confidence interval data + y_interp_err = unp.std_devs(y_data_with_uncertainty) + if np.isfinite(y_interp_err).all(): + self.plotter.set_series_data( + model._name, + y_interp_err=y_interp_err, + ) # Write fitting report report_description = "" @@ -586,7 +579,7 @@ def _run_analysis( if isinstance(res.value, (float, UFloat)): report_description += f"{analysis_result_to_repr(res)}\n" report_description += r"reduced-$\chi^2$ = " + f"{fit_data.reduced_chisq: .4g}" - self.drawer.draw_fit_report(description=report_description) + self.plotter.set_supplementary_data(report_text=report_description) # Add raw data points if self.options.return_data_points: @@ -596,8 +589,7 @@ def _run_analysis( # Finalize plot if self.options.plot: - self.drawer.format_canvas() - return analysis_results, [self.drawer.figure] + return analysis_results, [self.plotter.figure()] return analysis_results, [] diff --git a/qiskit_experiments/curve_analysis/standard_analysis/bloch_trajectory.py b/qiskit_experiments/curve_analysis/standard_analysis/bloch_trajectory.py index 098f7e9a70..f3c8909a25 100644 --- a/qiskit_experiments/curve_analysis/standard_analysis/bloch_trajectory.py +++ b/qiskit_experiments/curve_analysis/standard_analysis/bloch_trajectory.py @@ -138,7 +138,7 @@ def _default_options(cls): input_key="counts", data_actions=[dp.Probability("1"), dp.BasisExpectationValue()], ) - default_options.curve_drawer.set_options( + default_options.plotter.set_figure_options( xlabel="Flat top width", ylabel="Pauli expectation values", xval_unit="s", diff --git a/qiskit_experiments/curve_analysis/standard_analysis/error_amplification_analysis.py b/qiskit_experiments/curve_analysis/standard_analysis/error_amplification_analysis.py index d3aa479698..116430f2d9 100644 --- a/qiskit_experiments/curve_analysis/standard_analysis/error_amplification_analysis.py +++ b/qiskit_experiments/curve_analysis/standard_analysis/error_amplification_analysis.py @@ -105,7 +105,7 @@ def _default_options(cls): considered as good. Defaults to :math:`\pi/2`. """ default_options = super()._default_options() - default_options.curve_drawer.set_options( + default_options.plotter.set_figure_options( xlabel="Number of gates (n)", ylabel="Population", ylim=(0, 1.0), diff --git a/qiskit_experiments/curve_analysis/standard_analysis/gaussian.py b/qiskit_experiments/curve_analysis/standard_analysis/gaussian.py index 2c7f0cb442..2a17f54ac0 100644 --- a/qiskit_experiments/curve_analysis/standard_analysis/gaussian.py +++ b/qiskit_experiments/curve_analysis/standard_analysis/gaussian.py @@ -76,7 +76,7 @@ def __init__( @classmethod def _default_options(cls) -> Options: options = super()._default_options() - options.curve_drawer.set_options( + options.plotter.set_figure_options( xlabel="Frequency", ylabel="Signal (arb. units)", xval_unit="Hz", diff --git a/qiskit_experiments/curve_analysis/standard_analysis/resonance.py b/qiskit_experiments/curve_analysis/standard_analysis/resonance.py index a07fc5a67e..558de514d8 100644 --- a/qiskit_experiments/curve_analysis/standard_analysis/resonance.py +++ b/qiskit_experiments/curve_analysis/standard_analysis/resonance.py @@ -76,7 +76,7 @@ def __init__( @classmethod def _default_options(cls) -> Options: options = super()._default_options() - options.curve_drawer.set_options( + options.plotter.set_figure_options( xlabel="Frequency", ylabel="Signal (arb. units)", xval_unit="Hz", diff --git a/qiskit_experiments/curve_analysis/visualization/__init__.py b/qiskit_experiments/curve_analysis/visualization/__init__.py index 0c85169e32..42a7c838f2 100644 --- a/qiskit_experiments/curve_analysis/visualization/__init__.py +++ b/qiskit_experiments/curve_analysis/visualization/__init__.py @@ -10,22 +10,17 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. """ -Visualization functions -""" +Deprecated Visualization Functions. -from enum import Enum +.. note:: + This module is deprecated and replaced by :mod:`qiskit_experiments.visualization`. The new + visualization module contains classes to manage drawing to a figure canvas and plotting data + obtained from an experiment or analysis. +""" +from . import fit_result_plotters from .base_drawer import BaseCurveDrawer +from .curves import plot_curve_fit, plot_errorbar, plot_scatter +from .fit_result_plotters import FitResultPlotters from .mpl_drawer import MplCurveDrawer - -from . import fit_result_plotters -from .curves import plot_scatter, plot_errorbar, plot_curve_fit from .style import PlotterStyle - - -# pylint: disable=invalid-name -class FitResultPlotters(Enum): - """Map the plotter name to the plotters.""" - - mpl_single_canvas = fit_result_plotters.MplDrawSingleCanvas - mpl_multiv_canvas = fit_result_plotters.MplDrawMultiCanvasVstack diff --git a/qiskit_experiments/curve_analysis/visualization/base_drawer.py b/qiskit_experiments/curve_analysis/visualization/base_drawer.py index 2534663efe..25af22c764 100644 --- a/qiskit_experiments/curve_analysis/visualization/base_drawer.py +++ b/qiskit_experiments/curve_analysis/visualization/base_drawer.py @@ -13,11 +13,17 @@ """Curve drawer abstract class.""" from abc import ABC, abstractmethod -from typing import Dict, Sequence, Optional +from typing import Dict, Optional, Sequence from qiskit_experiments.framework import Options +from qiskit_experiments.warnings import deprecated_class +@deprecated_class( + "0.6", + msg="Plotting and drawing of analysis figures has been moved to the new " + "`qiskit_experiments.visualization` module.", +) class BaseCurveDrawer(ABC): """Abstract class for the serializable Qiskit Experiments curve drawer. diff --git a/qiskit_experiments/curve_analysis/visualization/curves.py b/qiskit_experiments/curve_analysis/visualization/curves.py index 6c7ef16ec6..d03e8d1dcf 100644 --- a/qiskit_experiments/curve_analysis/visualization/curves.py +++ b/qiskit_experiments/curve_analysis/visualization/curves.py @@ -12,14 +12,21 @@ """ Plotting functions for experiment analysis """ -from typing import Callable, List, Tuple, Optional +from typing import Callable, List, Optional, Tuple + import numpy as np from uncertainties import unumpy as unp from qiskit_experiments.curve_analysis.curve_data import FitData from qiskit_experiments.framework.matplotlib import get_non_gui_ax +from qiskit_experiments.warnings import deprecated_function +@deprecated_function( + "0.6", + msg="Plotting and drawing functionality has been moved to the new " + "`qiskit_experiments.visualization` module.", +) def plot_curve_fit( func: Callable, result: FitData, @@ -94,6 +101,11 @@ def plot_curve_fit( return ax +@deprecated_function( + "0.6", + msg="Plotting and drawing functionality has been moved to the new " + "`qiskit_experiments.visualization` module.", +) def plot_scatter( xdata: np.ndarray, ydata: np.ndarray, @@ -138,6 +150,11 @@ def plot_scatter( return ax +@deprecated_function( + "0.6", + msg="Plotting and drawing functionality has been moved to the new " + "`qiskit_experiments.visualization` module.", +) def plot_errorbar( xdata: np.ndarray, ydata: np.ndarray, diff --git a/qiskit_experiments/curve_analysis/visualization/fit_result_plotters.py b/qiskit_experiments/curve_analysis/visualization/fit_result_plotters.py index 537c3ed759..b597368fd7 100644 --- a/qiskit_experiments/curve_analysis/visualization/fit_result_plotters.py +++ b/qiskit_experiments/curve_analysis/visualization/fit_result_plotters.py @@ -22,20 +22,28 @@ """ from collections import defaultdict -from typing import List, Dict, Optional +from enum import Enum +from typing import Dict, List, Optional -import uncertainties import numpy as np +import uncertainties from matplotlib.ticker import FuncFormatter from qiskit.utils import detach_prefix -from qiskit_experiments.curve_analysis.curve_data import SeriesDef, FitData, CurveData +from qiskit_experiments.curve_analysis.curve_data import CurveData, FitData, SeriesDef from qiskit_experiments.framework import AnalysisResultData from qiskit_experiments.framework.matplotlib import get_non_gui_ax -from .curves import plot_scatter, plot_errorbar, plot_curve_fit +from qiskit_experiments.warnings import deprecated_class, deprecated_function + +from .curves import plot_curve_fit, plot_errorbar, plot_scatter from .style import PlotterStyle +@deprecated_class( + "0.6", + msg="Plotting and drawing of analysis figures has been moved to the new " + "`qiskit_experiments.visualization` module.", +) class MplDrawSingleCanvas: """A plotter to draw a single canvas figure for fit result.""" @@ -136,6 +144,11 @@ def draw( return figure +@deprecated_class( + "0.6", + msg="Plotting and drawing of analysis figures has been replaced with the new" + "`qiskit_experiments.visualization` module.", +) class MplDrawMultiCanvasVstack: """A plotter to draw a vertically stacked multi canvas figure for fit result.""" @@ -288,6 +301,11 @@ def draw( return figure +@deprecated_function( + "0.6", + msg="Plotting and drawing of analysis figures has been replaced with the new" + "`qiskit_experiments.visualization` module.", +) def draw_single_curve_mpl( axis: "matplotlib.axes.Axes", series_def: SeriesDef, @@ -341,6 +359,11 @@ def draw_single_curve_mpl( ) +@deprecated_function( + "0.6", + msg="Plotting and drawing of analysis figures has been replaced with the new" + "`qiskit_experiments.visualization` module.", +) def write_fit_report(result_entries: List[AnalysisResultData]) -> str: """A function that generates fit reports documentation from list of data. @@ -401,3 +424,16 @@ def format_val(float_val: float) -> str: analysis_description += f"{res.name} = {value_repr}\n" return analysis_description + + +# pylint: disable=invalid-name +@deprecated_class( + "0.6", + msg="Plotting and drawing of analysis figures has been moved to the new " + "`qiskit_experiments.visualization` module.", +) +class FitResultPlotters(Enum): + """Map the plotter name to the plotters.""" + + mpl_single_canvas = MplDrawSingleCanvas + mpl_multiv_canvas = MplDrawMultiCanvasVstack diff --git a/qiskit_experiments/curve_analysis/visualization/mpl_drawer.py b/qiskit_experiments/curve_analysis/visualization/mpl_drawer.py index 5d9dfbe65c..439e562255 100644 --- a/qiskit_experiments/curve_analysis/visualization/mpl_drawer.py +++ b/qiskit_experiments/curve_analysis/visualization/mpl_drawer.py @@ -12,21 +12,27 @@ """Curve drawer for matplotlib backend.""" -from typing import Sequence, Optional, Tuple +from typing import Optional, Sequence, Tuple import numpy as np from matplotlib.axes import Axes -from matplotlib.figure import Figure -from matplotlib.ticker import ScalarFormatter, Formatter from matplotlib.cm import tab10 +from matplotlib.figure import Figure from matplotlib.markers import MarkerStyle - +from matplotlib.ticker import Formatter, ScalarFormatter from qiskit.utils import detach_prefix + from qiskit_experiments.framework.matplotlib import get_non_gui_ax +from qiskit_experiments.warnings import deprecated_class from .base_drawer import BaseCurveDrawer +@deprecated_class( + "0.6", + msg="Plotting and drawing of analysis figures has been replaced with the new" + "`qiskit_experiments.visualization` module.", +) class MplCurveDrawer(BaseCurveDrawer): """Curve drawer for MatplotLib backend.""" diff --git a/qiskit_experiments/curve_analysis/visualization/style.py b/qiskit_experiments/curve_analysis/visualization/style.py index c63e0766bb..2248ed6fc2 100644 --- a/qiskit_experiments/curve_analysis/visualization/style.py +++ b/qiskit_experiments/curve_analysis/visualization/style.py @@ -13,9 +13,16 @@ Configurable stylesheet. """ import dataclasses -from typing import Tuple, List +from typing import List, Tuple +from qiskit_experiments.warnings import deprecated_class + +@deprecated_class( + "0.6", + msg="Plotting and drawing of analysis figures has been replaced with the new" + "`qiskit_experiments.visualization` module.", +) @dataclasses.dataclass class PlotterStyle: """A stylesheet for curve analysis figure.""" diff --git a/qiskit_experiments/library/characterization/analysis/cr_hamiltonian_analysis.py b/qiskit_experiments/library/characterization/analysis/cr_hamiltonian_analysis.py index c6dd75c29e..62831f3a8e 100644 --- a/qiskit_experiments/library/characterization/analysis/cr_hamiltonian_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/cr_hamiltonian_analysis.py @@ -17,6 +17,7 @@ import qiskit_experiments.curve_analysis as curve from qiskit_experiments.framework import AnalysisResultData +from qiskit_experiments.visualization import PlotStyle class CrossResonanceHamiltonianAnalysis(curve.CompositeCurveAnalysis): @@ -60,8 +61,17 @@ def __init__(self): def _default_options(cls): """Return the default analysis options.""" default_options = super()._default_options() - default_options.curve_drawer.set_options( + default_options.plotter.set_options( subplots=(3, 1), + style=PlotStyle( + { + "figsize": (8, 10), + "legend_loc": "lower right", + "textbox_rel_pos": (0.28, -0.10), + } + ), + ) + default_options.plotter.set_figure_options( xlabel="Flat top width", ylabel=[ r"$\langle$X(t)$\rangle$", @@ -69,17 +79,44 @@ def _default_options(cls): r"$\langle$Z(t)$\rangle$", ], xval_unit="s", - figsize=(8, 10), - legend_loc="lower right", - fit_report_rpos=(0.28, -0.10), ylim=(-1, 1), - plot_options={ - "x_ctrl0": {"color": "blue", "symbol": "o", "canvas": 0}, - "y_ctrl0": {"color": "blue", "symbol": "o", "canvas": 1}, - "z_ctrl0": {"color": "blue", "symbol": "o", "canvas": 2}, - "x_ctrl1": {"color": "red", "symbol": "^", "canvas": 0}, - "y_ctrl1": {"color": "red", "symbol": "^", "canvas": 1}, - "z_ctrl1": {"color": "red", "symbol": "^", "canvas": 2}, + series_params={ + "x_ctrl0": { + "canvas": 0, + "color": "blue", + "label": "X (ctrl0)", + "symbol": "o", + }, + "y_ctrl0": { + "canvas": 1, + "color": "blue", + "label": "Y (ctrl0)", + "symbol": "o", + }, + "z_ctrl0": { + "canvas": 2, + "color": "blue", + "label": "Z (ctrl0)", + "symbol": "o", + }, + "x_ctrl1": { + "canvas": 0, + "color": "red", + "label": "X (ctrl1)", + "symbol": "^", + }, + "y_ctrl1": { + "canvas": 1, + "color": "red", + "label": "Y (ctrl1)", + "symbol": "^", + }, + "z_ctrl1": { + "canvas": 2, + "color": "red", + "label": "Z (ctrl1)", + "symbol": "^", + }, }, ) diff --git a/qiskit_experiments/library/characterization/analysis/drag_analysis.py b/qiskit_experiments/library/characterization/analysis/drag_analysis.py index 35ad86537b..1ca7c6a168 100644 --- a/qiskit_experiments/library/characterization/analysis/drag_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/drag_analysis.py @@ -84,7 +84,7 @@ def _default_options(cls): descriptions of analysis options. """ default_options = super()._default_options() - default_options.curve_drawer.set_options( + default_options.plotter.set_figure_options( xlabel="Beta", ylabel="Signal (arb. units)", ) diff --git a/qiskit_experiments/library/characterization/analysis/ramsey_xy_analysis.py b/qiskit_experiments/library/characterization/analysis/ramsey_xy_analysis.py index 5765b53f7c..22e093a30b 100644 --- a/qiskit_experiments/library/characterization/analysis/ramsey_xy_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/ramsey_xy_analysis.py @@ -88,7 +88,7 @@ def _default_options(cls): descriptions of analysis options. """ default_options = super()._default_options() - default_options.curve_drawer.set_options( + default_options.plotter.set_figure_options( xlabel="Delay", ylabel="Signal (arb. units)", xval_unit="s", diff --git a/qiskit_experiments/library/characterization/analysis/resonator_spectroscopy_analysis.py b/qiskit_experiments/library/characterization/analysis/resonator_spectroscopy_analysis.py index 0e306b988f..86760195d4 100644 --- a/qiskit_experiments/library/characterization/analysis/resonator_spectroscopy_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/resonator_spectroscopy_analysis.py @@ -49,7 +49,8 @@ def _run_analysis( if self.options.plot_iq_data: axis = get_non_gui_ax() figure = axis.get_figure() - figure.set_size_inches(*self.drawer.options.figsize) + # TODO: Move plotting to a new IQPlotter class. + figure.set_size_inches(*self.plotter.drawer.style["figsize"]) iqs = [] @@ -68,12 +69,12 @@ def _run_analysis( iqs = np.vstack(iqs) axis.scatter(iqs[:, 0], iqs[:, 1], color="b") axis.set_xlabel( - "In phase [arb. units]", fontsize=self.drawer.options.axis_label_size + "In phase [arb. units]", fontsize=self.plotter.drawer.style["axis_label_size"] ) axis.set_ylabel( - "Quadrature [arb. units]", fontsize=self.drawer.options.axis_label_size + "Quadrature [arb. units]", fontsize=self.plotter.drawer.style["axis_label_size"] ) - axis.tick_params(labelsize=self.drawer.options.tick_label_size) + axis.tick_params(labelsize=self.plotter.drawer.style["tick_label_size"]) axis.grid(True) figures.append(figure) diff --git a/qiskit_experiments/library/characterization/analysis/t1_analysis.py b/qiskit_experiments/library/characterization/analysis/t1_analysis.py index f88020acb7..91a69cc2a5 100644 --- a/qiskit_experiments/library/characterization/analysis/t1_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/t1_analysis.py @@ -34,7 +34,7 @@ class T1Analysis(curve.DecayAnalysis): def _default_options(cls) -> Options: """Default analysis options.""" options = super()._default_options() - options.curve_drawer.set_options( + options.plotter.set_figure_options( xlabel="Delay", ylabel="P(1)", xval_unit="s", @@ -85,7 +85,7 @@ class T1KerneledAnalysis(curve.DecayAnalysis): def _default_options(cls) -> Options: """Default analysis options.""" options = super()._default_options() - options.curve_drawer.set_options( + options.plotter.set_figure_options( xlabel="Delay", ylabel="Normalized Projection on the Main Axis", xval_unit="s", diff --git a/qiskit_experiments/library/characterization/analysis/t2hahn_analysis.py b/qiskit_experiments/library/characterization/analysis/t2hahn_analysis.py index 09e1839cb4..fa1f3c958e 100644 --- a/qiskit_experiments/library/characterization/analysis/t2hahn_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/t2hahn_analysis.py @@ -34,7 +34,7 @@ class T2HahnAnalysis(curve.DecayAnalysis): def _default_options(cls) -> Options: """Default analysis options.""" options = super()._default_options() - options.curve_drawer.set_options( + options.plotter.set_figure_options( xlabel="Delay", ylabel="P(0)", xval_unit="s", diff --git a/qiskit_experiments/library/characterization/analysis/t2ramsey_analysis.py b/qiskit_experiments/library/characterization/analysis/t2ramsey_analysis.py index 3073e1793c..d33d60a478 100644 --- a/qiskit_experiments/library/characterization/analysis/t2ramsey_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/t2ramsey_analysis.py @@ -29,7 +29,7 @@ class T2RamseyAnalysis(curve.DampedOscillationAnalysis): def _default_options(cls) -> Options: """Default analysis options.""" options = super()._default_options() - options.curve_drawer.set_options( + options.plotter.set_figure_options( xlabel="Delay", ylabel="P(1)", xval_unit="s", diff --git a/qiskit_experiments/library/characterization/rabi.py b/qiskit_experiments/library/characterization/rabi.py index 5a99bcd8ac..5d880fe755 100644 --- a/qiskit_experiments/library/characterization/rabi.py +++ b/qiskit_experiments/library/characterization/rabi.py @@ -110,7 +110,7 @@ def __init__( result_parameters=[ParameterRepr("freq", self.__outcome__)], normalization=True, ) - self.analysis.drawer.set_options( + self.analysis.plotter.set_figure_options( xlabel="Amplitude", ylabel="Signal (arb. units)", ) diff --git a/qiskit_experiments/library/quantum_volume/qv_analysis.py b/qiskit_experiments/library/quantum_volume/qv_analysis.py index 94d3624db2..c250bf9765 100644 --- a/qiskit_experiments/library/quantum_volume/qv_analysis.py +++ b/qiskit_experiments/library/quantum_volume/qv_analysis.py @@ -20,7 +20,7 @@ import numpy as np import uncertainties from qiskit_experiments.exceptions import AnalysisError -from qiskit_experiments.curve_analysis import plot_scatter, plot_errorbar +from qiskit_experiments.curve_analysis.visualization import plot_scatter, plot_errorbar from qiskit_experiments.framework import ( BaseAnalysis, AnalysisResultData, diff --git a/qiskit_experiments/library/randomized_benchmarking/rb_analysis.py b/qiskit_experiments/library/randomized_benchmarking/rb_analysis.py index 0fb9c83764..2bc67c8bd5 100644 --- a/qiskit_experiments/library/randomized_benchmarking/rb_analysis.py +++ b/qiskit_experiments/library/randomized_benchmarking/rb_analysis.py @@ -96,7 +96,7 @@ def _default_options(cls): 2Q RB is corrected to exclude the depolarization of underlying 1Q channels. """ default_options = super()._default_options() - default_options.curve_drawer.set_options( + default_options.plotter.set_figure_options( xlabel="Clifford Length", ylabel="P(0)", ) diff --git a/qiskit_experiments/visualization/__init__.py b/qiskit_experiments/visualization/__init__.py new file mode 100644 index 0000000000..3f35e221f8 --- /dev/null +++ b/qiskit_experiments/visualization/__init__.py @@ -0,0 +1,70 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2021. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +r""" +========================================================= +Visualization (:mod:`qiskit_experiments.visualization`) +========================================================= + +.. currentmodule:: qiskit_experiments.visualization + +Visualization provides plotting functionality for creating figures from experiment and analysis results. +This includes plotter and drawer classes to plot data in :py:class:`CurveAnalysis` and its subclasses. +Plotters inherit from :class:`BasePlotter` and define a type of figure that may be generated from +experiment or analysis data. For example, the results from :class:`CurveAnalysis` --- or any other +experiment where results are plotted against a single parameter (i.e., :math:`x`) --- can be plotted +using the :class:`CurvePlotter` class, which plots X-Y-like values. + +These plotter classes act as a bridge (from the common bridge pattern in software development) between +analysis classes (or even users) and plotting backends such as Matplotlib. Drawers are the backends, with +a common interface defined in :class:`BaseDrawer`. Though Matplotlib is the only officially supported +plotting backend in Qiskit Experiments (i.e., through :class:`MplDrawer`), custom drawers can be +implemented by users to use alternative backends. As long as the backend is a subclass of +:class:`BaseDrawer`, and implements all the necessary functionality, all plotters should be able to +generate figures with the alternative backend. + +To collate style parameters together, plotters and drawers store instances of the :class:`PlotStyle` +class. These instances can be merged and updated, so that default styles can have their values +overwritten. + +Plotter Library +============== + +.. autosummary:: + :toctree: ../stubs/ + :template: autosummary/class.rst + + BasePlotter + CurvePlotter + +Drawer Library +============== + +.. autosummary:: + :toctree: ../stubs/ + :template: autosummary/class.rst + + BaseDrawer + MplDrawer + +Plotting Style +============== + +.. autosummary:: + :toctree: ../stubs/ + :template: autosummary/class.rst + + PlotStyle +""" + +from .drawers import BaseDrawer, LegacyCurveCompatDrawer, MplDrawer +from .plotters import BasePlotter, CurvePlotter +from .style import PlotStyle diff --git a/qiskit_experiments/visualization/drawers/__init__.py b/qiskit_experiments/visualization/drawers/__init__.py new file mode 100644 index 0000000000..1e67ce2688 --- /dev/null +++ b/qiskit_experiments/visualization/drawers/__init__.py @@ -0,0 +1,16 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +"""Drawers submodule, defining interfaces to figure backends.""" + +from .base_drawer import BaseDrawer +from .legacy_curve_compat_drawer import LegacyCurveCompatDrawer +from .mpl_drawer import MplDrawer diff --git a/qiskit_experiments/visualization/drawers/base_drawer.py b/qiskit_experiments/visualization/drawers/base_drawer.py new file mode 100644 index 0000000000..61f2917f9f --- /dev/null +++ b/qiskit_experiments/visualization/drawers/base_drawer.py @@ -0,0 +1,454 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2021, 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Drawer abstract class.""" + +from abc import ABC, abstractmethod +from typing import Dict, Optional, Sequence, Tuple + +from qiskit_experiments.framework import Options + +from ..style import PlotStyle + + +class BaseDrawer(ABC): + """Abstract class for the serializable Qiskit Experiments figure drawer. + + A drawer may be implemented by different drawer backends such as matplotlib or Plotly. Sub-classes + that wrap these backends by subclassing :class:`BaseDrawer` must implement the following abstract + methods. + + initialize_canvas + + This method should implement a protocol to initialize a drawer canvas with user input ``axis`` + object. Note that ``drawer`` supports visualization of experiment results in multiple canvases + tiled into N (row) x M (column) inset grids, which is specified in the option ``subplots``. By + default, this is N=1, M=1 and thus no inset grid will be initialized. The data points to draw + might be provided with a canvas number defined in :attr:`SeriesDef.canvas` which defaults to + ``None``, i.e. no-inset grids. + + This method should first check the drawer options (:attr:`options`) for the axis object and + initialize the axis only when it is not provided by the options. Once axis is initialized, this + is set to the instance member ``self._axis``. + + format_canvas + + This method formats the appearance of the canvas. Typically, it updates axis and tick labels. + Note that the axis SI unit may be specified in the drawer figure_options. In this case, axis + numbers should be auto-scaled with the unit prefix. + + Drawing Methods: + + scatter + + This method draws scatter points on the canvas, like a scatter-plot, with optional error-bars + in both the X and Y axes. + + line + + This method plots a line from provided X and Y values. + + filled_y_area + + This method plots a shaped region bounded by upper and lower Y-values. This method is + typically called with interpolated x and a pair of y values that represent the upper and + lower bound within certain confidence interval. If this is called multiple times, it may be + necessary to set the transparency so that overlapping regions can be distinguished. + + filled_x_area + + This method plots a shaped region bounded by upper and lower X-values, as a function of + Y-values. This method is a rotated analogue of :meth:`filled_y_area`. + + textbox + + This method draws a text-box on the canvas, which is a rectangular region containing some + text. + + Options and Figure Options + ========================== + + Drawers have both :attr:`options` and :attr:`figure_options` available to set parameters that define + how to draw and what is drawn, respectively. :class:`BasePlotter` is similar in that it also has + ``options`` and ``figure_options`. The former contains class-specific variables that define how an + instance behaves. The latter contains figure-specific variables that typically contain values that + are drawn on the canvas, such as text. For details on the difference between the two sets of options, + see the documentation for :class:`BasePlotter`. + + .. note:: + If a drawer instance is used with a plotter, then there is the potential for any figure-option + to be overwritten with their value from the plotter. This means that the drawer instance would + be modified indirectly when the :meth:`BasePlotter.figure` method is called. This must be kept + in mind when creating subclasses of :class:`BaseDrawer`. + + Legends + ======= + + Legends are generated based off of drawn graphics and their labels or names. These are managed by + individual drawer subclasses, and generated when the :meth:`format_canvas` method is called. Legend + entries are created when any drawing function is called with ``legend=True``. There are three + parameters in drawing functions that are relevant to legend generation: ``name``, ``label``, and + ``legend``. If a user would like the graphics drawn onto a canvas to be used as the graphical + component of a legend entry; they should set ``legend=True``. The legend entry label can be defined + in three locations: the ``label`` parameter of drawing functions, the ``"label"`` entry in + ``series_params``, and the ``name`` parameter of drawing functions. These three possible label + variables have a search hierarchy given by the order in the aforementioned list. If one of the label + variables is ``None``, the next is used. If all are ``None``, a legend entry is not generated for the + given series. + + The recommended way to customize the legend entries is as follows: + 1. Set the labels in the ``series_params`` option, keyed on the series names. + 2. Initialize the canvas. + 3. Call relevant drawing methods to create the figure. When calling the drawing method that + creates the graphic you would like to use in the legend, set ``legend=True``. For example, + ``drawer.scatter(...,legend=True)`` would use the scatter points as the legend graphics for + the given series. + 4. Format the canvas and call :meth:`figure` to get the figure. + """ + + def __init__(self): + """Create a BaseDrawer instance.""" + # Normal options. Which includes the drawer axis, subplots, and default style. + self._options = self._default_options() + # A set of changed options for serialization. + self._set_options = set() + + # Figure options which are typically updated by a plotter instance. Figure-options include the + # axis labels, figure title, and a custom style instance. + self._figure_options = self._default_figure_options() + # A set of changed figure-options for serialization. + self._set_figure_options = set() + + # The initialized axis/axes, set by `initialize_canvas`. + self._axis = None + + @property + def options(self) -> Options: + """Return the drawer options.""" + return self._options + + @property + def figure_options(self) -> Options: + """Return the figure options. + + These are typically updated by a plotter instance, and thus may change. It is recommended to set + figure options in a parent :class:`BasePlotter` instance that contains the :class:`BaseDrawer` + instance. + """ + return self._figure_options + + @classmethod + def _default_options(cls) -> Options: + """Return default drawer options. + + Drawer Options: + axis (Any): Arbitrary object that can be used as a canvas. + subplots (Tuple[int, int]): Number of rows and columns when the experimental + result is drawn in the multiple windows. + default_style (PlotStyle): The default style for drawer. + This must contain all required style parameters for :class:`drawer`, as is defined in + :meth:`PlotStyle.default_style()`. Subclasses can add extra required style parameters by + overriding :meth:`_default_style`. + """ + return Options( + axis=None, + subplots=(1, 1), + default_style=cls._default_style(), + ) + + @classmethod + def _default_style(cls) -> PlotStyle: + return PlotStyle.default_style() + + @classmethod + def _default_figure_options(cls) -> Options: + """Return default figure options. + + Figure Options: + xlabel (Union[str, List[str]]): X-axis label string of the output figure. + If there are multiple columns in the canvas, this could be a list of labels. + ylabel (Union[str, List[str]]): Y-axis label string of the output figure. + If there are multiple rows in the canvas, this could be a list of labels. + xlim (Tuple[float, float]): Min and max value of the horizontal axis. + If not provided, it is automatically scaled based on the input data points. + ylim (Tuple[float, float]): Min and max value of the vertical axis. + If not provided, it is automatically scaled based on the input data points. + xval_unit (str): SI unit of x values. No prefix is needed here. + For example, when the x values represent time, this option will be just "s" + rather than "ms". In the output figure, the prefix is automatically selected + based on the maximum value in this axis. If your x values are in [1e-3, 1e-4], + they are displayed as [1 ms, 10 ms]. This option is likely provided by the + analysis class rather than end-users. However, users can still override + if they need different unit notation. By default, this option is set to ``None``, + and no scaling is applied. If nothing is provided, the axis numbers will be + displayed in the scientific notation. + yval_unit (str): Unit of y values. See ``xval_unit`` for details. + figure_title (str): Title of the figure. Defaults to None, i.e. nothing is shown. + series_params (Dict[str, Dict[str, Any]]): A dictionary of parameters for each series. + This is keyed on the name for each series. Sub-dictionary is expected to have the + following three configurations, "canvas", "color", "symbol" and "label"; "canvas" is the + integer index of axis (when multi-canvas plot is set), "color" is the color of the drawn + graphics, "symbol" is the series marker style for scatter plots, and "label" is a user + provided series label that appears in the legend. + custom_style (PlotStyle): The style definition to use when drawing. This overwrites style + parameters in ``default_style`` in :attr:`options`. Defaults to an empty PlotStyle + instance (i.e., :code-block:`PlotStyle()`). + """ + return Options( + xlabel=None, + ylabel=None, + xlim=None, + ylim=None, + xval_unit=None, + yval_unit=None, + figure_title=None, + series_params={}, + custom_style=PlotStyle(), + ) + + def set_options(self, **fields): + """Set the drawer options. + + Args: + fields: The fields to update the options + + Raises: + AttributeError: if an unknown options is encountered. + """ + for field in fields: + if not hasattr(self._options, field): + raise AttributeError( + f"Options field {field} is not valid for {type(self).__name__}" + ) + + self._options.update_options(**fields) + self._set_options = self._set_options.union(fields) + + def set_figure_options(self, **fields): + """Set the figure options. + + Args: + fields: The fields to update the figure options + + Raises: + AttributeError: if an unknown figure-option is encountered. + """ + for field in fields: + if not hasattr(self._figure_options, field): + raise AttributeError( + f"Figure options field {field} is not valid for {type(self).__name__}" + ) + self._figure_options.update_options(**fields) + self._set_figure_options = self._set_figure_options.union(fields) + + @property + def style(self) -> PlotStyle: + """The combined plot style for this drawer. + + The returned style instance is a combination of :attr:`options.default_style` and + :attr:`figure_options.custom_style`. Style parameters set in ``custom_style`` override those set + in ``default_style``. If ``custom_style`` is not an instance of :class:`PlotStyle`, the returned + style is equivalent to ``default_style``. + + Returns: + PlotStyle: The plot style for this drawer. + """ + if isinstance(self.figure_options.custom_style, PlotStyle): + return PlotStyle.merge(self.options.default_style, self.figure_options.custom_style) + return self.options.default_style + + @abstractmethod + def initialize_canvas(self): + """Initialize the drawer canvas.""" + + @abstractmethod + def format_canvas(self): + """Final cleanup for the canvas appearance.""" + + def label_for(self, name: Optional[str], label: Optional[str]) -> Optional[str]: + """Get the legend label for the given series, with optional overrides. + + This method determines the legend label for a series, with optional overrides ``label`` and the + ``"label"`` entry in the ``series_params`` option (see :attr:`options`). ``label`` is returned if + it is not ``None``, as this is the override with the highest priority. If it is ``None``, then + the drawer will look for a ``"label"`` entry in ``series_params`, for the series identified by + ``name``. If this entry doesn't exist, or is ``None``, then ``name`` is used as the label. If all + these options are ``None``, then ``None`` is returned; signifying that a legend entry for the + provided series should not be generated. + + Args: + name: The name of the series. + label: Optional label override. + + Returns: + Optional[str]: The legend entry label, or ``None``. + """ + if label is not None: + return label + + if name: + return self.figure_options.series_params.get(name, {}).get("label", name) + return None + + @abstractmethod + def scatter( + self, + x_data: Sequence[float], + y_data: Sequence[float], + x_err: Optional[Sequence[float]] = None, + y_err: Optional[Sequence[float]] = None, + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Draw scatter points, with optional error-bars. + + Args: + x_data: X values. + y_data: Y values. + x_err: Optional error for X values. + y_err: Optional error for Y values. + name: Name of this series. + label: Optional legend label to override ``name`` and ``series_params``. + legend: Whether the drawn area must have a legend entry. Defaults to False. + The series label in the legend will be ``label`` if it is not None. If it is, then + ``series_params`` is searched for a "label" entry for the series identified by ``name``. + If this is also ``None``, then ``name`` is used as the fallback. If no ``name`` is + provided, then no legend entry is generated. + options: Valid options for the drawer backend API. + """ + + @abstractmethod + def line( + self, + x_data: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Draw a line. + + Args: + x_data: X values. + y_data: Y values. + name: Name of this series. + label: Optional legend label to override ``name`` and ``series_params``. + legend: Whether the drawn area must have a legend entry. Defaults to False. + The series label in the legend will be ``label`` if it is not None. If it is, then + ``series_params`` is searched for a "label" entry for the series identified by ``name``. + If this is also ``None``, then ``name`` is used as the fallback. If no ``name`` is + provided, then no legend entry is generated. + options: Valid options for the drawer backend API. + """ + + @abstractmethod + def filled_y_area( + self, + x_data: Sequence[float], + y_ub: Sequence[float], + y_lb: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Draw filled area as a function of x-values. + + Args: + x_data: X values. + y_ub: The upper boundary of Y values. + y_lb: The lower boundary of Y values. + name: Name of this series. + label: Optional legend label to override ``name`` and ``series_params``. + legend: Whether the drawn area must have a legend entry. Defaults to False. + The series label in the legend will be ``label`` if it is not None. If it is, then + ``series_params`` is searched for a "label" entry for the series identified by ``name``. + If this is also ``None``, then ``name`` is used as the fallback. If no ``name`` is + provided, then no legend entry is generated. + options: Valid options for the drawer backend API. + """ + + @abstractmethod + def filled_x_area( + self, + x_ub: Sequence[float], + x_lb: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Draw filled area as a function of y-values. + + Args: + x_ub: The upper boundary of X values. + x_lb: The lower boundary of X values. + y_data: Y values. + name: Name of this series. + label: Optional legend label to override ``name`` and ``series_params``. + legend: Whether the drawn area must have a legend entry. Defaults to False. + The series label in the legend will be ``label`` if it is not None. If it is, then + ``series_params`` is searched for a "label" entry for the series identified by ``name``. + If this is also ``None``, then ``name`` is used as the fallback. If no ``name`` is + provided, then no legend entry is generated. + options: Valid options for the drawer backend API. + """ + + @abstractmethod + def textbox( + self, + description: str, + rel_pos: Optional[Tuple[float, float]] = None, + **options, + ): + """Draw text box. + + Args: + description: A string to be drawn inside a report box. + rel_pos: Relative position of the text-box. If None, the default ``textbox_rel_pos`` from + the style is used. + options: Valid options for the drawer backend API. + """ + + @property + @abstractmethod + def figure(self): + """Return figure object handler to be saved in the database.""" + + def config(self) -> Dict: + """Return the config dictionary for this drawer.""" + options = dict((key, getattr(self._options, key)) for key in self._set_options) + figure_options = dict( + (key, getattr(self._figure_options, key)) for key in self._set_figure_options + ) + + return { + "cls": type(self), + "options": options, + "figure_options": figure_options, + } + + def __json_encode__(self): + return self.config() + + @classmethod + def __json_decode__(cls, value): + instance = cls() + if "options" in value: + instance.set_options(**value["options"]) + if "figure_options" in value: + instance.set_figure_options(**value["figure_options"]) + return instance diff --git a/qiskit_experiments/visualization/drawers/legacy_curve_compat_drawer.py b/qiskit_experiments/visualization/drawers/legacy_curve_compat_drawer.py new file mode 100644 index 0000000000..8081eac436 --- /dev/null +++ b/qiskit_experiments/visualization/drawers/legacy_curve_compat_drawer.py @@ -0,0 +1,187 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Compatibility wrapper for legacy BaseCurveDrawer.""" + +import warnings +from typing import Optional, Sequence, Tuple + +from qiskit_experiments.curve_analysis.visualization import BaseCurveDrawer +from qiskit_experiments.warnings import deprecated_class + +from .base_drawer import BaseDrawer + + +@deprecated_class( + "0.6", + msg="Legacy drawers from `.curve_analysis.visualization are deprecated. This compatibility wrapper " + "will be removed alongside the deprecated modules removal", +) +class LegacyCurveCompatDrawer(BaseDrawer): + """A compatibility wrapper for the legacy and deprecated :class:`BaseCurveDrawer`. + + :mod:`qiskit_experiments.curve_analysis.visualization` is deprecated and will be replaced with the + new :mod:`qiskit_experiments.visualization` module. Analysis classes instead use subclasses of + :class:`BasePlotter` to generate figures. This class wraps the legacy :class:`BaseCurveDrawer` class + so it can be used by analysis classes, such as :class:`CurveAnalysis`, until it is removed. + + .. note:: + As :class:`BaseCurveDrawer` doesn't support customizing legend entries, the ``legend`` and + ``label`` parameters in drawing methods (such as :meth:`scatter`) are unsupported and + do nothing. + """ + + def __init__(self, curve_drawer: BaseCurveDrawer): + """Create a LegacyCurveCompatDrawer instance. + + Args: + curve_drawer: A legacy BaseCurveDrawer to wrap in the compatibility drawer. + """ + super().__init__() + self._curve_drawer = curve_drawer + + def initialize_canvas(self): + self._curve_drawer.initialize_canvas() + + def format_canvas(self): + self._curve_drawer.format_canvas() + + # pylint: disable=unused-argument + def scatter( + self, + x_data: Sequence[float], + y_data: Sequence[float], + x_err: Optional[Sequence[float]] = None, + y_err: Optional[Sequence[float]] = None, + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Draws scatter points with optional Y errorbars. + + Args: + x_data: X values. + y_data: Y values. + x_err: Unsupported as :class:`BaseCurveDrawer` doesn't support X errorbars. Defaults to None. + y_err: Optional error for Y values. + name: Name of this series. + label: Unsupported as :class:`BaseCurveDrawer` doesn't support customizing legend entries. + legend: Unsupported as :class:`BaseCurveDrawer` doesn't support toggling legend entries. + options: Valid options for the drawer backend API. + """ + if x_err is not None: + warnings.warn(f"{self.__class__.__name__} doesn't support x_err.") + + if y_err is not None: + self._curve_drawer.draw_formatted_data(x_data, y_data, y_err, name, **options) + else: + self._curve_drawer.draw_raw_data(x_data, y_data, name, **options) + + # pylint: disable=unused-argument + def line( + self, + x_data: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Draw fit line. + + Args: + x_data: X values. + y_data: Fit Y values. + name: Name of this series. + label: Unsupported as :class:`BaseCurveDrawer` doesn't support customizing legend entries. + legend: Unsupported as :class:`BaseCurveDrawer` doesn't support toggling legend entries. + options: Valid options for the drawer backend API. + """ + self._curve_drawer.draw_fit_line(x_data, y_data, name, **options) + + # pylint: disable=unused-argument + def filled_y_area( + self, + x_data: Sequence[float], + y_ub: Sequence[float], + y_lb: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Draw filled area as a function of x-values. + + Args: + x_data: X values. + y_ub: The upper boundary of Y values. + y_lb: The lower boundary of Y values. + name: Name of this series. + label: Unsupported as :class:`BaseCurveDrawer` doesn't support customizing legend entries. + legend: Unsupported as :class:`BaseCurveDrawer` doesn't support toggling legend entries. + options: Valid options for the drawer backend API. + """ + + self._curve_drawer.draw_confidence_interval(x_data, y_ub, y_lb, name, **options) + + # pylint: disable=unused-argument + def filled_x_area( + self, + x_ub: Sequence[float], + x_lb: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Does nothing as this is functionality not supported by :class:`BaseCurveDrawer`.""" + warnings.warn(f"{self.__class__.__name__}.filled_x_area is not supported.") + + # pylint: disable=unused-argument + def textbox( + self, + description: str, + rel_pos: Optional[Tuple[float, float]] = None, + **options, + ): + """Draw textbox. + + Args: + description: A string to be drawn inside a text box. + rel_pos: Unsupported as :class:`BaseCurveDrawer` doesn't support modifying the location of + text in :meth:`textbox` or :meth:`BaseCurveDrawer.draw_fit_report`. + options: Valid options for the drawer backend API. + """ + + self._curve_drawer.draw_fit_report(description, **options) + + @property + def figure(self): + return self._curve_drawer.figure + + def set_options(self, **fields): + ## Handle option name changes + # BaseCurveDrawer used `plot_options` instead of `series_params` + if "series_params" in fields: + fields["plot_options"] = fields.pop("series_params") + # PlotStyle parameters are normal options in BaseCurveDrawer. + if "custom_style" in fields: + custom_style = fields.pop("custom_style") + for key, value in custom_style.items(): + fields[key] = value + + self._curve_drawer.set_options(**fields) + + def set_figure_options(self, **fields): + self.set_options(**fields) diff --git a/qiskit_experiments/visualization/drawers/mpl_drawer.py b/qiskit_experiments/visualization/drawers/mpl_drawer.py new file mode 100644 index 0000000000..4509e60f5e --- /dev/null +++ b/qiskit_experiments/visualization/drawers/mpl_drawer.py @@ -0,0 +1,472 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Curve drawer for matplotlib backend.""" + +from typing import Any, Dict, Optional, Sequence, Tuple + +import numpy as np +from matplotlib.axes import Axes +from matplotlib.cm import tab10 +from matplotlib.figure import Figure +from matplotlib.markers import MarkerStyle +from matplotlib.ticker import Formatter, ScalarFormatter +from qiskit.utils import detach_prefix + +from qiskit_experiments.framework.matplotlib import get_non_gui_ax + +from .base_drawer import BaseDrawer + + +class MplDrawer(BaseDrawer): + """Drawer for MatplotLib backend.""" + + DefaultMarkers = MarkerStyle.filled_markers + DefaultColors = tab10.colors + + class PrefixFormatter(Formatter): + """Matplotlib axis formatter to detach prefix. + + If a value is, e.g., x=1000.0 and the factor is 1000, then it will be shown + as 1.0 in the ticks and its unit will be shown with the prefactor 'k' + in the axis label. + """ + + def __init__(self, factor: float): + """Create a PrefixFormatter instance. + + Args: + factor: factor by which to scale tick values. + """ + self.factor = factor + + def __call__(self, x: Any, pos: int = None) -> str: + """Returns the formatted string for tick position ``pos`` and value ``x``. + + Args: + x: the tick value to format. + pos: the tick label position. + + Returns: + str: the formatted tick label. + """ + return self.fix_minus("{:.3g}".format(x * self.factor)) + + def __init__(self): + super().__init__() + # Used to track which series have already been plotted. Needed for _get_default_marker and + # _get_default_color. + self._series = list() + + def initialize_canvas(self): + # Create axis if empty + if not self.options.axis: + axis = get_non_gui_ax() + figure = axis.get_figure() + figure.set_size_inches(*self.style["figsize"]) + else: + axis = self.options.axis + + n_rows, n_cols = self.options.subplots + n_subplots = n_cols * n_rows + if n_subplots > 1: + # Add inset axis. User may provide a single axis object via the analysis option, + # while this analysis tries to draw its result in multiple canvases, + # especially when the analysis consists of multiple curves. + # Inset axis is experimental implementation of matplotlib 3.0 so maybe unstable API. + # This draws inset axes with shared x and y axis. + inset_ax_h = 1 / n_rows + inset_ax_w = 1 / n_cols + for i in range(n_rows): + for j in range(n_cols): + # x0, y0, width, height + bounds = [ + inset_ax_w * j, + 1 - inset_ax_h * (i + 1), + inset_ax_w, + inset_ax_h, + ] + sub_ax = axis.inset_axes(bounds, transform=axis.transAxes, zorder=1) + if j != 0: + # remove y axis except for most-left plot + sub_ax.set_yticklabels([]) + else: + # this axis locates at left, write y-label + if self.figure_options.ylabel: + label = self.figure_options.ylabel + if isinstance(label, list): + # Y label can be given as a list for each sub axis + label = label[i] + sub_ax.set_ylabel(label, fontsize=self.style["axis_label_size"]) + if i != n_rows - 1: + # remove x axis except for most-bottom plot + sub_ax.set_xticklabels([]) + else: + # this axis locates at bottom, write x-label + if self.figure_options.xlabel: + label = self.figure_options.xlabel + if isinstance(label, list): + # X label can be given as a list for each sub axis + label = label[j] + sub_ax.set_xlabel(label, fontsize=self.style["axis_label_size"]) + if j == 0 or i == n_rows - 1: + # Set label size for outer axes where labels are drawn + sub_ax.tick_params(labelsize=self.style["tick_label_size"]) + sub_ax.grid() + + # Remove original axis frames + axis.axis("off") + else: + axis.set_xlabel(self.figure_options.xlabel, fontsize=self.style["axis_label_size"]) + axis.set_ylabel(self.figure_options.ylabel, fontsize=self.style["axis_label_size"]) + axis.tick_params(labelsize=self.style["tick_label_size"]) + axis.grid() + + self._axis = axis + + def format_canvas(self): + if self._axis.child_axes: + # Multi canvas mode + all_axes = self._axis.child_axes + else: + all_axes = [self._axis] + + # Add data labels if there are multiple labels registered per sub_ax. + for sub_ax in all_axes: + _, labels = sub_ax.get_legend_handles_labels() + if len(labels) > 1: + sub_ax.legend(loc=self.style["legend_loc"]) + + # Format x and y axis + for ax_type in ("x", "y"): + # Get axis formatter from drawing options + if ax_type == "x": + lim = self.figure_options.xlim + unit = self.figure_options.xval_unit + else: + lim = self.figure_options.ylim + unit = self.figure_options.yval_unit + + # Compute data range from auto scale + if not lim: + v0 = np.nan + v1 = np.nan + for sub_ax in all_axes: + if ax_type == "x": + this_v0, this_v1 = sub_ax.get_xlim() + else: + this_v0, this_v1 = sub_ax.get_ylim() + v0 = np.nanmin([v0, this_v0]) + v1 = np.nanmax([v1, this_v1]) + lim = (v0, v1) + + # Format axis number notation + if unit: + # If value is specified, automatically scale axis magnitude + # and write prefix to axis label, i.e. 1e3 Hz -> 1 kHz + maxv = max(np.abs(lim[0]), np.abs(lim[1])) + try: + scaled_maxv, prefix = detach_prefix(maxv, decimal=3) + prefactor = scaled_maxv / maxv + except ValueError: + prefix = "" + prefactor = 1 + + formatter = MplDrawer.PrefixFormatter(prefactor) + units_str = f" [{prefix}{unit}]" + else: + # Use scientific notation with 3 digits, 1000 -> 1e3 + formatter = ScalarFormatter() + formatter.set_scientific(True) + formatter.set_powerlimits((-3, 3)) + + units_str = "" + + for sub_ax in all_axes: + if ax_type == "x": + ax = getattr(sub_ax, "xaxis") + tick_labels = sub_ax.get_xticklabels() + else: + ax = getattr(sub_ax, "yaxis") + tick_labels = sub_ax.get_yticklabels() + + if tick_labels: + # Set formatter only when tick labels exist + ax.set_major_formatter(formatter) + if units_str: + # Add units to label if both exist + label_txt_obj = ax.get_label() + label_str = label_txt_obj.get_text() + if label_str: + label_txt_obj.set_text(label_str + units_str) + + # Auto-scale all axes to the first sub axis + if ax_type == "x": + # get_shared_y_axes() is immutable from matplotlib>=3.6.0. Must use Axis.sharey() + # instead, but this can only be called once per axis. Here we call sharey on all axes in + # a chain, which should have the same effect. + if len(all_axes) > 1: + for ax1, ax2 in zip(all_axes[1:], all_axes[0:-1]): + ax1.sharex(ax2) + all_axes[0].set_xlim(lim) + else: + # get_shared_y_axes() is immutable from matplotlib>=3.6.0. Must use Axis.sharey() + # instead, but this can only be called once per axis. Here we call sharey on all axes in + # a chain, which should have the same effect. + if len(all_axes) > 1: + for ax1, ax2 in zip(all_axes[1:], all_axes[0:-1]): + ax1.sharey(ax2) + all_axes[0].set_ylim(lim) + # Add title + if self.figure_options.figure_title is not None: + self._axis.set_title( + label=self.figure_options.figure_title, + fontsize=self.style["axis_label_size"], + ) + + def _get_axis(self, index: Optional[int] = None) -> Axes: + """A helper method to get inset axis. + + Args: + index: Index of inset axis. If nothing is provided, it returns the entire axis. + + Returns: + Corresponding axis object. + + Raises: + IndexError: When axis index is specified but no inset axis is found. + """ + if index is not None: + try: + return self._axis.child_axes[index] + except IndexError as ex: + raise IndexError( + f"Canvas index {index} is out of range. " + f"Only {len(self._axis.child_axes)} subplots are initialized." + ) from ex + else: + return self._axis + + def _get_default_color(self, name: str) -> Tuple[float, ...]: + """A helper method to get default color for the series. + + Args: + name: Name of the series. + + Returns: + Default color available in matplotlib. + """ + if name not in self._series: + self._series.append(name) + + ind = self._series.index(name) % len(self.DefaultColors) + return self.DefaultColors[ind] + + def _get_default_marker(self, name: str) -> str: + """A helper method to get default marker for the scatter plot. + + Args: + name: Name of the series. + + Returns: + Default marker available in matplotlib. + """ + if name not in self._series: + self._series.append(name) + + ind = self._series.index(name) % len(self.DefaultMarkers) + return self.DefaultMarkers[ind] + + def _update_label_in_options( + self, + options: Dict[str, any], + name: Optional[str], + label: Optional[str] = None, + legend: bool = False, + ): + """Helper function to set the label entry in ``options`` based on given arguments. + + This method uses :meth:`label_for` to get the label for the series identified by ``name``. If + :meth:`label_for` returns ``None``, then ``_update_label_in_options`` doesn't add a `"label"` + entry into ``options``. I.e., a label entry is added to ``options`` only if it is not ``None``. + + Args: + options: The options dictionary being modified. + name: The name of the series being labelled. Used as a fall-back label if ``label`` is None + and no label exists in ``series_params`` for this series. + label: Optional legend label to override ``name`` and ``series_params``. + legend: Whether a label entry should be added to ``options``. USed as an easy toggle to + disable adding a label entry. Defaults to False. + """ + if legend: + _label = self.label_for(name, label) + if _label: + options["label"] = _label + + def scatter( + self, + x_data: Sequence[float], + y_data: Sequence[float], + x_err: Optional[Sequence[float]] = None, + y_err: Optional[Sequence[float]] = None, + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + + series_params = self.figure_options.series_params.get(name, {}) + marker = series_params.get("symbol", self._get_default_marker(name)) + color = series_params.get("color", self._get_default_color(name)) + axis = series_params.get("canvas", None) + + draw_options = { + "color": color, + "marker": marker, + "alpha": 0.8, + "zorder": 2, + } + self._update_label_in_options(draw_options, name, label, legend) + draw_options.update(**options) + + if x_err is None and y_err is None: + self._get_axis(axis).scatter(x_data, y_data, **draw_options) + else: + # Check for invalid error values. + if y_err is not None and not np.all(np.isfinite(y_err)): + y_err = None + if x_err is not None and not np.all(np.isfinite(x_err)): + x_err = None + + # `errorbar` has extra default draw_options to set, but we want to accept any overrides from + # `options`, and thus draw_options. + errorbar_options = { + "linestyle": "", + "markersize": 9, + } + errorbar_options.update(draw_options) + + self._get_axis(axis).errorbar( + x_data, y_data, yerr=y_err, xerr=x_err, **errorbar_options + ) + + def line( + self, + x_data: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + series_params = self.figure_options.series_params.get(name, {}) + axis = series_params.get("canvas", None) + color = series_params.get("color", self._get_default_color(name)) + + draw_ops = { + "color": color, + "linestyle": "-", + "linewidth": 2, + } + self._update_label_in_options(draw_ops, name, label, legend) + draw_ops.update(**options) + self._get_axis(axis).plot(x_data, y_data, **draw_ops) + + def filled_y_area( + self, + x_data: Sequence[float], + y_ub: Sequence[float], + y_lb: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + series_params = self.figure_options.series_params.get(name, {}) + axis = series_params.get("canvas", None) + color = series_params.get("color", self._get_default_color(name)) + + draw_ops = { + "alpha": 0.1, + "color": color, + } + self._update_label_in_options(draw_ops, name, label, legend) + draw_ops.update(**options) + self._get_axis(axis).fill_between(x_data, y1=y_lb, y2=y_ub, **draw_ops) + + def filled_x_area( + self, + x_ub: Sequence[float], + x_lb: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + series_params = self.figure_options.series_params.get(name, {}) + axis = series_params.get("canvas", None) + color = series_params.get("color", self._get_default_color(name)) + + draw_ops = { + "alpha": 0.1, + "color": color, + } + self._update_label_in_options(draw_ops, name, label, legend) + draw_ops.update(**options) + self._get_axis(axis).fill_betweenx(y_data, x1=x_lb, x2=x_ub, **draw_ops) + + def textbox( + self, + description: str, + rel_pos: Optional[Tuple[float, float]] = None, + **options, + ): + bbox_props = { + "boxstyle": "square, pad=0.3", + "fc": "white", + "ec": "black", + "lw": 1, + "alpha": 0.8, + } + bbox_props.update(**options) + + if rel_pos is None: + rel_pos = self.style["textbox_rel_pos"] + + text_box_handler = self._axis.text( + *rel_pos, + s=description, + ha="center", + va="top", + size=self.style["textbox_text_size"], + transform=self._axis.transAxes, + zorder=1000, # Very large zorder to draw over other graphics. + ) + text_box_handler.set_bbox(bbox_props) + + @property + def figure(self) -> Figure: + """Return figure object handler to be saved in the database. + + In the MatplotLib the ``Figure`` and ``Axes`` are different object. + User can pass a part of the figure (i.e. multi-axes) to the drawer option ``axis``. + For example, a user wants to combine two different experiment results in the + same figure, one can call ``pyplot.subplots`` with two rows and pass one of the + generated two axes to each experiment drawer. Once all the experiments complete, + the user will obtain the single figure collecting all experimental results. + + Note that this method returns the entire figure object, rather than a single axis. + Thus, the experiment data saved in the database might have a figure + collecting all child axes drawings. + """ + return self._axis.get_figure() diff --git a/qiskit_experiments/visualization/plotters/__init__.py b/qiskit_experiments/visualization/plotters/__init__.py new file mode 100644 index 0000000000..c8f2133501 --- /dev/null +++ b/qiskit_experiments/visualization/plotters/__init__.py @@ -0,0 +1,15 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +"""Plotters submodule, defining interfaces to draw figures.""" + +from .base_plotter import BasePlotter +from .curve_plotter import CurvePlotter diff --git a/qiskit_experiments/visualization/plotters/base_plotter.py b/qiskit_experiments/visualization/plotters/base_plotter.py new file mode 100644 index 0000000000..d7375995a9 --- /dev/null +++ b/qiskit_experiments/visualization/plotters/base_plotter.py @@ -0,0 +1,521 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +"""Base plotter abstract class""" + +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union + +from qiskit_experiments.framework import Options +from qiskit_experiments.visualization.drawers import BaseDrawer + +from ..style import PlotStyle + + +class BasePlotter(ABC): + """An abstract class for the serializable figure plotters. + + A plotter takes data from an experiment analysis class or experiment and plots a given figure using a + drawing backend. Sub-classes define the kind of figure created and the expected data. + + Data is split into series and supplementary data. Series data is grouped by series name (str). For + :class:`CurveAnalysis`, this is the model name for a curve fit. For series data associated with a + single series name and supplementary data, data-values are identified by a data-key (str). Different + data per series and figure must have a different data-key to avoid overwriting values. Experiment and + analysis results can be passed to the plotter so appropriate graphics can be drawn on the figure + canvas. Series data is added to the plotter using :meth:`set_series_data` whereas supplementary data + is added using :meth:`set_supplementary_data`. Series and supplementary data are retrieved using + :meth:`data_for` and :attr:`supplementary_data` respectively. + + Series data contains values to be plotted on a canvas, such that the data can be grouped into subsets + identified by their series name. Series names can be thought of as legend labels for the plotted + data, and as curve names for a curve-fit. Supplementary data is not associated with a series or curve + and is instead only associated with the figure. Examples include analysis reports or other text that + is drawn onto the figure canvas. + + Options and Figure Options + ========================== + + Plotters have both :attr:`options` and :attr:`figure_options` available to set parameters that define + how to plot and what is plotted. :class:`BaseDrawer` is similar in that it also has ``options`` and + ``figure_options`. The former contains class-specific variables that define how an instance behaves. + The latter contains figure-specific variables that typically contain values that are drawn on the + canvas, such as text. + + For example, :class:`BasePlotter` has an ``axis`` option that can be set to the canvas on which the + figure should be drawn. This changes how the plotter works in that it changes where the figure is + drawn. :class:`BasePlotter` has an ``xlabel`` figure-option that can be set to change the text drawn + next to the X-axis in the final figure. As the value of this option will be drawn on the figure, it + is a figure-option. + + As plotters need a drawer to generate a figure, and the drawer needs to know what to draw, + figure-options are passed to :attr:`drawer` when the :meth:`figure` method is called. Any + figure-options that are defined in both the plotters :attr:`figure_options` attribute and the drawers + ``figure_options`` attribute are copied to the drawer: i.e., :meth:`BaseDrawer.set_figure_options` is + called for each common figure-option, setting the value of the option to the value stored in the + plotter. + + .. note:: + If a figure-option called "foo" is not set in the drawers figure-options (:attr:`~BaseDrawer. + figure_options`), but is set in the plotters figure-options (:attr:`figure_options`), it will + not be copied over to the drawer when the :meth:`figure` method is called. This means that some + figure-options from the plotter may be unused by the drawer. :class:`BasePlotter` and its + subclasses filter these options before setting them in the drawer as subclasses of + :class:`BaseDrawer` may add additional figure-options. To make validation easier and the code + cleaner, the :meth:`figure` method conducts this check before setting figure-options in the + drawer. + + Example: + .. code-block:: python + plotter = MyPlotter(MyDrawer()) + + # MyDrawer contains the following figure_options with default values. + plotter.drawer.figure_options.xlabel + plotter.drawer.figure_options.ylabel + + # MyDrawer does NOT contain the following figure-option + # plotter.drawer.figure_options.unknown_variable # Raises an error as it does not exist in + # `plotter.drawer`. + + # If we set the following figure-options, they will be set in the drawer. + plotter.set_figure_options(xlabel="Frequency", ylabel="Fidelity") + + # During a call to `plotter.figure()`, the drawer figure-options are updated. + # The following values would be returned from the drawer. + plotter.drawer.figure_options.xlabel # returns "Frequency" + plotter.drawer.figure_options.ylabel # returns "Fidelity" + + # If we set the following option and figure-option will NOT be set in the drawer. + plotter.set_options(plot_fit=False) # Example plotter option + plotter.set_figure_options(unknown_variable=5e9) # Example figure-option + + # As `plot_fit` is not a figure-option, it is not set in the drawer. + plotter.drawer.options.plot_fit # Would raise an error if no default exists, or return a + # different value to `plotter.options.plot_fit`. + + # As `unknown_variable` is not set in the drawers figure-options, it is not set during a call + # to the `figure()` method. + # plotter.drawer.figure_options.unknown_variable # Raises an error as it does not exist + # in `plotter.drawer.figure_options`. + """ + + def __init__(self, drawer: BaseDrawer): + """Create a new plotter instance. + + Args: + drawer: The drawer to use when creating the figure. + """ + # Data to be plotted, such as scatter points, interpolated fits, and confidence intervals + self._series_data: Dict[str, Dict[str, Any]] = {} + # Data that isn't directly associated with a single series, such as text or fit reports. + self._supplementary_data: Dict[str, Any] = {} + + # Options for the plotter + self._options = self._default_options() + # Plotter options that have changed, for serialization. + self._set_options = set() + + # Figure options that are updated in the drawer when `plotter.figure()` is called + self._figure_options = self._default_figure_options() + # Figure options that have changed, for serialization. + self._set_figure_options = set() + + # The drawer backend to use for plotting. + self.drawer = drawer + + @property + def supplementary_data(self) -> Dict[str, Any]: + """Additional data for the figure being plotted, that isn't associated with a series. + + Supplementary data includes text, fit reports, or other data that is associated with the figure + but not an individual series. It is typically data additional to the direct results of an + experiment. + """ + return self._supplementary_data + + @property + def series_data(self) -> Dict[str, Dict[str, Any]]: + """Data for series being plotted. + + Series data includes data such as scatter points, interpolated fit values, and + standard-deviations. Series data is grouped by series-name and then by a data-key, both strings. + Though series data can be accessed through :meth:`series_data`, it is recommended to access them + with :meth:`data_for` and :meth:`data_exists_for` as they allow for easier access to nested + values and can handle multiple data-keys in one query. + + Returns: + dict: A dictionary containing series data. + """ + return self._series_data + + @property + def series(self) -> List[str]: + """Series names that have been added to this plotter.""" + return list(self._series_data.keys()) + + def data_keys_for(self, series_name: str) -> List[str]: + """Returns a list of data-keys for the given series. + + Args: + series_name: The series name for which to return the data-keys, i.e., the types of data for + each series. + + Returns: + list: The list of data-keys for data in the plotter associated with the given series. If the + series has not been added to the plotter, an empty list is returned. + """ + return list(self._series_data.get(series_name, [])) + + def data_for(self, series_name: str, data_keys: Union[str, List[str]]) -> Tuple[Optional[Any]]: + """Returns data associated with the given series. + + The returned tuple contains the data, associated with ``data_keys``, in the same orders as they + are provided. For example, + + .. code-example::python + plotter.set_series_data("seriesA", x=data.x, y=data.y, yerr=data.yerr) + + # The following calls are equivalent. + x, y, yerr = plotter.series_data_for("seriesA", ["x", "y", "yerr"]) + x, y, yerr = data.x, data.y, data.yerr + + :meth:`data_for` is intended to be used by sub-classes of :class:`BasePlotter` when plotting in + the :meth:`_plot_figure` method. + + Args: + series_name: The series name for the given series. + data_keys: List of data-keys for the data to be returned. If a single data-key is given as a + string, it is wrapped in a list. + + Returns: + tuple: A tuple of data associated with the given series, identified by ``data_keys``. If no + data has been set for a data-key, None is returned for the associated tuple entry. + """ + + # We may be given a single data-key, but we need a list for the rest of the function. + if not isinstance(data_keys, list): + data_keys = [data_keys] + + # The series doesn't exist in the plotter data, return None for each data-key in the output. + if series_name not in self._series_data: + return (None,) * len(data_keys) + + return tuple(self._series_data[series_name].get(key, None) for key in data_keys) + + def set_series_data(self, series_name: str, **data_kwargs): + """Sets data for the given series. + + Note that if data has already been assigned for the given series and data-key, it will be + overwritten with the new values. ``set_series_data`` will warn if the data-key is unexpected; + i.e., not within those returned by :meth:`expected_series_data_keys`. + + Args: + series_name: The name of the given series. + data_kwargs: The data to be added, where the keyword is the data-key. + """ + # Warn if the data-keys are not expected. + unknown_data_keys = [ + data_key for data_key in data_kwargs if data_key not in self.expected_series_data_keys() + ] + for unknown_data_key in unknown_data_keys: + warnings.warn( + f"{self.__class__.__name__} encountered an unknown data-key {unknown_data_key}. It may " + "not be used by the plotter class." + ) + + # Set data + if series_name not in self._series_data: + self._series_data[series_name] = {} + self._series_data[series_name].update(**data_kwargs) + + def clear_series_data(self, series_name: Optional[str] = None): + """Clear series data for this plotter. + + Args: + series_name: The series name identifying which data should be cleared. If None, all series + data is cleared. Defaults to None. + """ + if series_name is None: + self._series_data = {} + elif series_name in self._series_data: + self._series_data.pop(series_name) + + def set_supplementary_data(self, **data_kwargs): + """Sets supplementary data for the plotter. + + Supplementary data differs from series data in that it is not associate with a series name. Fit + reports are examples of supplementary data as they contain fit results from an analysis class, + such as the "goodness" of a curve-fit. + + Note that if data has already been assigned for the given data-key, it will be overwritten with + the new values. ``set_supplementary_data`` will warn if the data-key is unexpected; i.e., not + within those returned by :meth:`expected_supplementary_data_keys`. + + """ + + # Warn if any data-keys are not expected. + unknown_data_keys = [ + data_key + for data_key in data_kwargs + if data_key not in self.expected_supplementary_data_keys() + ] + for unknown_data_key in unknown_data_keys: + warnings.warn( + f"{self.__class__.__name__} encountered an unknown data-key {unknown_data_key}. It may " + "not be used by the plotter class." + ) + + self._supplementary_data.update(**data_kwargs) + + def clear_supplementary_data(self): + """Clears supplementary data.""" + self._supplementary_data = {} + + def data_exists_for(self, series_name: str, data_keys: Union[str, List[str]]) -> bool: + """Returns whether the given data-keys exist for the given series. + + Args: + series_name: The name of the given series. + data_keys: The data-keys to be checked. + + Returns: + bool: True if all data-keys have values assigned for the given series. False if at least one + does not have a value assigned. + """ + if not isinstance(data_keys, list): + data_keys = [data_keys] + + # Handle non-existent series name + if series_name not in self._series_data: + return False + + return all(key in self._series_data[series_name] for key in data_keys) + + @abstractmethod + def _plot_figure(self): + """Generates a figure using :attr:`drawer` and :meth:`data`. + + Sub-classes must override this function to plot data using the drawer. This function is called by + :meth:`figure` when :attr:`drawer` can be used to draw on the canvas. + """ + + def figure(self) -> Any: + """Generates and returns a figure for the already provided series and supplementary data. + + :meth:`figure` calls :meth:`_plot_figure`, which is overridden by sub-classes. Before and after + calling :meth:`_plot_figure`; :func:`_configure_drawer`, :func:`initialize_canvas` and + :func:`format_canvas` are called on the drawer respectively. + + Returns: + Any: A figure generated by :attr:`drawer`, of the same type as ``drawer.figure``. + """ + # Initialize drawer, to copy axis, subplots, style, and figure-options across. + self._configure_drawer() + + # Initialize canvas, which creates subplots, assigns axis labels, etc. + self.drawer.initialize_canvas() + + # Plot figure for given subclass. This is the core of BasePlotter subclasses. + self._plot_figure() + + # Final formatting of canvas, which sets axis limits etc. + self.drawer.format_canvas() + + # Return whatever figure is created by the drawer. + return self.drawer.figure + + @classmethod + @abstractmethod + def expected_series_data_keys(cls) -> List[str]: + """Returns the expected series data-keys supported by this plotter.""" + + @classmethod + @abstractmethod + def expected_supplementary_data_keys(cls) -> List[str]: + """Returns the expected supplementary data-keys supported by this plotter.""" + + @property + def options(self) -> Options: + """Options for the plotter. + + Options for a plotter modify how the class generates a figure. This includes an optional axis + object, being the drawer canvas. Make sure verify whether the option you want to set is in + :attr:`options` or :attr:`figure_options`. + """ + return self._options + + @property + def figure_options(self) -> Options: + """Figure options for the plotter and its drawer. + + Figure options differ from normal options (:attr:`options`) in that the plotter passes figure + options on to the drawer when creating a figure (when :meth:`figure` is called). This way + :attr:`drawer` can draw an appropriate figure. An example of a figure option is the x-axis label. + """ + return self._figure_options + + @classmethod + def _default_options(cls) -> Options: + """Return default plotter options. + + Options: + axis (Any): Arbitrary object that can be used as a drawing canvas. + subplots (Tuple[int, int]): Number of rows and columns when the experimental + result is drawn in the multiple windows. + style (PlotStyle): The style definition to use when plotting. + This overwrites figure-option `custom_style` set in :attr:`drawer`. The default is an + empty style object, and such the default :attr:`drawer` plotting style will be used. + """ + return Options( + axis=None, + subplots=(1, 1), + style=PlotStyle(), + ) + + @classmethod + def _default_figure_options(cls) -> Options: + """Return default figure options. + + Figure Options: + xlabel (Union[str, List[str]]): X-axis label string of the output figure. + If there are multiple columns in the canvas, this could be a list of labels. + ylabel (Union[str, List[str]]): Y-axis label string of the output figure. + If there are multiple rows in the canvas, this could be a list of labels. + xlim (Tuple[float, float]): Min and max value of the horizontal axis. + If not provided, it is automatically scaled based on the input data points. + ylim (Tuple[float, float]): Min and max value of the vertical axis. + If not provided, it is automatically scaled based on the input data points. + xval_unit (str): SI unit of x values. No prefix is needed here. + For example, when the x values represent time, this option will be just "s" rather than + "ms". In the output figure, the prefix is automatically selected based on the maximum + value in this axis. If your x values are in [1e-3, 1e-4], they are displayed as [1 ms, 10 + ms]. This option is likely provided by the analysis class rather than end-users. However, + users can still override if they need different unit notation. By default, this option is + set to ``None``, and no scaling is applied. If nothing is provided, the axis numbers will + be displayed in the scientific notation. + yval_unit (str): Unit of y values. See ``xval_unit`` for details. + figure_title (str): Title of the figure. Defaults to None, i.e. nothing is shown. + series_params (Dict[str, Dict[str, Any]]): A dictionary of plot parameters for each series. + This is keyed on the name for each series. Sub-dictionary is expected to have following + three configurations, "canvas", "color", and "symbol"; "canvas" is the integer index of + axis (when multi-canvas plot is set), "color" is the color of the curve, and "symbol" is + the marker style of the curve for scatter plots. + """ + return Options( + xlabel=None, + ylabel=None, + xlim=None, + ylim=None, + xval_unit=None, + yval_unit=None, + figure_title=None, + series_params={}, + ) + + def set_options(self, **fields): + """Set the plotter options. + + Args: + fields: The fields to update in options. + + Raises: + AttributeError: if an unknown option is encountered. + """ + for field in fields: + if not hasattr(self._options, field): + raise AttributeError( + f"Options field {field} is not valid for {type(self).__name__}" + ) + self._options.update_options(**fields) + self._set_options = self._set_options.union(fields) + + def set_figure_options(self, **fields): + """Set the figure options. + + Args: + fields: The fields to update in figure options. + """ + # Don't check if any option in fields already exists (like with `set_options`), as figure options + # are passed to `.drawer` which may have other figure-options. Any figure-option that isn't set + # in `.drawer.figure_options` won't be set anyway. Setting `.drawer.figure_options` only occurs + # in `.figure()`, so we can't compare to `.drawer.figure_options` now as `.drawer` may be changed + # between now and the call to `.figure()`. + self._figure_options.update_options(**fields) + self._set_figure_options = self._set_figure_options.union(fields) + + def _configure_drawer(self): + """Configures :attr:`drawer` before plotting. + + The following actions are taken: + 1. ``axis``, ``subplots``, and ``style`` are passed to :attr:`drawer`. + 2. ``figure_options`` in :attr:`drawer` are updated based on values set in the plotter + :attr:`figure_options` + + These steps are different as all figure-options could be passed to :attr:`drawer`, if the drawer + already has a figure-option with the same name. ``axis``, ``subplots``, and ``style`` are the + only plotter options (from :attr:`options`) passed to :attr:`drawer` in + :meth:`_configure_drawer`. This is done as these options make more sense as an option for a + plotter, given the interface of :class:`BasePlotter`. + """ + ## Axis, subplots, and style + if self.options.axis: + self.drawer.set_options(axis=self.options.axis) + if self.options.subplots: + self.drawer.set_options(subplots=self.options.subplots) + self.drawer.set_figure_options(custom_style=self.options.style) + + # Convert options to dictionaries for easy comparison of all options/fields. + _drawer_figure_options = self.drawer.figure_options.__dict__ + _plotter_figure_options = self.figure_options.__dict__ + + # If an option exists in drawer.figure_options AND in self.figure_options, set the drawers + # figure-option value to that from the plotter. + for opt_key in _drawer_figure_options: + if opt_key in _plotter_figure_options: + _drawer_figure_options[opt_key] = _plotter_figure_options[opt_key] + + # Use drawer.set_figure_options so figure-options are serialized. + self.drawer.set_figure_options(**_drawer_figure_options) + + def config(self) -> Dict: + """Return the config dictionary for this drawing.""" + options = dict((key, getattr(self._options, key)) for key in self._set_options) + figure_options = dict( + (key, getattr(self._figure_options, key)) for key in self._set_figure_options + ) + drawer = self.drawer.__json_encode__() + + return { + "cls": type(self), + "options": options, + "figure_options": figure_options, + "drawer": drawer, + } + + def __json_encode__(self): + return self.config() + + @classmethod + def __json_decode__(cls, value): + ## Process drawer as it's needed to create a plotter + drawer_values = value["drawer"] + # We expect a subclass of BaseDrawer + drawer_cls: BaseDrawer = drawer_values["cls"] + drawer = drawer_cls.__json_decode__(drawer_values) + + # Create plotter instance + instance = cls(drawer) + if "options" in value: + instance.set_options(**value["options"]) + if "figure_options" in value: + instance.set_figure_options(**value["figure_options"]) + return instance diff --git a/qiskit_experiments/visualization/plotters/curve_plotter.py b/qiskit_experiments/visualization/plotters/curve_plotter.py new file mode 100644 index 0000000000..97be6837b1 --- /dev/null +++ b/qiskit_experiments/visualization/plotters/curve_plotter.py @@ -0,0 +1,140 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +"""Plotter for curve-fits, specifically from :class:`CurveAnalysis`.""" +from typing import List + +from qiskit_experiments.framework import Options + +from .base_plotter import BasePlotter + + +class CurvePlotter(BasePlotter): + """A plotter class to plot results from :class:`CurveAnalysis`. + + :class:`CurvePlotter` plots results from curve-fits, which includes: + Raw results as a scatter plot. + Processed results with standard-deviations/confidence intervals. + Interpolated fit-results from the curve analysis. + Confidence interval for the fit-results. + A report on the performance of the fit. + """ + + @classmethod + def expected_series_data_keys(cls) -> List[str]: + """Returns the expected series data-keys supported by this plotter. + + Data Keys: + x: X-values for raw results. + y: Y-values for raw results. Goes with ``x``. + x_formatted: X-values for processed results. + y_formatted: Y-values for processed results. Goes with ``x_formatted``. + y_formatted_err: Error in ``y_formatted``, to be plotted as error-bars. + x_interp: Interpolated X-values for a curve-fit. + y_interp: Y-values corresponding to the fit for ``y_interp`` X-values. + y_interp_err: The standard-deviations of the fit for each X-value in ``y_interp``. + This data-key relates to the option ``plot_sigma``. + """ + return [ + "x", + "y", + "x_formatted", + "y_formatted", + "y_formatted_err", + "x_interp", + "y_interp", + "y_interp_err", + ] + + @classmethod + def expected_supplementary_data_keys(cls) -> List[str]: + """Returns the expected figures data-keys supported by this plotter. + + Data Keys: + report_text: A string containing any fit report information to be drawn in a box. + The style and position of the report is controlled by ``textbox_rel_pos`` and + ``textbox_text_size`` style parameters in :class:`PlotStyle`. + """ + return [ + "report_text", + ] + + @classmethod + def _default_options(cls) -> Options: + """Return curve-plotter specific default plotter options. + + Options: + plot_sigma (List[Tuple[float, float]]): A list of two number tuples + showing the configuration to write confidence intervals for the fit curve. + The first argument is the relative sigma (n_sigma), and the second argument is + the transparency of the interval plot in ``[0, 1]``. + Multiple n_sigma intervals can be drawn for the same curve. + + """ + options = super()._default_options() + options.plot_sigma = [(1.0, 0.7), (3.0, 0.3)] + return options + + def _plot_figure(self): + """Plots a curve-fit figure.""" + for ser in self.series: + # Scatter plot with error-bars + plotted_formatted_data = False + if self.data_exists_for(ser, ["x_formatted", "y_formatted", "y_formatted_err"]): + x, y, yerr = self.data_for(ser, ["x_formatted", "y_formatted", "y_formatted_err"]) + self.drawer.scatter(x, y, y_err=yerr, name=ser, zorder=2, legend=True) + plotted_formatted_data = True + + # Scatter plot + if self.data_exists_for(ser, ["x", "y"]): + x, y = self.data_for(ser, ["x", "y"]) + options = { + "zorder": 1, + } + # If we plotted formatted data, differentiate scatter points by setting normal X-Y + # markers to gray. + if plotted_formatted_data: + options["color"] = "gray" + # If we didn't plot formatted data, the X-Y markers should be used for the legend. We add + # it to ``options`` so it's easier to pass to ``scatter``. + if not plotted_formatted_data: + options["legend"] = True + self.drawer.scatter( + x, + y, + name=ser, + **options, + ) + + # Line plot for fit + if self.data_exists_for(ser, ["x_interp", "y_interp"]): + x, y = self.data_for(ser, ["x_interp", "y_interp"]) + self.drawer.line(x, y, name=ser, zorder=3) + + # Confidence interval plot + if self.data_exists_for(ser, ["x_interp", "y_interp", "y_interp_err"]): + x, y_interp, y_interp_err = self.data_for( + ser, ["x_interp", "y_interp", "y_interp_err"] + ) + for n_sigma, alpha in self.options.plot_sigma: + self.drawer.filled_y_area( + x, + y_interp + n_sigma * y_interp_err, + y_interp - n_sigma * y_interp_err, + name=ser, + alpha=alpha, + zorder=5, + ) + + # Fit report + if "report_text" in self.supplementary_data: + report_text = self.supplementary_data["report_text"] + self.drawer.textbox(report_text) diff --git a/qiskit_experiments/visualization/style.py b/qiskit_experiments/visualization/style.py new file mode 100644 index 0000000000..1a0a4bc077 --- /dev/null +++ b/qiskit_experiments/visualization/style.py @@ -0,0 +1,91 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Configurable stylesheet for :class:`BasePlotter` and :class:`BaseDrawer`. +""" + + +class PlotStyle(dict): + """A stylesheet for :class:`BasePlotter` and :class:`BaseDrawer`. + + This style class is used by :class:`BasePlotter` and :class:`BaseDrawer`. The default style for + Qiskit Experiments is defined in :meth:`default_style`. Style parameters are stored as dictionary + entries, grouped by graphics or figure component. For example, style parameters relating to textboxes + have the prefix ``textbox_``. For default style parameter names and their values, see the + :meth:`default_style` method. + + Example: + .. code-block:: python + # Create custom style + custom_style = PlotStyle( + { + "legend_loc": "upper right", + "textbox_rel_pos": (1, 1), + "textbox_text_size": 14, + } + ) + + # Create full style, using PEP448 to combine with default style. + full_style = PlotStyle.merge(PlotStyle.default_style(), custom_style) + + # Query style parameters + full_style["legend_loc"] # Returns "upper right" + full_style["axis_label_size"] # Returns the value provided in ``PlotStyle.default_style()`` + """ + + @classmethod + def default_style(cls) -> "PlotStyle": + """The default style across Qiskit Experiments. + + Style Parameters: + figsize (Tuple[int,int]): The size of the figure ``(width, height)``, in inches. + legend_loc (str): The location of the legend. + tick_label_size (int): The font size for tick labels. + axis_label_size (int): The font size for axis labels. + textbox_rel_pos (Tuple[float,float]): The relative position ``(horizontal, vertical)`` of + textboxes, as a percentage of the canvas dimensions. + textbox_text_size (int): The font size for textboxes. + + Returns: + PlotStyle: The default plot style used by Qiskit Experiments. + """ + style = { + # size of figure (width, height) + "figsize": (8, 5), # Tuple[int, int] + # legend location (vertical, horizontal) + "legend_loc": "center right", # str + # size of tick label + "tick_label_size": 14, # int + # size of axis label + "axis_label_size": 16, # int + # relative position of a textbox + "textbox_rel_pos": (0.6, 0.95), # Tuple[float, float] + # size of textbox text + "textbox_text_size": 14, # int + } + return cls(**style) + + @classmethod + def merge(cls, style1: "PlotStyle", style2: "PlotStyle") -> "PlotStyle": + """Merge ``style2`` into ``style1`` as a new PlotStyle instance. + + This method merges an additional style ``style2`` into a base instance ``style1``, returning the + merged style instance instead of modifying the inputs. + + Args: + style1: Base PlotStyle instance. + style2: Additional PlotStyle instance. + + Returns: + PlotStyle: merged style instance. + """ + return PlotStyle({**style1, **style2}) diff --git a/releasenotes/notes/add-new-visualization-module-9c6a84f2813459a7.yaml b/releasenotes/notes/add-new-visualization-module-9c6a84f2813459a7.yaml new file mode 100644 index 0000000000..2a0ded7126 --- /dev/null +++ b/releasenotes/notes/add-new-visualization-module-9c6a84f2813459a7.yaml @@ -0,0 +1,11 @@ +--- +features: + - | + Added new visualization module to plot figures and draw onto figure canvases. The new module contains + plotters and drawers, which integrate with CurveAnalysis but can be used independently of the + analysis classes. This module replaces the old and now deprecated + `qiskit_experiments.curve_analysis.visualization` submodule. +deprecations: + - | + Deprecated `qiskit_experiments.curve_analysis.visualization` submodule as it is replaced by the new + `qiskit_experiments.visualization` submodule. diff --git a/test/base.py b/test/base.py index d234bc4215..41915c2b8d 100644 --- a/test/base.py +++ b/test/base.py @@ -32,13 +32,30 @@ BaseExperiment, BaseAnalysis, ) -from qiskit_experiments.curve_analysis.visualization.base_drawer import BaseCurveDrawer +from qiskit_experiments.visualization import BaseDrawer from qiskit_experiments.curve_analysis.curve_data import CurveFitResult class QiskitExperimentsTestCase(QiskitTestCase): """Qiskit Experiments specific extra functionality for test cases.""" + @classmethod + def setUpClass(cls): + """Set-up test class.""" + super().setUpClass() + + # Some functionality may be deprecated in Qiskit Experiments. If the deprecation warnings aren't + # filtered, the tests will fail as ``QiskitTestCase`` sets all warnings to be treated as an error + # by default. + # pylint: disable=invalid-name + allow_deprecationwarning_message = [ + # TODO: Remove in 0.6, when submodule `.curve_analysis.visualization` is removed. + r".*Plotting and drawing functionality has been moved", + r".*Legacy drawers from `.curve_analysis.visualization are deprecated", + ] + for msg in allow_deprecationwarning_message: + warnings.filterwarnings("default", category=DeprecationWarning, message=msg) + def assertExperimentDone( self, experiment_data: ExperimentData, @@ -109,7 +126,7 @@ def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None) def json_equiv(cls, data1, data2) -> bool: """Check if two experiments are equivalent by comparing their configs""" # pylint: disable = too-many-return-statements - configurable_type = (BaseExperiment, BaseAnalysis, BaseCurveDrawer) + configurable_type = (BaseExperiment, BaseAnalysis, BaseDrawer) compare_repr = (DataAction, DataProcessor) list_type = (list, tuple, set) skipped = tuple() diff --git a/test/curve_analysis/test_baseclass.py b/test/curve_analysis/test_baseclass.py index 55d75a46fd..fd8033ff25 100644 --- a/test/curve_analysis/test_baseclass.py +++ b/test/curve_analysis/test_baseclass.py @@ -205,7 +205,7 @@ class InvalidClass: analysis.set_options(data_processor=InvalidClass()) with self.assertRaises(TypeError): - analysis.set_options(curve_drawer=InvalidClass()) + analysis.set_options(plotter=InvalidClass()) def test_end_to_end_single_function(self): """Integration test for single function.""" diff --git a/test/visualization/__init__.py b/test/visualization/__init__.py new file mode 100644 index 0000000000..441ae06e7f --- /dev/null +++ b/test/visualization/__init__.py @@ -0,0 +1,12 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +"""Test cases for visualization plotting.""" diff --git a/test/visualization/mock_drawer.py b/test/visualization/mock_drawer.py new file mode 100644 index 0000000000..3221dc67d2 --- /dev/null +++ b/test/visualization/mock_drawer.py @@ -0,0 +1,110 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Mock drawer for testing. +""" + +from typing import Optional, Sequence, Tuple + +from qiskit_experiments.visualization import BaseDrawer, PlotStyle + + +class MockDrawer(BaseDrawer): + """Mock drawer for visualization tests. + + Most methods of this class do nothing. + """ + + @property + def figure(self): + """Does nothing.""" + pass + + @classmethod + def _default_style(cls) -> PlotStyle: + """Default style. + + Style Param: + overwrite_param: A test style parameter to be overwritten by a test. + """ + style = super()._default_style() + style["overwrite_param"] = "overwrite_param" + return style + + def initialize_canvas(self): + """Does nothing.""" + pass + + def format_canvas(self): + """Does nothing.""" + pass + + def scatter( + self, + x_data: Sequence[float], + y_data: Sequence[float], + x_err: Optional[Sequence[float]] = None, + y_err: Optional[Sequence[float]] = None, + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Does nothing.""" + pass + + def line( + self, + x_data: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Does nothing.""" + pass + + def filled_y_area( + self, + x_data: Sequence[float], + y_ub: Sequence[float], + y_lb: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Does nothing.""" + pass + + def filled_x_area( + self, + x_ub: Sequence[float], + x_lb: Sequence[float], + y_data: Sequence[float], + name: Optional[str] = None, + label: Optional[str] = None, + legend: bool = False, + **options, + ): + """Does nothing.""" + pass + + def textbox( + self, + description: str, + rel_pos: Optional[Tuple[float, float]] = None, + **options, + ): + """Does nothing.""" + pass diff --git a/test/visualization/mock_plotter.py b/test/visualization/mock_plotter.py new file mode 100644 index 0000000000..c693c9ba44 --- /dev/null +++ b/test/visualization/mock_plotter.py @@ -0,0 +1,72 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Mock plotter for testing. +""" + +from typing import List + +from qiskit_experiments.visualization import BaseDrawer, BasePlotter + + +class MockPlotter(BasePlotter): + """Mock plotter for visualization tests. + + If :attr:`plotting_enabled` is true, :class:`MockPlotter` will plot formatted data. + :attr:`plotting_enabled` defaults to false as most test usage of the class uses :class:`MockDrawer`, + which doesn't generate a useful figure. + """ + + def __init__(self, drawer: BaseDrawer, plotting_enabled: bool = False): + """Construct a mock plotter instance for testing. + + Args: + drawer: The drawer to use for plotting + plotting_enabled: Whether to actually plot using :attr:`drawer` or not. Defaults to False. + """ + super().__init__(drawer) + self._plotting_enabled = plotting_enabled + + @property + def plotting_enabled(self): + """Whether :class:`MockPlotter` should plot data. + + Defaults to False during construction. + """ + return self._plotting_enabled + + def _plot_figure(self): + """Plots a figure if :attr:`plotting_enabled` is True. + + If :attr:`plotting_enabled` is True, :class:`MockPlotter` calls + :meth:`~BaseDrawer.scatter` for a series titled ``seriesA`` with ``x``, ``y``, and + ``z`` data-keys assigned to the x and y values and the y-error/standard deviation respectively. + If :attr:`drawer` generates a figure, then :meth:`figure` should return a scatterplot figure with + error-bars. + """ + if self.plotting_enabled: + self.drawer.scatter(*self.data_for("seriesA", ["x", "y", "z"]), "seriesA") + + @classmethod + def expected_series_data_keys(cls) -> List[str]: + """Dummy data-keys. + + Data Keys: + x: Dummy value. + y: Dummy value. + z: Dummy value. + """ + return ["x", "y", "z"] + + @classmethod + def expected_supplementary_data_keys(cls) -> List[str]: + return [] diff --git a/test/visualization/test_mpldrawer.py b/test/visualization/test_mpldrawer.py new file mode 100644 index 0000000000..2ecb377fb9 --- /dev/null +++ b/test/visualization/test_mpldrawer.py @@ -0,0 +1,44 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Test Matplotlib Drawer. +""" + +from test.base import QiskitExperimentsTestCase + +import matplotlib + +from qiskit_experiments.visualization import MplDrawer + + +class TestMplDrawer(QiskitExperimentsTestCase): + """Test MplDrawer.""" + + def test_end_to_end(self): + """Test that MplDrawer generates something.""" + drawer = MplDrawer() + + # Draw dummy data + drawer.initialize_canvas() + drawer.scatter([0, 1, 2], [0, 1, 2], name="seriesA") + drawer.scatter([0, 1, 2], [0, 1, 2], [0.1, 0.1, 0.1], None, name="seriesA") + drawer.line([3, 2, 1], [1, 2, 3], name="seriesB") + drawer.filled_x_area([0, 1, 2, 3], [1, 2, 1, 2], [-1, -2, -1, -2], name="seriesB") + drawer.filled_y_area([-1, 0, 1, 2], [-1, -2, -1, -2], [1, 2, 1, 2], name="seriesB") + drawer.textbox(r"Dummy report text with LaTex $\beta$") + + # Get result + fig = drawer.figure + + # Check that + self.assertTrue(fig is not None) + self.assertTrue(isinstance(fig, matplotlib.pyplot.Figure)) diff --git a/test/visualization/test_plotter.py b/test/visualization/test_plotter.py new file mode 100644 index 0000000000..e8cd29baf9 --- /dev/null +++ b/test/visualization/test_plotter.py @@ -0,0 +1,87 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Test integration of plotter. +""" + +from copy import copy +from test.base import QiskitExperimentsTestCase + +from .mock_drawer import MockDrawer +from .mock_plotter import MockPlotter + + +class TestPlotter(QiskitExperimentsTestCase): + """Test the generic plotter interface.""" + + def test_warn_unknown_data_key(self): + """Test that setting an unknown data-key raises a warning.""" + plotter = MockPlotter(MockDrawer()) + + # TODO: Add check for no-warnings. assertNoWarns only available from Python 3.10+ + + # An unknown data-key must raise a warning if it is used to set series data. + with self.assertWarns(UserWarning): + plotter.set_series_data("dummy_series", unknown_data_key=[0, 1, 2]) + + def test_series_data_end_to_end(self): + """Test end-to-end for series data setting and retrieving.""" + plotter = MockPlotter(MockDrawer()) + + series_data = { + "seriesA": { + "x": 0, + "y": "1", + "z": [2], + }, + "seriesB": { + "x": 1, + "y": 0.5, + }, + } + unexpected_data = ["a", True, 0] + expected_series_data = copy(series_data) + expected_series_data["seriesA"]["unexpected_data"] = unexpected_data + + for series, data in series_data.items(): + plotter.set_series_data(series, **data) + + with self.assertWarns(UserWarning): + plotter.set_series_data("seriesA", unexpected_data=unexpected_data) + + for series, data in expected_series_data.items(): + self.assertTrue(series in plotter.series) + self.assertTrue(plotter.data_exists_for(series, list(data.keys()))) + for data_key, value in data.items(): + # Must index [0] for `data_for` as it returns a tuple. + self.assertEqual(value, plotter.data_for(series, data_key)[0]) + + def test_supplementary_data_end_to_end(self): + """Test end-to-end for figure data setting and retrieval.""" + plotter = MockPlotter(MockDrawer()) + + expected_supplementary_data = { + "report_text": "Lorem ipsum", + "another_data_key": 3e9, + } + + plotter.set_supplementary_data(**expected_supplementary_data) + + # Check if figure data has been stored and can be retrieved + for key, expected_value in expected_supplementary_data.items(): + actual_value = plotter.supplementary_data[key] + self.assertEqual( + expected_value, + actual_value, + msg=f"Actual figure data value for {key} data-key is not as expected: {actual_value} " + f"(actual) vs {expected_value} (expected)", + ) diff --git a/test/visualization/test_plotter_drawer.py b/test/visualization/test_plotter_drawer.py new file mode 100644 index 0000000000..1d7dd8c617 --- /dev/null +++ b/test/visualization/test_plotter_drawer.py @@ -0,0 +1,159 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Test integration of plotters and drawers. +""" + +from copy import copy +from test.base import QiskitExperimentsTestCase + +from qiskit_experiments.framework import Options +from qiskit_experiments.visualization import BasePlotter, PlotStyle + +from .mock_drawer import MockDrawer +from .mock_plotter import MockPlotter + + +def dummy_plotter() -> BasePlotter: + """Return a MockPlotter with dummy option values. + + Returns: + BasePlotter: A dummy plotter. + """ + plotter = MockPlotter(MockDrawer()) + # Set dummy plot options to update + plotter.set_figure_options( + xlabel="xlabel", + ylabel="ylabel", + figure_title="figure_title", + non_drawer_options="should not be set", + ) + plotter.set_options( + style=PlotStyle(test_param="test_param", overwrite_param="new_overwrite_param_value") + ) + return plotter + + +class TestPlotterAndDrawerIntegration(QiskitExperimentsTestCase): + """Test plotter and drawer integration.""" + + def assertOptionsEqual( + self, + options1: Options, + options2: Options, + msg_prefix: str = "", + only_assert_for_intersection: bool = False, + ): + """Asserts that two options are the same by checking each individual option. + + This method is easier to read than a standard equality assertion as individual option names are + printed. + + Args: + options1: The first Options instance to check. + options2: The second Options instance to check. + msg_prefix: A prefix to add before assert messages. + only_assert_for_intersection: If True, will only check options that are in both Options + instances. Defaults to False. + """ + # Get combined field names + if only_assert_for_intersection: + fields = set(options1._fields.keys()).intersection(set(options2._fields.keys())) + else: + fields = set(options1._fields.keys()).union(set(options2._fields.keys())) + + # Check individual options. + for key in fields: + # Check if the option exists in both + self.assertTrue( + hasattr(options1, key), + msg=f"[{msg_prefix}] Expected field {key} in both, but only found in one: not in " + f"{options1}.", + ) + self.assertTrue( + hasattr(options2, key), + msg=f"[{msg_prefix}] Expected field {key} in both, but only found in one: not in " + f"{options2}.", + ) + self.assertEqual( + getattr(options1, key), + getattr(options2, key), + msg=f"[{msg_prefix}] Expected equal values for option '{key}': " + f"{getattr(options1, key),} vs {getattr(options2,key)}", + ) + + def test_figure_options(self): + """Test copying and passing of plot-options between plotter and drawer.""" + plotter = dummy_plotter() + + # Expected options + expected_figure_options = copy(plotter.drawer.figure_options) + expected_figure_options.xlabel = "xlabel" + expected_figure_options.ylabel = "ylabel" + expected_figure_options.figure_title = "figure_title" + + # Expected style + expected_custom_style = PlotStyle( + test_param="test_param", overwrite_param="new_overwrite_param_value" + ) + plotter.set_options(style=expected_custom_style) + expected_full_style = PlotStyle.merge( + plotter.drawer.options.default_style, expected_custom_style + ) + expected_figure_options.custom_style = expected_custom_style + + # Call plotter.figure() to force passing of figure_options to drawer + plotter.figure() + + ## Test values + # Check style as this is a more detailed plot-option than others. + self.assertEqual(expected_full_style, plotter.drawer.style) + + # Check individual plot-options, but only the intersection as those are the ones we expect to be + # updated. + self.assertOptionsEqual(expected_figure_options, plotter.drawer.figure_options, True) + + # Coarse equality check of figure_options + self.assertEqual( + expected_figure_options, + plotter.drawer.figure_options, + msg=rf"expected_figure_options = {expected_figure_options}\nactual_figure_options =" + rf"{plotter.drawer.figure_options}", + ) + + def test_serializable(self): + """Test that plotter is serializable.""" + original_plotter = dummy_plotter() + + def check_options(original, new): + """Verifies that ``new`` plotter has the same options as ``original`` plotter.""" + self.assertOptionsEqual(original.options, new.options, "options") + self.assertOptionsEqual(original.figure_options, new.figure_options, "figure_options") + self.assertOptionsEqual(original.drawer.options, new.drawer.options, "drawer.options") + self.assertOptionsEqual( + original.drawer.figure_options, new.drawer.figure_options, "drawer.figure_options" + ) + + ## Check that plotter, BEFORE PLOTTING, survives serialization correctly. + # HACK: A dedicated JSON encoder and decoder class would be better. + # __json___ are not typically called, instead json.dumps etc. is called + encoded = original_plotter.__json_encode__() + decoded_plotter = original_plotter.__class__.__json_decode__(encoded) + check_options(original_plotter, decoded_plotter) + + ## Check that plotter, AFTER PLOTTING, survives serialization correctly. + original_plotter.figure() + # HACK: A dedicated JSON encoder and decoder class would be better. + # __json___ are not typically called, instead json.dumps etc. is called + encoded = original_plotter.__json_encode__() + decoded_plotter = original_plotter.__class__.__json_decode__(encoded) + check_options(original_plotter, decoded_plotter) diff --git a/test/visualization/test_plotter_mpldrawer.py b/test/visualization/test_plotter_mpldrawer.py new file mode 100644 index 0000000000..5c76298233 --- /dev/null +++ b/test/visualization/test_plotter_mpldrawer.py @@ -0,0 +1,40 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Test integration of plotter with Matplotlib drawer. +""" + +from test.base import QiskitExperimentsTestCase + +import matplotlib + +from qiskit_experiments.visualization import MplDrawer + +from .mock_plotter import MockPlotter + + +class TestPlotterAndMplDrawer(QiskitExperimentsTestCase): + """Test generic plotter with Matplotlib drawer.""" + + def test_end_to_end(self): + """Test whether plotter with MplDrawer returns a figure.""" + plotter = MockPlotter(MplDrawer()) + plotter.set_series_data( + "seriesA", x=[0, 1, 2, 3, 4, 5], y=[0, 1, 0, 1, 0, 1], z=[0.1, 0.1, 0.3, 0.4, 0.0] + ) + fig = plotter.figure() + + # Expect something + self.assertTrue(fig is not None) + + # Expect a specific type + self.assertTrue(isinstance(fig, matplotlib.pyplot.Figure)) diff --git a/test/visualization/test_style.py b/test/visualization/test_style.py new file mode 100644 index 0000000000..a52bcfd82e --- /dev/null +++ b/test/visualization/test_style.py @@ -0,0 +1,154 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Test visualization plotter. +""" + +from copy import copy +from test.base import QiskitExperimentsTestCase +from typing import Tuple + +from qiskit_experiments.visualization import PlotStyle + + +class TestPlotStyle(QiskitExperimentsTestCase): + """Test PlotStyle""" + + @classmethod + def _dummy_styles(cls) -> Tuple[PlotStyle, PlotStyle, PlotStyle, PlotStyle]: + """Returns dummy input styles for PlotStyle tests. + + Returns: + PlotStyle: First input style. + PlotStyle: Second input style. + PlotStyle: Expected style combining second into first. + PlotStyle: Expected style combining first into second. + """ + custom_1 = PlotStyle(overwrite_field=0, unchanged_field_A="Test", none_field=[0, 1, 2, 3]) + custom_2 = PlotStyle(overwrite_field=6, unchanged_field_B=0.5, none_field=None) + expected_12 = PlotStyle( + overwrite_field=6, + unchanged_field_A="Test", + unchanged_field_B=0.5, + none_field=None, + ) + expected_21 = PlotStyle( + overwrite_field=0, + unchanged_field_A="Test", + unchanged_field_B=0.5, + none_field=[0, 1, 2, 3], + ) + return custom_1, custom_2, expected_12, expected_21 + + def test_default_only_contains_expected_fields(self): + """Test that only expected fields are set in the default style. + + This enforces two things: + 1. The expected style fields are not None. + 2. No extra fields are set. + + The second property being enforced is to make sure that this test fails if new default style + parameters are added to :meth:`PlotStyle.default_style` but not to this test. + """ + default = PlotStyle.default_style() + expected_not_none_fields = [ + "figsize", + "legend_loc", + "tick_label_size", + "axis_label_size", + "textbox_rel_pos", + "textbox_text_size", + ] + for field in expected_not_none_fields: + self.assertIsNotNone(default.get(field, None)) + # Check that default style keys are as expected, ignoring order. + self.assertCountEqual(expected_not_none_fields, list(default.keys())) + + def test_update(self): + """Test that styles can be updated.""" + custom_1, custom_2, expected_12, expected_21 = self._dummy_styles() + + # copy(...) is needed as .update() modifies the style instance + actual_12 = copy(custom_1) + actual_12.update(**custom_2) + actual_21 = copy(custom_2) + actual_21.update(**custom_1) + + self.assertDictEqual(actual_12, expected_12) + self.assertDictEqual(actual_21, expected_21) + + def test_merge_in_init(self): + """Test that styles can be merged.""" + custom_1, custom_2, expected_12, expected_21 = self._dummy_styles() + + self.assertDictEqual(PlotStyle.merge(custom_1, custom_2), expected_12) + self.assertDictEqual(PlotStyle.merge(custom_2, custom_1), expected_21) + + def test_field_access(self): + """Test that fields are accessed correctly""" + dummy_style = PlotStyle( + x="x", + # y isn't assigned and therefore doesn't exist in dummy_style + ) + + self.assertEqual(dummy_style["x"], "x") + + # This should throw as we haven't assigned y + with self.assertRaises(KeyError): + # Disable pointless-statement as accessing style fields can raise an exception. + # pylint: disable = pointless-statement + dummy_style["y"] + + def test_dict(self): + """Test that PlotStyle can be treated as a dictionary.""" + dummy_style = PlotStyle( + a="a", + b=0, + c=[1, 2, 3], + ) + expected_dict = { + "a": "a", + "b": 0, + "c": [1, 2, 3], + } + actual_dict = dict(dummy_style) + self.assertDictEqual(actual_dict, expected_dict, msg="PlotStyle dict is not as expected.") + + # Add a new variable + dummy_style["new_variable"] = 5e9 + expected_dict["new_variable"] = 5e9 + actual_dict = dict(dummy_style) + self.assertDictEqual( + actual_dict, + expected_dict, + msg="PlotStyle dict is not as expected, with post-init variables.", + ) + + def test_update_dict(self): + """Test that PlotStyle dictionary is correct when updated.""" + custom_1, custom_2, expected_12, expected_21 = self._dummy_styles() + + # copy(...) is needed as .update() modifies the style instance + actual_12 = copy(custom_1) + actual_12.update(custom_2) + actual_21 = copy(custom_2) + actual_21.update(custom_1) + + self.assertDictEqual(actual_12, expected_12) + self.assertDictEqual(actual_21, expected_21) + + def test_merge_dict(self): + """Test that PlotStyle dictionary is correct when merged.""" + custom_1, custom_2, expected_12, expected_21 = self._dummy_styles() + + self.assertDictEqual(PlotStyle.merge(custom_1, custom_2), expected_12) + self.assertDictEqual(PlotStyle.merge(custom_2, custom_1), expected_21)