Skip to content

Commit

Permalink
removing old test class TestJaxBase
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Feb 23, 2024
1 parent cd2f0fb commit 9b0f3e1
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 65 deletions.
50 changes: 0 additions & 50 deletions test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,53 +336,3 @@ def setUpClass(cls):
import qutip # pylint: disable=import-outside-toplevel,unused-import
except Exception as err:
raise unittest.SkipTest("Skipping qutip tests.") from err


# to be removed for 0.5.0
class TestJaxBase(unittest.TestCase):
"""Base class with setUpClass and tearDownClass for setting jax as the
default backend.
Test cases that inherit from this class will automatically work with jax
backend.
"""

@classmethod
def setUpClass(cls):
try:
# pylint: disable=import-outside-toplevel
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
except Exception as err:
raise unittest.SkipTest("Skipping jax tests.") from err

Array.set_default_backend("jax")

@classmethod
def tearDownClass(cls):
"""Set numpy back to the default backend."""
Array.set_default_backend("numpy")

def jit_wrap(self, func_to_test: Callable) -> Callable:
"""Wraps and jits func_to_test.
Args:
func_to_test: The function to be jited.
Returns:
Wrapped and jitted function."""
wf = wrap(jit, decorator=True)
return wf(wrap(func_to_test))

def jit_grad_wrap(self, func_to_test: Callable) -> Callable:
"""Tests whether a function can be graded. Converts
all functions to scalar, real functions if they are not
already.
Args:
func_to_test: The function whose gradient will be graded.
Returns:
JIT-compiled gradient of function.
"""
return wrap(lambda f: jit(grad(f)), decorator=True)(
lambda *args: np.sum(func_to_test(*args)).real
)
18 changes: 6 additions & 12 deletions test/dynamics/signals/test_signals_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from qiskit_dynamics.signals import Signal, DiscreteSignal, SignalSum, DiscreteSignalSum

from ..common import QiskitDynamicsTestCase, TestJaxBase
from ..common import test_array_backends, JAXTestBase

try:
from jax import jit, grad
Expand All @@ -28,7 +28,8 @@
pass


class TestSignalAddition(QiskitDynamicsTestCase):
@test_array_backends
class TestSignalAddition:
"""Testing special handling of signal addition."""

def test_SignalSum_construction(self):
Expand Down Expand Up @@ -81,7 +82,8 @@ def test_scalar_addition(self):
self.assertAllClose(sig_sum.envelope(1.5), np.array([3.0, -1.0]))


class TestSignalMultiplication(QiskitDynamicsTestCase):
@test_array_backends
class TestSignalMultiplication:
"""Test special handling of signal multiplication."""

def test_DiscreteSignal_products(self):
Expand Down Expand Up @@ -198,15 +200,7 @@ def test_signal_signal_product(self):
self.assertAllClose(sig_prod(t_vals), expected)


class TestSignalAdditionJax(TestSignalAddition, TestJaxBase):
"""Jax version of TestSignalAddition."""


class TestSignalMultiplicationJax(TestSignalMultiplication, TestJaxBase):
"""Jax version of TestSignalMultiplication."""


class TestSignalAlgebraJaxTransformations(QiskitDynamicsTestCase, TestJaxBase):
class TestSignalAlgebraJaxTransformations(JAXTestBase):
"""Test cases for jax transformations through signal algebraic operations."""

def setUp(self):
Expand Down
2 changes: 1 addition & 1 deletion test/dynamics/solvers/test_dyson_magnus_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def setUp(self):

@classmethod
def setUpClass(cls):
# calls TestJaxBase setUpClass
# calls JAXTestBase setUpClass
super().setUpClass()
# builds common objects
Test_PerturbativeSolver.build_testing_objects(cls, integration_method="jax_odeint")
Expand Down
4 changes: 2 additions & 2 deletions test/dynamics/test_jax_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from qiskit_dynamics.signals import Signal
from qiskit_dynamics import solve_lmde

from .common import TestJaxBase
from .common import JAXTestBase

try:
from jax import jit, grad
Expand All @@ -32,7 +32,7 @@
pass


class TestJaxTransformations(TestJaxBase):
class TestJaxTransformations(JAXTestBase):
"""Class for testing jax transformations of integrated use cases."""

def setUp(self):
Expand Down

0 comments on commit 9b0f3e1

Please sign in to comment.