Skip to content

Commit

Permalink
add validating type of output of funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
to24toro committed Nov 17, 2023
1 parent 8b674fe commit 6567b8f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
4 changes: 3 additions & 1 deletion test/dynamics/arraylias/register_functions/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from ...common import QiskitDynamicsTestCase


class TestMultiplyFunction(QiskitDynamicsTestCase):
class TestMatmulFunction(QiskitDynamicsTestCase):
"""Test cases for matmul functions registered in dynamics_numpy_alias."""

def test_scipy_sparse(self):
"""Test matmul for scipy_sparse."""
x = csr_matrix([[1, 0], [0, 1]])
y = csr_matrix([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.matmul(x, y), csr_matrix))
self.assertAllClose(csr_matrix.toarray(unp.matmul(x, y)), [[2, 2], [2, 2]])

def test_jax_sparse(self):
Expand All @@ -39,6 +40,7 @@ def test_jax_sparse(self):

x = BCOO.fromdense([[1, 0], [0, 1]])
y = BCOO.fromdense([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.matmul(x, y), BCOO))
self.assertAllClose(BCOO.todense(unp.matmul(x, y)), [[2, 2], [2, 2]])
except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err
3 changes: 3 additions & 0 deletions test/dynamics/arraylias/register_functions/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ def test_register_fallback(self):
"""Test register_fallback."""
x = np.array([1, 0])
y = np.array([1, 0])
self.assertTrue(isinstance(unp.multiply(x, y), np.ndarray))
self.assertAllClose(unp.multiply(x, y), [1, 0])

def test_scipy_sparse(self):
"""Test multiply for scipy_sparse."""
x = csr_matrix([[1, 0], [0, 1]])
y = csr_matrix([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.multiply(x, y), csr_matrix))
self.assertAllClose(csr_matrix.toarray(unp.multiply(x, y)), [[2, 0], [0, 2]])

def test_jax_sparse(self):
Expand All @@ -45,6 +47,7 @@ def test_jax_sparse(self):

x = BCOO.fromdense([[1, 0], [0, 1]])
y = BCOO.fromdense([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.multiply(x, y), BCOO))
self.assertAllClose(BCOO.todense(unp.multiply(x, y)), [[2, 0], [0, 2]])
except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err
6 changes: 5 additions & 1 deletion test/dynamics/arraylias/register_functions/test_rmatmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@
from ...common import QiskitDynamicsTestCase


class TestMultiplyFunction(QiskitDynamicsTestCase):
class TestRmatmulFunction(QiskitDynamicsTestCase):
"""Test cases for rmatmul functions registered in dynamics_numpy_alias."""

def test_numpy(self):
"""Test rmatmul for numpy."""
x = np.array([[1, 1], [1, 1]])
y = np.array([[1, 2], [3, 4]])
self.assertTrue(isinstance(unp.rmatmul(x, y), np.ndarray))
self.assertAllClose(unp.rmatmul(x, y), [[3, 3], [7, 7]])

def test_scipy_sparse(self):
"""Test rmatmul for scipy_sparse."""
x = csr_matrix([[1, 1], [1, 1]])
y = csr_matrix([[1, 2], [3, 4]])
self.assertTrue(isinstance(unp.rmatmul(x, y), csr_matrix))
self.assertAllClose(csr_matrix.toarray(unp.rmatmul(x, y)), [[3, 3], [7, 7]])

def test_jax(self):
Expand All @@ -45,6 +47,7 @@ def test_jax(self):

x = jnp.array([[1, 1], [1, 1]])
y = jnp.array([[1, 2], [3, 4]])
self.assertTrue(isinstance(unp.rmatmul(x, y), jnp.ndarray))
self.assertAllClose(unp.rmatmul(x, y), [[3, 3], [7, 7]])
except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err
Expand All @@ -56,6 +59,7 @@ def test_jax_sparse(self):

x = BCOO.fromdense([[1, 1], [1, 1]])
y = BCOO.fromdense([[1, 2], [3, 4]])
self.assertTrue(isinstance(unp.rmatmul(x, y), BCOO))
self.assertAllClose(BCOO.todense(unp.rmatmul(x, y)), [[3, 3], [7, 7]])
except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err

0 comments on commit 6567b8f

Please sign in to comment.