Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Curve analysis drawer refactoring #738

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6621de6
Introduce curve drawer mixin and remove the use of curve drawer callb…
nkanazawa1989 Mar 16, 2022
0e56448
copy from qiskit-experiments/#726
nkanazawa1989 Mar 17, 2022
34a3d97
fix subclass init
nkanazawa1989 Mar 22, 2022
434b894
move comment
nkanazawa1989 Mar 23, 2022
b9d91ed
remove analysis option from Rabi experiment
nkanazawa1989 Mar 23, 2022
7e85988
update failing code
nkanazawa1989 Mar 23, 2022
09a16bf
mixin -> serializable curve drawer instance
nkanazawa1989 Mar 28, 2022
8080a51
update reno
nkanazawa1989 Mar 28, 2022
bb44d59
fix unittest (add eq check for drawer)
nkanazawa1989 Mar 29, 2022
311d10b
define own axis formatter class to avoid pickling lambda/local scoped…
nkanazawa1989 Mar 29, 2022
680a0d2
update initialization of default drawer
nkanazawa1989 Mar 29, 2022
fc38419
minor update for CR Hamiltonian axis labels
nkanazawa1989 Mar 29, 2022
d56efae
Update qiskit_experiments/curve_analysis/visualization/base_drawer.py
nkanazawa1989 Mar 30, 2022
88f71ab
Update qiskit_experiments/curve_analysis/visualization/base_drawer.py
nkanazawa1989 Mar 30, 2022
c8e49cc
Update qiskit_experiments/curve_analysis/visualization/base_drawer.py
nkanazawa1989 Mar 30, 2022
4c388bf
Update qiskit_experiments/curve_analysis/visualization/base_drawer.py
nkanazawa1989 Mar 30, 2022
4f16997
Update qiskit_experiments/curve_analysis/visualization/mpl_drawer.py
nkanazawa1989 Mar 30, 2022
bcd8a6e
review comments
nkanazawa1989 Mar 30, 2022
5e5bf5e
remove dependency on fit model from drawer
nkanazawa1989 Apr 4, 2022
acb75f3
remove dependency on analysis result data object from drawer
nkanazawa1989 Apr 4, 2022
0784e74
Update qiskit_experiments/curve_analysis/curve_analysis.py
nkanazawa1989 Apr 4, 2022
e7c02e9
Update qiskit_experiments/curve_analysis/visualization/base_drawer.py
nkanazawa1989 Apr 4, 2022
2f2c77f
Merge branch 'main' into curve_analysis_refactor/drawer
nkanazawa1989 Apr 4, 2022
42e060d
black
nkanazawa1989 Apr 4, 2022
6418a25
update documentation
nkanazawa1989 Apr 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion qiskit_experiments/curve_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FitData
ParameterRepr
FitOptions
MplCurveDrawer

Standard Analysis
=================
Expand Down Expand Up @@ -119,7 +120,7 @@
process_curve_data,
process_multi_curve_data,
)
from .visualization import plot_curve_fit, plot_errorbar, plot_scatter, FitResultPlotters
from .visualization import MplCurveDrawer
from . import guess
from . import fit_function

Expand All @@ -132,3 +133,6 @@
GaussianAnalysis,
ErrorAmplificationAnalysis,
)

# deprecated
from .visualization import plot_curve_fit, plot_errorbar, plot_scatter, FitResultPlotters
241 changes: 186 additions & 55 deletions qiskit_experiments/curve_analysis/curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from uncertainties import unumpy as unp

from qiskit.providers import Backend
from qiskit.utils import detach_prefix
from qiskit_experiments.curve_analysis.curve_data import (
CurveData,
SeriesDef,
Expand All @@ -37,7 +38,7 @@
)
from qiskit_experiments.curve_analysis.curve_fit import multi_curve_fit
from qiskit_experiments.curve_analysis.data_processing import multi_mean_xy_data, data_sort
from qiskit_experiments.curve_analysis.visualization import FitResultPlotters, PlotterStyle
from qiskit_experiments.curve_analysis.visualization import MplCurveDrawer, BaseCurveDrawer
from qiskit_experiments.data_processing import DataProcessor
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.data_processing.processor_library import get_processor
Expand Down Expand Up @@ -290,11 +291,22 @@ def parameters(self) -> List[str]:
"""Return parameters of this curve analysis."""
return [s for s in self._fit_params() if s not in self.options.fixed_parameters]

