Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nkanazawa1989 committed Dec 2, 2021
1 parent 2b19a10 commit 2ad7511
Show file tree
Hide file tree
Showing 24 changed files with 143 additions and 123 deletions.
16 changes: 6 additions & 10 deletions qiskit_experiments/curve_analysis/curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,13 +582,8 @@ def _is_target_series(datum, **filters):
) from ex

if isinstance(data_processor, DataProcessor):
y_data = data_processor(data)
ydata = data_processor(data)
else:
# Can we replace this with error or still worth supporting?
# Looks like this is too much flexibility.
warnings.warn(
"Use of non DataProcessor instance has been deprecated.", DeprecationWarning
)
y_nominals, y_stderrs = zip(*map(data_processor, data))
ydata = unp.uarray(y_nominals, y_stderrs)

Expand All @@ -599,7 +594,7 @@ def _is_target_series(datum, **filters):
shots = np.asarray([datum.get("shots", np.nan) for datum in data])

# Find series (invalid data is labeled as -1)
data_index = np.full(x_values.size, -1, dtype=int)
data_index = np.full(xdata.size, -1, dtype=int)
for idx, series_def in enumerate(self.__series__):
data_matched = np.asarray(
[_is_target_series(datum, **series_def.filter_kwargs) for datum in data], dtype=bool
Expand All @@ -611,7 +606,7 @@ def _is_target_series(datum, **filters):
label="raw_data",
x=xdata,
y=unp.nominal_values(ydata),
y_err=unp.nominal_values(ydata),
y_err=unp.std_devs(ydata),
shots=shots,
data_index=data_index,
metadata=metadata,
Expand Down Expand Up @@ -980,10 +975,11 @@ def _run_analysis(
analysis_results.append(
AnalysisResultData(
name=PARAMS_ENTRY_PREFIX + self.__class__.__name__,
value=fit_result.parameters,
value=[p.nominal_value for p in fit_result.popt],
chisq=fit_result.reduced_chisq,
quality=quality,
extra={
"popt_keys": fit_result.popt_keys,
"dof": fit_result.dof,
"covariance_mat": fit_result.pcov,
"fit_models": fit_models,
Expand All @@ -1006,7 +1002,7 @@ def _run_analysis(
unit = None
result_entry = AnalysisResultData(
name=p_repr,
value=fit_result.fit_val(p_name),
value=fit_result.fitval(p_name),
unit=unit,
chisq=fit_result.reduced_chisq,
quality=quality,
Expand Down
14 changes: 9 additions & 5 deletions qiskit_experiments/curve_analysis/curve_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint: disable = invalid-name

from typing import List, Dict, Tuple, Callable, Optional, Union
from uncertainties import ufloat, correlated_values
from uncertainties import correlated_values, ufloat

import numpy as np
import scipy.optimize as opt
Expand Down Expand Up @@ -63,8 +63,9 @@ def curve_fit(
``xrange`` the range of xdata values used for fit.
Raises:
AnalysisError: if the number of degrees of freedom of the fit is
less than 1, or the curve fitting fails.
AnalysisError:
When the number of degrees of freedom of the fit is
less than 1, or the curve fitting fails.
.. note::
``sigma`` is assumed to be specified in the same units as ``ydata``
Expand Down Expand Up @@ -135,8 +136,11 @@ def fit_func(x, *params):
"scipy.optimize.curve_fit failed with error: {}".format(str(ex))
) from ex

# Keep parameter correlations in following analysis steps
fit_params = correlated_values(nom_values=popt, covariance_mat=pcov)
if np.isfinite(pcov).all():
# Keep parameter correlations in following analysis steps
fit_params = correlated_values(nom_values=popt, covariance_mat=pcov, tags=param_keys)
else:
fit_params = [ufloat(nom, np.nan) for nom in popt]

# Calculate the reduced chi-squared for fit
yfits = fit_func(xdata, *popt)
Expand Down
42 changes: 28 additions & 14 deletions qiskit_experiments/curve_analysis/fit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,38 @@
"""
# pylint: disable=invalid-name

import numpy as np
import functools
from typing import Callable

import numpy as np
from uncertainties import unumpy as unp


def uncertainties(fit_func):
def calc_uncertainties(fit_func: Callable) -> Callable:
"""Decolator that typecast y values to float array if input parameters have no error.
Args:
fit_func: Fit function that may return ufloat array.
Returns:
Fit function with typecast.
"""

@functools.wraps(fit_func)
def wrapper(*ags, **kwargs) -> np.ndarray:
yvals = wrapper(*args, **kwargs)
def _wrapper(x, *args, **kwargs) -> np.ndarray:
yvals = fit_func(x, *args, **kwargs)
try:
if isinstance(x, float):
# single value
return float(yvals)
return yvals.astype(float)
except TypeError:
return yvals
return wrapper

return _wrapper


@uncertainties
@calc_uncertainties
def cos(
x: np.ndarray,
amp: float = 1.0,
Expand All @@ -49,7 +63,7 @@ def cos(
return amp * unp.cos(2 * np.pi * freq * x + phase) + baseline


@uncertainties
@calc_uncertainties
def sin(
x: np.ndarray,
amp: float = 1.0,
Expand All @@ -66,7 +80,7 @@ def sin(
return amp * unp.sin(2 * np.pi * freq * x + phase) + baseline


@uncertainties
@calc_uncertainties
def exponential_decay(
x: np.ndarray,
amp: float = 1.0,
Expand All @@ -83,7 +97,7 @@ def exponential_decay(
return amp * base ** (-lamb * x + x0) + baseline


@uncertainties
@calc_uncertainties
def gaussian(
x: np.ndarray, amp: float = 1.0, sigma: float = 1.0, x0: float = 0.0, baseline: float = 0.0
) -> np.ndarray:
Expand All @@ -95,7 +109,7 @@ def gaussian(
return amp * unp.exp(-((x - x0) ** 2) / (2 * sigma ** 2)) + baseline


@uncertainties
@calc_uncertainties
def cos_decay(
x: np.ndarray,
amp: float = 1.0,
Expand All @@ -113,7 +127,7 @@ def cos_decay(
return exponential_decay(x, lamb=1 / tau) * cos(x, amp=amp, freq=freq, phase=phase) + baseline


@uncertainties
@calc_uncertainties
def sin_decay(
x: np.ndarray,
amp: float = 1.0,
Expand All @@ -131,7 +145,7 @@ def sin_decay(
return exponential_decay(x, lamb=1 / tau) * sin(x, amp=amp, freq=freq, phase=phase) + baseline


@uncertainties
@calc_uncertainties
def bloch_oscillation_x(
x: np.ndarray, px: float = 0.0, py: float = 0.0, pz: float = 0.0, baseline: float = 0.0
):
Expand All @@ -149,7 +163,7 @@ def bloch_oscillation_x(
return (-pz * px + pz * px * unp.cos(w * x) + w * py * unp.sin(w * x)) / (w ** 2) + baseline


@uncertainties
@calc_uncertainties
def bloch_oscillation_y(
x: np.ndarray, px: float = 0.0, py: float = 0.0, pz: float = 0.0, baseline: float = 0.0
):
Expand All @@ -167,7 +181,7 @@ def bloch_oscillation_y(
return (pz * py - pz * py * unp.cos(w * x) - w * px * unp.sin(w * x)) / (w ** 2) + baseline


@uncertainties
@calc_uncertainties
def bloch_oscillation_z(
x: np.ndarray, px: float = 0.0, py: float = 0.0, pz: float = 0.0, baseline: float = 0.0
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:

criteria = [
fit_data.reduced_chisq < 3,
tau.stderr is None or tau.stderr < tau.value,
tau.std_dev is None or tau.std_dev < tau.nominal_value,
]

if all(criteria):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
- a measured angle error that is smaller than the allowed maximum good angle error.
This quantity is set in the analysis options.
"""
fit_d_theta = fit_data.fitval("d_theta").value
fit_d_theta = fit_data.fitval("d_theta")
max_good_angle_error = self._get_option("max_good_angle_error")

criteria = [
fit_data.reduced_chisq < 3,
abs(fit_d_theta) < abs(max_good_angle_error),
abs(fit_d_theta.nominal_value) < abs(max_good_angle_error),
]

if all(criteria):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,12 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
- less than 10 full periods, and
- an error on the fit frequency lower than the fit frequency.
"""
fit_freq = fit_data.fitval("freq").value
fit_freq_err = fit_data.fitval("freq").stderr
fit_freq = fit_data.fitval("freq")

criteria = [
fit_data.reduced_chisq < 3,
1.0 / 4.0 < fit_freq < 10.0,
(fit_freq_err is None or (fit_freq_err < fit_freq)),
1.0 / 4.0 < fit_freq.nominal_value < 10.0,
(np.isnan(fit_freq.std_dev) or (fit_freq.std_dev < fit_freq.nominal_value)),
]

if all(criteria):
Expand Down Expand Up @@ -264,8 +263,8 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:

criteria = [
fit_data.reduced_chisq < 3,
tau.stderr is None or tau.stderr < tau.value,
freq.stderr is None or freq.stderr < freq.value,
np.isnan(tau.std_dev) or tau.std_dev < tau.nominal_value,
np.isnan(freq.std_dev) or freq.std_dev < freq.nominal_value,
]

if all(criteria):
Expand Down
19 changes: 9 additions & 10 deletions qiskit_experiments/curve_analysis/standard_analysis/resonance.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,20 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
min_freq = np.min(curve_data.x)
freq_increment = np.mean(np.diff(curve_data.x))

fit_a = fit_data.fitval("a").value
fit_b = fit_data.fitval("b").value
fit_freq = fit_data.fitval("freq").value
fit_sigma = fit_data.fitval("sigma").value
fit_sigma_err = fit_data.fitval("sigma").stderr
fit_a = fit_data.fitval("a")
fit_b = fit_data.fitval("b")
fit_freq = fit_data.fitval("freq")
fit_sigma = fit_data.fitval("sigma")

snr = abs(fit_a) / np.sqrt(abs(np.median(curve_data.y) - fit_b))
fit_width_ratio = fit_sigma / (max_freq - min_freq)
snr = abs(fit_a.nominal_value) / np.sqrt(abs(np.median(curve_data.y) - fit_b.nominal_value))
fit_width_ratio = fit_sigma.nominal_value / (max_freq - min_freq)

criteria = [
min_freq <= fit_freq <= max_freq,
1.5 * freq_increment < fit_sigma,
min_freq <= fit_freq.nominal_value <= max_freq,
1.5 * freq_increment < fit_sigma.nominal_value,
fit_width_ratio < 0.25,
fit_data.reduced_chisq < 3,
(fit_sigma_err is None or fit_sigma_err < fit_sigma),
(np.isnan(fit_sigma.std_dev) or fit_sigma.std_dev < fit_sigma.nominal_value),
snr > 2,
]

Expand Down
16 changes: 9 additions & 7 deletions qiskit_experiments/curve_analysis/visualization/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def plot_curve_fit(
ax.plot(xs, unp.nominal_values(ys_fit_with_error), **plot_opts)

# Confidence interval of 1 sigma
ax.fill_between(
xs,
unp.nominal_values(ys_fit_with_error) - unp.std_devs(ys_fit_with_error),
unp.nominal_values(ys_fit_with_error) + unp.std_devs(ys_fit_with_error),
alpha=0.1,
color=plot_opts["color"],
)
stdev_arr = unp.std_devs(ys_fit_with_error)
if np.isfinite(stdev_arr).all():
ax.fill_between(
xs,
y1=unp.nominal_values(ys_fit_with_error) - stdev_arr,
y2=unp.nominal_values(ys_fit_with_error) + stdev_arr,
alpha=0.1,
color=plot_opts["color"],
)

# Formatting
ax.tick_params(labelsize=labelsize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from uncertainties.core import UFloat

from qiskit_experiments.curve_analysis.curve_data import SeriesDef, FitData, CurveData
from qiskit_experiments.framework import AnalysisResultDataF
from qiskit_experiments.framework import AnalysisResultData
from qiskit_experiments.framework.matplotlib import get_non_gui_ax


Expand Down Expand Up @@ -409,7 +409,7 @@ def format_val(float_val: float) -> str:
value_repr = f"{val: .3g}"

# write error bar if it is finite value
if fitval.std_dev is not None and not np.isinf(fitval.std_dev):
if fitval.std_dev is not None and np.isfinite(fitval.std_dev):
# with stderr
err, err_prefix = detach_prefix(fitval.std_dev, decimal=3)
err_unit = err_prefix + res.unit
Expand Down
2 changes: 2 additions & 0 deletions qiskit_experiments/database_service/db_fitval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class FitVal:
This data is serializable with the Qiskit Experiment json serializer.
"""

# TODO deprecate this (replace with UFloat)

value: float
stderr: Optional[float] = None
unit: Optional[str] = None
Expand Down
1 change: 1 addition & 0 deletions qiskit_experiments/framework/base_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _format_analysis_result(self, data, experiment_id, experiment_components=Non
device_components = experiment_components

# Convert ufloat to FitVal so that database service can parse
# TODO completely deprecate FitVal. We can store UFloat in database.
if isinstance(data.value, UFloat):
value = FitVal(
value=data.value.nominal_value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,9 @@ def _extra_database_entry(self, fit_data: curve.FitData) -> List[AnalysisResultD
p1_val = fit_data.fitval(f"p{target}1")

if control == "z":
coef_val = 0.5 * (p0_val.value - p1_val.value) / (2 * np.pi)
coef_val = 0.5 * (p0_val - p1_val) / (2 * np.pi)
else:
coef_val = 0.5 * (p0_val.value + p1_val.value) / (2 * np.pi)
coef_val = 0.5 * (p0_val + p1_val) / (2 * np.pi)

extra_entries.append(
AnalysisResultData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,13 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
- a DRAG parameter value within the first period of the lowest number of repetitions,
- an error on the drag beta smaller than the beta.
"""
fit_beta = fit_data.fitval("beta").value
fit_beta_err = fit_data.fitval("beta").stderr
fit_freq0 = fit_data.fitval("freq0").value
fit_beta = fit_data.fitval("beta")
fit_freq0 = fit_data.fitval("freq0")

criteria = [
fit_data.reduced_chisq < 3,
fit_beta < 1 / fit_freq0,
fit_beta_err < abs(fit_beta),
fit_beta.nominal_value < 1 / fit_freq0.nominal_value,
fit_beta.std_dev < abs(fit_beta.nominal_value),
]

if all(criteria):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,11 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
- a reduced chi-squared lower than three,
- an error on the frequency smaller than the frequency.
"""
fit_freq = fit_data.fitval("freq").value
fit_freq_err = fit_data.fitval("freq").stderr
fit_freq = fit_data.fitval("freq")

criteria = [
fit_data.reduced_chisq < 3,
fit_freq_err < abs(fit_freq),
fit_freq.std_dev < abs(fit_freq.nominal_value),
]

if all(criteria):
Expand Down
Loading

0 comments on commit 2ad7511

Please sign in to comment.