Skip to content

Commit

Permalink
Add new plotting module for experiment and analysis figures (#902)
Browse files Browse the repository at this point in the history
* The visualization code has been moved to its own module. It now also implements a bridge interface 
where a plotter determines what data is plotted and how while a drawer implements a plotting backend
(such as matplotlib).

Co-authored-by: Daniel J. Egger <[email protected]>
Co-authored-by: Naoki Kanazawa <[email protected]>
  • Loading branch information
3 people authored Oct 6, 2022
1 parent 8c11b2b commit cd6f92a
Show file tree
Hide file tree
Showing 44 changed files with 3,006 additions and 154 deletions.
4 changes: 2 additions & 2 deletions qiskit_experiments/curve_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ class AnalysisB(CurveAnalysis):
compute custom quantities based on the raw fit parameters.
See :ref:`curve_analysis_results` for details.
Afterwards, the analysis draws several curves in the Matplotlib figure.
User can set custom drawer to the option ``curve_drawer``.
The drawer defaults to the :class:`MplCurveDrawer`.
Users can set a custom plotter in :class:`CurveAnalysis` classes, to customize
figures, by setting the :attr:`~CurveAnalysis.plotter` attribute.
Finally, it returns the list of created analysis results and Matplotlib figure.
Expand Down
66 changes: 56 additions & 10 deletions qiskit_experiments/curve_analysis/base_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,28 @@

import warnings
from abc import ABC, abstractmethod
from typing import List, Dict, Union
from typing import Dict, List, Union

import lmfit

from qiskit_experiments.data_processing import DataProcessor
from qiskit_experiments.data_processing.processor_library import get_processor
from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, Options, ExperimentData
from .curve_data import CurveData, ParameterRepr, CurveFitResult
from .visualization import MplCurveDrawer, BaseCurveDrawer
from qiskit_experiments.framework import (
AnalysisResultData,
BaseAnalysis,
ExperimentData,
Options,
)
from qiskit_experiments.visualization import (
BaseDrawer,
BasePlotter,
CurvePlotter,
LegacyCurveCompatDrawer,
MplDrawer,
)
from qiskit_experiments.warnings import deprecated_function

from .curve_data import CurveData, CurveFitResult, ParameterRepr

PARAMS_ENTRY_PREFIX = "@Parameters_"
DATA_ENTRY_PREFIX = "@Data_"
Expand Down Expand Up @@ -113,16 +126,28 @@ def models(self) -> List[lmfit.Model]:
"""Return fit models."""

@property
def drawer(self) -> BaseCurveDrawer:
"""A short-cut for curve drawer instance."""
return self._options.curve_drawer
def plotter(self) -> BasePlotter:
"""A short-cut to the curve plotter instance."""
return self._options.plotter

@property
@deprecated_function(
last_version="0.6",
msg="Replaced by `plotter` from the new visualization submodule.",
)
def drawer(self) -> BaseDrawer:
"""A short-cut for curve drawer instance, if set. ``None`` otherwise."""
if isinstance(self.plotter.drawer, LegacyCurveCompatDrawer):
return self.plotter.drawer._curve_drawer
else:
return None

@classmethod
def _default_options(cls) -> Options:
"""Return default analysis options.
Analysis Options:
curve_drawer (BaseCurveDrawer): A curve drawer instance to visualize
plotter (BasePlotter): A curve plotter instance to visualize
the analysis result.
plot_raw_data (bool): Set ``True`` to draw processed data points,
dataset without formatting, on canvas. This is ``False`` by default.
Expand Down Expand Up @@ -168,7 +193,7 @@ def _default_options(cls) -> Options:
"""
options = super()._default_options()

options.curve_drawer = MplCurveDrawer()
options.plotter = CurvePlotter(MplDrawer())
options.plot_raw_data = False
options.plot = True
options.return_fit_parameters = True
Expand All @@ -187,7 +212,7 @@ def _default_options(cls) -> Options:

# Set automatic validator for particular option values
options.set_validator(field="data_processor", validator_value=DataProcessor)
options.set_validator(field="curve_drawer", validator_value=BaseCurveDrawer)
options.set_validator(field="plotter", validator_value=BasePlotter)

return options

Expand All @@ -211,6 +236,27 @@ def set_options(self, **fields):
)
fields["lmfit_options"] = fields.pop("curve_fitter_options")

