Skip to content

Commit

Permalink
Raise user warning when numbers contain multiple series
Browse files Browse the repository at this point in the history
  • Loading branch information
nkanazawa1989 committed Feb 2, 2024
1 parent 8dc6c4f commit 144127a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
46 changes: 43 additions & 3 deletions qiskit_experiments/curve_analysis/scatter_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Iterator
from typing import Any
from itertools import groupby
Expand Down Expand Up @@ -116,6 +117,7 @@ def get_x(
kind: int | str | None = None,
category: str | None = None,
analysis: str | None = None,
check_unique: bool = True,
) -> np.ndarray:
"""Get subset of X values.
Expand All @@ -125,11 +127,16 @@ def get_x(
kind: Identifier of the data, either data UID or name.
category: Name of data category.
analysis: Name of analysis.
check_unique: Set True to check if multiple series are contained.
When multiple series are contained, it raises a user warning.
Returns:
Numpy array of X values.
"""
return self.filter(kind, category, analysis).x
sub_table = self.filter(kind, category, analysis)
if check_unique:
self._warn_composite_data(sub_table)
return sub_table.x

@property
def y(self) -> np.ndarray:
Expand All @@ -145,6 +152,7 @@ def get_y(
kind: int | str | None = None,
category: str | None = None,
analysis: str | None = None,
check_unique: bool = True,
) -> np.ndarray:
"""Get subset of Y values.
Expand All @@ -154,11 +162,16 @@ def get_y(
kind: Identifier of the data, either data UID or name.
category: Name of data category.
analysis: Name of analysis.
check_unique: Set True to check if multiple series are contained.
When multiple series are contained, it raises a user warning.
Returns:
Numpy array of Y values.
"""
return self.filter(kind, category, analysis).y
sub_table = self.filter(kind, category, analysis)
if check_unique:
self._warn_composite_data(sub_table)
return sub_table.y

@property
def y_err(self) -> np.ndarray:
Expand All @@ -174,6 +187,7 @@ def get_y_err(
kind: int | str | None = None,
category: str | None = None,
analysis: str | None = None,
check_unique: bool = True,
) -> np.ndarray:
"""Get subset of standard deviation of Y values.
Expand All @@ -183,11 +197,16 @@ def get_y_err(
kind: Identifier of the data, either data UID or name.
category: Name of data category.
analysis: Name of analysis.
check_unique: Set True to check if multiple series are contained.
When multiple series are contained, it raises a user warning.
Returns:
Numpy array of Y error values.
"""
return self.filter(kind, category, analysis).y_err
sub_table = self.filter(kind, category, analysis)
if check_unique:
self._warn_composite_data(sub_table)
return sub_table.y_err

@property
def name(self) -> np.ndarray:
Expand Down Expand Up @@ -339,6 +358,27 @@ def _format_table(cls, data: pd.DataFrame) -> pd.DataFrame:
.reset_index(drop=True)
)

@staticmethod
def _warn_composite_data(table: ScatterTable):
if len(table._data.name.unique()) > 1:
warnings.warn(
"Returned data contains multiple series. "
"You may want to filter the data by a specific kind identifier.",
UserWarning,
)
if len(table._data.category.unique()) > 1:
warnings.warn(
"Returned data contains multiple categories. "
"You may want to filter the data by a specific category name.",
UserWarning,
)
if len(table._data.analysis.unique()) > 1:
warnings.warn(
"Returned data contains multiple datasets from different component analyses. "
"You may want to filter the data by a specific analysis name.",
UserWarning,
)

@property
@deprecate_func(
since="0.6",
Expand Down
22 changes: 14 additions & 8 deletions test/curve_analysis/test_scatter_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"""Test scatter table."""

from test.base import QiskitExperimentsTestCase

import pandas as pd
import numpy as np

Expand Down Expand Up @@ -152,16 +151,23 @@ def test_get_subset_numbers(self):
"""Test end-user shortcut for getting the subset of x, y, y_err data."""
obj = ScatterTable.from_dataframe(self.reference)

np.testing.assert_array_equal(obj.get_x("model1", "raw", "Fit1"), np.array([0.100, 0.200]))
np.testing.assert_array_equal(obj.get_y("model1", "raw", "Fit1"), np.array([0.192, 0.854]))
np.testing.assert_array_equal(
obj.get_x("model1", "raw"), np.array([0.100, 0.200, 0.100, 0.200])
)
np.testing.assert_array_equal(
obj.get_y("model1", "raw"), np.array([0.192, 0.854, 0.567, 0.488])
)
np.testing.assert_array_equal(
obj.get_y_err("model1", "raw"), np.array([0.002, 0.090, 0.033, 0.038])
obj.get_y_err("model1", "raw", "Fit1"), np.array([0.002, 0.090])
)

def test_warn_composite_values(self):
"""Test raise warning when returned x, y, y_err data contains multiple data series."""
obj = ScatterTable.from_dataframe(self.reference)

with self.assertWarns(UserWarning):
obj.get_x()
with self.assertWarns(UserWarning):
obj.get_y()
with self.assertWarns(UserWarning):
obj.get_y_err()

def test_filter_data_by_class_id(self):
"""Test filter table data with data UID."""
obj = ScatterTable.from_dataframe(self.reference)
Expand Down

0 comments on commit 144127a

Please sign in to comment.