Skip to content

Commit

Permalink
Upgrade unittest equality method (#1132)
Browse files Browse the repository at this point in the history
### Summary

Current implementation of equality check, i.e.
`QiskitExperimentsTestCase.json_equiv`, is not readable and scalable
because it implements equality check logic for different types in a
single method. This PR adds new test module `test/extended_equality.py`
which implements new equality check dispatcher `is_equivalent`.

Developers no longer need to specify `check_func` in the
`assertRoundTripSerializable` and `assertRoundTripPickle` methods unless
they define custom class for a specific unittest. This simplifies
unittests and improves readability of equality check logic (and test
becomes more trustable).

This PR adds new software dependency in develop;
[multimethod](https://pypi.org/project/multimethod/)

Among several similar packages, this is chosen in favor of
- its license type (Apache License, Version 2.0) 
- syntax compatibility with `functools.singledispatch`
- support for subscripted generics in `typings`, e.g. `Union`

---------

Co-authored-by: Helena Zhang <[email protected]>
  • Loading branch information
nkanazawa1989 and coruscating authored May 11, 2023
1 parent 4038556 commit 2278679
Show file tree
Hide file tree
Showing 32 changed files with 564 additions and 257 deletions.
27 changes: 27 additions & 0 deletions releasenotes/notes/add-test-equality-checker-dbe5762d2b6a967f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
developer:
- |
Added the :meth:`QiskitExperimentsTestCase.assertEqualExtended` method for generic equality checks
of Qiskit Experiments class instances in unittests. This is a drop-in replacement of
calling the assertTrue with :meth:`QiskitExperimentsTestCase.json_equiv`.
Note that some Qiskit Experiments classes may not officially implement equality check logic,
although objects may be compared during unittests. Extended equality check is used
for such situations.
- |
The following unittest test case methods will be deprecated:
* :meth:`QiskitExperimentsTestCase.json_equiv`
* :meth:`QiskitExperimentsTestCase.ufloat_equiv`
* :meth:`QiskitExperimentsTestCase.analysis_result_equiv`
* :meth:`QiskitExperimentsTestCase.curve_fit_data_equiv`
* :meth:`QiskitExperimentsTestCase.experiment_data_equiv`
One can now use the :func:`~test.extended_equality.is_equivalent` function instead.
This function internally dispatches the logic for equality check.
- |
The default behavior of :meth:`QiskitExperimentsTestCase.assertRoundTripSerializable` and
:meth:`QiskitExperimentsTestCase.assertRoundTripPickle` when `check_func` is not
provided was upgraded. These methods now compare the decoded instance with
:func:`~test.extended_equality.is_equivalent`, rather than
delegating to the native `assertEqual` unittest method.
One writing a unittest for serialization no longer need to explicitly set checker function.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ qiskit-aer>=0.11.0
pandas>=1.1.5
cvxpy>=1.1.15
pylatexenc
multimethod
scikit-learn
sphinx-copybutton
# Pin versions below because of build errors
Expand Down
244 changes: 98 additions & 146 deletions test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,22 @@
Qiskit Experiments test case class
"""

import dataclasses
import json
import pickle
import warnings
from typing import Any, Callable, Optional

import numpy as np
import uncertainties
from lmfit import Model
from qiskit.test import QiskitTestCase
from qiskit_experiments.data_processing import DataAction, DataProcessor
from qiskit_experiments.framework.experiment_data import ExperimentStatus
from qiskit.utils.deprecation import deprecate_func

from qiskit_experiments.framework import (
ExperimentDecoder,
ExperimentEncoder,
ExperimentData,
BaseExperiment,
BaseAnalysis,
)
from qiskit_experiments.visualization import BaseDrawer
from qiskit_experiments.curve_analysis.curve_data import CurveFitResult
from qiskit_experiments.framework.experiment_data import ExperimentStatus
from .extended_equality import is_equivalent


class QiskitExperimentsTestCase(QiskitTestCase):
Expand Down Expand Up @@ -76,15 +71,52 @@ def assertExperimentDone(
msg="All threads are executed but status is not DONE. " + experiment_data.errors(),
)

def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] = None):
def assertEqualExtended(
self,
first: Any,
second: Any,
*,
msg: Optional[str] = None,
strict_type: bool = False,
):
"""Extended equality assertion which covers Qiskit Experiments classes.
.. note::
Some Qiskit Experiments class may intentionally avoid implementing
the equality dunder method, or may be used in some unusual situations.
These are mainly caused by to JSON round trip situation, and some custom classes
doesn't guarantee object equality after round trip.
This assertion function forcibly compares input two objects with
the custom equality checker, which is implemented for unittest purpose.
Args:
first: First object to compare.
second: Second object to compare.
msg: Optional. Custom error message issued when first and second object are not equal.
strict_type: Set True to enforce type check before comparison.
"""
default_msg = f"{first} != {second}"

self.assertTrue(
is_equivalent(first, second, strict_type=strict_type),
msg=msg or default_msg,
)

def assertRoundTripSerializable(
self,
obj: Any,
*,
check_func: Optional[Callable] = None,
strict_type: bool = False,
):
"""Assert that an object is round trip serializable.
Args:
obj: the object to be serialized.
check_func: Optional, a custom function ``check_func(a, b) -> bool``
to check equality of the original object with the decoded
object. If None the ``__eq__`` method of the original
object will be used.
to check equality of the original object with the decoded
object. If None :meth:`.assertEqualExtended` is called.
strict_type: Set True to enforce type check before comparison.
"""
try:
encoded = json.dumps(obj, cls=ExperimentEncoder)
Expand All @@ -94,20 +126,27 @@ def assertRoundTripSerializable(self, obj: Any, check_func: Optional[Callable] =
decoded = json.loads(encoded, cls=ExperimentDecoder)
except TypeError:
self.fail("JSON deserialization raised unexpectedly.")
if check_func is None:
self.assertEqual(obj, decoded)
else:

if check_func is not None:
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
else:
self.assertEqualExtended(obj, decoded, strict_type=strict_type)

def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None):
def assertRoundTripPickle(
self,
obj: Any,
*,
check_func: Optional[Callable] = None,
strict_type: bool = False,
):
"""Assert that an object is round trip serializable using pickle module.
Args:
obj: the object to be serialized.
check_func: Optional, a custom function ``check_func(a, b) -> bool``
to check equality of the original object with the decoded
object. If None the ``__eq__`` method of the original
object will be used.
to check equality of the original object with the decoded
object. If None :meth:`.assertEqualExtended` is called.
strict_type: Set True to enforce type check before comparison.
"""
try:
encoded = pickle.dumps(obj)
Expand All @@ -117,150 +156,63 @@ def assertRoundTripPickle(self, obj: Any, check_func: Optional[Callable] = None)
decoded = pickle.loads(encoded)
except TypeError:
self.fail("pickle deserialization raised unexpectedly.")
if check_func is None:
self.assertEqual(obj, decoded)
else:

if check_func is not None:
self.assertTrue(check_func(obj, decoded), msg=f"{obj} != {decoded}")
else:
self.assertEqualExtended(obj, decoded, strict_type=strict_type)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def json_equiv(cls, data1, data2) -> bool:
"""Check if two experiments are equivalent by comparing their configs"""
# pylint: disable = too-many-return-statements
configurable_type = (BaseExperiment, BaseAnalysis, BaseDrawer)
compare_repr = (DataAction, DataProcessor)
list_type = (list, tuple, set)
skipped = tuple()

if isinstance(data1, skipped) and isinstance(data2, skipped):
warnings.warn(f"Equivalence check for data {data1.__class__.__name__} is skipped.")
return True
elif isinstance(data1, configurable_type) and isinstance(data2, configurable_type):
return cls.json_equiv(data1.config(), data2.config())
elif dataclasses.is_dataclass(data1) and dataclasses.is_dataclass(data2):
# not using asdict. this copies all objects.
return cls.json_equiv(data1.__dict__, data2.__dict__)
elif isinstance(data1, dict) and isinstance(data2, dict):
if set(data1) != set(data2):
return False
return all(cls.json_equiv(data1[k], data2[k]) for k in data1.keys())
elif isinstance(data1, np.ndarray) or isinstance(data2, np.ndarray):
return np.allclose(data1, data2)
elif isinstance(data1, list_type) and isinstance(data2, list_type):
return all(cls.json_equiv(e1, e2) for e1, e2 in zip(data1, data2))
elif isinstance(data1, uncertainties.UFloat) and isinstance(data2, uncertainties.UFloat):
return cls.ufloat_equiv(data1, data2)
elif isinstance(data1, Model) and isinstance(data2, Model):
return cls.json_equiv(data1.dumps(), data2.dumps())
elif isinstance(data1, CurveFitResult) and isinstance(data2, CurveFitResult):
return cls.curve_fit_data_equiv(data1, data2)
elif isinstance(data1, compare_repr) and isinstance(data2, compare_repr):
# otherwise compare instance representation
return repr(data1) == repr(data2)

return data1 == data2
return is_equivalent(data1, data2)

@staticmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def ufloat_equiv(data1: uncertainties.UFloat, data2: uncertainties.UFloat) -> bool:
"""Check if two values with uncertainties are equal. No correlation is considered."""
return data1.n == data2.n and data1.s == data2.s
return is_equivalent(data1, data2)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def analysis_result_equiv(cls, result1, result2):
"""Test two analysis results are equivalent"""
# Check basic attributes skipping service which is not serializable
for att in [
"name",
"value",
"extra",
"device_components",
"result_id",
"experiment_id",
"chisq",
"quality",
"verified",
"tags",
"auto_save",
"source",
]:
if not cls.json_equiv(getattr(result1, att), getattr(result2, att)):
return False
return True
return is_equivalent(result1, result2)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def curve_fit_data_equiv(cls, data1, data2):
"""Test two curve fit result are equivalent."""
for att in [
"method",
"model_repr",
"success",
"nfev",
"message",
"dof",
"init_params",
"chisq",
"reduced_chisq",
"aic",
"bic",
"params",
"var_names",
"x_data",
"y_data",
"covar",
]:
if not cls.json_equiv(getattr(data1, att), getattr(data2, att)):
return False
return True
return is_equivalent(data1, data2)

@classmethod
@deprecate_func(
since="0.6",
additional_msg="Use test.extended_equality.is_equivalent instead.",
pending=True,
package_name="qiskit-experiments",
)
def experiment_data_equiv(cls, data1, data2):
"""Check two experiment data containers are equivalent"""

# Check basic attributes
# Skip non-compatible backend
for att in [
"experiment_id",
"experiment_type",
"parent_id",
"tags",
"job_ids",
"figure_names",
"share_level",
"metadata",
]:
if not cls.json_equiv(getattr(data1, att), getattr(data2, att)):
return False

# Check length of data, results, child_data
# check for child data attribute so this method still works for
# DbExperimentData
if hasattr(data1, "child_data"):
child_data1 = data1.child_data()
else:
child_data1 = []
if hasattr(data2, "child_data"):
child_data2 = data2.child_data()
else:
child_data2 = []

if (
len(data1.data()) != len(data2.data())
or len(data1.analysis_results()) != len(data2.analysis_results())
or len(child_data1) != len(child_data2)
):
return False

# Check data
if not cls.json_equiv(data1.data(), data2.data()):
return False

# Check analysis results
for result1, result2 in zip(data1.analysis_results(), data2.analysis_results()):
if not cls.analysis_result_equiv(result1, result2):
return False

# Check child data
for child1, child2 in zip(child_data1, child_data2):
if not cls.experiment_data_equiv(child1, child2):
return False

return True
return is_equivalent(data1, data2)
2 changes: 1 addition & 1 deletion test/calibration/test_calibrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def test_serialization(self):
cals = Calibrations.from_backend(backend, libraries=[library])
cals.add_parameter_value(0.12345, "amp", 3, "x")

self.assertRoundTripSerializable(cals, self.json_equiv)
self.assertRoundTripSerializable(cals)

def test_equality(self):
"""Test the equal method on calibrations."""
Expand Down
2 changes: 1 addition & 1 deletion test/curve_analysis/test_baseclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class TestCurveAnalysis(CurveAnalysisTestCase):
def test_roundtrip_serialize(self):
"""A testcase for serializing analysis instance."""
analysis = CurveAnalysis(models=[ExpressionModel(expr="par0 * x + par1", name="test")])
self.assertRoundTripSerializable(analysis, check_func=self.json_equiv)
self.assertRoundTripSerializable(analysis)

def test_parameters(self):
"""A testcase for getting fit parameters with attribute."""
Expand Down
6 changes: 3 additions & 3 deletions test/data_processing/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,14 @@ def test_json_single_node(self):
"""Check if the data processor is serializable."""
node = MinMaxNormalize()
processor = DataProcessor("counts", [node])
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
self.assertRoundTripSerializable(processor)

def test_json_multi_node(self):
"""Check if the data processor with multiple nodes is serializable."""
node1 = MinMaxNormalize()
node2 = AverageData(axis=2)
processor = DataProcessor("counts", [node1, node2])
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
self.assertRoundTripSerializable(processor)

def test_json_trained(self):
"""Check if trained data processor is serializable and still work."""
Expand All @@ -405,7 +405,7 @@ def test_json_trained(self):
main_axes=np.array([[1, 0]]), scales=[1.0], i_means=[0.0], q_means=[0.0]
)
processor = DataProcessor("memory", data_actions=[node])
self.assertRoundTripSerializable(processor, check_func=self.json_equiv)
self.assertRoundTripSerializable(processor)

serialized = json.dumps(processor, cls=ExperimentEncoder)
loaded_processor = json.loads(serialized, cls=ExperimentDecoder)
Expand Down
Loading

0 comments on commit 2278679

Please sign in to comment.