Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CurveAnalysis base class #765

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3a4f605
wip migration to curve analysis baseclass
nkanazawa1989 Apr 6, 2022
f9373d4
Merge branch 'main' of github.com:Qiskit/qiskit-experiments into feat…
nkanazawa1989 Apr 21, 2022
220a421
integration into main branch
nkanazawa1989 Apr 21, 2022
9e7ad5b
review comment
nkanazawa1989 Apr 21, 2022
759eaea
developer document in module doc
nkanazawa1989 Apr 22, 2022
3672723
update subclasses
nkanazawa1989 Apr 22, 2022
eb22e39
update curvefit unittests
nkanazawa1989 Apr 22, 2022
62d1609
finalize
nkanazawa1989 Apr 22, 2022
79e8704
remove validation
nkanazawa1989 Apr 22, 2022
3a86a25
review comments
nkanazawa1989 Apr 25, 2022
3d7a5e4
review comments
nkanazawa1989 Apr 25, 2022
ef9e61b
readd method documentation and update reno
nkanazawa1989 Apr 25, 2022
d7e25a4
docs and option name
nkanazawa1989 Apr 25, 2022
72d82e5
update reno with link
nkanazawa1989 Apr 26, 2022
e714fc0
Update releasenotes/notes/cleanup-curve-analysis-96d7ff706cae5b4e.yaml
nkanazawa1989 Apr 26, 2022
6509b4a
Update releasenotes/notes/cleanup-curve-analysis-96d7ff706cae5b4e.yaml
nkanazawa1989 Apr 26, 2022
9a92d1e
Merge branch 'main' into feature/curve_analysis_baseclass
nkanazawa1989 Apr 26, 2022
5f20a7b
Merge branch 'feature/curve_analysis_baseclass' of github.com:nkanaza…
nkanazawa1989 Apr 26, 2022
67c942b
test and lint fix
nkanazawa1989 Apr 26, 2022
0c9cb7c
Merge branch 'main' into feature/curve_analysis_baseclass
nkanazawa1989 Apr 26, 2022
d9013ad
minor typo fix
nkanazawa1989 Apr 26, 2022
8c585b2
Merge branch 'feature/curve_analysis_baseclass' of github.com:nkanaza…
nkanazawa1989 Apr 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
498 changes: 461 additions & 37 deletions qiskit_experiments/curve_analysis/__init__.py

Large diffs are not rendered by default.

547 changes: 547 additions & 0 deletions qiskit_experiments/curve_analysis/base_curve_analysis.py

Large diffs are not rendered by default.

956 changes: 123 additions & 833 deletions qiskit_experiments/curve_analysis/curve_analysis.py

Large diffs are not rendered by default.

148 changes: 99 additions & 49 deletions qiskit_experiments/curve_analysis/curve_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,94 +25,138 @@

@dataclasses.dataclass(frozen=True)
class SeriesDef:
"""Description of curve."""
"""A dataclass to describe the definition of the curve.

Attributes:
fit_func: A callable that defines the fit model of this curve. The argument names
in the callable are parsed to create the fit parameter list, which will appear
in the analysis results. The first argument should be ``x`` that represents
X-values that the experiment sweeps.
filter_kwargs: Optional. Dictionary of properties that uniquely identifies this series.
This dictionary is used for data processing.
This must be provided when the curve analysis consists of multiple series.
name: Optional. Name of this series.
plot_color: Optional. String representation of the color that is used to draw fit data
and data points in the output figure. This depends on the drawer class
being set to the curve analysis options. Usually this conforms to the
Matplotlib color names.
plot_symbol: Optional. String representation of the marker shape that is used to draw
data points in the output figure. This depends on the drawer class
being set to the curve analysis options. Usually this conforms to the
Matplotlib symbol names.
canvas: Optional. Index of sub-axis in the output figure that draws this curve.
This option is valid only when the drawer instance provides multi-axis drawing.
model_description: Optional. Arbitrary string representation of this fit model.
This string will appear in the analysis results as a part of metadata.
"""

# Arbitrary callback to define the fit function. First argument should be x.
fit_func: Callable

# Keyword dictionary to define the series with circuit metadata
filter_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

# Name of this series. This name will appear in the figure and raw x-y value report.
name: str = "Series-0"

# Color of this line.
plot_color: str = "black"

# Symbol to represent data points of this line.
plot_symbol: str = "o"

# Latex description of this fit model
model_description: Optional[str] = None

# Index of canvas if the result figure is multi-panel
canvas: Optional[int] = None

# Automatically extracted signature of the fit function
signature: List[str] = dataclasses.field(init=False)
model_description: Optional[str] = None
signature: Tuple[str, ...] = dataclasses.field(init=False)

