Skip to content

Commit

Permalink
initial attempt at updating testing infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Oct 1, 2023
1 parent d39b96b commit dccbb99
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 0 deletions.
15 changes: 15 additions & 0 deletions test/dynamics/arraylias/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Dynamics arraylias module tests.
"""
38 changes: 38 additions & 0 deletions test/dynamics/arraylias/test_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name

"""
Test global alias instances.
"""

from functools import partial

from ..common import QiskitDynamicsTestCase, test_array_backends

import numpy as np

from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import DYNAMICS_SCIPY as usp


@partial(test_array_backends, backends=["numpy", "jax", "array_numpy", "array_jax"])
class TestDynamicsNumpy(QiskitDynamicsTestCase):

def test_simple_case(self):
"""Validate correct type and output."""
a = self.asarray([1., 2., 3.])
output = unp.exp(a)
self.assertTrue(isinstance(output, type(a)))

expected = np.exp(np.array([1., 2., 3.]))
self.assertAllClose(output, expected)
105 changes: 105 additions & 0 deletions test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
Shared functionality and helpers for the unit tests.
"""

import sys
import warnings
import unittest
import inspect

from typing import Callable, Iterable
import numpy as np
from scipy.sparse import issparse

try:
from jax import jit, grad
import jax.numpy as jnp
except ImportError:
pass

Expand Down Expand Up @@ -53,6 +57,107 @@ def assertAllCloseSparse(self, A, B, rtol=1e-8, atol=1e-8):
self.assertTrue(np.allclose(A, B, rtol=rtol, atol=atol))


class NumpyBase(unittest.TestCase):
"""#############################################################################################"""

def lib(self):
return "numpy"

def asarray(self, a):
return np.array(a)

def assertArrayType(self, a):
return isinstance(a, np.ndarray)


class JaxBase(unittest.TestCase):
"""#############################################################################################"""
@classmethod
def setUpClass(cls):
try:
# pylint: disable=import-outside-toplevel
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
except Exception as err:
raise unittest.SkipTest("Skipping jax tests.") from err

def lib(self):
return "jax"

def asarray(self, a):
return jnp.array(a)

def assertArrayType(self, a):
return isinstance(a, jnp.ndarray)


class ArrayNumpyBase(unittest.TestCase):

def lib(self):
return "array_numpy"

def asarray(self, a):
return Array(a)

def assertArrayType(self, a):
return isinstance(a, Array) and a.backend == "numpy"


class ArrayJaxBase(unittest.TestCase):

@classmethod
def setUpClass(cls):
try:
# pylint: disable=import-outside-toplevel
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
except Exception as err:
raise unittest.SkipTest("Skipping jax tests.") from err

Array.set_default_backend("jax")

@classmethod
def tearDownClass(cls):
"""Set numpy back to the default backend."""
Array.set_default_backend("numpy")

def lib(self):
return "array_jax"

def asarray(self, a):
return Array(a)

def assertArrayType(self, a):
return isinstance(a, Array) and a.backend == "jax"


def test_array_backends(test_class, backends=["numpy", "jax"]):
"""Test class decorator for different array backends.
Creates subclasses of ``test_class`` with the method ``asarray`` for creating arrays of the
appropriate type, the ``lib`` method to inspect library, in addition to special setup and
teardown methods. The original ``test_class`` is then deleted so that it is no longer
accessible by unittest.
"""

# reference to module that called this function
module = inspect.getmodule(inspect.stack()[1][0])

libs = ["numpy", "jax", "array_numpy", "array_jax"]
base_classes = [NumpyBase, JaxBase, ArrayNumpyBase, ArrayJaxBase]
for lib, base_class in zip(libs, base_classes):
if lib in backends:
class_name = f"{test_class.__name__}_{lib}"
setattr(module, class_name, type(class_name, (test_class, base_class), dict()))

del test_class


# to be removed for 0.5.0
class TestJaxBase(unittest.TestCase):
"""Base class with setUpClass and tearDownClass for setting jax as the
default backend.
Expand Down

0 comments on commit dccbb99

Please sign in to comment.