Skip to content
This repository has been archived by the owner on Jul 13, 2022. It is now read-only.

Commit

Permalink
Curve analysis drawer refactoring (qiskit-community#738)
Browse files Browse the repository at this point in the history
This PR replaces existing curve drawer with JSON serializable new drawer instance. 

- New drawer has dedicated to options to simplify the curve analysis options.
- New drawer is stateful (axis object is internally kept after initialized) so that we can draw different data immediately.
- Drawer code can be easily tracked by IDE.

Co-authored-by: Daniel J. Egger <[email protected]>
  • Loading branch information
nkanazawa1989 and eggerdj authored Apr 4, 2022
1 parent 4d39ee5 commit 64e9859
Show file tree
Hide file tree
Showing 21 changed files with 919 additions and 126 deletions.
6 changes: 5 additions & 1 deletion qiskit_experiments/curve_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FitData
ParameterRepr
FitOptions
MplCurveDrawer
Standard Analysis
=================
Expand Down Expand Up @@ -119,7 +120,7 @@
process_curve_data,
process_multi_curve_data,
)
from .visualization import plot_curve_fit, plot_errorbar, plot_scatter, FitResultPlotters
from .visualization import MplCurveDrawer
from . import guess
from . import fit_function

Expand All @@ -132,3 +133,6 @@
GaussianAnalysis,
ErrorAmplificationAnalysis,
)

# deprecated
from .visualization import plot_curve_fit, plot_errorbar, plot_scatter, FitResultPlotters
241 changes: 186 additions & 55 deletions qiskit_experiments/curve_analysis/curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from uncertainties import unumpy as unp

from qiskit.providers import Backend
from qiskit.utils import detach_prefix
from qiskit_experiments.curve_analysis.curve_data import (
CurveData,
SeriesDef,
Expand All @@ -37,7 +38,7 @@
)
from qiskit_experiments.curve_analysis.curve_fit import multi_curve_fit
from qiskit_experiments.curve_analysis.data_processing import multi_mean_xy_data, data_sort
from qiskit_experiments.curve_analysis.visualization import FitResultPlotters, PlotterStyle
from qiskit_experiments.curve_analysis.visualization import MplCurveDrawer, BaseCurveDrawer
from qiskit_experiments.data_processing import DataProcessor
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.data_processing.processor_library import get_processor
Expand Down Expand Up @@ -290,11 +291,22 @@ def parameters(self) -> List[str]:
"""Return parameters of this curve analysis."""
return [s for s in self._fit_params() if s not in self.options.fixed_parameters]

@property
def drawer(self) -> BaseCurveDrawer:
"""A short-cut for curve drawer instance."""
return self._options.curve_plotter

