From 9b0f3e1c6bfd8184bb8dba9ebabafdb971082cf2 Mon Sep 17 00:00:00 2001 From: DanPuzzuoli Date: Fri, 23 Feb 2024 10:39:25 -0800 Subject: [PATCH] removing old test class TestJaxBase --- test/dynamics/common.py | 50 ------------------- test/dynamics/signals/test_signals_algebra.py | 18 +++---- .../solvers/test_dyson_magnus_solvers.py | 2 +- test/dynamics/test_jax_transformations.py | 4 +- 4 files changed, 9 insertions(+), 65 deletions(-) diff --git a/test/dynamics/common.py b/test/dynamics/common.py index ec8cd5115..18e70d305 100644 --- a/test/dynamics/common.py +++ b/test/dynamics/common.py @@ -336,53 +336,3 @@ def setUpClass(cls): import qutip # pylint: disable=import-outside-toplevel,unused-import except Exception as err: raise unittest.SkipTest("Skipping qutip tests.") from err - - -# to be removed for 0.5.0 -class TestJaxBase(unittest.TestCase): - """Base class with setUpClass and tearDownClass for setting jax as the - default backend. - - Test cases that inherit from this class will automatically work with jax - backend. - """ - - @classmethod - def setUpClass(cls): - try: - # pylint: disable=import-outside-toplevel - import jax - - jax.config.update("jax_enable_x64", True) - jax.config.update("jax_platform_name", "cpu") - except Exception as err: - raise unittest.SkipTest("Skipping jax tests.") from err - - Array.set_default_backend("jax") - - @classmethod - def tearDownClass(cls): - """Set numpy back to the default backend.""" - Array.set_default_backend("numpy") - - def jit_wrap(self, func_to_test: Callable) -> Callable: - """Wraps and jits func_to_test. - Args: - func_to_test: The function to be jited. - Returns: - Wrapped and jitted function.""" - wf = wrap(jit, decorator=True) - return wf(wrap(func_to_test)) - - def jit_grad_wrap(self, func_to_test: Callable) -> Callable: - """Tests whether a function can be graded. Converts - all functions to scalar, real functions if they are not - already. - Args: - func_to_test: The function whose gradient will be graded. - Returns: - JIT-compiled gradient of function. - """ - return wrap(lambda f: jit(grad(f)), decorator=True)( - lambda *args: np.sum(func_to_test(*args)).real - ) diff --git a/test/dynamics/signals/test_signals_algebra.py b/test/dynamics/signals/test_signals_algebra.py index 434046756..f10059961 100644 --- a/test/dynamics/signals/test_signals_algebra.py +++ b/test/dynamics/signals/test_signals_algebra.py @@ -19,7 +19,7 @@ from qiskit_dynamics.signals import Signal, DiscreteSignal, SignalSum, DiscreteSignalSum -from ..common import QiskitDynamicsTestCase, TestJaxBase +from ..common import test_array_backends, JAXTestBase try: from jax import jit, grad @@ -28,7 +28,8 @@ pass -class TestSignalAddition(QiskitDynamicsTestCase): +@test_array_backends +class TestSignalAddition: """Testing special handling of signal addition.""" def test_SignalSum_construction(self): @@ -81,7 +82,8 @@ def test_scalar_addition(self): self.assertAllClose(sig_sum.envelope(1.5), np.array([3.0, -1.0])) -class TestSignalMultiplication(QiskitDynamicsTestCase): +@test_array_backends +class TestSignalMultiplication: """Test special handling of signal multiplication.""" def test_DiscreteSignal_products(self): @@ -198,15 +200,7 @@ def test_signal_signal_product(self): self.assertAllClose(sig_prod(t_vals), expected) -class TestSignalAdditionJax(TestSignalAddition, TestJaxBase): - """Jax version of TestSignalAddition.""" - - -class TestSignalMultiplicationJax(TestSignalMultiplication, TestJaxBase): - """Jax version of TestSignalMultiplication.""" - - -class TestSignalAlgebraJaxTransformations(QiskitDynamicsTestCase, TestJaxBase): +class TestSignalAlgebraJaxTransformations(JAXTestBase): """Test cases for jax transformations through signal algebraic operations.""" def setUp(self): diff --git a/test/dynamics/solvers/test_dyson_magnus_solvers.py b/test/dynamics/solvers/test_dyson_magnus_solvers.py index f1b5b780d..dec10154c 100644 --- a/test/dynamics/solvers/test_dyson_magnus_solvers.py +++ b/test/dynamics/solvers/test_dyson_magnus_solvers.py @@ -354,7 +354,7 @@ def setUp(self): @classmethod def setUpClass(cls): - # calls TestJaxBase setUpClass + # calls JAXTestBase setUpClass super().setUpClass() # builds common objects Test_PerturbativeSolver.build_testing_objects(cls, integration_method="jax_odeint") diff --git a/test/dynamics/test_jax_transformations.py b/test/dynamics/test_jax_transformations.py index 7c3f03a34..feb59eefa 100644 --- a/test/dynamics/test_jax_transformations.py +++ b/test/dynamics/test_jax_transformations.py @@ -22,7 +22,7 @@ from qiskit_dynamics.signals import Signal from qiskit_dynamics import solve_lmde -from .common import TestJaxBase +from .common import JAXTestBase try: from jax import jit, grad @@ -32,7 +32,7 @@ pass -class TestJaxTransformations(TestJaxBase): +class TestJaxTransformations(JAXTestBase): """Class for testing jax transformations of integrated use cases.""" def setUp(self):