diff --git a/qiskit_dynamics/arraylias/alias.py b/qiskit_dynamics/arraylias/alias.py index 6115d1e72..8868b98c1 100644 --- a/qiskit_dynamics/arraylias/alias.py +++ b/qiskit_dynamics/arraylias/alias.py @@ -18,12 +18,21 @@ from typing import Union +from scipy.sparse import spmatrix + from arraylias import numpy_alias, scipy_alias from qiskit import QiskitError from qiskit_dynamics.array import Array +from .register_functions import ( + register_asarray, + register_matmul, + register_multiply, + register_rmatmul, +) + # global NumPy and SciPy aliases DYNAMICS_NUMPY_ALIAS = numpy_alias() DYNAMICS_SCIPY_ALIAS = scipy_alias() @@ -35,6 +44,23 @@ DYNAMICS_NUMPY = DYNAMICS_NUMPY_ALIAS() DYNAMICS_SCIPY = DYNAMICS_SCIPY_ALIAS() +# register required custom versions of functions for sparse type here +DYNAMICS_NUMPY_ALIAS.register_type(spmatrix, lib="scipy_sparse") + +try: + from jax.experimental.sparse import BCOO + + # register required custom versions of functions for BCOO type here + DYNAMICS_NUMPY_ALIAS.register_type(BCOO, lib="jax_sparse") +except ImportError: + pass + +# register custom functions for numpy_alias +register_asarray(alias=DYNAMICS_NUMPY_ALIAS) +register_matmul(alias=DYNAMICS_NUMPY_ALIAS) +register_multiply(alias=DYNAMICS_NUMPY_ALIAS) +register_rmatmul(alias=DYNAMICS_NUMPY_ALIAS) + ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list] diff --git a/qiskit_dynamics/arraylias/register_functions/__init__.py b/qiskit_dynamics/arraylias/register_functions/__init__.py new file mode 100644 index 000000000..4da5fc49c --- /dev/null +++ b/qiskit_dynamics/arraylias/register_functions/__init__.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Register custom functions using alias +""" + +from .asarray import register_asarray +from .matmul import register_matmul +from .rmatmul import register_rmatmul +from .multiply import register_multiply diff --git a/qiskit_dynamics/arraylias/register_functions/asarray.py b/qiskit_dynamics/arraylias/register_functions/asarray.py new file mode 100644 index 000000000..14cbb4452 --- /dev/null +++ b/qiskit_dynamics/arraylias/register_functions/asarray.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Registering asarray functions to alias +""" + +import numpy as np +from scipy.sparse import csr_matrix, issparse + + +def register_asarray(alias): + """register asarray functions to each array libraries""" + + @alias.register_default(path="asarray") + def _(arr): + return np.asarray(arr) + + @alias.register_function(lib="scipy_sparse", path="asarray") + def _(arr): + if issparse(arr): + return arr + return csr_matrix(arr) + + try: + from jax.experimental.sparse import BCOO + + @alias.register_function(lib="jax_sparse", path="asarray") + def _(arr): + if type(arr).__name__ == "BCOO": + return arr + return BCOO.fromdense(arr) + + except ImportError: + pass diff --git a/qiskit_dynamics/arraylias/register_functions/matmul.py b/qiskit_dynamics/arraylias/register_functions/matmul.py new file mode 100644 index 000000000..23e727b71 --- /dev/null +++ b/qiskit_dynamics/arraylias/register_functions/matmul.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Register matmul functions to alias. +""" + + +def register_matmul(alias): + """Register matmul functions to required array libraries.""" + + @alias.register_function(lib="scipy_sparse", path="matmul") + def _(x, y): + return x * y + + try: + from jax.experimental import sparse as jsparse + import jax.numpy as jnp + + jsparse_matmul = jsparse.sparsify(jnp.matmul) + + @alias.register_function(lib="jax_sparse", path="matmul") + def _(x, y): + return jsparse_matmul(x, y) + + except ImportError: + pass diff --git a/qiskit_dynamics/arraylias/register_functions/multiply.py b/qiskit_dynamics/arraylias/register_functions/multiply.py new file mode 100644 index 000000000..88765821d --- /dev/null +++ b/qiskit_dynamics/arraylias/register_functions/multiply.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Register multiply functions to alias. +""" + + +def register_multiply(alias): + """Register multiply functions to each array library.""" + + @alias.register_fallback(path="multiply") + def _(x, y): + return x * y + + @alias.register_function(lib="scipy_sparse", path="multiply") + def _(x, y): + return x.multiply(y) + + try: + from jax.experimental import sparse as jsparse + import jax.numpy as jnp + + jsparse_multiply = jsparse.sparsify(jnp.multiply) + + @alias.register_function(lib="jax_sparse", path="multiply") + def _(x, y): + return jsparse_multiply(x, y) + + except ImportError: + pass diff --git a/qiskit_dynamics/arraylias/register_functions/rmatmul.py b/qiskit_dynamics/arraylias/register_functions/rmatmul.py new file mode 100644 index 000000000..708022f76 --- /dev/null +++ b/qiskit_dynamics/arraylias/register_functions/rmatmul.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + + +""" +Register rmatmul functions to alias. +""" + +import numpy as np + + +def register_rmatmul(alias): + """Register rmatmul functions to each array library.""" + + @alias.register_function(lib="numpy", path="rmatmul") + def _(x, y): + return np.matmul(y, x) + + @alias.register_function(lib="scipy_sparse", path="rmatmul") + def _(x, y): + return y * x + + try: + from jax.experimental import sparse as jsparse + import jax.numpy as jnp + + jsparse_matmul = jsparse.sparsify(jnp.matmul) + + @alias.register_function(lib="jax", path="rmatmul") + def _(x, y): + return jnp.matmul(y, x) + + @alias.register_function(lib="jax_sparse", path="rmatmul") + def _(x, y): + return jsparse_matmul(y, x) + + except ImportError: + pass diff --git a/test/dynamics/arraylias/register_functions/test_asarray.py b/test/dynamics/arraylias/register_functions/test_asarray.py new file mode 100644 index 000000000..942183449 --- /dev/null +++ b/test/dynamics/arraylias/register_functions/test_asarray.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Test asarray functions +""" +import unittest +import numpy as np +from scipy.sparse import csr_matrix +from qiskit.quantum_info.operators import Operator + + +from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS +from qiskit_dynamics import DYNAMICS_NUMPY as unp + + +class TestAsarrayFunction(unittest.TestCase): + """Test cases for asarray functions registered in dynamics_numpy_alias.""" + + def test_register_default(self): + """Test register_default.""" + arr = Operator.from_label("X") + self.assertTrue(isinstance(unp.asarray(arr), np.ndarray)) + + def test_scipy_sparse(self): + """Test asarray for scipy_sparse.""" + arr = np.array([[1, 0], [0, 1]]) + sparse_arr = csr_matrix([[1, 0], [0, 1]]) + self.assertTrue(isinstance(unp.asarray(sparse_arr), csr_matrix)) + self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), csr_matrix)) + + def test_jax_sparse(self): + """Test asarray for jax_sparse.""" + try: + from jax.experimental.sparse import BCOO + + arr = np.array([[1, 0], [0, 1]]) + sparse_arr = BCOO.fromdense([[1, 0], [0, 1]]) + self.assertTrue(isinstance(unp.asarray(sparse_arr), BCOO)) + self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), BCOO)) + except ImportError as err: + raise unittest.SkipTest("Skipping jax tests.") from err diff --git a/test/dynamics/arraylias/register_functions/test_matmul.py b/test/dynamics/arraylias/register_functions/test_matmul.py new file mode 100644 index 000000000..1aae7ab44 --- /dev/null +++ b/test/dynamics/arraylias/register_functions/test_matmul.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Test matmul functions +""" +import unittest +import numpy as np +from scipy.sparse import csr_matrix + +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from ...common import QiskitDynamicsTestCase + + +class TestMatmulFunction(QiskitDynamicsTestCase): + """Test cases for matmul functions registered in dynamics_numpy_alias.""" + + def test_scipy_sparse(self): + """Test matmul for scipy_sparse.""" + x = csr_matrix([[1, 0], [0, 1]]) + y = csr_matrix([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.matmul(x, y), csr_matrix)) + self.assertAllClose(csr_matrix.toarray(unp.matmul(x, y)), [[2, 2], [2, 2]]) + + def test_jax_sparse(self): + """Test matmul for jax_sparse.""" + try: + from jax.experimental.sparse import BCOO + + x = BCOO.fromdense([[1, 0], [0, 1]]) + y = BCOO.fromdense([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.matmul(x, y), BCOO)) + self.assertAllClose(BCOO.todense(unp.matmul(x, y)), [[2, 2], [2, 2]]) + except ImportError as err: + raise unittest.SkipTest("Skipping jax tests.") from err diff --git a/test/dynamics/arraylias/register_functions/test_multiply.py b/test/dynamics/arraylias/register_functions/test_multiply.py new file mode 100644 index 000000000..bf763df63 --- /dev/null +++ b/test/dynamics/arraylias/register_functions/test_multiply.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Test multiply functions +""" +import unittest +import numpy as np +from scipy.sparse import csr_matrix + +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from ...common import QiskitDynamicsTestCase + + +class TestMultiplyFunction(QiskitDynamicsTestCase): + """Test cases for multiply functions registered in dynamics_numpy_alias.""" + + def test_register_fallback(self): + """Test register_fallback.""" + x = np.array([1, 0]) + y = np.array([1, 0]) + self.assertTrue(isinstance(unp.multiply(x, y), np.ndarray)) + self.assertAllClose(unp.multiply(x, y), [1, 0]) + + def test_scipy_sparse(self): + """Test multiply for scipy_sparse.""" + x = csr_matrix([[1, 0], [0, 1]]) + y = csr_matrix([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.multiply(x, y), csr_matrix)) + self.assertAllClose(csr_matrix.toarray(unp.multiply(x, y)), [[2, 0], [0, 2]]) + + def test_jax_sparse(self): + """Test multiply for jax_sparse.""" + try: + from jax.experimental.sparse import BCOO + + x = BCOO.fromdense([[1, 0], [0, 1]]) + y = BCOO.fromdense([[2, 2], [2, 2]]) + self.assertTrue(isinstance(unp.multiply(x, y), BCOO)) + self.assertAllClose(BCOO.todense(unp.multiply(x, y)), [[2, 0], [0, 2]]) + except ImportError as err: + raise unittest.SkipTest("Skipping jax tests.") from err diff --git a/test/dynamics/arraylias/register_functions/test_rmatmul.py b/test/dynamics/arraylias/register_functions/test_rmatmul.py new file mode 100644 index 000000000..2abe11c6e --- /dev/null +++ b/test/dynamics/arraylias/register_functions/test_rmatmul.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +# This code is part of Qiskit. +# +# (C) Copyright IBM 2023. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Test rmatmul functions +""" +import unittest +import numpy as np +from scipy.sparse import csr_matrix + +from qiskit_dynamics import DYNAMICS_NUMPY as unp +from ...common import QiskitDynamicsTestCase + + +class TestRmatmulFunction(QiskitDynamicsTestCase): + """Test cases for rmatmul functions registered in dynamics_numpy_alias.""" + + def test_numpy(self): + """Test rmatmul for numpy.""" + x = np.array([[1, 1], [1, 1]]) + y = np.array([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), np.ndarray)) + self.assertAllClose(unp.rmatmul(x, y), [[3, 3], [7, 7]]) + + def test_scipy_sparse(self): + """Test rmatmul for scipy_sparse.""" + x = csr_matrix([[1, 1], [1, 1]]) + y = csr_matrix([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), csr_matrix)) + self.assertAllClose(csr_matrix.toarray(unp.rmatmul(x, y)), [[3, 3], [7, 7]]) + + def test_jax(self): + """Test rmatmul for jax.""" + try: + import jax.numpy as jnp + + x = jnp.array([[1, 1], [1, 1]]) + y = jnp.array([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), jnp.ndarray)) + self.assertAllClose(unp.rmatmul(x, y), [[3, 3], [7, 7]]) + except ImportError as err: + raise unittest.SkipTest("Skipping jax tests.") from err + + def test_jax_sparse(self): + """Test rmatmul for jax_sparse.""" + try: + from jax.experimental.sparse import BCOO + + x = BCOO.fromdense([[1, 1], [1, 1]]) + y = BCOO.fromdense([[1, 2], [3, 4]]) + self.assertTrue(isinstance(unp.rmatmul(x, y), BCOO)) + self.assertAllClose(BCOO.todense(unp.rmatmul(x, y)), [[3, 3], [7, 7]]) + except ImportError as err: + raise unittest.SkipTest("Skipping jax tests.") from err diff --git a/test/dynamics/arraylias/test_alias.py b/test/dynamics/arraylias/test_alias.py index fdba7fae1..c342d0b3a 100644 --- a/test/dynamics/arraylias/test_alias.py +++ b/test/dynamics/arraylias/test_alias.py @@ -16,10 +16,13 @@ """ from functools import partial +import unittest import numpy as np import scipy as sp +from scipy.sparse import csr_matrix +from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS from qiskit_dynamics import DYNAMICS_NUMPY as unp from qiskit_dynamics import DYNAMICS_SCIPY as usp @@ -52,3 +55,24 @@ def test_simple_case(self): expected = sp.fft.dct(np.array([1.0, 2.0, 3.0])) self.assertAllClose(output, expected) + + +class TestDynamicsNumpyAliasType(unittest.TestCase): + """Test cases for which types are registered in dynamics_numpy_alias.""" + + def test_spmatrix_type(self): + """Test spmatrix is registered as scipy_sparse.""" + sp_matrix = csr_matrix([[0.0, 1.0], [1.0, 0.0]]) + registered_type_name = "scipy_sparse" + self.assertTrue(registered_type_name in DYNAMICS_NUMPY_ALIAS.infer_libs(sp_matrix)) + + def test_bcoo_type(self): + """Test bcoo is registered.""" + try: + from jax.experimental.sparse import BCOO + + bcoo = BCOO.fromdense([[0.0, 1.0], [1.0, 0.0]]) + registered_type_name = "jax_sparse" + self.assertTrue(registered_type_name in DYNAMICS_NUMPY_ALIAS.infer_libs(bcoo)[0]) + except ImportError as err: + raise unittest.SkipTest("Skipping jax tests.") from err