Skip to content

Commit

Permalink
Merge pull request #2 from DanPuzzuoli/potential-test-fix
Browse files Browse the repository at this point in the history
potential fix for using test_array_backends
  • Loading branch information
to24toro authored Oct 17, 2023
2 parents 02edfce + 693e378 commit de2596f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
8 changes: 4 additions & 4 deletions test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def assertAllCloseSparse(self, A, B, rtol=1e-8, atol=1e-8):
self.assertTrue(np.allclose(A, B, rtol=rtol, atol=atol))


class NumpyTestBase(unittest.TestCase):
class NumpyTestBase(QiskitDynamicsTestCase):
"""Base class for tests working with numpy arrays."""

@classmethod
Expand All @@ -93,7 +93,7 @@ def assertArrayType(self, a):
return isinstance(a, np.ndarray)


class JAXTestBase(unittest.TestCase):
class JAXTestBase(QiskitDynamicsTestCase):
"""Base class for tests working with JAX arrays."""

@classmethod
Expand Down Expand Up @@ -121,7 +121,7 @@ def assertArrayType(self, a):
return isinstance(a, jnp.ndarray)


class ArrayNumpyTestBase(unittest.TestCase):
class ArrayNumpyTestBase(QiskitDynamicsTestCase):
"""Base class for tests working with qiskit_dynamics Arrays with numpy backend."""

@classmethod
Expand All @@ -138,7 +138,7 @@ def assertArrayType(self, a):
return isinstance(a, Array) and a.backend == "numpy"


class ArrayJaxTestBase(unittest.TestCase):
class ArrayJaxTestBase(QiskitDynamicsTestCase):
"""Base class for tests working with qiskit_dynamics Arrays with jax backend."""

@classmethod
Expand Down
20 changes: 11 additions & 9 deletions test/dynamics/signals/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
class TestSignal(QiskitDynamicsTestCase):
class TestSignal:
"""Tests for Signal object."""

def setUp(self):
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_conjugate(self):


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
class TestConstant(QiskitDynamicsTestCase):
class TestConstant:
"""Tests for constant signal object."""

def setUp(self):
Expand Down Expand Up @@ -384,7 +384,7 @@ def test_conjugate(self):


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
class TestDiscreteSignal(QiskitDynamicsTestCase):
class TestDiscreteSignal:
"""Tests for DiscreteSignal object."""

def setUp(self):
Expand Down Expand Up @@ -520,8 +520,7 @@ def test_add_samples(self):
self.assertAllClose(discrete3.samples, [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0])


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
class TestSignalSum(QiskitDynamicsTestCase):
class TestSignalSum:
"""Test evaluation functions for ``SignalSum``."""

def setUp(self):
Expand Down Expand Up @@ -738,7 +737,6 @@ def test_conjugate(self):
)


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
class TestDiscreteSignalSum(TestSignalSum):
"""Tests for DiscreteSignalSum."""

Expand Down Expand Up @@ -768,8 +766,12 @@ def test_empty_DiscreteSignal_to_sum(self):
self.assertTrue(empty_sum.samples.shape == (1, 0))


test_array_backends(TestSignalSum, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
test_array_backends(TestDiscreteSignalSum, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
class TestSignalList(QiskitDynamicsTestCase):
class TestSignalList:
"""Test cases for SignalList class."""

def setUp(self):
Expand Down Expand Up @@ -834,7 +836,7 @@ def test_construction_with_numbers(self):


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
class TestSignalCollection(QiskitDynamicsTestCase):
class TestSignalCollection:
"""Test cases for SignalCollection functionality."""

def setUp(self):
Expand Down Expand Up @@ -886,7 +888,7 @@ def test_DiscreteSignalSum_iterator(self):
self.assertAllClose(sum_val, self.discrete_sig_sum(3.0))


class TestSignalsJaxTransformations(QiskitDynamicsTestCase, JAXTestBase):
class TestSignalsJaxTransformations(JAXTestBase):
"""Test cases for jax transformations of signals."""

def setUp(self):
Expand Down

0 comments on commit de2596f

Please sign in to comment.