@classmethod
def _default_options(cls) -> Options:
"""Return default analysis options.
Analysis Options:
curve_plotter (BaseCurveDrawer): A curve drawer instance to visualize
the analysis result.
plot_raw_data (bool): Set ``True`` to draw un-formatted data points on canvas.
This is ``True`` by default.
plot (bool): Set ``True`` to create figure for fit result.
This is ``False`` by default.
curve_fitter (Callable): A callback function to perform fitting with formatted data.
See :func:`~qiskit_experiments.analysis.multi_curve_fit` for example.
data_processor (Callable): A callback function to format experiment data.
Expand All @@ -306,23 +318,6 @@ def _default_options(cls) -> Options:
bounds (Dict[str, Tuple[float, float]]): Array-like or dictionary
of (min, max) tuple of fit parameter boundaries.
x_key (str): Circuit metadata key representing a scanned value.
plot (bool): Set ``True`` to create figure for fit result.
axis (AxesSubplot): Optional. A matplotlib axis object to draw.
xlabel (str): X label of fit result figure.
ylabel (str): Y label of fit result figure.
xlim (Tuple[float, float]): Min and max value of horizontal axis of the fit plot.
ylim (Tuple[float, float]): Min and max value of vertical axis of the fit plot.
xval_unit (str): SI unit of x values. No prefix is needed here.
For example, when the x values represent time, this option will be just "s"
rather than "ms". In the fit result plot, the prefix is automatically selected
based on the maximum value. If your x values are in [1e-3, 1e-4], they
are displayed as [1 ms, 10 ms]. This option is likely provided by the
analysis class rather than end-users. However, users can still override
if they need different unit notation. By default, this option is set to ``None``,
and no scaling is applied. X axis will be displayed in the scientific notation.
yval_unit (str): Unit of y values. Same as ``xval_unit``.
This value is not provided in most experiments, because y value is usually
population or expectation values.
result_parameters (List[Union[str, ParameterRepr]): Parameters reported in the
database as a dedicated entry. This is a list of parameter representation
which is either string or ParameterRepr object. If you provide more
Expand All @@ -331,13 +326,6 @@ def _default_options(cls) -> Options:
The parameter name should be defined in the series definition.
Representation should be printable in standard output, i.e. no latex syntax.
return_data_points (bool): Set ``True`` to return formatted XY data.
curve_plotter (str): A name of plotter function used to generate
the curve fit result figure. This refers to the mapper
:py:class:`~qiskit_experiments.curve_analysis.visualization.FitResultPlotters`
to retrieve the corresponding callback function.
style (PlotterStyle): An instance of
:py:class:`~qiskit_experiments.curve_analysis.visualization.style.PlotterStyle`
that contains a set of configurations to create a fit plot.
extra (Dict[str, Any]): A dictionary that is appended to all database entries
as extra information.
curve_fitter_options (Dict[str, Any]) Options that are passed to the
Expand All @@ -348,22 +336,15 @@ def _default_options(cls) -> Options:
"""
options = super()._default_options()

options.curve_plotter = MplCurveDrawer()
options.plot_raw_data = False
options.plot = True
options.curve_fitter = multi_curve_fit
options.data_processor = None
options.normalization = False
options.x_key = "xval"
options.plot = True
options.axis = None
options.xlabel = None
options.ylabel = None
options.xlim = None
options.ylim = None
options.xval_unit = None
options.yval_unit = None
options.result_parameters = None
options.return_data_points = False
options.curve_plotter = "mpl_single_canvas"
options.style = PlotterStyle()
options.extra = dict()
options.curve_fitter_options = dict()
options.p0 = {}
Expand All @@ -372,6 +353,58 @@ def _default_options(cls) -> Options:

return options

def set_options(self, **fields):
"""Set the analysis options for :meth:`run` method.
Args:
fields: The fields to update the options
Raises:
KeyError: When removed option ``curve_fitter`` is set.
TypeError: When invalid drawer instance is provided.
"""
# TODO remove this in Qiskit Experiments v0.4
if "curve_plotter" in fields and isinstance(fields["curve_plotter"], str):
plotter_str = fields["curve_plotter"]
warnings.warn(
f"The curve plotter '{plotter_str}' has been deprecated. "
"The option is replaced with 'MplCurveDrawer' instance. "
"If this is a loaded analysis, please save this instance again to update option value. "
"This warning will be removed with backport in Qiskit Experiments 0.4.",
DeprecationWarning,
stacklevel=2,
)
fields["curve_plotter"] = MplCurveDrawer()

if "curve_plotter" in fields and not isinstance(fields["curve_plotter"], BaseCurveDrawer):
plotter_obj = fields["curve_plotter"]
raise TypeError(
f"'{plotter_obj.__class__.__name__}' object is not valid curve drawer instance."
)