def __post_init__(self):
"""Parse the fit function signature to extract the names of the variables.

Fit functions take arguments F(x, p0, p1, p2, ...) thus the first value should be excluded.
"""
signature = list(inspect.signature(self.fit_func).parameters.keys())
fitparams = signature[1:]
fitparams = tuple(signature[1:])

# Note that this dataclass is frozen
object.__setattr__(self, "signature", fitparams)


@dataclasses.dataclass(frozen=True)
class CurveData:
"""Set of extracted experiment data."""

# Name of this data set
label: str
"""A dataclass that manages the multiple arrays comprising the dataset for fitting.

This dataset can consist of X, Y values from multiple series.
To extract curve data of the particular series, :meth:`get_subset_of` can be used.

Attributes:
x: X-values that experiment sweeps.
y: Y-values that observed and processed by the data processor.
y_err: Uncertainty of the Y-values which is created by the data processor.
Usually this assumes standard error.
shots: Number of shots used in the experiment to obtain the Y-values.
data_allocation: List with identical size with other arrays.
The value indicates the series index of the corresponding element.
This is classified based upon the matching of :attr:`SeriesDef.filter_kwargs`
with the circuit metadata of the corresponding data index.
If metadata doesn't match with any series definition, element is filled with ``-1``.
labels: List of curve labels. The list index corresponds to the series index.
"""

# X data
x: np.ndarray

# Y data (measured data)
y: np.ndarray

# Error bar
y_err: np.ndarray

# Shots number
shots: np.ndarray
data_allocation: np.ndarray
labels: List[str]

# Maping of data index to series index
data_index: Union[np.ndarray, int]
def get_subset_of(self, index: Union[str, int]) -> "CurveData":
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved
"""Filter data by series name or index.

# Metadata associated with each data point. Generated from the circuit metadata.
metadata: np.ndarray = None
Args:
index: Series index of name.

Returns:
A subset of data corresponding to a particular series.
"""
if isinstance(index, int):
_index = index
_name = self.labels[index]
else:
_index = self.labels.index(index)
_name = index

locs = self.data_allocation == _index
return CurveData(
x=self.x[locs],
y=self.y[locs],
y_err=self.y_err[locs],
shots=self.shots[locs],
data_allocation=np.full(np.count_nonzero(locs), _index),
labels=[_name],
)


@dataclasses.dataclass(frozen=True)
class FitData:
"""Set of data generated by the fit function."""
"""A dataclass to store the outcome of the fitting.

Attributes:
popt: List of optimal parameter values with uncertainties if available.
popt_keys: List of parameter names being fit.
pcov: Covariance matrix from the least square fitting.
reduced_chisq: Reduced Chi-squared value for the fit curve.
dof: Degree of freedom in this fit model.
x_data: X-values provided to the fitter.
y_data: Y-values provided to the fitter.
"""

# Order sensitive fit parameter values
popt: List[uncertainties.UFloat]

# Order sensitive parameter name list
popt_keys: List[str]

# Covariance matrix
pcov: np.ndarray

# Reduced Chi-squared value of fit curve
reduced_chisq: float

# Degree of freedom
dof: int
x_data: np.ndarray
y_data: np.ndarray

# X data range
x_range: Tuple[float, float]
@property
def x_range(self) -> Tuple[float, float]:
"""Range of x values."""
return np.min(self.x_data), np.max(self.x_data)

# Y data range
y_range: Tuple[float, float]
@property
def y_range(self) -> Tuple[float, float]:
"""Range of y values."""
return np.min(self.y_data), np.max(self.y_data)

def fitval(self, key: str) -> uncertainties.UFloat:
"""A helper method to get fit value object from parameter key name.
Expand All @@ -136,7 +180,13 @@ def fitval(self, key: str) -> uncertainties.UFloat:

@dataclasses.dataclass
class ParameterRepr:
"""Detailed description of fitting parameter."""
"""Detailed description of fitting parameter.

Attributes:
name: Original name of the fit parameter being defined in the fit model.
repr: Optional. Human-readable parameter name shown in the analysis result and in the figure.
unit: Optional. Physical unit of this parameter if applicable.
"""

# Fitter argument name
name: str
Expand Down
8 changes: 2 additions & 6 deletions qiskit_experiments/curve_analysis/curve_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,14 @@ def fit_func(x, *params):
residues = residues / (sigma**2)
reduced_chisq = np.sum(residues) / dof

# Compute data range for fit
xdata_range = np.min(xdata), np.max(xdata)
ydata_range = np.min(ydata), np.max(ydata)

return FitData(
popt=list(fit_params),
popt_keys=list(param_keys),
pcov=pcov,
reduced_chisq=reduced_chisq,
dof=dof,
x_range=xdata_range,
y_range=ydata_range,
x_data=xdata,
y_data=ydata,
)


Expand Down
12 changes: 5 additions & 7 deletions qiskit_experiments/curve_analysis/standard_analysis/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,19 @@ class DecayAnalysis(curve.CurveAnalysis):
]

def _generate_fit_guesses(
self, user_opt: curve.FitOptions
self,
user_opt: curve.FitOptions,
curve_data: curve.CurveData,
) -> Union[curve.FitOptions, List[curve.FitOptions]]:
"""Compute the initial guesses.
"""Create algorithmic guess with analysis options and curve data.

Args:
user_opt: Fit options filled with user provided guess and bounds.
curve_data: Formatted data collection to fit.

Returns:
List of fit options that are passed to the fitter function.

Raises:
AnalysisError: When the y data is likely constant.
"""
curve_data = self._data()

