Skip to content

Commit

Permalink
Fixing issues with JAX type hints when not installed (#323)
Browse files Browse the repository at this point in the history
add _isArrayLike function for ArrayLike type checking in 3.9 (#325)
moving jax test
fixing linting
  • Loading branch information
DanPuzzuoli committed Feb 20, 2024
1 parent 2638a82 commit 2b7d237
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 120 deletions.
85 changes: 2 additions & 83 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=no-self-use, # disabled as it is too verbose
fixme, # disabled as TODOs would show up as warnings
disable=fixme, # disabled as TODOs would show up as warnings
protected-access, # disabled as we don't follow the public vs private
# convention strictly
duplicate-code, # disabled as it is too verbose
Expand All @@ -70,8 +69,7 @@ disable=no-self-use, # disabled as it is too verbose
unnecessary-pass, # allow for methods with just "pass", for clarity
no-else-return, # relax "elif" after a clause with a return
docstring-first-line-empty, # relax docstring style
import-outside-toplevel,
bad-continuation, bad-whitespace # differences of opinion with black
import-outside-toplevel



Expand All @@ -82,12 +80,6 @@ disable=no-self-use, # disabled as it is too verbose
# mypackage.mymodule.MyReporterClass.
output-format=text

# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no

# Tells whether to display a full report or only the messages
reports=yes

Expand Down Expand Up @@ -133,66 +125,6 @@ include-naming-hint=no
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty

# Regular expression matching correct module names
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$

# Naming hint for module names
module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$

# Regular expression matching correct constant names
const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$

# Naming hint for constant names
const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$

# Regular expression matching correct class names
class-rgx=[A-Z_][a-zA-Z0-9]+$

# Naming hint for class names
class-name-hint=[A-Z_][a-zA-Z0-9]+$

# Regular expression matching correct function names
function-rgx=[a-z_][a-z0-9_]{2,30}$

# Naming hint for function names
function-name-hint=[a-z_][a-z0-9_]{2,30}$

# Regular expression matching correct method names
method-rgx=(([a-z_][a-z0-9_]{2,49})|(assert[A-Z][a-zA-Z0-9]{2,43})|(test_[_a-zA-Z0-9]{2,}))$

# Naming hint for method names
method-name-hint=[a-z_][a-z0-9_]{2,30}$ or camelCase `assert*` in tests.

# Regular expression matching correct attribute names
attr-rgx=[a-z_][a-z0-9_]{2,30}$

# Naming hint for attribute names
attr-name-hint=[a-z_][a-z0-9_]{2,30}$

# Regular expression matching correct argument names
argument-rgx=[a-z_][a-z0-9_]{2,30}|ax|dt$

# Naming hint for argument names
argument-name-hint=[a-z_][a-z0-9_]{2,30}$

# Regular expression matching correct variable names
variable-rgx=[a-z_][a-z0-9_]{2,30}$

# Naming hint for variable names
variable-name-hint=[a-z_][a-z0-9_]{2,30}$

# Regular expression matching correct class attribute names
class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$

# Naming hint for class attribute names
class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$

# Regular expression matching correct inline iteration names
inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$

# Naming hint for inline iteration names
inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$

# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
Expand Down Expand Up @@ -220,12 +152,6 @@ ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# else.
single-line-if-stmt=no

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,dict-separator

# Maximum number of lines in a module
max-module-lines=1000

Expand Down Expand Up @@ -416,10 +342,3 @@ known-third-party=enchant
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no


[EXCEPTIONS]

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
6 changes: 6 additions & 0 deletions docs/tutorials/optimizing_pulse_sequence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ for a more detailed explanation of why this step is necessary.

.. jupyter-execute::

#################################################################################
# Remove this
#################################################################################
import warnings
warnings.filterwarnings("ignore")

import jax
jax.config.update("jax_enable_x64", True)

Expand Down
8 changes: 8 additions & 0 deletions qiskit_dynamics/arraylias/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name

"""
Global alias instances.
Expand Down Expand Up @@ -76,6 +77,13 @@
ArrayLike = Union[Union[DYNAMICS_NUMPY_ALIAS.registered_types()], list]


def _isArrayLike(x: any) -> bool:
"""Return true if x is an ArrayLike object. Equivalent to isinstance(x, ArrayLike), which does
not work in Python 3.9.
"""
return isinstance(x, (DYNAMICS_NUMPY_ALIAS.registered_types(), list))


def _preferred_lib(*args, **kwargs):
"""Given a list of args and kwargs with potentially mixed array types, determine the appropriate
library to dispatch to.
Expand Down
4 changes: 2 additions & 2 deletions qiskit_dynamics/models/hamiltonian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS as numpy_alias
from qiskit_dynamics.arraylias.alias import ArrayLike
from qiskit_dynamics.arraylias.alias import ArrayLike, _isArrayLike
from qiskit_dynamics.signals import Signal, SignalList
from .generator_model import GeneratorModel
from .rotating_frame import RotatingFrame
Expand Down Expand Up @@ -170,7 +170,7 @@ def is_hermitian(operator: ArrayLike, tol: Optional[float] = 1e-10) -> bool:
elif type(operator).__name__ == "BCOO":
# fall back on array case for BCOO
return is_hermitian(operator.todense())
elif isinstance(operator, ArrayLike):
elif _isArrayLike(operator):
adj = None
adj = unp.transpose(unp.conjugate(operator))
return np.linalg.norm(adj - operator) < tol
Expand Down
15 changes: 10 additions & 5 deletions qiskit_dynamics/perturbation/array_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@

from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS as numpy_alias
from qiskit_dynamics.arraylias.alias import _preferred_lib, _numpy_multi_dispatch, ArrayLike
from qiskit_dynamics.arraylias.alias import (
_preferred_lib,
_numpy_multi_dispatch,
ArrayLike,
_isArrayLike,
)

from qiskit_dynamics.perturbation.multiset_utils import (
_validate_non_negative_ints,
Expand Down Expand Up @@ -367,7 +372,7 @@ def add(
QiskitError: if other cannot be cast as an ArrayPolynomial.
"""

if isinstance(other, ArrayLike):
if _isArrayLike(other):
other = ArrayPolynomial(constant_term=other)

if isinstance(other, ArrayPolynomial):
Expand Down Expand Up @@ -398,7 +403,7 @@ def matmul(
Raises:
QiskitError: if other cannot be cast as an ArrayPolynomial.
"""
if isinstance(other, ArrayLike):
if _isArrayLike(other):
other = ArrayPolynomial(constant_term=other)

if isinstance(other, ArrayPolynomial):
Expand Down Expand Up @@ -430,7 +435,7 @@ def mul(
QiskitError: if other cannot be cast as an ArrayPolynomial.
"""

if isinstance(other, ArrayLike):
if _isArrayLike(other):
other = ArrayPolynomial(constant_term=other)

if isinstance(other, ArrayPolynomial):
Expand Down Expand Up @@ -485,7 +490,7 @@ def __matmul__(self, other: Union["ArrayPolynomial", ArrayLike]) -> "ArrayPolyno

def __rmatmul__(self, other: Union["ArrayPolynomial", ArrayLike]) -> "ArrayPolynomial":
"""Dunder method for rmatmul."""
if isinstance(other, ArrayLike):
if _isArrayLike(other):
other = ArrayPolynomial(constant_term=other)

if isinstance(other, ArrayPolynomial):
Expand Down
10 changes: 5 additions & 5 deletions qiskit_dynamics/perturbation/custom_binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ def _compute_linear_combos(


def _compute_unique_evaluations_jax(
A: jnp.ndarray,
B: jnp.ndarray,
A: np.ndarray,
B: np.ndarray,
unique_evaluation_pairs: np.array,
binary_op: Callable,
) -> jnp.ndarray:
) -> np.ndarray:
"""JAX version of a single loop step of :meth:`linear_combos`. Note that in this function
binary_op is assumed to be vectorized."""
A = jnp.append(A, jnp.zeros((1,) + A[0].shape, dtype=complex), axis=0)
Expand All @@ -282,8 +282,8 @@ def _compute_unique_evaluations_jax(


def _compute_single_linear_combo_jax(
unique_evaluations: jnp.ndarray, single_combo_rule: Tuple[np.array, np.array]
) -> jnp.ndarray:
unique_evaluations: np.ndarray, single_combo_rule: Tuple[np.array, np.array]
) -> np.ndarray:
"""JAX version of :meth:`unique_products`."""
coeffs, indices = single_combo_rule
return jnp.tensordot(coeffs, unique_evaluations[indices], axes=1)
Expand Down
1 change: 1 addition & 0 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name

"""
Pulse schedule to Signals converter.
Expand Down
8 changes: 4 additions & 4 deletions qiskit_dynamics/solvers/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def lanczos_expm(


@requires_array_library("jax")
def jax_lanczos_basis(A: jnp.ndarray, y0: jnp.ndarray, k_dim: int):
def jax_lanczos_basis(A: np.ndarray, y0: np.ndarray, k_dim: int):
"""JAX version of lanczos_basis."""

data_type = jnp.result_type(A.dtype, y0.dtype)
Expand Down Expand Up @@ -206,7 +206,7 @@ def zeros_func(_):


@requires_array_library("jax")
def jax_lanczos_eigh(A: jnp.ndarray, y0: jnp.ndarray, k_dim: int):
def jax_lanczos_eigh(A: np.ndarray, y0: np.ndarray, k_dim: int):
"""JAX version of lanczos_eigh."""

tridiagonal, q_basis = jax_lanczos_basis(A, y0, k_dim)
Expand All @@ -217,8 +217,8 @@ def jax_lanczos_eigh(A: jnp.ndarray, y0: jnp.ndarray, k_dim: int):

@requires_array_library("jax")
def jax_lanczos_expm(
A: jnp.ndarray,
y0: jnp.ndarray,
A: np.ndarray,
y0: np.ndarray,
k_dim: int,
scale_factor: Optional[float] = 1,
):
Expand Down
3 changes: 2 additions & 1 deletion qiskit_dynamics/solvers/solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from qiskit_dynamics import ArrayLike
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import DYNAMICS_NUMPY_ALIAS as numpy_alias
from qiskit_dynamics.arraylias.alias import _isArrayLike

from qiskit_dynamics.models import (
HamiltonianModel,
Expand Down Expand Up @@ -725,7 +726,7 @@ def initial_state_converter(obj: Any) -> Tuple[ArrayLike, Type, Callable]:
"""
# pylint: disable=invalid-name
y0_cls = None
if isinstance(obj, ArrayLike):
if _isArrayLike(obj):
y0, y0_cls, wrapper = obj, None, lambda x: x
if isinstance(obj, QuantumState):
y0, y0_cls = obj.data, obj.__class__
Expand Down
4 changes: 2 additions & 2 deletions qiskit_dynamics/solvers/solver_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def trim_t_results(
return results


def merge_t_args_jax(t_span: ArrayLike, t_eval: Optional[ArrayLike] = None) -> jnp.ndarray:
def merge_t_args_jax(t_span: ArrayLike, t_eval: Optional[ArrayLike] = None) -> np.ndarray:
"""JAX-compilable version of merge_t_args.
Rather than raise errors, sets return values to ``jnp.nan`` to signal errors.
Expand All @@ -135,7 +135,7 @@ def merge_t_args_jax(t_span: ArrayLike, t_eval: Optional[ArrayLike] = None) -> j
t_eval: Time points to include in returned results.
Returns:
jnp.ndarray: Combined list of times.
np.ndarray: Combined list of times.
Raises:
ValueError: If either argument is not one dimensional.
Expand Down
4 changes: 2 additions & 2 deletions test/dynamics/models/test_generator_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from qiskit import QiskitError
from qiskit.quantum_info.operators import Operator
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics.arraylias.alias import ArrayLike
from qiskit_dynamics.arraylias.alias import _isArrayLike
from qiskit_dynamics.models import GeneratorModel, RotatingFrame
from qiskit_dynamics.models.generator_model import (
_static_operator_into_frame_basis,
Expand Down Expand Up @@ -198,7 +198,7 @@ def _basic_frame_evaluate_test(self, frame_operator, t):
if isinstance(frame_operator, Operator):
frame_operator = frame_operator.data

if isinstance(frame_operator, ArrayLike) and frame_operator.ndim == 1:
if _isArrayLike(frame_operator) and frame_operator.ndim == 1:
frame_operator = np.diag(frame_operator)

value = basic_model(t)
Expand Down
4 changes: 2 additions & 2 deletions test/dynamics/models/test_hamiltonian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from qiskit import QiskitError
from qiskit.quantum_info.operators import Operator
from qiskit_dynamics.arraylias.alias import ArrayLike
from qiskit_dynamics.arraylias.alias import _isArrayLike
from qiskit_dynamics.models import HamiltonianModel
from qiskit_dynamics.models.hamiltonian_model import is_hermitian
from qiskit_dynamics.signals import Signal, SignalList
Expand Down Expand Up @@ -116,7 +116,7 @@ def _basic_frame_evaluate_test(self, frame_operator, t):
# convert to 2d array
if isinstance(frame_operator, Operator):
frame_operator = frame_operator.data
if isinstance(frame_operator, ArrayLike) and frame_operator.ndim == 1:
if _isArrayLike(frame_operator) and frame_operator.ndim == 1:
frame_operator = np.diag(frame_operator)

value = basic_hamiltonian(t) / -1j
Expand Down
1 change: 1 addition & 0 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name
"""
Tests to convert from pulse schedules to signals.
"""
Expand Down
Loading

0 comments on commit 2b7d237

Please sign in to comment.