# TODO remove this in Qiskit Experiments 0.6
if "curve_drawer" in fields:
warnings.warn(
"The option 'curve_drawer' is replaced with 'plotter'. "
"This option will be removed in Qiskit Experiments 0.6.",
DeprecationWarning,
stacklevel=2,
)
# Set the plotter drawer to `curve_drawer`. If `curve_drawer` is the right type, set it
# directly. If not, wrap it in a compatibility drawer.
if isinstance(fields["curve_drawer"], BaseDrawer):
plotter = self.options.plotter
plotter.drawer = fields.pop("curve_drawer")
fields["plotter"] = plotter
else:
drawer = fields["curve_drawer"]
compat_drawer = LegacyCurveCompatDrawer(drawer)
plotter = self.options.plotter
plotter.drawer = compat_drawer
fields["plotter"] = plotter

super().set_options(**fields)

@abstractmethod
Expand Down
133 changes: 85 additions & 48 deletions qiskit_experiments/curve_analysis/composite_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,31 @@
"""
# pylint: disable=invalid-name
import warnings
from typing import Dict, List, Tuple, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import lmfit
import numpy as np
from uncertainties import unumpy as unp, UFloat

from qiskit_experiments.framework import BaseAnalysis, ExperimentData, AnalysisResultData, Options
from .base_curve_analysis import BaseCurveAnalysis, PARAMS_ENTRY_PREFIX
from uncertainties import UFloat
from uncertainties import unumpy as unp

from qiskit_experiments.framework import (
AnalysisResultData,
BaseAnalysis,
ExperimentData,
Options,
)
from qiskit_experiments.visualization import (
BaseDrawer,
BasePlotter,
CurvePlotter,
LegacyCurveCompatDrawer,
MplDrawer,
)
from qiskit_experiments.warnings import deprecated_function

from .base_curve_analysis import PARAMS_ENTRY_PREFIX, BaseCurveAnalysis
from .curve_data import CurveFitResult
from .utils import analysis_result_to_repr, eval_with_uncertainties
from .visualization import MplCurveDrawer, BaseCurveDrawer


class CompositeCurveAnalysis(BaseAnalysis):
Expand Down Expand Up @@ -124,9 +138,21 @@ def models(self) -> Dict[str, List[lmfit.Model]]:
return models

@property
def drawer(self) -> BaseCurveDrawer:
"""A short-cut for curve drawer instance."""
return self._options.curve_drawer
def plotter(self) -> BasePlotter:
"""A short-cut to the plotter instance."""
return self._options.plotter

@property
@deprecated_function(
last_version="0.6",
msg="Replaced by `plotter` from the new visualization submodule.",
)
def drawer(self) -> BaseDrawer:
"""A short-cut for curve drawer instance, if set. ``None`` otherwise."""
if hasattr(self._options, "curve_drawer"):
return self._options.curve_drawer
else:
return None

def analyses(
self, index: Optional[Union[str, int]] = None
Expand Down Expand Up @@ -187,7 +213,7 @@ def _default_options(cls) -> Options:
"""Default analysis options.
Analysis Options:
curve_drawer (BaseCurveDrawer): A curve drawer instance to visualize
plotter (BasePlotter): A plotter instance to visualize
the analysis result.
plot (bool): Set ``True`` to create figure for fit result.
This is ``True`` by default.
Expand All @@ -200,19 +226,40 @@ def _default_options(cls) -> Options:
"""
options = super()._default_options()
options.update_options(
curve_drawer=MplCurveDrawer(),
plotter=CurvePlotter(MplDrawer()),
plot=True,
return_fit_parameters=True,
return_data_points=False,
extra={},
)

# Set automatic validator for particular option values
options.set_validator(field="curve_drawer", validator_value=BaseCurveDrawer)
options.set_validator(field="plotter", validator_value=BasePlotter)

return options

def set_options(self, **fields):
# TODO remove this in Qiskit Experiments 0.6
if "curve_drawer" in fields:
warnings.warn(
"The option 'curve_drawer' is replaced with 'plotter'. "
"This option will be removed in Qiskit Experiments 0.6.",
DeprecationWarning,
stacklevel=2,
)
# Set the plotter drawer to `curve_drawer`. If `curve_drawer` is the right type, set it
# directly. If not, wrap it in a compatibility drawer.
if isinstance(fields["curve_drawer"], BaseDrawer):
plotter = self.options.plotter
plotter.drawer = fields.pop("curve_drawer")
fields["plotter"] = plotter
else:
drawer = fields["curve_drawer"]
compat_drawer = LegacyCurveCompatDrawer(drawer)
plotter = self.options.plotter
plotter.drawer = compat_drawer
fields["plotter"] = plotter

for field in fields:
if not hasattr(self.options, field):
warnings.warn(
Expand All @@ -232,10 +279,6 @@ def _run_analysis(

analysis_results = []

# Initialize canvas
if self.options.plot:
self.drawer.initialize_canvas()

fit_dataset = {}
for analysis in self._analyses:
analysis._initialize(experiment_data)
Expand All @@ -251,22 +294,22 @@ def _run_analysis(
if self.options.plot and analysis.options.plot_raw_data:
for model in analysis.models:
sub_data = processed_data.get_subset_of(model._name)
self.drawer.draw_raw_data(
x_data=sub_data.x,
y_data=sub_data.y,
name=model._name + f"_{analysis.name}",
self.plotter.set_series_data(
model._name + f"_{analysis.name}",
x=sub_data.x,
y=sub_data.y,
)

# Format data
formatted_data = analysis._format_data(processed_data)
if self.options.plot:
for model in analysis.models:
sub_data = formatted_data.get_subset_of(model._name)
self.drawer.draw_formatted_data(
x_data=sub_data.x,
y_data=sub_data.y,
y_err_data=sub_data.y_err,
name=model._name + f"_{analysis.name}",
self.plotter.set_series_data(
model._name + f"_{analysis.name}",
x_formatted=sub_data.x,
y_formatted=sub_data.y,
y_formatted_err=sub_data.y_err,
)

# Run fitting
Expand Down Expand Up @@ -299,34 +342,30 @@ def _run_analysis(

# Draw fit result
if self.options.plot:
interp_x = np.linspace(
x_interp = np.linspace(
np.min(formatted_data.x), np.max(formatted_data.x), num=100
)
for model in analysis.models:
y_data_with_uncertainty = eval_with_uncertainties(
x=interp_x,
x=x_interp,
model=model,
params=fit_data.ufloat_params,
)
y_mean = unp.nominal_values(y_data_with_uncertainty)
# Draw fit line
self.drawer.draw_fit_line(
x_data=interp_x,
y_data=y_mean,
name=model._name + f"_{analysis.name}",
y_interp = unp.nominal_values(y_data_with_uncertainty)
# Add fit line data
self.plotter.set_series_data(
model._name + f"_{analysis.name}",
x_interp=x_interp,
y_interp=y_interp,
)
if fit_data.covar is not None:
# Draw confidence intervals with different n_sigma
sigmas = unp.std_devs(y_data_with_uncertainty)
if np.isfinite(sigmas).all():
for n_sigma, alpha in self.drawer.options.plot_sigma:
self.drawer.draw_confidence_interval(
x_data=interp_x,
y_ub=y_mean + n_sigma * sigmas,
y_lb=y_mean - n_sigma * sigmas,
name=model._name + f"_{analysis.name}",
alpha=alpha,
)
# Add confidence interval data
y_interp_err = unp.std_devs(y_data_with_uncertainty)
if np.isfinite(y_interp_err).all():
self.plotter.set_series_data(
model._name + f"_{analysis.name}",
y_interp_err=y_interp_err,
)

# Add raw data points
if self.options.return_data_points:
Expand Down Expand Up @@ -355,10 +394,8 @@ def _run_analysis(
for group, fit_data in fit_dataset.items():
chisqs.append(r"reduced-$\chi^2$ = " + f"{fit_data.reduced_chisq: .4g} ({group})")
report += "\n".join(chisqs)
self.drawer.draw_fit_report(description=report)
self.plotter.set_supplementary_data(report_text=report)

# Finalize canvas
self.drawer.format_canvas()
return analysis_results, [self.drawer.figure]
return analysis_results, [self.plotter.figure()]

return analysis_results, []
Loading

0 comments on commit cd6f92a

Please sign in to comment.