Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse modules to update dynamics by arraylias #286

Merged
merged 22 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
56f5e7a
add sparse module
to24toro Nov 7, 2023
1f0f136
Update qiskit_dynamics/arraylias/register_functions/to_sparse.py
to24toro Nov 8, 2023
a56088f
Update qiskit_dynamics/arraylias/register_functions/to_sparse.py
to24toro Nov 8, 2023
ab5ec7d
Update qiskit_dynamics/arraylias/register_functions/rmatmul.py
to24toro Nov 8, 2023
907a878
Update qiskit_dynamics/arraylias/register_functions/multiply.py
to24toro Nov 8, 2023
025454e
Update qiskit_dynamics/arraylias/register_functions/to_dense.py
to24toro Nov 8, 2023
d9e537d
Update qiskit_dynamics/arraylias/register_functions/to_dense.py
to24toro Nov 8, 2023
6687bb6
Update qiskit_dynamics/arraylias/register_functions/matmul.py
to24toro Nov 8, 2023
41d8fc7
Update qiskit_dynamics/arraylias/register_functions/rmatmul.py
to24toro Nov 8, 2023
d075ded
Update qiskit_dynamics/arraylias/register_functions/multiply.py
to24toro Nov 8, 2023
9fe669e
Update qiskit_dynamics/arraylias/register_functions/matmul.py
to24toro Nov 8, 2023
c962a34
Update qiskit_dynamics/arraylias/register_functions/asarray.py
to24toro Nov 13, 2023
64cbf56
Update qiskit_dynamics/arraylias/register_functions/to_dense.py
to24toro Nov 13, 2023
8dccad1
change the order of the funtion places
to24toro Nov 15, 2023
44c1c4f
remove to_dense to_sparse to_numeric_matrix_type
to24toro Nov 15, 2023
ced1dc4
lint
to24toro Nov 15, 2023
9fdce08
Update alias.py
DanPuzzuoli Nov 15, 2023
903b42f
Merge branch 'main' into arraylias/sparse_module
DanPuzzuoli Nov 15, 2023
12ddc08
Update alias.py
DanPuzzuoli Nov 15, 2023
f6f109e
addtest
to24toro Nov 16, 2023
8b674fe
remove fallback
to24toro Nov 16, 2023
6567b8f
add validating type of output of funcs
to24toro Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions qiskit_dynamics/arraylias/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,21 @@

from typing import Union

from scipy.sparse import spmatrix

from arraylias import numpy_alias, scipy_alias

from qiskit import QiskitError

from qiskit_dynamics.array import Array

from .register_functions import (
register_asarray,
register_matmul,
register_multiply,
register_rmatmul,
)

# global NumPy and SciPy aliases
DYNAMICS_NUMPY_ALIAS = numpy_alias()
DYNAMICS_SCIPY_ALIAS = scipy_alias()
Expand All @@ -35,6 +44,23 @@
DYNAMICS_NUMPY = DYNAMICS_NUMPY_ALIAS()
DYNAMICS_SCIPY = DYNAMICS_SCIPY_ALIAS()

# register required custom versions of functions for sparse type here
DYNAMICS_NUMPY_ALIAS.register_type(spmatrix, lib="scipy_sparse")

try:
from jax.experimental.sparse import BCOO

# register required custom versions of functions for BCOO type here
DYNAMICS_NUMPY_ALIAS.register_type(BCOO, lib="jax_sparse")
except ImportError:
pass

# register custom functions for numpy_alias
register_asarray(alias=DYNAMICS_NUMPY_ALIAS)
register_matmul(alias=DYNAMICS_NUMPY_ALIAS)
register_multiply(alias=DYNAMICS_NUMPY_ALIAS)
register_rmatmul(alias=DYNAMICS_NUMPY_ALIAS)


ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list]

Expand Down
21 changes: 21 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Register custom functions using alias
"""

from .asarray import register_asarray
from .matmul import register_matmul
from .rmatmul import register_rmatmul
from .multiply import register_multiply
46 changes: 46 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/asarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Registering asarray functions to alias
"""

import numpy as np
from scipy.sparse import csr_matrix, issparse


def register_asarray(alias):
"""register asarray functions to each array libraries"""

@alias.register_default(path="asarray")
def _(arr):
return np.asarray(arr)

@alias.register_function(lib="scipy_sparse", path="asarray")
def _(arr):
if issparse(arr):
return arr
return csr_matrix(arr)

try:
from jax.experimental.sparse import BCOO

@alias.register_function(lib="jax_sparse", path="asarray")
def _(arr):
if type(arr).__name__ == "BCOO":
return arr
return BCOO.fromdense(arr)