@property
def drawer(self) -> BaseCurveDrawer:
"""A short-cut for curve drawer instance."""
return self._options.curve_plotter

@classmethod
def _default_options(cls) -> Options:
"""Return default analysis options.
Analysis Options:
curve_plotter (BaseCurveDrawer): A curve drawer instance to visualize
the analysis result.
plot_raw_data (bool): Set ``True`` to draw un-formatted data points on canvas.
This is ``True`` by default.
plot (bool): Set ``True`` to create figure for fit result.
This is ``False`` by default.
curve_fitter (Callable): A callback function to perform fitting with formatted data.
See :func:`~qiskit_experiments.analysis.multi_curve_fit` for example.
data_processor (Callable): A callback function to format experiment data.
Expand All @@ -306,23 +318,6 @@ def _default_options(cls) -> Options:
bounds (Dict[str, Tuple[float, float]]): Array-like or dictionary
of (min, max) tuple of fit parameter boundaries.
x_key (str): Circuit metadata key representing a scanned value.
plot (bool): Set ``True`` to create figure for fit result.
axis (AxesSubplot): Optional. A matplotlib axis object to draw.
xlabel (str): X label of fit result figure.
ylabel (str): Y label of fit result figure.
xlim (Tuple[float, float]): Min and max value of horizontal axis of the fit plot.
ylim (Tuple[float, float]): Min and max value of vertical axis of the fit plot.
xval_unit (str): SI unit of x values. No prefix is needed here.
For example, when the x values represent time, this option will be just "s"
rather than "ms". In the fit result plot, the prefix is automatically selected
based on the maximum value. If your x values are in [1e-3, 1e-4], they
are displayed as [1 ms, 10 ms]. This option is likely provided by the
analysis class rather than end-users. However, users can still override
if they need different unit notation. By default, this option is set to ``None``,
and no scaling is applied. X axis will be displayed in the scientific notation.
yval_unit (str): Unit of y values. Same as ``xval_unit``.
This value is not provided in most experiments, because y value is usually
population or expectation values.
result_parameters (List[Union[str, ParameterRepr]): Parameters reported in the
database as a dedicated entry. This is a list of parameter representation
which is either string or ParameterRepr object. If you provide more
Expand All @@ -331,13 +326,6 @@ def _default_options(cls) -> Options:
The parameter name should be defined in the series definition.
Representation should be printable in standard output, i.e. no latex syntax.
return_data_points (bool): Set ``True`` to return formatted XY data.
curve_plotter (str): A name of plotter function used to generate
the curve fit result figure. This refers to the mapper
:py:class:`~qiskit_experiments.curve_analysis.visualization.FitResultPlotters`
to retrieve the corresponding callback function.
style (PlotterStyle): An instance of
:py:class:`~qiskit_experiments.curve_analysis.visualization.style.PlotterStyle`
that contains a set of configurations to create a fit plot.
extra (Dict[str, Any]): A dictionary that is appended to all database entries
as extra information.
curve_fitter_options (Dict[str, Any]) Options that are passed to the
Expand All @@ -348,22 +336,15 @@ def _default_options(cls) -> Options:
"""
options = super()._default_options()

options.curve_plotter = MplCurveDrawer()
options.plot_raw_data = False
options.plot = True
options.curve_fitter = multi_curve_fit
options.data_processor = None
options.normalization = False
options.x_key = "xval"
options.plot = True
options.axis = None
options.xlabel = None
options.ylabel = None
options.xlim = None
options.ylim = None
options.xval_unit = None
options.yval_unit = None
options.result_parameters = None
options.return_data_points = False
options.curve_plotter = "mpl_single_canvas"
options.style = PlotterStyle()
options.extra = dict()
options.curve_fitter_options = dict()
options.p0 = {}
Expand All @@ -372,6 +353,58 @@ def _default_options(cls) -> Options:

return options

def set_options(self, **fields):
"""Set the analysis options for :meth:`run` method.
Args:
fields: The fields to update the options
Raises:
KeyError: When removed option ``curve_fitter`` is set.
TypeError: When invalid drawer instance is provided.
"""
# TODO remove this in Qiskit Experiments v0.4
if "curve_plotter" in fields and isinstance(fields["curve_plotter"], str):
plotter_str = fields["curve_plotter"]
warnings.warn(
f"The curve plotter '{plotter_str}' has been deprecated. "
"The option is replaced with 'MplCurveDrawer' instance. "
"If this is a loaded analysis, please save this instance again to update option value. "
"This warning will be removed with backport in Qiskit Experiments 0.4.",
DeprecationWarning,
stacklevel=2,
)
fields["curve_plotter"] = MplCurveDrawer()

