Skip to content

Commit

Permalink
update curvefit unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
nkanazawa1989 committed Apr 22, 2022
1 parent 3672723 commit eb22e39
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 153 deletions.
3 changes: 3 additions & 0 deletions qiskit_experiments/curve_analysis/base_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def _default_options(cls) -> Options:
# Set automatic validator for particular option values
options.set_validator(field="data_processor", validator_value=DataProcessor)
options.set_validator(field="curve_plotter", validator_value=BaseCurveDrawer)
options.set_validator(field="p0", validator_value=dict)
options.set_validator(field="bounds", validator_value=dict)
options.set_validator(field="fixed_parameters", validator_value=dict)

return options

Expand Down
22 changes: 12 additions & 10 deletions qiskit_experiments/curve_analysis/curve_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,20 @@ def get_subset_of(self, index: Union[str, int]) -> "CurveData":
A subset of data corresponding to a particular series.
"""
if isinstance(index, int):
inds = self.data_allocation == index
name = self.labels[index]
_index = index
_name = self.labels[index]
else:
inds = self.data_allocation == self.labels.index(index)
name = index
_index = self.labels.index(index)
_name = index

locs = self.data_allocation == _index
return CurveData(
x=self.x[inds],
y=self.y[inds],
y_err=self.y_err[inds],
shots=self.shots[inds],
data_allocation=np.full(np.count_nonzero(inds), index),
labels=[name],
x=self.x[locs],
y=self.y[locs],
y_err=self.y_err[locs],
shots=self.shots[locs],
data_allocation=np.full(np.count_nonzero(locs), _index),
labels=[_name],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ParameterRepr,
FitOptions,
)
from qiskit_experiments.curve_analysis.data_processing import probability
from qiskit_experiments.data_processing import DataProcessor, Probability
from qiskit_experiments.exceptions import AnalysisError
from qiskit_experiments.framework import ExperimentData

Expand Down Expand Up @@ -76,85 +76,42 @@ def _default_options(cls):
return TestAnalysis()


class TestFitData(QiskitExperimentsTestCase):
"""Unittest for fit data dataclass."""

def test_get_value(self):
"""Get fit value from fit data object."""
pcov = np.diag(np.ones(3))
popt = np.asarray([1.0, 2.0, 3.0])
fit_params = correlated_values(popt, pcov)

data = FitData(
popt=fit_params,
popt_keys=["a", "b", "c"],
pcov=pcov,
reduced_chisq=0.0,
dof=0,
x_range=(0, 0),
y_range=(0, 0),
)

a_val = data.fitval("a")
self.assertEqual(a_val, fit_params[0])

b_val = data.fitval("b")
self.assertEqual(b_val, fit_params[1])

c_val = data.fitval("c")
self.assertEqual(c_val, fit_params[2])


class TestCurveAnalysisUnit(QiskitExperimentsTestCase):
"""Unittest for curve fit analysis."""

def setUp(self):
super().setUp()
self.xvalues = np.linspace(1.0, 5.0, 10)

# Description of test setting
#
# - This model contains three curves, namely, curve1, curve2, curve3
# - Each curve can be represented by the same function
# - Parameter amp and baseline are shared among all curves
# - Each curve has unique lamb
# - In total 5 parameters in the fit, namely, p0, p1, p2, p3
#
self.analysis = create_new_analysis(
series=[
SeriesDef(
name="curve1",
fit_func=lambda x, par0, par1, par2, par3, par4: fit_function.exponential_decay(
x, amp=par0, lamb=par1, baseline=par4
),
filter_kwargs={"type": 1, "valid": True},
model_description=r"p_0 * \exp(p_1 x) + p4",
class TestAnalysis(CurveAnalysis):
"""Fake analysis class for unittest."""
__series__ = [
SeriesDef(
name="curve1",
fit_func=lambda x, par0, par1, par2, par3, par4: fit_function.exponential_decay(
x, amp=par0, lamb=par1, baseline=par4
),
SeriesDef(
name="curve2",
fit_func=lambda x, par0, par1, par2, par3, par4: fit_function.exponential_decay(
x, amp=par0, lamb=par2, baseline=par4
),
filter_kwargs={"type": 2, "valid": True},
model_description=r"p_0 * \exp(p_2 x) + p4",
filter_kwargs={"op1": 1, "op2": True},
model_description=r"p_0 * \exp(p_1 x) + p4",
),
SeriesDef(
name="curve2",
fit_func=lambda x, par0, par1, par2, par3, par4: fit_function.exponential_decay(
x, amp=par0, lamb=par2, baseline=par4
),
SeriesDef(
name="curve3",
fit_func=lambda x, par0, par1, par2, par3, par4: fit_function.exponential_decay(
x, amp=par0, lamb=par3, baseline=par4
),
filter_kwargs={"type": 3, "valid": True},
model_description=r"p_0 * \exp(p_3 x) + p4",
filter_kwargs={"op1": 2, "op2": True},
model_description=r"p_0 * \exp(p_2 x) + p4",
),
SeriesDef(
name="curve3",
fit_func=lambda x, par0, par1, par2, par3, par4: fit_function.exponential_decay(
x, amp=par0, lamb=par3, baseline=par4
),
],
)
self.err_decimal = 3
filter_kwargs={"op1": 3, "op2": True},
model_description=r"p_0 * \exp(p_3 x) + p4",
),
]

def test_parsed_fit_params(self):
"""Test parsed fit params."""
self.assertSetEqual(
set(self.analysis._fit_params()), {"par0", "par1", "par2", "par3", "par4"}
)
analysis = self.TestAnalysis()
self.assertSetEqual(set(analysis.parameters), {"par0", "par1", "par2", "par3", "par4"})

def test_cannot_create_invalid_series_fit(self):
"""Test we cannot create invalid analysis instance."""
Expand All @@ -176,100 +133,137 @@ def test_cannot_create_invalid_series_fit(self):

def test_data_extraction(self):
"""Test data extraction method."""
self.analysis.set_options(x_key="xval")
xvalues = np.linspace(1.0, 5.0, 10)

analysis = self.TestAnalysis()
analysis.set_options(
data_processor=DataProcessor("counts", [Probability("1")])
)

# data to analyze
test_data0 = simulate_output_data(
func=fit_function.exponential_decay,
xvals=self.xvalues,
xvals=xvalues,
param_dict={"amp": 1.0},
type=1,
valid=True,
op1=1,
op2=True,
)

# fake data
test_data1 = simulate_output_data(
func=fit_function.exponential_decay,
xvals=self.xvalues,
param_dict={"amp": 0.5},
type=2,
valid=False,
curve_data = analysis._run_data_processing(
raw_data=test_data0.data(),
series=analysis.__series__,
)

# merge two experiment data
for datum in test_data1.data():
test_data0.add_data(datum)
# check x values
ref_x = xvalues
np.testing.assert_array_almost_equal(curve_data.x, ref_x)

# check y values
ref_y = fit_function.exponential_decay(xvalues, amp=1.0)
np.testing.assert_array_almost_equal(curve_data.y, ref_y, decimal=3)

self.analysis._extract_curves(
experiment_data=test_data0, data_processor=probability(outcome="1")
# check data allocation
ref_alloc = np.zeros(10, dtype=int)
self.assertListEqual(list(curve_data.data_allocation), list(ref_alloc))

def test_data_extraction_with_subset(self):
"""Test data extraction method with multiple series."""
xvalues = np.linspace(1.0, 5.0, 10)

analysis = self.TestAnalysis()
analysis.set_options(
data_processor=DataProcessor("counts", [Probability("1")])
)

raw_data = self.analysis._data(label="raw_data")
# data to analyze
test_data0 = simulate_output_data(
func=fit_function.exponential_decay,
xvals=xvalues,
param_dict={"amp": 1.0},
op1=1,
op2=True,
)

xdata = raw_data.x
ydata = raw_data.y
sigma = raw_data.y_err
d_index = raw_data.data_index
test_data1 = simulate_output_data(
func=fit_function.exponential_decay,
xvals=xvalues,
param_dict={"amp": 0.5},
op1=2,
op2=True,
)

# check if the module filter off data: valid=False
self.assertEqual(len(xdata), 20)
# get subset
curve_data_of_1 = analysis._run_data_processing(
raw_data=test_data0.data() + test_data1.data(),
series=analysis.__series__,
).get_subset_of("curve1")

# check x values
ref_x = np.concatenate((self.xvalues, self.xvalues))
np.testing.assert_array_almost_equal(xdata, ref_x)
ref_x = xvalues
np.testing.assert_array_almost_equal(curve_data_of_1.x, ref_x)

# check y values
ref_y = np.concatenate(
(
fit_function.exponential_decay(self.xvalues, amp=1.0),
fit_function.exponential_decay(self.xvalues, amp=0.5),
)
ref_y = fit_function.exponential_decay(xvalues, amp=1.0)
np.testing.assert_array_almost_equal(curve_data_of_1.y, ref_y, decimal=3)

# check data allocation
ref_alloc = np.zeros(10, dtype=int)
self.assertListEqual(list(curve_data_of_1.data_allocation), list(ref_alloc))

def test_create_results(self):
"""Test creating analysis results."""
analysis = self.TestAnalysis()
analysis.set_options(
result_parameters=["par0", ParameterRepr("par1", "Param1", "SomeUnit")],
)

pcov = np.diag(np.ones(5))
popt = np.asarray([1.0, 2.0, 3.0, 4.0, 5.0])
fit_params = correlated_values(popt, pcov)

fit_data = FitData(
popt=fit_params,
popt_keys=["par0", "par1", "par2", "par3", "par4", "par5"],
pcov=pcov,
reduced_chisq=2.0,
dof=0,
x_data=np.arange(5),
y_data=np.arange(5),
)
np.testing.assert_array_almost_equal(ydata, ref_y, decimal=self.err_decimal)

# check series
ref_series = np.concatenate((np.zeros(10, dtype=int), -1 * np.ones(10, dtype=int)))
self.assertListEqual(list(d_index), list(ref_series))
outcomes = analysis._create_analysis_results(fit_data, quality="good", test_val=1)

# check y errors
ref_yerr = ref_y * (1 - ref_y) / 100000
np.testing.assert_array_almost_equal(sigma, ref_yerr, decimal=self.err_decimal)
# entry name
self.assertEqual(outcomes[0].name, "@Parameters_TestAnalysis")
self.assertEqual(outcomes[1].name, "par0")
self.assertEqual(outcomes[2].name, "Param1")

def test_get_subset(self):
"""Test that get subset data from full data array."""
# data to analyze
fake_data = [
{"data": 1, "metadata": {"xval": 1, "type": 1, "valid": True}},
{"data": 2, "metadata": {"xval": 2, "type": 2, "valid": True}},
{"data": 3, "metadata": {"xval": 3, "type": 1, "valid": True}},
{"data": 4, "metadata": {"xval": 4, "type": 3, "valid": True}},
{"data": 5, "metadata": {"xval": 5, "type": 3, "valid": True}},
{"data": 6, "metadata": {"xval": 6, "type": 4, "valid": True}}, # this if fake
]
expdata = ExperimentData(experiment=FakeExperiment())
for datum in fake_data:
expdata.add_data(datum)

def _processor(datum):
return datum["data"], datum["data"] * 2

self.analysis.set_options(x_key="xval")
self.analysis._extract_curves(expdata, data_processor=_processor)

filt_data = self.analysis._data(series_name="curve1")
np.testing.assert_array_equal(filt_data.x, np.asarray([1, 3], dtype=float))
np.testing.assert_array_equal(filt_data.y, np.asarray([1, 3], dtype=float))
np.testing.assert_array_equal(filt_data.y_err, np.asarray([2, 6], dtype=float))

filt_data = self.analysis._data(series_name="curve2")
np.testing.assert_array_equal(filt_data.x, np.asarray([2], dtype=float))
np.testing.assert_array_equal(filt_data.y, np.asarray([2], dtype=float))
np.testing.assert_array_equal(filt_data.y_err, np.asarray([4], dtype=float))

filt_data = self.analysis._data(series_name="curve3")
np.testing.assert_array_equal(filt_data.x, np.asarray([4, 5], dtype=float))
np.testing.assert_array_equal(filt_data.y, np.asarray([4, 5], dtype=float))
np.testing.assert_array_equal(filt_data.y_err, np.asarray([8, 10], dtype=float))
# entry value
self.assertEqual(outcomes[1].value, fit_params[0])
self.assertEqual(outcomes[2].value, fit_params[1])

# other metadata
self.assertEqual(outcomes[2].quality, "good")
self.assertEqual(outcomes[2].chisq, 2.0)
ref_meta = {
"test_val": 1,
"unit": "SomeUnit",
}
self.assertDictEqual(outcomes[2].extra, ref_meta)

def test_invalid_options(self):
"""Test setting invalid options."""
analysis = self.TestAnalysis()

class InvalidClass:
"""Dummy class."""
pass

with self.assertRaises(TypeError):
analysis.set_options(data_processor=InvalidClass())

with self.assertRaises(TypeError):
analysis.set_options(curve_plotter=InvalidClass())


class TestCurveAnalysisIntegration(QiskitExperimentsTestCase):
Expand Down

0 comments on commit eb22e39

Please sign in to comment.