Skip to content

Commit

Permalink
Operator collections arraylias integration (#291)
Browse files Browse the repository at this point in the history
Co-authored-by: Kento Ueda <[email protected]>
  • Loading branch information
DanPuzzuoli and to24toro committed Jan 24, 2024
1 parent 21d4efe commit da27bf1
Show file tree
Hide file tree
Showing 13 changed files with 921 additions and 1,227 deletions.
15 changes: 12 additions & 3 deletions qiskit_dynamics/arraylias/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
register_matmul,
register_multiply,
register_rmatmul,
register_linear_combo,
register_transpose,
register_conjugate,
)

# global NumPy and SciPy aliases
Expand Down Expand Up @@ -60,6 +63,9 @@
register_matmul(alias=DYNAMICS_NUMPY_ALIAS)
register_multiply(alias=DYNAMICS_NUMPY_ALIAS)
register_rmatmul(alias=DYNAMICS_NUMPY_ALIAS)
register_linear_combo(alias=DYNAMICS_NUMPY_ALIAS)
register_conjugate(alias=DYNAMICS_NUMPY_ALIAS)
register_transpose(alias=DYNAMICS_NUMPY_ALIAS)


ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list]
Expand All @@ -82,13 +88,16 @@ def _preferred_lib(*args, **kwargs):
"""
args = list(args) + list(kwargs.values())
if len(args) == 1:
return DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])
libs = DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])
return libs[0] if len(libs) > 0 else "numpy"

lib0 = DYNAMICS_NUMPY_ALIAS.infer_libs(args[0])[0]
lib1 = _preferred_lib(args[1:])[0]
lib0 = _preferred_lib(args[0])
lib1 = _preferred_lib(args[1:])

if lib0 == "numpy" and lib1 == "numpy":
return "numpy"
elif lib0 == "jax_sparse" or lib1 == "jax_sparse":
return "jax_sparse"
elif lib0 == "jax" or lib1 == "jax":
return "jax"

Expand Down
3 changes: 3 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@
from .matmul import register_matmul
from .rmatmul import register_rmatmul
from .multiply import register_multiply
from .linear_combo import register_linear_combo
from .conjugate import register_conjugate
from .transpose import register_transpose
37 changes: 37 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/conjugate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Registering conjugate.
"""


def register_conjugate(alias):
"""Register linear functions for each array library."""

try:
from jax.dtypes import canonicalize_dtype
import jax.numpy as jnp
from jax.experimental.sparse import sparsify

# can be changed to sparsify(jnp.conjugate) when implemented
def conj_workaround(x):
if jnp.issubdtype(x.dtype, canonicalize_dtype(jnp.complex128)):
return x.real - 1j * x.imag
return x

alias.register_function(func=sparsify(conj_workaround), lib="jax_sparse", path="conjugate")

except ImportError:
pass
51 changes: 51 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/linear_combo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Registering linear_combo functions to alias. This computes a linear combination of matrices (given
by a 3d array).
"""

import numpy as np


def register_linear_combo(alias):
"""Register linear functions for each array library."""

@alias.register_default(path="linear_combo")
def _(coeffs, mats):
return np.tensordot(coeffs, mats, axes=1)

@alias.register_function(lib="numpy", path="linear_combo")
def _(coeffs, mats):
return np.tensordot(coeffs, mats, axes=1)

try:
import jax.numpy as jnp

@alias.register_function(lib="jax", path="linear_combo")
def _(coeffs, mats):
return jnp.tensordot(coeffs, mats, axes=1)

from jax.experimental.sparse import sparsify

jsparse_sum = sparsify(jnp.sum)

@alias.register_function(lib="jax_sparse", path="linear_combo")
def _(coeffs, mats):
# pylint: disable=unexpected-keyword-arg
return jsparse_sum(jnp.broadcast_to(coeffs[:, None, None], mats.shape) * mats, axis=0)

except ImportError:
pass
31 changes: 31 additions & 0 deletions qiskit_dynamics/arraylias/register_functions/transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Registering transpose.
"""


def register_transpose(alias):
"""Register linear functions for each array library."""

try:
from jax.experimental.sparse import bcoo_transpose

@alias.register_function(lib="jax_sparse", path="transpose")
def _(arr, axes=None):
return bcoo_transpose(arr, permutation=axes)

