diff --git a/test/dynamics/common.py b/test/dynamics/common.py index ec8cd5115..11bdf26e7 100644 --- a/test/dynamics/common.py +++ b/test/dynamics/common.py @@ -54,7 +54,7 @@ from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS -from qiskit_dynamics.array import Array, wrap +from qiskit_dynamics.array import Array def _is_sparse_object_array(A): @@ -336,53 +336,3 @@ def setUpClass(cls): import qutip # pylint: disable=import-outside-toplevel,unused-import except Exception as err: raise unittest.SkipTest("Skipping qutip tests.") from err - - -# to be removed for 0.5.0 -class TestJaxBase(unittest.TestCase): - """Base class with setUpClass and tearDownClass for setting jax as the - default backend. - - Test cases that inherit from this class will automatically work with jax - backend. - """ - - @classmethod - def setUpClass(cls): - try: - # pylint: disable=import-outside-toplevel - import jax - - jax.config.update("jax_enable_x64", True) - jax.config.update("jax_platform_name", "cpu") - except Exception as err: - raise unittest.SkipTest("Skipping jax tests.") from err - - Array.set_default_backend("jax") - - @classmethod - def tearDownClass(cls): - """Set numpy back to the default backend.""" - Array.set_default_backend("numpy") - - def jit_wrap(self, func_to_test: Callable) -> Callable: - """Wraps and jits func_to_test. - Args: - func_to_test: The function to be jited. - Returns: - Wrapped and jitted function.""" - wf = wrap(jit, decorator=True) - return wf(wrap(func_to_test)) - - def jit_grad_wrap(self, func_to_test: Callable) -> Callable: - """Tests whether a function can be graded. Converts - all functions to scalar, real functions if they are not - already. - Args: - func_to_test: The function whose gradient will be graded. - Returns: - JIT-compiled gradient of function. - """ - return wrap(lambda f: jit(grad(f)), decorator=True)( - lambda *args: np.sum(func_to_test(*args)).real - ) diff --git a/test/dynamics/signals/test_signals_algebra.py b/test/dynamics/signals/test_signals_algebra.py index 434046756..0bb98f2b5 100644 --- a/test/dynamics/signals/test_signals_algebra.py +++ b/test/dynamics/signals/test_signals_algebra.py @@ -9,7 +9,7 @@ # Any modifications or derivative works of this code must retain this # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,no-member """ Tests for algebraic operations on signals. @@ -19,7 +19,7 @@ from qiskit_dynamics.signals import Signal, DiscreteSignal, SignalSum, DiscreteSignalSum -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import test_array_backends, JAXTestBase try: from jax import jit, grad @@ -28,7 +28,8 @@ pass -class TestSignalAddition(QiskitDynamicsTestCase): +@test_array_backends +class TestSignalAddition: """Testing special handling of signal addition.""" def test_SignalSum_construction(self): @@ -81,7 +82,8 @@ def test_scalar_addition(self): self.assertAllClose(sig_sum.envelope(1.5), np.array([3.0, -1.0])) -class TestSignalMultiplication(QiskitDynamicsTestCase): +@test_array_backends +class TestSignalMultiplication: """Test special handling of signal multiplication.""" def test_DiscreteSignal_products(self): @@ -198,15 +200,7 @@ def test_signal_signal_product(self): self.assertAllClose(sig_prod(t_vals), expected) -class TestSignalAdditionJax(TestSignalAddition, TestJaxBase): - """Jax version of TestSignalAddition.""" - - -class TestSignalMultiplicationJax(TestSignalMultiplication, TestJaxBase): - """Jax version of TestSignalMultiplication.""" - - -class TestSignalAlgebraJaxTransformations(QiskitDynamicsTestCase, TestJaxBase): +class TestSignalAlgebraJaxTransformations(JAXTestBase): """Test cases for jax transformations through signal algebraic operations.""" def setUp(self): diff --git a/test/dynamics/solvers/test_dyson_magnus_solvers.py b/test/dynamics/solvers/test_dyson_magnus_solvers.py index f1b5b780d..dec10154c 100644 --- a/test/dynamics/solvers/test_dyson_magnus_solvers.py +++ b/test/dynamics/solvers/test_dyson_magnus_solvers.py @@ -354,7 +354,7 @@ def setUp(self): @classmethod def setUpClass(cls): - # calls TestJaxBase setUpClass + # calls JAXTestBase setUpClass super().setUpClass() # builds common objects Test_PerturbativeSolver.build_testing_objects(cls, integration_method="jax_odeint") diff --git a/test/dynamics/test_jax_transformations.py b/test/dynamics/test_jax_transformations.py index 7c3f03a34..feb59eefa 100644 --- a/test/dynamics/test_jax_transformations.py +++ b/test/dynamics/test_jax_transformations.py @@ -22,7 +22,7 @@ from qiskit_dynamics.signals import Signal from qiskit_dynamics import solve_lmde -from .common import TestJaxBase +from .common import JAXTestBase try: from jax import jit, grad @@ -32,7 +32,7 @@ pass -class TestJaxTransformations(TestJaxBase): +class TestJaxTransformations(JAXTestBase): """Class for testing jax transformations of integrated use cases.""" def setUp(self):