From 6567b8f637372e5cf20bfb57c8edabc5e9b6932c Mon Sep 17 00:00:00 2001 From: to24toro Date: Fri, 17 Nov 2023 10:31:03 +0900 Subject: [PATCH] add validating type of output of funcs --- test/dynamics/arraylias/register_functions/test_matmul.py | 4 +++- test/dynamics/arraylias/register_functions/test_multiply.py | 3 +++ test/dynamics/arraylias/register_functions/test_rmatmul.py | 6 +++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/dynamics/arraylias/register_functions/test_matmul.py b/test/dynamics/arraylias/register_functions/test_matmul.py index 3d727f338..1aae7ab44 100644 --- a/test/dynamics/arraylias/register_functions/test_matmul.py +++ b/test/dynamics/arraylias/register_functions/test_matmul.py @@ -23,13 +23,14 @@ from ...common import QiskitDynamicsTestCase -class TestMultiplyFunction(QiskitDynamicsTestCase): +class TestMatmulFunction(QiskitDynamicsTestCase): """Test cases for matmul functions registered in dynamics_numpy_alias.""" def test_scipy_sparse(self): """Test matmul for scipy_sparse.""" x = csr_matrix([[1, 0], [0, 1]]) y = csr_matrix([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.matmul(x, y), csr_matrix)) self.assertAllClose(csr_matrix.toarray(unp.matmul(x, y)), [[2, 2], [2, 2]]) def test_jax_sparse(self): @@ -39,6 +40,7 @@ def test_jax_sparse(self): x = BCOO.fromdense([[1, 0], [0, 1]]) y = BCOO.fromdense([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.matmul(x, y), BCOO)) self.assertAllClose(BCOO.todense(unp.matmul(x, y)), [[2, 2], [2, 2]]) except ImportError as err: raise unittest.SkipTest("Skipping jax tests.") from err diff --git a/test/dynamics/arraylias/register_functions/test_multiply.py b/test/dynamics/arraylias/register_functions/test_multiply.py index 311dbdf0a..bf763df63 100644 --- a/test/dynamics/arraylias/register_functions/test_multiply.py +++ b/test/dynamics/arraylias/register_functions/test_multiply.py @@ -30,12 +30,14 @@ def test_register_fallback(self): """Test register_fallback.""" x = np.array([1, 0]) y = np.array([1, 0]) + self.assertTrue(isinstance(unp.multiply(x, y), np.ndarray)) self.assertAllClose(unp.multiply(x, y), [1, 0]) def test_scipy_sparse(self): """Test multiply for scipy_sparse.""" x = csr_matrix([[1, 0], [0, 1]]) y = csr_matrix([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.multiply(x, y), csr_matrix)) self.assertAllClose(csr_matrix.toarray(unp.multiply(x, y)), [[2, 0], [0, 2]]) def test_jax_sparse(self): @@ -45,6 +47,7 @@ def test_jax_sparse(self): x = BCOO.fromdense([[1, 0], [0, 1]]) y = BCOO.fromdense([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.multiply(x, y), BCOO)) self.assertAllClose(BCOO.todense(unp.multiply(x, y)), [[2, 0], [0, 2]]) except ImportError as err: raise unittest.SkipTest("Skipping jax tests.") from err diff --git a/test/dynamics/arraylias/register_functions/test_rmatmul.py b/test/dynamics/arraylias/register_functions/test_rmatmul.py index f24e39b35..2abe11c6e 100644 --- a/test/dynamics/arraylias/register_functions/test_rmatmul.py +++ b/test/dynamics/arraylias/register_functions/test_rmatmul.py @@ -23,19 +23,21 @@ from ...common import QiskitDynamicsTestCase -class TestMultiplyFunction(QiskitDynamicsTestCase): +class TestRmatmulFunction(QiskitDynamicsTestCase): """Test cases for rmatmul functions registered in dynamics_numpy_alias.""" def test_numpy(self): """Test rmatmul for numpy.""" x = np.array([[1, 1], [1, 1]]) y = np.array([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), np.ndarray)) self.assertAllClose(unp.rmatmul(x, y), [[3, 3], [7, 7]]) def test_scipy_sparse(self): """Test rmatmul for scipy_sparse.""" x = csr_matrix([[1, 1], [1, 1]]) y = csr_matrix([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), csr_matrix)) self.assertAllClose(csr_matrix.toarray(unp.rmatmul(x, y)), [[3, 3], [7, 7]]) def test_jax(self): @@ -45,6 +47,7 @@ def test_jax(self): x = jnp.array([[1, 1], [1, 1]]) y = jnp.array([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), jnp.ndarray)) self.assertAllClose(unp.rmatmul(x, y), [[3, 3], [7, 7]]) except ImportError as err: raise unittest.SkipTest("Skipping jax tests.") from err @@ -56,6 +59,7 @@ def test_jax_sparse(self): x = BCOO.fromdense([[1, 1], [1, 1]]) y = BCOO.fromdense([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), BCOO)) self.assertAllClose(BCOO.todense(unp.rmatmul(x, y)), [[3, 3], [7, 7]]) except ImportError as err: raise unittest.SkipTest("Skipping jax tests.") from err