except ImportError:
pass
19 changes: 10 additions & 9 deletions qiskit_dynamics/models/generator_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@
from qiskit import QiskitError
from qiskit.quantum_info.operators import Operator
from qiskit_dynamics.models.operator_collections import (
BaseOperatorCollection,
DenseOperatorCollection,
SparseOperatorCollection,
JAXSparseOperatorCollection,
OperatorCollection,
ScipySparseOperatorCollection,
)
from qiskit_dynamics.array import Array
from qiskit_dynamics.signals import Signal, SignalList
Expand Down Expand Up @@ -538,7 +536,7 @@ def construct_operator_collection(
evaluation_mode: str,
static_operator: Union[None, Array, csr_matrix],
operators: Union[None, Array, List[csr_matrix]],
) -> BaseOperatorCollection:
) -> Union[OperatorCollection, ScipySparseOperatorCollection]:
"""Construct an operator collection for :class:`GeneratorModel`.
Args:
Expand All @@ -547,14 +545,15 @@ def construct_operator_collection(
operators: Operators for the model.
Returns:
BaseOperatorCollection: The relevant operator collection.
Union[OperatorCollection, ScipySparseOperatorCollection]: The relevant operator collection.
Raises:
NotImplementedError: If the ``evaluation_mode`` is invalid.
"""

if evaluation_mode == "dense":
return DenseOperatorCollection(static_operator=static_operator, operators=operators)
# return DenseOperatorCollection(static_operator=static_operator, operators=operators)
pass
if evaluation_mode == "sparse" and Array.default_backend() == "jax":
# warn that sparse mode when using JAX is primarily recommended for use on CPU
if jax.default_backend() != "cpu":
Expand All @@ -563,9 +562,11 @@ def construct_operator_collection(
stacklevel=2,
)

return JAXSparseOperatorCollection(static_operator=static_operator, operators=operators)
# return JAXSparseOperatorCollection(static_operator=static_operator, operators=operators)
pass
if evaluation_mode == "sparse":
return SparseOperatorCollection(static_operator=static_operator, operators=operators)
# return SparseOperatorCollection(static_operator=static_operator, operators=operators)
pass

raise NotImplementedError(
f"Evaluation mode '{evaluation_mode}' is not supported. Call "
Expand Down
36 changes: 22 additions & 14 deletions qiskit_dynamics/models/lindblad_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,10 @@
)
from .hamiltonian_model import HamiltonianModel, is_hermitian
from .operator_collections import (
BaseLindbladOperatorCollection,
DenseLindbladCollection,
DenseVectorizedLindbladCollection,
SparseLindbladCollection,
JAXSparseLindbladCollection,
SparseVectorizedLindbladCollection,
JAXSparseVectorizedLindbladCollection,
LindbladCollection,
ScipySparseLindbladCollection,
VectorizedLindbladCollection,
ScipySparseVectorizedLindbladCollection,
)
from .rotating_frame import RotatingFrame

Expand Down Expand Up @@ -654,7 +651,12 @@ def construct_lindblad_operator_collection(
hamiltonian_operators: Union[None, Array, List[csr_matrix]],
static_dissipators: Union[None, Array, csr_matrix],
dissipator_operators: Union[None, Array, List[csr_matrix]],
) -> BaseLindbladOperatorCollection:
) -> Union[
LindbladCollection,
ScipySparseLindbladCollection,
VectorizedLindbladCollection,
ScipySparseVectorizedLindbladCollection,
]:
"""Construct a Lindblad operator collection.
Args:
Expand Down Expand Up @@ -685,19 +687,25 @@ def construct_lindblad_operator_collection(
)

if evaluation_mode == "dense":
CollectionClass = DenseLindbladCollection
# CollectionClass = DenseLindbladCollection
pass
elif evaluation_mode == "sparse":
if Array.default_backend() == "jax":
CollectionClass = JAXSparseLindbladCollection
# CollectionClass = JAXSparseLindbladCollection
pass
else:
CollectionClass = SparseLindbladCollection
# CollectionClass = SparseLindbladCollection
pass
elif evaluation_mode == "dense_vectorized":
CollectionClass = DenseVectorizedLindbladCollection
# CollectionClass = DenseVectorizedLindbladCollection
pass
elif evaluation_mode == "sparse_vectorized":
if Array.default_backend() == "jax":
CollectionClass = JAXSparseVectorizedLindbladCollection
# CollectionClass = JAXSparseVectorizedLindbladCollection
pass
else:
CollectionClass = SparseVectorizedLindbladCollection
# CollectionClass = SparseVectorizedLindbladCollection
pass
else:
raise NotImplementedError(
f"Evaluation mode '{evaluation_mode}' is not supported. See "
Expand Down
Loading

0 comments on commit da27bf1

Please sign in to comment.