Skip to content

Commit

Permalink
Merge branch 'main' into fix-complex-cast-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli authored Feb 26, 2024
2 parents cfd5a46 + f99d651 commit 33b69c6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 67 deletions.
52 changes: 1 addition & 51 deletions test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS

from qiskit_dynamics.array import Array, wrap
from qiskit_dynamics.array import Array


def _is_sparse_object_array(A):
Expand Down Expand Up @@ -336,53 +336,3 @@ def setUpClass(cls):
import qutip # pylint: disable=import-outside-toplevel,unused-import
except Exception as err:
raise unittest.SkipTest("Skipping qutip tests.") from err


# to be removed for 0.5.0
class TestJaxBase(unittest.TestCase):
"""Base class with setUpClass and tearDownClass for setting jax as the
default backend.
Test cases that inherit from this class will automatically work with jax
backend.
"""

@classmethod
def setUpClass(cls):
try:
# pylint: disable=import-outside-toplevel
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
except Exception as err:
raise unittest.SkipTest("Skipping jax tests.") from err

Array.set_default_backend("jax")

@classmethod
def tearDownClass(cls):
"""Set numpy back to the default backend."""
Array.set_default_backend("numpy")

def jit_wrap(self, func_to_test: Callable) -> Callable:
"""Wraps and jits func_to_test.
Args:
func_to_test: The function to be jited.
Returns:
Wrapped and jitted function."""
wf = wrap(jit, decorator=True)
return wf(wrap(func_to_test))

def jit_grad_wrap(self, func_to_test: Callable) -> Callable:
"""Tests whether a function can be graded. Converts
all functions to scalar, real functions if they are not
already.
Args:
func_to_test: The function whose gradient will be graded.
Returns:
JIT-compiled gradient of function.
"""
return wrap(lambda f: jit(grad(f)), decorator=True)(
lambda *args: np.sum(func_to_test(*args)).real
)
20 changes: 7 additions & 13 deletions test/dynamics/signals/test_signals_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name
# pylint: disable=invalid-name,no-member

"""
Tests for algebraic operations on signals.
Expand All @@ -19,7 +19,7 @@

from qiskit_dynamics.signals import Signal, DiscreteSignal, SignalSum, DiscreteSignalSum

from ..common import QiskitDynamicsTestCase, TestJaxBase
from ..common import test_array_backends, JAXTestBase

try:
from jax import jit, grad
Expand All @@ -28,7 +28,8 @@
pass


class TestSignalAddition(QiskitDynamicsTestCase):
@test_array_backends
class TestSignalAddition:
"""Testing special handling of signal addition."""

def test_SignalSum_construction(self):
Expand Down Expand Up @@ -81,7 +82,8 @@ def test_scalar_addition(self):
self.assertAllClose(sig_sum.envelope(1.5), np.array([3.0, -1.0]))


class TestSignalMultiplication(QiskitDynamicsTestCase):
@test_array_backends
class TestSignalMultiplication:
"""Test special handling of signal multiplication."""

def test_DiscreteSignal_products(self):
Expand Down Expand Up @@ -198,15 +200,7 @@ def test_signal_signal_product(self):
self.assertAllClose(sig_prod(t_vals), expected)


class TestSignalAdditionJax(TestSignalAddition, TestJaxBase):
"""Jax version of TestSignalAddition."""


class TestSignalMultiplicationJax(TestSignalMultiplication, TestJaxBase):
"""Jax version of TestSignalMultiplication."""


class TestSignalAlgebraJaxTransformations(QiskitDynamicsTestCase, TestJaxBase):
class TestSignalAlgebraJaxTransformations(JAXTestBase):
"""Test cases for jax transformations through signal algebraic operations."""

def setUp(self):
Expand Down
2 changes: 1 addition & 1 deletion test/dynamics/solvers/test_dyson_magnus_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def setUp(self):

@classmethod
def setUpClass(cls):
# calls TestJaxBase setUpClass
# calls JAXTestBase setUpClass
super().setUpClass()
# builds common objects
Test_PerturbativeSolver.build_testing_objects(cls, integration_method="jax_odeint")
Expand Down
4 changes: 2 additions & 2 deletions test/dynamics/test_jax_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from qiskit_dynamics.signals import Signal
from qiskit_dynamics import solve_lmde

from .common import TestJaxBase
from .common import JAXTestBase

try:
from jax import jit, grad
Expand All @@ -32,7 +32,7 @@
pass


class TestJaxTransformations(TestJaxBase):
class TestJaxTransformations(JAXTestBase):
"""Class for testing jax transformations of integrated use cases."""

def setUp(self):
Expand Down

0 comments on commit 33b69c6

Please sign in to comment.