Skip to content

Commit

Permalink
Add sparse modules to update dynamics by arraylias (#286)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Puzzuoli <[email protected]>
  • Loading branch information
to24toro and DanPuzzuoli authored Nov 17, 2023
1 parent c4cb96b commit f44c721
Show file tree
Hide file tree
Showing 11 changed files with 463 additions and 0 deletions.
26 changes: 26 additions & 0 deletions qiskit_dynamics/arraylias/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,21 @@

from typing import Union

from scipy.sparse import spmatrix

from arraylias import numpy_alias, scipy_alias

from qiskit import QiskitError

from qiskit_dynamics.array import Array

from .register_functions import (
register_asarray,
register_matmul,
register_multiply,
register_rmatmul,
)

# global NumPy and SciPy aliases
DYNAMICS_NUMPY_ALIAS = numpy_alias()
DYNAMICS_SCIPY_ALIAS = scipy_alias()
Expand All @@ -35,6 +44,23 @@
DYNAMICS_NUMPY = DYNAMICS_NUMPY_ALIAS()
DYNAMICS_SCIPY = DYNAMICS_SCIPY_ALIAS()

# register required custom versions of functions for sparse type here
DYNAMICS_NUMPY_ALIAS.register_type(spmatrix, lib="scipy_sparse")

try:
from jax.experimental.sparse import BCOO

# register required custom versions of functions for BCOO type here
DYNAMICS_NUMPY_ALIAS.register_type(BCOO, lib="jax_sparse")
except ImportError:
pass

# register custom functions for numpy_alias
register_asarray(alias=DYNAMICS_NUMPY_ALIAS)
register_matmul(alias=DYNAMICS_NUMPY_ALIAS)
register_multiply(alias=DYNAMICS_NUMPY_ALIAS)
register_rmatmul(alias=DYNAMICS_NUMPY_ALIAS)


ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list]

Expand Down
21 changes: 21 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Register custom functions using alias
"""

from .asarray import register_asarray
from .matmul import register_matmul
from .rmatmul import register_rmatmul
from .multiply import register_multiply
46 changes: 46 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/asarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Registering asarray functions to alias
"""

import numpy as np
from scipy.sparse import csr_matrix, issparse


def register_asarray(alias):
"""register asarray functions to each array libraries"""

@alias.register_default(path="asarray")
def _(arr):
return np.asarray(arr)

@alias.register_function(lib="scipy_sparse", path="asarray")
def _(arr):
if issparse(arr):
return arr
return csr_matrix(arr)

try:
from jax.experimental.sparse import BCOO

@alias.register_function(lib="jax_sparse", path="asarray")
def _(arr):
if type(arr).__name__ == "BCOO":
return arr
return BCOO.fromdense(arr)

except ImportError:
pass
38 changes: 38 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Register matmul functions to alias.
"""


def register_matmul(alias):
"""Register matmul functions to required array libraries."""

@alias.register_function(lib="scipy_sparse", path="matmul")
def _(x, y):
return x * y

try:
from jax.experimental import sparse as jsparse
import jax.numpy as jnp

jsparse_matmul = jsparse.sparsify(jnp.matmul)

@alias.register_function(lib="jax_sparse", path="matmul")
def _(x, y):
return jsparse_matmul(x, y)

except ImportError:
pass
42 changes: 42 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/multiply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Register multiply functions to alias.
"""


def register_multiply(alias):
"""Register multiply functions to each array library."""

@alias.register_fallback(path="multiply")
def _(x, y):
return x * y

@alias.register_function(lib="scipy_sparse", path="multiply")
def _(x, y):
return x.multiply(y)

try:
from jax.experimental import sparse as jsparse
import jax.numpy as jnp

jsparse_multiply = jsparse.sparsify(jnp.multiply)

@alias.register_function(lib="jax_sparse", path="multiply")
def _(x, y):
return jsparse_multiply(x, y)

except ImportError:
pass
49 changes: 49 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/rmatmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.


"""
Register rmatmul functions to alias.
"""

import numpy as np


def register_rmatmul(alias):
"""Register rmatmul functions to each array library."""

@alias.register_function(lib="numpy", path="rmatmul")
def _(x, y):
return np.matmul(y, x)

@alias.register_function(lib="scipy_sparse", path="rmatmul")
def _(x, y):
return y * x

try:
from jax.experimental import sparse as jsparse
import jax.numpy as jnp

jsparse_matmul = jsparse.sparsify(jnp.matmul)

@alias.register_function(lib="jax", path="rmatmul")
def _(x, y):
return jnp.matmul(y, x)

@alias.register_function(lib="jax_sparse", path="rmatmul")
def _(x, y):
return jsparse_matmul(y, x)

except ImportError:
pass
53 changes: 53 additions & 0 deletions test/dynamics/arraylias/register_functions/test_asarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Test asarray functions
"""
import unittest
import numpy as np
from scipy.sparse import csr_matrix
from qiskit.quantum_info.operators import Operator


from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS
from qiskit_dynamics import DYNAMICS_NUMPY as unp


class TestAsarrayFunction(unittest.TestCase):
"""Test cases for asarray functions registered in dynamics_numpy_alias."""

def test_register_default(self):
"""Test register_default."""
arr = Operator.from_label("X")
self.assertTrue(isinstance(unp.asarray(arr), np.ndarray))

def test_scipy_sparse(self):
"""Test asarray for scipy_sparse."""
arr = np.array([[1, 0], [0, 1]])
sparse_arr = csr_matrix([[1, 0], [0, 1]])
self.assertTrue(isinstance(unp.asarray(sparse_arr), csr_matrix))
self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), csr_matrix))

def test_jax_sparse(self):
"""Test asarray for jax_sparse."""
try:
from jax.experimental.sparse import BCOO

arr = np.array([[1, 0], [0, 1]])
sparse_arr = BCOO.fromdense([[1, 0], [0, 1]])
self.assertTrue(isinstance(unp.asarray(sparse_arr), BCOO))
self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), BCOO))
except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err
46 changes: 46 additions & 0 deletions test/dynamics/arraylias/register_functions/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Test matmul functions
"""
import unittest
import numpy as np
from scipy.sparse import csr_matrix

from qiskit_dynamics import DYNAMICS_NUMPY as unp
from ...common import QiskitDynamicsTestCase


class TestMatmulFunction(QiskitDynamicsTestCase):
"""Test cases for matmul functions registered in dynamics_numpy_alias."""

def test_scipy_sparse(self):
"""Test matmul for scipy_sparse."""
x = csr_matrix([[1, 0], [0, 1]])
y = csr_matrix([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.matmul(x, y), csr_matrix))
self.assertAllClose(csr_matrix.toarray(unp.matmul(x, y)), [[2, 2], [2, 2]])

def test_jax_sparse(self):
"""Test matmul for jax_sparse."""
try:
from jax.experimental.sparse import BCOO

x = BCOO.fromdense([[1, 0], [0, 1]])
y = BCOO.fromdense([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.matmul(x, y), BCOO))
self.assertAllClose(BCOO.todense(unp.matmul(x, y)), [[2, 2], [2, 2]])
except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err
Loading

0 comments on commit f44c721

Please sign in to comment.