user_opt.p0.set_if_empty(base=curve.guess.min_height(curve_data.y)[0])

alpha = curve.guess.exp_decay(curve_data.x, curve_data.y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ class ErrorAmplificationAnalysis(curve.CurveAnalysis):
often correspond to symmetry points of the fit function. Furthermore,
this type of analysis is intended for values of :math:`d\theta` close to zero.

# section: note

Different analysis classes may subclass this class to fix some of the fit parameters.
"""

__series__ = [
Expand Down Expand Up @@ -109,7 +106,7 @@ def _default_options(cls):
considered as good. Defaults to :math:`\pi/2`.
"""
default_options = super()._default_options()
default_options.curve_plotter.set_options(
default_options.curve_drawer.set_options(
xlabel="Number of gates (n)",
ylabel="Population",
ylim=(0, 1.0),
Expand All @@ -120,22 +117,21 @@ def _default_options(cls):
return default_options

def _generate_fit_guesses(
self, user_opt: curve.FitOptions
self,
user_opt: curve.FitOptions,
curve_data: curve.CurveData,
) -> Union[curve.FitOptions, List[curve.FitOptions]]:
"""Compute the initial guesses.
"""Create algorithmic guess with analysis options and curve data.

Args:
user_opt: Fit options filled with user provided guess and bounds.
curve_data: Formatted data collection to fit.

Returns:
List of fit options that are passed to the fitter function.

Raises:
CalibrationError: When ``angle_per_gate`` is missing.
"""
fixed_params = self.options.fixed_parameters

curve_data = self._data()
max_abs_y, _ = curve.guess.max_height(curve_data.y, absolute=True)
max_y, min_y = np.max(curve_data.y), np.min(curve_data.y)

Expand Down
22 changes: 10 additions & 12 deletions qiskit_experiments/curve_analysis/standard_analysis/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GaussianAnalysis(curve.CurveAnalysis):
@classmethod
def _default_options(cls) -> Options:
options = super()._default_options()
options.curve_plotter.set_options(
options.curve_drawer.set_options(
xlabel="Frequency",
ylabel="Signal (arb. units)",
xval_unit="Hz",
Expand All @@ -81,17 +81,19 @@ def _default_options(cls) -> Options:
return options

def _generate_fit_guesses(
self, user_opt: curve.FitOptions
self,
user_opt: curve.FitOptions,
curve_data: curve.CurveData,
) -> Union[curve.FitOptions, List[curve.FitOptions]]:
"""Compute the initial guesses.
"""Create algorithmic guess with analysis options and curve data.

Args:
user_opt: Fit options filled with user provided guess and bounds.
curve_data: Formatted data collection to fit.

Returns:
List of fit options that are passed to the fitter function.
"""
curve_data = self._data()
max_abs_y, _ = curve.guess.max_height(curve_data.y, absolute=True)

user_opt.bounds.set_if_empty(
Expand Down Expand Up @@ -128,22 +130,18 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
threshold of two, and
- a standard error on the sigma of the Gaussian that is smaller than the sigma.
"""
curve_data = self._data()

max_freq = np.max(curve_data.x)
min_freq = np.min(curve_data.x)
freq_increment = np.mean(np.diff(curve_data.x))
freq_increment = np.mean(np.diff(fit_data.x_data))

fit_a = fit_data.fitval("a")
fit_b = fit_data.fitval("b")
fit_freq = fit_data.fitval("freq")
fit_sigma = fit_data.fitval("sigma")

snr = abs(fit_a.n) / np.sqrt(abs(np.median(curve_data.y) - fit_b.n))
fit_width_ratio = fit_sigma.n / (max_freq - min_freq)
snr = abs(fit_a.n) / np.sqrt(abs(np.median(fit_data.y_data) - fit_b.n))
fit_width_ratio = fit_sigma.n / np.ptp(fit_data.x_data)

criteria = [
min_freq <= fit_freq.n <= max_freq,
fit_data.x_range[0] <= fit_freq.n <= fit_data.x_range[1],
1.5 * freq_increment < fit_sigma.n,
fit_width_ratio < 0.25,
fit_data.reduced_chisq < 3,
Expand Down
Loading