except ImportError:
pass
38 changes: 38 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Register matmul functions to alias.
"""


def register_matmul(alias):
"""Register matmul functions to required array libraries."""

@alias.register_function(lib="scipy_sparse", path="matmul")
def _(x, y):
return x * y

try:
from jax.experimental import sparse as jsparse
import jax.numpy as jnp

jsparse_matmul = jsparse.sparsify(jnp.matmul)

@alias.register_function(lib="jax_sparse", path="matmul")
def _(x, y):
return jsparse_matmul(x, y)

except ImportError:
pass
42 changes: 42 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/multiply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Register multiply functions to alias.
"""


def register_multiply(alias):
"""Register multiply functions to each array library."""

@alias.register_fallback(path="multiply")
def _(x, y):
return x * y

@alias.register_function(lib="scipy_sparse", path="multiply")
def _(x, y):
return x.multiply(y)

try:
from jax.experimental import sparse as jsparse
import jax.numpy as jnp

jsparse_multiply = jsparse.sparsify(jnp.multiply)

@alias.register_function(lib="jax_sparse", path="multiply")
def _(x, y):
return jsparse_multiply(x, y)

except ImportError:
pass
49 changes: 49 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/rmatmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.


"""
Register rmatmul functions to alias.
"""

import numpy as np


def register_rmatmul(alias):
"""Register rmatmul functions to each array library."""

@alias.register_function(lib="numpy", path="rmatmul")
def _(x, y):
return np.matmul(y, x)

@alias.register_function(lib="scipy_sparse", path="rmatmul")
def _(x, y):
return y * x

try:
from jax.experimental import sparse as jsparse
import jax.numpy as jnp

jsparse_matmul = jsparse.sparsify(jnp.matmul)

@alias.register_function(lib="jax", path="rmatmul")
def _(x, y):
return jnp.matmul(y, x)

@alias.register_function(lib="jax_sparse", path="rmatmul")
def _(x, y):
return jsparse_matmul(y, x)

except ImportError:
pass
53 changes: 53 additions & 0 deletions test/dynamics/arraylias/register_functions/test_asarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Test asarray functions
"""
import unittest
import numpy as np
from scipy.sparse import csr_matrix
from qiskit.quantum_info.operators import Operator


from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS
from qiskit_dynamics import DYNAMICS_NUMPY as unp


class TestAsarrayFunction(unittest.TestCase):
"""Test cases for asarray functions registered in dynamics_numpy_alias."""

def test_register_default(self):
"""Test register_default."""
arr = Operator.from_label("X")
self.assertTrue(isinstance(unp.asarray(arr), np.ndarray))

def test_scipy_sparse(self):
"""Test asarray for scipy_sparse."""
arr = np.array([[1, 0], [0, 1]])
sparse_arr = csr_matrix([[1, 0], [0, 1]])
self.assertTrue(isinstance(unp.asarray(sparse_arr), csr_matrix))
self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), csr_matrix))

def test_jax_sparse(self):
"""Test asarray for jax_sparse."""
try:
from jax.experimental.sparse import BCOO

arr = np.array([[1, 0], [0, 1]])
sparse_arr = BCOO.fromdense([[1, 0], [0, 1]])
self.assertTrue(isinstance(unp.asarray(sparse_arr), BCOO))
self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), BCOO))
except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err
46 changes: 46 additions & 0 deletions test/dynamics/arraylias/register_functions/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Test matmul functions
"""
import unittest
import numpy as np
from scipy.sparse import csr_matrix

from qiskit_dynamics import DYNAMICS_NUMPY as unp
from ...common import QiskitDynamicsTestCase


class TestMatmulFunction(QiskitDynamicsTestCase):
"""Test cases for matmul functions registered in dynamics_numpy_alias."""

def test_scipy_sparse(self):
"""Test matmul for scipy_sparse."""
x = csr_matrix([[1, 0], [0, 1]])
y = csr_matrix([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.matmul(x, y), csr_matrix))
self.assertAllClose(csr_matrix.toarray(unp.matmul(x, y)), [[2, 2], [2, 2]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate type of unp.matmul(x, y)


def test_jax_sparse(self):
"""Test matmul for jax_sparse."""
try:
from jax.experimental.sparse import BCOO

x = BCOO.fromdense([[1, 0], [0, 1]])
y = BCOO.fromdense([[2, 2], [2, 2]])
self.assertTrue(isinstance(unp.matmul(x, y), BCOO))
self.assertAllClose(BCOO.todense(unp.matmul(x, y)), [[2, 2], [2, 2]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here, validate type of unp.matmul(x, y)

except ImportError as err:
raise unittest.SkipTest("Skipping jax tests.") from err
Loading
Loading