Skip to content

Commit

Permalink
Scatter table refactoring (#1319)
Browse files Browse the repository at this point in the history
### Summary

This PR modifies `ScatterTable` which is introduced in #1253.

This change resolves some code issues in #1315 and #1243.

### Details and comments

In the original design `ScatterTable` is tied to the fit models, and the
columns contains `model_name` (str) and `model_id` (int). Also the fit
module only allows to have three categorical data; "processed",
"formatted", "fitted". However, #1243 breaks this assumption, namely,
the `StarkRamseyXYAmpScanAnalysis` fitter defines two fit models which
are not directly mapped to the results data. The data fed into the
models is synthesized by consuming the input results data. The fitter
needs to manage four categorical data; "raw", "ramsey" (raw results),
"phase" (synthesized data for fit), and "fitted".

This PR relaxes the tight coupling of data to the fit model. In above
example, "raw" and "ramsey" category data can fill new fields `name`
(formally model_name) and `class_id` (model_id) without indicating a
particular fit model. Usually, raw category data is just classified
according to the `data_subfit_map` definition, and the map doesn't need
to match with the fit models. The connection to fit models is only
introduced in a particular category defined by new option value
`fit_category`. This option defaults to "formatted", but
`StarkRamseyXYAmpScanAnalysis` fitter would set "phase" instead. Thus
fit model assignment is effectively delayed until the formatter
function.

Also the original scatter table is designed to store all circuit
metadata which causes some problem in data formatting, especially when
it tries to average the data over the same x value in the group.
Non-numeric data is averaged by builtin set operation, but this assumes
the metadata value is hashable object, which is not generally true. This
PR also drops all metadata from the scatter table. Note that important
metadata fields for the curve analysis are one used for model
classification (classifier fields), and other fields just decorate the
table with unnecessary memory footprint requirements. The classifier
fields and `name` (`class_id`) are sort of duplicated information. This
implies the `name` and `class_id` fields are enough for end-users to
reuse the table data for further analysis once after it's saved as an
artifact.

---------

Co-authored-by: Will Shanks <[email protected]>
  • Loading branch information
nkanazawa1989 and wshanks authored Dec 22, 2023
1 parent 726f422 commit 5bb1fb4
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 198 deletions.
111 changes: 76 additions & 35 deletions docs/tutorials/curve_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,12 @@ different sets of experiment results. A single experiment can define sub-experim
consisting of multiple circuits which are tagged with common metadata,
and curve analysis sorts the experiment results based on the circuit metadata.

This is an example of showing the abstract data structure of a typical curve analysis experiment:
This is an example showing the abstract data flow of a typical curve analysis experiment:

.. jupyter-input::
:emphasize-lines: 1,10,19

"experiment"
- circuits[0] (x=x1_A, "series_A")
- circuits[1] (x=x1_B, "series_B")
- circuits[2] (x=x2_A, "series_A")
- circuits[3] (x=x2_B, "series_B")
- circuits[4] (x=x3_A, "series_A")
- circuits[5] (x=x3_B, "series_B")
- ...

"experiment data"
- data[0] (y1_A, "series_A")
- data[1] (y1_B, "series_B")
- data[2] (y2_A, "series_A")
- data[3] (y2_B, "series_B")
- data[4] (y3_A, "series_A")
- data[5] (y3_B, "series_B")
- ...

"analysis"
- "series_A": y_A = f_A(x_A; p0, p1, p2)
- "series_B": y_B = f_B(x_B; p0, p1, p2)
- fixed parameters {p1: v}
.. figure:: images/curve_analysis_structure.png
:width: 600
:align: center
:class: no-scaled-link

Here the experiment runs two subsets of experiments, namely, series A and series B.
The analysis defines corresponding fit models :math:`f_A(x_A)` and :math:`f_B(x_B)`.
Expand Down Expand Up @@ -289,21 +268,78 @@ A developer can override this method to perform initialization of analysis-speci

Curve analysis calls the :meth:`_run_data_processing` method, where
the data processor in the analysis option is internally called.
This consumes input experiment results and creates the :class:`.CurveData` dataclass.
Then the :meth:`_format_data` method is called with the processed dataset to format it.
This consumes input experiment results and creates the :class:`.ScatterTable` dataframe.
This table may look like:

.. code-block::
xval yval yerr name class_id category shots
0 0.1 0.153659 0.011258 A 0 raw 1024
1 0.1 0.590732 0.015351 B 1 raw 1024
2 0.1 0.315610 0.014510 A 0 raw 1024
3 0.1 0.376098 0.015123 B 1 raw 1024
4 0.2 0.937073 0.007581 A 0 raw 1024
5 0.2 0.323415 0.014604 B 1 raw 1024
6 0.2 0.538049 0.015565 A 0 raw 1024
7 0.2 0.530244 0.015581 B 1 raw 1024
8 0.3 0.143902 0.010958 A 0 raw 1024
9 0.3 0.261951 0.013727 B 1 raw 1024
10 0.3 0.830732 0.011707 A 0 raw 1024
11 0.3 0.874634 0.010338 B 1 raw 1024
where the experiment consists of two subset series A and B, and the experiment parameter (xval)
is scanned from 0.1 to 0.3 in each subset. In this example, the experiment is run twice
for each condition. The role of each column is as follows:

- ``xval``: Parameter scanned in the experiment. This value must be defined in the circuit metadata.
- ``yval``: Nominal part of the outcome. The outcome is something like expectation value, which is computed from the experiment result with the data processor.
- ``yerr``: Standard error of the outcome, which is mainly due to sampling error.
- ``name``: Unique identifier of the result class. This is defined by the ``data_subfit_map`` option.
- ``class_id``: Numerical index corresponding to the result class. This number is automatically assigned.
- ``category``: The attribute of data set. The "raw" category indicates an output from the data processing.
- ``shots``: Number of measurement shots used to acquire this result.

3. Formatting
^^^^^^^^^^^^^

Next, the processed dataset is converted into another format suited for the fitting and
every valid result is assigned a class corresponding to a fit model.
By default, the formatter takes average of the outcomes in the processed dataset
over the same x values, followed by the sorting in the ascending order of x values.
This allows the analysis to easily estimate the slope of the curves to
create algorithmic initial guess of fit parameters.
A developer can inject extra data processing, for example, filtering, smoothing,
or elimination of outliers for better fitting.
The new class_id is given here so that its value corresponds to the fit model object index
in this analysis class. This index mapping is done based upon the correspondence of
the data name and the fit model name.

This is done by calling :meth:`_format_data` method.
This may return new scatter table object with the addition of rows like the following below.

.. code-block::
12 0.1 0.234634 0.009183 A 0 formatted 2048
13 0.2 0.737561 0.008656 A 0 formatted 2048
14 0.3 0.487317 0.008018 A 0 formatted 2048
15 0.1 0.483415 0.010774 B 1 formatted 2048
16 0.2 0.426829 0.010678 B 1 formatted 2048
17 0.3 0.568293 0.008592 B 1 formatted 2048
The default :meth:`_format_data` method adds its output data with the category "formatted".
This category name must be also specified in the analysis option ``fit_category``.
If overriding this method to do additional processing after the default formatting,
the ``fit_category`` analysis option can be set to choose a different category name to use to
select the data to pass to the fitting routine.
The (x, y) value in each row is passed to the corresponding fit model object
to compute residual values for the least square optimization.

3. Fitting
^^^^^^^^^^

Curve analysis calls the :meth:`_run_curve_fit` method, which is the core functionality of the fitting.
Another method :meth:`_generate_fit_guesses` is internally called to
prepare the initial guess and parameter boundary with respect to the formatted data.
Curve analysis calls the :meth:`_run_curve_fit` method with the formatted subset of the scatter table.
This internally calls :meth:`_generate_fit_guesses` to prepare
the initial guess and parameter boundary with respect to the formatted dataset.
Developers usually override this method to provide better initial guesses
tailored to the defined fit model or type of the associated experiment.
See :ref:`curve_analysis_init_guess` for more details.
Expand All @@ -314,13 +350,18 @@ custom fitting algorithms. This method must return a :class:`.CurveFitResult` da
^^^^^^^^^^^^^^^^^^

Curve analysis runs several postprocessing against the fit outcome.
It calls :meth:`._create_analysis_results` to create the :class:`.AnalysisResultData` class
When the fit is successful, it calls :meth:`._create_analysis_results` to create the :class:`.AnalysisResultData` objects
for the fitting parameters of interest. A developer can inject custom code to
compute custom quantities based on the raw fit parameters.
See :ref:`curve_analysis_results` for details.
Afterwards, figure plotting is handed over to the :doc:`Visualization </tutorials/visualization>` module via
the :attr:`~.CurveAnalysis.plotter` attribute, and a list of created analysis results and the figure are returned.

Afterwards, fit curves are computed with the fit models and optimal parameters, and the scatter table is
updated with the computed (x, y) values. This dataset is stored under the "fitted" category.

Finally, the :meth:`._create_figures` method is called with the entire scatter table data
to initialize the curve plotter instance accessible via the :attr:`~.CurveAnalysis.plotter` attribute.
The visualization is handed over to the :doc:`Visualization </tutorials/visualization>` module,
which provides a standardized image format for curve fit results.
A developer can overwrite this method to draw custom images.

.. _curve_analysis_init_guess:

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 8 additions & 2 deletions qiskit_experiments/curve_analysis/base_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _default_options(cls) -> Options:
lmfit_options (Dict[str, Any]): Options that are passed to the
LMFIT minimizer. Acceptable options depend on fit_method.
x_key (str): Circuit metadata key representing a scanned value.
fit_category (str): Name of dataset in the scatter table to fit.
result_parameters (List[Union[str, ParameterRepr]): Parameters reported in the
database as a dedicated entry. This is a list of parameter representation
which is either string or ParameterRepr object. If you provide more
Expand Down Expand Up @@ -219,6 +220,7 @@ def _default_options(cls) -> Options:
options.normalization = False
options.average_method = "shots_weighted"
options.x_key = "xval"
options.fit_category = "formatted"
options.result_parameters = []
options.extra = {}
options.fit_method = "least_squares"
Expand Down Expand Up @@ -282,11 +284,13 @@ def set_options(self, **fields):
def _run_data_processing(
self,
raw_data: List[Dict],
category: str = "raw",
) -> ScatterTable:
"""Perform data processing from the experiment result payload.
Args:
raw_data: Payload in the experiment data.
category: Category string of the output dataset.
Returns:
Processed data that will be sent to the formatter method.
Expand All @@ -296,14 +300,16 @@ def _run_data_processing(
def _format_data(
self,
curve_data: ScatterTable,
category: str = "formatted",
) -> ScatterTable:
"""Postprocessing for the processed dataset.
"""Postprocessing for preparing the fitting data.
Args:
curve_data: Processed dataset created from experiment results.
category: Category string of the output dataset.
Returns:
Formatted data.
New scatter table instance including fit data.
"""

@abstractmethod
Expand Down
60 changes: 27 additions & 33 deletions qiskit_experiments/curve_analysis/composite_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,32 +230,32 @@ def _create_figures(
A list of figures.
"""
for analysis in self.analyses():
sub_data = curve_data[curve_data.model_name.str.endswith(f"_{analysis.name}")]
for model_id, data in list(sub_data.groupby("model_id")):
model_name = analysis._models[model_id]._name
sub_data = curve_data[curve_data.group == analysis.name]
for name, data in list(sub_data.groupby("name")):
full_name = f"{name}_{analysis.name}"
# Plot raw data scatters
if analysis.options.plot_raw_data:
raw_data = data.filter(like="processed", axis="index")
raw_data = data[data.category == "raw"]
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x=raw_data.xval.to_numpy(),
y=raw_data.yval.to_numpy(),
)
# Plot formatted data scatters
formatted_data = data.filter(like="formatted", axis="index")
formatted_data = data[data.category == analysis.options.fit_category]
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x_formatted=formatted_data.xval.to_numpy(),
y_formatted=formatted_data.yval.to_numpy(),
y_formatted_err=formatted_data.yerr.to_numpy(),
)
# Plot fit lines
line_data = data.filter(like="fitted", axis="index")
line_data = data[data.category == "fitted"]
if len(line_data) == 0:
continue
fit_stdev = line_data.yerr.to_numpy()
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x_interp=line_data.xval.to_numpy(),
y_interp=line_data.yval.to_numpy(),
y_interp_err=fit_stdev if np.isfinite(fit_stdev).all() else None,
Expand Down Expand Up @@ -353,21 +353,16 @@ def _run_analysis(
metadata = analysis.options.extra.copy()
metadata["group"] = analysis.name

curve_data = analysis._format_data(
analysis._run_data_processing(experiment_data.data())
)
fit_data = analysis._run_curve_fit(curve_data.filter(like="formatted", axis="index"))
table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
formatted_subset = table[table.category == analysis.options.fit_category]
fit_data = analysis._run_curve_fit(formatted_subset)
fit_dataset[analysis.name] = fit_data

if fit_data.success:
quality = analysis._evaluate_quality(fit_data)
else:
quality = "bad"

# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and quality == "bad")

if self.options.return_fit_parameters:
# Store fit status overview entry regardless of success.
# This is sometime useful when debugging the fitting code.
Expand All @@ -382,10 +377,9 @@ def _run_analysis(
if fit_data.success:
# Add fit data to curve data table
fit_curves = []
formatted = curve_data.filter(like="formatted", axis="index")
columns = list(curve_data.columns)
for i, sub_data in list(formatted.groupby("model_id")):
name = analysis._models[i]._name
columns = list(table.columns)
model_names = analysis.model_names()
for i, sub_data in list(formatted_subset.groupby("class_id")):
xval = sub_data.xval.to_numpy()
if len(xval) == 0:
# If data is empty, skip drawing this model.
Expand All @@ -404,12 +398,10 @@ def _run_analysis(
model_fit[:, columns.index("yval")] = unp.nominal_values(yval_fit)
if fit_data.covar is not None:
model_fit[:, columns.index("yerr")] = unp.std_devs(yval_fit)
model_fit[:, columns.index("model_name")] = name
model_fit[:, columns.index("model_id")] = i
curve_data = curve_data.append_list_values(
other=np.vstack(fit_curves),
prefix="fitted",
)
model_fit[:, columns.index("name")] = model_names[i]
model_fit[:, columns.index("class_id")] = i
model_fit[:, columns.index("category")] = "fitted"
table = table.append_list_values(other=np.vstack(fit_curves))
analysis_results.extend(
analysis._create_analysis_results(
fit_data=fit_data,
Expand All @@ -421,18 +413,20 @@ def _run_analysis(
if self.options.return_data_points:
# Add raw data points
analysis_results.extend(
analysis._create_curve_data(
curve_data=curve_data.filter(like="formatted", axis="index"),
**metadata,
)
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
)

curve_data.model_name += f"_{analysis.name}"
curve_data_set.append(curve_data)
# Add extra column to identify the fit model
table["group"] = analysis.name
curve_data_set.append(table)

combined_curve_data = pd.concat(curve_data_set)
total_quality = self._evaluate_quality(fit_dataset)

# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")

# Create analysis results by combining all fit data
if all(fit_data.success for fit_data in fit_dataset.values()):
composite_results = self._create_analysis_results(
Expand Down
Loading

0 comments on commit 5bb1fb4

Please sign in to comment.