# pylint: disable=no-member
draw_options = set(self.drawer.options.__dict__.keys()) | {"style"}
deprecated = draw_options & fields.keys()
if any(deprecated):
warnings.warn(
f"Option(s) {deprecated} have been moved to draw_options and will be removed soon. "
"Use self.drawer.set_options instead. "
"If this is a loaded analysis, please save this instance again to update option value. "
"This warning will be removed with backport in Qiskit Experiments 0.4.",
DeprecationWarning,
stacklevel=2,
)
draw_options = dict()
for depopt in deprecated:
if depopt == "style":
for k, v in fields.pop("style").items():
draw_options[k] = v
else:
draw_options[depopt] = fields.pop(depopt)
self.drawer.set_options(**draw_options)

super().set_options(**fields)

def _generate_fit_guesses(self, user_opt: FitOptions) -> Union[FitOptions, List[FitOptions]]:
"""Create algorithmic guess with analysis options and curve data.
Expand Down Expand Up @@ -786,6 +819,7 @@ def _run_analysis(
for series_def in self.__series__:
dict_def = dataclasses.asdict(series_def)
dict_def["fit_func"] = functools.partial(series_def.fit_func, **assigned_params)
del dict_def["signature"]
assigned_series.append(SeriesDef(**dict_def))
self.__series__ = assigned_series

Expand Down Expand Up @@ -946,8 +980,8 @@ def _run_analysis(
name=DATA_ENTRY_PREFIX + self.__class__.__name__,
value=raw_data_dict,
extra={
"x-unit": self.options.xval_unit,
"y-unit": self.options.yval_unit,
"x-unit": self.drawer.options.xval_unit,
"y-unit": self.drawer.options.yval_unit,
},
)
analysis_results.append(raw_data_entry)
Expand All @@ -956,24 +990,73 @@ def _run_analysis(
# 6. Create figures
#
if self.options.plot:
fit_figure = FitResultPlotters[self.options.curve_plotter].value.draw(
series_defs=self.__series__,
raw_samples=[self._data(ser.name, "raw_data") for ser in self.__series__],
fit_samples=[self._data(ser.name, "fit_ready") for ser in self.__series__],
tick_labels={
"xval_unit": self.options.xval_unit,
"yval_unit": self.options.yval_unit,
"xlabel": self.options.xlabel,
"ylabel": self.options.ylabel,
"xlim": self.options.xlim,
"ylim": self.options.ylim,
},
fit_data=fit_result,
result_entries=analysis_results,
style=self.options.style,
axis=self.options.axis,
)
figures = [fit_figure]
# Initialize axis
self.drawer.initialize_canvas()
# Write raw data
if self.options.plot_raw_data:
for s in self.__series__:
raw_data = self._data(label="raw_data", series_name=s.name)
self.drawer.draw_raw_data(
x_data=raw_data.x,
y_data=raw_data.y,
ax_index=s.canvas,
)
# Write data points
for s in self.__series__:
curve_data = self._data(label="fit_ready", series_name=s.name)
self.drawer.draw_formatted_data(
x_data=curve_data.x,
y_data=curve_data.y,
y_err_data=curve_data.y_err,
name=s.name,
ax_index=s.canvas,
color=s.plot_color,
marker=s.plot_symbol,
)
# Write fit results if fitting succeeded
if fit_result:
for s in self.__series__:
interp_x = np.linspace(*fit_result.x_range, 100)

params = {}
for fitpar in s.signature:
if fitpar in self.options.fixed_parameters:
params[fitpar] = self.options.fixed_parameters[fitpar]
else:
params[fitpar] = fit_result.fitval(fitpar)

y_data_with_uncertainty = s.fit_func(interp_x, **params)
y_mean = unp.nominal_values(y_data_with_uncertainty)
y_std = unp.std_devs(y_data_with_uncertainty)
# Draw fit line
self.drawer.draw_fit_line(
x_data=interp_x,
y_data=y_mean,
ax_index=s.canvas,
color=s.plot_color,
)
# Draw confidence intervals with different n_sigma
sigmas = unp.std_devs(y_data_with_uncertainty)
if np.isfinite(sigmas).all():
for n_sigma, alpha in self.drawer.options.plot_sigma:
self.drawer.draw_confidence_interval(
x_data=interp_x,
y_ub=y_mean + n_sigma * y_std,
y_lb=y_mean - n_sigma * y_std,
ax_index=s.canvas,
alpha=alpha,
color=s.plot_color,
)

# Write fitting report
report_description = ""
for res in analysis_results:
if isinstance(res.value, (float, uncertainties.UFloat)):
report_description += f"{analysis_result_to_repr(res)}\n"
report_description += r"Fit $\chi^2$ = " + f"{fit_result.reduced_chisq: .4g}"
self.drawer.draw_fit_report(description=report_description)
self.drawer.format_canvas()
figures = [self.drawer.figure]
else:
figures = []

Expand Down Expand Up @@ -1034,3 +1117,51 @@ def is_error_not_significant(
return True

return False


def analysis_result_to_repr(result: AnalysisResultData) -> str:
"""A helper function to create string representation from analysis result data object.
Args:
result: Analysis result data.
Returns:
String representation of the data.
"""
if not isinstance(result.value, (float, uncertainties.UFloat)):
return AnalysisError(f"Result data {result.name} is not a valid fit parameter data type.")

unit = result.extra.get("unit", None)

def _format_val(value):
# Return value with unit with prefix, i.e. 1000 Hz -> 1 kHz.
if unit:
try:
val, val_prefix = detach_prefix(value, decimal=3)
except ValueError:
val = value
val_prefix = ""
return f"{val: .3g}", f" {val_prefix}{unit}"
if np.abs(value) < 1e-3 or np.abs(value) > 1e3:
return f"{value: .4e}", ""
return f"{value: .4g}", ""

if isinstance(result.value, float):
# Only nominal part
n_repr, n_unit = _format_val(result.value)
value_repr = n_repr + n_unit
else:
# Nominal part
n_repr, n_unit = _format_val(result.value.nominal_value)

# Standard error part
if result.value.std_dev is not None and np.isfinite(result.value.std_dev):
s_repr, s_unit = _format_val(result.value.std_dev)
if n_unit == s_unit:
value_repr = f" {n_repr} \u00B1 {s_repr}{n_unit}"
else:
value_repr = f" {n_repr + n_unit} \u00B1 {s_repr + s_unit}"
else:
value_repr = n_repr + n_unit

return f"{result.name} = {value_repr}"
15 changes: 15 additions & 0 deletions qiskit_experiments/curve_analysis/curve_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import dataclasses
import inspect
from typing import Any, Dict, Callable, Union, List, Tuple, Optional, Iterable

import numpy as np
Expand Down Expand Up @@ -47,6 +48,20 @@ class SeriesDef:
# Index of canvas if the result figure is multi-panel
canvas: Optional[int] = None

# Automatically extracted signature of the fit function
signature: List[str] = dataclasses.field(init=False)

def __post_init__(self):
"""Parse the fit function signature to extract the names of the variables.
Fit functions take arguments F(x, p0, p1, p2, ...) thus the first value should be excluded.
"""
signature = list(inspect.signature(self.fit_func).parameters.keys())
fitparams = signature[1:]

# Note that this dataclass is frozen
object.__setattr__(self, "signature", fitparams)


@dataclasses.dataclass(frozen=True)
class CurveData:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ def _default_options(cls):
considered as good. Defaults to :math:`\pi/2`.
"""
default_options = super()._default_options()
default_options.curve_plotter.set_options(
xlabel="Number of gates (n)",
ylabel="Population",
ylim=(0, 1.0),
)
default_options.result_parameters = ["d_theta"]
default_options.xlabel = "Number of gates (n)"
default_options.ylabel = "Population"
default_options.ylim = [0, 1.0]
default_options.max_good_angle_error = np.pi / 2

return default_options
Expand Down
Loading

0 comments on commit 64e9859

Please sign in to comment.