-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sparse modules to update dynamics by arraylias #286
Merged
DanPuzzuoli
merged 22 commits into
qiskit-community:main
from
to24toro:arraylias/sparse_module
Nov 17, 2023
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
56f5e7a
add sparse module
to24toro 1f0f136
Update qiskit_dynamics/arraylias/register_functions/to_sparse.py
to24toro a56088f
Update qiskit_dynamics/arraylias/register_functions/to_sparse.py
to24toro ab5ec7d
Update qiskit_dynamics/arraylias/register_functions/rmatmul.py
to24toro 907a878
Update qiskit_dynamics/arraylias/register_functions/multiply.py
to24toro 025454e
Update qiskit_dynamics/arraylias/register_functions/to_dense.py
to24toro d9e537d
Update qiskit_dynamics/arraylias/register_functions/to_dense.py
to24toro 6687bb6
Update qiskit_dynamics/arraylias/register_functions/matmul.py
to24toro 41d8fc7
Update qiskit_dynamics/arraylias/register_functions/rmatmul.py
to24toro d075ded
Update qiskit_dynamics/arraylias/register_functions/multiply.py
to24toro 9fe669e
Update qiskit_dynamics/arraylias/register_functions/matmul.py
to24toro c962a34
Update qiskit_dynamics/arraylias/register_functions/asarray.py
to24toro 64cbf56
Update qiskit_dynamics/arraylias/register_functions/to_dense.py
to24toro 8dccad1
change the order of the funtion places
to24toro 44c1c4f
remove to_dense to_sparse to_numeric_matrix_type
to24toro ced1dc4
lint
to24toro 9fdce08
Update alias.py
DanPuzzuoli 903b42f
Merge branch 'main' into arraylias/sparse_module
DanPuzzuoli 12ddc08
Update alias.py
DanPuzzuoli f6f109e
addtest
to24toro 8b674fe
remove fallback
to24toro 6567b8f
add validating type of output of funcs
to24toro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
""" | ||
Register custom functions using alias | ||
""" | ||
|
||
from .asarray import register_asarray | ||
from .matmul import register_matmul | ||
from .rmatmul import register_rmatmul | ||
from .multiply import register_multiply |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" | ||
Registering asarray functions to alias | ||
""" | ||
|
||
import numpy as np | ||
from scipy.sparse import csr_matrix, issparse | ||
|
||
|
||
def register_asarray(alias): | ||
"""register asarray functions to each array libraries""" | ||
|
||
@alias.register_default(path="asarray") | ||
def _(arr): | ||
return np.asarray(arr) | ||
|
||
@alias.register_function(lib="scipy_sparse", path="asarray") | ||
def _(arr): | ||
if issparse(arr): | ||
return arr | ||
return csr_matrix(arr) | ||
|
||
try: | ||
from jax.experimental.sparse import BCOO | ||
|
||
@alias.register_function(lib="jax_sparse", path="asarray") | ||
def _(arr): | ||
if type(arr).__name__ == "BCOO": | ||
return arr | ||
return BCOO.fromdense(arr) | ||
|
||
except ImportError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" | ||
Register matmul functions to alias. | ||
""" | ||
|
||
|
||
def register_matmul(alias): | ||
"""Register matmul functions to required array libraries.""" | ||
|
||
@alias.register_function(lib="scipy_sparse", path="matmul") | ||
def _(x, y): | ||
return x * y | ||
|
||
try: | ||
from jax.experimental import sparse as jsparse | ||
import jax.numpy as jnp | ||
|
||
jsparse_matmul = jsparse.sparsify(jnp.matmul) | ||
|
||
@alias.register_function(lib="jax_sparse", path="matmul") | ||
def _(x, y): | ||
return jsparse_matmul(x, y) | ||
|
||
except ImportError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" | ||
Register multiply functions to alias. | ||
""" | ||
|
||
|
||
def register_multiply(alias): | ||
"""Register multiply functions to each array library.""" | ||
|
||
@alias.register_fallback(path="multiply") | ||
def _(x, y): | ||
return x * y | ||
|
||
@alias.register_function(lib="scipy_sparse", path="multiply") | ||
def _(x, y): | ||
return x.multiply(y) | ||
|
||
try: | ||
from jax.experimental import sparse as jsparse | ||
import jax.numpy as jnp | ||
|
||
jsparse_multiply = jsparse.sparsify(jnp.multiply) | ||
|
||
@alias.register_function(lib="jax_sparse", path="multiply") | ||
def _(x, y): | ||
return jsparse_multiply(x, y) | ||
|
||
except ImportError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
|
||
""" | ||
Register rmatmul functions to alias. | ||
""" | ||
|
||
import numpy as np | ||
|
||
|
||
def register_rmatmul(alias): | ||
"""Register rmatmul functions to each array library.""" | ||
|
||
@alias.register_function(lib="numpy", path="rmatmul") | ||
def _(x, y): | ||
return np.matmul(y, x) | ||
|
||
@alias.register_function(lib="scipy_sparse", path="rmatmul") | ||
def _(x, y): | ||
return y * x | ||
|
||
try: | ||
from jax.experimental import sparse as jsparse | ||
import jax.numpy as jnp | ||
|
||
jsparse_matmul = jsparse.sparsify(jnp.matmul) | ||
|
||
@alias.register_function(lib="jax", path="rmatmul") | ||
def _(x, y): | ||
return jnp.matmul(y, x) | ||
|
||
@alias.register_function(lib="jax_sparse", path="rmatmul") | ||
def _(x, y): | ||
return jsparse_matmul(y, x) | ||
|
||
except ImportError: | ||
pass |
53 changes: 53 additions & 0 deletions
53
test/dynamics/arraylias/register_functions/test_asarray.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" | ||
Test asarray functions | ||
""" | ||
import unittest | ||
import numpy as np | ||
from scipy.sparse import csr_matrix | ||
from qiskit.quantum_info.operators import Operator | ||
|
||
|
||
from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS | ||
from qiskit_dynamics import DYNAMICS_NUMPY as unp | ||
|
||
|
||
class TestAsarrayFunction(unittest.TestCase): | ||
"""Test cases for asarray functions registered in dynamics_numpy_alias.""" | ||
|
||
def test_register_default(self): | ||
"""Test register_default.""" | ||
arr = Operator.from_label("X") | ||
self.assertTrue(isinstance(unp.asarray(arr), np.ndarray)) | ||
|
||
def test_scipy_sparse(self): | ||
"""Test asarray for scipy_sparse.""" | ||
arr = np.array([[1, 0], [0, 1]]) | ||
sparse_arr = csr_matrix([[1, 0], [0, 1]]) | ||
self.assertTrue(isinstance(unp.asarray(sparse_arr), csr_matrix)) | ||
self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), csr_matrix)) | ||
|
||
def test_jax_sparse(self): | ||
"""Test asarray for jax_sparse.""" | ||
try: | ||
from jax.experimental.sparse import BCOO | ||
|
||
arr = np.array([[1, 0], [0, 1]]) | ||
sparse_arr = BCOO.fromdense([[1, 0], [0, 1]]) | ||
self.assertTrue(isinstance(unp.asarray(sparse_arr), BCOO)) | ||
self.assertTrue(isinstance(DYNAMICS_NUMPY_ALIAS(like=sparse_arr).asarray(arr), BCOO)) | ||
except ImportError as err: | ||
raise unittest.SkipTest("Skipping jax tests.") from err |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2023. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" | ||
Test matmul functions | ||
""" | ||
import unittest | ||
import numpy as np | ||
from scipy.sparse import csr_matrix | ||
|
||
from qiskit_dynamics import DYNAMICS_NUMPY as unp | ||
from ...common import QiskitDynamicsTestCase | ||
|
||
|
||
class TestMatmulFunction(QiskitDynamicsTestCase): | ||
"""Test cases for matmul functions registered in dynamics_numpy_alias.""" | ||
|
||
def test_scipy_sparse(self): | ||
"""Test matmul for scipy_sparse.""" | ||
x = csr_matrix([[1, 0], [0, 1]]) | ||
y = csr_matrix([[2, 2], [2, 2]]) | ||
self.assertTrue(isinstance(unp.matmul(x, y), csr_matrix)) | ||
self.assertAllClose(csr_matrix.toarray(unp.matmul(x, y)), [[2, 2], [2, 2]]) | ||
|
||
def test_jax_sparse(self): | ||
"""Test matmul for jax_sparse.""" | ||
try: | ||
from jax.experimental.sparse import BCOO | ||
|
||
x = BCOO.fromdense([[1, 0], [0, 1]]) | ||
y = BCOO.fromdense([[2, 2], [2, 2]]) | ||
self.assertTrue(isinstance(unp.matmul(x, y), BCOO)) | ||
self.assertAllClose(BCOO.todense(unp.matmul(x, y)), [[2, 2], [2, 2]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same thing here, validate type of |
||
except ImportError as err: | ||
raise unittest.SkipTest("Skipping jax tests.") from err |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validate type of
unp.matmul(x, y)