Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
- Remove context of fit model from data name and index; model_name -> name, model_id -> class_id
- Remove extra metadata from the scatter table
  • Loading branch information
nkanazawa1989 committed Nov 14, 2023
1 parent 2307d0e commit 7ecee38
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 158 deletions.
10 changes: 8 additions & 2 deletions qiskit_experiments/curve_analysis/base_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _default_options(cls) -> Options:
lmfit_options (Dict[str, Any]): Options that are passed to the
LMFIT minimizer. Acceptable options depend on fit_method.
x_key (str): Circuit metadata key representing a scanned value.
fit_category (str): Name of dataset in the scatter table to fit.
result_parameters (List[Union[str, ParameterRepr]): Parameters reported in the
database as a dedicated entry. This is a list of parameter representation
which is either string or ParameterRepr object. If you provide more
Expand Down Expand Up @@ -219,6 +220,7 @@ def _default_options(cls) -> Options:
options.normalization = False
options.average_method = "shots_weighted"
options.x_key = "xval"
options.fit_category = "formatted"
options.result_parameters = []
options.extra = {}
options.fit_method = "least_squares"
Expand Down Expand Up @@ -282,11 +284,13 @@ def set_options(self, **fields):
def _run_data_processing(
self,
raw_data: List[Dict],
category: str = "raw",
) -> ScatterTable:
"""Perform data processing from the experiment result payload.
Args:
raw_data: Payload in the experiment data.
category: Category string of the output dataset.
Returns:
Processed data that will be sent to the formatter method.
Expand All @@ -296,14 +300,16 @@ def _run_data_processing(
def _format_data(
self,
curve_data: ScatterTable,
category: str = "formatted",
) -> ScatterTable:
"""Postprocessing for the processed dataset.
"""Postprocessing for preparing the fitting data.
Args:
curve_data: Processed dataset created from experiment results.
category: Category string of the output dataset.
Returns:
Formatted data.
New scatter table instance including fit data.
"""

@abstractmethod
Expand Down
60 changes: 27 additions & 33 deletions qiskit_experiments/curve_analysis/composite_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,32 +230,32 @@ def _create_figures(
A list of figures.
"""
for analysis in self.analyses():
sub_data = curve_data[curve_data.model_name.str.endswith(f"_{analysis.name}")]
for model_id, data in list(sub_data.groupby("model_id")):
model_name = analysis._models[model_id]._name
sub_data = curve_data[curve_data.group == analysis.name]
for name, data in list(sub_data.groupby("name")):
full_name = f"{name}_{analysis.name}"
# Plot raw data scatters
if analysis.options.plot_raw_data:
raw_data = data.filter(like="processed", axis="index")
raw_data = data[data.category == "raw"]
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x=raw_data.xval.to_numpy(),
y=raw_data.yval.to_numpy(),
)
# Plot formatted data scatters
formatted_data = data.filter(like="formatted", axis="index")
formatted_data = data[data.category == analysis.options.fit_category]
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x_formatted=formatted_data.xval.to_numpy(),
y_formatted=formatted_data.yval.to_numpy(),
y_formatted_err=formatted_data.yerr.to_numpy(),
)
# Plot fit lines
line_data = data.filter(like="fitted", axis="index")
line_data = data[data.category == "fitted"]
if len(line_data) == 0:
continue
fit_stdev = line_data.yerr.to_numpy()
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x_interp=line_data.xval.to_numpy(),
y_interp=line_data.yval.to_numpy(),
y_interp_err=fit_stdev if np.isfinite(fit_stdev).all() else None,
Expand Down Expand Up @@ -353,21 +353,16 @@ def _run_analysis(
metadata = analysis.options.extra.copy()
metadata["group"] = analysis.name

curve_data = analysis._format_data(
analysis._run_data_processing(experiment_data.data())
)
fit_data = analysis._run_curve_fit(curve_data.filter(like="formatted", axis="index"))
table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
formatted_subset = table[table.category == analysis.options.fit_category]
fit_data = analysis._run_curve_fit(formatted_subset)
fit_dataset[analysis.name] = fit_data

if fit_data.success:
quality = analysis._evaluate_quality(fit_data)
else:
quality = "bad"

# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and quality == "bad")

if self.options.return_fit_parameters:
# Store fit status overview entry regardless of success.
# This is sometime useful when debugging the fitting code.
Expand All @@ -382,10 +377,9 @@ def _run_analysis(
if fit_data.success:
# Add fit data to curve data table
fit_curves = []
formatted = curve_data.filter(like="formatted", axis="index")
columns = list(curve_data.columns)
for i, sub_data in list(formatted.groupby("model_id")):
name = analysis._models[i]._name
columns = list(table.columns)
model_names = analysis.model_names()
for i, sub_data in list(formatted_subset.groupby("class_id")):
xval = sub_data.xval.to_numpy()
if len(xval) == 0:
# If data is empty, skip drawing this model.
Expand All @@ -404,12 +398,10 @@ def _run_analysis(
model_fit[:, columns.index("yval")] = unp.nominal_values(yval_fit)
if fit_data.covar is not None:
model_fit[:, columns.index("yerr")] = unp.std_devs(yval_fit)
model_fit[:, columns.index("model_name")] = name
model_fit[:, columns.index("model_id")] = i
curve_data = curve_data.append_list_values(
other=np.vstack(fit_curves),
prefix="fitted",
)
model_fit[:, columns.index("name")] = model_names[i]
model_fit[:, columns.index("class_id")] = i
model_fit[:, columns.index("category")] = "fitted"
table = table.append_list_values(other=np.vstack(fit_curves))
analysis_results.extend(
analysis._create_analysis_results(
fit_data=fit_data,
Expand All @@ -421,18 +413,20 @@ def _run_analysis(
if self.options.return_data_points:
# Add raw data points
analysis_results.extend(
analysis._create_curve_data(
curve_data=curve_data.filter(like="formatted", axis="index"),
**metadata,
)
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
)

curve_data.model_name += f"_{analysis.name}"
curve_data_set.append(curve_data)
# Add extra column to identify the fit model
table["group"] = analysis.name
curve_data_set.append(table)

combined_curve_data = pd.concat(curve_data_set)
total_quality = self._evaluate_quality(fit_dataset)

# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")

# Create analysis results by combining all fit data
if all(fit_data.success for fit_data in fit_dataset.values()):
composite_results = self._create_analysis_results(
Expand Down
Loading

0 comments on commit 7ecee38

Please sign in to comment.