if "curve_plotter" in fields and not isinstance(fields["curve_plotter"], BaseCurveDrawer):
plotter_obj = fields["curve_plotter"]
raise TypeError(
f"'{plotter_obj.__class__.__name__}' object is not valid curve drawer instance."
)

# pylint: disable=no-member
draw_options = set(self.drawer.options.__dict__.keys()) | {"style"}
deprecated = draw_options & fields.keys()
if any(deprecated):
warnings.warn(
f"Option(s) {deprecated} have been moved to draw_options and will be removed soon. "
"Use self.drawer.set_options instead. "
"If this is a loaded analysis, please save this instance again to update option value. "
"This warning will be removed with backport in Qiskit Experiments 0.4.",
DeprecationWarning,
stacklevel=2,
)
draw_options = dict()
for depopt in deprecated:
if depopt == "style":
for k, v in fields.pop("style").items():
draw_options[k] = v
else:
draw_options[depopt] = fields.pop(depopt)
self.drawer.set_options(**draw_options)

super().set_options(**fields)

def _generate_fit_guesses(self, user_opt: FitOptions) -> Union[FitOptions, List[FitOptions]]:
"""Create algorithmic guess with analysis options and curve data.
Expand Down Expand Up @@ -786,6 +819,7 @@ def _run_analysis(
for series_def in self.__series__:
dict_def = dataclasses.asdict(series_def)
dict_def["fit_func"] = functools.partial(series_def.fit_func, **assigned_params)
del dict_def["signature"]
assigned_series.append(SeriesDef(**dict_def))
self.__series__ = assigned_series

Expand Down Expand Up @@ -946,8 +980,8 @@ def _run_analysis(
name=DATA_ENTRY_PREFIX + self.__class__.__name__,
value=raw_data_dict,
extra={
"x-unit": self.options.xval_unit,
"y-unit": self.options.yval_unit,
"x-unit": self.drawer.options.xval_unit,
"y-unit": self.drawer.options.yval_unit,
},
)
analysis_results.append(raw_data_entry)
Expand All @@ -956,24 +990,73 @@ def _run_analysis(
# 6. Create figures
#
if self.options.plot:
fit_figure = FitResultPlotters[self.options.curve_plotter].value.draw(
series_defs=self.__series__,
raw_samples=[self._data(ser.name, "raw_data") for ser in self.__series__],
fit_samples=[self._data(ser.name, "fit_ready") for ser in self.__series__],
tick_labels={
"xval_unit": self.options.xval_unit,
"yval_unit": self.options.yval_unit,
"xlabel": self.options.xlabel,
"ylabel": self.options.ylabel,
"xlim": self.options.xlim,
"ylim": self.options.ylim,
},
fit_data=fit_result,
result_entries=analysis_results,
style=self.options.style,
axis=self.options.axis,
)
figures = [fit_figure]
# Initialize axis
self.drawer.initialize_canvas()
# Write raw data
if self.options.plot_raw_data:
for s in self.__series__:
raw_data = self._data(label="raw_data", series_name=s.name)
self.drawer.draw_raw_data(
x_data=raw_data.x,
y_data=raw_data.y,
ax_index=s.canvas,
)
# Write data points
for s in self.__series__:
curve_data = self._data(label="fit_ready", series_name=s.name)
self.drawer.draw_formatted_data(
x_data=curve_data.x,
y_data=curve_data.y,
y_err_data=curve_data.y_err,
name=s.name,
ax_index=s.canvas,
color=s.plot_color,
marker=s.plot_symbol,
)
# Write fit results if fitting succeeded
if fit_result:
for s in self.__series__:
interp_x = np.linspace(*fit_result.x_range, 100)

params = {}
for fitpar in s.signature:
if fitpar in self.options.fixed_parameters:
params[fitpar] = self.options.fixed_parameters[fitpar]
else:
params[fitpar] = fit_result.fitval(fitpar)

y_data_with_uncertainty = s.fit_func(interp_x, **params)
y_mean = unp.nominal_values(y_data_with_uncertainty)
y_std = unp.std_devs(y_data_with_uncertainty)
# Draw fit line
self.drawer.draw_fit_line(
x_data=interp_x,
y_data=y_mean,
ax_index=s.canvas,
color=s.plot_color,
)
# Draw confidence intervals with different n_sigma
sigmas = unp.std_devs(y_data_with_uncertainty)
if np.isfinite(sigmas).all():
for n_sigma, alpha in self.drawer.options.plot_sigma:
self.drawer.draw_confidence_interval(
x_data=interp_x,
y_ub=y_mean + n_sigma * y_std,
y_lb=y_mean - n_sigma * y_std,
ax_index=s.canvas,
alpha=alpha,
color=s.plot_color,
)

# Write fitting report
report_description = ""
for res in analysis_results:
if isinstance(res.value, (float, uncertainties.UFloat)):
report_description += f"{analysis_result_to_repr(res)}\n"
report_description += r"Fit $\chi^2$ = " + f"{fit_result.reduced_chisq: .4g}"
self.drawer.draw_fit_report(description=report_description)
self.drawer.format_canvas()
figures = [self.drawer.figure]
else:
figures = []

Expand Down Expand Up @@ -1034,3 +1117,51 @@ def is_error_not_significant(
return True

return False


def analysis_result_to_repr(result: AnalysisResultData) -> str:
"""A helper function to create string representation from analysis result data object.
Args:
result: Analysis result data.
Returns:
String representation of the data.
"""
if not isinstance(result.value, (float, uncertainties.UFloat)):
return AnalysisError(f"Result data {result.name} is not a valid fit parameter data type.")

unit = result.extra.get("unit", None)

def _format_val(value):
# Return value with unit with prefix, i.e. 1000 Hz -> 1 kHz.
if unit:
try:
val, val_prefix = detach_prefix(value, decimal=3)
except ValueError:
val = value
val_prefix = ""
return f"{val: .3g}", f" {val_prefix}{unit}"
if np.abs(value) < 1e-3 or np.abs(value) > 1e3:
return f"{value: .4e}", ""
return f"{value: .4g}", ""

if isinstance(result.value, float):
# Only nominal part
n_repr, n_unit = _format_val(result.value)
value_repr = n_repr + n_unit
else:
# Nominal part
n_repr, n_unit = _format_val(result.value.nominal_value)

# Standard error part
if result.value.std_dev is not None and np.isfinite(result.value.std_dev):
s_repr, s_unit = _format_val(result.value.std_dev)
if n_unit == s_unit:
value_repr = f" {n_repr} \u00B1 {s_repr}{n_unit}"
else:
value_repr = f" {n_repr + n_unit} \u00B1 {s_repr + s_unit}"
else:
value_repr = n_repr + n_unit

return f"{result.name} = {value_repr}"
15 changes: 15 additions & 0 deletions qiskit_experiments/curve_analysis/curve_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import dataclasses
import inspect
from typing import Any, Dict, Callable, Union, List, Tuple, Optional, Iterable

import numpy as np
Expand Down Expand Up @@ -47,6 +48,20 @@ class SeriesDef:
# Index of canvas if the result figure is multi-panel
canvas: Optional[int] = None

# Automatically extracted signature of the fit function
signature: List[str] = dataclasses.field(init=False)

def __post_init__(self):
"""Parse the fit function signature to extract the names of the variables.

Fit functions take arguments F(x, p0, p1, p2, ...) thus the first value should be excluded.
"""
signature = list(inspect.signature(self.fit_func).parameters.keys())
fitparams = signature[1:]

# Note that this dataclass is frozen
object.__setattr__(self, "signature", fitparams)


@dataclasses.dataclass(frozen=True)
class CurveData:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ def _default_options(cls):
considered as good. Defaults to :math:`\pi/2`.
"""
default_options = super()._default_options()
default_options.curve_plotter.set_options(
xlabel="Number of gates (n)",
ylabel="Population",
ylim=(0, 1.0),
)
default_options.result_parameters = ["d_theta"]
default_options.xlabel = "Number of gates (n)"
default_options.ylabel = "Population"
default_options.ylim = [0, 1.0]
default_options.max_good_angle_error = np.pi / 2
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved

return default_options
Expand Down
Loading